tint/sem: Add more helpers to Constant
Many backends can produce cleaner code if all the elements are zero or the same value. Bug: tint:1504 Change-Id: Iff3227884473b0be42395e4a637a7fe0b7a1b238 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91966 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
6b52f9d1d4
commit
bfb5fd794c
|
@ -45,9 +45,9 @@ Constant& Constant::operator=(const Constant& rhs) = default;
|
|||
|
||||
bool Constant::AnyZero() const {
|
||||
return WithElements([&](auto&& vec) {
|
||||
for (auto scalar : vec) {
|
||||
using T = std::remove_reference_t<decltype(scalar)>;
|
||||
if (scalar == T(0)) {
|
||||
using T = typename std::decay_t<decltype(vec)>::value_type;
|
||||
for (auto el : vec) {
|
||||
if (el == T(0)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -55,6 +55,32 @@ bool Constant::AnyZero() const {
|
|||
});
|
||||
}
|
||||
|
||||
bool Constant::AllZero() const {
|
||||
return WithElements([&](auto&& vec) {
|
||||
using T = typename std::decay_t<decltype(vec)>::value_type;
|
||||
for (auto el : vec) {
|
||||
if (el != T(0)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
bool Constant::AllEqual(size_t start, size_t end) const {
|
||||
return WithElements([&](auto&& vec) {
|
||||
if (!vec.empty()) {
|
||||
auto value = vec[start];
|
||||
for (size_t i = start + 1; i < end; i++) {
|
||||
if (vec[i] != value) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
const Type* Constant::CheckElemType(const sem::Type* ty, size_t num_elements) {
|
||||
diag::List diag;
|
||||
if (ty->is_abstract_or_scalar() || ty->IsAnyOf<Vector, Matrix>()) {
|
||||
|
|
|
@ -132,6 +132,17 @@ class Constant {
|
|||
/// @returns true if any element is zero
|
||||
bool AnyZero() const;
|
||||
|
||||
/// @returns true if all elements are zero
|
||||
bool AllZero() const;
|
||||
|
||||
/// @returns true if all elements are the same value
|
||||
bool AllEqual() const { return AllEqual(0, ElementCount()); }
|
||||
|
||||
/// @param start the first element index
|
||||
/// @param end one past the last element index
|
||||
/// @returns true if all elements between `[start, end)` are the same value
|
||||
bool AllEqual(size_t start, size_t end) const;
|
||||
|
||||
/// @param index the index of the element
|
||||
/// @return the element at `index`, which must be of type `T`.
|
||||
template <typename T>
|
||||
|
|
|
@ -195,5 +195,43 @@ TEST_F(ConstantTest, Element_mat2x3_f16) {
|
|||
EXPECT_EQ(c.ElementCount(), 6u);
|
||||
}
|
||||
|
||||
TEST_F(ConstantTest, AnyZero) {
|
||||
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AnyZero(), false);
|
||||
EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AnyZero(), true);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 3_a}).AnyZero(), true);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AnyZero(), true);
|
||||
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AnyZero(), true);
|
||||
}
|
||||
|
||||
TEST_F(ConstantTest, AllZero) {
|
||||
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllZero(), false);
|
||||
EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AllZero(), false);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 3_a}).AllZero(), false);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AllZero(), false);
|
||||
EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AllZero(), true);
|
||||
}
|
||||
|
||||
TEST_F(ConstantTest, AllEqual) {
|
||||
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllEqual(), false);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 3_a}).AllEqual(), false);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 3_a, 3_a}).AllEqual(), false);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 1_a}).AllEqual(), true);
|
||||
EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 2_a}).AllEqual(), true);
|
||||
EXPECT_EQ(Constant(vec3_ai, {3_a, 3_a, 3_a}).AllEqual(), true);
|
||||
}
|
||||
|
||||
TEST_F(ConstantTest, AllEqualRange) {
|
||||
auto* vec3_ai = create<Vector>(create<AbstractInt>(), 3u);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllEqual(1, 3), false);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 3_a}).AllEqual(1, 3), false);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 3_a, 3_a}).AllEqual(1, 3), true);
|
||||
EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 1_a}).AllEqual(1, 3), true);
|
||||
EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 2_a}).AllEqual(1, 3), true);
|
||||
EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 3_a}).AllEqual(1, 3), false);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::sem
|
||||
|
|
Loading…
Reference in New Issue