mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-21 18:59:21 +00:00
tint: const eval of extractBits
Bug: tint:1581 Change-Id: I56e9b7de9aef803eaf6304c122f40e5a0c4dce67 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/108203 Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
committed by
Dawn LUCI CQ
parent
58eca19f33
commit
11f0c52bfb
@@ -1726,6 +1726,61 @@ ConstEval::Result ConstEval::countTrailingZeros(const sem::Type* ty,
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::extractBits(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source) {
|
||||
auto transform = [&](const sem::Constant* c0) {
|
||||
auto create = [&](auto in_e) -> ImplResult {
|
||||
using NumberT = decltype(in_e);
|
||||
using T = UnwrapNumber<NumberT>;
|
||||
using UT = std::make_unsigned_t<T>;
|
||||
using NumberUT = Number<UT>;
|
||||
|
||||
// Read args that are always scalar
|
||||
NumberUT in_offset = args[1]->As<NumberUT>();
|
||||
NumberUT in_count = args[2]->As<NumberUT>();
|
||||
|
||||
constexpr UT w = sizeof(UT) * 8;
|
||||
if ((in_offset + in_count) > w) {
|
||||
AddError("'offset + 'count' must be less than or equal to the bit width of 'e'",
|
||||
source);
|
||||
return utils::Failure;
|
||||
}
|
||||
|
||||
// Cast all to unsigned
|
||||
UT e = static_cast<UT>(in_e);
|
||||
UT o = static_cast<UT>(in_offset);
|
||||
UT c = static_cast<UT>(in_count);
|
||||
|
||||
NumberT result;
|
||||
if (c == UT{0}) {
|
||||
// The result is 0 if c is 0
|
||||
result = NumberT{0};
|
||||
} else if (c == w) {
|
||||
// The result is e if c is w
|
||||
result = NumberT{e};
|
||||
} else {
|
||||
// Otherwise, bits 0..c - 1 of the result are copied from bits o..o + c - 1 of e.
|
||||
UT src_mask = ((UT{1} << c) - UT{1}) << o;
|
||||
UT r = (e & src_mask) >> o;
|
||||
if constexpr (IsSignedIntegral<NumberT>) {
|
||||
// Other bits of the result are the same as bit c - 1 of the result.
|
||||
// Only need to set other bits if bit at c - 1 of result is 1
|
||||
if ((r & (UT{1} << (c - UT{1}))) != UT{0}) {
|
||||
UT dst_mask = src_mask >> o;
|
||||
r = r | (~UT{0} & ~dst_mask);
|
||||
}
|
||||
}
|
||||
|
||||
result = NumberT{r};
|
||||
}
|
||||
return CreateElement(builder, c0->Type(), result);
|
||||
};
|
||||
return Dispatch_iu32(create, c0);
|
||||
};
|
||||
return TransformElements(builder, ty, transform, args[0]);
|
||||
}
|
||||
|
||||
ConstEval::Result ConstEval::firstLeadingBit(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source&) {
|
||||
|
||||
@@ -476,6 +476,15 @@ class ConstEval {
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// extractBits builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
/// @param source the source location of the conversion
|
||||
/// @return the result value, or null if the value cannot be calculated
|
||||
Result extractBits(const sem::Type* ty,
|
||||
utils::VectorRef<const sem::Constant*> args,
|
||||
const Source& source);
|
||||
|
||||
/// firstLeadingBit builtin
|
||||
/// @param ty the expression type
|
||||
/// @param args the input arguments
|
||||
|
||||
@@ -830,6 +830,109 @@ INSTANTIATE_TEST_SUITE_P(InsertBits,
|
||||
std::make_tuple(1000, 1000), //
|
||||
std::make_tuple(u32::Highest(), u32::Highest())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> ExtractBitsCases() {
|
||||
using UT = Number<std::make_unsigned_t<UnwrapNumber<T>>>;
|
||||
|
||||
// If T is signed, fills most significant bits of `val` with 1s
|
||||
auto set_msbs_if_signed = [](T val) {
|
||||
if constexpr (IsSignedIntegral<T>) {
|
||||
T result = T(~0);
|
||||
for (size_t b = 0; val; ++b) {
|
||||
if ((val & 1) == 0) {
|
||||
result = result & ~(1 << b); // Clear bit b
|
||||
}
|
||||
val = val >> 1;
|
||||
}
|
||||
return result;
|
||||
} else {
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
auto e = T(0b10100011110001011010001111000101);
|
||||
auto f = T(0b01010101010101010101010101010101);
|
||||
auto g = T(0b11111010001111000101101000111100);
|
||||
|
||||
auto r = std::vector<Case>{
|
||||
// args: e, offset, count
|
||||
|
||||
// If count is 0, result is 0
|
||||
C({e, UT(0), UT(0)}, T(0)), //
|
||||
C({e, UT(1), UT(0)}, T(0)), //
|
||||
C({e, UT(2), UT(0)}, T(0)), //
|
||||
C({e, UT(3), UT(0)}, T(0)),
|
||||
// ...
|
||||
C({e, UT(29), UT(0)}, T(0)), //
|
||||
C({e, UT(30), UT(0)}, T(0)), //
|
||||
C({e, UT(31), UT(0)}, T(0)),
|
||||
|
||||
// Extract at offset 0, varying counts
|
||||
C({e, UT(0), UT(1)}, set_msbs_if_signed(T(0b1))), //
|
||||
C({e, UT(0), UT(2)}, T(0b01)), //
|
||||
C({e, UT(0), UT(3)}, set_msbs_if_signed(T(0b101))), //
|
||||
C({e, UT(0), UT(4)}, T(0b0101)), //
|
||||
C({e, UT(0), UT(5)}, T(0b00101)), //
|
||||
C({e, UT(0), UT(6)}, T(0b000101)), //
|
||||
// ...
|
||||
C({e, UT(0), UT(28)}, T(0b0011110001011010001111000101)), //
|
||||
C({e, UT(0), UT(29)}, T(0b00011110001011010001111000101)), //
|
||||
C({e, UT(0), UT(30)}, set_msbs_if_signed(T(0b100011110001011010001111000101))), //
|
||||
C({e, UT(0), UT(31)}, T(0b0100011110001011010001111000101)), //
|
||||
C({e, UT(0), UT(32)}, T(0b10100011110001011010001111000101)), //
|
||||
|
||||
// Extract at varying offsets and counts
|
||||
C({e, UT(0), UT(1)}, set_msbs_if_signed(T(0b1))), //
|
||||
C({e, UT(31), UT(1)}, set_msbs_if_signed(T(0b1))), //
|
||||
C({e, UT(3), UT(5)}, set_msbs_if_signed(T(0b11000))), //
|
||||
C({e, UT(4), UT(7)}, T(0b0111100)), //
|
||||
C({e, UT(10), UT(16)}, set_msbs_if_signed(T(0b1111000101101000))), //
|
||||
C({e, UT(10), UT(22)}, set_msbs_if_signed(T(0b1010001111000101101000))),
|
||||
|
||||
// Vector tests
|
||||
C({Vec(e, f, g), //
|
||||
Val(UT(5)), Val(UT(8))}, //
|
||||
Vec(T(0b00011110), //
|
||||
set_msbs_if_signed(T(0b10101010)), //
|
||||
set_msbs_if_signed(T(0b11010001)))),
|
||||
};
|
||||
|
||||
return r;
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
ExtractBits,
|
||||
ResolverConstEvalBuiltinTest,
|
||||
testing::Combine(testing::Values(sem::BuiltinType::kExtractBits),
|
||||
testing::ValuesIn(Concat(ExtractBitsCases<i32>(), //
|
||||
ExtractBitsCases<u32>()))));
|
||||
|
||||
using ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount =
|
||||
ResolverTestWithParam<std::tuple<size_t, size_t>>;
|
||||
TEST_P(ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount, Test) {
|
||||
auto& p = GetParam();
|
||||
auto* expr = Call(Source{{12, 24}}, sem::str(sem::BuiltinType::kExtractBits), Expr(1_u),
|
||||
Expr(u32(std::get<0>(p))), Expr(u32(std::get<1>(p))));
|
||||
GlobalConst("C", expr);
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:24 error: 'offset + 'count' must be less than or equal to the bit width of 'e'");
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ExtractBits,
|
||||
ResolverConstEvalBuiltinTest_ExtractBits_InvalidOffsetAndCount,
|
||||
testing::Values( //
|
||||
std::make_tuple(33, 0), //
|
||||
std::make_tuple(34, 0), //
|
||||
std::make_tuple(1000, 0), //
|
||||
std::make_tuple(u32::Highest(), 0), //
|
||||
std::make_tuple(0, 33), //
|
||||
std::make_tuple(0, 34), //
|
||||
std::make_tuple(0, 1000), //
|
||||
std::make_tuple(0, u32::Highest()), //
|
||||
std::make_tuple(33, 33), //
|
||||
std::make_tuple(34, 34), //
|
||||
std::make_tuple(1000, 1000), //
|
||||
std::make_tuple(u32::Highest(), u32::Highest())));
|
||||
|
||||
template <typename T>
|
||||
std::vector<Case> SaturateCases() {
|
||||
return {
|
||||
|
||||
@@ -12230,7 +12230,7 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* parameters */ &kParameters[594],
|
||||
/* return matcher indices */ &kMatcherIndices[1],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::extractBits,
|
||||
},
|
||||
{
|
||||
/* [325] */
|
||||
@@ -12242,7 +12242,7 @@ constexpr OverloadInfo kOverloads[] = {
|
||||
/* parameters */ &kParameters[459],
|
||||
/* return matcher indices */ &kMatcherIndices[30],
|
||||
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
|
||||
/* const eval */ nullptr,
|
||||
/* const eval */ &ConstEval::extractBits,
|
||||
},
|
||||
{
|
||||
/* [326] */
|
||||
|
||||
Reference in New Issue
Block a user