diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc index 8708aa8b2f..35c5c2ea10 100644 --- a/src/tint/resolver/const_eval_binary_op_test.cc +++ b/src/tint/resolver/const_eval_binary_op_test.cc @@ -14,6 +14,8 @@ #include "src/tint/resolver/const_eval_test.h" +#include "src/tint/utils/result.h" + using namespace tint::number_suffixes; // NOLINT using ::testing::HasSubstr; @@ -24,10 +26,16 @@ namespace { using resolver::operator<<; struct Case { + struct Success { + Types value; + }; + struct Failure { + std::string error; + }; + Types lhs; Types rhs; - Types expected; - bool overflow; + utils::Result expected; }; struct ErrorCase { @@ -37,20 +45,37 @@ struct ErrorCase { /// Creates a Case with Values of any type template -Case C(Value lhs, Value rhs, Value expected, bool overflow = false) { - return Case{std::move(lhs), std::move(rhs), std::move(expected), overflow}; +Case C(Value lhs, Value rhs, Value expected) { + return Case{std::move(lhs), std::move(rhs), Case::Success{std::move(expected)}}; } /// Convenience overload that creates a Case with just scalars template >> -Case C(T lhs, U rhs, V expected, bool overflow = false) { - return Case{Val(lhs), Val(rhs), Val(expected), overflow}; +Case C(T lhs, U rhs, V expected) { + return Case{Val(lhs), Val(rhs), Case::Success{Val(expected)}}; +} + +/// Creates an failure Case with Values of any type +template +Case E(Value lhs, Value rhs, std::string error) { + return Case{std::move(lhs), std::move(rhs), Case::Failure{std::move(error)}}; +} + +/// Convenience overload that creates an error Case with just scalars +template >> +Case E(T lhs, U rhs, std::string error) { + return Case{Val(lhs), Val(rhs), Case::Failure{std::move(error)}}; } /// Prints Case to ostream static std::ostream& operator<<(std::ostream& o, const Case& c) { - o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: " << c.expected - << ", overflow: " << c.overflow; + o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: "; + if (c.expected) { + auto s = c.expected.Get(); + o << s.value; + } else { + o << "[ERROR: " << c.expected.Failure().error << "]"; + } return o; } @@ -66,38 +91,37 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) { auto op = std::get<0>(GetParam()); auto& c = std::get<1>(GetParam()); - auto* expected = ToValueBase(c.expected); - if (expected->IsAbstract() && c.overflow) { - // Overflow is not allowed for abstract types. This is tested separately. - return; - } - - auto* lhs = ToValueBase(c.lhs); - auto* rhs = ToValueBase(c.rhs); - - auto* lhs_expr = lhs->Expr(*this); - auto* rhs_expr = rhs->Expr(*this); - auto* expr = create(op, lhs_expr, rhs_expr); + auto* lhs_expr = ToValueBase(c.lhs)->Expr(*this); + auto* rhs_expr = ToValueBase(c.rhs)->Expr(*this); + auto* expr = create(Source{{12, 34}}, op, lhs_expr, rhs_expr); GlobalConst("C", expr); - ASSERT_TRUE(r()->Resolve()) << r()->error(); - auto* sem = Sem().Get(expr); - const sem::Constant* value = sem->ConstantValue(); - ASSERT_NE(value, nullptr); - EXPECT_TYPE(value->Type(), sem->Type()); + if (c.expected) { + ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto expected_case = c.expected.Get(); + auto* expected = ToValueBase(expected_case.value); - auto values_flat = ScalarArgsFrom(value); - auto expected_values_flat = expected->Args(); - ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length()); - for (size_t i = 0; i < values_flat.values.Length(); ++i) { - auto& a = values_flat.values[i]; - auto& b = expected_values_flat.values[i]; - EXPECT_EQ(a, b); - if (expected->IsIntegral()) { - // Check that the constant's integer doesn't contain unexpected - // data in the MSBs that are outside of the bit-width of T. - EXPECT_EQ(builder::As(a), builder::As(b)); + auto* sem = Sem().Get(expr); + const sem::Constant* value = sem->ConstantValue(); + ASSERT_NE(value, nullptr); + EXPECT_TYPE(value->Type(), sem->Type()); + + auto values_flat = ScalarArgsFrom(value); + auto expected_values_flat = expected->Args(); + ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length()); + for (size_t i = 0; i < values_flat.values.Length(); ++i) { + auto& a = values_flat.values[i]; + auto& b = expected_values_flat.values[i]; + EXPECT_EQ(a, b); + if (expected->IsIntegral()) { + // Check that the constant's integer doesn't contain unexpected + // data in the MSBs that are outside of the bit-width of T. + EXPECT_EQ(builder::As(a), builder::As(b)); + } } + } else { + ASSERT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), c.expected.Failure().error); } } @@ -113,28 +137,53 @@ INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs, template std::vector OpAddIntCases() { static_assert(IsIntegral); - return { + auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{3}), C(T::Lowest(), T{1}, T{T::Lowest() + 1}), C(T::Highest(), Negate(T{1}), T{T::Highest() - 1}), C(T::Lowest(), T::Highest(), Negate(T{1})), - C(T::Highest(), T{1}, T::Lowest(), true), - C(T::Lowest(), Negate(T{1}), T::Highest(), true), }; + ConcatIntoIf>( // + r, std::vector{ + C(T::Highest(), T{1}, T::Lowest()), + C(T::Lowest(), Negate(T{1}), T::Highest()), + }); + + auto error_msg = [](auto a, auto b) { + return "12:34 error: " + OverflowErrorMessage(a, "+", b); + }; + ConcatIntoIf>( // + r, std::vector{ + E(T::Highest(), T{1}, error_msg(T::Highest(), T{1})), + E(T::Lowest(), Negate(T{1}), error_msg(T::Lowest(), Negate(T{1}))), + }); + return r; } template std::vector OpAddFloatCases() { static_assert(IsFloatingPoint); - return { + auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{3}), C(T::Lowest(), T{1}, T{T::Lowest() + 1}), C(T::Highest(), Negate(T{1}), T{T::Highest() - 1}), C(T::Lowest(), T::Highest(), T{0}), - C(T::Highest(), T::Highest(), T::Inf(), true), - C(T::Lowest(), Negate(T::Highest()), -T::Inf(), true), }; + ConcatIntoIf>( // + r, std::vector{ + C(T::Highest(), T::Highest(), T::Inf()), + C(T::Lowest(), Negate(T::Highest()), -T::Inf()), + }); + auto error_msg = [](auto a, auto b) { + return "12:34 error: " + OverflowErrorMessage(a, "+", b); + }; + ConcatIntoIf>( // + r, std::vector{ + E(T::Highest(), T::Highest(), error_msg(T::Highest(), T::Highest())), + E(T::Lowest(), Negate(T::Highest()), error_msg(T::Lowest(), Negate(T::Highest()))), + }); + return r; } INSTANTIATE_TEST_SUITE_P(Add, ResolverConstEvalBinaryOpTest, @@ -150,28 +199,52 @@ INSTANTIATE_TEST_SUITE_P(Add, template std::vector OpSubIntCases() { static_assert(IsIntegral); - return { + auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{3}, T{2}, T{1}), C(T{T::Lowest() + 1}, T{1}, T::Lowest()), C(T{T::Highest() - 1}, Negate(T{1}), T::Highest()), C(Negate(T{1}), T::Highest(), T::Lowest()), - C(T::Lowest(), T{1}, T::Highest(), true), - C(T::Highest(), Negate(T{1}), T::Lowest(), true), }; + ConcatIntoIf>( // + r, std::vector{ + C(T::Lowest(), T{1}, T::Highest()), + C(T::Highest(), Negate(T{1}), T::Lowest()), + }); + auto error_msg = [](auto a, auto b) { + return "12:34 error: " + OverflowErrorMessage(a, "-", b); + }; + ConcatIntoIf>( // + r, std::vector{ + E(T::Lowest(), T{1}, error_msg(T::Lowest(), T{1})), + E(T::Highest(), Negate(T{1}), error_msg(T::Highest(), Negate(T{1}))), + }); + return r; } template std::vector OpSubFloatCases() { static_assert(IsFloatingPoint); - return { + auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{3}, T{2}, T{1}), C(T::Highest(), T{1}, T{T::Highest() - 1}), C(T::Lowest(), Negate(T{1}), T{T::Lowest() + 1}), C(T{0}, T::Highest(), T::Lowest()), - C(T::Highest(), Negate(T::Highest()), T::Inf(), true), - C(T::Lowest(), T::Highest(), -T::Inf(), true), }; + ConcatIntoIf>( // + r, std::vector{ + C(T::Highest(), Negate(T::Highest()), T::Inf()), + C(T::Lowest(), T::Highest(), -T::Inf()), + }); + auto error_msg = [](auto a, auto b) { + return "12:34 error: " + OverflowErrorMessage(a, "-", b); + }; + ConcatIntoIf>( // + r, std::vector{ + E(T::Highest(), Negate(T::Highest()), error_msg(T::Highest(), Negate(T::Highest()))), + E(T::Lowest(), T::Highest(), error_msg(T::Lowest(), T::Highest())), + }); + return r; } INSTANTIATE_TEST_SUITE_P(Sub, ResolverConstEvalBinaryOpTest, @@ -186,21 +259,35 @@ INSTANTIATE_TEST_SUITE_P(Sub, template std::vector OpMulScalarCases() { - return { + auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{2}), C(T{2}, T{3}, T{6}), C(Negate(T{2}), T{3}, Negate(T{6})), C(T::Highest(), T{1}, T::Highest()), C(T::Lowest(), T{1}, T::Lowest()), - C(T::Highest(), T::Highest(), Mul(T::Highest(), T::Highest()), true), - C(T::Lowest(), T::Lowest(), Mul(T::Lowest(), T::Lowest()), true), }; + ConcatIntoIf>( // + r, std::vector{ + C(T::Highest(), T::Highest(), Mul(T::Highest(), T::Highest())), + C(T::Lowest(), T::Lowest(), Mul(T::Lowest(), T::Lowest())), + }); + auto error_msg = [](auto a, auto b) { + return "12:34 error: " + OverflowErrorMessage(a, "*", b); + }; + ConcatIntoIf>( // + r, std::vector{ + E(T::Highest(), T::Highest(), error_msg(T::Highest(), T::Highest())), + E(T::Lowest(), T::Lowest(), error_msg(T::Lowest(), T::Lowest())), + E(T::Highest(), T{2}, error_msg(T::Highest(), T{2})), + E(T::Lowest(), Negate(T{2}), error_msg(T::Lowest(), Negate(T{2}))), + }); + return r; } template std::vector OpMulVecCases() { - return { + auto r = std::vector{ // s * vec3 = vec3 C(Val(T{2.0}), Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.5}, T{4.5}, T{6.5})), // vec3 * s = vec3 @@ -208,11 +295,20 @@ std::vector OpMulVecCases() { // vec3 * vec3 = vec3 C(Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.0}, T{2.0}, T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})), }; + auto error_msg = [](auto a, auto b) { + return "12:34 error: " + OverflowErrorMessage(a, "*", b); + }; + ConcatIntoIf>( // + r, std::vector{ + E(Val(T::Highest()), Vec(T{2}, T{1}), error_msg(T::Highest(), T{2})), + E(Val(T::Lowest()), Vec(Negate(T{2}), T{1}), error_msg(T::Lowest(), Negate(T{2}))), + }); + return r; } template std::vector OpMulMatCases() { - return { + auto r = std::vector{ // s * mat3x2 = mat3x2 C(Val(T{2.25}), Mat({T{1.0}, T{4.0}}, // @@ -248,6 +344,68 @@ std::vector OpMulMatCases() { Mat({T{24.25}, T{31.0}}, // {T{51.25}, T{67.0}})), // }; + auto error_msg = [](auto a, const char* op, auto b) { + return "12:34 error: " + OverflowErrorMessage(a, op, b); + }; + ConcatIntoIf>( // + r, std::vector{ + // vector-matrix multiply + + // Overflow from first multiplication of dot product of vector and matrix column 0 + // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) + // ^ + E(Vec(T::Highest(), T{1.0}), // + Mat({T{2.0}, T{1.0}}, // + {T{1.0}, T{1.0}}), // + error_msg(T{2}, "*", T::Highest())), + + // Overflow from second multiplication of dot product of vector and matrix column 0 + // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) + // ^ + E(Vec(T{1.0}, T::Highest()), // + Mat({T{1.0}, T{2.0}}, // + {T{1.0}, T{1.0}}), // + error_msg(T{2}, "*", T::Highest())), + + // Overflow from addition of dot product of vector and matrix column 0 + // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) + // ^ + E(Vec(T::Highest(), T::Highest()), // + Mat({T{1.0}, T{1.0}}, // + {T{1.0}, T{1.0}}), // + error_msg(T::Highest(), "+", T::Highest())), + + // matrix-matrix multiply + + // Overflow from first multiplication of dot product of lhs row 0 and rhs column 0 + // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] + // ^ + E(Mat({T::Highest(), T{1.0}}, // + {T{1.0}, T{1.0}}), // + Mat({T{2.0}, T{1.0}}, // + {T{1.0}, T{1.0}}), // + error_msg(T::Highest(), "*", T{2.0})), + + // Overflow from second multiplication of dot product of lhs row 0 and rhs column 0 + // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] + // ^ + E(Mat({T{1.0}, T::Highest()}, // + {T{1.0}, T{1.0}}), // + Mat({T{1.0}, T{1.0}}, // + {T{2.0}, T{1.0}}), // + error_msg(T::Highest(), "*", T{2.0})), + + // Overflow from addition of dot product of lhs row 0 and rhs column 0 + // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] + // ^ + E(Mat({T::Highest(), T{1.0}}, // + {T::Highest(), T{1.0}}), // + Mat({T{1.0}, T{1.0}}, // + {T{1.0}, T{1.0}}), // + error_msg(T::Highest(), "+", T::Highest())), + }); + + return r; } INSTANTIATE_TEST_SUITE_P(Mul, @@ -285,22 +443,22 @@ std::vector OpDivIntCases() { C(Val(T{0}), Val(T::Highest()), Val(T{0})), C(Val(T{0}), Val(T::Lowest()), Val(T{0})), }; - ConcatIntoIf>( // + ConcatIntoIf && IsIntegral>( // r, std::vector{ // e1, when e2 is zero. - C(T{123}, T{0}, T{123}, true), + C(T{123}, T{0}, T{123}), }); - ConcatIntoIf>( // + ConcatIntoIf && IsSignedIntegral>( // r, std::vector{ // e1, when e1 is the most negative value in T, and e2 is -1. - C(T::Smallest(), T{-1}, T::Smallest(), true), + C(T::Smallest(), T{-1}, T::Smallest()), }); return r; } template std::vector OpDivFloatCases() { - return { + std::vector r = { C(Val(T{0}), Val(T{1}), Val(T{0})), C(Val(T{1}), Val(T{1}), Val(T{1})), C(Val(T{1}), Val(T{1}), Val(T{1})), @@ -311,11 +469,33 @@ std::vector OpDivFloatCases() { C(Val(T::Highest()), Val(T::Highest()), Val(T{1})), C(Val(T{0}), Val(T::Highest()), Val(T{0})), C(Val(T{0}), Val(T::Lowest()), Val(-T{0})), - C(T{123}, T{0}, T::Inf(), true), - C(T{-123}, -T{0}, T::Inf(), true), - C(T{-123}, T{0}, -T::Inf(), true), - C(T{123}, -T{0}, -T::Inf(), true), }; + ConcatIntoIf>( // + r, std::vector{ + C(T{123}, T{0}, T::Inf()), + C(T{-123}, -T{0}, T::Inf()), + C(T{-123}, T{0}, -T::Inf()), + C(T{123}, -T{0}, -T::Inf()), + }); + auto error_msg = [](auto a, auto b) { + return "12:34 error: " + OverflowErrorMessage(a, "/", b); + }; + ConcatIntoIf>( // + r, std::vector{ + // Divide by zero + E(T{123}, T{0}, error_msg(T{123}, T{0})), + E(Negate(T{123}), Negate(T{0}), error_msg(Negate(T{123}), Negate(T{0}))), + E(Negate(T{123}), T{0}, error_msg(Negate(T{123}), T{0})), + E(T{123}, Negate(T{0}), error_msg(T{123}, Negate(T{0}))), + }); + + ConcatIntoIf>( // + r, std::vector{ + // Most negative value divided by -1 + E(AInt::Lowest(), -1_a, error_msg(AInt::Lowest(), -1_a)), + }); + + return r; } INSTANTIATE_TEST_SUITE_P(Div, ResolverConstEvalBinaryOpTest, @@ -653,8 +833,8 @@ std::vector ShiftLeftCases() { C(Negate(T{0}), ST{u32::Highest()}, Negate(T{0})), // }); - // Cases that are fine for signed values (no sign change), but would overflow unsigned values. - // See ResolverConstEvalBinaryOpTest_Overflow for negative tests. + // Cases that are fine for signed values (no sign change), but would overflow + // unsigned values. See below for negative tests. ConcatIntoIf>( // r, std::vector{ C(B::TwoLeftMost, ST{1}, B::LeftMost), // @@ -678,6 +858,39 @@ std::vector ShiftLeftCases() { C(B::AllButLeftMost, ST{1}, B::AllButRightMost), }); + auto error_msg = [](auto a, auto b) { + return "12:34 error: " + OverflowErrorMessage(a, "<<", b); + }; + ConcatIntoIf>( // + r, std::vector{ + // ShiftLeft of AInts that result in values not representable as AInts. + // Note that for i32/u32, these would error because shift value is larger than 32. + E(B::All, T{B::NumBits}, error_msg(B::All, T{B::NumBits})), + E(B::RightMost, T{B::NumBits}, error_msg(B::RightMost, T{B::NumBits})), + E(B::AllButLeftMost, T{B::NumBits}, error_msg(B::AllButLeftMost, T{B::NumBits})), + E(B::AllButLeftMost, T{B::NumBits + 1}, + error_msg(B::AllButLeftMost, T{B::NumBits + 1})), + E(B::AllButLeftMost, T{B::NumBits + 1000}, + error_msg(B::AllButLeftMost, T{B::NumBits + 1000})), + }); + ConcatIntoIf>( // + r, std::vector{ + // ShiftLeft of u32s that overflow (non-zero bits get shifted out) + E(T{0b00010}, T{31}, error_msg(T{0b00010}, T{31})), + E(T{0b00100}, T{30}, error_msg(T{0b00100}, T{30})), + E(T{0b01000}, T{29}, error_msg(T{0b01000}, T{29})), + E(T{0b10000}, T{28}, error_msg(T{0b10000}, T{28})), + //... + E(T{1 << 28}, T{4}, error_msg(T{1 << 28}, T{4})), + E(T{1 << 29}, T{3}, error_msg(T{1 << 29}, T{3})), + E(T{1 << 30}, T{2}, error_msg(T{1 << 30}, T{2})), + E(T{1u << 31}, T{1}, error_msg(T{1u << 31}, T{1})), + + // And some more + E(B::All, T{1}, error_msg(B::All, T{1})), + E(B::AllButLeftMost, T{2}, error_msg(B::AllButLeftMost, T{2})), + }); + return r; } INSTANTIATE_TEST_SUITE_P(ShiftLeft, @@ -688,153 +901,6 @@ INSTANTIATE_TEST_SUITE_P(ShiftLeft, ShiftLeftCases(), // ShiftLeftCases())))); -// Tests for errors on overflow/underflow of binary operations with abstract numbers -struct OverflowCase { - ast::BinaryOp op; - Types lhs; - Types rhs; -}; - -static std::ostream& operator<<(std::ostream& o, const OverflowCase& c) { - o << ast::FriendlyName(c.op) << ", lhs: " << c.lhs << ", rhs: " << c.rhs; - return o; -} -using ResolverConstEvalBinaryOpTest_Overflow = ResolverTestWithParam; -TEST_P(ResolverConstEvalBinaryOpTest_Overflow, Test) { - Enable(ast::Extension::kF16); - auto& c = GetParam(); - auto* lhs = ToValueBase(c.lhs); - auto* rhs = ToValueBase(c.rhs); - auto* lhs_expr = lhs->Expr(*this); - auto* rhs_expr = rhs->Expr(*this); - auto* expr = create(Source{{1, 1}}, c.op, lhs_expr, rhs_expr); - GlobalConst("C", expr); - ASSERT_FALSE(r()->Resolve()); - EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '")); - EXPECT_THAT(r()->error(), HasSubstr("' cannot be represented as '" + lhs->TypeName() + "'")); -} -INSTANTIATE_TEST_SUITE_P( - Test, - ResolverConstEvalBinaryOpTest_Overflow, - testing::Values( - - // scalar-scalar add - OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Highest()), Val(1_a)}, - OverflowCase{ast::BinaryOp::kAdd, Val(AInt::Lowest()), Val(-1_a)}, - OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Highest()), Val(AFloat::Highest())}, - OverflowCase{ast::BinaryOp::kAdd, Val(AFloat::Lowest()), Val(AFloat::Lowest())}, - // scalar-scalar subtract - OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Lowest()), Val(1_a)}, - OverflowCase{ast::BinaryOp::kSubtract, Val(AInt::Highest()), Val(-1_a)}, - OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Highest()), Val(AFloat::Lowest())}, - OverflowCase{ast::BinaryOp::kSubtract, Val(AFloat::Lowest()), Val(AFloat::Highest())}, - - // scalar-scalar multiply - OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Val(2_a)}, - OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Val(-2_a)}, - - // scalar-vector multiply - OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Highest()), Vec(2_a, 1_a)}, - OverflowCase{ast::BinaryOp::kMultiply, Val(AInt::Lowest()), Vec(-2_a, 1_a)}, - - // vector-matrix multiply - - // Overflow from first multiplication of dot product of vector and matrix column 0 - // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) - // ^ - OverflowCase{ast::BinaryOp::kMultiply, // - Vec(AFloat::Highest(), 1.0_a), // - Mat({2.0_a, 1.0_a}, // - {1.0_a, 1.0_a})}, - - // Overflow from second multiplication of dot product of vector and matrix column 0 - // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) - // ^ - OverflowCase{ast::BinaryOp::kMultiply, // - Vec(1.0_a, AFloat::Highest()), // - Mat({1.0_a, 2.0_a}, // - {1.0_a, 1.0_a})}, - - // Overflow from addition of dot product of vector and matrix column 0 - // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) - // ^ - OverflowCase{ast::BinaryOp::kMultiply, // - Vec(AFloat::Highest(), AFloat::Highest()), // - Mat({1.0_a, 1.0_a}, // - {1.0_a, 1.0_a})}, - - // matrix-matrix multiply - - // Overflow from first multiplication of dot product of lhs row 0 and rhs column 0 - // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] - // ^ - OverflowCase{ast::BinaryOp::kMultiply, // - Mat({AFloat::Highest(), 1.0_a}, // - {1.0_a, 1.0_a}), // - Mat({2.0_a, 1.0_a}, // - {1.0_a, 1.0_a})}, - - // Overflow from second multiplication of dot product of lhs row 0 and rhs column 0 - // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] - // ^ - OverflowCase{ast::BinaryOp::kMultiply, // - Mat({1.0_a, AFloat::Highest()}, // - {1.0_a, 1.0_a}), // - Mat({1.0_a, 1.0_a}, // - {2.0_a, 1.0_a})}, - - // Overflow from addition of dot product of lhs row 0 and rhs column 0 - // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] - // ^ - OverflowCase{ast::BinaryOp::kMultiply, // - Mat({AFloat::Highest(), 1.0_a}, // - {AFloat::Highest(), 1.0_a}), // - Mat({1.0_a, 1.0_a}, // - {1.0_a, 1.0_a})}, - - // Divide by zero - OverflowCase{ast::BinaryOp::kDivide, Val(123_a), Val(0_a)}, - OverflowCase{ast::BinaryOp::kDivide, Val(-123_a), Val(-0_a)}, - OverflowCase{ast::BinaryOp::kDivide, Val(-123_a), Val(0_a)}, - OverflowCase{ast::BinaryOp::kDivide, Val(123_a), Val(-0_a)}, - - // Most negative value divided by -1 - OverflowCase{ast::BinaryOp::kDivide, Val(AInt::Lowest()), Val(-1_a)}, - - // ShiftLeft of AInts that result in values not representable as AInts. - // Note that for i32/u32, these would error because shift value is larger than 32. - OverflowCase{ast::BinaryOp::kShiftLeft, // - Val(AInt{BitValues::All}), // - Val(AInt{BitValues::NumBits})}, // - OverflowCase{ast::BinaryOp::kShiftLeft, // - Val(AInt{BitValues::RightMost}), // - Val(AInt{BitValues::NumBits})}, // - OverflowCase{ast::BinaryOp::kShiftLeft, // - Val(AInt{BitValues::AllButLeftMost}), // - Val(AInt{BitValues::NumBits})}, // - OverflowCase{ast::BinaryOp::kShiftLeft, // - Val(AInt{BitValues::AllButLeftMost}), // - Val(AInt{BitValues::NumBits + 1})}, // - OverflowCase{ast::BinaryOp::kShiftLeft, // - Val(AInt{BitValues::AllButLeftMost}), // - Val(AInt{BitValues::NumBits + 1000})}, - - // ShiftLeft of u32s that overflow (non-zero bits get shifted out) - OverflowCase{ast::BinaryOp::kShiftLeft, Val(0b00010_u), Val(31_u)}, - OverflowCase{ast::BinaryOp::kShiftLeft, Val(0b00100_u), Val(30_u)}, - OverflowCase{ast::BinaryOp::kShiftLeft, Val(0b01000_u), Val(29_u)}, - OverflowCase{ast::BinaryOp::kShiftLeft, Val(0b10000_u), Val(28_u)}, - // ... - OverflowCase{ast::BinaryOp::kShiftLeft, Val(u32(1u << 28)), Val(4_u)}, - OverflowCase{ast::BinaryOp::kShiftLeft, Val(u32(1u << 29)), Val(3_u)}, - OverflowCase{ast::BinaryOp::kShiftLeft, Val(u32(1u << 30)), Val(2_u)}, - OverflowCase{ast::BinaryOp::kShiftLeft, Val(u32(1u << 31)), Val(1_u)}, - // And some more - OverflowCase{ast::BinaryOp::kShiftLeft, Val(BitValues::All), Val(1_u)}, - OverflowCase{ast::BinaryOp::kShiftLeft, Val(BitValues::AllButLeftMost), Val(2_u)} - - )); - TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a)); EXPECT_FALSE(r()->Resolve()); @@ -853,15 +919,16 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AFloat) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AFloat::Highest()), AFloat::Highest())); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "1:1 error: '1.7976931348623157081e+308 + 1.7976931348623157081e+308' cannot be represented as 'abstract-float'"); + "1:1 error: '1.7976931348623157081e+308 + 1.7976931348623157081e+308' cannot be " + "represented as 'abstract-float'"); } TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AFloat::Lowest()), AFloat::Lowest())); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ( - r()->error(), - "1:1 error: '-1.7976931348623157081e+308 + -1.7976931348623157081e+308' cannot be represented as 'abstract-float'"); + EXPECT_EQ(r()->error(), + "1:1 error: '-1.7976931348623157081e+308 + -1.7976931348623157081e+308' cannot be " + "represented as 'abstract-float'"); } // Mixed AInt and AFloat args to test implicit conversion to AFloat diff --git a/src/tint/resolver/const_eval_builtin_test.cc b/src/tint/resolver/const_eval_builtin_test.cc index 9437296c88..2fdee3ce34 100644 --- a/src/tint/resolver/const_eval_builtin_test.cc +++ b/src/tint/resolver/const_eval_builtin_test.cc @@ -54,7 +54,7 @@ struct Case { bool float_compare = false; }; struct Failure { - std::string error = nullptr; + std::string error; }; utils::Vector args; @@ -108,16 +108,6 @@ static Case E(std::initializer_list sargs, std::string err) { return Case{std::move(args), std::move(err)}; } -/// Returns the overflow error message for binary ops -template -std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs) { - std::stringstream ss; - ss << std::setprecision(20); - ss << "'" << lhs.value << " " << op << " " << rhs.value << "' cannot be represented as '" - << FriendlyName() << "'"; - return ss.str(); -} - using ResolverConstEvalBuiltinTest = ResolverTestWithParam>; TEST_P(ResolverConstEvalBuiltinTest, Test) { @@ -132,13 +122,12 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) { } auto* expr = Call(Source{{12, 34}}, sem::str(builtin), std::move(args)); - GlobalConst("C", expr); if (c.expected) { - auto expected = c.expected.Get(); + auto expected_case = c.expected.Get(); - auto* expected_expr = ToValueBase(expected.value)->Expr(*this); + auto* expected_expr = ToValueBase(expected_case.value)->Expr(*this); GlobalConst("E", expected_expr); ASSERT_TRUE(r()->Resolve()) << r()->error(); @@ -168,21 +157,21 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) { if (std::isnan(e)) { EXPECT_TRUE(std::isnan(v)); } else { - auto vf = (expected.pos_or_neg ? Abs(v) : v); - if (expected.float_compare) { + auto vf = (expected_case.pos_or_neg ? Abs(v) : v); + if (expected_case.float_compare) { EXPECT_FLOAT_EQ(vf, e); } else { EXPECT_EQ(vf, e); } } } else { - EXPECT_EQ((expected.pos_or_neg ? Abs(v) : v), e); + EXPECT_EQ((expected_case.pos_or_neg ? Abs(v) : v), e); // Check that the constant's integer doesn't contain unexpected // data in the MSBs that are outside of the bit-width of T. EXPECT_EQ(a->As(), b->As()); } }, - expected.value); + expected_case.value); return HasFailure() ? Action::kStop : Action::kContinue; }); diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h index 2f000f82b6..c64a86b9fb 100644 --- a/src/tint/resolver/const_eval_test.h +++ b/src/tint/resolver/const_eval_test.h @@ -16,6 +16,7 @@ #define SRC_TINT_RESOLVER_CONST_EVAL_TEST_H_ #include +#include #include #include "gmock/gmock.h" @@ -134,6 +135,16 @@ inline void ConcatIntoIf([[maybe_unused]] Vec& v1, [[maybe_unused]] Vecs&&... vs } } +/// Returns the overflow error message for binary ops +template +inline std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs) { + std::stringstream ss; + ss << std::setprecision(20); + ss << "'" << lhs.value << " " << op << " " << rhs.value << "' cannot be represented as '" + << FriendlyName() << "'"; + return ss.str(); +} + using builder::IsValue; using builder::Mat; using builder::Val;