diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 54baaadc7d..067b4d451a 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -579,6 +579,7 @@ libtint_source_set("libtint_core_all_src") { "utils/enum_set.h", "utils/foreach_macro.h", "utils/hash.h", + "utils/hashmap_base.h", "utils/hashmap.h", "utils/hashset.h", "utils/map.h", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 8049b3c8f1..bd5e824658 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -474,6 +474,7 @@ set(TINT_LIB_SRCS utils/enum_set.h utils/foreach_macro.h utils/hash.h + utils/hashmap_base.h utils/hashmap.h utils/hashset.h utils/map.h diff --git a/src/tint/clone_context.h b/src/tint/clone_context.h index 029279b240..05f4868b92 100644 --- a/src/tint/clone_context.h +++ b/src/tint/clone_context.h @@ -27,6 +27,7 @@ #include "src/tint/symbol.h" #include "src/tint/traits.h" #include "src/tint/utils/hashmap.h" +#include "src/tint/utils/hashset.h" #include "src/tint/utils/vector.h" // Forward declarations diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc index 4c71607893..d3febf2d70 100644 --- a/src/tint/reader/spirv/function.cc +++ b/src/tint/reader/spirv/function.cc @@ -3419,7 +3419,7 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info, utils::Hashmap copied_phis; for (const auto assignment : worklist) { const auto phi_id = assignment.phi_id; - if (read_set.Find(phi_id)) { + if (read_set.Contains(phi_id)) { auto copy_name = namer_.MakeDerivedName(namer_.Name(phi_id) + "_c" + std::to_string(block_info.id)); auto copy_sym = builder_.Symbols().Register(copy_name); diff --git a/src/tint/utils/hashmap.h b/src/tint/utils/hashmap.h index 87a0dd9b9b..1040e0cbbb 100644 --- a/src/tint/utils/hashmap.h +++ b/src/tint/utils/hashmap.h @@ -19,255 +19,119 @@ #include #include +#include "src/tint/debug.h" #include "src/tint/utils/hash.h" -#include "src/tint/utils/hashset.h" +#include "src/tint/utils/hashmap_base.h" +#include "src/tint/utils/vector.h" namespace tint::utils { /// An unordered map that uses a robin-hood hashing algorithm. -/// -/// Hashmap internally wraps a Hashset for providing a store for key-value pairs. -/// -/// @see Hashset -template , - typename EQUAL = std::equal_to> -class Hashmap { - /// Entry holds a key and value pair, and is used as the element type of the underlying Hashset. - /// Entries are compared and hashed using only the #key. - /// @see Hasher - /// @see Equality - struct Entry { - /// Constructor from a key and value pair - Entry(K k, V v) : key(std::move(k)), value(std::move(v)) {} - - /// Copy-constructor. - Entry(const Entry&) = default; - - /// Move-constructor. - Entry(Entry&&) = default; - - /// Copy-assignment operator - Entry& operator=(const Entry&) = default; - - /// Move-assignment operator - Entry& operator=(Entry&&) = default; - - K key; /// The map entry key - V value; /// The map entry value - }; - - /// Hash provider for the underlying Hashset. - /// Provides hash functions for an Entry or K. - /// The hash functions only consider the key of an entry. - struct Hasher { - /// Calculates a hash from an Entry - size_t operator()(const Entry& entry) const { return HASH()(entry.key); } - /// Calculates a hash from a K - size_t operator()(const K& key) const { return HASH()(key); } - }; - - /// Equality provider for the underlying Hashset. - /// Provides equality functions for an Entry or K to an Entry. - /// The equality functions only consider the key for equality. - struct Equality { - /// Compares an Entry to an Entry for equality. - bool operator()(const Entry& a, const Entry& b) const { return EQUAL()(a.key, b.key); } - /// Compares a K to an Entry for equality. - bool operator()(const K& a, const Entry& b) const { return EQUAL()(a, b.key); } - }; - - /// The underlying set - using Set = Hashset; + typename HASH = Hasher, + typename EQUAL = std::equal_to> +class Hashmap : public HashmapBase { + using Base = HashmapBase; + using PutMode = typename Base::PutMode; public: - /// A Key and Value const-reference pair. - struct KeyValue { - /// key of a map entry - const K& key; - /// value of a map entry - const V& value; + /// The key type + using Key = KEY; + /// The value type + using Value = VALUE; + /// The key-value type for a map entry + using Entry = KeyValue; - /// Equality operator - /// @param other the other KeyValue - /// @returns true if the key and value of this KeyValue are equal to other's. - bool operator==(const KeyValue& other) const { - return key == other.key && value == other.value; - } - }; + /// Result of Add() + using AddResult = typename Base::PutResult; - /// STL-style alias to KeyValue. - /// Used by gmock for the `ElementsAre` checks. - using value_type = KeyValue; - - /// Iterator for the map. - /// Iterators are invalidated if the map is modified. - class Iterator { - public: - /// @returns the key of the entry pointed to by this iterator - const K& Key() const { return it->key; } - - /// @returns the value of the entry pointed to by this iterator - const V& Value() const { return it->value; } - - /// Increments the iterator - /// @returns this iterator - Iterator& operator++() { - ++it; - return *this; - } - - /// Equality operator - /// @param other the other iterator to compare this iterator to - /// @returns true if this iterator is equal to other - bool operator==(const Iterator& other) const { return it == other.it; } - - /// Inequality operator - /// @param other the other iterator to compare this iterator to - /// @returns true if this iterator is not equal to other - bool operator!=(const Iterator& other) const { return it != other.it; } - - /// @returns a pair of key and value for the entry pointed to by this iterator - KeyValue operator*() const { return {Key(), Value()}; } - - private: - /// Friend class - friend class Hashmap; - - /// Underlying iterator type - using SetIterator = typename Set::Iterator; - - explicit Iterator(SetIterator i) : it(i) {} - - SetIterator it; - }; - - /// Removes all entries from the map. - void Clear() { set_.Clear(); } - - /// Adds the key-value pair to the map, if the map does not already contain an entry with a key - /// equal to `key`. - /// @param key the entry's key to add to the map - /// @param value the entry's value to add to the map - /// @returns true if the entry was added to the map, false if there was already an entry in the - /// map with a key equal to `key`. - template - bool Add(KEY&& key, VALUE&& value) { - return set_.Add(Entry{std::forward(key), std::forward(value)}); + /// Adds a value to the map, if the map does not already contain an entry with the key @p key. + /// @param key the entry key. + /// @param value the value of the entry to add to the map. + /// @returns A AddResult describing the result of the add + template + AddResult Add(K&& key, V&& value) { + return this->template Put(std::forward(key), std::forward(value)); } - /// Adds the key-value pair to the map, replacing any entry with a key equal to `key`. - /// @param key the entry's key to add to the map - /// @param value the entry's value to add to the map - template - void Replace(KEY&& key, VALUE&& value) { - set_.Replace(Entry{std::forward(key), std::forward(value)}); + /// Adds a new entry to the map, replacing any entry that has a key equal to @p key. + /// @param key the entry key. + /// @param value the value of the entry to add to the map. + /// @returns A AddResult describing the result of the replace + template + AddResult Replace(K&& key, V&& value) { + return this->template Put(std::forward(key), std::forward(value)); } - /// Searches for an entry with the given key value. - /// @param key the entry's key value to search for. - /// @returns the value of the entry with the given key, or no value if the entry was not found. - std::optional Get(const K& key) { - if (auto* entry = set_.Find(key)) { - return entry->value; + /// @param key the key to search for. + /// @returns the value of the entry that is equal to `value`, or no value if the entry was not + /// found. + std::optional Get(const Key& key) const { + if (auto [found, index] = this->IndexOf(key); found) { + return this->slots_[index].entry->value; } return std::nullopt; } - /// Searches for an entry with the given key value, adding and returning the result of - /// calling `create` if the entry was not found. + /// Searches for an entry with the given key, adding and returning the result of calling + /// @p create if the entry was not found. /// @note: Before calling `create`, the map will insert a zero-initialized value for the given - /// key, which will be replaced with the value returned by `create`. If `create` adds an entry - /// with `key` to this map, it will be replaced. + /// key, which will be replaced with the value returned by @p create. If @p create adds an entry + /// with @p key to this map, it will be replaced. /// @param key the entry's key value to search for. /// @param create the create function to call if the map does not contain the key. /// @returns the value of the entry. - template - V& GetOrCreate(const K& key, CREATE&& create) { - auto res = set_.Add(Entry{key, V{}}); - if (res.action == AddAction::kAdded) { - // Store the set generation before calling create() - auto generation = set_.Generation(); + template + Value& GetOrCreate(K&& key, CREATE&& create) { + auto res = Add(std::forward(key), Value{}); + if (res.action == MapAction::kAdded) { + // Store the map generation before calling create() + auto generation = this->Generation(); // Call create(), which might modify this map. auto value = create(); // Was this map mutated? - if (set_.Generation() == generation) { + if (this->Generation() == generation) { // Calling create() did not touch the map. No need to lookup again. - res.entry->value = std::move(value); + *res.value = std::move(value); } else { // Calling create() modified the map. Need to insert again. - res = set_.Replace(Entry{key, std::move(value)}); + res = Replace(key, std::move(value)); } } - return res.entry->value; + return *res.value; } /// Searches for an entry with the given key value, adding and returning a newly created /// zero-initialized value if the entry was not found. /// @param key the entry's key value to search for. /// @returns the value of the entry. - V& GetOrZero(const K& key) { - auto res = set_.Add(Entry{key, V{}}); - return res.entry->value; + template + Value& GetOrZero(K&& key) { + auto res = Add(std::forward(key), Value{}); + return *res.value; } - /// Searches for an entry with the given key value. - /// @param key the entry's key value to search for. - /// @returns the a pointer to the value of the entry with the given key, or nullptr if the entry - /// was not found. - /// @warning the pointer must not be used after the map is mutated - V* Find(const K& key) { - if (auto* entry = set_.Find(key)) { - return &entry->value; + /// @param key the key to search for. + /// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does + /// not contain the given value. + const Value* Find(const Key& key) const { + if (auto [found, index] = this->IndexOf(key); found) { + return &this->slots_[index].entry->value; } return nullptr; } - /// Searches for an entry with the given key value. - /// @param key the entry's key value to search for. - /// @returns the a pointer to the value of the entry with the given key, or nullptr if the entry - /// was not found. - /// @warning the pointer must not be used after the map is mutated - const V* Find(const K& key) const { - if (auto* entry = set_.Find(key)) { - return &entry->value; + /// @param key the key to search for. + /// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does + /// not contain the given value. + Value* Find(const Key& key) { + if (auto [found, index] = this->IndexOf(key); found) { + return &this->slots_[index].entry->value; } return nullptr; } - - /// Removes an entry from the set with a key equal to `key`. - /// @param key the entry key value to remove. - /// @returns true if an entry was removed. - bool Remove(const K& key) { return set_.Remove(key); } - - /// Checks whether an entry exists in the map with a key equal to `key`. - /// @param key the entry key value to search for. - /// @returns true if the map contains an entry with the given key. - bool Contains(const K& key) const { return set_.Contains(key); } - - /// Pre-allocates memory so that the map can hold at least `capacity` entries. - /// @param capacity the new capacity of the map. - void Reserve(size_t capacity) { set_.Reserve(capacity); } - - /// @returns the number of entries in the map. - size_t Count() const { return set_.Count(); } - - /// @returns a monotonic counter which is incremented whenever the map is mutated. - size_t Generation() const { return set_.Generation(); } - - /// @returns true if the map contains no entries. - bool IsEmpty() const { return set_.IsEmpty(); } - - /// @returns an iterator to the start of the map - Iterator begin() const { return Iterator{set_.begin()}; } - - /// @returns an iterator to the end of the map - Iterator end() const { return Iterator{set_.end()}; } - - private: - Set set_; }; } // namespace tint::utils diff --git a/src/tint/utils/hashmap_base.h b/src/tint/utils/hashmap_base.h new file mode 100644 index 0000000000..0460373cff --- /dev/null +++ b/src/tint/utils/hashmap_base.h @@ -0,0 +1,564 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_UTILS_HASHMAP_BASE_H_ +#define SRC_TINT_UTILS_HASHMAP_BASE_H_ + +#include +#include +#include +#include +#include + +#include "src/tint/debug.h" +#include "src/tint/utils/hash.h" +#include "src/tint/utils/vector.h" + +namespace tint::utils { + +/// Action taken by a map mutation +enum class MapAction { + /// A new entry was added to the map + kAdded, + /// A existing entry in the map was replaced + kReplaced, + /// No action was taken as the map already contained an entry with the given key + kKeptExisting, +}; + +/// KeyValue is a key-value pair. +template +struct KeyValue { + /// The key type + using Key = KEY; + /// The value type + using Value = VALUE; + + /// The key + Key key; + + /// The value + Value value; + + /// Equality operator + /// @param other the RHS of the operator + /// @returns true if both the key and value of this KeyValue are equal to the key and value + /// of @p other + template + bool operator==(const KeyValue& other) const { + return key == other.key && value == other.value; + } + + /// Inequality operator + /// @param other the RHS of the operator + /// @returns true if either the key and value of this KeyValue are not equal to the key and + /// value of @p other + template + bool operator!=(const KeyValue& other) const { + return *this != other; + } +}; + +/// Writes the KeyValue to the std::ostream. +/// @param out the std::ostream to write to +/// @param key_value the KeyValue to write +/// @returns out so calls can be chained +template +std::ostream& operator<<(std::ostream& out, const KeyValue& key_value) { + return out << "[" << key_value.key << ": " << key_value.value << "]"; +} + +/// A base class for Hashmap and Hashset that uses a robin-hood hashing algorithm. +/// @see the fantastic tutorial: https://programming.guide/robin-hood-hashing.html +template , + typename EQUAL = std::equal_to> +class HashmapBase { + static constexpr bool ValueIsVoid = std::is_same_v; + + public: + /// The key type + using Key = KEY; + /// The value type + using Value = VALUE; + /// The entry type for the map. + /// This is: + /// - Key when Value is void (used by Hashset) + /// - KeyValue when Value is void (used by Hashmap) + using Entry = std::conditional_t>; + + /// STL-friendly alias to Entry. Used by gmock. + using value_type = Entry; + + private: + /// @returns the key from an entry + static const Key& KeyOf(const Entry& entry) { + if constexpr (ValueIsVoid) { + return entry; + } else { + return entry.key; + } + } + + /// @returns a pointer to the value from an entry. + static Value* ValueOf(Entry& entry) { + if constexpr (ValueIsVoid) { + return nullptr; // Hashset only has keys + } else { + return &entry.value; + } + } + + /// A slot is a single entry in the underlying vector. + /// A slot can either be empty or filled with a value. If the slot is empty, #hash and #distance + /// will be zero. + struct Slot { + bool Equals(size_t key_hash, const Key& key) const { + return key_hash == hash && EQUAL()(key, KeyOf(*entry)); + } + + /// The slot value. If this does not contain a value, then the slot is vacant. + std::optional entry; + /// The precomputed hash of value. + size_t hash = 0; + size_t distance = 0; + }; + + /// The target length of the underlying vector length in relation to the number of entries in + /// the map, expressed as a percentage. For example a value of `150` would mean there would be + /// at least 50% more slots than the number of map entries. + static constexpr size_t kRehashFactor = 150; + + /// @returns the target slot vector size to hold `n` map entries. + static constexpr size_t NumSlots(size_t count) { return (count * kRehashFactor) / 100; } + + /// The fixed-size slot vector length, based on N and kRehashFactor. + static constexpr size_t kNumFixedSlots = NumSlots(N); + + /// The minimum number of slots for the map. + static constexpr size_t kMinSlots = std::max(kNumFixedSlots, 4); + + public: + /// Iterator for entries in the map. + /// Iterators are invalidated if the map is modified. + class Iterator { + public: + /// @returns the value pointed to by this iterator + const Entry* operator->() const { return ¤t->entry.value(); } + + /// @returns a reference to the value at the iterator + const Entry& operator*() const { return current->entry.value(); } + + /// Increments the iterator + /// @returns this iterator + Iterator& operator++() { + if (current == end) { + return *this; + } + current++; + SkipToNextValue(); + return *this; + } + + /// Equality operator + /// @param other the other iterator to compare this iterator to + /// @returns true if this iterator is equal to other + bool operator==(const Iterator& other) const { return current == other.current; } + + /// Inequality operator + /// @param other the other iterator to compare this iterator to + /// @returns true if this iterator is not equal to other + bool operator!=(const Iterator& other) const { return current != other.current; } + + private: + /// Friend class + friend class HashmapBase; + + Iterator(const Slot* c, const Slot* e) : current(c), end(e) { SkipToNextValue(); } + + /// Moves the iterator forward, stopping at the next slot that is not empty. + void SkipToNextValue() { + while (current != end && !current->entry.has_value()) { + current++; + } + } + + const Slot* current; /// The slot the iterator is pointing to + const Slot* end; /// One past the last slot in the map + }; + + /// Constructor + HashmapBase() { slots_.Resize(kMinSlots); } + + /// Copy constructor + /// @param other the other HashmapBase to copy + HashmapBase(const HashmapBase& other) = default; + + /// Move constructor + /// @param other the other HashmapBase to move + HashmapBase(HashmapBase&& other) = default; + + /// Destructor + ~HashmapBase() { Clear(); } + + /// Copy-assignment operator + /// @param other the other HashmapBase to copy + /// @returns this so calls can be chained + HashmapBase& operator=(const HashmapBase& other) = default; + + /// Move-assignment operator + /// @param other the other HashmapBase to move + /// @returns this so calls can be chained + HashmapBase& operator=(HashmapBase&& other) = default; + + /// Removes all entries from the map. + void Clear() { + slots_.Clear(); // Destructs all entries + slots_.Resize(kMinSlots); + count_ = 0; + generation_++; + } + + /// Removes an entry from the map. + /// @param key the entry key. + /// @returns true if an entry was removed. + bool Remove(const Key& key) { + const auto [found, start] = IndexOf(key); + if (!found) { + return false; + } + + // Shuffle the entries backwards until we either find a free slot, or a slot that has zero + // distance. + Slot* prev = nullptr; + Scan(start, [&](size_t, size_t index) { + auto& slot = slots_[index]; + if (prev) { + // note: `distance == 0` also includes empty slots. + if (slot.distance == 0) { + // Clear the previous slot, and stop shuffling. + *prev = {}; + return Action::kStop; + } else { + // Shuffle the slot backwards. + prev->entry = std::move(slot.entry); + prev->hash = slot.hash; + prev->distance = slot.distance - 1; + } + } + prev = &slot; + return Action::kContinue; + }); + + // Entry was removed. + count_--; + generation_++; + + return true; + } + + /// Checks whether an entry exists in the map + /// @param key the key to search for. + /// @returns true if the map contains an entry with the given value. + bool Contains(const Key& key) const { + const auto [found, _] = IndexOf(key); + return found; + } + + /// Pre-allocates memory so that the map can hold at least `capacity` entries. + /// @param capacity the new capacity of the map. + void Reserve(size_t capacity) { + // Calculate the number of slots required to hold `capacity` entries. + const size_t num_slots = std::max(NumSlots(capacity), kMinSlots); + if (slots_.Length() >= num_slots) { + // Already have enough slots. + return; + } + + // Move all the values out of the map and into a vector. + Vector entries; + entries.Reserve(count_); + for (auto& slot : slots_) { + if (slot.entry.has_value()) { + entries.Push(std::move(slot.entry.value())); + } + } + + // Clear the map, grow the number of slots. + Clear(); + slots_.Resize(num_slots); + + // As the number of slots has grown, the slot indices will have changed from before, so + // re-add all the entries back into the map. + for (auto& entry : entries) { + if constexpr (ValueIsVoid) { + struct NoValue {}; + Put(std::move(entry), NoValue{}); + } else { + Put(std::move(entry.key), std::move(entry.value)); + } + } + } + + /// @returns the number of entries in the map. + size_t Count() const { return count_; } + + /// @returns true if the map contains no entries. + bool IsEmpty() const { return count_ == 0; } + + /// @returns a monotonic counter which is incremented whenever the map is mutated. + size_t Generation() const { return generation_; } + + /// @returns an iterator to the start of the map. + Iterator begin() const { return Iterator{slots_.begin(), slots_.end()}; } + + /// @returns an iterator to the end of the map. + Iterator end() const { return Iterator{slots_.end(), slots_.end()}; } + + /// A debug function for checking that the map is in good health. + /// Asserts if the map is corrupted. + void ValidateIntegrity() const { + size_t num_alive = 0; + for (size_t slot_idx = 0; slot_idx < slots_.Length(); slot_idx++) { + const auto& slot = slots_[slot_idx]; + if (slot.entry.has_value()) { + num_alive++; + auto const [index, hash] = Hash(KeyOf(*slot.entry)); + TINT_ASSERT(Utils, hash == slot.hash); + TINT_ASSERT(Utils, slot_idx == Wrap(index + slot.distance)); + } + } + TINT_ASSERT(Utils, num_alive == count_); + } + + protected: + /// The behaviour of Put() when an entry already exists with the given key. + enum class PutMode { + /// Do not replace existing entries with the new value. + kAdd, + /// Replace existing entries with the new value. + kReplace, + }; + + /// Result of Put() + struct PutResult { + /// Whether the insert replaced or added a new entry to the map. + MapAction action = MapAction::kAdded; + /// A pointer to the inserted entry value. + Value* value = nullptr; + + /// @returns true if the entry was added to the map, or an existing entry was replaced. + operator bool() const { return action != MapAction::kKeptExisting; } + }; + + /// The common implementation for Add() and Replace() + /// @param key the key of the entry to add to the map. + /// @param value the value of the entry to add to the map. + /// @returns A PutResult describing the result of the insertion + template + PutResult Put(K&& key, V&& value) { + // Ensure the map can fit a new entry + if (ShouldRehash(count_ + 1)) { + Reserve((count_ + 1) * 2); + } + + const auto hash = Hash(key); + + auto make_entry = [&]() { + if constexpr (ValueIsVoid) { + return std::forward(key); + } else { + return Entry{std::forward(key), std::forward(value)}; + } + }; + + PutResult result{}; + Scan(hash.scan_start, [&](size_t distance, size_t index) { + auto& slot = slots_[index]; + if (!slot.entry.has_value()) { + // Found an empty slot. + // Place value directly into the slot, and we're done. + slot.entry.emplace(make_entry()); + slot.hash = hash.code; + slot.distance = distance; + count_++; + generation_++; + result = PutResult{MapAction::kAdded, ValueOf(*slot.entry)}; + return Action::kStop; + } + + // Slot has an entry + + if (slot.Equals(hash.code, key)) { + // Slot is equal to value. Replace or preserve? + if constexpr (MODE == PutMode::kReplace) { + slot.entry = make_entry(); + generation_++; + result = PutResult{MapAction::kReplaced, ValueOf(*slot.entry)}; + } else { + result = PutResult{MapAction::kKeptExisting, ValueOf(*slot.entry)}; + } + return Action::kStop; + } + + if (slot.distance < distance) { + // Existing slot has a closer distance than the value we're attempting to insert. + // Steal from the rich! + // Move the current slot to a temporary (evicted), and put the value into the slot. + Slot evicted{make_entry(), hash.code, distance}; + std::swap(evicted, slot); + + // Find a new home for the evicted slot. + evicted.distance++; // We've already swapped at index. + InsertShuffle(Wrap(index + 1), std::move(evicted)); + + count_++; + generation_++; + result = PutResult{MapAction::kAdded, ValueOf(*slot.entry)}; + + return Action::kStop; + } + return Action::kContinue; + }); + + return result; + } + + /// Return type of the Scan() callback. + enum class Action { + /// Continue scanning for a slot + kContinue, + /// Immediately stop scanning for a slot + kStop, + }; + + /// Sequentially visits each of the slots starting with the slot with the index @p start, + /// calling the callback function @p f for each slot until @p f returns Action::kStop. + /// @param start the index of the first slot to start scanning from. + /// @param f the callback function which: + /// * must be a function with the signature `Action(size_t distance, size_t index)`. + /// * must return Action::kStop within one whole cycle of the slots. + template + void Scan(size_t start, F&& f) const { + size_t index = start; + for (size_t distance = 0; distance < slots_.Length(); distance++) { + if (f(distance, index) == Action::kStop) { + return; + } + index = Wrap(index + 1); + } + tint::diag::List diags; + TINT_ICE(Utils, diags) << "HashmapBase::Scan() looped entire map without finding a slot"; + } + + /// HashResult is the return value of Hash() + struct HashResult { + /// The target (zero-distance) slot index for the key. + size_t scan_start; + /// The calculated hash code of the key. + size_t code; + }; + + /// @param key the key to hash + /// @returns a tuple holding the target slot index for the given value, and the hash of the + /// value, respectively. + HashResult Hash(const Key& key) const { + size_t hash = HASH()(key); + size_t index = Wrap(hash); + return {index, hash}; + } + + /// Looks for the key in the map. + /// @param key the key to search for. + /// @returns a tuple holding a boolean representing whether the key was found in the map, and + /// if found, the index of the slot that holds the key. + std::tuple IndexOf(const Key& key) const { + const auto hash = Hash(key); + + bool found = false; + size_t idx = 0; + + Scan(hash.scan_start, [&](size_t distance, size_t index) { + auto& slot = slots_[index]; + if (!slot.entry.has_value()) { + return Action::kStop; + } + if (slot.Equals(hash.code, key)) { + found = true; + idx = index; + return Action::kStop; + } + if (slot.distance < distance) { + // If the slot distance is less than the current probe distance, then the slot must + // be for entry that has an index that comes after key. In this situation, we know + // that the map does not contain the key, as it would have been found before this + // slot. The "Lookup" section of https://programming.guide/robin-hood-hashing.html + // suggests that the condition should inverted, but this is wrong. + return Action::kStop; + } + return Action::kContinue; + }); + + return {found, idx}; + } + + /// Shuffles slots for an insertion that has been placed one slot before `start`. + /// @param start the index of the first slot to start shuffling. + /// @param evicted the slot content that was evicted for the insertion. + void InsertShuffle(size_t start, Slot evicted) { + Scan(start, [&](size_t, size_t index) { + auto& slot = slots_[index]; + + if (!slot.entry.has_value()) { + // Empty slot found for evicted. + slot = std::move(evicted); + return Action::kStop; // We're done. + } + + if (slot.distance < evicted.distance) { + // Occupied slot has shorter distance to evicted. + // Swap slot and evicted. + std::swap(slot, evicted); + } + + // evicted moves further from the target slot... + evicted.distance++; + + return Action::kContinue; + }); + } + + /// @param count the number of new entries in the map + /// @returns true if the map should grow the slot vector, and rehash the items. + bool ShouldRehash(size_t count) const { return NumSlots(count) > slots_.Length(); } + + /// @param index an input value + /// @returns the input value modulo the number of slots. + size_t Wrap(size_t index) const { return index % slots_.Length(); } + + /// The vector of slots. The vector length is equal to its capacity. + Vector slots_; + + /// The number of entries in the map. + size_t count_ = 0; + + /// Counter that's incremented with each modification to the map. + size_t generation_ = 0; +}; + +} // namespace tint::utils + +#endif // SRC_TINT_UTILS_HASHMAP_BASE_H_ diff --git a/src/tint/utils/hashmap_test.cc b/src/tint/utils/hashmap_test.cc index 9a5b01e0b4..77421cf76c 100644 --- a/src/tint/utils/hashmap_test.cc +++ b/src/tint/utils/hashmap_test.cc @@ -92,14 +92,14 @@ TEST(Hashmap, Generation) { TEST(Hashmap, Iterator) { using Map = Hashmap; - using KV = typename Map::KeyValue; + using Entry = typename Map::Entry; Map map; map.Add(1, "one"); map.Add(4, "four"); map.Add(3, "three"); map.Add(2, "two"); - EXPECT_THAT(map, testing::UnorderedElementsAre(KV{1, "one"}, KV{2, "two"}, KV{3, "three"}, - KV{4, "four"})); + EXPECT_THAT(map, testing::UnorderedElementsAre(Entry{1, "one"}, Entry{2, "two"}, + Entry{3, "three"}, Entry{4, "four"})); } TEST(Hashmap, AddMany) { diff --git a/src/tint/utils/hashset.h b/src/tint/utils/hashset.h index f7d5efe743..53f71f5d65 100644 --- a/src/tint/utils/hashset.h +++ b/src/tint/utils/hashset.h @@ -23,497 +23,26 @@ #include #include "src/tint/debug.h" -#include "src/tint/utils/hash.h" +#include "src/tint/utils/hashmap.h" #include "src/tint/utils/vector.h" namespace tint::utils { -/// Action taken by Hashset::Insert() -enum class AddAction { - /// Insert() added a new entry to the Hashset - kAdded, - /// Insert() replaced an existing entry in the Hashset - kReplaced, - /// Insert() found an existing entry, which was not replaced. - kKeptExisting, -}; - /// An unordered set that uses a robin-hood hashing algorithm. -/// @see the fantastic tutorial: https://programming.guide/robin-hood-hashing.html -template , typename EQUAL = std::equal_to> -class Hashset { - /// A slot is a single entry in the underlying vector. - /// A slot can either be empty or filled with a value. If the slot is empty, #hash and #distance - /// will be zero. - struct Slot { - template - bool Equals(size_t value_hash, const V& val) const { - return value_hash == hash && EQUAL()(val, value.value()); - } - - /// The slot value. If this does not contain a value, then the slot is vacant. - std::optional value; - /// The precomputed hash of value. - size_t hash = 0; - size_t distance = 0; - }; - - /// The target length of the underlying vector length in relation to the number of entries in - /// the set, expressed as a percentage. For example a value of `150` would mean there would be - /// at least 50% more slots than the number of set entries. - static constexpr size_t kRehashFactor = 150; - - /// @returns the target slot vector size to hold `n` set entries. - static constexpr size_t NumSlots(size_t count) { return (count * kRehashFactor) / 100; } - - /// The fixed-size slot vector length, based on N and kRehashFactor. - static constexpr size_t kNumFixedSlots = NumSlots(N); - - /// The minimum number of slots for the set. - static constexpr size_t kMinSlots = std::max(kNumFixedSlots, 4); +template , typename EQUAL = std::equal_to> +class Hashset : public HashmapBase { + using Base = HashmapBase; + using PutMode = typename Base::PutMode; public: - /// Iterator for entries in the set. - /// Iterators are invalidated if the set is modified. - class Iterator { - public: - /// @returns the value pointed to by this iterator - const T* operator->() const { return ¤t->value.value(); } - - /// Increments the iterator - /// @returns this iterator - Iterator& operator++() { - if (current == end) { - return *this; - } - current++; - SkipToNextValue(); - return *this; - } - - /// Equality operator - /// @param other the other iterator to compare this iterator to - /// @returns true if this iterator is equal to other - bool operator==(const Iterator& other) const { return current == other.current; } - - /// Inequality operator - /// @param other the other iterator to compare this iterator to - /// @returns true if this iterator is not equal to other - bool operator!=(const Iterator& other) const { return current != other.current; } - - /// @returns a reference to the value at the iterator - const T& operator*() const { return current->value.value(); } - - private: - /// Friend class - friend class Hashset; - - Iterator(const Slot* c, const Slot* e) : current(c), end(e) { SkipToNextValue(); } - - /// Moves the iterator forward, stopping at the next slot that is not empty. - void SkipToNextValue() { - while (current != end && !current->value.has_value()) { - current++; - } - } - - const Slot* current; /// The slot the iterator is pointing to - const Slot* end; /// One past the last slot in the set - }; - - /// Type of `T`. - using value_type = T; - - /// Constructor - Hashset() { slots_.Resize(kMinSlots); } - - /// Copy constructor - /// @param other the other Hashset to copy - Hashset(const Hashset& other) = default; - - /// Move constructor - /// @param other the other Hashset to move - Hashset(Hashset&& other) = default; - - /// Destructor - ~Hashset() { Clear(); } - - /// Copy-assignment operator - /// @param other the other Hashset to copy - /// @returns this so calls can be chained - Hashset& operator=(const Hashset& other) = default; - - /// Move-assignment operator - /// @param other the other Hashset to move - /// @returns this so calls can be chained - Hashset& operator=(Hashset&& other) = default; - - /// Removes all entries from the set. - void Clear() { - slots_.Clear(); // Destructs all entries - slots_.Resize(kMinSlots); - count_ = 0; - generation_++; - } - - /// Result of Add() - struct AddResult { - /// Whether the insert replaced or added a new entry to the set. - AddAction action = AddAction::kAdded; - /// A pointer to the inserted entry. - /// @warning do not modify this pointer in a way that would cause the equality or hash of - /// the entry to change. Doing this will corrupt the Hashset. - T* entry = nullptr; - - /// @returns true if the entry was added to the set, or an existing entry was replaced. - operator bool() const { return action != AddAction::kKeptExisting; } - }; - /// Adds a value to the set, if the set does not already contain an entry equal to `value`. /// @param value the value to add to the set. - /// @returns A AddResult describing the result of the add - /// @warning do not modify the inserted entry in a way that would cause the equality of hash of - /// the entry to change. Doing this will corrupt the Hashset. + /// @returns true if the value was added, false if there was an existing value in the set. template - AddResult Add(V&& value) { - return Put(std::forward(value)); + bool Add(V&& value) { + struct NoValue {}; + return this->template Put(std::forward(value), NoValue{}); } - - /// Adds a value to the set, replacing any entry equal to `value`. - /// @param value the value to add to the set. - /// @returns A AddResult describing the result of the replace - template - AddResult Replace(V&& value) { - return Put(std::forward(value)); - } - - /// Removes an entry from the set. - /// @param value the value to remove from the set. - /// @returns true if an entry was removed. - template - bool Remove(const V& value) { - const auto [found, start] = IndexOf(value); - if (!found) { - return false; - } - - // Shuffle the entries backwards until we either find a free slot, or a slot that has zero - // distance. - Slot* prev = nullptr; - Scan(start, [&](size_t, size_t index) { - auto& slot = slots_[index]; - if (prev) { - // note: `distance == 0` also includes empty slots. - if (slot.distance == 0) { - // Clear the previous slot, and stop shuffling. - *prev = {}; - return Action::kStop; - } else { - // Shuffle the slot backwards. - prev->value = std::move(slot.value); - prev->hash = slot.hash; - prev->distance = slot.distance - 1; - } - } - prev = &slot; - return Action::kContinue; - }); - - // Entry was removed. - count_--; - generation_++; - - return true; - } - - /// @param value the value to search for. - /// @returns the value of the entry that is equal to `value`, or no value if the entry was not - /// found. - template - std::optional Get(const V& value) const { - if (const auto [found, index] = IndexOf(value); found) { - return slots_[index].value.value(); - } - return std::nullopt; - } - - /// @param value the value to search for. - /// @returns a pointer to the entry that is equal to the given value, or nullptr if the set does - /// not contain the given value. - template - const T* Find(const V& value) const { - const auto [found, index] = IndexOf(value); - return found ? &slots_[index].value.value() : nullptr; - } - - /// @param value the value to search for. - /// @returns a pointer to the entry that is equal to the given value, or nullptr if the set does - /// not contain the given value. - /// @warning do not modify the inserted entry in a way that would cause the equality of hash of - /// the entry to change. Doing this will corrupt the Hashset. - template - T* Find(const V& value) { - const auto [found, index] = IndexOf(value); - return found ? &slots_[index].value.value() : nullptr; - } - - /// Checks whether an entry exists in the set - /// @param value the value to search for. - /// @returns true if the set contains an entry with the given value. - template - bool Contains(const V& value) const { - const auto [found, _] = IndexOf(value); - return found; - } - - /// Pre-allocates memory so that the set can hold at least `capacity` entries. - /// @param capacity the new capacity of the set. - void Reserve(size_t capacity) { - // Calculate the number of slots required to hold `capacity` entries. - const size_t num_slots = std::max(NumSlots(capacity), kMinSlots); - if (slots_.Length() >= num_slots) { - // Already have enough slots. - return; - } - - // Move all the values out of the set and into a vector. - Vector values; - values.Reserve(count_); - for (auto& slot : slots_) { - if (slot.value.has_value()) { - values.Push(std::move(slot.value.value())); - } - } - - // Clear the set, grow the number of slots. - Clear(); - slots_.Resize(num_slots); - - // As the number of slots has grown, the slot indices will have changed from before, so - // re-add all the values back into the set. - for (auto& value : values) { - Add(std::move(value)); - } - } - - /// @returns the number of entries in the set. - size_t Count() const { return count_; } - - /// @returns true if the set contains no entries. - bool IsEmpty() const { return count_ == 0; } - - /// @returns a monotonic counter which is incremented whenever the set is mutated. - size_t Generation() const { return generation_; } - - /// @returns an iterator to the start of the set. - Iterator begin() const { return Iterator{slots_.begin(), slots_.end()}; } - - /// @returns an iterator to the end of the set. - Iterator end() const { return Iterator{slots_.end(), slots_.end()}; } - - /// A debug function for checking that the set is in good health. - /// Asserts if the set is corrupted. - void ValidateIntegrity() const { - size_t num_alive = 0; - for (size_t slot_idx = 0; slot_idx < slots_.Length(); slot_idx++) { - const auto& slot = slots_[slot_idx]; - if (slot.value.has_value()) { - num_alive++; - auto const [index, hash] = Hash(slot.value.value()); - TINT_ASSERT(Utils, hash == slot.hash); - TINT_ASSERT(Utils, slot_idx == Wrap(index + slot.distance)); - } - } - TINT_ASSERT(Utils, num_alive == count_); - } - - private: - /// The behaviour of Put() when an entry already exists with the given key. - enum class PutMode { - /// Do not replace existing entries with the new value. - kAdd, - /// Replace existing entries with the new value. - kReplace, - }; - /// The common implementation for Add() and Replace() - /// @param value the value to add to the set. - /// @returns A AddResult describing the result of the insertion - template - AddResult Put(V&& value) { - // Ensure the set can fit a new entry - if (ShouldRehash(count_ + 1)) { - Reserve((count_ + 1) * 2); - } - - const auto hash = Hash(value); - - AddResult result{}; - Scan(hash.scan_start, [&](size_t distance, size_t index) { - auto& slot = slots_[index]; - if (!slot.value.has_value()) { - // Found an empty slot. - // Place value directly into the slot, and we're done. - slot.value.emplace(std::forward(value)); - slot.hash = hash.value; - slot.distance = distance; - count_++; - generation_++; - result = AddResult{AddAction::kAdded, &slot.value.value()}; - return Action::kStop; - } - - // Slot has an entry - - if (slot.Equals(hash.value, value)) { - // Slot is equal to value. Replace or preserve? - if constexpr (MODE == PutMode::kReplace) { - slot.value = std::forward(value); - generation_++; - result = AddResult{AddAction::kReplaced, &slot.value.value()}; - } else { - result = AddResult{AddAction::kKeptExisting, &slot.value.value()}; - } - return Action::kStop; - } - - if (slot.distance < distance) { - // Existing slot has a closer distance than the value we're attempting to insert. - // Steal from the rich! - // Move the current slot to a temporary (evicted), and put the value into the slot. - Slot evicted{std::forward(value), hash.value, distance}; - std::swap(evicted, slot); - - // Find a new home for the evicted slot. - evicted.distance++; // We've already swapped at index. - InsertShuffle(Wrap(index + 1), std::move(evicted)); - - count_++; - generation_++; - result = AddResult{AddAction::kAdded, &slot.value.value()}; - - return Action::kStop; - } - return Action::kContinue; - }); - - return result; - } - - /// Return type of the Scan() callback. - enum class Action { - /// Continue scanning for a slot - kContinue, - /// Immediately stop scanning for a slot - kStop, - }; - - /// Sequentially visits each of the slots starting with the slot with the index `start`, calling - /// the callback function `f` for each slot until `f` returns Action::kStop. - /// `f` must be a function with the signature `Action(size_t distance, size_t index)`. - /// `f` must return Action::kStop within one whole cycle of the slots. - template - void Scan(size_t start, F&& f) const { - size_t index = start; - for (size_t distance = 0; distance < slots_.Length(); distance++) { - if (f(distance, index) == Action::kStop) { - return; - } - index = Wrap(index + 1); - } - tint::diag::List diags; - TINT_ICE(Utils, diags) << "Hashset::Scan() looped entire set without finding a slot"; - } - - /// HashResult is the return value of Hash() - struct HashResult { - /// The target (zero-distance) slot index for the value. - size_t scan_start; - /// The calculated hash of the value. - size_t value; - }; - - /// @returns a tuple holding the target slot index for the given value, and the hash of the - /// value, respectively. - template - HashResult Hash(const V& value) const { - size_t hash = HASH()(value); - size_t index = Wrap(hash); - return {index, hash}; - } - - /// Looks for the value in the set. - /// @returns a tuple holding a boolean representing whether the value was found in the set, and - /// if found, the index of the slot that holds the value. - template - std::tuple IndexOf(const V& value) const { - const auto hash = Hash(value); - - bool found = false; - size_t idx = 0; - - Scan(hash.scan_start, [&](size_t distance, size_t index) { - auto& slot = slots_[index]; - if (!slot.value.has_value()) { - return Action::kStop; - } - if (slot.Equals(hash.value, value)) { - found = true; - idx = index; - return Action::kStop; - } - if (slot.distance < distance) { - // If the slot distance is less than the current probe distance, then the slot must - // be for entry that has an index that comes after value. In this situation, we know - // that the set does not contain the value, as it would have been found before this - // slot. The "Lookup" section of https://programming.guide/robin-hood-hashing.html - // suggests that the condition should inverted, but this is wrong. - return Action::kStop; - } - return Action::kContinue; - }); - - return {found, idx}; - } - - /// Shuffles slots for an insertion that has been placed one slot before `start`. - /// @param evicted the slot content that was evicted for the insertion. - void InsertShuffle(size_t start, Slot evicted) { - Scan(start, [&](size_t, size_t index) { - auto& slot = slots_[index]; - - if (!slot.value.has_value()) { - // Empty slot found for evicted. - slot = std::move(evicted); - return Action::kStop; // We're done. - } - - if (slot.distance < evicted.distance) { - // Occupied slot has shorter distance to evicted. - // Swap slot and evicted. - std::swap(slot, evicted); - } - - // evicted moves further from the target slot... - evicted.distance++; - - return Action::kContinue; - }); - } - - /// @returns true if the set should grow the slot vector, and rehash the items. - bool ShouldRehash(size_t count) const { return NumSlots(count) > slots_.Length(); } - - /// Wrap returns the index value modulo the number of slots. - size_t Wrap(size_t index) const { return index % slots_.Length(); } - - /// The vector of slots. The vector length is equal to its capacity. - Vector slots_; - - /// The number of entries in the set. - size_t count_ = 0; - - /// Counter that's incremented with each modification to the set. - size_t generation_ = 0; }; } // namespace tint::utils diff --git a/src/tint/utils/hashset_test.cc b/src/tint/utils/hashset_test.cc index 6e8d1cc687..64f0da3518 100644 --- a/src/tint/utils/hashset_test.cc +++ b/src/tint/utils/hashset_test.cc @@ -74,18 +74,12 @@ TEST(Hashset, Generation) { EXPECT_EQ(set.Generation(), 1u); set.Add(1); EXPECT_EQ(set.Generation(), 1u); - set.Replace(1); - EXPECT_EQ(set.Generation(), 2u); set.Add(2); - EXPECT_EQ(set.Generation(), 3u); + EXPECT_EQ(set.Generation(), 2u); set.Remove(1); - EXPECT_EQ(set.Generation(), 4u); + EXPECT_EQ(set.Generation(), 3u); set.Clear(); - EXPECT_EQ(set.Generation(), 5u); - set.Find(2); - EXPECT_EQ(set.Generation(), 5u); - set.Get(2); - EXPECT_EQ(set.Generation(), 5u); + EXPECT_EQ(set.Generation(), 4u); } TEST(Hashset, Iterator) { @@ -103,53 +97,30 @@ TEST(Hashset, Soak) { Hashset set; for (size_t i = 0; i < 1000000; i++) { std::string value = std::to_string(rnd() & 0x100); - switch (rnd() % 8) { + switch (rnd() % 5) { case 0: { // Add auto expected = reference.emplace(value).second; ASSERT_EQ(set.Add(value), expected) << "i: " << i; ASSERT_TRUE(set.Contains(value)) << "i: " << i; break; } - case 1: { // Replace - reference.emplace(value); - set.Replace(value); - ASSERT_TRUE(set.Contains(value)) << "i: " << i; - break; - } - case 2: { // Remove + case 1: { // Remove auto expected = reference.erase(value) != 0; ASSERT_EQ(set.Remove(value), expected) << "i: " << i; ASSERT_FALSE(set.Contains(value)) << "i: " << i; break; } - case 3: { // Contains + case 2: { // Contains auto expected = reference.count(value) != 0; ASSERT_EQ(set.Contains(value), expected) << "i: " << i; break; } - case 4: { // Get - if (reference.count(value) != 0) { - ASSERT_TRUE(set.Get(value).has_value()) << "i: " << i; - ASSERT_EQ(set.Get(value), value) << "i: " << i; - } else { - ASSERT_FALSE(set.Get(value).has_value()) << "i: " << i; - } - break; - } - case 5: { // Find - if (reference.count(value) != 0) { - ASSERT_EQ(*set.Find(value), value) << "i: " << i; - } else { - ASSERT_EQ(set.Find(value), nullptr) << "i: " << i; - } - break; - } - case 6: { // Copy / Move + case 3: { // Copy / Move Hashset tmp(set); set = std::move(tmp); break; } - case 7: { // Clear + case 4: { // Clear reference.clear(); set.Clear(); break;