tint/sem: Add more helpers to Constant

Many backends can produce cleaner code if all the elements are zero or
the same value.

Bug: tint:1504
Change-Id: Iff3227884473b0be42395e4a637a7fe0b7a1b238
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91966
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-05-31 21:23:29 +00:00 committed by Dawn LUCI CQ
parent 6b52f9d1d4
commit bfb5fd794c
3 changed files with 78 additions and 3 deletions

View File

@ -45,9 +45,9 @@ Constant& Constant::operator=(const Constant& rhs) = default;
bool Constant::AnyZero() const {
return WithElements([&](auto&& vec) {
for (auto scalar : vec) {
using T = std::remove_reference_t<decltype(scalar)>;
if (scalar == T(0)) {
using T = typename std::decay_t<decltype(vec)>::value_type;
for (auto el : vec) {
if (el == T(0)) {
return true;
}
}
@ -55,6 +55,32 @@ bool Constant::AnyZero() const {
});
}
bool Constant::AllZero() const {
return WithElements([&](auto&& vec) {
using T = typename std::decay_t<decltype(vec)>::value_type;
for (auto el : vec) {
if (el != T(0)) {
return false;
}
}
return true;
});
}
bool Constant::AllEqual(size_t start, size_t end) const {
return WithElements([&](auto&& vec) {
if (!vec.empty()) {
auto value = vec[start];
for (size_t i = start + 1; i < end; i++) {
if (vec[i] != value) {
return false;
}
}
}
return true;
});
}
const Type* Constant::CheckElemType(const sem::Type* ty, size_t num_elements) {
diag::List diag;
if (ty->is_abstract_or_scalar() || ty->IsAnyOf<Vector, Matrix>()) {

View File

@ -132,6 +132,17 @@ class Constant {
/// @returns true if any element is zero
bool AnyZero() const;
/// @returns true if all elements are zero
bool AllZero() const;
/// @returns true if all elements are the same value
bool AllEqual() const { return AllEqual(0, ElementCount()); }
/// @param start the first element index
/// @param end one past the last element index
/// @returns true if all elements between `[start, end)` are the same value
bool AllEqual(size_t start, size_t end) const;
/// @param index the index of the element
/// @return the element at `index`, which must be of type `T`.
template <typename T>

View File

@ -195,5 +195,43 @@ TEST_F(ConstantTest, Element_mat2x3_f16) {
EXPECT_EQ(c.ElementCount(), 6u);
}
TEST_F(ConstantTest, AnyZero) {
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AnyZero(), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 3_a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AnyZero(), true);
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AnyZero(), true);
}
TEST_F(ConstantTest, AllZero) {
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 3_a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AllZero(), false);
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AllZero(), true);
}
TEST_F(ConstantTest, AllEqual) {
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllEqual(), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 3_a}).AllEqual(), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 3_a, 3_a}).AllEqual(), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 1_a}).AllEqual(), true);
EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 2_a}).AllEqual(), true);
EXPECT_EQ(Constant(vec3_ai, {3_a, 3_a, 3_a}).AllEqual(), true);
}
TEST_F(ConstantTest, AllEqualRange) {
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 3_a}).AllEqual(1, 3), false);
EXPECT_EQ(Constant(vec3_ai, {1_a, 3_a, 3_a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 1_a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 2_a}).AllEqual(1, 3), true);
EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 3_a}).AllEqual(1, 3), false);
}
} // namespace
} // namespace tint::sem