mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-21 18:59:21 +00:00
Add const-eval for round
This CL adds const-eval for `round`. Bug: tint:1581 Change-Id: I16ebb2010969debb8de50479235d9d9f5433c0a1 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110175 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
c214cbe98b
commit
19e5042ade
@@ -2270,6 +2270,42 @@ ConstEval::Result ConstEval::reverseBits(const sem::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::round(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
auto transform = [&](const sem::Constant* c0) {
|
||||
auto create = [&](auto e) {
|
||||
using NumberT = decltype(e);
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
|
||||
auto integral = NumberT(0);
|
||||
auto fract = std::abs(std::modf(e.value, &(integral.value)));
|
||||
// When e lies halfway between integers k and k + 1, the result is k when k is even,
|
||||
// and k + 1 when k is odd.
|
||||
NumberT result = NumberT(0.0);
|
||||
if (fract == NumberT(0.5)) {
|
||||
// If the integral value is negative, then we need to subtract one in order to move
|
||||
// to the correct `k`. The half way check is `k` and `k + 1` which in the positive
|
||||
// case is `x` and `x + 1` but in the negative case is `x - 1` and `x`.
|
||||
T integral_val = integral.value;
|
||||
if (std::signbit(integral_val)) {
|
||||
integral_val = std::abs(integral_val - 1);
|
||||
}
|
||||
if (uint64_t(integral_val) % 2 == 0) {
|
||||
result = NumberT(std::floor(e.value));
|
||||
} else {
|
||||
result = NumberT(std::ceil(e.value));
|
||||
}
|
||||
} else {
|
||||
result = NumberT(std::round(e.value));
|
||||
}
|
||||
return CreateElement(builder, c0->Type(), result);
|
||||
};
|
||||
return Dispatch_fa_f32_f16(create, c0);
|
||||
};
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::saturate(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
|
||||
@@ -656,6 +656,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// round builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result round(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// saturate builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
||||
@@ -1477,6 +1477,36 @@ INSTANTIATE_TEST_SUITE_P( //
|
||||
testing::ValuesIn(Concat(ReverseBitsCases<i32>(), //
|
||||
ReverseBitsCases<u32>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> RoundCases() {
|
||||
std::vector<Case> cases = {
|
||||
C({T(0.0)}, T(0.0)), //
|
||||
C({-T(0.0)}, -T(0.0)), //
|
||||
C({T(1.5)}, T(2.0)), //
|
||||
C({T(2.5)}, T(2.0)), //
|
||||
C({T(2.4)}, T(2.0)), //
|
||||
C({T(2.6)}, T(3.0)), //
|
||||
C({T(1.49999)}, T(1.0)), //
|
||||
C({T(1.50001)}, T(2.0)), //
|
||||
C({-T(1.5)}, -T(2.0)), //
|
||||
C({-T(2.5)}, -T(2.0)), //
|
||||
C({-T(2.6)}, -T(3.0)), //
|
||||
C({-T(2.4)}, -T(2.0)), //
|
||||
|
||||
// Vector tests
|
||||
C({Vec(T(0.0), T(1.5), T(2.5))}, Vec(T(0.0), T(2.0), T(2.0))),
|
||||
};
|
||||
|
||||
return cases;
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
Round,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kRound),
|
||||
testing::ValuesIn(Concat(RoundCases<AFloat>(), //
|
||||
RoundCases<f32>(),
|
||||
RoundCases<f16>()))));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> SaturateCases() {
|
||||
return {
|
||||
|
||||
@@ -12993,24 +12993,24 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[895],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::round,
|
||||
},
|
||||
{
|
||||
/* [389] */
|
||||
/* num parameters */ 1,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[26],
|
||||
/* template types */ &kTemplateTypes[23],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[896],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::round,
|
||||
},
|
||||
{
|
||||
/* [390] */
|
||||
@@ -14461,8 +14461,8 @@ constexpr IntrinsicInfo kBuiltins[] = {
|
||||
},
|
||||
{
|
||||
/* [66] */
|
||||
/* fn round<T : f32_f16>(T) -> T */
|
||||
/* fn round<N : num, T : f32_f16>(vec<N, T>) -> vec<N, T> */
|
||||
/* fn round<T : fa_f32_f16>(@test_value(3.4) T) -> T */
|
||||
/* fn round<N : num, T : fa_f32_f16>(@test_value(3.4) vec<N, T>) -> vec<N, T> */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[388],
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user