diff --git a/src/tint/intrinsics.def b/src/tint/intrinsics.def index 179cf88334..d9188edfaa 100644 --- a/src/tint/intrinsics.def +++ b/src/tint/intrinsics.def @@ -993,8 +993,8 @@ op || (bool, bool) -> bool @const op << <T: ia_iu32>(T, u32) -> T @const op << <T: ia_iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T> -op >> <T: iu32>(T, u32) -> T -op >> <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T> +@const op >> <T: ia_iu32>(T, u32) -> T +@const op >> <T: ia_iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T> //////////////////////////////////////////////////////////////////////////////// // Tint internal builtins // diff --git a/src/tint/resolver/const_eval.cc b/src/tint/resolver/const_eval.cc index 8170adbe08..45ab2de0fd 100644 --- a/src/tint/resolver/const_eval.cc +++ b/src/tint/resolver/const_eval.cc @@ -1950,6 +1950,70 @@ ConstEval::Result ConstEval::OpShiftLeft(const sem::Type* ty, return TransformElements(builder, ty, transform, args[0], args[1]); } +ConstEval::Result ConstEval::OpShiftRight(const sem::Type* ty, + utils::VectorRef<const sem::Constant*> args, + const Source& source) { + auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) { + auto create = [&](auto e1, auto e2) -> ImplResult { + using NumberT = decltype(e1); + using T = UnwrapNumber<NumberT>; + using UT = std::make_unsigned_t<T>; + constexpr size_t bit_width = BitWidth<NumberT>; + const UT e1u = static_cast<UT>(e1); + const UT e2u = static_cast<UT>(e2); + + auto signed_shift_right = [&] { + // In C++, right shift of a signed negative number is implementation-defined. + // Although most implementations sign-extend, we do it manually to ensure it works + // correctly on all implementations. + const UT msb = UT{1} << (bit_width - 1); + UT sign_ext = 0; + if (e1u & msb) { + // Set e2 + 1 bits to 1 + UT num_shift_bits_mask = ((UT{1} << e2u) - UT{1}); + sign_ext = (num_shift_bits_mask << (bit_width - e2u - UT{1})) | msb; + } + return static_cast<T>((e1u >> e2u) | sign_ext); + }; + + T result = 0; + if constexpr (IsAbstract<NumberT>) { + if (static_cast<size_t>(e2) >= bit_width) { + result = T{0}; + } else { + result = signed_shift_right(); + } + } else { + if (static_cast<size_t>(e2) >= bit_width) { + // At shader/pipeline-creation time, it is an error to shift by the bit width of + // the lhs or greater. NOTE: At runtime, we shift by e2 % (bit width of e1). + AddError( + "shift right value must be less than the bit width of the lhs, which is " + + std::to_string(bit_width), + source); + return utils::Failure; + } + + if constexpr (std::is_signed_v<T>) { + result = signed_shift_right(); + } else { + result = e1 >> e2; + } + } + return CreateElement(builder, source, sem::Type::DeepestElementOf(ty), NumberT{result}); + }; + return Dispatch_ia_iu32(create, c0, c1); + }; + + if (!sem::Type::DeepestElementOf(args[1]->Type())->Is<sem::U32>()) { + TINT_ICE(Resolver, builder.Diagnostics()) + << "Element type of rhs of ShiftLeft must be a u32"; + return utils::Failure; + } + + return TransformElements(builder, ty, transform, args[0], args[1]); +} + ConstEval::Result ConstEval::abs(const sem::Type* ty, utils::VectorRef<const sem::Constant*> args, const Source& source) { diff --git a/src/tint/resolver/const_eval.h b/src/tint/resolver/const_eval.h index 9905036da0..3e6c626fae 100644 --- a/src/tint/resolver/const_eval.h +++ b/src/tint/resolver/const_eval.h @@ -382,6 +382,15 @@ class ConstEval { utils::VectorRef<const sem::Constant*> args, const Source& source); + /// Bitwise shift right operator '<<' + /// @param ty the expression type + /// @param args the input arguments + /// @param source the source location + /// @return the result value, or null if the value cannot be calculated + Result OpShiftRight(const sem::Type* ty, + utils::VectorRef<const sem::Constant*> args, + const Source& source); + //////////////////////////////////////////////////////////////////////////// // Builtins //////////////////////////////////////////////////////////////////////////// diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc index 2590466fbf..a37028522f 100644 --- a/src/tint/resolver/const_eval_binary_op_test.cc +++ b/src/tint/resolver/const_eval_binary_op_test.cc @@ -917,8 +917,7 @@ INSTANTIATE_TEST_SUITE_P(Xor, template <typename T> std::vector<Case> ShiftLeftCases() { - // Shift type is u32 for non-abstract - using ST = std::conditional_t<IsAbstract<T>, T, u32>; + using ST = u32; // Shift type is u32 using B = BitValues<T>; auto r = std::vector<Case>{ C(T{0b1010}, ST{0}, T{0b0000'0000'1010}), // @@ -1200,5 +1199,144 @@ INSTANTIATE_TEST_SUITE_P(Test, ShiftLeftSignChangeErrorCases<AInt>(), ShiftLeftSignChangeErrorCases<i32>()))); +template <typename T> +std::vector<Case> ShiftRightCases() { + using B = BitValues<T>; + auto r = std::vector<Case>{ + C(T{0b10101100}, u32{0}, T{0b10101100}), // + C(T{0b10101100}, u32{1}, T{0b01010110}), // + C(T{0b10101100}, u32{2}, T{0b00101011}), // + C(T{0b10101100}, u32{3}, T{0b00010101}), // + C(T{0b10101100}, u32{4}, T{0b00001010}), // + C(T{0b10101100}, u32{5}, T{0b00000101}), // + C(T{0b10101100}, u32{6}, T{0b00000010}), // + C(T{0b10101100}, u32{7}, T{0b00000001}), // + C(T{0b10101100}, u32{8}, T{0b00000000}), // + C(T{0b10101100}, u32{9}, T{0b00000000}), // + C(B::LeftMost, u32{0}, B::LeftMost), // + }; + + // msb not set, same for all types: inserted bit is 0 + ConcatInto( // + r, std::vector<Case>{ + C(T{0b01000000000000000000000010101100}, u32{0}, // + T{0b01000000000000000000000010101100}), + C(T{0b01000000000000000000000010101100}, u32{1}, // + T{0b00100000000000000000000001010110}), + C(T{0b01000000000000000000000010101100}, u32{2}, // + T{0b00010000000000000000000000101011}), + C(T{0b01000000000000000000000010101100}, u32{3}, // + T{0b00001000000000000000000000010101}), + C(T{0b01000000000000000000000010101100}, u32{4}, // + T{0b00000100000000000000000000001010}), + C(T{0b01000000000000000000000010101100}, u32{5}, // + T{0b00000010000000000000000000000101}), + C(T{0b01000000000000000000000010101100}, u32{6}, // + T{0b00000001000000000000000000000010}), + C(T{0b01000000000000000000000010101100}, u32{7}, // + T{0b00000000100000000000000000000001}), + C(T{0b01000000000000000000000010101100}, u32{8}, // + T{0b00000000010000000000000000000000}), + C(T{0b01000000000000000000000010101100}, u32{9}, // + T{0b00000000001000000000000000000000}), + }); + + // msb set, result differs for i32 and u32 + if constexpr (std::is_same_v<T, u32>) { + // If unsigned, insert zero bits at the most significant positions. + ConcatInto( // + r, std::vector<Case>{ + C(T{0b10000000000000000000000010101100}, u32{0}, + T{0b10000000000000000000000010101100}), + C(T{0b10000000000000000000000010101100}, u32{1}, + T{0b01000000000000000000000001010110}), + C(T{0b10000000000000000000000010101100}, u32{2}, + T{0b00100000000000000000000000101011}), + C(T{0b10000000000000000000000010101100}, u32{3}, + T{0b00010000000000000000000000010101}), + C(T{0b10000000000000000000000010101100}, u32{4}, + T{0b00001000000000000000000000001010}), + C(T{0b10000000000000000000000010101100}, u32{5}, + T{0b00000100000000000000000000000101}), + C(T{0b10000000000000000000000010101100}, u32{6}, + T{0b00000010000000000000000000000010}), + C(T{0b10000000000000000000000010101100}, u32{7}, + T{0b00000001000000000000000000000001}), + C(T{0b10000000000000000000000010101100}, u32{8}, + T{0b00000000100000000000000000000000}), + C(T{0b10000000000000000000000010101100}, u32{9}, + T{0b00000000010000000000000000000000}), + // msb shifted by bit width - 1 + C(T{0b10000000000000000000000000000000}, u32{31}, + T{0b00000000000000000000000000000001}), + }); + } else if constexpr (std::is_same_v<T, i32>) { + // If signed, each inserted bit is 1, so the result is negative. + ConcatInto( // + r, std::vector<Case>{ + C(T{0b10000000000000000000000010101100}, u32{0}, + T{0b10000000000000000000000010101100}), // + C(T{0b10000000000000000000000010101100}, u32{1}, + T{0b11000000000000000000000001010110}), // + C(T{0b10000000000000000000000010101100}, u32{2}, + T{0b11100000000000000000000000101011}), // + C(T{0b10000000000000000000000010101100}, u32{3}, + T{0b11110000000000000000000000010101}), // + C(T{0b10000000000000000000000010101100}, u32{4}, + T{0b11111000000000000000000000001010}), // + C(T{0b10000000000000000000000010101100}, u32{5}, + T{0b11111100000000000000000000000101}), // + C(T{0b10000000000000000000000010101100}, u32{6}, + T{0b11111110000000000000000000000010}), // + C(T{0b10000000000000000000000010101100}, u32{7}, + T{0b11111111000000000000000000000001}), // + C(T{0b10000000000000000000000010101100}, u32{8}, + T{0b11111111100000000000000000000000}), // + C(T{0b10000000000000000000000010101100}, u32{9}, + T{0b11111111110000000000000000000000}), // + // msb shifted by bit width - 1 + C(T{0b10000000000000000000000000000000}, u32{31}, + T{0b11111111111111111111111111111111}), + }); + } + + // Test shift right by bit width or more + if constexpr (IsAbstract<T>) { + // For abstract int, no error, result is 0 + ConcatInto( // + r, std::vector<Case>{ + C(T{0}, u32{B::NumBits}, T{0}), + C(T{0}, u32{B::NumBits + 1}, T{0}), + C(T{0}, u32{B::NumBits + 1000}, T{0}), + C(T{42}, u32{B::NumBits}, T{0}), + C(T{42}, u32{B::NumBits + 1}, T{0}), + C(T{42}, u32{B::NumBits + 1000}, T{0}), + }); + } else { + // For concretes, error + const char* error_msg = + "12:34 error: shift right value must be less than the bit width of the lhs, which is " + "32"; + ConcatInto( // + r, std::vector<Case>{ + E(T{0}, u32{B::NumBits}, error_msg), + E(T{0}, u32{B::NumBits + 1}, error_msg), + E(T{0}, u32{B::NumBits + 1000}, error_msg), + E(T{42}, u32{B::NumBits}, error_msg), + E(T{42}, u32{B::NumBits + 1}, error_msg), + E(T{42}, u32{B::NumBits + 1000}, error_msg), + }); + } + + return r; +} +INSTANTIATE_TEST_SUITE_P(ShiftRight, + ResolverConstEvalBinaryOpTest, + testing::Combine( // + testing::Values(ast::BinaryOp::kShiftRight), + testing::ValuesIn(Concat(ShiftRightCases<AInt>(), // + ShiftRightCases<i32>(), // + ShiftRightCases<u32>())))); + } // namespace } // namespace tint::resolver diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc index 922a1dcea9..ef40bffcc2 100644 --- a/src/tint/resolver/function_validation_test.cc +++ b/src/tint/resolver/function_validation_test.cc @@ -890,10 +890,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) { // @compute @workgroup_size(1 << 2 + 4) // fn main() {} + GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i)); Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))), + WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")), }); EXPECT_FALSE(r()->Resolve()); @@ -905,10 +906,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) { // @compute @workgroup_size(1, 1 << 2 + 4) // fn main() {} + GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i)); Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))), + WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")), }); EXPECT_FALSE(r()->Resolve()); @@ -920,10 +922,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) { TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) { // @compute @workgroup_size(1, 1, 1 << 2 + 4) // fn main() {} + GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i)); Func("main", utils::Empty, ty.void_(), utils::Empty, utils::Vector{ Stage(ast::PipelineStage::kCompute), - WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))), + WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")), }); EXPECT_FALSE(r()->Resolve()); diff --git a/src/tint/resolver/intrinsic_table.inl b/src/tint/resolver/intrinsic_table.inl index 00b7557333..3fedcd042c 100644 --- a/src/tint/resolver/intrinsic_table.inl +++ b/src/tint/resolver/intrinsic_table.inl @@ -13469,24 +13469,24 @@ constexpr OverloadInfo kOverloads[] = { /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 0, - /* template types */ &kTemplateTypes[25], + /* template types */ &kTemplateTypes[28], /* template numbers */ &kTemplateNumbers[10], /* parameters */ &kParameters[778], /* return matcher indices */ &kMatcherIndices[3], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpShiftRight, }, { /* [432] */ /* num parameters */ 2, /* num template types */ 1, /* num template numbers */ 1, - /* template types */ &kTemplateTypes[25], + /* template types */ &kTemplateTypes[28], /* template numbers */ &kTemplateNumbers[4], /* parameters */ &kParameters[780], /* return matcher indices */ &kMatcherIndices[30], /* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline), - /* const eval */ nullptr, + /* const eval */ &ConstEval::OpShiftRight, }, { /* [433] */ @@ -14975,8 +14975,8 @@ constexpr IntrinsicInfo kBinaryOperators[] = { }, { /* [17] */ - /* op >><T : iu32>(T, u32) -> T */ - /* op >><T : iu32, N : num>(vec<N, T>, vec<N, u32>) -> vec<N, T> */ + /* op >><T : ia_iu32>(T, u32) -> T */ + /* op >><T : ia_iu32, N : num>(vec<N, T>, vec<N, u32>) -> vec<N, T> */ /* num overloads */ 2, /* overloads */ &kOverloads[431], },