From bfb5fd794c13e846e40a4cb89e7dc20fbc1275a9 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Tue, 31 May 2022 21:23:29 +0000 Subject: [PATCH] tint/sem: Add more helpers to Constant Many backends can produce cleaner code if all the elements are zero or the same value. Bug: tint:1504 Change-Id: Iff3227884473b0be42395e4a637a7fe0b7a1b238 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91966 Reviewed-by: Antonio Maiorano Commit-Queue: Ben Clayton --- src/tint/sem/constant.cc | 32 ++++++++++++++++++++++++++--- src/tint/sem/constant.h | 11 ++++++++++ src/tint/sem/constant_test.cc | 38 +++++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 3 deletions(-) diff --git a/src/tint/sem/constant.cc b/src/tint/sem/constant.cc index 1fa83f5046..a2cfe8fa75 100644 --- a/src/tint/sem/constant.cc +++ b/src/tint/sem/constant.cc @@ -45,9 +45,9 @@ Constant& Constant::operator=(const Constant& rhs) = default; bool Constant::AnyZero() const { return WithElements([&](auto&& vec) { - for (auto scalar : vec) { - using T = std::remove_reference_t; - if (scalar == T(0)) { + using T = typename std::decay_t::value_type; + for (auto el : vec) { + if (el == T(0)) { return true; } } @@ -55,6 +55,32 @@ bool Constant::AnyZero() const { }); } +bool Constant::AllZero() const { + return WithElements([&](auto&& vec) { + using T = typename std::decay_t::value_type; + for (auto el : vec) { + if (el != T(0)) { + return false; + } + } + return true; + }); +} + +bool Constant::AllEqual(size_t start, size_t end) const { + return WithElements([&](auto&& vec) { + if (!vec.empty()) { + auto value = vec[start]; + for (size_t i = start + 1; i < end; i++) { + if (vec[i] != value) { + return false; + } + } + } + return true; + }); +} + const Type* Constant::CheckElemType(const sem::Type* ty, size_t num_elements) { diag::List diag; if (ty->is_abstract_or_scalar() || ty->IsAnyOf()) { diff --git a/src/tint/sem/constant.h b/src/tint/sem/constant.h index 43109f7ccd..c99b97920e 100644 --- a/src/tint/sem/constant.h +++ b/src/tint/sem/constant.h @@ -132,6 +132,17 @@ class Constant { /// @returns true if any element is zero bool AnyZero() const; + /// @returns true if all elements are zero + bool AllZero() const; + + /// @returns true if all elements are the same value + bool AllEqual() const { return AllEqual(0, ElementCount()); } + + /// @param start the first element index + /// @param end one past the last element index + /// @returns true if all elements between `[start, end)` are the same value + bool AllEqual(size_t start, size_t end) const; + /// @param index the index of the element /// @return the element at `index`, which must be of type `T`. template diff --git a/src/tint/sem/constant_test.cc b/src/tint/sem/constant_test.cc index 345ebd807b..6ad3cd6491 100644 --- a/src/tint/sem/constant_test.cc +++ b/src/tint/sem/constant_test.cc @@ -195,5 +195,43 @@ TEST_F(ConstantTest, Element_mat2x3_f16) { EXPECT_EQ(c.ElementCount(), 6u); } +TEST_F(ConstantTest, AnyZero) { + auto* vec3_ai = create(create(), 3u); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AnyZero(), false); + EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AnyZero(), true); + EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 3_a}).AnyZero(), true); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AnyZero(), true); + EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AnyZero(), true); +} + +TEST_F(ConstantTest, AllZero) { + auto* vec3_ai = create(create(), 3u); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllZero(), false); + EXPECT_EQ(Constant(vec3_ai, {0_a, 2_a, 3_a}).AllZero(), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 0_a, 3_a}).AllZero(), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 0_a}).AllZero(), false); + EXPECT_EQ(Constant(vec3_ai, {0_a, 0_a, 0_a}).AllZero(), true); +} + +TEST_F(ConstantTest, AllEqual) { + auto* vec3_ai = create(create(), 3u); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllEqual(), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 3_a}).AllEqual(), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 3_a, 3_a}).AllEqual(), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 1_a}).AllEqual(), true); + EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 2_a}).AllEqual(), true); + EXPECT_EQ(Constant(vec3_ai, {3_a, 3_a, 3_a}).AllEqual(), true); +} + +TEST_F(ConstantTest, AllEqualRange) { + auto* vec3_ai = create(create(), 3u); + EXPECT_EQ(Constant(vec3_ai, {1_a, 2_a, 3_a}).AllEqual(1, 3), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 3_a}).AllEqual(1, 3), false); + EXPECT_EQ(Constant(vec3_ai, {1_a, 3_a, 3_a}).AllEqual(1, 3), true); + EXPECT_EQ(Constant(vec3_ai, {1_a, 1_a, 1_a}).AllEqual(1, 3), true); + EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 2_a}).AllEqual(1, 3), true); + EXPECT_EQ(Constant(vec3_ai, {2_a, 2_a, 3_a}).AllEqual(1, 3), false); +} + } // namespace } // namespace tint::sem