tint: const eval of binary left shift

Bug: tint:1581
Change-Id: I8c1b01bcae2a205e712b8004573cc26a3366785a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/103061
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano 2022-09-23 21:58:29 +00:00 committed by Dawn LUCI CQ
parent b46c6d7c1d
commit 5f33facbc1
7 changed files with 4400 additions and 4091 deletions

View File

@ -956,8 +956,10 @@ op || (bool, bool) -> bool
@const op >= <T: fia_fiu32_f16>(T, T) -> bool @const op >= <T: fia_fiu32_f16>(T, T) -> bool
@const op >= <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool> @const op >= <T: fiu32_f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op << <T: iu32>(T, u32) -> T @const op << <T: iu32>(T, u32) -> T
op << <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T> @const op << <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
@const op << <T: ia>(T, ia) -> T
@const op << <T: ia, N: num> (vec<N, T>, vec<N, ia>) -> vec<N, T>
op >> <T: iu32>(T, u32) -> T op >> <T: iu32>(T, u32) -> T
op >> <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T> op >> <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>

View File

@ -95,6 +95,10 @@ constexpr bool IsUnsignedIntegral =
template <typename T> template <typename T>
constexpr bool IsNumeric = IsIntegral<T> || IsFloatingPoint<T>; constexpr bool IsNumeric = IsIntegral<T> || IsFloatingPoint<T>;
/// Returns the bit width of T
template <typename T>
constexpr size_t BitWidth = sizeof(UnwrapNumber<T>) * 8;
/// NumberBase is a CRTP base class for Number<T> /// NumberBase is a CRTP base class for Number<T>
template <typename NumberT> template <typename NumberT>
struct NumberBase { struct NumberBase {
@ -259,6 +263,10 @@ using f32 = Number<float>;
/// However since C++ don't have native binary16 type, the value is stored as float. /// However since C++ don't have native binary16 type, the value is stored as float.
using f16 = Number<detail::NumberKindF16>; using f16 = Number<detail::NumberKindF16>;
/// True iff T is an abstract number type
template <typename T>
constexpr bool IsAbstract = std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>;
/// @returns the friendly name of Number type T /// @returns the friendly name of Number type T
template <typename T, traits::EnableIf<IsNumber<T>>* = nullptr> template <typename T, traits::EnableIf<IsNumber<T>>* = nullptr>
const char* FriendlyName() { const char* FriendlyName() {

View File

@ -2125,6 +2125,17 @@ class ProgramBuilder {
ast::BinaryOp::kShiftLeft, Expr(std::forward<LHS>(lhs)), Expr(std::forward<RHS>(rhs))); ast::BinaryOp::kShiftLeft, Expr(std::forward<LHS>(lhs)), Expr(std::forward<RHS>(rhs)));
} }
/// @param source the source information
/// @param lhs the left hand argument to the bit shift left operation
/// @param rhs the right hand argument to the bit shift left operation
/// @returns a `ast::BinaryExpression` bit shifting left `lhs` by `rhs`
template <typename LHS, typename RHS>
const ast::BinaryExpression* Shl(const Source& source, LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(source, ast::BinaryOp::kShiftLeft,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param lhs the left hand argument to the xor operation /// @param lhs the left hand argument to the xor operation
/// @param rhs the right hand argument to the xor operation /// @param rhs the right hand argument to the xor operation
/// @returns a `ast::BinaryExpression` bitwise xor-ing `lhs` and `rhs` /// @returns a `ast::BinaryExpression` bitwise xor-ing `lhs` and `rhs`

View File

@ -244,7 +244,7 @@ struct Element : ImplConstant {
// Conversion success // Conversion success
return builder.create<Element<TO>>(target_ty, conv.Get()); return builder.create<Element<TO>>(target_ty, conv.Get());
// --- Below this point are the failure cases --- // --- Below this point are the failure cases ---
} else if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) { } else if constexpr (IsAbstract<T>) {
// [abstract-numeric -> x] - materialization failure // [abstract-numeric -> x] - materialization failure
std::stringstream ss; std::stringstream ss;
ss << "value " << value << " cannot be represented as "; ss << "value " << value << " cannot be represented as ";
@ -576,7 +576,7 @@ ConstEval::ConstEval(ProgramBuilder& b) : builder(b) {}
template <typename NumberT> template <typename NumberT>
utils::Result<NumberT> ConstEval::Add(NumberT a, NumberT b) { utils::Result<NumberT> ConstEval::Add(NumberT a, NumberT b) {
NumberT result; NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) { if constexpr (IsAbstract<NumberT>) {
// Check for over/underflow for abstract values // Check for over/underflow for abstract values
if (auto r = CheckedAdd(a, b)) { if (auto r = CheckedAdd(a, b)) {
result = r->value; result = r->value;
@ -604,7 +604,7 @@ template <typename NumberT>
utils::Result<NumberT> ConstEval::Mul(NumberT a, NumberT b) { utils::Result<NumberT> ConstEval::Mul(NumberT a, NumberT b) {
using T = UnwrapNumber<NumberT>; using T = UnwrapNumber<NumberT>;
NumberT result; NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) { if constexpr (IsAbstract<NumberT>) {
// Check for over/underflow for abstract values // Check for over/underflow for abstract values
if (auto r = CheckedMul(a, b)) { if (auto r = CheckedMul(a, b)) {
result = r->value; result = r->value;
@ -1029,7 +1029,7 @@ ConstEval::Result ConstEval::OpMinus(const sem::Type* ty,
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ImplResult {
using NumberT = decltype(i); using NumberT = decltype(i);
NumberT result; NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) { if constexpr (IsAbstract<NumberT>) {
// Check for over/underflow for abstract values // Check for over/underflow for abstract values
if (auto r = CheckedSub(i, j)) { if (auto r = CheckedSub(i, j)) {
result = r->value; result = r->value;
@ -1244,7 +1244,7 @@ ConstEval::Result ConstEval::OpDivide(const sem::Type* ty,
auto create = [&](auto i, auto j) -> ImplResult { auto create = [&](auto i, auto j) -> ImplResult {
using NumberT = decltype(i); using NumberT = decltype(i);
NumberT result; NumberT result;
if constexpr (std::is_same_v<NumberT, AInt> || std::is_same_v<NumberT, AFloat>) { if constexpr (IsAbstract<NumberT>) {
// Check for over/underflow for abstract values // Check for over/underflow for abstract values
if (auto r = CheckedDiv(i, j)) { if (auto r = CheckedDiv(i, j)) {
result = r->value; result = r->value;
@ -1416,6 +1416,80 @@ ConstEval::Result ConstEval::OpXor(const sem::Type* ty,
return r; return r;
} }
ConstEval::Result ConstEval::OpShiftLeft(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0, const sem::Constant* c1) {
auto create = [&](auto e1, auto e2) -> const ImplConstant* {
using NumberT = decltype(e1);
using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>;
constexpr size_t bit_width = BitWidth<NumberT>;
UT e1u = static_cast<UT>(e1);
UT e2u = static_cast<UT>(e2);
if constexpr (IsAbstract<NumberT>) {
// NOTE: Concrete shift left requires an unsigned rhs, so this check only applies
// for abstracts.
if (e2 < 0) {
AddError("cannot shift left by a negative value", source);
return nullptr;
}
// The e2 + 1 most significant bits of e1 must have the same bit value, otherwise
// sign change (overflow) would occur.
// Check sign change only if e2 is less than bit width of e1. If e1 is larger
// than bit width, we check for non-representable value below.
if (e2u < bit_width) {
size_t must_match_msb = e2u + 1;
UT mask = ~UT{0} << (bit_width - must_match_msb);
if ((e1u & mask) != 0 && (e1u & mask) != mask) {
AddError("shift left operation results in sign change", source);
return nullptr;
}
} else {
// If shift value >= bit_width, then any non-zero value would overflow
if (e1 != 0) {
AddError(OverflowErrorMessage(e1, "<<", e2), source);
return nullptr;
}
}
} else {
if (static_cast<size_t>(e2) >= bit_width) {
// At shader/pipeline-creation time, it is an error to shift by the bit width of
// the lhs or greater.
// NOTE: At runtime, we shift by e2 % (bit width of e1).
AddError(
"shift left value must be less than the bit width of the lhs, which is " +
std::to_string(bit_width),
source);
return nullptr;
}
// The e2 + 1 most significant bits of e1 must have the same bit value, otherwise
// sign change (overflow) would occur.
size_t must_match_msb = e2u + 1;
UT mask = ~UT{0} << (bit_width - must_match_msb);
if ((e1u & mask) != 0 && (e1u & mask) != mask) {
AddError("shift left operation results in sign change", source);
return nullptr;
}
}
// Avoid UB by left shifting as unsigned value
auto result = static_cast<T>(static_cast<UT>(e1) << e2);
return CreateElement(builder, sem::Type::DeepestElementOf(ty), NumberT{result});
};
return Dispatch_ia_iu32(create, c0, c1);
};
auto r = TransformElements(builder, ty, transform, args[0], args[1]);
if (builder.Diagnostics().contains_errors()) {
return utils::Failure;
}
return r;
}
ConstEval::Result ConstEval::atan2(const sem::Type* ty, ConstEval::Result ConstEval::atan2(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source&) { const Source&) {

View File

@ -361,8 +361,17 @@ class ConstEval {
/// @param source the source location of the conversion /// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated /// @return the result value, or null if the value cannot be calculated
Result OpXor(const sem::Type* ty, Result OpXor(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args, utils::VectorRef<const sem::Constant*> args,
const Source& source); const Source& source);
/// Bitwise shift left operator '<<'
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result OpShiftLeft(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
//////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////
// Builtins // Builtins

View File

@ -2947,15 +2947,34 @@ Action ForEachElemPair(const sem::Constant* a, const sem::Constant* b, Func&& f)
return Action::kContinue; return Action::kContinue;
} }
template <typename T> template <typename NumberT>
struct BitValues { struct BitValues {
using UT = UnwrapNumber<T>; using T = UnwrapNumber<NumberT>;
static constexpr size_t NumBits = sizeof(UT) * 8; struct detail {
static inline const T All = T{~T{0}}; using UT = std::make_unsigned_t<T>;
static inline const T LeftMost = T{T{1} << (NumBits - 1u)}; static constexpr size_t NumBits = sizeof(T) * 8;
static inline const T AllButLeftMost = T{~LeftMost}; static constexpr T All = T{~T{0}};
static inline const T RightMost = T{1}; static constexpr T LeftMost = static_cast<T>(UT{1} << (NumBits - 1u));
static inline const T AllButRightMost = T{~RightMost}; static constexpr T AllButLeftMost = T{~LeftMost};
static constexpr T TwoLeftMost = static_cast<T>(UT{0b11} << (NumBits - 2u));
static constexpr T AllButTwoLeftMost = T{~TwoLeftMost};
static constexpr T RightMost = T{1};
static constexpr T AllButRightMost = T{~RightMost};
};
static inline const size_t NumBits = detail::NumBits;
static inline const NumberT All = NumberT{detail::All};
static inline const NumberT LeftMost = NumberT{detail::LeftMost};
static inline const NumberT AllButLeftMost = NumberT{detail::AllButLeftMost};
static inline const NumberT TwoLeftMost = NumberT{detail::TwoLeftMost};
static inline const NumberT AllButTwoLeftMost = NumberT{detail::AllButTwoLeftMost};
static inline const NumberT RightMost = NumberT{detail::RightMost};
static inline const NumberT AllButRightMost = NumberT{detail::AllButRightMost};
template <typename U, typename V>
static constexpr NumberT Lsh(U val, V shiftBy) {
return NumberT{T{val} << T{shiftBy}};
}
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
@ -3182,7 +3201,7 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
GlobalConst("C", expr); GlobalConst("C", expr);
auto* expected_expr = expected.Expr(*this); auto* expected_expr = expected.Expr(*this);
GlobalConst("E", expected_expr); GlobalConst("E", expected_expr);
EXPECT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr); auto* sem = Sem().Get(expr);
const sem::Constant* value = sem->ConstantValue(); const sem::Constant* value = sem->ConstantValue();
@ -3712,6 +3731,48 @@ INSTANTIATE_TEST_SUITE_P(Xor,
XorCases<i32>(), // XorCases<i32>(), //
XorCases<u32>())))); XorCases<u32>()))));
template <typename T>
std::vector<Case> ShiftLeftCases() {
// Shift type is u32 for non-abstract
using ST = std::conditional_t<IsAbstract<T>, T, u32>;
using B = BitValues<T>;
return {
C(T{0b1010}, ST{0}, T{0b0000'0000'1010}), //
C(T{0b1010}, ST{1}, T{0b0000'0001'0100}), //
C(T{0b1010}, ST{2}, T{0b0000'0010'1000}), //
C(T{0b1010}, ST{3}, T{0b0000'0101'0000}), //
C(T{0b1010}, ST{4}, T{0b0000'1010'0000}), //
C(T{0b1010}, ST{5}, T{0b0001'0100'0000}), //
C(T{0b1010}, ST{6}, T{0b0010'1000'0000}), //
C(T{0b1010}, ST{7}, T{0b0101'0000'0000}), //
C(T{0b1010}, ST{8}, T{0b1010'0000'0000}), //
C(B::LeftMost, ST{0}, B::LeftMost), //
C(B::TwoLeftMost, ST{1}, B::LeftMost), // No overflow
C(B::All, ST{1}, B::AllButRightMost), // No overflow
C(B::All, ST{B::NumBits - 1}, B::LeftMost), // No overflow
C(Vec(T{0b1010}, T{0b1010}), //
Vec(ST{0}, ST{1}), //
Vec(T{0b0000'0000'1010}, T{0b0000'0001'0100})), //
C(Vec(T{0b1010}, T{0b1010}), //
Vec(ST{2}, ST{3}), //
Vec(T{0b0000'0010'1000}, T{0b0000'0101'0000})), //
C(Vec(T{0b1010}, T{0b1010}), //
Vec(ST{4}, ST{5}), //
Vec(T{0b0000'1010'0000}, T{0b0001'0100'0000})), //
C(Vec(T{0b1010}, T{0b1010}, T{0b1010}), //
Vec(ST{6}, ST{7}, ST{8}), //
Vec(T{0b0010'1000'0000}, T{0b0101'0000'0000}, T{0b1010'0000'0000})), //
};
}
INSTANTIATE_TEST_SUITE_P(ShiftLeft,
ResolverConstEvalBinaryOpTest,
testing::Combine( //
testing::Values(ast::BinaryOp::kShiftLeft),
testing::ValuesIn(Concat(ShiftLeftCases<AInt>(), //
ShiftLeftCases<i32>(), //
ShiftLeftCases<u32>()))));
// Tests for errors on overflow/underflow of binary operations with abstract numbers // Tests for errors on overflow/underflow of binary operations with abstract numbers
struct OverflowCase { struct OverflowCase {
ast::BinaryOp op; ast::BinaryOp op;
@ -3829,7 +3890,25 @@ INSTANTIATE_TEST_SUITE_P(
OverflowCase{ast::BinaryOp::kDivide, Val(123_a), Val(-0_a)}, OverflowCase{ast::BinaryOp::kDivide, Val(123_a), Val(-0_a)},
// Most negative value divided by -1 // Most negative value divided by -1
OverflowCase{ast::BinaryOp::kDivide, Val(AInt::Lowest()), Val(-1_a)} OverflowCase{ast::BinaryOp::kDivide, Val(AInt::Lowest()), Val(-1_a)},
// ShiftLeft of AInts that result in values not representable as AInts.
// Note that for i32/u32, these would error because shift value is larger than 32.
OverflowCase{ast::BinaryOp::kShiftLeft, //
Val(AInt{BitValues<AInt>::All}), //
Val(AInt{BitValues<AInt>::NumBits})}, //
OverflowCase{ast::BinaryOp::kShiftLeft, //
Val(AInt{BitValues<AInt>::RightMost}), //
Val(AInt{BitValues<AInt>::NumBits})}, //
OverflowCase{ast::BinaryOp::kShiftLeft, //
Val(AInt{BitValues<AInt>::AllButLeftMost}), //
Val(AInt{BitValues<AInt>::NumBits})}, //
OverflowCase{ast::BinaryOp::kShiftLeft, //
Val(AInt{BitValues<AInt>::AllButLeftMost}), //
Val(AInt{BitValues<AInt>::NumBits + 1})}, //
OverflowCase{ast::BinaryOp::kShiftLeft, //
Val(AInt{BitValues<AInt>::AllButLeftMost}), //
Val(AInt{BitValues<AInt>::NumBits + 1000})}
)); ));
@ -3893,6 +3972,78 @@ INSTANTIATE_TEST_SUITE_P(
{2.2_a, 4.3_a}, // {2.2_a, 4.3_a}, //
{2.2_a, 4.3_a})) // {2.2_a, 4.3_a})) //
))); )));
// AInt left shift negative value -> error
TEST_F(ResolverConstEvalTest, BinaryAbstractShiftLeftByNegativeValue_Error) {
GlobalConst("c", Shl(Source{{1, 1}}, Expr(1_a), Expr(-1_a)));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "1:1 error: cannot shift left by a negative value");
}
// i32/u32 left shift by >= 32 -> error
using ResolverConstEvalShiftLeftConcreteGeqBitWidthError =
ResolverTestWithParam<std::tuple<Types, Types>>;
TEST_P(ResolverConstEvalShiftLeftConcreteGeqBitWidthError, Test) {
auto* lhs_expr =
std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<0>(GetParam()));
auto* rhs_expr =
std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<1>(GetParam()));
GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"1:1 error: shift left value must be less than the bit width of the lhs, which is 32");
}
INSTANTIATE_TEST_SUITE_P(Test,
ResolverConstEvalShiftLeftConcreteGeqBitWidthError,
testing::Values( //
std::make_tuple(Val(1_i), Val(32_u)), //
std::make_tuple(Val(1_i), Val(33_u)), //
std::make_tuple(Val(1_i), Val(34_u)), //
std::make_tuple(Val(1_i), Val(99999999_u)), //
std::make_tuple(Val(1_u), Val(32_u)), //
std::make_tuple(Val(1_u), Val(33_u)), //
std::make_tuple(Val(1_u), Val(34_u)), //
std::make_tuple(Val(1_u), Val(99999999_u)) //
));
// AInt left shift results in sign change error
using ResolverConstEvalShiftLeftSignChangeError = ResolverTestWithParam<std::tuple<Types, Types>>;
TEST_P(ResolverConstEvalShiftLeftSignChangeError, Test) {
auto* lhs_expr =
std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<0>(GetParam()));
auto* rhs_expr =
std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<1>(GetParam()));
GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "1:1 error: shift left operation results in sign change");
}
template <typename T>
std::vector<std::tuple<Types, Types>> ShiftLeftSignChangeErrorCases() {
// Shift type is u32 for non-abstract
using ST = std::conditional_t<IsAbstract<T>, T, u32>;
using B = BitValues<T>;
return {
{Val(T{0b0001}), Val(ST{B::NumBits - 1})},
{Val(T{0b0010}), Val(ST{B::NumBits - 2})},
{Val(T{0b0100}), Val(ST{B::NumBits - 3})},
{Val(T{0b1000}), Val(ST{B::NumBits - 4})},
{Val(T{0b0011}), Val(ST{B::NumBits - 2})},
{Val(T{0b0110}), Val(ST{B::NumBits - 3})},
{Val(T{0b1100}), Val(ST{B::NumBits - 4})},
{Val(B::AllButLeftMost), Val(ST{1})},
{Val(B::AllButLeftMost), Val(ST{B::NumBits - 1})},
{Val(B::LeftMost), Val(ST{1})},
{Val(B::LeftMost), Val(ST{B::NumBits - 1})},
};
}
INSTANTIATE_TEST_SUITE_P(Test,
ResolverConstEvalShiftLeftSignChangeError,
testing::ValuesIn(Concat( //
ShiftLeftSignChangeErrorCases<AInt>(),
ShiftLeftSignChangeErrorCases<i32>(),
ShiftLeftSignChangeErrorCases<u32>())));
} // namespace binary_op } // namespace binary_op
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff