diff --git a/src/utils/enum_set.h b/src/utils/enum_set.h index 46d93e118d..3a71f78852 100644 --- a/src/utils/enum_set.h +++ b/src/utils/enum_set.h @@ -17,6 +17,7 @@ #include #include +#include #include namespace tint { @@ -62,16 +63,79 @@ struct EnumSet { /// Equality operator /// @param rhs the other EnumSet to compare this to /// @return true if this EnumSet is equal to rhs - inline bool operator==(const EnumSet& rhs) const { return set == rhs.set; } + inline bool operator==(EnumSet rhs) const { return set == rhs.set; } /// Inequality operator /// @param rhs the other EnumSet to compare this to /// @return true if this EnumSet is not equal to rhs - inline bool operator!=(const EnumSet& rhs) const { return set != rhs.set; } + inline bool operator!=(EnumSet rhs) const { return set != rhs.set; } + + /// Equality operator + /// @param rhs the enum to compare this to + /// @return true if this EnumSet only contains `rhs` + inline bool operator==(Enum rhs) const { return set == Bit(rhs); } + + /// Inequality operator + /// @param rhs the enum to compare this to + /// @return false if this EnumSet only contains `rhs` + inline bool operator!=(Enum rhs) const { return set != Bit(rhs); } /// @return the underlying value for the EnumSet inline uint64_t Value() const { return set; } + /// Iterator provides read-only, unidirectional iterator over the enums of an + /// EnumSet. + class Iterator { + static constexpr int8_t kEnd = 63; + + Iterator(uint64_t s, int8_t b) : set(s), pos(b) {} + + /// Make the constructor accessible to the EnumSet. + friend struct EnumSet; + + public: + /// @return the Enum value at this point in the iterator + Enum operator*() const { return static_cast(pos); } + + /// Increments the iterator + /// @returns this iterator + Iterator& operator++() { + while (pos < kEnd) { + pos++; + if (set & (static_cast(1) << static_cast(pos))) { + break; + } + } + return *this; + } + + /// Equality operator + /// @param rhs the Iterator to compare this to + /// @return true if the two iterators are equal + bool operator==(const Iterator& rhs) const { + return set == rhs.set && pos == rhs.pos; + } + + /// Inequality operator + /// @param rhs the Iterator to compare this to + /// @return true if the two iterators are different + bool operator!=(const Iterator& rhs) const { return !(*this == rhs); } + + private: + const uint64_t set; + int8_t pos; + }; + + /// @returns an read-only iterator to the beginning of the set + Iterator begin() { + auto it = Iterator{set, -1}; + ++it; // Move to first set bit + return it; + } + + /// @returns an iterator to the beginning of the set + Iterator end() { return Iterator{set, Iterator::kEnd}; } + private: static constexpr uint64_t Bit(Enum value) { return static_cast(1) << static_cast(value); @@ -87,6 +151,24 @@ struct EnumSet { uint64_t set = 0; }; +/// Writes the EnumSet to the std::ostream. +/// @param out the std::ostream to write to +/// @param set the EnumSet to write +/// @returns out so calls can be chained +template +inline std::ostream& operator<<(std::ostream& out, EnumSet set) { + out << "{"; + bool first = true; + for (auto e : set) { + if (!first) { + out << ", "; + } + first = false; + out << e; + } + return out << "}"; +} + } // namespace utils } // namespace tint diff --git a/src/utils/enum_set_test.cc b/src/utils/enum_set_test.cc index 80c3dcfc2c..9a5186df1a 100644 --- a/src/utils/enum_set_test.cc +++ b/src/utils/enum_set_test.cc @@ -14,13 +14,30 @@ #include "src/utils/enum_set.h" -#include "gtest/gtest.h" +#include +#include + +#include "gmock/gmock.h" namespace tint { namespace utils { namespace { -enum class E { A, B, C }; +using ::testing::ElementsAre; + +enum class E { A = 0, B = 3, C = 7 }; + +std::ostream& operator<<(std::ostream& out, E e) { + switch (e) { + case E::A: + return out << "A"; + case E::B: + return out << "B"; + case E::C: + return out << "C"; + } + return out << "E(" << static_cast(e) << ")"; +} TEST(EnumSetTest, ConstructEmpty) { EnumSet set; @@ -59,16 +76,34 @@ TEST(EnumSetTest, Remove) { EXPECT_FALSE(set.Contains(E::C)); } -TEST(EnumSetTest, Equality) { +TEST(EnumSetTest, EqualitySet) { EXPECT_TRUE(EnumSet(E::A, E::B) == EnumSet(E::A, E::B)); EXPECT_FALSE(EnumSet(E::A, E::B) == EnumSet(E::A, E::C)); } -TEST(EnumSetTest, Inequality) { +TEST(EnumSetTest, InequalitySet) { EXPECT_FALSE(EnumSet(E::A, E::B) != EnumSet(E::A, E::B)); EXPECT_TRUE(EnumSet(E::A, E::B) != EnumSet(E::A, E::C)); } +TEST(EnumSetTest, EqualityEnum) { + EXPECT_TRUE(EnumSet(E::A) == E::A); + EXPECT_FALSE(EnumSet(E::B) == E::A); + EXPECT_FALSE(EnumSet(E::B) == E::C); + EXPECT_FALSE(EnumSet(E::A, E::B) == E::A); + EXPECT_FALSE(EnumSet(E::A, E::B) == E::B); + EXPECT_FALSE(EnumSet(E::A, E::B) == E::C); +} + +TEST(EnumSetTest, InequalityEnum) { + EXPECT_FALSE(EnumSet(E::A) != E::A); + EXPECT_TRUE(EnumSet(E::B) != E::A); + EXPECT_TRUE(EnumSet(E::B) != E::C); + EXPECT_TRUE(EnumSet(E::A, E::B) != E::A); + EXPECT_TRUE(EnumSet(E::A, E::B) != E::B); + EXPECT_TRUE(EnumSet(E::A, E::B) != E::C); +} + TEST(EnumSetTest, Hash) { auto hash = [&](EnumSet s) { return std::hash>()(s); }; EXPECT_EQ(hash(EnumSet(E::A, E::B)), hash(EnumSet(E::A, E::B))); @@ -78,9 +113,44 @@ TEST(EnumSetTest, Hash) { TEST(EnumSetTest, Value) { EXPECT_EQ(EnumSet().Value(), 0u); EXPECT_EQ(EnumSet(E::A).Value(), 1u); - EXPECT_EQ(EnumSet(E::B).Value(), 2u); - EXPECT_EQ(EnumSet(E::C).Value(), 4u); - EXPECT_EQ(EnumSet(E::A, E::C).Value(), 5u); + EXPECT_EQ(EnumSet(E::B).Value(), 8u); + EXPECT_EQ(EnumSet(E::C).Value(), 128u); + EXPECT_EQ(EnumSet(E::A, E::C).Value(), 129u); +} + +TEST(EnumSetTest, Iterator) { + auto set = EnumSet(E::C, E::A); + + auto it = set.begin(); + EXPECT_EQ(*it, E::A); + EXPECT_NE(it, set.end()); + ++it; + EXPECT_EQ(*it, E::C); + EXPECT_NE(it, set.end()); + ++it; + EXPECT_EQ(it, set.end()); +} + +TEST(EnumSetTest, IteratorEmpty) { + auto set = EnumSet(); + EXPECT_EQ(set.begin(), set.end()); +} + +TEST(EnumSetTest, Loop) { + auto set = EnumSet(E::C, E::A); + + std::vector seen; + for (auto e : set) { + seen.emplace_back(e); + } + + EXPECT_THAT(seen, ElementsAre(E::A, E::C)); +} + +TEST(EnumSetTest, Ostream) { + std::stringstream ss; + ss << EnumSet(E::A, E::C); + EXPECT_EQ(ss.str(), "{A, C}"); } } // namespace