// Copyright 2022 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/tint/resolver/const_eval_test.h" #include "src/tint/utils/result.h" using namespace tint::number_suffixes; // NOLINT using ::testing::HasSubstr; namespace tint::resolver { namespace { // Bring in std::ostream& operator<<(std::ostream& o, const Types& types) using resolver::operator<<; struct Case { struct Success { Types value; }; struct Failure { std::string error; }; Types lhs; Types rhs; utils::Result expected; }; struct ErrorCase { Types lhs; Types rhs; }; /// Creates a Case with Values of any type template Case C(Value lhs, Value rhs, Value expected) { return Case{std::move(lhs), std::move(rhs), Case::Success{std::move(expected)}}; } /// Convenience overload that creates a Case with just scalars template >> Case C(T lhs, U rhs, V expected) { return Case{Val(lhs), Val(rhs), Case::Success{Val(expected)}}; } /// Creates an failure Case with Values of any type template Case E(Value lhs, Value rhs, std::string error) { return Case{std::move(lhs), std::move(rhs), Case::Failure{std::move(error)}}; } /// Convenience overload that creates an error Case with just scalars template >> Case E(T lhs, U rhs, std::string error) { return Case{Val(lhs), Val(rhs), Case::Failure{std::move(error)}}; } /// Prints Case to ostream static std::ostream& operator<<(std::ostream& o, const Case& c) { o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: "; if (c.expected) { auto s = c.expected.Get(); o << s.value; } else { o << "[ERROR: " << c.expected.Failure().error << "]"; } return o; } /// Prints ErrorCase to ostream std::ostream& operator<<(std::ostream& o, const ErrorCase& c) { o << c.lhs << ", " << c.rhs; return o; } using ResolverConstEvalBinaryOpTest = ResolverTestWithParam>; TEST_P(ResolverConstEvalBinaryOpTest, Test) { Enable(ast::Extension::kF16); auto op = std::get<0>(GetParam()); auto& c = std::get<1>(GetParam()); auto* lhs_expr = ToValueBase(c.lhs)->Expr(*this); auto* rhs_expr = ToValueBase(c.rhs)->Expr(*this); auto* expr = create(Source{{12, 34}}, op, lhs_expr, rhs_expr); GlobalConst("C", expr); if (c.expected) { ASSERT_TRUE(r()->Resolve()) << r()->error(); auto expected_case = c.expected.Get(); auto* expected = ToValueBase(expected_case.value); auto* sem = Sem().Get(expr); const sem::Constant* value = sem->ConstantValue(); ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); CheckConstant(value, expected); } else { ASSERT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), c.expected.Failure().error); } } INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs, ResolverConstEvalBinaryOpTest, testing::Combine(testing::Values(ast::BinaryOp::kAdd), testing::ValuesIn(std::vector{ // Mixed abstract type args C(1_a, 2.3_a, 3.3_a), C(2.3_a, 1_a, 3.3_a), }))); template std::vector OpAddIntCases() { static_assert(IsIntegral); auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{3}), C(T::Lowest(), T{1}, T{T::Lowest() + 1}), C(T::Highest(), Negate(T{1}), T{T::Highest() - 1}), C(T::Lowest(), T::Highest(), Negate(T{1})), }; if constexpr (IsAbstract) { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "+", b); }; ConcatInto( // r, std::vector{ E(T::Highest(), T{1}, error_msg(T::Highest(), T{1})), E(T::Lowest(), Negate(T{1}), error_msg(T::Lowest(), Negate(T{1}))), }); } else { ConcatInto( // r, std::vector{ C(T::Highest(), T{1}, T::Lowest()), C(T::Lowest(), Negate(T{1}), T::Highest()), }); } return r; } template std::vector OpAddFloatCases() { static_assert(IsFloatingPoint); auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "+", b); }; return std::vector{ C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{3}), C(T::Lowest(), T{1}, T{T::Lowest() + 1}), C(T::Highest(), Negate(T{1}), T{T::Highest() - 1}), C(T::Lowest(), T::Highest(), T{0}), E(T::Highest(), T::Highest(), error_msg(T::Highest(), T::Highest())), E(T::Lowest(), Negate(T::Highest()), error_msg(T::Lowest(), Negate(T::Highest()))), }; } INSTANTIATE_TEST_SUITE_P(Add, ResolverConstEvalBinaryOpTest, testing::Combine(testing::Values(ast::BinaryOp::kAdd), testing::ValuesIn(Concat( // OpAddIntCases(), OpAddIntCases(), OpAddIntCases(), OpAddFloatCases(), OpAddFloatCases(), OpAddFloatCases())))); template std::vector OpSubIntCases() { static_assert(IsIntegral); auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{3}, T{2}, T{1}), C(T{T::Lowest() + 1}, T{1}, T::Lowest()), C(T{T::Highest() - 1}, Negate(T{1}), T::Highest()), C(Negate(T{1}), T::Highest(), T::Lowest()), }; if constexpr (IsAbstract) { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "-", b); }; ConcatInto( // r, std::vector{ E(T::Lowest(), T{1}, error_msg(T::Lowest(), T{1})), E(T::Highest(), Negate(T{1}), error_msg(T::Highest(), Negate(T{1}))), }); } else { ConcatInto( // r, std::vector{ C(T::Lowest(), T{1}, T::Highest()), C(T::Highest(), Negate(T{1}), T::Lowest()), }); } return r; } template std::vector OpSubFloatCases() { static_assert(IsFloatingPoint); auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "-", b); }; return std::vector{ C(T{0}, T{0}, T{0}), C(T{3}, T{2}, T{1}), C(T::Highest(), T{1}, T{T::Highest() - 1}), C(T::Lowest(), Negate(T{1}), T{T::Lowest() + 1}), C(T{0}, T::Highest(), T::Lowest()), E(T::Highest(), Negate(T::Highest()), error_msg(T::Highest(), Negate(T::Highest()))), E(T::Lowest(), T::Highest(), error_msg(T::Lowest(), T::Highest())), }; } INSTANTIATE_TEST_SUITE_P(Sub, ResolverConstEvalBinaryOpTest, testing::Combine(testing::Values(ast::BinaryOp::kSubtract), testing::ValuesIn(Concat( // OpSubIntCases(), OpSubIntCases(), OpSubIntCases(), OpSubFloatCases(), OpSubFloatCases(), OpSubFloatCases())))); template std::vector OpMulScalarCases() { auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{2}), C(T{2}, T{3}, T{6}), C(Negate(T{2}), T{3}, Negate(T{6})), C(T::Highest(), T{1}, T::Highest()), C(T::Lowest(), T{1}, T::Lowest()), }; if constexpr (IsAbstract || IsFloatingPoint) { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "*", b); }; ConcatInto( // r, std::vector{ // Fail if result is +/-inf E(T::Highest(), T::Highest(), error_msg(T::Highest(), T::Highest())), E(T::Lowest(), T::Lowest(), error_msg(T::Lowest(), T::Lowest())), E(T::Highest(), T{2}, error_msg(T::Highest(), T{2})), E(T::Lowest(), Negate(T{2}), error_msg(T::Lowest(), Negate(T{2}))), }); } else { ConcatInto( // r, std::vector{ C(T::Highest(), T::Highest(), Mul(T::Highest(), T::Highest())), C(T::Lowest(), T::Lowest(), Mul(T::Lowest(), T::Lowest())), }); } return r; } template std::vector OpMulVecCases() { auto r = std::vector{ // s * vec3 = vec3 C(Val(T{2.0}), Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.5}, T{4.5}, T{6.5})), // vec3 * s = vec3 C(Vec(T{1.25}, T{2.25}, T{3.25}), Val(T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})), // vec3 * vec3 = vec3 C(Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.0}, T{2.0}, T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})), }; if constexpr (IsAbstract || IsFloatingPoint) { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "*", b); }; ConcatInto( // r, std::vector{ // Fail if result is +/-inf E(Val(T::Highest()), Vec(T{2}, T{1}), error_msg(T::Highest(), T{2})), E(Val(T::Lowest()), Vec(Negate(T{2}), T{1}), error_msg(T::Lowest(), Negate(T{2}))), }); } else { ConcatInto( // r, std::vector{ C(Val(T::Highest()), Vec(T{2}, T{1}), Vec(T{-2}, T::Highest())), C(Val(T::Lowest()), Vec(Negate(T{2}), T{1}), Vec(T{0}, T{T::Lowest()})), }); } return r; } template std::vector OpMulMatCases() { auto r = std::vector{ // s * mat3x2 = mat3x2 C(Val(T{2.25}), Mat({T{1.0}, T{4.0}}, // {T{2.0}, T{5.0}}, // {T{3.0}, T{6.0}}), Mat({T{2.25}, T{9.0}}, // {T{4.5}, T{11.25}}, // {T{6.75}, T{13.5}})), // mat3x2 * s = mat3x2 C(Mat({T{1.0}, T{4.0}}, // {T{2.0}, T{5.0}}, // {T{3.0}, T{6.0}}), Val(T{2.25}), Mat({T{2.25}, T{9.0}}, // {T{4.5}, T{11.25}}, // {T{6.75}, T{13.5}})), // vec3 * mat2x3 = vec2 C(Vec(T{1.25}, T{2.25}, T{3.25}), // Mat({T{1.0}, T{2.0}, T{3.0}}, // {T{4.0}, T{5.0}, T{6.0}}), // Vec(T{15.5}, T{35.75})), // mat2x3 * vec2 = vec3 C(Mat({T{1.0}, T{2.0}, T{3.0}}, // {T{4.0}, T{5.0}, T{6.0}}), // Vec(T{1.25}, T{2.25}), // Vec(T{10.25}, T{13.75}, T{17.25})), // mat3x2 * mat2x3 = mat2x2 C(Mat({T{1.0}, T{2.0}}, // {T{3.0}, T{4.0}}, // {T{5.0}, T{6.0}}), // Mat({T{1.25}, T{2.25}, T{3.25}}, // {T{4.25}, T{5.25}, T{6.25}}), // Mat({T{24.25}, T{31.0}}, // {T{51.25}, T{67.0}})), // }; auto error_msg = [](auto a, const char* op, auto b) { return "12:34 error: " + OverflowErrorMessage(a, op, b); }; ConcatIntoIf || IsFloatingPoint>( // r, std::vector{ // vector-matrix multiply // Overflow from first multiplication of dot product of vector and matrix column 0 // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) // ^ E(Vec(T::Highest(), T{1.0}), // Mat({T{2.0}, T{1.0}}, // {T{1.0}, T{1.0}}), // error_msg(T{2}, "*", T::Highest())), // Overflow from second multiplication of dot product of vector and matrix column 0 // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) // ^ E(Vec(T{1.0}, T::Highest()), // Mat({T{1.0}, T{2.0}}, // {T{1.0}, T{1.0}}), // error_msg(T{2}, "*", T::Highest())), // Overflow from addition of dot product of vector and matrix column 0 // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) // ^ E(Vec(T::Highest(), T::Highest()), // Mat({T{1.0}, T{1.0}}, // {T{1.0}, T{1.0}}), // error_msg(T::Highest(), "+", T::Highest())), // matrix-matrix multiply // Overflow from first multiplication of dot product of lhs row 0 and rhs column 0 // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] // ^ E(Mat({T::Highest(), T{1.0}}, // {T{1.0}, T{1.0}}), // Mat({T{2.0}, T{1.0}}, // {T{1.0}, T{1.0}}), // error_msg(T::Highest(), "*", T{2.0})), // Overflow from second multiplication of dot product of lhs row 0 and rhs column 0 // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] // ^ E(Mat({T{1.0}, T{1.0}}, // {T::Highest(), T{1.0}}), // Mat({T{1.0}, T{2.0}}, // {T{1.0}, T{1.0}}), // error_msg(T::Highest(), "*", T{2.0})), // Overflow from addition of dot product of lhs row 0 and rhs column 0 // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] // ^ E(Mat({T::Highest(), T{1.0}}, // {T::Highest(), T{1.0}}), // Mat({T{1.0}, T{1.0}}, // {T{1.0}, T{1.0}}), // error_msg(T::Highest(), "+", T::Highest())), }); return r; } INSTANTIATE_TEST_SUITE_P(Mul, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kMultiply), testing::ValuesIn(Concat( // OpMulScalarCases(), OpMulScalarCases(), OpMulScalarCases(), OpMulScalarCases(), OpMulScalarCases(), OpMulScalarCases(), OpMulVecCases(), OpMulVecCases(), OpMulVecCases(), OpMulVecCases(), OpMulVecCases(), OpMulVecCases(), OpMulMatCases(), OpMulMatCases(), OpMulMatCases())))); template std::vector OpDivIntCases() { std::vector r = { C(Val(T{0}), Val(T{1}), Val(T{0})), C(Val(T{1}), Val(T{1}), Val(T{1})), C(Val(T{1}), Val(T{1}), Val(T{1})), C(Val(T{2}), Val(T{1}), Val(T{2})), C(Val(T{4}), Val(T{2}), Val(T{2})), C(Val(T::Highest()), Val(T{1}), Val(T::Highest())), C(Val(T::Lowest()), Val(T{1}), Val(T::Lowest())), C(Val(T::Highest()), Val(T::Highest()), Val(T{1})), C(Val(T{0}), Val(T::Highest()), Val(T{0})), C(Val(T{0}), Val(T::Lowest()), Val(T{0})), }; ConcatIntoIf && IsIntegral>( // r, std::vector{ // e1, when e2 is zero. C(T{123}, T{0}, T{123}), }); ConcatIntoIf && IsSignedIntegral>( // r, std::vector{ // e1, when e1 is the most negative value in T, and e2 is -1. C(T::Smallest(), T{-1}, T::Smallest()), }); auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "/", b); }; ConcatIntoIf>( // r, std::vector{ // Most negative value divided by -1 E(AInt::Lowest(), -1_a, error_msg(AInt::Lowest(), -1_a)), }); return r; } template std::vector OpDivFloatCases() { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "/", b); }; std::vector r = { C(Val(T{0}), Val(T{1}), Val(T{0})), C(Val(T{1}), Val(T{1}), Val(T{1})), C(Val(T{1}), Val(T{1}), Val(T{1})), C(Val(T{2}), Val(T{1}), Val(T{2})), C(Val(T{4}), Val(T{2}), Val(T{2})), C(Val(T::Highest()), Val(T{1}), Val(T::Highest())), C(Val(T::Lowest()), Val(T{1}), Val(T::Lowest())), C(Val(T::Highest()), Val(T::Highest()), Val(T{1})), C(Val(T{0}), Val(T::Highest()), Val(T{0})), C(Val(T{0}), Val(T::Lowest()), Val(-T{0})), // Divide by zero E(T{123}, T{0}, error_msg(T{123}, T{0})), E(Negate(T{123}), Negate(T{0}), error_msg(Negate(T{123}), Negate(T{0}))), E(Negate(T{123}), T{0}, error_msg(Negate(T{123}), T{0})), E(T{123}, Negate(T{0}), error_msg(T{123}, Negate(T{0}))), }; return r; } INSTANTIATE_TEST_SUITE_P(Div, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kDivide), testing::ValuesIn(Concat( // OpDivIntCases(), OpDivIntCases(), OpDivIntCases(), OpDivFloatCases(), OpDivFloatCases(), OpDivFloatCases())))); template std::vector OpEqualCases() { return { C(Val(T{0}), Val(T{0}), Val(true == equals)), C(Val(T{0}), Val(T{1}), Val(false == equals)), C(Val(T{1}), Val(T{0}), Val(false == equals)), C(Val(T{1}), Val(T{1}), Val(true == equals)), C(Vec(T{0}, T{0}), Vec(T{0}, T{0}), Vec(true == equals, true == equals)), C(Vec(T{1}, T{0}), Vec(T{0}, T{1}), Vec(false == equals, false == equals)), C(Vec(T{1}, T{1}), Vec(T{0}, T{1}), Vec(false == equals, true == equals)), }; } INSTANTIATE_TEST_SUITE_P(Equal, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kEqual), testing::ValuesIn(Concat( // OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases())))); INSTANTIATE_TEST_SUITE_P(NotEqual, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kNotEqual), testing::ValuesIn(Concat( // OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases())))); template std::vector OpLessThanCases() { return { C(Val(T{0}), Val(T{0}), Val(false == less_than)), C(Val(T{0}), Val(T{1}), Val(true == less_than)), C(Val(T{1}), Val(T{0}), Val(false == less_than)), C(Val(T{1}), Val(T{1}), Val(false == less_than)), C(Vec(T{0}, T{0}), Vec(T{0}, T{0}), Vec(false == less_than, false == less_than)), C(Vec(T{0}, T{0}), Vec(T{1}, T{1}), Vec(true == less_than, true == less_than)), C(Vec(T{1}, T{1}), Vec(T{0}, T{0}), Vec(false == less_than, false == less_than)), C(Vec(T{1}, T{0}), Vec(T{0}, T{1}), Vec(false == less_than, true == less_than)), }; } INSTANTIATE_TEST_SUITE_P(LessThan, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kLessThan), testing::ValuesIn(Concat( // OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases())))); INSTANTIATE_TEST_SUITE_P(GreaterThanEqual, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kGreaterThanEqual), testing::ValuesIn(Concat( // OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases())))); template std::vector OpGreaterThanCases() { return { C(Val(T{0}), Val(T{0}), Val(false == greater_than)), C(Val(T{0}), Val(T{1}), Val(false == greater_than)), C(Val(T{1}), Val(T{0}), Val(true == greater_than)), C(Val(T{1}), Val(T{1}), Val(false == greater_than)), C(Vec(T{0}, T{0}), Vec(T{0}, T{0}), Vec(false == greater_than, false == greater_than)), C(Vec(T{1}, T{1}), Vec(T{0}, T{0}), Vec(true == greater_than, true == greater_than)), C(Vec(T{0}, T{0}), Vec(T{1}, T{1}), Vec(false == greater_than, false == greater_than)), C(Vec(T{1}, T{0}), Vec(T{0}, T{1}), Vec(true == greater_than, false == greater_than)), }; } INSTANTIATE_TEST_SUITE_P(GreaterThan, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kGreaterThan), testing::ValuesIn(Concat( // OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases())))); INSTANTIATE_TEST_SUITE_P(LessThanEqual, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kLessThanEqual), testing::ValuesIn(Concat( // OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases())))); static std::vector OpAndBoolCases() { return { C(true, true, true), C(true, false, false), C(false, true, false), C(false, false, false), C(Vec(true, true), Vec(true, false), Vec(true, false)), C(Vec(true, true), Vec(false, true), Vec(false, true)), C(Vec(true, false), Vec(true, false), Vec(true, false)), C(Vec(false, true), Vec(true, false), Vec(false, false)), C(Vec(false, false), Vec(true, false), Vec(false, false)), }; } template std::vector OpAndIntCases() { using B = BitValues; return { C(T{0b1010}, T{0b1111}, T{0b1010}), C(T{0b1010}, T{0b0000}, T{0b0000}), C(T{0b1010}, T{0b0011}, T{0b0010}), C(T{0b1010}, T{0b1100}, T{0b1000}), C(T{0b1010}, T{0b0101}, T{0b0000}), C(B::All, B::All, B::All), C(B::LeftMost, B::LeftMost, B::LeftMost), C(B::RightMost, B::RightMost, B::RightMost), C(B::All, T{0}, T{0}), C(T{0}, B::All, T{0}), C(B::LeftMost, B::AllButLeftMost, T{0}), C(B::AllButLeftMost, B::LeftMost, T{0}), C(B::RightMost, B::AllButRightMost, T{0}), C(B::AllButRightMost, B::RightMost, T{0}), C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(B::All, B::All, B::All), // Vec(B::All, B::LeftMost, B::RightMost)), // C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(T{0}, T{0}, T{0}), // Vec(T{0}, T{0}, T{0})), // C(Vec(B::LeftMost, B::RightMost), // Vec(B::AllButLeftMost, B::AllButRightMost), // Vec(T{0}, T{0})), }; } INSTANTIATE_TEST_SUITE_P(And, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kAnd), testing::ValuesIn( // Concat(OpAndBoolCases(), // OpAndIntCases(), OpAndIntCases(), OpAndIntCases())))); static std::vector OpOrBoolCases() { return { C(true, true, true), C(true, false, true), C(false, true, true), C(false, false, false), C(Vec(true, true), Vec(true, false), Vec(true, true)), C(Vec(true, true), Vec(false, true), Vec(true, true)), C(Vec(true, false), Vec(true, false), Vec(true, false)), C(Vec(false, true), Vec(true, false), Vec(true, true)), C(Vec(false, false), Vec(true, false), Vec(true, false)), }; } template std::vector OpOrIntCases() { using B = BitValues; return { C(T{0b1010}, T{0b1111}, T{0b1111}), C(T{0b1010}, T{0b0000}, T{0b1010}), C(T{0b1010}, T{0b0011}, T{0b1011}), C(T{0b1010}, T{0b1100}, T{0b1110}), C(T{0b1010}, T{0b0101}, T{0b1111}), C(B::All, B::All, B::All), C(B::LeftMost, B::LeftMost, B::LeftMost), C(B::RightMost, B::RightMost, B::RightMost), C(B::All, T{0}, B::All), C(T{0}, B::All, B::All), C(B::LeftMost, B::AllButLeftMost, B::All), C(B::AllButLeftMost, B::LeftMost, B::All), C(B::RightMost, B::AllButRightMost, B::All), C(B::AllButRightMost, B::RightMost, B::All), C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(B::All, B::All, B::All), // Vec(B::All, B::All, B::All)), // C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(T{0}, T{0}, T{0}), // Vec(B::All, B::LeftMost, B::RightMost)), // C(Vec(B::LeftMost, B::RightMost), // Vec(B::AllButLeftMost, B::AllButRightMost), // Vec(B::All, B::All)), }; } INSTANTIATE_TEST_SUITE_P(Or, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kOr), testing::ValuesIn(Concat(OpOrBoolCases(), OpOrIntCases(), OpOrIntCases(), OpOrIntCases())))); TEST_F(ResolverConstEvalTest, NotAndOrOfVecs) { // const C = !((vec2(true, true) & vec2(true, false)) | vec2(false, true)); auto v1 = Vec(true, true).Expr(*this); auto v2 = Vec(true, false).Expr(*this); auto v3 = Vec(false, true).Expr(*this); auto expr = Not(Or(And(v1, v2), v3)); GlobalConst("C", expr); auto expected_expr = Vec(false, false).Expr(*this); GlobalConst("E", expected_expr); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); const sem::Constant* value = sem->ConstantValue(); ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); auto* expected_sem = Sem().Get(expected_expr); const sem::Constant* expected_value = expected_sem->ConstantValue(); ASSERT_NE(expected_value, nullptr); EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) { EXPECT_EQ(a->As(), b->As()); return HasFailure() ? Action::kStop : Action::kContinue; }); } template std::vector XorCases() { using B = BitValues; return { C(T{0b1010}, T{0b1111}, T{0b0101}), C(T{0b1010}, T{0b0000}, T{0b1010}), C(T{0b1010}, T{0b0011}, T{0b1001}), C(T{0b1010}, T{0b1100}, T{0b0110}), C(T{0b1010}, T{0b0101}, T{0b1111}), C(B::All, B::All, T{0}), C(B::LeftMost, B::LeftMost, T{0}), C(B::RightMost, B::RightMost, T{0}), C(B::All, T{0}, B::All), C(T{0}, B::All, B::All), C(B::LeftMost, B::AllButLeftMost, B::All), C(B::AllButLeftMost, B::LeftMost, B::All), C(B::RightMost, B::AllButRightMost, B::All), C(B::AllButRightMost, B::RightMost, B::All), C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(B::All, B::All, B::All), // Vec(T{0}, B::AllButLeftMost, B::AllButRightMost)), // C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(T{0}, T{0}, T{0}), // Vec(B::All, B::LeftMost, B::RightMost)), // C(Vec(B::LeftMost, B::RightMost), // Vec(B::AllButLeftMost, B::AllButRightMost), // Vec(B::All, B::All)), }; } INSTANTIATE_TEST_SUITE_P(Xor, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kXor), testing::ValuesIn(Concat(XorCases(), // XorCases(), // XorCases())))); template std::vector ShiftLeftCases() { // Shift type is u32 for non-abstract using ST = std::conditional_t, T, u32>; using B = BitValues; auto r = std::vector{ C(T{0b1010}, ST{0}, T{0b0000'0000'1010}), // C(T{0b1010}, ST{1}, T{0b0000'0001'0100}), // C(T{0b1010}, ST{2}, T{0b0000'0010'1000}), // C(T{0b1010}, ST{3}, T{0b0000'0101'0000}), // C(T{0b1010}, ST{4}, T{0b0000'1010'0000}), // C(T{0b1010}, ST{5}, T{0b0001'0100'0000}), // C(T{0b1010}, ST{6}, T{0b0010'1000'0000}), // C(T{0b1010}, ST{7}, T{0b0101'0000'0000}), // C(T{0b1010}, ST{8}, T{0b1010'0000'0000}), // C(B::LeftMost, ST{0}, B::LeftMost), // C(Vec(T{0b1010}, T{0b1010}), // Vec(ST{0}, ST{1}), // Vec(T{0b0000'0000'1010}, T{0b0000'0001'0100})), // C(Vec(T{0b1010}, T{0b1010}), // Vec(ST{2}, ST{3}), // Vec(T{0b0000'0010'1000}, T{0b0000'0101'0000})), // C(Vec(T{0b1010}, T{0b1010}), // Vec(ST{4}, ST{5}), // Vec(T{0b0000'1010'0000}, T{0b0001'0100'0000})), // C(Vec(T{0b1010}, T{0b1010}, T{0b1010}), // Vec(ST{6}, ST{7}, ST{8}), // Vec(T{0b0010'1000'0000}, T{0b0101'0000'0000}, T{0b1010'0000'0000})), // }; // Abstract 0 can be shifted by any u32 value (0 to 2^32), whereas concrete 0 (or any number) // can only by shifted by a value less than the number of bits of the lhs. // (see ResolverConstEvalShiftLeftConcreteGeqBitWidthError for negative tests) ConcatIntoIf>( // r, std::vector{ C(T{0}, ST{64}, T{0}), // C(T{0}, ST{65}, T{0}), // C(T{0}, ST{65}, T{0}), // C(T{0}, ST{10000}, T{0}), // C(T{0}, ST{u32::Highest()}, T{0}), // C(Negate(T{0}), ST{64}, Negate(T{0})), // C(Negate(T{0}), ST{65}, Negate(T{0})), // C(Negate(T{0}), ST{65}, Negate(T{0})), // C(Negate(T{0}), ST{10000}, Negate(T{0})), // C(Negate(T{0}), ST{u32::Highest()}, Negate(T{0})), // }); // Cases that are fine for signed values (no sign change), but would overflow // unsigned values. See below for negative tests. ConcatIntoIf>( // r, std::vector{ C(B::TwoLeftMost, ST{1}, B::LeftMost), // C(B::All, ST{1}, B::AllButRightMost), // C(B::All, ST{B::NumBits - 1}, B::LeftMost) // }); // Cases that are fine for unsigned values, but would overflow (sign change) signed // values. See ShiftLeftSignChangeErrorCases() for negative tests. ConcatIntoIf>( // r, std::vector{ C(T{0b0001}, ST{B::NumBits - 1}, B::Lsh(0b0001, B::NumBits - 1)), C(T{0b0010}, ST{B::NumBits - 2}, B::Lsh(0b0010, B::NumBits - 2)), C(T{0b0100}, ST{B::NumBits - 3}, B::Lsh(0b0100, B::NumBits - 3)), C(T{0b1000}, ST{B::NumBits - 4}, B::Lsh(0b1000, B::NumBits - 4)), C(T{0b0011}, ST{B::NumBits - 2}, B::Lsh(0b0011, B::NumBits - 2)), C(T{0b0110}, ST{B::NumBits - 3}, B::Lsh(0b0110, B::NumBits - 3)), C(T{0b1100}, ST{B::NumBits - 4}, B::Lsh(0b1100, B::NumBits - 4)), C(B::AllButLeftMost, ST{1}, B::AllButRightMost), }); auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "<<", b); }; ConcatIntoIf>( // r, std::vector{ // ShiftLeft of AInts that result in values not representable as AInts. // Note that for i32/u32, these would error because shift value is larger than 32. E(B::All, T{B::NumBits}, error_msg(B::All, T{B::NumBits})), E(B::RightMost, T{B::NumBits}, error_msg(B::RightMost, T{B::NumBits})), E(B::AllButLeftMost, T{B::NumBits}, error_msg(B::AllButLeftMost, T{B::NumBits})), E(B::AllButLeftMost, T{B::NumBits + 1}, error_msg(B::AllButLeftMost, T{B::NumBits + 1})), E(B::AllButLeftMost, T{B::NumBits + 1000}, error_msg(B::AllButLeftMost, T{B::NumBits + 1000})), }); ConcatIntoIf>( // r, std::vector{ // ShiftLeft of u32s that overflow (non-zero bits get shifted out) E(T{0b00010}, T{31}, error_msg(T{0b00010}, T{31})), E(T{0b00100}, T{30}, error_msg(T{0b00100}, T{30})), E(T{0b01000}, T{29}, error_msg(T{0b01000}, T{29})), E(T{0b10000}, T{28}, error_msg(T{0b10000}, T{28})), //... E(T{1 << 28}, T{4}, error_msg(T{1 << 28}, T{4})), E(T{1 << 29}, T{3}, error_msg(T{1 << 29}, T{3})), E(T{1 << 30}, T{2}, error_msg(T{1 << 30}, T{2})), E(T{1u << 31}, T{1}, error_msg(T{1u << 31}, T{1})), // And some more E(B::All, T{1}, error_msg(B::All, T{1})), E(B::AllButLeftMost, T{2}, error_msg(B::AllButLeftMost, T{2})), }); return r; } INSTANTIATE_TEST_SUITE_P(ShiftLeft, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kShiftLeft), testing::ValuesIn(Concat(ShiftLeftCases(), // ShiftLeftCases(), // ShiftLeftCases())))); TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '9223372036854775807 + 1' cannot be represented as 'abstract-int'"); } TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AInt) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Lowest()), -1_a)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '-9223372036854775808 + -1' cannot be represented as 'abstract-int'"); } TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AFloat) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AFloat::Highest()), AFloat::Highest())); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '1.7976931348623157081e+308 + 1.7976931348623157081e+308' cannot be " "represented as 'abstract-float'"); } TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AFloat::Lowest()), AFloat::Lowest())); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '-1.7976931348623157081e+308 + -1.7976931348623157081e+308' cannot be " "represented as 'abstract-float'"); } // Mixed AInt and AFloat args to test implicit conversion to AFloat INSTANTIATE_TEST_SUITE_P( AbstractMixed, ResolverConstEvalBinaryOpTest, testing::Combine( testing::Values(ast::BinaryOp::kAdd), testing::Values(C(Val(1_a), Val(2.3_a), Val(3.3_a)), C(Val(2.3_a), Val(1_a), Val(3.3_a)), C(Val(1_a), Vec(2.3_a, 2.3_a, 2.3_a), Vec(3.3_a, 3.3_a, 3.3_a)), C(Vec(2.3_a, 2.3_a, 2.3_a), Val(1_a), Vec(3.3_a, 3.3_a, 3.3_a)), C(Vec(2.3_a, 2.3_a, 2.3_a), Val(1_a), Vec(3.3_a, 3.3_a, 3.3_a)), C(Val(1_a), Vec(2.3_a, 2.3_a, 2.3_a), Vec(3.3_a, 3.3_a, 3.3_a)), C(Mat({1_a, 2_a}, // {1_a, 2_a}, // {1_a, 2_a}), // Mat({1.2_a, 2.3_a}, // {1.2_a, 2.3_a}, // {1.2_a, 2.3_a}), // Mat({2.2_a, 4.3_a}, // {2.2_a, 4.3_a}, // {2.2_a, 4.3_a})), // C(Mat({1.2_a, 2.3_a}, // {1.2_a, 2.3_a}, // {1.2_a, 2.3_a}), // Mat({1_a, 2_a}, // {1_a, 2_a}, // {1_a, 2_a}), // Mat({2.2_a, 4.3_a}, // {2.2_a, 4.3_a}, // {2.2_a, 4.3_a})) // ))); // AInt left shift negative value -> error TEST_F(ResolverConstEvalTest, BinaryAbstractShiftLeftByNegativeValue_Error) { GlobalConst("c", Shl(Expr(1_a), Expr(Source{{1, 1}}, -1_a))); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: value -1 cannot be represented as 'u32'"); } // AInt left shift by AInt or u32 always results in an AInt TEST_F(ResolverConstEvalTest, BinaryAbstractShiftLeftRemainsAbstract) { auto* expr1 = Shl(Expr(1_a), Expr(1_u)); GlobalConst("c1", expr1); auto* expr2 = Shl(Expr(1_a), Expr(1_a)); GlobalConst("c2", expr2); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem1 = Sem().Get(expr1); ASSERT_NE(sem1, nullptr); auto* sem2 = Sem().Get(expr2); ASSERT_NE(sem2, nullptr); auto aint_ty = create(); EXPECT_EQ(sem1->Type(), aint_ty); EXPECT_EQ(sem2->Type(), aint_ty); } // i32/u32 left shift by >= 32 -> error using ResolverConstEvalShiftLeftConcreteGeqBitWidthError = ResolverTestWithParam; TEST_P(ResolverConstEvalShiftLeftConcreteGeqBitWidthError, Test) { auto* lhs_expr = ToValueBase(GetParam().lhs)->Expr(*this); auto* rhs_expr = ToValueBase(GetParam().rhs)->Expr(*this); GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ( r()->error(), "1:1 error: shift left value must be less than the bit width of the lhs, which is 32"); } INSTANTIATE_TEST_SUITE_P(Test, ResolverConstEvalShiftLeftConcreteGeqBitWidthError, testing::Values( // ErrorCase{Val(0_u), Val(32_u)}, // ErrorCase{Val(0_u), Val(33_u)}, // ErrorCase{Val(0_u), Val(34_u)}, // ErrorCase{Val(0_u), Val(10000_u)}, // ErrorCase{Val(0_u), Val(u32::Highest())}, // ErrorCase{Val(0_i), Val(32_u)}, // ErrorCase{Val(0_i), Val(33_u)}, // ErrorCase{Val(0_i), Val(34_u)}, // ErrorCase{Val(0_i), Val(10000_u)}, // ErrorCase{Val(0_i), Val(u32::Highest())}, // ErrorCase{Val(Negate(0_u)), Val(32_u)}, // ErrorCase{Val(Negate(0_u)), Val(33_u)}, // ErrorCase{Val(Negate(0_u)), Val(34_u)}, // ErrorCase{Val(Negate(0_u)), Val(10000_u)}, // ErrorCase{Val(Negate(0_u)), Val(u32::Highest())}, // ErrorCase{Val(Negate(0_i)), Val(32_u)}, // ErrorCase{Val(Negate(0_i)), Val(33_u)}, // ErrorCase{Val(Negate(0_i)), Val(34_u)}, // ErrorCase{Val(Negate(0_i)), Val(10000_u)}, // ErrorCase{Val(Negate(0_i)), Val(u32::Highest())}, // ErrorCase{Val(1_i), Val(32_u)}, // ErrorCase{Val(1_i), Val(33_u)}, // ErrorCase{Val(1_i), Val(34_u)}, // ErrorCase{Val(1_i), Val(10000_u)}, // ErrorCase{Val(1_i), Val(u32::Highest())}, // ErrorCase{Val(1_u), Val(32_u)}, // ErrorCase{Val(1_u), Val(33_u)}, // ErrorCase{Val(1_u), Val(34_u)}, // ErrorCase{Val(1_u), Val(10000_u)}, // ErrorCase{Val(1_u), Val(u32::Highest())} // )); // AInt left shift results in sign change error using ResolverConstEvalShiftLeftSignChangeError = ResolverTestWithParam; TEST_P(ResolverConstEvalShiftLeftSignChangeError, Test) { auto* lhs_expr = ToValueBase(GetParam().lhs)->Expr(*this); auto* rhs_expr = ToValueBase(GetParam().rhs)->Expr(*this); GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: shift left operation results in sign change"); } template std::vector ShiftLeftSignChangeErrorCases() { // Shift type is u32 for non-abstract using ST = std::conditional_t, T, u32>; using B = BitValues; return { {Val(T{0b0001}), Val(ST{B::NumBits - 1})}, {Val(T{0b0010}), Val(ST{B::NumBits - 2})}, {Val(T{0b0100}), Val(ST{B::NumBits - 3})}, {Val(T{0b1000}), Val(ST{B::NumBits - 4})}, {Val(T{0b0011}), Val(ST{B::NumBits - 2})}, {Val(T{0b0110}), Val(ST{B::NumBits - 3})}, {Val(T{0b1100}), Val(ST{B::NumBits - 4})}, {Val(B::AllButLeftMost), Val(ST{1})}, {Val(B::AllButLeftMost), Val(ST{B::NumBits - 1})}, {Val(B::LeftMost), Val(ST{1})}, {Val(B::LeftMost), Val(ST{B::NumBits - 1})}, }; } INSTANTIATE_TEST_SUITE_P(Test, ResolverConstEvalShiftLeftSignChangeError, testing::ValuesIn(Concat( // ShiftLeftSignChangeErrorCases(), ShiftLeftSignChangeErrorCases()))); } // namespace } // namespace tint::resolver