Add const-eval for sign.

This CL adds const-eval for the `sign` builtin.

Bug: tint:1581
Change-Id: I5d9bfd3f3f742bcba69fbb0d7f47dc57ce18e134
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107460
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
dan sinclair
2022-10-28 14:56:50 +00:00
committed by Dawn LUCI CQ
parent fd1b5a893d
commit c395660ee6
102 changed files with 2512 additions and 187 deletions

View File

@@ -527,8 +527,8 @@ fn round<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const("select_bool") fn select<T: scalar>(T, T, bool) -> T
@const("select_bool") fn select<T: scalar, N: num>(vec<N, T>, vec<N, T>, bool) -> vec<N, T>
@const("select_boolvec") fn select<N: num, T: scalar>(vec<N, T>, vec<N, T>, vec<N, bool>) -> vec<N, T>
fn sign<T: f32_f16>(T) -> T
fn sign<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
@const fn sign<T: fa_f32_f16>(T) -> T
@const fn sign<N: num, T: fa_f32_f16>(vec<N, T>) -> vec<N, T>
fn sin<T: f32_f16>(T) -> T
fn sin<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn sinh<T: f32_f16>(T) -> T

View File

@@ -1659,6 +1659,28 @@ ConstEval::Result ConstEval::select_boolvec(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::sign(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto e) -> ImplResult {
using NumberT = decltype(e);
NumberT result;
NumberT zero{0.0};
if (e.value < zero) {
result = NumberT{-1.0};
} else if (e.value > zero) {
result = NumberT{1.0};
} else {
result = zero;
}
return CreateElement(builder, c0->Type(), result);
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::step(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@@ -458,6 +458,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// sign builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result sign(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// step builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -531,6 +531,32 @@ INSTANTIATE_TEST_SUITE_P( //
SelectCases<f16>(),
SelectBoolCases()))));
template <typename T>
std::vector<Case> SignCases() {
return {
C({-T(1)}, -T(1)),
C({-T(0.5)}, -T(1)),
C({T(0)}, T(0)),
C({-T(0)}, T(0)),
C({T(0.5)}, T(1)),
C({T(1)}, T(1)),
C({T::Highest()}, T(1.0)),
C({T::Lowest()}, -T(1.0)),
// Vector tests
C({Vec(-T(0.5), T(0), T(0.5))}, Vec(-T(1.0), T(0.0), T(1.0))),
C({Vec(T::Highest(), T::Lowest())}, Vec(T(1.0), -T(1.0))),
};
}
INSTANTIATE_TEST_SUITE_P( //
Sign,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kSign),
testing::ValuesIn(Concat(SignCases<AFloat>(), //
SignCases<f32>(),
SignCases<f16>()))));
template <typename T>
std::vector<Case> StepCases() {
return {

View File

@@ -12455,24 +12455,24 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[904],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::sign,
},
{
/* [345] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[24],
/* template numbers */ &kTemplateNumbers[5],
/* parameters */ &kParameters[905],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::sign,
},
{
/* [346] */
@@ -14442,8 +14442,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [68] */
/* fn sign<T : f32_f16>(T) -> T */
/* fn sign<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn sign<T : fa_f32_f16>(T) -> T */
/* fn sign<N : num, T : fa_f32_f16>(vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[344],
},