diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index cf3ebae488..a8cbbf5efe 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -565,6 +565,8 @@ libtint_source_set("libtint_core_all_src") { "utils/debugger.h", "utils/enum_set.h", "utils/hash.h", + "utils/hashmap.h", + "utils/hashset.h", "utils/map.h", "utils/math.h", "utils/scoped_assignment.h", @@ -1235,6 +1237,8 @@ if (tint_build_unittests) { "utils/hash_test.cc", "utils/io/command_test.cc", "utils/io/tmpfile_test.cc", + "utils/hashmap_test.cc", + "utils/hashset_test.cc", "utils/map_test.cc", "utils/math_test.cc", "utils/result_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 2d253303e5..057a80df3d 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -475,6 +475,8 @@ set(TINT_LIB_SRCS utils/crc32.h utils/enum_set.h utils/hash.h + utils/hashmap.h + utils/hashset.h utils/map.h utils/math.h utils/scoped_assignment.h @@ -863,6 +865,8 @@ if(TINT_BUILD_TESTS) utils/hash_test.cc utils/io/command_test.cc utils/io/tmpfile_test.cc + utils/hashmap_test.cc + utils/hashset_test.cc utils/map_test.cc utils/math_test.cc utils/result_test.cc diff --git a/src/tint/utils/hashmap.h b/src/tint/utils/hashmap.h new file mode 100644 index 0000000000..81bebf2410 --- /dev/null +++ b/src/tint/utils/hashmap.h @@ -0,0 +1,305 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_UTILS_HASHMAP_H_ +#define SRC_TINT_UTILS_HASHMAP_H_ + +#include +#include +#include + +#include "src/tint/utils/hashset.h" + +namespace tint::utils { + +/// An unordered map that uses a robin-hood hashing algorithm. +/// +/// Hashmap internally wraps a Hashset for providing a store for key-value pairs. +/// +/// @see Hashset +template , + typename EQUAL = std::equal_to> +class Hashmap { + /// LazyCreator is a transient structure used to late-build the Entry::value, when inserted into + /// the underlying Hashset. + /// + /// LazyCreator holds a #key, and a #create function used to build the final Entry::value. + /// The #create function must be of the signature `V()`. + /// + /// LazyCreator can be compared to Entry and hashed, allowing them to be passed to + /// Hashset::Insert(). If the set does not contain an existing entry with #key, + /// Hashset::Insert() will construct a new Entry passing the rvalue LazyCreator as the + /// constructor argument, which in turn calls the #create function to generate the entry value. + /// + /// @see Entry + /// @see Hasher + /// @see Equality + template + struct LazyCreator { + /// The key of the entry to insert into the map + const K& key; + /// The value creation function + CREATE create; + }; + + /// Entry holds a key and value pair, and is used as the element type of the underlying Hashset. + /// Entries are compared and hashed using only the #key. + /// @see Hasher + /// @see Equality + struct Entry { + /// Constructor from a key and value pair + Entry(K k, V v) : key(std::move(k)), value(std::move(v)) {} + + /// Copy-constructor. + Entry(const Entry&) = default; + + /// Move-constructor. + Entry(Entry&&) = default; + + /// Constructor from a LazyCreator. + /// The constructor invokes the LazyCreator::create function to build the #value. + /// @see LazyCreator + template + Entry(const LazyCreator& creator) // NOLINT(runtime/explicit) + : key(creator.key), value(creator.create()) {} + + /// Assignment operator from a LazyCreator. + /// The assignment invokes the LazyCreator::create function to build the #value. + /// @see LazyCreator + template + Entry& operator=(LazyCreator&& creator) { + key = std::move(creator.key); + value = creator.create(); + return *this; + } + + /// Copy-assignment operator + Entry& operator=(const Entry&) = default; + + /// Move-assignment operator + Entry& operator=(Entry&&) = default; + + K key; /// The map entry key + V value; /// The map entry value + }; + + /// Hash provider for the underlying Hashset. + /// Provides hash functions for an Entry, K or LazyCreator. + /// The hash functions only consider the key of an entry. + struct Hasher { + /// Calculates a hash from an Entry + size_t operator()(const Entry& entry) const { return HASH()(entry.key); } + /// Calculates a hash from a K + size_t operator()(const K& key) const { return HASH()(key); } + /// Calculates a hash from a LazyCreator + template + size_t operator()(const LazyCreator& lc) const { + return HASH()(lc.key); + } + }; + + /// Equality provider for the underlying Hashset. + /// Provides equality functions for an Entry, K or LazyCreator to an Entry. + /// The equality functions only consider the key for equality. + struct Equality { + /// Compares an Entry to an Entry for equality. + bool operator()(const Entry& a, const Entry& b) const { return EQUAL()(a.key, b.key); } + /// Compares a K to an Entry for equality. + bool operator()(const K& a, const Entry& b) const { return EQUAL()(a, b.key); } + /// Compares a LazyCreator to an Entry for equality. + template + bool operator()(const LazyCreator& lc, const Entry& b) const { + return EQUAL()(lc.key, b.key); + } + }; + + /// The underlying set + using Set = Hashset; + + public: + /// A Key and Value const-reference pair. + struct KeyValue { + /// key of a map entry + const K& key; + /// value of a map entry + const V& value; + + /// Equality operator + /// @param other the other KeyValue + /// @returns true if the key and value of this KeyValue are equal to other's. + bool operator==(const KeyValue& other) const { + return key == other.key && value == other.value; + } + }; + + /// STL-style alias to KeyValue. + /// Used by gmock for the `ElementsAre` checks. + using value_type = KeyValue; + + /// Iterator for the map + class Iterator { + public: + /// @returns the key of the entry pointed to by this iterator + const K& Key() const { return it->key; } + + /// @returns the value of the entry pointed to by this iterator + const V& Value() const { return it->value; } + + /// Increments the iterator + /// @returns this iterator + Iterator& operator++() { + ++it; + return *this; + } + + /// Equality operator + /// @param other the other iterator to compare this iterator to + /// @returns true if this iterator is equal to other + bool operator==(const Iterator& other) const { return it == other.it; } + + /// Inequality operator + /// @param other the other iterator to compare this iterator to + /// @returns true if this iterator is not equal to other + bool operator!=(const Iterator& other) const { return it != other.it; } + + /// @returns a pair of key and value for the entry pointed to by this iterator + KeyValue operator*() const { return {Key(), Value()}; } + + private: + /// Friend class + friend class Hashmap; + + /// Underlying iterator type + using SetIterator = typename Set::Iterator; + + explicit Iterator(SetIterator i) : it(i) {} + + SetIterator it; + }; + + /// Removes all entries from the map. + void Clear() { set_.Clear(); } + + /// Adds the key-value pair to the map, if the map does not already contain an entry with a key + /// equal to `key`. + /// @param key the entry's key to add to the map + /// @param value the entry's value to add to the map + /// @returns true if the entry was added to the map, false if there was already an entry in the + /// map with a key equal to `key`. + template + bool Add(KEY&& key, VALUE&& value) { + return set_.Add(Entry{std::forward(key), std::forward(value)}); + } + + /// Adds the key-value pair to the map, replacing any entry with a key equal to `key`. + /// @param key the entry's key to add to the map + /// @param value the entry's value to add to the map + template + void Replace(KEY&& key, VALUE&& value) { + set_.Replace(Entry{std::forward(key), std::forward(value)}); + } + + /// Searches for an entry with the given key value. + /// @param key the entry's key value to search for. + /// @returns the value of the entry with the given key, or no value if the entry was not found. + std::optional Get(const K& key) { + if (auto* entry = set_.Find(key)) { + return entry->value; + } + return std::nullopt; + } + + /// Searches for an entry with the given key value, adding and returning the result of + /// calling `create` if the entry was not found. + /// @param key the entry's key value to search for. + /// @param create the create function to call if the map does not contain the key. + /// @returns the value of the entry. + template + V& GetOrCreate(const K& key, CREATE&& create) { + LazyCreator lc{key, std::forward(create)}; + auto res = set_.Add(std::move(lc)); + return res.entry->value; + } + + /// Searches for an entry with the given key value, adding and returning a newly created + /// zero-initialized value if the entry was not found. + /// @param key the entry's key value to search for. + /// @returns the value of the entry. + V& GetOrZero(const K& key) { + auto zero = [] { return V{}; }; + LazyCreator lc{key, zero}; + auto res = set_.Add(std::move(lc)); + return res.entry->value; + } + + /// Searches for an entry with the given key value. + /// @param key the entry's key value to search for. + /// @returns the a pointer to the value of the entry with the given key, or nullptr if the entry + /// was not found. + /// @warning the pointer must not be used after the map is mutated + V* Find(const K& key) { + if (auto* entry = set_.Find(key)) { + return &entry->value; + } + return nullptr; + } + + /// Searches for an entry with the given key value. + /// @param key the entry's key value to search for. + /// @returns the a pointer to the value of the entry with the given key, or nullptr if the entry + /// was not found. + /// @warning the pointer must not be used after the map is mutated + const V* Find(const K& key) const { + if (auto* entry = set_.Find(key)) { + return &entry->value; + } + return nullptr; + } + + /// Removes an entry from the set with a key equal to `key`. + /// @param key the entry key value to remove. + /// @returns true if an entry was removed. + bool Remove(const K& key) { return set_.Remove(key); } + + /// Checks whether an entry exists in the map with a key equal to `key`. + /// @param key the entry key value to search for. + /// @returns true if the map contains an entry with the given key. + bool Contains(const K& key) const { return set_.Contains(key); } + + /// Pre-allocates memory so that the map can hold at least `capacity` entries. + /// @param capacity the new capacity of the map. + void Reserve(size_t capacity) { set_.Reserve(capacity); } + + /// @returns the number of entries in the map. + size_t Count() const { return set_.Count(); } + + /// @returns true if the map contains no entries. + bool IsEmpty() const { return set_.IsEmpty(); } + + /// @returns an iterator to the start of the map + Iterator begin() const { return Iterator{set_.begin()}; } + + /// @returns an iterator to the end of the map + Iterator end() const { return Iterator{set_.end()}; } + + private: + Set set_; +}; + +} // namespace tint::utils + +#endif // SRC_TINT_UTILS_HASHMAP_H_ diff --git a/src/tint/utils/hashmap_test.cc b/src/tint/utils/hashmap_test.cc new file mode 100644 index 0000000000..45e929b4c0 --- /dev/null +++ b/src/tint/utils/hashmap_test.cc @@ -0,0 +1,179 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/utils/hashmap.h" + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" + +namespace tint::utils { +namespace { + +constexpr std::array kPrimes{ + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, + 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, + 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, + 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, + 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, +}; + +TEST(Hashmap, Empty) { + Hashmap map; + EXPECT_EQ(map.Count(), 0u); +} + +TEST(Hashmap, AddRemove) { + Hashmap map; + EXPECT_TRUE(map.Add("hello", "world")); + EXPECT_EQ(map.Get("hello"), "world"); + EXPECT_EQ(map.Count(), 1u); + EXPECT_TRUE(map.Contains("hello")); + EXPECT_FALSE(map.Contains("world")); + EXPECT_FALSE(map.Add("hello", "cat")); + EXPECT_EQ(map.Count(), 1u); + EXPECT_TRUE(map.Remove("hello")); + EXPECT_EQ(map.Count(), 0u); + EXPECT_FALSE(map.Contains("hello")); + EXPECT_FALSE(map.Contains("world")); +} + +TEST(Hashmap, ReplaceRemove) { + Hashmap map; + map.Replace("hello", "world"); + EXPECT_EQ(map.Get("hello"), "world"); + EXPECT_EQ(map.Count(), 1u); + EXPECT_TRUE(map.Contains("hello")); + EXPECT_FALSE(map.Contains("world")); + map.Replace("hello", "cat"); + EXPECT_EQ(map.Get("hello"), "cat"); + EXPECT_EQ(map.Count(), 1u); + EXPECT_TRUE(map.Remove("hello")); + EXPECT_EQ(map.Count(), 0u); + EXPECT_FALSE(map.Contains("hello")); + EXPECT_FALSE(map.Contains("world")); +} + +TEST(Hashmap, Iterator) { + using Map = Hashmap; + using KV = typename Map::KeyValue; + Map map; + map.Add(1, "one"); + map.Add(4, "four"); + map.Add(3, "three"); + map.Add(2, "two"); + EXPECT_THAT(map, testing::UnorderedElementsAre(KV{1, "one"}, KV{2, "two"}, KV{3, "three"}, + KV{4, "four"})); +} + +TEST(Hashmap, AddMany) { + Hashmap map; + for (size_t i = 0; i < kPrimes.size(); i++) { + int prime = kPrimes[i]; + ASSERT_TRUE(map.Add(prime, std::to_string(prime))) << "i: " << i; + ASSERT_FALSE(map.Add(prime, std::to_string(prime))) << "i: " << i; + ASSERT_EQ(map.Count(), i + 1); + } + ASSERT_EQ(map.Count(), kPrimes.size()); + for (int prime : kPrimes) { + ASSERT_TRUE(map.Contains(prime)) << prime; + ASSERT_EQ(map.Get(prime), std::to_string(prime)) << prime; + } +} + +TEST(Hashmap, GetOrCreate) { + Hashmap map; + EXPECT_EQ(map.GetOrCreate(0, [&] { return "zero"; }), "zero"); + EXPECT_EQ(map.Count(), 1u); + EXPECT_EQ(map.Get(0), "zero"); + + bool create_called = false; + EXPECT_EQ(map.GetOrCreate(0, + [&] { + create_called = true; + return "oh noes"; + }), + "zero"); + EXPECT_FALSE(create_called); + EXPECT_EQ(map.Count(), 1u); + EXPECT_EQ(map.Get(0), "zero"); + + EXPECT_EQ(map.GetOrCreate(1, [&] { return "one"; }), "one"); + EXPECT_EQ(map.Count(), 2u); + EXPECT_EQ(map.Get(1), "one"); +} + +TEST(Hashmap, Soak) { + std::mt19937 rnd; + std::unordered_map reference; + Hashmap map; + for (size_t i = 0; i < 1000000; i++) { + std::string key = std::to_string(rnd() & 64); + std::string value = "V" + key; + switch (rnd() % 7) { + case 0: { // Add + auto expected = reference.emplace(key, value).second; + EXPECT_EQ(map.Add(key, value), expected) << "i:" << i; + EXPECT_EQ(map.Get(key), value) << "i:" << i; + EXPECT_TRUE(map.Contains(key)) << "i:" << i; + break; + } + case 1: { // Replace + reference[key] = value; + map.Replace(key, value); + EXPECT_EQ(map.Get(key), value) << "i:" << i; + EXPECT_TRUE(map.Contains(key)) << "i:" << i; + break; + } + case 2: { // Remove + auto expected = reference.erase(key) != 0; + EXPECT_EQ(map.Remove(key), expected) << "i:" << i; + EXPECT_FALSE(map.Get(key).has_value()) << "i:" << i; + EXPECT_FALSE(map.Contains(key)) << "i:" << i; + break; + } + case 3: { // Contains + auto expected = reference.count(key) != 0; + EXPECT_EQ(map.Contains(key), expected) << "i:" << i; + break; + } + case 4: { // Get + if (reference.count(key) != 0) { + auto expected = reference[key]; + EXPECT_EQ(map.Get(key), expected) << "i:" << i; + } else { + EXPECT_FALSE(map.Get(key).has_value()) << "i:" << i; + } + break; + } + case 5: { // Copy / Move + Hashmap tmp(map); + map = std::move(tmp); + break; + } + case 6: { // Clear + reference.clear(); + map.Clear(); + break; + } + } + } +} + +} // namespace +} // namespace tint::utils diff --git a/src/tint/utils/hashset.h b/src/tint/utils/hashset.h new file mode 100644 index 0000000000..f88a304bf9 --- /dev/null +++ b/src/tint/utils/hashset.h @@ -0,0 +1,508 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_UTILS_HASHSET_H_ +#define SRC_TINT_UTILS_HASHSET_H_ + +#include +#include +#include +#include +#include +#include + +#include "src/tint/debug.h" +#include "src/tint/utils/vector.h" + +namespace tint::utils { + +/// Action taken by Hashset::Insert() +enum class AddAction { + /// Insert() added a new entry to the Hashset + kAdded, + /// Insert() replaced an existing entry in the Hashset + kReplaced, + /// Insert() found an existing entry, which was not replaced. + kKeptExisting, +}; + +/// An unordered set that uses a robin-hood hashing algorithm. +/// @see the fantastic tutorial: https://programming.guide/robin-hood-hashing.html +template , typename EQUAL = std::equal_to> +class Hashset { + /// A slot is a single entry in the underlying vector. + /// A slot can either be empty or filled with a value. If the slot is empty, #hash and #distance + /// will be zero. + struct Slot { + template + bool Equals(size_t value_hash, const V& val) const { + return value_hash == hash && EQUAL()(val, value.value()); + } + + /// The slot value. If this does not contain a value, then the slot is vacant. + std::optional value; + /// The precomputed hash of value. + size_t hash = 0; + size_t distance = 0; + }; + + /// The target length of the underlying vector length in relation to the number of entries in + /// the set, expressed as a percentage. For example a value of `150` would mean there would be + /// at least 50% more slots than the number of set entries. + static constexpr size_t kRehashFactor = 150; + + /// @returns the target slot vector size to hold `n` set entries. + static constexpr size_t NumSlots(size_t count) { return (count * kRehashFactor) / 100; } + + /// The fixed-size slot vector length, based on N and kRehashFactor. + static constexpr size_t kNumFixedSlots = NumSlots(N); + + /// The minimum number of slots for the set. + static constexpr size_t kMinSlots = std::max(kNumFixedSlots, 4); + + public: + /// Iterator for entries in the set + class Iterator { + public: + /// @returns the value pointed to by this iterator + const T* operator->() const { return ¤t->value.value(); } + + /// Increments the iterator + /// @returns this iterator + Iterator& operator++() { + if (current == end) { + return *this; + } + current++; + SkipToNextValue(); + return *this; + } + + /// Equality operator + /// @param other the other iterator to compare this iterator to + /// @returns true if this iterator is equal to other + bool operator==(const Iterator& other) const { return current == other.current; } + + /// Inequality operator + /// @param other the other iterator to compare this iterator to + /// @returns true if this iterator is not equal to other + bool operator!=(const Iterator& other) const { return current != other.current; } + + /// @returns a reference to the value at the iterator + const T& operator*() const { return current->value.value(); } + + private: + /// Friend class + friend class Hashset; + + Iterator(const Slot* c, const Slot* e) : current(c), end(e) { SkipToNextValue(); } + + /// Moves the iterator forward, stopping at the next slot that is not empty. + void SkipToNextValue() { + while (current != end && !current->value.has_value()) { + current++; + } + } + + const Slot* current; /// The slot the iterator is pointing to + const Slot* end; /// One past the last slot in the set + }; + + /// Type of `T`. + using value_type = T; + + /// Constructor + Hashset() { slots_.Resize(kMinSlots); } + + /// Copy constructor + /// @param other the other Hashset to copy + Hashset(const Hashset& other) = default; + + /// Move constructor + /// @param other the other Hashset to move + Hashset(Hashset&& other) = default; + + /// Destructor + ~Hashset() { Clear(); } + + /// Copy-assignment operator + /// @param other the other Hashset to copy + /// @returns this so calls can be chained + Hashset& operator=(const Hashset& other) = default; + + /// Move-assignment operator + /// @param other the other Hashset to move + /// @returns this so calls can be chained + Hashset& operator=(Hashset&& other) = default; + + /// Removes all entries from the set. + void Clear() { + slots_.Clear(); // Destructs all entries + slots_.Resize(kMinSlots); + count_ = 0; + } + + /// Result of Add() + struct AddResult { + /// Whether the insert replaced or added a new entry to the set. + AddAction action = AddAction::kAdded; + /// A pointer to the inserted entry. + /// @warning do not modify this pointer in a way that would cause the equality or hash of + /// the entry to change. Doing this will corrupt the Hashset. + T* entry = nullptr; + + /// @returns true if the entry was added to the set, or an existing entry was replaced. + operator bool() const { return action != AddAction::kKeptExisting; } + }; + + /// Adds a value to the set, if the set does not already contain an entry equal to `value`. + /// @param value the value to add to the set. + /// @returns A AddResult describing the result of the add + /// @warning do not modify the inserted entry in a way that would cause the equality of hash of + /// the entry to change. Doing this will corrupt the Hashset. + template + AddResult Add(V&& value) { + return Put(std::forward(value)); + } + + /// Adds a value to the set, replacing any entry equal to `value`. + /// @param value the value to add to the set. + /// @returns A AddResult describing the result of the replace + template + AddResult Replace(V&& value) { + return Put(std::forward(value)); + } + + /// Removes an entry from the set. + /// @param value the value to remove from the set. + /// @returns true if an entry was removed. + template + bool Remove(const V& value) { + const auto [found, start] = IndexOf(value); + if (!found) { + return false; + } + + // Shuffle the entries backwards until we either find a free slot, or a slot that has zero + // distance. + Slot* prev = nullptr; + Scan(start, [&](size_t, size_t index) { + auto& slot = slots_[index]; + if (prev) { + // note: `distance == 0` also includes empty slots. + if (slot.distance == 0) { + // Clear the previous slot, and stop shuffling. + *prev = {}; + return Action::kStop; + } else { + // Shuffle the slot backwards. + prev->value = std::move(slot.value); + prev->hash = slot.hash; + prev->distance = slot.distance - 1; + } + } + prev = &slot; + return Action::kContinue; + }); + + // Entry was removed. + count_--; + + return true; + } + + /// @param value the value to search for. + /// @returns the value of the entry that is equal to `value`, or no value if the entry was not + /// found. + template + std::optional Get(const V& value) const { + if (const auto [found, index] = IndexOf(value); found) { + return slots_[index].value.value(); + } + return std::nullopt; + } + + /// @param value the value to search for. + /// @returns a pointer to the entry that is equal to the given value, or nullptr if the set does + /// not contain the given value. + template + const T* Find(const V& value) const { + const auto [found, index] = IndexOf(value); + return found ? &slots_[index].value.value() : nullptr; + } + + /// @param value the value to search for. + /// @returns a pointer to the entry that is equal to the given value, or nullptr if the set does + /// not contain the given value. + /// @warning do not modify the inserted entry in a way that would cause the equality of hash of + /// the entry to change. Doing this will corrupt the Hashset. + template + T* Find(const V& value) { + const auto [found, index] = IndexOf(value); + return found ? &slots_[index].value.value() : nullptr; + } + + /// Checks whether an entry exists in the set + /// @param value the value to search for. + /// @returns true if the set contains an entry with the given value. + template + bool Contains(const V& value) const { + const auto [found, _] = IndexOf(value); + return found; + } + + /// Pre-allocates memory so that the set can hold at least `capacity` entries. + /// @param capacity the new capacity of the set. + void Reserve(size_t capacity) { + // Calculate the number of slots required to hold `capacity` entries. + const size_t num_slots = std::max(NumSlots(capacity), kMinSlots); + if (slots_.Length() >= num_slots) { + // Already have enough slots. + return; + } + + // Move all the values out of the set and into a vector. + Vector values; + values.Reserve(count_); + for (auto& slot : slots_) { + if (slot.value.has_value()) { + values.Push(std::move(slot.value.value())); + } + } + + // Clear the set, grow the number of slots. + Clear(); + slots_.Resize(num_slots); + + // As the number of slots has grown, the slot indices will have changed from before, so + // re-add all the values back into the set. + for (auto& value : values) { + Add(std::move(value)); + } + } + + /// @returns the number of entries in the set. + size_t Count() const { return count_; } + + /// @returns true if the set contains no entries. + bool IsEmpty() const { return count_ == 0; } + + /// @returns an iterator to the start of the set. + Iterator begin() const { return Iterator{slots_.begin(), slots_.end()}; } + + /// @returns an iterator to the end of the set. + Iterator end() const { return Iterator{slots_.end(), slots_.end()}; } + + /// A debug function for checking that the set is in good health. + /// Asserts if the set is corrupted. + void ValidateIntegrity() const { + size_t num_alive = 0; + for (size_t slot_idx = 0; slot_idx < slots_.Length(); slot_idx++) { + const auto& slot = slots_[slot_idx]; + if (slot.value.has_value()) { + num_alive++; + auto const [index, hash] = Hash(slot.value.value()); + TINT_ASSERT(Utils, hash == slot.hash); + TINT_ASSERT(Utils, slot_idx == Wrap(index + slot.distance)); + } + } + TINT_ASSERT(Utils, num_alive == count_); + } + + private: + /// The behaviour of Put() when an entry already exists with the given key. + enum class PutMode { + /// Do not replace existing entries with the new value. + kAdd, + /// Replace existing entries with the new value. + kReplace, + }; + /// The common implementation for Add() and Replace() + /// @param value the value to add to the set. + /// @returns A AddResult describing the result of the insertion + template + AddResult Put(V&& value) { + // Ensure the set can fit a new entry + if (ShouldRehash(count_ + 1)) { + Reserve((count_ + 1) * 2); + } + + const auto hash = Hash(value); + + AddResult result{}; + Scan(hash.scan_start, [&](size_t distance, size_t index) { + auto& slot = slots_[index]; + if (!slot.value.has_value()) { + // Found an empty slot. + // Place value directly into the slot, and we're done. + slot.value.emplace(std::forward(value)); + slot.hash = hash.value; + slot.distance = distance; + count_++; + result = AddResult{AddAction::kAdded, &slot.value.value()}; + return Action::kStop; + } + + // Slot has an entry + + if (slot.Equals(hash.value, value)) { + // Slot is equal to value. Replace or preserve? + if constexpr (MODE == PutMode::kReplace) { + slot.value = std::forward(value); + result = AddResult{AddAction::kReplaced, &slot.value.value()}; + } else { + result = AddResult{AddAction::kKeptExisting, &slot.value.value()}; + } + return Action::kStop; + } + + if (slot.distance < distance) { + // Existing slot has a closer distance than the value we're attempting to insert. + // Steal from the rich! + // Move the current slot to a temporary (evicted), and put the value into the slot. + Slot evicted{std::forward(value), hash.value, distance}; + std::swap(evicted, slot); + + // Find a new home for the evicted slot. + evicted.distance++; // We've already swapped at index. + InsertShuffle(Wrap(index + 1), std::move(evicted)); + + count_++; + result = AddResult{AddAction::kAdded, &slot.value.value()}; + + return Action::kStop; + } + return Action::kContinue; + }); + + return result; + } + + /// Return type of the Scan() callback. + enum class Action { + /// Continue scanning for a slot + kContinue, + /// Immediately stop scanning for a slot + kStop, + }; + + /// Sequentially visits each of the slots starting with the slot with the index `start`, calling + /// the callback function `f` for each slot until `f` returns Action::kStop. + /// `f` must be a function with the signature `Action(size_t distance, size_t index)`. + /// `f` must return Action::kStop within one whole cycle of the slots. + template + void Scan(size_t start, F&& f) const { + size_t index = start; + for (size_t distance = 0; distance < slots_.Length(); distance++) { + if (f(distance, index) == Action::kStop) { + return; + } + index = Wrap(index + 1); + } + tint::diag::List diags; + TINT_ICE(Utils, diags) << "Hashset::Scan() looped entire set without finding a slot"; + } + + /// HashResult is the return value of Hash() + struct HashResult { + /// The target (zero-distance) slot index for the value. + size_t scan_start; + /// The calculated hash of the value. + size_t value; + }; + + /// @returns a tuple holding the target slot index for the given value, and the hash of the + /// value, respectively. + template + HashResult Hash(const V& value) const { + size_t hash = HASH()(value); + size_t index = Wrap(hash); + return {index, hash}; + } + + /// Looks for the value in the set. + /// @returns a tuple holding a boolean representing whether the value was found in the set, and + /// if found, the index of the slot that holds the value. + template + std::tuple IndexOf(const V& value) const { + const auto hash = Hash(value); + + bool found = false; + size_t idx = 0; + + Scan(hash.scan_start, [&](size_t distance, size_t index) { + auto& slot = slots_[index]; + if (!slot.value.has_value()) { + return Action::kStop; + } + if (slot.Equals(hash.value, value)) { + found = true; + idx = index; + return Action::kStop; + } + if (slot.distance < distance) { + // If the slot distance is less than the current probe distance, then the slot must + // be for entry that has an index that comes after value. In this situation, we know + // that the set does not contain the value, as it would have been found before this + // slot. The "Lookup" section of https://programming.guide/robin-hood-hashing.html + // suggests that the condition should inverted, but this is wrong. + return Action::kStop; + } + return Action::kContinue; + }); + + return {found, idx}; + } + + /// Shuffles slots for an insertion that has been placed one slot before `start`. + /// @param evicted the slot content that was evicted for the insertion. + void InsertShuffle(size_t start, Slot evicted) { + Scan(start, [&](size_t, size_t index) { + auto& slot = slots_[index]; + + if (!slot.value.has_value()) { + // Empty slot found for evicted. + slot = std::move(evicted); + return Action::kStop; // We're done. + } + + if (slot.distance < evicted.distance) { + // Occupied slot has shorter distance to evicted. + // Swap slot and evicted. + std::swap(slot, evicted); + } + + // evicted moves further from the target slot... + evicted.distance++; + + return Action::kContinue; + }); + } + + /// @returns true if the set should grow the slot vector, and rehash the items. + bool ShouldRehash(size_t count) const { return NumSlots(count) > slots_.Length(); } + + /// Wrap returns the index value modulo the number of slots. + size_t Wrap(size_t index) const { return index % slots_.Length(); } + + /// The vector of slots. The vector length is equal to its capacity. + Vector slots_; + + /// The number of entries in the set. + size_t count_ = 0; +}; + +} // namespace tint::utils + +#endif // SRC_TINT_UTILS_HASHSET_H_ diff --git a/src/tint/utils/hashset_test.cc b/src/tint/utils/hashset_test.cc new file mode 100644 index 0000000000..4213b32490 --- /dev/null +++ b/src/tint/utils/hashset_test.cc @@ -0,0 +1,142 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/utils/hashset.h" + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" + +namespace tint::utils { +namespace { + +constexpr std::array kPrimes{ + 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53, + 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113, 127, 131, + 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197, 199, 211, 223, + 227, 229, 233, 239, 241, 251, 257, 263, 269, 271, 277, 281, 283, 293, 307, 311, + 313, 317, 331, 337, 347, 349, 353, 359, 367, 373, 379, 383, 389, 397, 401, 409, +}; + +TEST(Hashset, Empty) { + Hashset set; + EXPECT_EQ(set.Count(), 0u); +} + +TEST(Hashset, AddRemove) { + Hashset set; + EXPECT_TRUE(set.Add("hello")); + EXPECT_EQ(set.Count(), 1u); + EXPECT_TRUE(set.Contains("hello")); + EXPECT_FALSE(set.Contains("world")); + EXPECT_FALSE(set.Add("hello")); + EXPECT_EQ(set.Count(), 1u); + EXPECT_TRUE(set.Remove("hello")); + EXPECT_EQ(set.Count(), 0u); + EXPECT_FALSE(set.Contains("hello")); + EXPECT_FALSE(set.Contains("world")); +} + +TEST(Hashset, AddMany) { + Hashset set; + for (size_t i = 0; i < kPrimes.size(); i++) { + int prime = kPrimes[i]; + ASSERT_TRUE(set.Add(prime)) << "i: " << i; + ASSERT_FALSE(set.Add(prime)) << "i: " << i; + ASSERT_EQ(set.Count(), i + 1); + set.ValidateIntegrity(); + } + ASSERT_EQ(set.Count(), kPrimes.size()); + for (int prime : kPrimes) { + ASSERT_TRUE(set.Contains(prime)) << prime; + } +} + +TEST(Hashset, Iterator) { + Hashset set; + set.Add("one"); + set.Add("four"); + set.Add("three"); + set.Add("two"); + EXPECT_THAT(set, testing::UnorderedElementsAre("one", "two", "three", "four")); +} + +TEST(Hashset, Soak) { + std::mt19937 rnd; + std::unordered_set reference; + Hashset set; + for (size_t i = 0; i < 1000000; i++) { + std::string value = std::to_string(rnd() & 0x100); + switch (rnd() % 8) { + case 0: { // Add + auto expected = reference.emplace(value).second; + ASSERT_EQ(set.Add(value), expected) << "i: " << i; + ASSERT_TRUE(set.Contains(value)) << "i: " << i; + break; + } + case 1: { // Replace + reference.emplace(value); + set.Replace(value); + ASSERT_TRUE(set.Contains(value)) << "i: " << i; + break; + } + case 2: { // Remove + auto expected = reference.erase(value) != 0; + ASSERT_EQ(set.Remove(value), expected) << "i: " << i; + ASSERT_FALSE(set.Contains(value)) << "i: " << i; + break; + } + case 3: { // Contains + auto expected = reference.count(value) != 0; + ASSERT_EQ(set.Contains(value), expected) << "i: " << i; + break; + } + case 4: { // Get + if (reference.count(value) != 0) { + ASSERT_TRUE(set.Get(value).has_value()) << "i: " << i; + ASSERT_EQ(set.Get(value), value) << "i: " << i; + } else { + ASSERT_FALSE(set.Get(value).has_value()) << "i: " << i; + } + break; + } + case 5: { // Find + if (reference.count(value) != 0) { + ASSERT_EQ(*set.Find(value), value) << "i: " << i; + } else { + ASSERT_EQ(set.Find(value), nullptr) << "i: " << i; + } + break; + } + case 6: { // Copy / Move + Hashset tmp(set); + set = std::move(tmp); + break; + } + case 7: { // Clear + reference.clear(); + set.Clear(); + break; + } + } + set.ValidateIntegrity(); + } +} + +} // namespace +} // namespace tint::utils