tint: const eval of binary right shift
Bug: tint:1581 Change-Id: I3f40454559c4fc36565de1a11a6e6c8c394fd0cc Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/112620 Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
7423496da6
commit
42ada5f248
|
@ -993,8 +993,8 @@ op || (bool, bool) -> bool
|
|||
@const op << <T: ia_iu32>(T, u32) -> T
|
||||
@const op << <T: ia_iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
|
||||
|
||||
op >> <T: iu32>(T, u32) -> T
|
||||
op >> <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
|
||||
@const op >> <T: ia_iu32>(T, u32) -> T
|
||||
@const op >> <T: ia_iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Tint internal builtins //
|
||||
|
|
|
@ -1950,6 +1950,70 @@ ConstEval::Result ConstEval::OpShiftLeft(const sem::Type* ty,
|
|||
return TransformElements(builder, ty, transform, args[0], args[1]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::OpShiftRight(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
|
||||
auto create = [&](auto e1, auto e2) -> ImplResult {
|
||||
using NumberT = decltype(e1);
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
constexpr size_t bit_width = BitWidth<NumberT>;
|
||||
const UT e1u = static_cast<UT>(e1);
|
||||
const UT e2u = static_cast<UT>(e2);
|
||||
|
||||
auto signed_shift_right = [&] {
|
||||
// In C++, right shift of a signed negative number is implementation-defined.
|
||||
// Although most implementations sign-extend, we do it manually to ensure it works
|
||||
// correctly on all implementations.
|
||||
const UT msb = UT{1} << (bit_width - 1);
|
||||
UT sign_ext = 0;
|
||||
if (e1u & msb) {
|
||||
// Set e2 + 1 bits to 1
|
||||
UT num_shift_bits_mask = ((UT{1} << e2u) - UT{1});
|
||||
sign_ext = (num_shift_bits_mask << (bit_width - e2u - UT{1})) | msb;
|
||||
}
|
||||
return static_cast<T>((e1u >> e2u) | sign_ext);
|
||||
};
|
||||
|
||||
T result = 0;
|
||||
if constexpr (IsAbstract<NumberT>) {
|
||||
if (static_cast<size_t>(e2) >= bit_width) {
|
||||
result = T{0};
|
||||
} else {
|
||||
result = signed_shift_right();
|
||||
}
|
||||
} else {
|
||||
if (static_cast<size_t>(e2) >= bit_width) {
|
||||
// At shader/pipeline-creation time, it is an error to shift by the bit width of
|
||||
// the lhs or greater. NOTE: At runtime, we shift by e2 % (bit width of e1).
|
||||
AddError(
|
||||
"shift right value must be less than the bit width of the lhs, which is " +
|
||||
std::to_string(bit_width),
|
||||
source);
|
||||
return utils::Failure;
|
||||
}
|
||||
|
||||
if constexpr (std::is_signed_v<T>) {
|
||||
result = signed_shift_right();
|
||||
} else {
|
||||
result = e1 >> e2;
|
||||
}
|
||||
}
|
||||
return CreateElement(builder, source, sem::Type::DeepestElementOf(ty), NumberT{result});
|
||||
};
|
||||
return Dispatch_ia_iu32(create, c0, c1);
|
||||
};
|
||||
|
||||
if (!sem::Type::DeepestElementOf(args[1]->Type())->Is<sem::U32>()) {
|
||||
TINT_ICE(Resolver, builder.Diagnostics())
|
||||
<< "Element type of rhs of ShiftLeft must be a u32";
|
||||
return utils::Failure;
|
||||
}
|
||||
|
||||
return TransformElements(builder, ty, transform, args[0], args[1]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::abs(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
|
|
|
@ -382,6 +382,15 @@ class ConstEval {
|
|||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// Bitwise shift right operator '<<'
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result OpShiftRight(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
// Builtins
|
||||
////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -917,8 +917,7 @@ INSTANTIATE_TEST_SUITE_P(Xor,
|
|||
|
||||
template <typename T>
|
||||
std::vector<Case> ShiftLeftCases() {
|
||||
// Shift type is u32 for non-abstract
|
||||
using ST = std::conditional_t<IsAbstract<T>, T, u32>;
|
||||
using ST = u32; // Shift type is u32
|
||||
using B = BitValues<T>;
|
||||
auto r = std::vector<Case>{
|
||||
C(T{0b1010}, ST{0}, T{0b0000'0000'1010}), //
|
||||
|
@ -1200,5 +1199,144 @@ INSTANTIATE_TEST_SUITE_P(Test,
|
|||
ShiftLeftSignChangeErrorCases<AInt>(),
|
||||
ShiftLeftSignChangeErrorCases<i32>())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> ShiftRightCases() {
|
||||
using B = BitValues<T>;
|
||||
auto r = std::vector<Case>{
|
||||
C(T{0b10101100}, u32{0}, T{0b10101100}), //
|
||||
C(T{0b10101100}, u32{1}, T{0b01010110}), //
|
||||
C(T{0b10101100}, u32{2}, T{0b00101011}), //
|
||||
C(T{0b10101100}, u32{3}, T{0b00010101}), //
|
||||
C(T{0b10101100}, u32{4}, T{0b00001010}), //
|
||||
C(T{0b10101100}, u32{5}, T{0b00000101}), //
|
||||
C(T{0b10101100}, u32{6}, T{0b00000010}), //
|
||||
C(T{0b10101100}, u32{7}, T{0b00000001}), //
|
||||
C(T{0b10101100}, u32{8}, T{0b00000000}), //
|
||||
C(T{0b10101100}, u32{9}, T{0b00000000}), //
|
||||
C(B::LeftMost, u32{0}, B::LeftMost), //
|
||||
};
|
||||
|
||||
// msb not set, same for all types: inserted bit is 0
|
||||
ConcatInto( //
|
||||
r, std::vector<Case>{
|
||||
C(T{0b01000000000000000000000010101100}, u32{0}, //
|
||||
T{0b01000000000000000000000010101100}),
|
||||
C(T{0b01000000000000000000000010101100}, u32{1}, //
|
||||
T{0b00100000000000000000000001010110}),
|
||||
C(T{0b01000000000000000000000010101100}, u32{2}, //
|
||||
T{0b00010000000000000000000000101011}),
|
||||
C(T{0b01000000000000000000000010101100}, u32{3}, //
|
||||
T{0b00001000000000000000000000010101}),
|
||||
C(T{0b01000000000000000000000010101100}, u32{4}, //
|
||||
T{0b00000100000000000000000000001010}),
|
||||
C(T{0b01000000000000000000000010101100}, u32{5}, //
|
||||
T{0b00000010000000000000000000000101}),
|
||||
C(T{0b01000000000000000000000010101100}, u32{6}, //
|
||||
T{0b00000001000000000000000000000010}),
|
||||
C(T{0b01000000000000000000000010101100}, u32{7}, //
|
||||
T{0b00000000100000000000000000000001}),
|
||||
C(T{0b01000000000000000000000010101100}, u32{8}, //
|
||||
T{0b00000000010000000000000000000000}),
|
||||
C(T{0b01000000000000000000000010101100}, u32{9}, //
|
||||
T{0b00000000001000000000000000000000}),
|
||||
});
|
||||
|
||||
// msb set, result differs for i32 and u32
|
||||
if constexpr (std::is_same_v<T, u32>) {
|
||||
// If unsigned, insert zero bits at the most significant positions.
|
||||
ConcatInto( //
|
||||
r, std::vector<Case>{
|
||||
C(T{0b10000000000000000000000010101100}, u32{0},
|
||||
T{0b10000000000000000000000010101100}),
|
||||
C(T{0b10000000000000000000000010101100}, u32{1},
|
||||
T{0b01000000000000000000000001010110}),
|
||||
C(T{0b10000000000000000000000010101100}, u32{2},
|
||||
T{0b00100000000000000000000000101011}),
|
||||
C(T{0b10000000000000000000000010101100}, u32{3},
|
||||
T{0b00010000000000000000000000010101}),
|
||||
C(T{0b10000000000000000000000010101100}, u32{4},
|
||||
T{0b00001000000000000000000000001010}),
|
||||
C(T{0b10000000000000000000000010101100}, u32{5},
|
||||
T{0b00000100000000000000000000000101}),
|
||||
C(T{0b10000000000000000000000010101100}, u32{6},
|
||||
T{0b00000010000000000000000000000010}),
|
||||
C(T{0b10000000000000000000000010101100}, u32{7},
|
||||
T{0b00000001000000000000000000000001}),
|
||||
C(T{0b10000000000000000000000010101100}, u32{8},
|
||||
T{0b00000000100000000000000000000000}),
|
||||
C(T{0b10000000000000000000000010101100}, u32{9},
|
||||
T{0b00000000010000000000000000000000}),
|
||||
// msb shifted by bit width - 1
|
||||
C(T{0b10000000000000000000000000000000}, u32{31},
|
||||
T{0b00000000000000000000000000000001}),
|
||||
});
|
||||
} else if constexpr (std::is_same_v<T, i32>) {
|
||||
// If signed, each inserted bit is 1, so the result is negative.
|
||||
ConcatInto( //
|
||||
r, std::vector<Case>{
|
||||
C(T{0b10000000000000000000000010101100}, u32{0},
|
||||
T{0b10000000000000000000000010101100}), //
|
||||
C(T{0b10000000000000000000000010101100}, u32{1},
|
||||
T{0b11000000000000000000000001010110}), //
|
||||
C(T{0b10000000000000000000000010101100}, u32{2},
|
||||
T{0b11100000000000000000000000101011}), //
|
||||
C(T{0b10000000000000000000000010101100}, u32{3},
|
||||
T{0b11110000000000000000000000010101}), //
|
||||
C(T{0b10000000000000000000000010101100}, u32{4},
|
||||
T{0b11111000000000000000000000001010}), //
|
||||
C(T{0b10000000000000000000000010101100}, u32{5},
|
||||
T{0b11111100000000000000000000000101}), //
|
||||
C(T{0b10000000000000000000000010101100}, u32{6},
|
||||
T{0b11111110000000000000000000000010}), //
|
||||
C(T{0b10000000000000000000000010101100}, u32{7},
|
||||
T{0b11111111000000000000000000000001}), //
|
||||
C(T{0b10000000000000000000000010101100}, u32{8},
|
||||
T{0b11111111100000000000000000000000}), //
|
||||
C(T{0b10000000000000000000000010101100}, u32{9},
|
||||
T{0b11111111110000000000000000000000}), //
|
||||
// msb shifted by bit width - 1
|
||||
C(T{0b10000000000000000000000000000000}, u32{31},
|
||||
T{0b11111111111111111111111111111111}),
|
||||
});
|
||||
}
|
||||
|
||||
// Test shift right by bit width or more
|
||||
if constexpr (IsAbstract<T>) {
|
||||
// For abstract int, no error, result is 0
|
||||
ConcatInto( //
|
||||
r, std::vector<Case>{
|
||||
C(T{0}, u32{B::NumBits}, T{0}),
|
||||
C(T{0}, u32{B::NumBits + 1}, T{0}),
|
||||
C(T{0}, u32{B::NumBits + 1000}, T{0}),
|
||||
C(T{42}, u32{B::NumBits}, T{0}),
|
||||
C(T{42}, u32{B::NumBits + 1}, T{0}),
|
||||
C(T{42}, u32{B::NumBits + 1000}, T{0}),
|
||||
});
|
||||
} else {
|
||||
// For concretes, error
|
||||
const char* error_msg =
|
||||
"12:34 error: shift right value must be less than the bit width of the lhs, which is "
|
||||
"32";
|
||||
ConcatInto( //
|
||||
r, std::vector<Case>{
|
||||
E(T{0}, u32{B::NumBits}, error_msg),
|
||||
E(T{0}, u32{B::NumBits + 1}, error_msg),
|
||||
E(T{0}, u32{B::NumBits + 1000}, error_msg),
|
||||
E(T{42}, u32{B::NumBits}, error_msg),
|
||||
E(T{42}, u32{B::NumBits + 1}, error_msg),
|
||||
E(T{42}, u32{B::NumBits + 1000}, error_msg),
|
||||
});
|
||||
}
|
||||
|
||||
return r;
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ShiftRight,
|
||||
ResolverConstEvalBinaryOpTest,
|
||||
testing::Combine( //
|
||||
testing::Values(ast::BinaryOp::kShiftRight),
|
||||
testing::ValuesIn(Concat(ShiftRightCases<AInt>(), //
|
||||
ShiftRightCases<i32>(), //
|
||||
ShiftRightCases<u32>()))));
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::resolver
|
||||
|
|
|
@ -890,10 +890,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_NonConst) {
|
|||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
|
||||
// @compute @workgroup_size(1 << 2 + 4)
|
||||
// fn main() {}
|
||||
GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
|
||||
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
@ -905,10 +906,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_x) {
|
|||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
|
||||
// @compute @workgroup_size(1, 1 << 2 + 4)
|
||||
// fn main() {}
|
||||
GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
|
||||
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
@ -920,10 +922,11 @@ TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_y) {
|
|||
TEST_F(ResolverFunctionValidationTest, WorkgroupSize_InvalidExpr_z) {
|
||||
// @compute @workgroup_size(1, 1, 1 << 2 + 4)
|
||||
// fn main() {}
|
||||
GlobalVar("x", ty.i32(), ast::AddressSpace::kPrivate, Expr(0_i));
|
||||
Func("main", utils::Empty, ty.void_(), utils::Empty,
|
||||
utils::Vector{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), Shr(1_i, Add(2_u, 4_u)))),
|
||||
WorkgroupSize(Construct(Source{{12, 34}}, ty.i32(), "x")),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
|
|
@ -13469,24 +13469,24 @@ constexpr OverloadInfo kOverloads[] = {
|
|||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 0,
|
||||
/* template types */ &kTemplateTypes[25],
|
||||
/* template types */ &kTemplateTypes[28],
|
||||
/* template numbers */ &kTemplateNumbers[10],
|
||||
/* parameters */ &kParameters[778],
|
||||
/* return matcher indices */ &kMatcherIndices[3],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpShiftRight,
|
||||
},
|
||||
{
|
||||
/* [432] */
|
||||
/* num parameters */ 2,
|
||||
/* num template types */ 1,
|
||||
/* num template numbers */ 1,
|
||||
/* template types */ &kTemplateTypes[25],
|
||||
/* template types */ &kTemplateTypes[28],
|
||||
/* template numbers */ &kTemplateNumbers[4],
|
||||
/* parameters */ &kParameters[780],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsOperator, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::OpShiftRight,
|
||||
},
|
||||
{
|
||||
/* [433] */
|
||||
|
@ -14975,8 +14975,8 @@ constexpr IntrinsicInfo kBinaryOperators[] = {
|
|||
},
|
||||
{
|
||||
/* [17] */
|
||||
/* op >><T : iu32>(T, u32) -> T */
|
||||
/* op >><T : iu32, N : num>(vec<N, T>, vec<N, u32>) -> vec<N, T> */
|
||||
/* op >><T : ia_iu32>(T, u32) -> T */
|
||||
/* op >><T : ia_iu32, N : num>(vec<N, T>, vec<N, u32>) -> vec<N, T> */
|
||||
/* num overloads */ 2,
|
||||
/* overloads */ &kOverloads[431],
|
||||
},
|
||||
|
|
Loading…
Reference in New Issue