tint: const eval of binary right shift

Bug: tint:1581
Change-Id: I3f40454559c4fc36565de1a11a6e6c8c394fd0cc
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112620
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2022-12-02 15:25:20 +00:00 committed by Dawn LUCI CQ
parent 7423496da6
commit 42ada5f248
6 changed files with 227 additions and 13 deletions

View File

@ -993,8 +993,8 @@ op || (bool, bool) -> bool
@const op << <T: ia_iu32>(T, u32) -> T
@const op << <T: ia_iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
op >> <T: iu32>(T, u32) -> T
op >> <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
@const op >> <T: ia_iu32>(T, u32) -> T
@const op >> <T: ia_iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
////////////////////////////////////////////////////////////////////////////////
// Tint internal builtins //

View File

@ -1950,6 +1950,70 @@ ConstEval::Result ConstEval::OpShiftLeft(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::OpShiftRight(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto e1, auto e2) -> ImplResult {
using NumberT = decltype(e1);
using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>;
constexpr size_t bit_width = BitWidth<NumberT>;
const UT e1u = static_cast<UT>(e1);
const UT e2u = static_cast<UT>(e2);
auto signed_shift_right = [&] {
// In C++, right shift of a signed negative number is implementation-defined.
// Although most implementations sign-extend, we do it manually to ensure it works
// correctly on all implementations.
const UT msb = UT{1} << (bit_width - 1);
UT sign_ext = 0;
if (e1u & msb) {
// Set e2 + 1 bits to 1
UT num_shift_bits_mask = ((UT{1} << e2u) - UT{1});
sign_ext = (num_shift_bits_mask << (bit_width - e2u - UT{1})) | msb;
}
return static_cast<T>((e1u >> e2u) | sign_ext);
};
T result = 0;
if constexpr (IsAbstract<NumberT>) {
if (static_cast<size_t>(e2) >= bit_width) {
result = T{0};
} else {
result = signed_shift_right();
}
} else {
if (static_cast<size_t>(e2) >= bit_width) {
// At shader/pipeline-creation time, it is an error to shift by the bit width of
// the lhs or greater. NOTE: At runtime, we shift by e2 % (bit width of e1).
AddError(
"shift right value must be less than the bit width of the lhs, which is " +
std::to_string(bit_width),
source);
return utils::Failure;
}
if constexpr (std::is_signed_v<T>) {
result = signed_shift_right();
} else {
result = e1 >> e2;
}
}
return CreateElement(builder, source, sem::Type::DeepestElementOf(ty), NumberT{result});
};
return Dispatch_ia_iu32(create, c0, c1);
};
if (!sem::Type::DeepestElementOf(args[1]->Type())->Is<sem::U32>()) {
TINT_ICE(Resolver, builder.Diagnostics())
<< "Element type of rhs of ShiftLeft must be a u32";
return utils::Failure;
}
return TransformElements(builder, ty, transform, args[0], args[1]);
}
ConstEval::Result ConstEval::abs(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {

View File

@ -382,6 +382,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// Bitwise shift right operator '<<'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location
/// @return the result value, or null if the value cannot be calculated
Result OpShiftRight(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
////////////////////////////////////////////////////////////////////////////
// Builtins
////////////////////////////////////////////////////////////////////////////

View File

@ -917,8 +917,7 @@ INSTANTIATE_TEST_SUITE_P(Xor,
template <typename T>
std::vector<Case> ShiftLeftCases() {
// Shift type is u32 for non-abstract
using ST = std::conditional_t<IsAbstract<T>, T, u32>;
using ST = u32; // Shift type is u32
using B = BitValues<T>;
auto r = std::vector<Case>{
C(T{0b1010}, ST{0}, T{0b0000'0000'1010}), //
@ -1200,5 +1199,144 @@ INSTANTIATE_TEST_SUITE_P(Test,
ShiftLeftSignChangeErrorCases<AInt>(),
ShiftLeftSignChangeErrorCases<i32>())));
template <typename T>
std::vector<Case> ShiftRightCases() {
using B = BitValues<T>;
auto r = std::vector<Case>{
C(T{0b10101100}, u32{0}, T{0b10101100}), //
C(T{0b10101100}, u32{1}, T{0b01010110}), //
C(T{0b10101100}, u32{2}, T{0b00101011}), //
C(T{0b10101100}, u32{3}, T{0b00010101}), //
C(T{0b10101100}, u32{4}, T{0b00001010}), //
C(T{0b10101100}, u32{5}, T{0b00000101}), //
C(T{0b10101100}, u32{6}, T{0b00000010}), //
C(T{0b10101100}, u32{7}, T{0b00000001}), //
C(T{0b10101100}, u32{8}, T{0b00000000}), //
C(T{0b10101100}, u32{9}, T{0b00000000}), //
C(B::LeftMost, u32{0}, B::LeftMost), //
};
// msb not set, same for all types: inserted bit is 0
ConcatInto( //
r, std::vector<Case>{
C(T{0b01000000000000000000000010101100}, u32{0}, //
T{0b01000000000000000000000010101100}),
C(T{0b01000000000000000000000010101100}, u32{1}, //
T{0b00100000000000000000000001010110}),
C(T{0b01000000000000000000000010101100}, u32{2}, //
T{0b00010000000000000000000000101011}),
C(T{0b01000000000000000000000010101100}, u32{3}, //
T{0b00001000000000000000000000010101}),
C(T{0b01000000000000000000000010101100}, u32{4}, //
T{0b00000100000000000000000000001010}),
C(T{0b01000000000000000000000010101100}, u32{5}, //
T{0b00000010000000000000000000000101}),
C(T{0b01000000000000000000000010101100}, u32{6}, //
T{0b00000001000000000000000000000010}),
C(T{0b01000000000000000000000010101100}, u32{7}, //
T{0b00000000100000000000000000000001}),
C(T{0b01000000000000000000000010101100}, u32{8}, //
T{0b00000000010000000000000000000000}),
C(T{0b01000000000000000000000010101100}, u32{9}, //
T{0b00000000001000000000000000000000}),
});
// msb set, result differs for i32 and u32
if constexpr (std::is_same_v<T, u32>) {
// If unsigned, insert zero bits at the most significant positions.
ConcatInto( //
r, std::vector<Case>{
C(T{0b10000000000000000000000010101100}, u32{0},
T{0b10000000000000000000000010101100}),
C(T{0b10000000000000000000000010101100}, u32{1},
T{0b01000000000000000000000001010110}),
C(T{0b10000000000000000000000010101100}, u32{2},
T{0b00100000000000000000000000101011}),
C(T{0b10000000000000000000000010101100}, u32{3},
T{0b00010000000000000000000000010101}),
C(T{0b10000000000000000000000010101100}, u32{4},
T{0b00001000000000000000000000001010}),
C(T{0b10000000000000000000000010101100}, u32{5},
T{0b00000100000000000000000000000101}),
C(T{0b10000000000000000000000010101100}, u32{6},
T{0b00000010000000000000000000000010}),
C(T{0b10000000000000000000000010101100}, u32{7},
T{0b00000001000000000000000000000001}),
C(T{0b10000000000000000000000010101100}, u32{8},
T{0b00000000100000000000000000000000}),
C(T{0b10000000000000000000000010101100}, u32{9},
T{0b00000000010000000000000000000000}),
// msb shifted by bit width - 1
C(T{0b10000000000000000000000000000000}, u32{31},
T{0b00000000000000000000000000000001}),
});
} else if constexpr (std::is_same_v<T, i32>) {
// If signed, each inserted bit is 1, so the result is negative.
ConcatInto( //
r, std::vector<Case>{
C(T{0b10000000000000000000000010101100}, u32{0},
T{0b10000000000000000000000010101100}), //
C(T{0b10000000000000000000000010101100}, u32{1},
T{0b11000000000000000000000001010110}), //
C(T{0b10000000000000000000000010101100}, u32{2},
T{0b11100000000000000000000000101011}), //
C(T{0b10000000000000000000000010101100}, u32{3},
T{0b11110000000000000000000000010101}), //
C(T{0b10000000000000000000000010101100}, u32{4},
T{0b11111000000000000000000000001010}), //
C(T{0b10000000000000000000000010101100}, u32{5},
T{0b11111100000000000000000000000101}), //
C(T{0b10000000000000000000000010101100}, u32{6},
T{0b11111110000000000000000000000010}), //
C(T{0b10000000000000000000000010101100}, u32{7},
T{0b11111111000000000000000000000001}), //
C(T{0b10000000000000000000000010101100}, u32{8},
T{0b11111111100000000000000000000000}), //
C(T{0b10000000000000000000000010101100}, u32{9},
T{0b11111111110000000000000000000000}), //
// msb shifted by bit width - 1
C(T{0b10000000000000000000000000000000}, u32{31},
T{0b11111111111111111111111111111111}),
});
}
// Test shift right by bit width or more
if constexpr (IsAbstract<T>) {
// For abstract int, no error, result is 0
ConcatInto( //
r, std::vector<Case>{
C(T{0}, u32{B::NumBits}, T{0}),
C(T{0}, u32{B::NumBits + 1}, T{0}),
C(T{0}, u32{B::NumBits + 1000}, T{0}),
C(T{42}, u32{B::NumBits}, T{0}),
C(T{42}, u32{B::NumBits + 1}, T{0}),
C(T{42}, u32{B::NumBits + 1000}, T{0}),
});
} else {
// For concretes, error
const char* error_msg =
"12:34 error: shift right value must be less than the bit width of the lhs, which is "
"32";
ConcatInto( //
r, std::vector<Case>{
E(T{0}, u32{B::NumBits}, error_msg),
E(T{0}, u32{B::NumBits + 1}, error_msg),
E(T{0}, u32{B::NumBits + 1000}, error_msg),
E(T{42}, u32{B::NumBits}, error_msg),
E(T{42}, u32{B::NumBits + 1}, error_msg),
E(T{42}, u32{B::NumBits + 1000}, error_msg),
});
}
return r;
}
INSTANTIATE_TEST_SUITE_P(ShiftRight,
ResolverConstEvalBinaryOpTest,
testing::Combine( //
testing::Values(ast::BinaryOp::kShiftRight),
testing::ValuesIn(Concat(ShiftRightCases<AInt>(), //
ShiftRightCases<i32>(), //
ShiftRightCases<u32>()))));
} // namespace
} // namespace tint::resolver

View File

@ -890,10 +890,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
// @compute @workgroup_size(1 << 2 + 4)
// fn main() {}
GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
});
EXPECT_FALSE(r()->Resolve());
@ -905,10 +906,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
// @compute @workgroup_size(1, 1 << 2 + 4)
// fn main() {}
GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
});
EXPECT_FALSE(r()->Resolve());
@ -920,10 +922,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
// @compute @workgroup_size(1, 1, 1 << 2 + 4)
// fn main() {}
GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
Func("main", utils::Empty, ty.void_(), utils::Empty,
utils::Vector{
Stage(ast::PipelineStage::kCompute),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
});
EXPECT_FALSE(r()->Resolve());

View File

@ -13469,24 +13469,24 @@ constexpr OverloadInfo kOverloads[] = {
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 0,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[28],
/* template numbers */ &kTemplateNumbers[10],
/* parameters */ &kParameters[778],
/* return matcher indices */ &kMatcherIndices[3],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpShiftRight,
},
{
/* [432] */
/* num parameters */ 2,
/* num template types */ 1,
/* num template numbers */ 1,
/* template types */ &kTemplateTypes[25],
/* template types */ &kTemplateTypes[28],
/* template numbers */ &kTemplateNumbers[4],
/* parameters */ &kParameters[780],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::OpShiftRight,
},
{
/* [433] */
@ -14975,8 +14975,8 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
},
{
/* [17] */
/* op >><T : iu32>(T, u32) -> T */
/* op >><T : iu32, N : num>(vec<N, T>, vec<N, u32>) -> vec<N, T> */
/* op >><T : ia_iu32>(T, u32) -> T */
/* op >><T : ia_iu32, N : num>(vec<N, T>, vec<N, u32>) -> vec<N, T> */
/* num overloads */ 2,
/* overloads */ &kOverloads[431],
},