tint: const eval of firstLeadingBit

Bug: tint:1581
Change-Id: I33c87ced173938bcd16e00debdd5c6682b4a9426
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107763
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano
2022-10-31 20:34:36 +00:00
committed by Dawn LUCI CQ
parent 76c21c070b
commit 1abe52dc1c
46 changed files with 362 additions and 1584 deletions

View File

@@ -467,8 +467,8 @@ fn exp2<N: num, T: f32_f16>(vec<N, T>) -> vec<N, T>
fn extractBits<T: iu32>(T, u32, u32) -> T
fn extractBits<N: num, T: iu32>(vec<N, T>, u32, u32) -> vec<N, T>
fn faceForward<N: num, T: f32_f16>(vec<N, T>, vec<N, T>, vec<N, T>) -> vec<N, T>
fn firstLeadingBit<T: iu32>(T) -> T
fn firstLeadingBit<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
@const fn firstLeadingBit<T: iu32>(T) -> T
@const fn firstLeadingBit<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
fn firstTrailingBit<T: iu32>(T) -> T
fn firstTrailingBit<N: num, T: iu32>(vec<N, T>) -> vec<N, T>
fn floor<T: f32_f16>(T) -> T

View File

@@ -192,6 +192,23 @@ std::string OverflowErrorMessage(NumberT lhs, const char* op, NumberT rhs) {
return ss.str();
}
/// @returns the number of consecutive leading bits in `@p e` set to `@p bit_value_to_count`.
template <typename T>
auto CountLeadingBits(T e, T bit_value_to_count) -> std::make_unsigned_t<T> {
using UT = std::make_unsigned_t<T>;
constexpr UT kNumBits = sizeof(UT) * 8;
constexpr UT kLeftMost = UT{1} << (kNumBits - 1);
const UT b = bit_value_to_count == 0 ? UT{0} : kLeftMost;
auto v = static_cast<UT>(e);
auto count = UT{0};
while ((count < kNumBits) && ((v & kLeftMost) == b)) {
++count;
v <<= 1;
}
return count;
}
/// ImplConstant inherits from sem::Constant to add an private implementation method for conversion.
struct ImplConstant : public sem::Constant {
/// Convert attempts to convert the constant value to the given type. On error, Convert()
@@ -1639,17 +1656,7 @@ ConstEval::Result ConstEval::countLeadingZeros(const sem::Type* ty,
auto create = [&](auto e) {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>;
constexpr UT kNumBits = sizeof(UT) * 8;
constexpr UT kLeftMost = UT{1} << (kNumBits - 1);
auto v = static_cast<UT>(e);
auto count = UT{0};
while ((count < kNumBits) && ((v & kLeftMost) == 0)) {
++count;
v <<= 1;
}
auto count = CountLeadingBits(T{e}, T{0});
return CreateElement(builder, c0->Type(), NumberT(count));
};
return Dispatch_iu32(create, c0);
@@ -1706,6 +1713,50 @@ ConstEval::Result ConstEval::countTrailingZeros(const sem::Type* ty,
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::firstLeadingBit(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {
auto transform = [&](const sem::Constant* c0) {
auto create = [&](auto e) {
using NumberT = decltype(e);
using T = UnwrapNumber<NumberT>;
using UT = std::make_unsigned_t<T>;
constexpr UT kNumBits = sizeof(UT) * 8;
NumberT result;
if constexpr (IsUnsignedIntegral<T>) {
if (e == T{0}) {
// T(-1) if e is zero.
result = NumberT(static_cast<T>(-1));
} else {
// Otherwise the position of the most significant 1 bit in e.
static_assert(std::is_same_v<T, UT>);
UT count = CountLeadingBits(UT{e}, UT{0});
UT pos = kNumBits - count - 1;
result = NumberT(pos);
}
} else {
if (e == T{0} || e == T{-1}) {
// -1 if e is 0 or -1.
result = NumberT(-1);
} else {
// Otherwise the position of the most significant bit in e that is different
// from e's sign bit.
UT eu = static_cast<UT>(e);
UT sign_bit = eu >> (kNumBits - 1);
UT count = CountLeadingBits(eu, sign_bit);
UT pos = kNumBits - count - 1;
result = NumberT(pos);
}
}
return CreateElement(builder, c0->Type(), result);
};
return Dispatch_iu32(create, c0);
};
return TransformElements(builder, ty, transform, args[0]);
}
ConstEval::Result ConstEval::saturate(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source&) {

View File

@@ -467,6 +467,15 @@ class ConstEval {
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// firstLeadingBit builtin
/// @param ty the expression type
/// @param args the input arguments
/// @param source the source location of the conversion
/// @return the result value, or null if the value cannot be calculated
Result firstLeadingBit(const sem::Type* ty,
utils::VectorRef<const sem::Constant*> args,
const Source& source);
/// saturate builtin
/// @param ty the expression type
/// @param args the input arguments

View File

@@ -618,6 +618,68 @@ INSTANTIATE_TEST_SUITE_P( //
testing::ValuesIn(Concat(CountOneBitsCases<i32>(), //
CountOneBitsCases<u32>()))));
template <typename T>
std::vector<Case> FirstLeadingBitCases() {
using B = BitValues<T>;
auto r = std::vector<Case>{
// Both signed and unsigned return T(-1) for input 0
C({T(0)}, T(-1)),
C({B::Lsh(1, 30)}, T(30)), //
C({B::Lsh(1, 29)}, T(29)), //
C({B::Lsh(1, 28)}, T(28)),
//...
C({B::Lsh(1, 3)}, T(3)), //
C({B::Lsh(1, 2)}, T(2)), //
C({B::Lsh(1, 1)}, T(1)), //
C({B::Lsh(1, 0)}, T(0)),
C({T(0b0000'0000'0100'1000'1000'1000'0000'0000)}, T(22)),
C({T(0b0000'0000'0000'0100'1000'1000'0000'0000)}, T(18)),
// Vector tests
C({Vec(B::Lsh(1, 30), B::Lsh(1, 29), B::Lsh(1, 28))}, Vec(T(30), T(29), T(28))),
C({Vec(B::Lsh(1, 2), B::Lsh(1, 1), B::Lsh(1, 0))}, Vec(T(2), T(1), T(0))),
};
ConcatIntoIf<IsUnsignedIntegral<T>>( //
r, std::vector<Case>{
C({B::Lsh(1, 31)}, T(31)),
C({T(0b1111'1111'1111'1111'1111'1111'1111'1110)}, T(31)),
C({T(0b1111'1111'1111'1111'1111'1111'1111'1100)}, T(31)),
C({T(0b1111'1111'1111'1111'1111'1111'1111'1000)}, T(31)),
//...
C({T(0b1110'0000'0000'0000'0000'0000'0000'0000)}, T(31)),
C({T(0b1100'0000'0000'0000'0000'0000'0000'0000)}, T(31)),
C({T(0b1000'0000'0000'0000'0000'0000'0000'0000)}, T(31)),
});
ConcatIntoIf<IsSignedIntegral<T>>( //
r, std::vector<Case>{
// Signed returns -1 for input -1
C({T(-1)}, T(-1)),
C({B::Lsh(1, 31)}, T(30)),
C({T(0b1111'1111'1111'1111'1111'1111'1111'1110)}, T(0)),
C({T(0b1111'1111'1111'1111'1111'1111'1111'1100)}, T(1)),
C({T(0b1111'1111'1111'1111'1111'1111'1111'1000)}, T(2)),
//...
C({T(0b1110'0000'0000'0000'0000'0000'0000'0000)}, T(28)),
C({T(0b1100'0000'0000'0000'0000'0000'0000'0000)}, T(29)),
C({T(0b1000'0000'0000'0000'0000'0000'0000'0000)}, T(30)),
});
return r;
}
INSTANTIATE_TEST_SUITE_P( //
FirstLeadingBit,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kFirstLeadingBit),
testing::ValuesIn(Concat(FirstLeadingBitCases<i32>(), //
FirstLeadingBitCases<u32>()))));
template <typename T>
std::vector<Case> SaturateCases() {
return {

View File

@@ -12196,7 +12196,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[910],
/* return matcher indices */ &kMatcherIndices[1],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::firstLeadingBit,
},
{
/* [323] */
@@ -12208,7 +12208,7 @@ constexpr OverloadInfo kOverloads[] = {
/* parameters */ &kParameters[909],
/* return matcher indices */ &kMatcherIndices[30],
/* flags */ OverloadFlags(OverloadFlag::kIsBuiltin, OverloadFlag::kSupportsVertexPipeline, OverloadFlag::kSupportsFragmentPipeline, OverloadFlag::kSupportsComputePipeline),
/* const eval */ nullptr,
/* const eval */ &ConstEval::firstLeadingBit,
},
{
/* [324] */