From babc21f9316878f04a8968bef870453dabd8ee3c Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Wed, 30 Nov 2022 16:02:49 +0000 Subject: [PATCH] tint: const eval of modulo operator Bug: tint:1581 Change-Id: Icf9aaab29f45a41e6c367f60bf98ccc8958a56c7 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112322 Commit-Queue: Antonio Maiorano Reviewed-by: Ben Clayton Kokoro: Kokoro --- src/tint/intrinsics.def | 8 +- src/tint/resolver/const_eval.cc | 53 +++++- src/tint/resolver/const_eval.h | 24 +++ .../resolver/const_eval_binary_op_test.cc | 155 ++++++++++++++++++ src/tint/resolver/intrinsic_table.inl | 24 +-- 5 files changed, 246 insertions(+), 18 deletions(-) diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def index fb75c979fb..6933e57bf7 100644 --- a/src/tint/intrinsics.def +++ b/src/tint/intrinsics.def @@ -932,10 +932,10 @@ init mat4x4(vec4, vec4, vec4, vec4) -> mat4x4 @const op / (vec, T) -> vec @const op / (T, vec) -> vec -op % (T, T) -> T -op % (vec, vec) -> vec -op % (vec, T) -> vec -op % (T, vec) -> vec +@const op % (T, T) -> T +@const op % (vec, vec) -> vec +@const op % (vec, T) -> vec +@const op % (T, vec) -> vec @const op ^ (T, T) -> T @const op ^ (vec, vec) -> vec diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 719582458e..8170adbe08 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -744,7 +744,6 @@ utils::Result ConstEval::Mul(const Source& source, NumberT a, NumberT b using T = UnwrapNumber; NumberT result; if constexpr (IsAbstract || IsFloatingPoint) { - // Check for over/underflow for abstract values if (auto r = CheckedMul(a, b)) { result = r->value; } else { @@ -770,7 +769,6 @@ template utils::Result ConstEval::Div(const Source& source, NumberT a, NumberT b) { NumberT result; if constexpr (IsAbstract || IsFloatingPoint) { - // Check for over/underflow for abstract values if (auto r = CheckedDiv(a, b)) { result = r->value; } else { @@ -799,6 +797,38 @@ utils::Result ConstEval::Div(const Source& source, NumberT a, NumberT b return result; } +template +utils::Result ConstEval::Mod(const Source& source, NumberT a, NumberT b) { + NumberT result; + if constexpr (IsAbstract || IsFloatingPoint) { + if (auto r = CheckedMod(a, b)) { + result = r->value; + } else { + AddError(OverflowErrorMessage(a, "%", b), source); + return utils::Failure; + } + } else { + using T = UnwrapNumber; + auto lhs = a.value; + auto rhs = b.value; + if (rhs == 0) { + // lhs % 0 is an error + AddError(OverflowErrorMessage(a, "%", b), source); + return utils::Failure; + } + if constexpr (std::is_signed_v) { + // For signed integers, lhs % -1 where lhs is the + // most negative value is an error + if (rhs == -1 && lhs == std::numeric_limits::min()) { + AddError(OverflowErrorMessage(a, "%", b), source); + return utils::Failure; + } + } + result = lhs % rhs; + } + return result; +} + template utils::Result ConstEval::Dot2(const Source& source, NumberT a1, @@ -1111,6 +1141,15 @@ auto ConstEval::DivFunc(const Source& source, const sem::Type* elem_ty) { }; } +auto ConstEval::ModFunc(const Source& source, const sem::Type* elem_ty) { + return [=](auto a1, auto a2) -> ImplResult { + if (auto r = Mod(source, a1, a2)) { + return CreateElement(builder, source, elem_ty, r.Get()); + } + return utils::Failure; + }; +} + auto ConstEval::Dot2Func(const Source& source, const sem::Type* elem_ty) { return [=](auto a1, auto a2, auto b1, auto b2) -> ImplResult { if (auto r = Dot2(source, a1, a2, b1, b2)) { @@ -1683,6 +1722,16 @@ ConstEval::Result ConstEval::OpDivide(const sem::Type* ty, return TransformBinaryElements(builder, ty, transform, args[0], args[1]); } +ConstEval::Result ConstEval::OpModulo(const sem::Type* ty, + utils::VectorRef args, + const Source& source) { + auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { + return Dispatch_fia_fiu32_f16(ModFunc(source, c0->Type()), c0, c1); + }; + + return TransformBinaryElements(builder, ty, transform, args[0], args[1]); +} + ConstEval::Result ConstEval::OpEqual(const sem::Type* ty, utils::VectorRef args, const Source& source) { diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index d92dc3fedc..9905036da0 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -283,6 +283,15 @@ class ConstEval { utils::VectorRef args, const Source& source); + /// Modulo operator '%' + /// @param ty the expression type + /// @param args the input arguments + /// @param source the source location + /// @return the result value, or null if the value cannot be calculated + Result OpModulo(const sem::Type* ty, + utils::VectorRef args, + const Source& source); + /// Equality operator '==' /// @param ty the expression type /// @param args the input arguments @@ -1040,6 +1049,14 @@ class ConstEval { template utils::Result Div(const Source& source, NumberT a, NumberT b); + /// Returns the (signed) remainder of the division of two Numbers + /// @param source the source location + /// @param a the lhs number + /// @param b the rhs number + /// @returns the result number on success, or logs an error and returns Failure + template + utils::Result Mod(const Source& source, NumberT a, NumberT b); + /// Returns the dot product of (a1,a2) with (b1,b2) /// @param source the source location /// @param a1 component 1 of lhs vector @@ -1216,6 +1233,13 @@ class ConstEval { /// @returns the callable function auto DivFunc(const Source& source, const sem::Type* elem_ty); + /// Returns a callable that calls Mod, and creates a Constant with its result of type `elem_ty` + /// if successful, or returns Failure otherwise. + /// @param source the source location + /// @param elem_ty the element type of the Constant to create on success + /// @returns the callable function + auto ModFunc(const Source& source, const sem::Type* elem_ty); + /// Returns a callable that calls Dot2, and creates a Constant with its result of type `elem_ty` /// if successful, or returns Failure otherwise. /// @param source the source location diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc index 2a13360202..2590466fbf 100644 --- a/src/tint/resolver/const_eval_binary_op_test.cc +++ b/src/tint/resolver/const_eval_binary_op_test.cc @@ -485,6 +485,161 @@ INSTANTIATE_TEST_SUITE_P(Div, OpDivFloatCases(), OpDivFloatCases())))); +template +std::vector OpModCases() { + auto error_msg = [](auto a, auto b) { + return "12:34 error: " + OverflowErrorMessage(a, "%", b); + }; + + // Common cases for all types + std::vector r = { + C(T{0}, T{1}, T{0}), // + C(T{1}, T{1}, T{0}), // + C(T{10}, T{1}, T{0}), // + C(T{10}, T{2}, T{0}), // + C(T{10}, T{3}, T{1}), // + C(T{10}, T{4}, T{2}), // + C(T{10}, T{5}, T{0}), // + C(T{10}, T{6}, T{4}), // + C(T{10}, T{5}, T{0}), // + C(T{10}, T{8}, T{2}), // + C(T{10}, T{9}, T{1}), // + C(T{10}, T{10}, T{0}), // + + // Error on divide by zero + E(T{123}, T{0}, error_msg(T{123}, T{0})), + E(T::Highest(), T{0}, error_msg(T::Highest(), T{0})), + E(T::Lowest(), T{0}, error_msg(T::Lowest(), T{0})), + }; + + if constexpr (IsIntegral) { + ConcatInto( // + r, std::vector{ + C(T::Highest(), T{T::Highest() - T{1}}, T{1}), + }); + } + + if constexpr (IsSignedIntegral) { + ConcatInto( // + r, std::vector{ + C(T::Lowest(), T{T::Lowest() + T{1}}, -T(1)), + + // Error on most negative integer divided by -1 + E(T::Lowest(), T{-1}, error_msg(T::Lowest(), T{-1})), + }); + } + + // Negative values (both signed integrals and floating point) + if constexpr (IsSignedIntegral || IsFloatingPoint) { + ConcatInto( // + r, std::vector{ + C(-T{1}, T{1}, T{0}), // + + // lhs negative, rhs positive + C(-T{10}, T{1}, T{0}), // + C(-T{10}, T{2}, T{0}), // + C(-T{10}, T{3}, -T{1}), // + C(-T{10}, T{4}, -T{2}), // + C(-T{10}, T{5}, T{0}), // + C(-T{10}, T{6}, -T{4}), // + C(-T{10}, T{5}, T{0}), // + C(-T{10}, T{8}, -T{2}), // + C(-T{10}, T{9}, -T{1}), // + C(-T{10}, T{10}, T{0}), // + + // lhs positive, rhs negative + C(T{10}, -T{1}, T{0}), // + C(T{10}, -T{2}, T{0}), // + C(T{10}, -T{3}, T{1}), // + C(T{10}, -T{4}, T{2}), // + C(T{10}, -T{5}, T{0}), // + C(T{10}, -T{6}, T{4}), // + C(T{10}, -T{5}, T{0}), // + C(T{10}, -T{8}, T{2}), // + C(T{10}, -T{9}, T{1}), // + C(T{10}, -T{10}, T{0}), // + + // lhs negative, rhs negative + C(-T{10}, -T{1}, T{0}), // + C(-T{10}, -T{2}, T{0}), // + C(-T{10}, -T{3}, -T{1}), // + C(-T{10}, -T{4}, -T{2}), // + C(-T{10}, -T{5}, T{0}), // + C(-T{10}, -T{6}, -T{4}), // + C(-T{10}, -T{5}, T{0}), // + C(-T{10}, -T{8}, -T{2}), // + C(-T{10}, -T{9}, -T{1}), // + C(-T{10}, -T{10}, T{0}), // + }); + } + + // Float values + if constexpr (IsFloatingPoint) { + ConcatInto( // + r, std::vector{ + C(T{10.5}, T{1}, T{0.5}), // + C(T{10.5}, T{2}, T{0.5}), // + C(T{10.5}, T{3}, T{1.5}), // + C(T{10.5}, T{4}, T{2.5}), // + C(T{10.5}, T{5}, T{0.5}), // + C(T{10.5}, T{6}, T{4.5}), // + C(T{10.5}, T{5}, T{0.5}), // + C(T{10.5}, T{8}, T{2.5}), // + C(T{10.5}, T{9}, T{1.5}), // + C(T{10.5}, T{10}, T{0.5}), // + + // lhs negative, rhs positive + C(-T{10.5}, T{1}, -T{0.5}), // + C(-T{10.5}, T{2}, -T{0.5}), // + C(-T{10.5}, T{3}, -T{1.5}), // + C(-T{10.5}, T{4}, -T{2.5}), // + C(-T{10.5}, T{5}, -T{0.5}), // + C(-T{10.5}, T{6}, -T{4.5}), // + C(-T{10.5}, T{5}, -T{0.5}), // + C(-T{10.5}, T{8}, -T{2.5}), // + C(-T{10.5}, T{9}, -T{1.5}), // + C(-T{10.5}, T{10}, -T{0.5}), // + + // lhs positive, rhs negative + C(T{10.5}, -T{1}, T{0.5}), // + C(T{10.5}, -T{2}, T{0.5}), // + C(T{10.5}, -T{3}, T{1.5}), // + C(T{10.5}, -T{4}, T{2.5}), // + C(T{10.5}, -T{5}, T{0.5}), // + C(T{10.5}, -T{6}, T{4.5}), // + C(T{10.5}, -T{5}, T{0.5}), // + C(T{10.5}, -T{8}, T{2.5}), // + C(T{10.5}, -T{9}, T{1.5}), // + C(T{10.5}, -T{10}, T{0.5}), // + + // lhs negative, rhs negative + C(-T{10.5}, -T{1}, -T{0.5}), // + C(-T{10.5}, -T{2}, -T{0.5}), // + C(-T{10.5}, -T{3}, -T{1.5}), // + C(-T{10.5}, -T{4}, -T{2.5}), // + C(-T{10.5}, -T{5}, -T{0.5}), // + C(-T{10.5}, -T{6}, -T{4.5}), // + C(-T{10.5}, -T{5}, -T{0.5}), // + C(-T{10.5}, -T{8}, -T{2.5}), // + C(-T{10.5}, -T{9}, -T{1.5}), // + C(-T{10.5}, -T{10}, -T{0.5}), // + }); + } + + return r; +} +INSTANTIATE_TEST_SUITE_P(Mod, + ResolverConstEvalBinaryOpTest, + testing::Combine( // + testing::Values(ast::BinaryOp::kModulo), + testing::ValuesIn(Concat( // + OpModCases(), + OpModCases(), + OpModCases(), + OpModCases(), + OpModCases(), + OpModCases())))); + template std::vector OpEqualCases() { return { diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl index d0752daf93..89611da4e4 100644 --- a/src/tint/resolver/intrinsic_table.inl +++ b/src/tint/resolver/intrinsic_table.inl @@ -11322,48 +11322,48 @@ constexpr OverloadInfo kOverloads[] = { /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 0, - /* template types */ &kTemplateTypes[30], + /* template types */ &kTemplateTypes[22], /* template numbers */ &kTemplateNumbers[10], /* parameters */ &kParameters[718], /* return matcher indices */ &kMatcherIndices[3], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpModulo, }, { /* [251] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 1, - /* template types */ &kTemplateTypes[30], + /* template types */ &kTemplateTypes[22], /* template numbers */ &kTemplateNumbers[4], /* parameters */ &kParameters[720], /* return matcher indices */ &kMatcherIndices[30], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpModulo, }, { /* [252] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 1, - /* template types */ &kTemplateTypes[30], + /* template types */ &kTemplateTypes[22], /* template numbers */ &kTemplateNumbers[4], /* parameters */ &kParameters[722], /* return matcher indices */ &kMatcherIndices[30], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpModulo, }, { /* [253] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 1, - /* template types */ &kTemplateTypes[30], + /* template types */ &kTemplateTypes[22], /* template numbers */ &kTemplateNumbers[4], /* parameters */ &kParameters[724], /* return matcher indices */ &kMatcherIndices[30], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpModulo, }, { /* [254] */ @@ -14929,10 +14929,10 @@ constexpr IntrinsicInfo kBinaryOperators[] = { }, { /* [4] */ - /* op %(T, T) -> T */ - /* op %(vec, vec) -> vec */ - /* op %(vec, T) -> vec */ - /* op %(T, vec) -> vec */ + /* op %(T, T) -> T */ + /* op %(vec, vec) -> vec */ + /* op %(vec, T) -> vec */ + /* op %(T, vec) -> vec */ /* num overloads */ 4, /* overloads */ &kOverloads[250], },