tint: const eval of extractBits

Bug: tint:1581
Change-Id: I56e9b7de9aef803eaf6304c122f40e5a0c4dce67
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/108203
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano
2022-11-03 13:03:08 +00:00
committed by Dawn LUCI CQ
parent 58eca19f33
commit 11f0c52bfb
46 changed files with 395 additions and 738 deletions

View File

@@ -1726,6 +1726,61 @@ ConstEval::Result ConstEval::countTrailingZeros(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::extractBits(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto in_e) -> ImplResult {
using NumberT = decltype(in_e);
using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>;
using NumberUT = Number<UT>;
// Read args that are always scalar
NumberUT in_offset = args[1]->As<NumberUT>();
NumberUT in_count = args[2]->As<NumberUT>();
constexpr UT w = sizeof(UT) * 8;
if ((in_offset + in_count) > w) {
AddError("'offset + 'count' must be less than or equal to the bit width of 'e'",
source);
return utils::Failure;
}
// Cast all to unsigned
UT e = static_cast<UT>(in_e);
UT o = static_cast<UT>(in_offset);
UT c = static_cast<UT>(in_count);
NumberT result;
if (c == UT{0}) {
// The result is 0 if c is 0
result = NumberT{0};
} else if (c == w) {
// The result is e if c is w
result = NumberT{e};
} else {
// Otherwise, bits 0..c - 1 of the result are copied from bits o..o + c - 1 of e.
UT src_mask = ((UT{1} << c) - UT{1}) << o;
UT r = (e & src_mask) >> o;
if constexpr (IsSignedIntegral<NumberT>) {
// Other bits of the result are the same as bit c - 1 of the result.
// Only need to set other bits if bit at c - 1 of result is 1
if ((r & (UT{1} << (c - UT{1}))) != UT{0}) {
UT dst_mask = src_mask >> o;
r = r | (~UT{0} & ~dst_mask);
}
}
result = NumberT{r};
}
return CreateElement(builder, c0->Type(), result);
};
return Dispatch_iu32(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::firstLeadingBit(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@@ -476,6 +476,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// extractBits builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result extractBits(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// firstLeadingBit builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -830,6 +830,109 @@ INSTANTIATE_TEST_SUITE_P(InsertBits,
std::make_tuple(1000, 1000), //
std::make_tuple(u32::Highest(), u32::Highest())));
template <typename T>
std::vector<Case> ExtractBitsCases() {
using UT = Number<std::make_unsigned_t<UnwrapNumber<T>>>;
// If T is signed, fills most significant bits of `val` with 1s
auto set_msbs_if_signed = [](T val) {
if constexpr (IsSignedIntegral<T>) {
T result = T(~0);
for (size_t b = 0; val; ++b) {
if ((val & 1) == 0) {
result = result & ~(1 << b); // Clear bit b
}
val = val >> 1;
}
return result;
} else {
return val;
}
};
auto e = T(0b10100011110001011010001111000101);
auto f = T(0b01010101010101010101010101010101);
auto g = T(0b11111010001111000101101000111100);
auto r = std::vector<Case>{
// args: e, offset, count
// If count is 0, result is 0
C({e, UT(0), UT(0)}, T(0)), //
C({e, UT(1), UT(0)}, T(0)), //
C({e, UT(2), UT(0)}, T(0)), //
C({e, UT(3), UT(0)}, T(0)),
// ...
C({e, UT(29), UT(0)}, T(0)), //
C({e, UT(30), UT(0)}, T(0)), //
C({e, UT(31), UT(0)}, T(0)),
// Extract at offset 0, varying counts
C({e, UT(0), UT(1)}, set_msbs_if_signed(T(0b1))), //
C({e, UT(0), UT(2)}, T(0b01)), //
C({e, UT(0), UT(3)}, set_msbs_if_signed(T(0b101))), //
C({e, UT(0), UT(4)}, T(0b0101)), //
C({e, UT(0), UT(5)}, T(0b00101)), //
C({e, UT(0), UT(6)}, T(0b000101)), //
// ...
C({e, UT(0), UT(28)}, T(0b0011110001011010001111000101)), //
C({e, UT(0), UT(29)}, T(0b00011110001011010001111000101)), //
C({e, UT(0), UT(30)}, set_msbs_if_signed(T(0b100011110001011010001111000101))), //
C({e, UT(0), UT(31)}, T(0b0100011110001011010001111000101)), //
C({e, UT(0), UT(32)}, T(0b10100011110001011010001111000101)), //
// Extract at varying offsets and counts
C({e, UT(0), UT(1)}, set_msbs_if_signed(T(0b1))), //
C({e, UT(31), UT(1)}, set_msbs_if_signed(T(0b1))), //
C({e, UT(3), UT(5)}, set_msbs_if_signed(T(0b11000))), //
C({e, UT(4), UT(7)}, T(0b0111100)), //
C({e, UT(10), UT(16)}, set_msbs_if_signed(T(0b1111000101101000))), //
C({e, UT(10), UT(22)}, set_msbs_if_signed(T(0b1010001111000101101000))),
// Vector tests
C({Vec(e, f, g), //
Val(UT(5)), Val(UT(8))}, //
Vec(T(0b00011110), //
set_msbs_if_signed(T(0b10101010)), //
set_msbs_if_signed(T(0b11010001)))),
};
return r;
}
INSTANTIATE_TEST_SUITE_P( //
ExtractBits,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kExtractBits),
testing::ValuesIn(Concat(ExtractBitsCases<i32>(), //
ExtractBitsCases<u32>()))));
using ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount =
ResolverTestWithParam<std::tuple<size_t, size_t>>;
TEST_P(ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount, Test) {
auto& p = GetParam();
auto* expr = Call(Source{{12, 24}}, sem::str(sem::BuiltinType::kExtractBits), Expr(1_u),
Expr(u32(std::get<0>(p))), Expr(u32(std::get<1>(p))));
GlobalConst("C", expr);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:24 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
}
INSTANTIATE_TEST_SUITE_P(ExtractBits,
ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount,
testing::Values( //
std::make_tuple(33, 0), //
std::make_tuple(34, 0), //
std::make_tuple(1000, 0), //
std::make_tuple(u32::Highest(), 0), //
std::make_tuple(0, 33), //
std::make_tuple(0, 34), //
std::make_tuple(0, 1000), //
std::make_tuple(0, u32::Highest()), //
std::make_tuple(33, 33), //
std::make_tuple(34, 34), //
std::make_tuple(1000, 1000), //
std::make_tuple(u32::Highest(), u32::Highest())));
template <typename T>
std::vector<Case> SaturateCases() {
return {

View File

@@ -12230,7 +12230,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[594],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::extractBits,
},
{
/* [325] */
@@ -12242,7 +12242,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[459],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::extractBits,
},
{
/* [326] */