Add const-eval for round

This CL adds const-eval for `round`.

Bug: tint:1581
Change-Id: I16ebb2010969debb8de50479235d9d9f5433c0a1
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110175
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair
2022-11-16 22:52:08 +00:00
committed by Dawn LUCI CQ
parent c214cbe98b
commit 19e5042ade
165 changed files with 2543 additions and 310 deletions

View File

@@ -2270,6 +2270,42 @@ ConstEval::Result ConstEval::reverseBits(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::round(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto e) {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
auto integral = NumberT(0);
auto fract = std::abs(std::modf(e.value, &(integral.value)));
// When e lies halfway between integers k and k + 1, the result is k when k is even,
// and k + 1 when k is odd.
NumberT result = NumberT(0.0);
if (fract == NumberT(0.5)) {
// If the integral value is negative, then we need to subtract one in order to move
// to the correct `k`. The half way check is `k` and `k + 1` which in the positive
// case is `x` and `x + 1` but in the negative case is `x - 1` and `x`.
T integral_val = integral.value;
if (std::signbit(integral_val)) {
integral_val = std::abs(integral_val - 1);
}
if (uint64_t(integral_val) % 2 == 0) {
result = NumberT(std::floor(e.value));
} else {
result = NumberT(std::ceil(e.value));
}
} else {
result = NumberT(std::round(e.value));
}
return CreateElement(builder, c0->Type(), result);
};
return Dispatch_fa_f32_f16(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::saturate(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@@ -656,6 +656,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// round builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result round(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// saturate builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -1477,6 +1477,36 @@ INSTANTIATE_TEST_SUITE_P( //
testing::ValuesIn(Concat(ReverseBitsCases<i32>(), //
ReverseBitsCases<u32>()))));
template <typename T>
std::vector<Case> RoundCases() {
std::vector<Case> cases = {
C({T(0.0)}, T(0.0)), //
C({-T(0.0)}, -T(0.0)), //
C({T(1.5)}, T(2.0)), //
C({T(2.5)}, T(2.0)), //
C({T(2.4)}, T(2.0)), //
C({T(2.6)}, T(3.0)), //
C({T(1.49999)}, T(1.0)), //
C({T(1.50001)}, T(2.0)), //
C({-T(1.5)}, -T(2.0)), //
C({-T(2.5)}, -T(2.0)), //
C({-T(2.6)}, -T(3.0)), //
C({-T(2.4)}, -T(2.0)), //
// Vector tests
C({Vec(T(0.0), T(1.5), T(2.5))}, Vec(T(0.0), T(2.0), T(2.0))),
};
return cases;
}
INSTANTIATE_TEST_SUITE_P( //
Round,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kRound),
testing::ValuesIn(Concat(RoundCases<AFloat>(), //
RoundCases<f32>(),
RoundCases<f16>()))));
template <typename T>
std::vector<Case> SaturateCases() {
return {

View File

@@ -12993,24 +12993,24 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[895],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::round,
},
{
/* [389] */
/* num parameters */ 1,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[26],
/* template types */ &kTemplateTypes[23],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[896],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::round,
},
{
/* [390] */
@@ -14461,8 +14461,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
},
{
/* [66] */
/* fn round<T : f32_f16>(T) -> T */
/* fn round<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
/* fn round<T : fa_f32_f16>(@test_value(3.4) T) -> T */
/* fn round<N : num, T : fa_f32_f16>(@test_value(3.4) vec<N, T>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[388],
},