diff --git a/src/tint/sem/constant.cc b/src/tint/sem/constant.cc index 1fa83f5046..a2cfe8fa75 100644 --- a/src/tint/sem/constant.cc +++ b/src/tint/sem/constant.cc @@ -45,9 +45,9 @@ Constant& Constant::operator=(const Constant& rhs) = default; bool Constant::AnyZero() const { return WithElements([&](auto&& vec) { - for (auto scalar : vec) { - using T = std::remove_reference_t; - if (scalar == T(0)) { + using T = typename std::decay_t::value_type; + for (auto el : vec) { + if (el == T(0)) { return true; } } @@ -55,6 +55,32 @@ bool Constant::AnyZero() const { }); } +bool Constant::AllZero() const { + return WithElements([&](auto&& vec) { + using T = typename std::decay_t::value_type; + for (auto el : vec) { + if (el != T(0)) { + return false; + } + } + return true; + }); +} + +bool Constant::AllEqual(size_t start, size_t end) const { + return WithElements([&](auto&& vec) { + if (!vec.empty()) { + auto value = vec[start]; + for (size_t i = start + 1; i < end; i++) { + if (vec[i] != value) { + return false; + } + } + } + return true; + }); +} + const Type* Constant::CheckElemType(const sem::Type* ty, size_t num_elements) { diag::List diag; if (ty->is_abstract_or_scalar() || ty->IsAnyOf()) { diff --git a/src/tint/sem/constant.h b/src/tint/sem/constant.h index 43109f7ccd..c99b97920e 100644 --- a/src/tint/sem/constant.h +++ b/src/tint/sem/constant.h @@ -132,6 +132,17 @@ class Constant { /// @returns true if any element is zero bool AnyZero() const; + /// @returns true if all elements are zero + bool AllZero() const; + + /// @returns true if all elements are the same value + bool AllEqual() const { return AllEqual(0, ElementCount()); } + + /// @param start the first element index + /// @param end one past the last element index + /// @returns true if all elements between `[start, end)` are the same value + bool AllEqual(size_t start, size_t end) const; + /// @param index the index of the element /// @return the element at `index`, which must be of type `T`. template diff --git a/src/tint/sem/constant_test.cc b/src/tint/sem/constant_test.cc index 345ebd807b..6ad3cd6491 100644 --- a/src/tint/sem/constant_test.cc +++ b/src/tint/sem/constant_test.cc @@ -195,5 +195,43 @@ TEST_F(ConstantTest, Element_mat2x3_f16) { EXPECT_EQ(c.ElementCount(), 6u); } +TEST_F(ConstantTest, AnyZero) { + auto* vec3_ai = create(create(), 3u); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AnyZero(), false); + EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AnyZero(), true); + EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 3_a}).AnyZero(), true); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AnyZero(), true); + EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AnyZero(), true); +} + +TEST_F(ConstantTest, AllZero) { + auto* vec3_ai = create(create(), 3u); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllZero(), false); + EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AllZero(), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 3_a}).AllZero(), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AllZero(), false); + EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AllZero(), true); +} + +TEST_F(ConstantTest, AllEqual) { + auto* vec3_ai = create(create(), 3u); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllEqual(), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 3_a}).AllEqual(), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 3_a, 3_a}).AllEqual(), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 1_a}).AllEqual(), true); + EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 2_a}).AllEqual(), true); + EXPECT_EQ(Constant(vec3_ai, {3_a, 3_a, 3_a}).AllEqual(), true); +} + +TEST_F(ConstantTest, AllEqualRange) { + auto* vec3_ai = create(create(), 3u); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllEqual(1, 3), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 3_a}).AllEqual(1, 3), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 3_a, 3_a}).AllEqual(1, 3), true); + EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 1_a}).AllEqual(1, 3), true); + EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 2_a}).AllEqual(1, 3), true); + EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 3_a}).AllEqual(1, 3), false); +} + } // namespace } // namespace tint::sem