diff --git a/src/tint/clone_context.cc b/src/tint/clone_context.cc index 12252942ad..457522bf8b 100644 --- a/src/tint/clone_context.cc +++ b/src/tint/clone_context.cc @@ -27,9 +27,6 @@ Cloneable::Cloneable() = default; Cloneable::Cloneable(Cloneable&&) = default; Cloneable::~Cloneable() = default; -CloneContext::ListTransforms::ListTransforms() = default; -CloneContext::ListTransforms::~ListTransforms() = default; - CloneContext::CloneContext(ProgramBuilder* to, Program const* from, bool auto_clone_symbols) : dst(to), src(from) { if (auto_clone_symbols) { @@ -48,7 +45,7 @@ Symbol CloneContext::Clone(Symbol s) { if (!src) { return s; // In-place clone } - return utils::GetOrCreate(cloned_symbols_, s, [&]() -> Symbol { + return cloned_symbols_.GetOrCreate(s, [&]() -> Symbol { if (symbol_transform_) { return symbol_transform_(s); } @@ -76,9 +73,8 @@ const tint::Cloneable* CloneContext::CloneCloneable(const Cloneable* object) { } // Was Replace() called for this object? - auto it = replacements_.find(object); - if (it != replacements_.end()) { - return it->second(); + if (auto* fn = replacements_.Find(object)) { + return (*fn)(); } // Attempt to clone using the registered replacer functions. diff --git a/src/tint/clone_context.h b/src/tint/clone_context.h index e7e2d52c6b..3c5f6ec5d3 100644 --- a/src/tint/clone_context.h +++ b/src/tint/clone_context.h @@ -18,8 +18,6 @@ #include #include #include -#include -#include #include #include @@ -28,6 +26,7 @@ #include "src/tint/program_id.h" #include "src/tint/symbol.h" #include "src/tint/traits.h" +#include "src/tint/utils/hashmap.h" #include "src/tint/utils/vector.h" // Forward declarations @@ -201,56 +200,49 @@ class CloneContext { void Clone(utils::Vector& to, const utils::Vector& from) { to.Reserve(from.Length()); - auto list_transform_it = list_transforms_.find(&from); - if (list_transform_it != list_transforms_.end()) { - const auto& transforms = list_transform_it->second; - for (auto* o : transforms.insert_front_) { + auto transforms = list_transforms_.Find(&from); + + if (transforms) { + for (auto* o : transforms->insert_front_) { to.Push(CheckedCast(o)); } for (auto& el : from) { - auto insert_before_it = transforms.insert_before_.find(el); - if (insert_before_it != transforms.insert_before_.end()) { - for (auto insert : insert_before_it->second) { + if (auto* insert_before = transforms->insert_before_.Find(el)) { + for (auto insert : *insert_before) { to.Push(CheckedCast(insert)); } } - if (transforms.remove_.count(el) == 0) { + if (!transforms->remove_.Contains(el)) { to.Push(Clone(el)); } - auto insert_after_it = transforms.insert_after_.find(el); - if (insert_after_it != transforms.insert_after_.end()) { - for (auto insert : insert_after_it->second) { + if (auto* insert_after = transforms->insert_after_.Find(el)) { + for (auto insert : *insert_after) { to.Push(CheckedCast(insert)); } } } - for (auto* o : transforms.insert_back_) { + for (auto* o : transforms->insert_back_) { to.Push(CheckedCast(o)); } } else { for (auto& el : from) { to.Push(Clone(el)); - // Clone(el) may have inserted after - list_transform_it = list_transforms_.find(&from); - if (list_transform_it != list_transforms_.end()) { - const auto& transforms = list_transform_it->second; - - auto insert_after_it = transforms.insert_after_.find(el); - if (insert_after_it != transforms.insert_after_.end()) { - for (auto insert : insert_after_it->second) { + // Clone(el) may have updated the transformation list, adding an `insert_after` + // transform for `from`. + if (transforms) { + if (auto* insert_after = transforms->insert_after_.Find(el)) { + for (auto insert : *insert_after) { to.Push(CheckedCast(insert)); } } } } - // Clone(el)s may have inserted back - list_transform_it = list_transforms_.find(&from); - if (list_transform_it != list_transforms_.end()) { - const auto& transforms = list_transform_it->second; - - for (auto* o : transforms.insert_back_) { + // Clone(el) may have updated the transformation list, adding an `insert_back_` + // transform for `from`. + if (transforms) { + for (auto* o : transforms->insert_back_) { to.Push(CheckedCast(o)); } } @@ -358,7 +350,7 @@ class CloneContext { CloneContext& Replace(const WHAT* what, const WITH* with) { TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, with); - replacements_[what] = [with]() -> const Cloneable* { return with; }; + replacements_.Add(what, [with]() -> const Cloneable* { return with; }); return *this; } @@ -378,7 +370,7 @@ class CloneContext { template > CloneContext& Replace(const WHAT* what, WITH&& with) { TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what); - replacements_[what] = with; + replacements_.Add(what, with); return *this; } @@ -396,7 +388,7 @@ class CloneContext { return *this; } - list_transforms_[&vector].remove_.emplace(object); + list_transforms_.Edit(&vector).remove_.Add(object); return *this; } @@ -408,9 +400,7 @@ class CloneContext { template CloneContext& InsertFront(const utils::Vector& vector, OBJECT* object) { TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object); - auto& transforms = list_transforms_[&vector]; - auto& list = transforms.insert_front_; - list.Push(object); + list_transforms_.Edit(&vector).insert_front_.Push(object); return *this; } @@ -422,9 +412,7 @@ class CloneContext { template CloneContext& InsertBack(const utils::Vector& vector, OBJECT* object) { TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, object); - auto& transforms = list_transforms_[&vector]; - auto& list = transforms.insert_back_; - list.Push(object); + list_transforms_.Edit(&vector).insert_back_.Push(object); return *this; } @@ -446,9 +434,7 @@ class CloneContext { return *this; } - auto& transforms = list_transforms_[&vector]; - auto& list = transforms.insert_before_[before]; - list.Push(object); + list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push(object); return *this; } @@ -470,9 +456,7 @@ class CloneContext { return *this; } - auto& transforms = list_transforms_[&vector]; - auto& list = transforms.insert_after_[after]; - list.Push(object); + list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push(object); return *this; } @@ -502,6 +486,31 @@ class CloneContext { std::function function; }; + /// A vector of const Cloneable* + using CloneableList = utils::Vector; + + /// Transformations to be applied to a list (vector) + struct ListTransforms { + /// A map of object in #src to omit when cloned into #dst. + utils::Hashset remove_; + + /// A list of objects in #dst to insert before any others when the vector is cloned. + CloneableList insert_front_; + + /// A list of objects in #dst to insert after all others when the vector is cloned. + CloneableList insert_back_; + + /// A map of object in #src to the list of cloned objects in #dst. + /// Clone(const utils::Vector& v) will use this to insert the map-value + /// list into the target vector before cloning and inserting the map-key. + utils::Hashmap insert_before_; + + /// A map of object in #src to the list of cloned objects in #dst. + /// Clone(const utils::Vector& v) will use this to insert the map-value + /// list into the target vector after cloning and inserting the map-key. + utils::Hashmap insert_after_; + }; + CloneContext(const CloneContext&) = delete; CloneContext& operator=(const CloneContext&) = delete; @@ -530,50 +539,78 @@ class CloneContext { /// @returns the diagnostic list of #dst diag::List& Diagnostics() const; - /// A vector of const Cloneable* - using CloneableList = utils::Vector; + /// VectorListTransforms is a map of utils::Vector pointer to transforms for that list + struct VectorListTransforms { + /// An accessor to the VectorListTransforms map. + /// Index caches the last map lookup, and will only re-search the map if the transform map + /// was modified since the last lookup. + struct Index { + /// @returns true if the map now holds a value for the index + operator bool() { + Update(); + return cached_; + } - /// Transformations to be applied to a list (vector) - struct ListTransforms { - /// Constructor - ListTransforms(); - /// Destructor - ~ListTransforms(); + /// @returns a pointer to the indexed map entry + const ListTransforms* operator->() { + Update(); + return cached_; + } - /// A map of object in #src to omit when cloned into #dst. - std::unordered_set remove_; + private: + friend VectorListTransforms; - /// A list of objects in #dst to insert before any others when the vector is - /// cloned. - CloneableList insert_front_; + Index(const void* list, + VectorListTransforms& vlt, + uint32_t generation, + const ListTransforms* cached) + : list_(list), vlt_(vlt), generation_(generation), cached_(cached) {} - /// A list of objects in #dst to insert befor after any others when the - /// vector is cloned. - CloneableList insert_back_; + void Update() { + if (vlt_.generation_ != generation_) { + cached_ = vlt_.map_.Find(list_); + generation_ = vlt_.generation_; + } + } - /// A map of object in #src to the list of cloned objects in #dst. - /// Clone(const utils::Vector& v) will use this to insert the map-value - /// list into the target vector before cloning and inserting the map-key. - std::unordered_map insert_before_; + const void* list_; + VectorListTransforms& vlt_; + uint32_t generation_; + const ListTransforms* cached_; + }; - /// A map of object in #src to the list of cloned objects in #dst. - /// Clone(const utils::Vector& v) will use this to insert the map-value - /// list into the target vector after cloning and inserting the map-key. - std::unordered_map insert_after_; + /// Edit returns a reference to the ListTransforms for the given vector pointer and + /// increments #list_transform_generation_ signalling that the list transforms have been + /// modified. + inline ListTransforms& Edit(const void* list) { + generation_++; + return map_.GetOrZero(list); + } + + /// @returns an Index to the transforms for the given list. + inline Index Find(const void* list) { + return Index{list, *this, generation_, map_.Find(list)}; + } + + private: + /// The map of vector pointer to ListTransforms + utils::Hashmap map_; + + /// A counter that's incremented each time list transforms are modified. + uint32_t generation_ = 0; }; - /// A map of object in #src to functions that create their replacement in - /// #dst - std::unordered_map> replacements_; + /// A map of object in #src to functions that create their replacement in #dst + utils::Hashmap, 8> replacements_; /// A map of symbol in #src to their cloned equivalent in #dst - std::unordered_map cloned_symbols_; + utils::Hashmap cloned_symbols_; /// Cloneable transform functions registered with ReplaceAll() utils::Vector transforms_; - /// Map of utils::Vector pointer to transforms for that list - std::unordered_map list_transforms_; + /// Transformations to apply to vectors + VectorListTransforms list_transforms_; /// Symbol transform registered with ReplaceAll() SymbolTransform symbol_transform_;