diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 8dc4728a55..a7f5d7d480 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -2092,19 +2092,19 @@ ConstEval::Result ConstEval::insertBits(const sem::Type* ty, NumberUT in_offset = args[2]->As(); NumberUT in_count = args[3]->As(); - constexpr UT w = sizeof(UT) * 8; - if ((in_offset + in_count) > w) { - AddError("'offset + 'count' must be less than or equal to the bit width of 'e'", - source); - return utils::Failure; - } - // Cast all to unsigned UT e = static_cast(in_e); UT newbits = static_cast(in_newbits); UT o = static_cast(in_offset); UT c = static_cast(in_count); + constexpr UT w = sizeof(UT) * 8; + if (o > w || c > w || (o + c) > w) { + AddError("'offset + 'count' must be less than or equal to the bit width of 'e'", + source); + return utils::Failure; + } + NumberT result; if (c == UT{0}) { // The result is e if c is 0 diff --git a/src/tint/resolver/const_eval_builtin_test.cc b/src/tint/resolver/const_eval_builtin_test.cc index 69a30a4c8f..cc950abe72 100644 --- a/src/tint/resolver/const_eval_builtin_test.cc +++ b/src/tint/resolver/const_eval_builtin_test.cc @@ -1150,6 +1150,26 @@ std::vector InsertBitsCases() { T(0b1010'0101'1010'0101'1010'0111'1111'1101))), }; + const char* error_msg = + "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'"; + ConcatInto( // + r, std::vector{ + E({T(1), T(1), UT(33), UT(0)}, error_msg), // + E({T(1), T(1), UT(34), UT(0)}, error_msg), // + E({T(1), T(1), UT(1000), UT(0)}, error_msg), // + E({T(1), T(1), UT::Highest(), UT()}, error_msg), // + E({T(1), T(1), UT(0), UT(33)}, error_msg), // + E({T(1), T(1), UT(0), UT(34)}, error_msg), // + E({T(1), T(1), UT(0), UT(1000)}, error_msg), // + E({T(1), T(1), UT(0), UT::Highest()}, error_msg), // + E({T(1), T(1), UT(33), UT(33)}, error_msg), // + E({T(1), T(1), UT(34), UT(34)}, error_msg), // + E({T(1), T(1), UT(1000), UT(1000)}, error_msg), // + E({T(1), T(1), UT::Highest(), UT(1)}, error_msg), + E({T(1), T(1), UT(1), UT::Highest()}, error_msg), + E({T(1), T(1), UT::Highest(), u32::Highest()}, error_msg), + }); + return r; } INSTANTIATE_TEST_SUITE_P( // @@ -1253,6 +1273,26 @@ std::vector ExtractBitsCases() { set_msbs_if_signed(T(0b11010001)))), }; + const char* error_msg = + "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'"; + ConcatInto( // + r, std::vector{ + E({T(1), UT(33), UT(0)}, error_msg), + E({T(1), UT(34), UT(0)}, error_msg), + E({T(1), UT(1000), UT(0)}, error_msg), + E({T(1), UT::Highest(), UT(0)}, error_msg), + E({T(1), UT(0), UT(33)}, error_msg), + E({T(1), UT(0), UT(34)}, error_msg), + E({T(1), UT(0), UT(1000)}, error_msg), + E({T(1), UT(0), UT::Highest()}, error_msg), + E({T(1), UT(33), UT(33)}, error_msg), + E({T(1), UT(34), UT(34)}, error_msg), + E({T(1), UT(1000), UT(1000)}, error_msg), + E({T(1), UT::Highest(), UT(1)}, error_msg), + E({T(1), UT(1), UT::Highest()}, error_msg), + E({T(1), UT::Highest(), UT::Highest()}, error_msg), + }); + return r; } INSTANTIATE_TEST_SUITE_P( // @@ -1262,35 +1302,6 @@ INSTANTIATE_TEST_SUITE_P( // testing::ValuesIn(Concat(ExtractBitsCases(), // ExtractBitsCases())))); -using ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount = - ResolverTestWithParam>; -TEST_P(ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount, Test) { - auto& p = GetParam(); - auto* expr = Call(Source{{12, 34}}, sem::str(sem::BuiltinType::kExtractBits), Expr(1_u), - Expr(u32(std::get<0>(p))), Expr(u32(std::get<1>(p)))); - GlobalConst("C", expr); - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: 'offset + 'count' must be less than or equal to the bit width of 'e'"); -} -INSTANTIATE_TEST_SUITE_P(ExtractBits, - ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount, - testing::Values( // - std::make_tuple(33, 0), // - std::make_tuple(34, 0), // - std::make_tuple(1000, 0), // - std::make_tuple(u32::Highest(), 0), // - std::make_tuple(0, 33), // - std::make_tuple(0, 34), // - std::make_tuple(0, 1000), // - std::make_tuple(0, u32::Highest()), // - std::make_tuple(33, 33), // - std::make_tuple(34, 34), // - std::make_tuple(1000, 1000), // - std::make_tuple(u32::Highest(), 1), // - std::make_tuple(1, u32::Highest()), // - std::make_tuple(u32::Highest(), u32::Highest()))); - template std::vector MaxCases() { return {