diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index b7674dcd75..719582458e 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -779,24 +779,22 @@ utils::Result ConstEval::Div(const Source& source, NumberT a, NumberT b } } else { using T = UnwrapNumber; - auto divide_values = [](T lhs, T rhs) { - if constexpr (std::is_integral_v) { - // For integers, lhs / 0 returns lhs - if (rhs == 0) { - return lhs; - } - - if constexpr (std::is_signed_v) { - // For signed integers, for lhs / -1, return lhs if lhs is the - // most negative value - if (rhs == -1 && lhs == std::numeric_limits::min()) { - return lhs; - } - } + auto lhs = a.value; + auto rhs = b.value; + if (rhs == 0) { + // For integers (as for floats), lhs / 0 is an error + AddError(OverflowErrorMessage(a, "/", b), source); + return utils::Failure; + } + if constexpr (std::is_signed_v) { + // For signed integers, lhs / -1 where lhs is the + // most negative value is an error + if (rhs == -1 && lhs == std::numeric_limits::min()) { + AddError(OverflowErrorMessage(a, "/", b), source); + return utils::Failure; } - return lhs / rhs; - }; - result = divide_values(a.value, b.value); + } + result = lhs / rhs; } return result; } diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc index a35214a86c..f621c61f65 100644 --- a/src/tint/resolver/const_eval_binary_op_test.cc +++ b/src/tint/resolver/const_eval_binary_op_test.cc @@ -419,36 +419,31 @@ INSTANTIATE_TEST_SUITE_P(Mul, template std::vector OpDivIntCases() { - std::vector r = { - C(Val(T{0}), Val(T{1}), Val(T{0})), - C(Val(T{1}), Val(T{1}), Val(T{1})), - C(Val(T{1}), Val(T{1}), Val(T{1})), - C(Val(T{2}), Val(T{1}), Val(T{2})), - C(Val(T{4}), Val(T{2}), Val(T{2})), - C(Val(T::Highest()), Val(T{1}), Val(T::Highest())), - C(Val(T::Lowest()), Val(T{1}), Val(T::Lowest())), - C(Val(T::Highest()), Val(T::Highest()), Val(T{1})), - C(Val(T{0}), Val(T::Highest()), Val(T{0})), - C(Val(T{0}), Val(T::Lowest()), Val(T{0})), - }; - ConcatIntoIf && IsIntegral>( // - r, std::vector{ - // e1, when e2 is zero. - C(T{123}, T{0}, T{123}), - }); - ConcatIntoIf && IsSignedIntegral>( // - r, std::vector{ - // e1, when e1 is the most negative value in T, and e2 is -1. - C(T::Smallest(), T{-1}, T::Smallest()), - }); - auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "/", b); }; - ConcatIntoIf>( // + + std::vector r = { + C(T{0}, T{1}, T{0}), + C(T{1}, T{1}, T{1}), + C(T{1}, T{1}, T{1}), + C(T{2}, T{1}, T{2}), + C(T{4}, T{2}, T{2}), + C(T::Highest(), T{1}, T::Highest()), + C(T::Lowest(), T{1}, T::Lowest()), + C(T::Highest(), T::Highest(), T{1}), + C(T{0}, T::Highest(), T{0}), + + // Divide by zero + E(T{123}, T{0}, error_msg(T{123}, T{0})), + E(T::Highest(), T{0}, error_msg(T::Highest(), T{0})), + E(T::Lowest(), T{0}, error_msg(T::Lowest(), T{0})), + }; + + // Error on most negative divided by -1 + ConcatIntoIf>( // r, std::vector{ - // Most negative value divided by -1 - E(AInt::Lowest(), -1_a, error_msg(AInt::Lowest(), -1_a)), + E(T::Lowest(), T{-1}, error_msg(T::Lowest(), T{-1})), }); return r; }