// Copyright 2022 The Tint Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "src/tint/resolver/const_eval_test.h" #include "src/tint/reader/wgsl/parser.h" #include "src/tint/utils/result.h" using namespace tint::number_suffixes; // NOLINT using ::testing::HasSubstr; namespace tint::resolver { namespace { struct Case { struct Success { Value value; }; struct Failure { std::string error; }; Value lhs; Value rhs; utils::Result expected; }; struct ErrorCase { Value lhs; Value rhs; }; /// Creates a Case with Values of any type Case C(Value lhs, Value rhs, Value expected) { return Case{std::move(lhs), std::move(rhs), Case::Success{std::move(expected)}}; } /// Convenience overload that creates a Case with just scalars template >> Case C(T lhs, U rhs, V expected) { return Case{Val(lhs), Val(rhs), Case::Success{Val(expected)}}; } /// Creates an failure Case with Values of any type Case E(Value lhs, Value rhs, std::string error) { return Case{std::move(lhs), std::move(rhs), Case::Failure{std::move(error)}}; } /// Convenience overload that creates an error Case with just scalars template >> Case E(T lhs, U rhs, std::string error) { return Case{Val(lhs), Val(rhs), Case::Failure{std::move(error)}}; } /// Prints Case to ostream static std::ostream& operator<<(std::ostream& o, const Case& c) { o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: "; if (c.expected) { auto& s = c.expected.Get(); o << s.value; } else { o << "[ERROR: " << c.expected.Failure().error << "]"; } return o; } /// Prints ErrorCase to ostream std::ostream& operator<<(std::ostream& o, const ErrorCase& c) { o << c.lhs << ", " << c.rhs; return o; } using ResolverConstEvalBinaryOpTest = ResolverTestWithParam>; TEST_P(ResolverConstEvalBinaryOpTest, Test) { Enable(builtin::Extension::kF16); auto op = std::get<0>(GetParam()); auto& c = std::get<1>(GetParam()); auto* lhs_expr = c.lhs.Expr(*this); auto* rhs_expr = c.rhs.Expr(*this); auto* expr = create(Source{{12, 34}}, op, lhs_expr, rhs_expr); GlobalConst("C", expr); if (c.expected) { ASSERT_TRUE(r()->Resolve()) << r()->error(); auto expected_case = c.expected.Get(); auto& expected = expected_case.value; auto* sem = Sem().Get(expr); const constant::Value* value = sem->ConstantValue(); ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); CheckConstant(value, expected); } else { ASSERT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), c.expected.Failure().error); } } INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs, ResolverConstEvalBinaryOpTest, testing::Combine(testing::Values(ast::BinaryOp::kAdd), testing::ValuesIn(std::vector{ // Mixed abstract type args C(1_a, 2.3_a, 3.3_a), C(2.3_a, 1_a, 3.3_a), }))); template std::vector OpAddIntCases() { static_assert(IsIntegral); auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{3}), C(T::Lowest(), T{1}, T{T::Lowest() + 1}), C(T::Highest(), Negate(T{1}), T{T::Highest() - 1}), C(T::Lowest(), T::Highest(), Negate(T{1})), }; if constexpr (IsAbstract) { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "+", b); }; ConcatInto( // r, std::vector{ E(T::Highest(), T{1}, error_msg(T::Highest(), T{1})), E(T::Lowest(), Negate(T{1}), error_msg(T::Lowest(), Negate(T{1}))), }); } else { ConcatInto( // r, std::vector{ C(T::Highest(), T{1}, T::Lowest()), C(T::Lowest(), Negate(T{1}), T::Highest()), }); } return r; } template std::vector OpAddFloatCases() { static_assert(IsFloatingPoint); auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "+", b); }; return std::vector{ C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{3}), C(T::Lowest(), T{1}, T{T::Lowest() + 1}), C(T::Highest(), Negate(T{1}), T{T::Highest() - 1}), C(T::Lowest(), T::Highest(), T{0}), E(T::Highest(), T::Highest(), error_msg(T::Highest(), T::Highest())), E(T::Lowest(), Negate(T::Highest()), error_msg(T::Lowest(), Negate(T::Highest()))), }; } INSTANTIATE_TEST_SUITE_P(Add, ResolverConstEvalBinaryOpTest, testing::Combine(testing::Values(ast::BinaryOp::kAdd), testing::ValuesIn(Concat( // OpAddIntCases(), OpAddIntCases(), OpAddIntCases(), OpAddFloatCases(), OpAddFloatCases(), OpAddFloatCases())))); template std::vector OpSubIntCases() { static_assert(IsIntegral); auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{3}, T{2}, T{1}), C(T{T::Lowest() + 1}, T{1}, T::Lowest()), C(T{T::Highest() - 1}, Negate(T{1}), T::Highest()), C(Negate(T{1}), T::Highest(), T::Lowest()), }; if constexpr (IsAbstract) { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "-", b); }; ConcatInto( // r, std::vector{ E(T::Lowest(), T{1}, error_msg(T::Lowest(), T{1})), E(T::Highest(), Negate(T{1}), error_msg(T::Highest(), Negate(T{1}))), }); } else { ConcatInto( // r, std::vector{ C(T::Lowest(), T{1}, T::Highest()), C(T::Highest(), Negate(T{1}), T::Lowest()), }); } return r; } template std::vector OpSubFloatCases() { static_assert(IsFloatingPoint); auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "-", b); }; return std::vector{ C(T{0}, T{0}, T{0}), C(T{3}, T{2}, T{1}), C(T::Highest(), T{1}, T{T::Highest() - 1}), C(T::Lowest(), Negate(T{1}), T{T::Lowest() + 1}), C(T{0}, T::Highest(), T::Lowest()), E(T::Highest(), Negate(T::Highest()), error_msg(T::Highest(), Negate(T::Highest()))), E(T::Lowest(), T::Highest(), error_msg(T::Lowest(), T::Highest())), }; } INSTANTIATE_TEST_SUITE_P(Sub, ResolverConstEvalBinaryOpTest, testing::Combine(testing::Values(ast::BinaryOp::kSubtract), testing::ValuesIn(Concat( // OpSubIntCases(), OpSubIntCases(), OpSubIntCases(), OpSubFloatCases(), OpSubFloatCases(), OpSubFloatCases())))); template std::vector OpMulScalarCases() { auto r = std::vector{ C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{2}), C(T{2}, T{3}, T{6}), C(Negate(T{2}), T{3}, Negate(T{6})), C(T::Highest(), T{1}, T::Highest()), C(T::Lowest(), T{1}, T::Lowest()), }; if constexpr (IsAbstract || IsFloatingPoint) { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "*", b); }; ConcatInto( // r, std::vector{ // Fail if result is +/-inf E(T::Highest(), T::Highest(), error_msg(T::Highest(), T::Highest())), E(T::Lowest(), T::Lowest(), error_msg(T::Lowest(), T::Lowest())), E(T::Highest(), T{2}, error_msg(T::Highest(), T{2})), E(T::Lowest(), Negate(T{2}), error_msg(T::Lowest(), Negate(T{2}))), }); } else { ConcatInto( // r, std::vector{ C(T::Highest(), T::Highest(), Mul(T::Highest(), T::Highest())), C(T::Lowest(), T::Lowest(), Mul(T::Lowest(), T::Lowest())), }); } return r; } template std::vector OpMulVecCases() { auto r = std::vector{ // s * vec3 = vec3 C(Val(T{2.0}), Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.5}, T{4.5}, T{6.5})), // vec3 * s = vec3 C(Vec(T{1.25}, T{2.25}, T{3.25}), Val(T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})), // vec3 * vec3 = vec3 C(Vec(T{1.25}, T{2.25}, T{3.25}), Vec(T{2.0}, T{2.0}, T{2.0}), Vec(T{2.5}, T{4.5}, T{6.5})), }; if constexpr (IsAbstract || IsFloatingPoint) { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "*", b); }; ConcatInto( // r, std::vector{ // Fail if result is +/-inf E(Val(T::Highest()), Vec(T{2}, T{1}), error_msg(T::Highest(), T{2})), E(Val(T::Lowest()), Vec(Negate(T{2}), T{1}), error_msg(T::Lowest(), Negate(T{2}))), }); } else { ConcatInto( // r, std::vector{ C(Val(T::Highest()), Vec(T{2}, T{1}), Vec(T{-2}, T::Highest())), C(Val(T::Lowest()), Vec(Negate(T{2}), T{1}), Vec(T{0}, T{T::Lowest()})), }); } return r; } template std::vector OpMulMatCases() { auto r = std::vector{ // s * mat3x2 = mat3x2 C(Val(T{2.25}), Mat({T{1.0}, T{4.0}}, // {T{2.0}, T{5.0}}, // {T{3.0}, T{6.0}}), Mat({T{2.25}, T{9.0}}, // {T{4.5}, T{11.25}}, // {T{6.75}, T{13.5}})), // mat3x2 * s = mat3x2 C(Mat({T{1.0}, T{4.0}}, // {T{2.0}, T{5.0}}, // {T{3.0}, T{6.0}}), Val(T{2.25}), Mat({T{2.25}, T{9.0}}, // {T{4.5}, T{11.25}}, // {T{6.75}, T{13.5}})), // vec3 * mat2x3 = vec2 C(Vec(T{1.25}, T{2.25}, T{3.25}), // Mat({T{1.0}, T{2.0}, T{3.0}}, // {T{4.0}, T{5.0}, T{6.0}}), // Vec(T{15.5}, T{35.75})), // mat2x3 * vec2 = vec3 C(Mat({T{1.0}, T{2.0}, T{3.0}}, // {T{4.0}, T{5.0}, T{6.0}}), // Vec(T{1.25}, T{2.25}), // Vec(T{10.25}, T{13.75}, T{17.25})), // mat3x2 * mat2x3 = mat2x2 C(Mat({T{1.0}, T{2.0}}, // {T{3.0}, T{4.0}}, // {T{5.0}, T{6.0}}), // Mat({T{1.25}, T{2.25}, T{3.25}}, // {T{4.25}, T{5.25}, T{6.25}}), // Mat({T{24.25}, T{31.0}}, // {T{51.25}, T{67.0}})), // }; auto error_msg = [](auto a, const char* op, auto b) { return "12:34 error: " + OverflowErrorMessage(a, op, b); }; ConcatIntoIf || IsFloatingPoint>( // r, std::vector{ // vector-matrix multiply // Overflow from first multiplication of dot product of vector and matrix column 0 // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) // ^ E(Vec(T::Highest(), T{1.0}), // Mat({T{2.0}, T{1.0}}, // {T{1.0}, T{1.0}}), // error_msg(T{2}, "*", T::Highest())), // Overflow from second multiplication of dot product of vector and matrix column 0 // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) // ^ E(Vec(T{1.0}, T::Highest()), // Mat({T{1.0}, T{2.0}}, // {T{1.0}, T{1.0}}), // error_msg(T{2}, "*", T::Highest())), // Overflow from addition of dot product of vector and matrix column 0 // i.e. (v[0] * m[0][0] + v[1] * m[0][1]) // ^ E(Vec(T::Highest(), T::Highest()), // Mat({T{1.0}, T{1.0}}, // {T{1.0}, T{1.0}}), // error_msg(T::Highest(), "+", T::Highest())), // matrix-matrix multiply // Overflow from first multiplication of dot product of lhs row 0 and rhs column 0 // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] // ^ E(Mat({T::Highest(), T{1.0}}, // {T{1.0}, T{1.0}}), // Mat({T{2.0}, T{1.0}}, // {T{1.0}, T{1.0}}), // error_msg(T::Highest(), "*", T{2.0})), // Overflow from second multiplication of dot product of lhs row 0 and rhs column 0 // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] // ^ E(Mat({T{1.0}, T{1.0}}, // {T::Highest(), T{1.0}}), // Mat({T{1.0}, T{2.0}}, // {T{1.0}, T{1.0}}), // error_msg(T::Highest(), "*", T{2.0})), // Overflow from addition of dot product of lhs row 0 and rhs column 0 // i.e. m1[0][0] * m2[0][0] + m1[0][1] * m[1][0] // ^ E(Mat({T::Highest(), T{1.0}}, // {T::Highest(), T{1.0}}), // Mat({T{1.0}, T{1.0}}, // {T{1.0}, T{1.0}}), // error_msg(T::Highest(), "+", T::Highest())), }); return r; } INSTANTIATE_TEST_SUITE_P(Mul, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kMultiply), testing::ValuesIn(Concat( // OpMulScalarCases(), OpMulScalarCases(), OpMulScalarCases(), OpMulScalarCases(), OpMulScalarCases(), OpMulScalarCases(), OpMulVecCases(), OpMulVecCases(), OpMulVecCases(), OpMulVecCases(), OpMulVecCases(), OpMulVecCases(), OpMulMatCases(), OpMulMatCases(), OpMulMatCases())))); template std::vector OpDivIntCases() { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "/", b); }; std::vector r = { C(T{0}, T{1}, T{0}), C(T{1}, T{1}, T{1}), C(T{1}, T{1}, T{1}), C(T{2}, T{1}, T{2}), C(T{4}, T{2}, T{2}), C(T::Highest(), T{1}, T::Highest()), C(T::Lowest(), T{1}, T::Lowest()), C(T::Highest(), T::Highest(), T{1}), C(T{0}, T::Highest(), T{0}), // Divide by zero E(T{123}, T{0}, error_msg(T{123}, T{0})), E(T::Highest(), T{0}, error_msg(T::Highest(), T{0})), E(T::Lowest(), T{0}, error_msg(T::Lowest(), T{0})), }; // Error on most negative divided by -1 ConcatIntoIf>( // r, std::vector{ E(T::Lowest(), T{-1}, error_msg(T::Lowest(), T{-1})), }); return r; } template std::vector OpDivFloatCases() { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "/", b); }; std::vector r = { C(T{0}, T{1}, T{0}), C(T{1}, T{1}, T{1}), C(T{1}, T{1}, T{1}), C(T{2}, T{1}, T{2}), C(T{4}, T{2}, T{2}), C(T::Highest(), T{1}, T::Highest()), C(T::Lowest(), T{1}, T::Lowest()), C(T::Highest(), T::Highest(), T{1}), C(T{0}, T::Highest(), T{0}), C(T{0}, T::Lowest(), -T{0}), // Divide by zero E(T{123}, T{0}, error_msg(T{123}, T{0})), E(Negate(T{123}), Negate(T{0}), error_msg(Negate(T{123}), Negate(T{0}))), E(Negate(T{123}), T{0}, error_msg(Negate(T{123}), T{0})), E(T{123}, Negate(T{0}), error_msg(T{123}, Negate(T{0}))), }; return r; } INSTANTIATE_TEST_SUITE_P(Div, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kDivide), testing::ValuesIn(Concat( // OpDivIntCases(), OpDivIntCases(), OpDivIntCases(), OpDivFloatCases(), OpDivFloatCases(), OpDivFloatCases())))); template std::vector OpModCases() { auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "%", b); }; // Common cases for all types std::vector r = { C(T{0}, T{1}, T{0}), // C(T{1}, T{1}, T{0}), // C(T{10}, T{1}, T{0}), // C(T{10}, T{2}, T{0}), // C(T{10}, T{3}, T{1}), // C(T{10}, T{4}, T{2}), // C(T{10}, T{5}, T{0}), // C(T{10}, T{6}, T{4}), // C(T{10}, T{5}, T{0}), // C(T{10}, T{8}, T{2}), // C(T{10}, T{9}, T{1}), // C(T{10}, T{10}, T{0}), // // Error on divide by zero E(T{123}, T{0}, error_msg(T{123}, T{0})), E(T::Highest(), T{0}, error_msg(T::Highest(), T{0})), E(T::Lowest(), T{0}, error_msg(T::Lowest(), T{0})), }; if constexpr (IsIntegral) { ConcatInto( // r, std::vector{ C(T::Highest(), T{T::Highest() - T{1}}, T{1}), }); } if constexpr (IsSignedIntegral) { ConcatInto( // r, std::vector{ C(T::Lowest(), T{T::Lowest() + T{1}}, -T(1)), // Error on most negative integer divided by -1 E(T::Lowest(), T{-1}, error_msg(T::Lowest(), T{-1})), }); } // Negative values (both signed integrals and floating point) if constexpr (IsSignedIntegral || IsFloatingPoint) { ConcatInto( // r, std::vector{ C(-T{1}, T{1}, T{0}), // // lhs negative, rhs positive C(-T{10}, T{1}, T{0}), // C(-T{10}, T{2}, T{0}), // C(-T{10}, T{3}, -T{1}), // C(-T{10}, T{4}, -T{2}), // C(-T{10}, T{5}, T{0}), // C(-T{10}, T{6}, -T{4}), // C(-T{10}, T{5}, T{0}), // C(-T{10}, T{8}, -T{2}), // C(-T{10}, T{9}, -T{1}), // C(-T{10}, T{10}, T{0}), // // lhs positive, rhs negative C(T{10}, -T{1}, T{0}), // C(T{10}, -T{2}, T{0}), // C(T{10}, -T{3}, T{1}), // C(T{10}, -T{4}, T{2}), // C(T{10}, -T{5}, T{0}), // C(T{10}, -T{6}, T{4}), // C(T{10}, -T{5}, T{0}), // C(T{10}, -T{8}, T{2}), // C(T{10}, -T{9}, T{1}), // C(T{10}, -T{10}, T{0}), // // lhs negative, rhs negative C(-T{10}, -T{1}, T{0}), // C(-T{10}, -T{2}, T{0}), // C(-T{10}, -T{3}, -T{1}), // C(-T{10}, -T{4}, -T{2}), // C(-T{10}, -T{5}, T{0}), // C(-T{10}, -T{6}, -T{4}), // C(-T{10}, -T{5}, T{0}), // C(-T{10}, -T{8}, -T{2}), // C(-T{10}, -T{9}, -T{1}), // C(-T{10}, -T{10}, T{0}), // }); } // Float values if constexpr (IsFloatingPoint) { ConcatInto( // r, std::vector{ C(T{10.5}, T{1}, T{0.5}), // C(T{10.5}, T{2}, T{0.5}), // C(T{10.5}, T{3}, T{1.5}), // C(T{10.5}, T{4}, T{2.5}), // C(T{10.5}, T{5}, T{0.5}), // C(T{10.5}, T{6}, T{4.5}), // C(T{10.5}, T{5}, T{0.5}), // C(T{10.5}, T{8}, T{2.5}), // C(T{10.5}, T{9}, T{1.5}), // C(T{10.5}, T{10}, T{0.5}), // // lhs negative, rhs positive C(-T{10.5}, T{1}, -T{0.5}), // C(-T{10.5}, T{2}, -T{0.5}), // C(-T{10.5}, T{3}, -T{1.5}), // C(-T{10.5}, T{4}, -T{2.5}), // C(-T{10.5}, T{5}, -T{0.5}), // C(-T{10.5}, T{6}, -T{4.5}), // C(-T{10.5}, T{5}, -T{0.5}), // C(-T{10.5}, T{8}, -T{2.5}), // C(-T{10.5}, T{9}, -T{1.5}), // C(-T{10.5}, T{10}, -T{0.5}), // // lhs positive, rhs negative C(T{10.5}, -T{1}, T{0.5}), // C(T{10.5}, -T{2}, T{0.5}), // C(T{10.5}, -T{3}, T{1.5}), // C(T{10.5}, -T{4}, T{2.5}), // C(T{10.5}, -T{5}, T{0.5}), // C(T{10.5}, -T{6}, T{4.5}), // C(T{10.5}, -T{5}, T{0.5}), // C(T{10.5}, -T{8}, T{2.5}), // C(T{10.5}, -T{9}, T{1.5}), // C(T{10.5}, -T{10}, T{0.5}), // // lhs negative, rhs negative C(-T{10.5}, -T{1}, -T{0.5}), // C(-T{10.5}, -T{2}, -T{0.5}), // C(-T{10.5}, -T{3}, -T{1.5}), // C(-T{10.5}, -T{4}, -T{2.5}), // C(-T{10.5}, -T{5}, -T{0.5}), // C(-T{10.5}, -T{6}, -T{4.5}), // C(-T{10.5}, -T{5}, -T{0.5}), // C(-T{10.5}, -T{8}, -T{2.5}), // C(-T{10.5}, -T{9}, -T{1.5}), // C(-T{10.5}, -T{10}, -T{0.5}), // }); } return r; } INSTANTIATE_TEST_SUITE_P(Mod, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kModulo), testing::ValuesIn(Concat( // OpModCases(), OpModCases(), OpModCases(), OpModCases(), OpModCases(), OpModCases())))); template std::vector OpEqualCases() { return { C(T{0}, T{0}, true == equals), C(T{0}, T{1}, false == equals), C(T{1}, T{0}, false == equals), C(T{1}, T{1}, true == equals), C(Vec(T{0}, T{0}), Vec(T{0}, T{0}), Vec(true == equals, true == equals)), C(Vec(T{1}, T{0}), Vec(T{0}, T{1}), Vec(false == equals, false == equals)), C(Vec(T{1}, T{1}), Vec(T{0}, T{1}), Vec(false == equals, true == equals)), }; } INSTANTIATE_TEST_SUITE_P(Equal, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kEqual), testing::ValuesIn(Concat( // OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases())))); INSTANTIATE_TEST_SUITE_P(NotEqual, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kNotEqual), testing::ValuesIn(Concat( // OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases(), OpEqualCases())))); template std::vector OpLessThanCases() { return { C(T{0}, T{0}, false == less_than), C(T{0}, T{1}, true == less_than), C(T{1}, T{0}, false == less_than), C(T{1}, T{1}, false == less_than), C(Vec(T{0}, T{0}), Vec(T{0}, T{0}), Vec(false == less_than, false == less_than)), C(Vec(T{0}, T{0}), Vec(T{1}, T{1}), Vec(true == less_than, true == less_than)), C(Vec(T{1}, T{1}), Vec(T{0}, T{0}), Vec(false == less_than, false == less_than)), C(Vec(T{1}, T{0}), Vec(T{0}, T{1}), Vec(false == less_than, true == less_than)), }; } INSTANTIATE_TEST_SUITE_P(LessThan, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kLessThan), testing::ValuesIn(Concat( // OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases())))); INSTANTIATE_TEST_SUITE_P(GreaterThanEqual, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kGreaterThanEqual), testing::ValuesIn(Concat( // OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases(), OpLessThanCases())))); template std::vector OpGreaterThanCases() { return { C(T{0}, T{0}, false == greater_than), C(T{0}, T{1}, false == greater_than), C(T{1}, T{0}, true == greater_than), C(T{1}, T{1}, false == greater_than), C(Vec(T{0}, T{0}), Vec(T{0}, T{0}), Vec(false == greater_than, false == greater_than)), C(Vec(T{1}, T{1}), Vec(T{0}, T{0}), Vec(true == greater_than, true == greater_than)), C(Vec(T{0}, T{0}), Vec(T{1}, T{1}), Vec(false == greater_than, false == greater_than)), C(Vec(T{1}, T{0}), Vec(T{0}, T{1}), Vec(true == greater_than, false == greater_than)), }; } INSTANTIATE_TEST_SUITE_P(GreaterThan, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kGreaterThan), testing::ValuesIn(Concat( // OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases())))); INSTANTIATE_TEST_SUITE_P(LessThanEqual, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kLessThanEqual), testing::ValuesIn(Concat( // OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases(), OpGreaterThanCases())))); static std::vector OpLogicalAndCases() { return { C(true, true, true), C(true, false, false), C(false, true, false), C(false, false, false), }; } INSTANTIATE_TEST_SUITE_P(LogicalAnd, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kLogicalAnd), testing::ValuesIn(OpLogicalAndCases()))); static std::vector OpLogicalOrCases() { return { C(true, true, true), C(true, false, true), C(false, true, true), C(false, false, false), }; } INSTANTIATE_TEST_SUITE_P(LogicalOr, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kLogicalOr), testing::ValuesIn(OpLogicalOrCases()))); static std::vector OpAndBoolCases() { return { C(true, true, true), C(true, false, false), C(false, true, false), C(false, false, false), C(Vec(true, true), Vec(true, false), Vec(true, false)), C(Vec(true, true), Vec(false, true), Vec(false, true)), C(Vec(true, false), Vec(true, false), Vec(true, false)), C(Vec(false, true), Vec(true, false), Vec(false, false)), C(Vec(false, false), Vec(true, false), Vec(false, false)), }; } template std::vector OpAndIntCases() { using B = BitValues; return { C(T{0b1010}, T{0b1111}, T{0b1010}), C(T{0b1010}, T{0b0000}, T{0b0000}), C(T{0b1010}, T{0b0011}, T{0b0010}), C(T{0b1010}, T{0b1100}, T{0b1000}), C(T{0b1010}, T{0b0101}, T{0b0000}), C(B::All, B::All, B::All), C(B::LeftMost, B::LeftMost, B::LeftMost), C(B::RightMost, B::RightMost, B::RightMost), C(B::All, T{0}, T{0}), C(T{0}, B::All, T{0}), C(B::LeftMost, B::AllButLeftMost, T{0}), C(B::AllButLeftMost, B::LeftMost, T{0}), C(B::RightMost, B::AllButRightMost, T{0}), C(B::AllButRightMost, B::RightMost, T{0}), C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(B::All, B::All, B::All), // Vec(B::All, B::LeftMost, B::RightMost)), // C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(T{0}, T{0}, T{0}), // Vec(T{0}, T{0}, T{0})), // C(Vec(B::LeftMost, B::RightMost), // Vec(B::AllButLeftMost, B::AllButRightMost), // Vec(T{0}, T{0})), }; } INSTANTIATE_TEST_SUITE_P(And, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kAnd), testing::ValuesIn( // Concat(OpAndBoolCases(), // OpAndIntCases(), OpAndIntCases(), OpAndIntCases())))); static std::vector OpOrBoolCases() { return { C(true, true, true), C(true, false, true), C(false, true, true), C(false, false, false), C(Vec(true, true), Vec(true, false), Vec(true, true)), C(Vec(true, true), Vec(false, true), Vec(true, true)), C(Vec(true, false), Vec(true, false), Vec(true, false)), C(Vec(false, true), Vec(true, false), Vec(true, true)), C(Vec(false, false), Vec(true, false), Vec(true, false)), }; } template std::vector OpOrIntCases() { using B = BitValues; return { C(T{0b1010}, T{0b1111}, T{0b1111}), C(T{0b1010}, T{0b0000}, T{0b1010}), C(T{0b1010}, T{0b0011}, T{0b1011}), C(T{0b1010}, T{0b1100}, T{0b1110}), C(T{0b1010}, T{0b0101}, T{0b1111}), C(B::All, B::All, B::All), C(B::LeftMost, B::LeftMost, B::LeftMost), C(B::RightMost, B::RightMost, B::RightMost), C(B::All, T{0}, B::All), C(T{0}, B::All, B::All), C(B::LeftMost, B::AllButLeftMost, B::All), C(B::AllButLeftMost, B::LeftMost, B::All), C(B::RightMost, B::AllButRightMost, B::All), C(B::AllButRightMost, B::RightMost, B::All), C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(B::All, B::All, B::All), // Vec(B::All, B::All, B::All)), // C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(T{0}, T{0}, T{0}), // Vec(B::All, B::LeftMost, B::RightMost)), // C(Vec(B::LeftMost, B::RightMost), // Vec(B::AllButLeftMost, B::AllButRightMost), // Vec(B::All, B::All)), }; } INSTANTIATE_TEST_SUITE_P(Or, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kOr), testing::ValuesIn(Concat(OpOrBoolCases(), OpOrIntCases(), OpOrIntCases(), OpOrIntCases())))); TEST_F(ResolverConstEvalTest, NotAndOrOfVecs) { auto v1 = Vec(true, true).Expr(*this); auto v2 = Vec(true, false).Expr(*this); auto v3 = Vec(false, true).Expr(*this); auto expr = Not(Or(And(v1, v2), v3)); GlobalConst("C", expr); auto expected_expr = Vec(false, false).Expr(*this); GlobalConst("E", expected_expr); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); const constant::Value* value = sem->ConstantValue(); ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); auto* expected_sem = Sem().GetVal(expected_expr); const constant::Value* expected_value = expected_sem->ConstantValue(); ASSERT_NE(expected_value, nullptr); EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); ForEachElemPair(value, expected_value, [&](const constant::Value* a, const constant::Value* b) { EXPECT_EQ(a->ValueAs(), b->ValueAs()); return HasFailure() ? Action::kStop : Action::kContinue; }); } template std::vector XorCases() { using B = BitValues; return { C(T{0b1010}, T{0b1111}, T{0b0101}), C(T{0b1010}, T{0b0000}, T{0b1010}), C(T{0b1010}, T{0b0011}, T{0b1001}), C(T{0b1010}, T{0b1100}, T{0b0110}), C(T{0b1010}, T{0b0101}, T{0b1111}), C(B::All, B::All, T{0}), C(B::LeftMost, B::LeftMost, T{0}), C(B::RightMost, B::RightMost, T{0}), C(B::All, T{0}, B::All), C(T{0}, B::All, B::All), C(B::LeftMost, B::AllButLeftMost, B::All), C(B::AllButLeftMost, B::LeftMost, B::All), C(B::RightMost, B::AllButRightMost, B::All), C(B::AllButRightMost, B::RightMost, B::All), C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(B::All, B::All, B::All), // Vec(T{0}, B::AllButLeftMost, B::AllButRightMost)), // C(Vec(B::All, B::LeftMost, B::RightMost), // Vec(T{0}, T{0}, T{0}), // Vec(B::All, B::LeftMost, B::RightMost)), // C(Vec(B::LeftMost, B::RightMost), // Vec(B::AllButLeftMost, B::AllButRightMost), // Vec(B::All, B::All)), }; } INSTANTIATE_TEST_SUITE_P(Xor, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kXor), testing::ValuesIn(Concat(XorCases(), // XorCases(), // XorCases())))); template std::vector ShiftLeftCases() { using ST = u32; // Shift type is u32 using B = BitValues; auto r = std::vector{ C(T{0b1010}, ST{0}, T{0b0000'0000'1010}), // C(T{0b1010}, ST{1}, T{0b0000'0001'0100}), // C(T{0b1010}, ST{2}, T{0b0000'0010'1000}), // C(T{0b1010}, ST{3}, T{0b0000'0101'0000}), // C(T{0b1010}, ST{4}, T{0b0000'1010'0000}), // C(T{0b1010}, ST{5}, T{0b0001'0100'0000}), // C(T{0b1010}, ST{6}, T{0b0010'1000'0000}), // C(T{0b1010}, ST{7}, T{0b0101'0000'0000}), // C(T{0b1010}, ST{8}, T{0b1010'0000'0000}), // C(B::LeftMost, ST{0}, B::LeftMost), // C(Vec(T{0b1010}, T{0b1010}), // Vec(ST{0}, ST{1}), // Vec(T{0b0000'0000'1010}, T{0b0000'0001'0100})), // C(Vec(T{0b1010}, T{0b1010}), // Vec(ST{2}, ST{3}), // Vec(T{0b0000'0010'1000}, T{0b0000'0101'0000})), // C(Vec(T{0b1010}, T{0b1010}), // Vec(ST{4}, ST{5}), // Vec(T{0b0000'1010'0000}, T{0b0001'0100'0000})), // C(Vec(T{0b1010}, T{0b1010}, T{0b1010}), // Vec(ST{6}, ST{7}, ST{8}), // Vec(T{0b0010'1000'0000}, T{0b0101'0000'0000}, T{0b1010'0000'0000})), // }; // Abstract 0 can be shifted by any u32 value (0 to 2^32), whereas concrete 0 (or any number) // can only by shifted by a value less than the number of bits of the lhs. // (see ResolverConstEvalShiftLeftConcreteGeqBitWidthError for negative tests) ConcatIntoIf>( // r, std::vector{ C(T{0}, ST{64}, T{0}), // C(T{0}, ST{65}, T{0}), // C(T{0}, ST{65}, T{0}), // C(T{0}, ST{10000}, T{0}), // C(T{0}, ST{u32::Highest()}, T{0}), // C(Negate(T{0}), ST{64}, Negate(T{0})), // C(Negate(T{0}), ST{65}, Negate(T{0})), // C(Negate(T{0}), ST{65}, Negate(T{0})), // C(Negate(T{0}), ST{10000}, Negate(T{0})), // C(Negate(T{0}), ST{u32::Highest()}, Negate(T{0})), // }); // Cases that are fine for signed values (no sign change), but would overflow // unsigned values. See below for negative tests. ConcatIntoIf>( // r, std::vector{ C(B::TwoLeftMost, ST{1}, B::LeftMost), // C(B::All, ST{1}, B::AllButRightMost), // C(B::All, ST{B::NumBits - 1}, B::LeftMost) // }); // Cases that are fine for unsigned values, but would overflow (sign change) signed // values. See ShiftLeftSignChangeErrorCases() for negative tests. ConcatIntoIf>( // r, std::vector{ C(T{0b0001}, ST{B::NumBits - 1}, B::Lsh(0b0001, B::NumBits - 1)), C(T{0b0010}, ST{B::NumBits - 2}, B::Lsh(0b0010, B::NumBits - 2)), C(T{0b0100}, ST{B::NumBits - 3}, B::Lsh(0b0100, B::NumBits - 3)), C(T{0b1000}, ST{B::NumBits - 4}, B::Lsh(0b1000, B::NumBits - 4)), C(T{0b0011}, ST{B::NumBits - 2}, B::Lsh(0b0011, B::NumBits - 2)), C(T{0b0110}, ST{B::NumBits - 3}, B::Lsh(0b0110, B::NumBits - 3)), C(T{0b1100}, ST{B::NumBits - 4}, B::Lsh(0b1100, B::NumBits - 4)), C(B::AllButLeftMost, ST{1}, B::AllButRightMost), }); auto error_msg = [](auto a, auto b) { return "12:34 error: " + OverflowErrorMessage(a, "<<", b); }; ConcatIntoIf>( // r, std::vector{ // ShiftLeft of AInts that result in values not representable as AInts. // Note that for i32/u32, these would error because shift value is larger than 32. E(B::All, T{B::NumBits}, error_msg(B::All, T{B::NumBits})), E(B::RightMost, T{B::NumBits}, error_msg(B::RightMost, T{B::NumBits})), E(B::AllButLeftMost, T{B::NumBits}, error_msg(B::AllButLeftMost, T{B::NumBits})), E(B::AllButLeftMost, T{B::NumBits + 1}, error_msg(B::AllButLeftMost, T{B::NumBits + 1})), E(B::AllButLeftMost, T{B::NumBits + 1000}, error_msg(B::AllButLeftMost, T{B::NumBits + 1000})), }); ConcatIntoIf>( // r, std::vector{ // ShiftLeft of u32s that overflow (non-zero bits get shifted out) E(T{0b00010}, T{31}, error_msg(T{0b00010}, T{31})), E(T{0b00100}, T{30}, error_msg(T{0b00100}, T{30})), E(T{0b01000}, T{29}, error_msg(T{0b01000}, T{29})), E(T{0b10000}, T{28}, error_msg(T{0b10000}, T{28})), //... E(T{1 << 28}, T{4}, error_msg(T{1 << 28}, T{4})), E(T{1 << 29}, T{3}, error_msg(T{1 << 29}, T{3})), E(T{1 << 30}, T{2}, error_msg(T{1 << 30}, T{2})), E(T{1u << 31}, T{1}, error_msg(T{1u << 31}, T{1})), // And some more E(B::All, T{1}, error_msg(B::All, T{1})), E(B::AllButLeftMost, T{2}, error_msg(B::AllButLeftMost, T{2})), }); return r; } INSTANTIATE_TEST_SUITE_P(ShiftLeft, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kShiftLeft), testing::ValuesIn(Concat(ShiftLeftCases(), // ShiftLeftCases(), // ShiftLeftCases())))); TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AInt) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Highest()), 1_a)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '9223372036854775807 + 1' cannot be represented as 'abstract-int'"); } TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AInt) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AInt::Lowest()), -1_a)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '-9223372036854775808 + -1' cannot be represented as 'abstract-int'"); } TEST_F(ResolverConstEvalTest, BinaryAbstractAddOverflow_AFloat) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AFloat::Highest()), AFloat::Highest())); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '1.7976931348623157081e+308 + 1.7976931348623157081e+308' cannot be " "represented as 'abstract-float'"); } TEST_F(ResolverConstEvalTest, BinaryAbstractAddUnderflow_AFloat) { GlobalConst("c", Add(Source{{1, 1}}, Expr(AFloat::Lowest()), AFloat::Lowest())); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: '-1.7976931348623157081e+308 + -1.7976931348623157081e+308' cannot be " "represented as 'abstract-float'"); } // Mixed AInt and AFloat args to test implicit conversion to AFloat INSTANTIATE_TEST_SUITE_P( AbstractMixed, ResolverConstEvalBinaryOpTest, testing::Combine( testing::Values(ast::BinaryOp::kAdd), testing::Values(C(Val(1_a), Val(2.3_a), Val(3.3_a)), C(Val(2.3_a), Val(1_a), Val(3.3_a)), C(Val(1_a), Vec(2.3_a, 2.3_a, 2.3_a), Vec(3.3_a, 3.3_a, 3.3_a)), C(Vec(2.3_a, 2.3_a, 2.3_a), Val(1_a), Vec(3.3_a, 3.3_a, 3.3_a)), C(Vec(2.3_a, 2.3_a, 2.3_a), Val(1_a), Vec(3.3_a, 3.3_a, 3.3_a)), C(Val(1_a), Vec(2.3_a, 2.3_a, 2.3_a), Vec(3.3_a, 3.3_a, 3.3_a)), C(Mat({1_a, 2_a}, // {1_a, 2_a}, // {1_a, 2_a}), // Mat({1.2_a, 2.3_a}, // {1.2_a, 2.3_a}, // {1.2_a, 2.3_a}), // Mat({2.2_a, 4.3_a}, // {2.2_a, 4.3_a}, // {2.2_a, 4.3_a})), // C(Mat({1.2_a, 2.3_a}, // {1.2_a, 2.3_a}, // {1.2_a, 2.3_a}), // Mat({1_a, 2_a}, // {1_a, 2_a}, // {1_a, 2_a}), // Mat({2.2_a, 4.3_a}, // {2.2_a, 4.3_a}, // {2.2_a, 4.3_a})) // ))); // AInt left shift negative value -> error TEST_F(ResolverConstEvalTest, BinaryAbstractShiftLeftByNegativeValue_Error) { GlobalConst("c", Shl(Expr(1_a), Expr(Source{{1, 1}}, -1_a))); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: value -1 cannot be represented as 'u32'"); } // AInt left shift by AInt or u32 always results in an AInt TEST_F(ResolverConstEvalTest, BinaryAbstractShiftLeftRemainsAbstract) { auto* expr1 = Shl(Expr(1_a), Expr(1_u)); GlobalConst("c1", expr1); auto* expr2 = Shl(Expr(1_a), Expr(1_a)); GlobalConst("c2", expr2); EXPECT_TRUE(r()->Resolve()) << r()->error(); auto* sem1 = Sem().Get(expr1); ASSERT_NE(sem1, nullptr); auto* sem2 = Sem().Get(expr2); ASSERT_NE(sem2, nullptr); auto aint_ty = create(); EXPECT_EQ(sem1->Type(), aint_ty); EXPECT_EQ(sem2->Type(), aint_ty); } // i32/u32 left shift by >= 32 -> error using ResolverConstEvalShiftLeftConcreteGeqBitWidthError = ResolverTestWithParam; TEST_P(ResolverConstEvalShiftLeftConcreteGeqBitWidthError, Test) { auto* lhs_expr = GetParam().lhs.Expr(*this); auto* rhs_expr = GetParam().rhs.Expr(*this); GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ( r()->error(), "1:1 error: shift left value must be less than the bit width of the lhs, which is 32"); } INSTANTIATE_TEST_SUITE_P(Test, ResolverConstEvalShiftLeftConcreteGeqBitWidthError, testing::Values( // ErrorCase{Val(0_u), Val(32_u)}, // ErrorCase{Val(0_u), Val(33_u)}, // ErrorCase{Val(0_u), Val(34_u)}, // ErrorCase{Val(0_u), Val(10000_u)}, // ErrorCase{Val(0_u), Val(u32::Highest())}, // ErrorCase{Val(0_i), Val(32_u)}, // ErrorCase{Val(0_i), Val(33_u)}, // ErrorCase{Val(0_i), Val(34_u)}, // ErrorCase{Val(0_i), Val(10000_u)}, // ErrorCase{Val(0_i), Val(u32::Highest())}, // ErrorCase{Val(Negate(0_u)), Val(32_u)}, // ErrorCase{Val(Negate(0_u)), Val(33_u)}, // ErrorCase{Val(Negate(0_u)), Val(34_u)}, // ErrorCase{Val(Negate(0_u)), Val(10000_u)}, // ErrorCase{Val(Negate(0_u)), Val(u32::Highest())}, // ErrorCase{Val(Negate(0_i)), Val(32_u)}, // ErrorCase{Val(Negate(0_i)), Val(33_u)}, // ErrorCase{Val(Negate(0_i)), Val(34_u)}, // ErrorCase{Val(Negate(0_i)), Val(10000_u)}, // ErrorCase{Val(Negate(0_i)), Val(u32::Highest())}, // ErrorCase{Val(1_i), Val(32_u)}, // ErrorCase{Val(1_i), Val(33_u)}, // ErrorCase{Val(1_i), Val(34_u)}, // ErrorCase{Val(1_i), Val(10000_u)}, // ErrorCase{Val(1_i), Val(u32::Highest())}, // ErrorCase{Val(1_u), Val(32_u)}, // ErrorCase{Val(1_u), Val(33_u)}, // ErrorCase{Val(1_u), Val(34_u)}, // ErrorCase{Val(1_u), Val(10000_u)}, // ErrorCase{Val(1_u), Val(u32::Highest())} // )); // AInt left shift results in sign change error using ResolverConstEvalShiftLeftSignChangeError = ResolverTestWithParam; TEST_P(ResolverConstEvalShiftLeftSignChangeError, Test) { auto* lhs_expr = GetParam().lhs.Expr(*this); auto* rhs_expr = GetParam().rhs.Expr(*this); GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: shift left operation results in sign change"); } template std::vector ShiftLeftSignChangeErrorCases() { // Shift type is u32 for non-abstract using ST = std::conditional_t, T, u32>; using B = BitValues; return { {Val(T{0b0001}), Val(ST{B::NumBits - 1})}, {Val(T{0b0010}), Val(ST{B::NumBits - 2})}, {Val(T{0b0100}), Val(ST{B::NumBits - 3})}, {Val(T{0b1000}), Val(ST{B::NumBits - 4})}, {Val(T{0b0011}), Val(ST{B::NumBits - 2})}, {Val(T{0b0110}), Val(ST{B::NumBits - 3})}, {Val(T{0b1100}), Val(ST{B::NumBits - 4})}, {Val(B::AllButLeftMost), Val(ST{1})}, {Val(B::AllButLeftMost), Val(ST{B::NumBits - 1})}, {Val(B::LeftMost), Val(ST{1})}, {Val(B::LeftMost), Val(ST{B::NumBits - 1})}, }; } INSTANTIATE_TEST_SUITE_P(Test, ResolverConstEvalShiftLeftSignChangeError, testing::ValuesIn(Concat( // ShiftLeftSignChangeErrorCases(), ShiftLeftSignChangeErrorCases()))); template std::vector ShiftRightCases() { using B = BitValues; auto r = std::vector{ C(T{0b10101100}, u32{0}, T{0b10101100}), // C(T{0b10101100}, u32{1}, T{0b01010110}), // C(T{0b10101100}, u32{2}, T{0b00101011}), // C(T{0b10101100}, u32{3}, T{0b00010101}), // C(T{0b10101100}, u32{4}, T{0b00001010}), // C(T{0b10101100}, u32{5}, T{0b00000101}), // C(T{0b10101100}, u32{6}, T{0b00000010}), // C(T{0b10101100}, u32{7}, T{0b00000001}), // C(T{0b10101100}, u32{8}, T{0b00000000}), // C(T{0b10101100}, u32{9}, T{0b00000000}), // C(B::LeftMost, u32{0}, B::LeftMost), // }; // msb not set, same for all types: inserted bit is 0 ConcatInto( // r, std::vector{ C(T{0b01000000000000000000000010101100}, u32{0}, // T{0b01000000000000000000000010101100}), C(T{0b01000000000000000000000010101100}, u32{1}, // T{0b00100000000000000000000001010110}), C(T{0b01000000000000000000000010101100}, u32{2}, // T{0b00010000000000000000000000101011}), C(T{0b01000000000000000000000010101100}, u32{3}, // T{0b00001000000000000000000000010101}), C(T{0b01000000000000000000000010101100}, u32{4}, // T{0b00000100000000000000000000001010}), C(T{0b01000000000000000000000010101100}, u32{5}, // T{0b00000010000000000000000000000101}), C(T{0b01000000000000000000000010101100}, u32{6}, // T{0b00000001000000000000000000000010}), C(T{0b01000000000000000000000010101100}, u32{7}, // T{0b00000000100000000000000000000001}), C(T{0b01000000000000000000000010101100}, u32{8}, // T{0b00000000010000000000000000000000}), C(T{0b01000000000000000000000010101100}, u32{9}, // T{0b00000000001000000000000000000000}), }); // msb set, result differs for i32 and u32 if constexpr (std::is_same_v) { // If unsigned, insert zero bits at the most significant positions. ConcatInto( // r, std::vector{ C(T{0b10000000000000000000000010101100}, u32{0}, T{0b10000000000000000000000010101100}), C(T{0b10000000000000000000000010101100}, u32{1}, T{0b01000000000000000000000001010110}), C(T{0b10000000000000000000000010101100}, u32{2}, T{0b00100000000000000000000000101011}), C(T{0b10000000000000000000000010101100}, u32{3}, T{0b00010000000000000000000000010101}), C(T{0b10000000000000000000000010101100}, u32{4}, T{0b00001000000000000000000000001010}), C(T{0b10000000000000000000000010101100}, u32{5}, T{0b00000100000000000000000000000101}), C(T{0b10000000000000000000000010101100}, u32{6}, T{0b00000010000000000000000000000010}), C(T{0b10000000000000000000000010101100}, u32{7}, T{0b00000001000000000000000000000001}), C(T{0b10000000000000000000000010101100}, u32{8}, T{0b00000000100000000000000000000000}), C(T{0b10000000000000000000000010101100}, u32{9}, T{0b00000000010000000000000000000000}), // msb shifted by bit width - 1 C(T{0b10000000000000000000000000000000}, u32{31}, T{0b00000000000000000000000000000001}), }); } else if constexpr (std::is_same_v) { // If signed, each inserted bit is 1, so the result is negative. ConcatInto( // r, std::vector{ C(T{0b10000000000000000000000010101100}, u32{0}, T{0b10000000000000000000000010101100}), // C(T{0b10000000000000000000000010101100}, u32{1}, T{0b11000000000000000000000001010110}), // C(T{0b10000000000000000000000010101100}, u32{2}, T{0b11100000000000000000000000101011}), // C(T{0b10000000000000000000000010101100}, u32{3}, T{0b11110000000000000000000000010101}), // C(T{0b10000000000000000000000010101100}, u32{4}, T{0b11111000000000000000000000001010}), // C(T{0b10000000000000000000000010101100}, u32{5}, T{0b11111100000000000000000000000101}), // C(T{0b10000000000000000000000010101100}, u32{6}, T{0b11111110000000000000000000000010}), // C(T{0b10000000000000000000000010101100}, u32{7}, T{0b11111111000000000000000000000001}), // C(T{0b10000000000000000000000010101100}, u32{8}, T{0b11111111100000000000000000000000}), // C(T{0b10000000000000000000000010101100}, u32{9}, T{0b11111111110000000000000000000000}), // // msb shifted by bit width - 1 C(T{0b10000000000000000000000000000000}, u32{31}, T{0b11111111111111111111111111111111}), }); } // Test shift right by bit width or more if constexpr (IsAbstract) { // For abstract int, no error, result is 0 ConcatInto( // r, std::vector{ C(T{0}, u32{B::NumBits}, T{0}), C(T{0}, u32{B::NumBits + 1}, T{0}), C(T{0}, u32{B::NumBits + 1000}, T{0}), C(T{42}, u32{B::NumBits}, T{0}), C(T{42}, u32{B::NumBits + 1}, T{0}), C(T{42}, u32{B::NumBits + 1000}, T{0}), }); } else { // For concretes, error const char* error_msg = "12:34 error: shift right value must be less than the bit width of the lhs, which is " "32"; ConcatInto( // r, std::vector{ E(T{0}, u32{B::NumBits}, error_msg), E(T{0}, u32{B::NumBits + 1}, error_msg), E(T{0}, u32{B::NumBits + 1000}, error_msg), E(T{42}, u32{B::NumBits}, error_msg), E(T{42}, u32{B::NumBits + 1}, error_msg), E(T{42}, u32{B::NumBits + 1000}, error_msg), }); } return r; } INSTANTIATE_TEST_SUITE_P(ShiftRight, ResolverConstEvalBinaryOpTest, testing::Combine( // testing::Values(ast::BinaryOp::kShiftRight), testing::ValuesIn(Concat(ShiftRightCases(), // ShiftRightCases(), // ShiftRightCases())))); namespace LogicalShortCircuit { /// Validates that `binary` is a short-circuiting logical and expression static void ValidateAnd(const sem::Info& sem, const ast::BinaryExpression* binary) { auto* lhs = binary->lhs; auto* rhs = binary->rhs; auto* lhs_sem = sem.GetVal(lhs); ASSERT_TRUE(lhs_sem->ConstantValue()); EXPECT_EQ(lhs_sem->ConstantValue()->ValueAs(), false); EXPECT_EQ(lhs_sem->Stage(), sem::EvaluationStage::kConstant); auto* rhs_sem = sem.GetVal(rhs); EXPECT_EQ(rhs_sem->ConstantValue(), nullptr); EXPECT_EQ(rhs_sem->Stage(), sem::EvaluationStage::kNotEvaluated); auto* binary_sem = sem.Get(binary); ASSERT_TRUE(binary_sem->ConstantValue()); EXPECT_EQ(binary_sem->ConstantValue()->ValueAs(), false); EXPECT_EQ(binary_sem->Stage(), sem::EvaluationStage::kConstant); } /// Validates that `binary` is a short-circuiting logical or expression static void ValidateOr(const sem::Info& sem, const ast::BinaryExpression* binary) { auto* lhs = binary->lhs; auto* rhs = binary->rhs; auto* lhs_sem = sem.GetVal(lhs); ASSERT_TRUE(lhs_sem->ConstantValue()); EXPECT_EQ(lhs_sem->ConstantValue()->ValueAs(), true); EXPECT_EQ(lhs_sem->Stage(), sem::EvaluationStage::kConstant); auto* rhs_sem = sem.GetVal(rhs); EXPECT_EQ(rhs_sem->ConstantValue(), nullptr); EXPECT_EQ(rhs_sem->Stage(), sem::EvaluationStage::kNotEvaluated); auto* binary_sem = sem.Get(binary); ASSERT_TRUE(binary_sem->ConstantValue()); EXPECT_EQ(binary_sem->ConstantValue()->ValueAs(), true); EXPECT_EQ(binary_sem->Stage(), sem::EvaluationStage::kConstant); } // Naming convention for tests below: // // [Non]ShortCircuit_[And|Or]_[Error|Invalid]_ // // Where: // ShortCircuit: the rhs will not be const-evaluated // NonShortCircuitL the rhs will be const-evaluated // // And/Or: type of binary expression // // Error: a non-const evaluation error (e.g. parser or validation error) // Invalid: a const-evaluation error // // the type of operation on the rhs that may or may not be short-circuited. //////////////////////////////////////////////// // Short-Circuit Unary //////////////////////////////////////////////// // NOTE: Cannot demonstrate short-circuiting an invalid unary op as const eval of unary does not // fail. TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Unary) { // const one = 1; // const result = (one == 0) && (!0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Not(Source{{12, 34}}, 0_a); GlobalConst("result", LogicalAnd(lhs, rhs)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator ! (abstract-int) 2 candidate operators: operator ! (bool) -> bool operator ! (vecN) -> vecN )"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Unary) { // const one = 1; // const result = (one == 1) || (!0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Not(Source{{12, 34}}, 0_a); GlobalConst("result", LogicalOr(lhs, rhs)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator ! (abstract-int) 2 candidate operators: operator ! (bool) -> bool operator ! (vecN) -> vecN )"); } //////////////////////////////////////////////// // Short-Circuit Binary //////////////////////////////////////////////// TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Binary) { // const one = 1; // const result = (one == 0) && ((2 / 0) == 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Div(2_a, 0_a), 0_a); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateAnd(Sem(), binary); } TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Binary) { // const one = 1; // const result = (one == 1) && ((2 / 0) == 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Div(Source{{12, 34}}, 2_a, 0_a), 0_a); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: '2 / 0' cannot be represented as 'abstract-int'"); } TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Binary) { // const one = 1; // const result = (one == 0) && (2 / 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Div(2_a, 0_a); auto* binary = LogicalAnd(Source{{12, 34}}, lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator && (bool, abstract-int) 1 candidate operator: operator && (bool, bool) -> bool )"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Binary) { // const one = 1; // const result = (one == 1) || ((2 / 0) == 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Div(2_a, 0_a), 0_a); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateOr(Sem(), binary); } TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Binary) { // const one = 1; // const result = (one == 0) || ((2 / 0) == 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Div(Source{{12, 34}}, 2_a, 0_a), 0_a); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: '2 / 0' cannot be represented as 'abstract-int'"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Binary) { // const one = 1; // const result = (one == 1) || (2 / 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Div(2_a, 0_a); auto* binary = LogicalOr(Source{{12, 34}}, lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator || (bool, abstract-int) 1 candidate operator: operator || (bool, bool) -> bool )"); } //////////////////////////////////////////////// // Short-Circuit Materialize //////////////////////////////////////////////// TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Materialize) { // const one = 1; // const result = (one == 0) && (1.7976931348623157e+308 == 0.0f); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Expr(1.7976931348623157e+308_a), 0_f); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateAnd(Sem(), binary); } TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Materialize) { // const one = 1; // const result = (one == 1) && (1.7976931348623157e+308 == 0.0f); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Expr(Source{{12, 34}}, 1.7976931348623157e+308_a), 0_f); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: value 1.7976931348623157081e+308 cannot be represented as 'f32'"); } TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Materialize) { // const one = 1; // const result = (one == 0) && (1.7976931348623157e+308 == 0i); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Source{{12, 34}}, 1.7976931348623157e+308_a, 0_i); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator == (abstract-float, i32) 2 candidate operators: operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool operator == (vecN, vecN) -> vecN where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool )"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Materialize) { // const one = 1; // const result = (one == 1) || (1.7976931348623157e+308 == 0.0f); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(1.7976931348623157e+308_a, 0_f); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateOr(Sem(), binary); } TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Materialize) { // const one = 1; // const result = (one == 0) || (1.7976931348623157e+308 == 0.0f); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Expr(Source{{12, 34}}, 1.7976931348623157e+308_a), 0_f); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: value 1.7976931348623157081e+308 cannot be represented as 'f32'"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Materialize) { // const one = 1; // const result = (one == 1) || (1.7976931348623157e+308 == 0i); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Source{{12, 34}}, Expr(1.7976931348623157e+308_a), 0_i); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator == (abstract-float, i32) 2 candidate operators: operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool operator == (vecN, vecN) -> vecN where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool )"); } //////////////////////////////////////////////// // Short-Circuit Index //////////////////////////////////////////////// TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Index) { // const one = 1; // const a = array(1i, 2i, 3i); // const i = 4; // const result = (one == 0) && (a[i] == 0); GlobalConst("one", Expr(1_a)); GlobalConst("a", array(1_i, 2_i, 3_i)); GlobalConst("i", Expr(4_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(IndexAccessor("a", "i"), 0_a); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateAnd(Sem(), binary); } TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Index) { // const one = 1; // const a = array(1i, 2i, 3i); // const i = 3; // const result = (one == 1) && (a[i] == 0); GlobalConst("one", Expr(1_a)); GlobalConst("a", array(1_i, 2_i, 3_i)); GlobalConst("i", Expr(3_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(IndexAccessor("a", Expr(Source{{12, 34}}, "i")), 0_a); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: index 3 out of bounds [0..2]"); } TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Index) { // const one = 1; // const a = array(1i, 2i, 3i); // const i = 3; // const result = (one == 0) && (a[i] == 0.0f); GlobalConst("one", Expr(1_a)); GlobalConst("a", array(1_i, 2_i, 3_i)); GlobalConst("i", Expr(3_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Source{{12, 34}}, IndexAccessor("a", "i"), 0.0_f); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator == (i32, f32) 2 candidate operators: operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool operator == (vecN, vecN) -> vecN where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool )"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Index) { // const one = 1; // const a = array(1i, 2i, 3i); // const i = 4; // const result = (one == 1) || (a[i] == 0); GlobalConst("one", Expr(1_a)); GlobalConst("a", array(1_i, 2_i, 3_i)); GlobalConst("i", Expr(4_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(IndexAccessor("a", "i"), 0_a); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateOr(Sem(), binary); } TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Index) { // const one = 1; // const a = array(1i, 2i, 3i); // const i = 3; // const result = (one == 0) || (a[i] == 0); GlobalConst("one", Expr(1_a)); GlobalConst("a", array(1_i, 2_i, 3_i)); GlobalConst("i", Expr(3_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(IndexAccessor("a", Expr(Source{{12, 34}}, "i")), 0_a); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: index 3 out of bounds [0..2]"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Index) { // const one = 1; // const a = array(1i, 2i, 3i); // const i = 3; // const result = (one == 1) || (a[i] == 0.0f); GlobalConst("one", Expr(1_a)); GlobalConst("a", array(1_i, 2_i, 3_i)); GlobalConst("i", Expr(3_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Source{{12, 34}}, IndexAccessor("a", "i"), 0.0_f); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator == (i32, f32) 2 candidate operators: operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool operator == (vecN, vecN) -> vecN where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool )"); } //////////////////////////////////////////////// // Short-Circuit Bitcast //////////////////////////////////////////////// TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_Bitcast) { // const one = 1; // const a = 0x7F800000; // const result = (one == 0) && (bitcast(a) == 0.0); GlobalConst("one", Expr(1_a)); GlobalConst("a", Expr(0x7F800000_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Bitcast("a"), 0.0_a); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateAnd(Sem(), binary); } TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_Bitcast) { // const one = 1; // const a = 0x7F800000; // const result = (one == 1) && (bitcast(a) == 0.0); GlobalConst("one", Expr(1_a)); GlobalConst("a", Expr(0x7F800000_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Bitcast(Source{{12, 34}}, ty.f32(), "a"), 0.0_a); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: value inf cannot be represented as 'f32'"); } TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Bitcast) { // const one = 1; // const a = 0x7F800000; // const result = (one == 0) && (bitcast(a) == 0i); GlobalConst("one", Expr(1_a)); GlobalConst("a", Expr(0x7F800000_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Source{{12, 34}}, Bitcast("a"), 0_i); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator == (f32, i32) 2 candidate operators: operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool operator == (vecN, vecN) -> vecN where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool )"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_Bitcast) { // const one = 1; // const a = 0x7F800000; // const result = (one == 1) || (bitcast(a) == 0.0); GlobalConst("one", Expr(1_a)); GlobalConst("a", Expr(0x7F800000_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Bitcast("a"), 0.0_a); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateOr(Sem(), binary); } TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_Bitcast) { // const one = 1; // const a = 0x7F800000; // const result = (one == 0) || (bitcast(a) == 0.0); GlobalConst("one", Expr(1_a)); GlobalConst("a", Expr(0x7F800000_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Bitcast(Source{{12, 34}}, ty.f32(), "a"), 0.0_a); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: value inf cannot be represented as 'f32'"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Bitcast) { // const one = 1; // const a = 0x7F800000; // const result = (one == 1) || (bitcast(a) == 0i); GlobalConst("one", Expr(1_a)); GlobalConst("a", Expr(0x7F800000_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Source{{12, 34}}, Bitcast("a"), 0_i); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator == (f32, i32) 2 candidate operators: operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool operator == (vecN, vecN) -> vecN where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool )"); } //////////////////////////////////////////////// // Short-Circuit value construction / conversion //////////////////////////////////////////////// // NOTE: Cannot demonstrate short-circuiting an invalid init/convert as const eval of init/convert // always succeeds. TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Init) { // const one = 1; // const result = (one == 0) && (vec2(1.0, true).x == 0.0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(MemberAccessor(vec2(Source{{12, 34}}, 1.0_a, Expr(true)), "x"), 0.0_a); GlobalConst("result", LogicalAnd(lhs, rhs)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching constructor for vec2(abstract-float, bool) 4 candidate constructors: vec2(x: T, y: T) -> vec2 where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool vec2(T) -> vec2 where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool vec2(vec2) -> vec2 where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool vec2() -> vec2 where: T is f32, f16, i32, u32 or bool 5 candidate conversions: vec2(vec2) -> vec2 where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool vec2(vec2) -> vec2 where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool vec2(vec2) -> vec2 where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool vec2(vec2) -> vec2 where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool vec2(vec2) -> vec2 where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32 )"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Init) { // const one = 1; // const result = (one == 1) || (vec2(1.0, true).x == 0.0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(MemberAccessor(vec2(Source{{12, 34}}, 1.0_a, Expr(true)), "x"), 0.0_a); GlobalConst("result", LogicalOr(lhs, rhs)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching constructor for vec2(abstract-float, bool) 4 candidate constructors: vec2(x: T, y: T) -> vec2 where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool vec2(T) -> vec2 where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool vec2(vec2) -> vec2 where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool vec2() -> vec2 where: T is f32, f16, i32, u32 or bool 5 candidate conversions: vec2(vec2) -> vec2 where: T is f32, U is abstract-int, abstract-float, i32, f16, u32 or bool vec2(vec2) -> vec2 where: T is f16, U is abstract-int, abstract-float, f32, i32, u32 or bool vec2(vec2) -> vec2 where: T is i32, U is abstract-int, abstract-float, f32, f16, u32 or bool vec2(vec2) -> vec2 where: T is u32, U is abstract-int, abstract-float, f32, f16, i32 or bool vec2(vec2) -> vec2 where: T is bool, U is abstract-int, abstract-float, f32, f16, i32 or u32 )"); } //////////////////////////////////////////////// // Short-Circuit Array/Struct Init //////////////////////////////////////////////// // NOTE: Cannot demonstrate short-circuiting an invalid array/struct init as const eval of // array/struct init always succeeds. TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_StructInit) { // struct S { // a : i32, // b : f32, // } // const one = 1; // const result = (one == 0) && Foo(1, true).a == 0; Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())}); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(MemberAccessor(Call("S", Expr(1_a), Expr(Source{{12, 34}}, true)), "a"), 0_a); GlobalConst("result", LogicalAnd(lhs, rhs)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: type in structure constructor does not match struct member type: " "expected 'f32', found 'bool'"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_StructInit) { // struct S { // a : i32, // b : f32, // } // const one = 1; // const result = (one == 1) || Foo(1, true).a == 0; Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())}); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(MemberAccessor(Call("S", Expr(1_a), Expr(Source{{12, 34}}, true)), "a"), 0_a); GlobalConst("result", LogicalOr(lhs, rhs)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: type in structure constructor does not match struct member type: " "expected 'f32', found 'bool'"); } //////////////////////////////////////////////// // Short-Circuit Builtin Call //////////////////////////////////////////////// TEST_F(ResolverConstEvalTest, ShortCircuit_And_Invalid_BuiltinCall) { // const one = 1; // return (one == 0) && (extractBits(1, 0, 99) == 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Call("extractBits", 1_a, 0_a, 99_a), 0_a); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateAnd(Sem(), binary); } TEST_F(ResolverConstEvalTest, NonShortCircuit_And_Invalid_BuiltinCall) { // const one = 1; // return (one == 1) && (extractBits(1, 0, 99) == 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Call(Source{{12, 34}}, "extractBits", 1_a, 0_a, 99_a), 0_a); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'"); } TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_BuiltinCall) { // const one = 1; // return (one == 0) && (extractBits(1, 0, 99) == 0.0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Source{{12, 34}}, Call("extractBits", 1_a, 0_a, 99_a), 0.0_a); auto* binary = LogicalAnd(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator == (i32, abstract-float) 2 candidate operators: operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool operator == (vecN, vecN) -> vecN where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool )"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Invalid_BuiltinCall) { // const one = 1; // return (one == 1) || (extractBits(1, 0, 99) == 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Call("extractBits", 1_a, 0_a, 99_a), 0_a); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateOr(Sem(), binary); } TEST_F(ResolverConstEvalTest, NonShortCircuit_Or_Invalid_BuiltinCall) { // const one = 1; // return (one == 0) || (extractBits(1, 0, 99) == 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(Call(Source{{12, 34}}, "extractBits", 1_a, 0_a, 99_a), 0_a); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_BuiltinCall) { // const one = 1; // return (one == 1) || (extractBits(1, 0, 99) == 0.0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(Source{{12, 34}}, Call("extractBits", 1_a, 0_a, 99_a), 0.0_a); auto* binary = LogicalOr(lhs, rhs); GlobalConst("result", binary); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), R"(12:34 error: no matching overload for operator == (i32, abstract-float) 2 candidate operators: operator == (T, T) -> bool where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool operator == (vecN, vecN) -> vecN where: T is abstract-int, abstract-float, f32, f16, i32, u32 or bool )"); } //////////////////////////////////////////////// // Short-Circuit Literal //////////////////////////////////////////////// // NOTE: Cannot demonstrate short-circuiting an invalid literal as const eval of a literal does not // fail. #if TINT_BUILD_WGSL_READER TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Literal) { // NOTE: This fails parsing rather than resolving, which is why we can't use the ProgramBuilder // for this test. auto src = R"( const one = 1; const result = (one == 0) && (1111111111111111111111111111111i == 0); )"; auto file = std::make_unique("test", src); auto program = reader::wgsl::Parse(file.get()); EXPECT_FALSE(program.IsValid()); diag::Formatter::Style style; style.print_newline_at_end = false; auto error = diag::Formatter(style).format(program.Diagnostics()); EXPECT_EQ(error, R"(test:3:31 error: value cannot be represented as 'i32' const result = (one == 0) && (1111111111111111111111111111111i == 0); ^ )"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Literal) { // NOTE: This fails parsing rather than resolving, which is why we can't use the ProgramBuilder // for this test. auto src = R"( const one = 1; const result = (one == 1) || (1111111111111111111111111111111i == 0); )"; auto file = std::make_unique("test", src); auto program = reader::wgsl::Parse(file.get()); EXPECT_FALSE(program.IsValid()); diag::Formatter::Style style; style.print_newline_at_end = false; auto error = diag::Formatter(style).format(program.Diagnostics()); EXPECT_EQ(error, R"(test:3:31 error: value cannot be represented as 'i32' const result = (one == 1) || (1111111111111111111111111111111i == 0); ^ )"); } #endif // TINT_BUILD_WGSL_READER //////////////////////////////////////////////// // Short-Circuit Member Access //////////////////////////////////////////////// // NOTE: Cannot demonstrate short-circuiting an invalid member access as const eval of member access // always succeeds. TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_MemberAccess) { // struct S { // a : i32, // b : f32, // } // const s = S(1, 2.0); // const one = 1; // const result = (one == 0) && (s.c == 0); Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())}); GlobalConst("s", Call("S", Expr(1_a), Expr(2.0_a))); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(MemberAccessor(Source{{12, 34}}, "s", "c"), 0_a); GlobalConst("result", LogicalAnd(lhs, rhs)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: struct member c not found"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_MemberAccess) { // struct S { // a : i32, // b : f32, // } // const s = S(1, 2.0); // const one = 1; // const result = (one == 1) || (s.c == 0); Structure("S", utils::Vector{Member("a", ty.i32()), Member("b", ty.f32())}); GlobalConst("s", Call("S", Expr(1_a), Expr(2.0_a))); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(MemberAccessor(Source{{12, 34}}, "s", "c"), 0_a); GlobalConst("result", LogicalOr(lhs, rhs)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: struct member c not found"); } //////////////////////////////////////////////// // Short-Circuit Swizzle //////////////////////////////////////////////// // NOTE: Cannot demonstrate short-circuiting an invalid swizzle as const eval of swizzle always // succeeds. TEST_F(ResolverConstEvalTest, ShortCircuit_And_Error_Swizzle) { // const one = 1; // const result = (one == 0) && (vec2(1, 2).z == 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 0_a); auto* rhs = Equal(MemberAccessor(vec2(1_a, 2_a), Ident(Source{{12, 34}}, "z")), 0_a); GlobalConst("result", LogicalAnd(lhs, rhs)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: invalid vector swizzle member"); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_Error_Swizzle) { // const one = 1; // const result = (one == 1) || (vec2(1, 2).z == 0); GlobalConst("one", Expr(1_a)); auto* lhs = Equal("one", 1_a); auto* rhs = Equal(MemberAccessor(vec2(1_a, 2_a), Ident(Source{{12, 34}}, "z")), 0_a); GlobalConst("result", LogicalOr(lhs, rhs)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: invalid vector swizzle member"); } //////////////////////////////////////////////// // Short-Circuit Mixed Constant and Runtime //////////////////////////////////////////////// TEST_F(ResolverConstEvalTest, ShortCircuit_And_MixedConstantAndRuntime) { // var j : i32; // let result = false && j < (0 - 8); auto* j = Decl(Var("j", ty.i32())); auto* binary = LogicalAnd(Expr(false), LessThan("j", Sub(0_a, 8_a))); auto* result = Let("result", binary); WrapInFunction(j, result); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateAnd(Sem(), binary); } TEST_F(ResolverConstEvalTest, ShortCircuit_Or_MixedConstantAndRuntime) { // var j : i32; // let result = true || j < (0 - 8); auto* j = Decl(Var("j", ty.i32())); auto* binary = LogicalOr(Expr(true), LessThan("j", Sub(0_a, 8_a))); auto* result = Let("result", binary); WrapInFunction(j, result); EXPECT_TRUE(r()->Resolve()) << r()->error(); ValidateOr(Sem(), binary); } //////////////////////////////////////////////// // Short-Circuit Nested //////////////////////////////////////////////// #if TINT_BUILD_WGSL_READER using ResolverConstEvalTestShortCircuit = ResolverTestWithParam>; TEST_P(ResolverConstEvalTestShortCircuit, Test) { const char* expr = std::get<0>(GetParam()); bool should_pass = std::get<1>(GetParam()); auto src = std::string(R"( const one = 1; const result = )"); src = src + expr + ";"; auto file = std::make_unique("test", src); auto program = reader::wgsl::Parse(file.get()); if (should_pass) { diag::Formatter::Style style; style.print_newline_at_end = false; auto error = diag::Formatter(style).format(program.Diagnostics()); EXPECT_TRUE(program.IsValid()) << error; } else { EXPECT_FALSE(program.IsValid()); } } INSTANTIATE_TEST_SUITE_P(Nested, ResolverConstEvalTestShortCircuit, testing::ValuesIn(std::vector>{ // AND nested rhs {"(one == 0) && ((one == 0) && ((2/0)==0))", true}, {"(one == 1) && ((one == 0) && ((2/0)==0))", true}, {"(one == 0) && ((one == 1) && ((2/0)==0))", true}, {"(one == 1) && ((one == 1) && ((2/0)==0))", false}, // AND nested lhs {"((one == 0) && ((2/0)==0)) && (one == 0)", true}, {"((one == 0) && ((2/0)==0)) && (one == 1)", true}, {"((one == 1) && ((2/0)==0)) && (one == 0)", false}, {"((one == 1) && ((2/0)==0)) && (one == 1)", false}, // OR nested rhs {"(one == 1) || ((one == 1) || ((2/0)==0))", true}, {"(one == 0) || ((one == 1) || ((2/0)==0))", true}, {"(one == 1) || ((one == 0) || ((2/0)==0))", true}, {"(one == 0) || ((one == 0) || ((2/0)==0))", false}, // OR nested lhs {"((one == 1) || ((2/0)==0)) || (one == 1)", true}, {"((one == 1) || ((2/0)==0)) || (one == 0)", true}, {"((one == 0) || ((2/0)==0)) || (one == 1)", false}, {"((one == 0) || ((2/0)==0)) || (one == 0)", false}, // AND nested both sides {"((one == 0) && ((2/0)==0)) && ((one == 0) && ((2/0)==0))", true}, {"((one == 0) && ((2/0)==0)) && ((one == 1) && ((2/0)==0))", true}, {"((one == 1) && ((2/0)==0)) && ((one == 0) && ((2/0)==0))", false}, {"((one == 1) && ((2/0)==0)) && ((one == 1) && ((2/0)==0))", false}, // OR nested both sides {"((one == 1) || ((2/0)==0)) && ((one == 1) || ((2/0)==0))", true}, {"((one == 1) || ((2/0)==0)) && ((one == 0) || ((2/0)==0))", false}, {"((one == 0) || ((2/0)==0)) && ((one == 1) || ((2/0)==0))", false}, {"((one == 0) || ((2/0)==0)) && ((one == 0) || ((2/0)==0))", false}, // AND chained {"(one == 0) && (one == 0) && ((2 / 0) == 0)", true}, {"(one == 1) && (one == 0) && ((2 / 0) == 0)", true}, {"(one == 0) && (one == 1) && ((2 / 0) == 0)", true}, {"(one == 1) && (one == 1) && ((2 / 0) == 0)", false}, // OR chained {"(one == 1) || (one == 1) || ((2 / 0) == 0)", true}, {"(one == 0) || (one == 1) || ((2 / 0) == 0)", true}, {"(one == 1) || (one == 0) || ((2 / 0) == 0)", true}, {"(one == 0) || (one == 0) || ((2 / 0) == 0)", false}, })); #endif // TINT_BUILD_WGSL_READER } // namespace LogicalShortCircuit } // namespace } // namespace tint::resolver