From 7c6e229a18cab18b1a069cb83f91809819e32828 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 23 Nov 2022 21:04:25 +0000 Subject: [PATCH] tint/utils: Make Hashmap::Find() safer to use Don't return a raw pointer to the map entry's value, instead return a new Reference which re-looks up the entry if the map is mutated. Change-Id: I031749785faeac98e2a129a776493cb0371a5cb9 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110540 Reviewed-by: Antonio Maiorano Kokoro: Kokoro Commit-Queue: Ben Clayton --- src/tint/clone_context.cc | 2 +- src/tint/clone_context.h | 77 +++-------------- src/tint/reader/spirv/function.cc | 2 +- src/tint/reader/spirv/parser_impl.h | 2 +- src/tint/resolver/dependency_graph.cc | 2 +- src/tint/resolver/dependency_graph_test.cc | 12 +-- src/tint/resolver/resolver.cc | 2 +- src/tint/resolver/sem_helper.h | 2 +- src/tint/resolver/uniformity.cc | 8 +- src/tint/resolver/validator.cc | 4 +- src/tint/scope_stack.h | 4 +- .../transform/decompose_strided_matrix.cc | 4 +- src/tint/transform/simplify_pointers.cc | 2 +- src/tint/transform/std140.cc | 4 +- .../truncate_interstage_variables.cc | 2 +- src/tint/transform/unshadow.cc | 2 +- .../transform/utils/hoist_to_decl_before.cc | 22 ++--- src/tint/utils/hashmap.h | 82 +++++++++++++++++-- src/tint/utils/hashmap_test.cc | 65 +++++++++++++++ 19 files changed, 187 insertions(+), 113 deletions(-) diff --git a/src/tint/clone_context.cc b/src/tint/clone_context.cc index 457522bf8b..fe94a14488 100644 --- a/src/tint/clone_context.cc +++ b/src/tint/clone_context.cc @@ -73,7 +73,7 @@ const tint::Cloneable* CloneContext::CloneCloneable(const Cloneable* object) { } // Was Replace() called for this object? - if (auto* fn = replacements_.Find(object)) { + if (auto fn = replacements_.Find(object)) { return (*fn)(); } diff --git a/src/tint/clone_context.h b/src/tint/clone_context.h index 05f4868b92..f1db6174e9 100644 --- a/src/tint/clone_context.h +++ b/src/tint/clone_context.h @@ -208,7 +208,7 @@ class CloneContext { to.Push(CheckedCast(builder())); } for (auto& el : from) { - if (auto* insert_before = transforms->insert_before_.Find(el)) { + if (auto insert_before = transforms->insert_before_.Find(el)) { for (auto& builder : *insert_before) { to.Push(CheckedCast(builder())); } @@ -216,7 +216,7 @@ class CloneContext { if (!transforms->remove_.Contains(el)) { to.Push(Clone(el)); } - if (auto* insert_after = transforms->insert_after_.Find(el)) { + if (auto insert_after = transforms->insert_after_.Find(el)) { for (auto& builder : *insert_after) { to.Push(CheckedCast(builder())); } @@ -232,7 +232,7 @@ class CloneContext { // Clone(el) may have updated the transformation list, adding an `insert_after` // transform for `from`. if (transforms) { - if (auto* insert_after = transforms->insert_after_.Find(el)) { + if (auto insert_after = transforms->insert_after_.Find(el)) { for (auto& builder : *insert_after) { to.Push(CheckedCast(builder())); } @@ -389,7 +389,7 @@ class CloneContext { return *this; } - list_transforms_.Edit(&vector).remove_.Add(object); + list_transforms_.GetOrZero(&vector)->remove_.Add(object); return *this; } @@ -411,7 +411,7 @@ class CloneContext { /// @returns this CloneContext so calls can be chained template CloneContext& InsertFront(const utils::Vector& vector, BUILDER&& builder) { - list_transforms_.Edit(&vector).insert_front_.Push(std::forward(builder)); + list_transforms_.GetOrZero(&vector)->insert_front_.Push(std::forward(builder)); return *this; } @@ -434,7 +434,7 @@ class CloneContext { /// @returns this CloneContext so calls can be chained template CloneContext& InsertBack(const utils::Vector& vector, BUILDER&& builder) { - list_transforms_.Edit(&vector).insert_back_.Push(std::forward(builder)); + list_transforms_.GetOrZero(&vector)->insert_back_.Push(std::forward(builder)); return *this; } @@ -456,7 +456,7 @@ class CloneContext { return *this; } - list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push( + list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push( [object] { return object; }); return *this; } @@ -475,7 +475,7 @@ class CloneContext { CloneContext& InsertBefore(const utils::Vector& vector, const BEFORE* before, BUILDER&& builder) { - list_transforms_.Edit(&vector).insert_before_.GetOrZero(before).Push( + list_transforms_.GetOrZero(&vector)->insert_before_.GetOrZero(before)->Push( std::forward(builder)); return *this; } @@ -498,7 +498,7 @@ class CloneContext { return *this; } - list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push( + list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push( [object] { return object; }); return *this; } @@ -517,7 +517,7 @@ class CloneContext { CloneContext& InsertAfter(const utils::Vector& vector, const AFTER* after, BUILDER&& builder) { - list_transforms_.Edit(&vector).insert_after_.GetOrZero(after).Push( + list_transforms_.GetOrZero(&vector)->insert_after_.GetOrZero(after)->Push( std::forward(builder)); return *this; } @@ -601,61 +601,6 @@ class CloneContext { /// @returns the diagnostic list of #dst diag::List& Diagnostics() const; - /// VectorListTransforms is a map of utils::Vector pointer to transforms for that list - struct VectorListTransforms { - using Map = utils::Hashmap; - - /// An accessor to the VectorListTransforms map. - /// Index caches the last map lookup, and will only re-search the map if the transform map - /// was modified since the last lookup. - struct Index { - /// @returns true if the map now holds a value for the index - operator bool() { - Update(); - return cached_; - } - - /// @returns a pointer to the indexed map entry - const ListTransforms* operator->() { - Update(); - return cached_; - } - - private: - friend VectorListTransforms; - - Index(const void* list, Map* map) - : list_(list), - map_(map), - generation_(map->Generation()), - cached_(map_->Find(list)) {} - - void Update() { - if (map_->Generation() != generation_) { - cached_ = map_->Find(list_); - generation_ = map_->Generation(); - } - } - - const void* list_; - Map* map_; - uint64_t generation_; - const ListTransforms* cached_; - }; - - /// Edit returns a reference to the ListTransforms for the given vector pointer and - /// increments #list_transform_generation_ signalling that the list transforms have been - /// modified. - inline ListTransforms& Edit(const void* list) { return map_.GetOrZero(list); } - - /// @returns an Index to the transforms for the given list. - inline Index Find(const void* list) { return Index{list, &map_}; } - - private: - /// The map of vector pointer to ListTransforms - Map map_; - }; - /// A map of object in #src to functions that create their replacement in #dst utils::Hashmap, 8> replacements_; @@ -666,7 +611,7 @@ class CloneContext { utils::Vector transforms_; /// Transformations to apply to vectors - VectorListTransforms list_transforms_; + utils::Hashmap list_transforms_; /// Symbol transform registered with ReplaceAll() SymbolTransform symbol_transform_; diff --git a/src/tint/reader/spirv/function.cc b/src/tint/reader/spirv/function.cc index f97d169c35..accd692e34 100644 --- a/src/tint/reader/spirv/function.cc +++ b/src/tint/reader/spirv/function.cc @@ -3522,7 +3522,7 @@ bool FunctionEmitter::EmitStatementsInBasicBlock(const BlockInfo& block_info, const auto phi_id = assignment.phi_id; auto* const lhs_expr = builder_.Expr(namer_.Name(phi_id)); // If RHS value is actually a phi we just cpatured, then use it. - auto* const copy_sym = copied_phis.Find(assignment.value_id); + auto copy_sym = copied_phis.Find(assignment.value_id); auto* const rhs_expr = copy_sym ? builder_.Expr(*copy_sym) : MakeExpression(assignment.value_id).expr; AddStatement(builder_.Assign(lhs_expr, rhs_expr)); diff --git a/src/tint/reader/spirv/parser_impl.h b/src/tint/reader/spirv/parser_impl.h index 1780ca3b9b..9ff9033b28 100644 --- a/src/tint/reader/spirv/parser_impl.h +++ b/src/tint/reader/spirv/parser_impl.h @@ -666,7 +666,7 @@ class ParserImpl : Reader { /// @param id a SPIR-V ID /// @returns the AST variable or null. const ast::Var* GetModuleVariable(uint32_t id) { - auto* entry = module_variable_.Find(id); + auto entry = module_variable_.Find(id); return entry ? *entry : nullptr; } diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc index 1fda54fdee..796412b5d7 100644 --- a/src/tint/resolver/dependency_graph.cc +++ b/src/tint/resolver/dependency_graph.cc @@ -471,7 +471,7 @@ class DependencyScanner { } } - if (auto* global = globals_.Find(to); global && (*global)->node == resolved) { + if (auto global = globals_.Find(to); global && (*global)->node == resolved) { if (dependency_edges_.Add(DependencyEdge{current_global_, *global}, DependencyInfo{from->source, action})) { current_global_->deps.Push(*global); diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc index 2cc4a3ab52..81e79ff0d0 100644 --- a/src/tint/resolver/dependency_graph_test.cc +++ b/src/tint/resolver/dependency_graph_test.cc @@ -1128,8 +1128,8 @@ TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) { if (expect_pass) { // Check that the use resolves to the declaration - auto* resolved_symbol = graph.resolved_symbols.Find(use); - ASSERT_NE(resolved_symbol, nullptr); + auto resolved_symbol = graph.resolved_symbols.Find(use); + ASSERT_TRUE(resolved_symbol); EXPECT_EQ(*resolved_symbol, decl) << "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "") << "\n" @@ -1179,8 +1179,8 @@ TEST_P(ResolverDependencyShadowTest, Test) { helper.Build(); auto shadows = Build().shadows; - auto* shadow = shadows.Find(inner_var); - ASSERT_NE(shadow, nullptr); + auto shadow = shadows.Find(inner_var); + ASSERT_TRUE(shadow); EXPECT_EQ(*shadow, outer); } @@ -1310,8 +1310,8 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) { auto graph = Build(); for (auto use : symbol_uses) { - auto* resolved_symbol = graph.resolved_symbols.Find(use.use); - ASSERT_NE(resolved_symbol, nullptr) << use.where; + auto resolved_symbol = graph.resolved_symbols.Find(use.use); + ASSERT_TRUE(resolved_symbol) << use.where; EXPECT_EQ(*resolved_symbol, use.decl) << use.where; } } diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 05f7675fa2..0fdd78c3e0 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -2481,7 +2481,7 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) { if (loop_block->FirstContinue()) { // If our identifier is in loop_block->decls, make sure its index is // less than first_continue - if (auto* decl = loop_block->Decls().Find(symbol)) { + if (auto decl = loop_block->Decls().Find(symbol)) { if (decl->order >= loop_block->NumDeclsAtFirstContinue()) { AddError("continue statement bypasses declaration of '" + builder_->Symbols().NameFor(symbol) + "'", diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h index 12ef4a2ea4..2e557d906d 100644 --- a/src/tint/resolver/sem_helper.h +++ b/src/tint/resolver/sem_helper.h @@ -54,7 +54,7 @@ class SemHelper { /// @param node the node to retrieve template SEM* ResolvedSymbol(const ast::Node* node) const { - auto* resolved = dependencies_.resolved_symbols.Find(node); + auto resolved = dependencies_.resolved_symbols.Find(node); return resolved ? const_cast(builder_->Sem().Get(*resolved)) : nullptr; } diff --git a/src/tint/resolver/uniformity.cc b/src/tint/resolver/uniformity.cc index afdbdfb1b6..ffedaf6240 100644 --- a/src/tint/resolver/uniformity.cc +++ b/src/tint/resolver/uniformity.cc @@ -600,7 +600,7 @@ class UniformityGraph { } // Add an edge from the variable's loop input node to its value at this point. - auto** in_node = info.var_in_nodes.Find(var); + auto in_node = info.var_in_nodes.Find(var); TINT_ASSERT(Resolver, in_node != nullptr); auto* out_node = current_function_->variables.Get(var); if (out_node != *in_node) { @@ -1334,7 +1334,7 @@ class UniformityGraph { [&](const sem::Function* func) { // We must have already analyzed the user-defined function since we process // functions in dependency order. - auto* info = functions_.Find(func->Declaration()); + auto info = functions_.Find(func->Declaration()); TINT_ASSERT(Resolver, info != nullptr); callsite_tag = info->callsite_tag; function_tag = info->function_tag; @@ -1466,7 +1466,7 @@ class UniformityGraph { } else if (auto* user = target->As()) { // This is a call to a user-defined function, so inspect the functions called by that // function and look for one whose node has an edge from the RequiredToBeUniform node. - auto* target_info = functions_.Find(user->Declaration()); + auto target_info = functions_.Find(user->Declaration()); for (auto* call_node : target_info->required_to_be_uniform->edges) { if (call_node->type == Node::kRegular) { auto* child_call = call_node->ast->As(); @@ -1636,7 +1636,7 @@ class UniformityGraph { // If this is a call to a user-defined function, add a note to show the reason that the // parameter is required to be uniform. if (auto* user = target->As()) { - auto* next_function = functions_.Find(user->Declaration()); + auto next_function = functions_.Find(user->Declaration()); Node* next_cause = next_function->parameters[cause->arg_index].init_value; MakeError(*next_function, next_cause, true); } diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index db93bbe6e7..2fc8f21085 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -719,7 +719,7 @@ bool Validator::AtomicVariable( return false; } } else if (type->IsAnyOf()) { - if (auto* found = atomic_composite_info.Find(type)) { + if (auto found = atomic_composite_info.Find(type)) { if (address_space != ast::AddressSpace::kStorage && address_space != ast::AddressSpace::kWorkgroup) { AddError("atomic variables must have or address space", @@ -798,7 +798,7 @@ bool Validator::Override( for (auto* attr : decl->attributes) { if (attr->Is()) { auto id = v->OverrideId(); - if (auto* var = override_ids.Find(id); var && *var != v) { + if (auto var = override_ids.Find(id); var && *var != v) { AddError("@id values must be unique", attr->source); AddNote( "a override with an ID of " + std::to_string(id.value) + diff --git a/src/tint/scope_stack.h b/src/tint/scope_stack.h index a2da4dde6f..75c50b48c6 100644 --- a/src/tint/scope_stack.h +++ b/src/tint/scope_stack.h @@ -44,7 +44,7 @@ class ScopeStack { /// stack, otherwise the zero initializer for type T. V Set(const K& key, V val) { auto& back = stack_.Back(); - if (auto* el = back.Find(key)) { + if (auto el = back.Find(key)) { std::swap(val, *el); return val; } @@ -57,7 +57,7 @@ class ScopeStack { /// @returns the value, or the zero initializer if the value was not found V Get(const K& key) const { for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) { - if (auto* val = iter->Find(key)) { + if (auto val = iter->Find(key)) { return *val; } } diff --git a/src/tint/transform/decompose_strided_matrix.cc b/src/tint/transform/decompose_strided_matrix.cc index 5494ca246f..b7fd7c2d6e 100644 --- a/src/tint/transform/decompose_strided_matrix.cc +++ b/src/tint/transform/decompose_strided_matrix.cc @@ -129,7 +129,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src, std::unordered_map mat_to_arr; ctx.ReplaceAll([&](const ast::AssignmentStatement* stmt) -> const ast::Statement* { if (auto* access = src->Sem().Get(stmt->lhs)) { - if (auto* info = decomposed.Find(access->Member()->Declaration())) { + if (auto info = decomposed.Find(access->Member()->Declaration())) { auto fn = utils::GetOrCreate(mat_to_arr, *info, [&] { auto name = b.Symbols().New("mat" + std::to_string(info->matrix->columns()) + "x" + @@ -168,7 +168,7 @@ Transform::ApplyResult DecomposeStridedMatrix::Apply(const Program* src, std::unordered_map arr_to_mat; ctx.ReplaceAll([&](const ast::MemberAccessorExpression* expr) -> const ast::Expression* { if (auto* access = src->Sem().Get(expr)) { - if (auto* info = decomposed.Find(access->Member()->Declaration())) { + if (auto info = decomposed.Find(access->Member()->Declaration())) { auto fn = utils::GetOrCreate(arr_to_mat, *info, [&] { auto name = b.Symbols().New("arr_to_mat" + std::to_string(info->matrix->columns()) + diff --git a/src/tint/transform/simplify_pointers.cc b/src/tint/transform/simplify_pointers.cc index b2b99ed470..a0855b7089 100644 --- a/src/tint/transform/simplify_pointers.cc +++ b/src/tint/transform/simplify_pointers.cc @@ -140,7 +140,7 @@ struct SimplifyPointers::State { // variable identifier. ctx.ReplaceAll([&](const ast::Expression* expr) -> const ast::Expression* { // Look to see if we need to swap this Expression with a saved variable. - if (auto* saved_var = saved_vars.Find(expr)) { + if (auto saved_var = saved_vars.Find(expr)) { return ctx.dst->Expr(*saved_var); } diff --git a/src/tint/transform/std140.cc b/src/tint/transform/std140.cc index 2116b0fcf4..8b566fefcd 100644 --- a/src/tint/transform/std140.cc +++ b/src/tint/transform/std140.cc @@ -401,7 +401,7 @@ struct Std140::State { return Switch( ty, // [&](const sem::Struct* str) -> const ast::Type* { - if (auto* std140 = std140_structs.Find(str)) { + if (auto std140 = std140_structs.Find(str)) { return b.create(*std140); } return nullptr; @@ -695,7 +695,7 @@ struct Std140::State { // call, or by reassembling a std140 matrix from column vector members. utils::Vector args; for (auto* member : str->Members()) { - if (auto* col_members = std140_mat_members.Find(member)) { + if (auto col_members = std140_mat_members.Find(member)) { // std140 decomposed matrix. Reassemble. auto* mat_ty = CreateASTTypeFor(ctx, member->Type()); auto mat_args = diff --git a/src/tint/transform/truncate_interstage_variables.cc b/src/tint/transform/truncate_interstage_variables.cc index a5e7256566..30237bc85d 100644 --- a/src/tint/transform/truncate_interstage_variables.cc +++ b/src/tint/transform/truncate_interstage_variables.cc @@ -161,7 +161,7 @@ Transform::ApplyResult TruncateInterstageVariables::Apply(const Program* src, ctx.ReplaceAll( [&](const ast::ReturnStatement* return_statement) -> const ast::ReturnStatement* { auto* return_sem = sem.Get(return_statement); - if (auto* mapping_fn_sym = + if (auto mapping_fn_sym = entry_point_functions_to_truncate_functions.Find(return_sem->Function())) { return b.Return(return_statement->source, b.Call(*mapping_fn_sym, ctx.Clone(return_statement->value))); diff --git a/src/tint/transform/unshadow.cc b/src/tint/transform/unshadow.cc index 93ce595a13..8d2b876381 100644 --- a/src/tint/transform/unshadow.cc +++ b/src/tint/transform/unshadow.cc @@ -97,7 +97,7 @@ struct Unshadow::State { ctx.ReplaceAll( [&](const ast::IdentifierExpression* ident) -> const tint::ast::IdentifierExpression* { if (auto* user = sem.Get(ident)) { - if (auto* renamed = renamed_to.Find(user->Variable())) { + if (auto renamed = renamed_to.Find(user->Variable())) { return b.Expr(*renamed); } } diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc index d4db655538..ede1986c94 100644 --- a/src/tint/transform/utils/hoist_to_decl_before.cc +++ b/src/tint/transform/utils/hoist_to_decl_before.cc @@ -135,7 +135,7 @@ struct HoistToDeclBefore::State { /// automatically called. /// @warning the returned reference is invalid if this is called a second time, or the /// #for_loops map is mutated. - LoopInfo& ForLoop(const sem::ForLoopStatement* for_loop) { + auto ForLoop(const sem::ForLoopStatement* for_loop) { if (for_loops.IsEmpty()) { RegisterForLoopTransform(); } @@ -147,7 +147,7 @@ struct HoistToDeclBefore::State { /// automatically called. /// @warning the returned reference is invalid if this is called a second time, or the /// #for_loops map is mutated. - LoopInfo& WhileLoop(const sem::WhileStatement* while_loop) { + auto WhileLoop(const sem::WhileStatement* while_loop) { if (while_loops.IsEmpty()) { RegisterWhileLoopTransform(); } @@ -159,7 +159,7 @@ struct HoistToDeclBefore::State { /// automatically called. /// @warning the returned reference is invalid if this is called a second time, or the /// #else_ifs map is mutated. - ElseIfInfo& ElseIf(const ast::IfStatement* else_if) { + auto ElseIf(const ast::IfStatement* else_if) { if (else_ifs.IsEmpty()) { RegisterElseIfTransform(); } @@ -172,7 +172,7 @@ struct HoistToDeclBefore::State { auto& sem = ctx.src->Sem(); if (auto* fl = sem.Get(stmt)) { - if (auto* info = for_loops.Find(fl)) { + if (auto info = for_loops.Find(fl)) { auto* for_loop = fl->Declaration(); // For-loop needs to be decomposed to a loop. // Build the loop body's statements. @@ -222,7 +222,7 @@ struct HoistToDeclBefore::State { auto& sem = ctx.src->Sem(); if (auto* w = sem.Get(stmt)) { - if (auto* info = while_loops.Find(w)) { + if (auto info = while_loops.Find(w)) { auto* while_loop = w->Declaration(); // While needs to be decomposed to a loop. // Build the loop body's statements. @@ -259,7 +259,7 @@ struct HoistToDeclBefore::State { void RegisterElseIfTransform() const { // Decompose 'else-if' statements into 'else { if }' blocks. ctx.ReplaceAll([&](const ast::IfStatement* stmt) -> const ast::Statement* { - if (auto* info = else_ifs.Find(stmt)) { + if (auto info = else_ifs.Find(stmt)) { // Build the else block's body statements, starting with let decls for the // conditional expression. auto body_stmts = Build(info->cond_decls); @@ -291,10 +291,10 @@ struct HoistToDeclBefore::State { if (else_if && else_if->Parent()->Is()) { // Insertion point is an 'else if' condition. // Need to convert 'else if' to 'else { if }'. - auto& else_if_info = ElseIf(else_if->Declaration()); + auto else_if_info = ElseIf(else_if->Declaration()); // Index the map to convert this else if, even if `stmt` is nullptr. - auto& decls = else_if_info.cond_decls; + auto& decls = else_if_info->cond_decls; if constexpr (!std::is_same_v) { decls.Push(std::forward(builder)); } @@ -306,7 +306,7 @@ struct HoistToDeclBefore::State { // For-loop needs to be decomposed to a loop. // Index the map to convert this for-loop, even if `stmt` is nullptr. - auto& decls = ForLoop(fl).cond_decls; + auto& decls = ForLoop(fl)->cond_decls; if constexpr (!std::is_same_v) { decls.Push(std::forward(builder)); } @@ -318,7 +318,7 @@ struct HoistToDeclBefore::State { // While needs to be decomposed to a loop. // Index the map to convert this while, even if `stmt` is nullptr. - auto& decls = WhileLoop(w).cond_decls; + auto& decls = WhileLoop(w)->cond_decls; if constexpr (!std::is_same_v) { decls.Push(std::forward(builder)); } @@ -354,7 +354,7 @@ struct HoistToDeclBefore::State { // For-loop needs to be decomposed to a loop. // Index the map to convert this for-loop, even if `stmt` is nullptr. - auto& decls = ForLoop(fl).cont_decls; + auto& decls = ForLoop(fl)->cont_decls; if constexpr (!std::is_same_v) { decls.Push(std::forward(builder)); } diff --git a/src/tint/utils/hashmap.h b/src/tint/utils/hashmap.h index 1040e0cbbb..19a5abebc8 100644 --- a/src/tint/utils/hashmap.h +++ b/src/tint/utils/hashmap.h @@ -47,6 +47,67 @@ class Hashmap : public HashmapBase { /// Result of Add() using AddResult = typename Base::PutResult; + /// Reference is returned by Hashmap::Find(), and performs dynamic Hashmap lookups. + /// The value returned by the Reference reflects the current state of the Hashmap, and so the + /// referenced value may change, or transition between valid or invalid based on the current + /// state of the Hashmap. + template + class ReferenceT { + /// `const Value` if IS_CONST, or `Value` if !IS_CONST + using T = std::conditional_t; + + /// `const Hashmap` if IS_CONST, or `Hashmap` if !IS_CONST + using Map = std::conditional_t; + + public: + /// @returns true if the reference is valid. + operator bool() const { return Get() != nullptr; } + + /// @returns the pointer to the Value, or nullptr if the reference is invalid. + operator T*() const { return Get(); } + + /// @returns the pointer to the Value + /// @warning if the Hashmap does not contain a value for the reference, then this will + /// trigger a TINT_ASSERT, or invalid pointer dereference. + T* operator->() const { + auto* hashmap_reference_lookup = Get(); + TINT_ASSERT(Utils, hashmap_reference_lookup != nullptr); + return hashmap_reference_lookup; + } + + /// @returns the pointer to the Value, or nullptr if the reference is invalid. + T* Get() const { + auto generation = map_.Generation(); + if (generation_ != generation) { + cached_ = map_.Lookup(key_); + generation_ = generation; + } + return cached_; + } + + private: + friend Hashmap; + + /// Constructor + ReferenceT(Map& map, const Key& key) + : map_(map), key_(key), cached_(nullptr), generation_(map.Generation() - 1) {} + + /// Constructor + ReferenceT(Map& map, const Key& key, T* value) + : map_(map), key_(key), cached_(value), generation_(map.Generation()) {} + + Map& map_; + const Key key_; + mutable T* cached_ = nullptr; + mutable size_t generation_ = 0; + }; + + /// A mutable reference returned by Find() + using Reference = ReferenceT; + + /// An immutable reference returned by Find() + using ConstReference = ReferenceT; + /// Adds a value to the map, if the map does not already contain an entry with the key @p key. /// @param key the entry key. /// @param value the value of the entry to add to the map. @@ -108,25 +169,28 @@ class Hashmap : public HashmapBase { /// @param key the entry's key value to search for. /// @returns the value of the entry. template - Value& GetOrZero(K&& key) { + Reference GetOrZero(K&& key) { auto res = Add(std::forward(key), Value{}); - return *res.value; + return Reference(*this, key, res.value); } /// @param key the key to search for. - /// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does - /// not contain the given value. - const Value* Find(const Key& key) const { + /// @returns a reference to the entry that is equal to the given value. + Reference Find(const Key& key) { return Reference(*this, key); } + + /// @param key the key to search for. + /// @returns a reference to the entry that is equal to the given value. + ConstReference Find(const Key& key) const { return ConstReference(*this, key); } + + private: + Value* Lookup(const Key& key) { if (auto [found, index] = this->IndexOf(key); found) { return &this->slots_[index].entry->value; } return nullptr; } - /// @param key the key to search for. - /// @returns a pointer to the entry that is equal to the given value, or nullptr if the map does - /// not contain the given value. - Value* Find(const Key& key) { + const Value* Lookup(const Key& key) const { if (auto [found, index] = this->IndexOf(key); found) { return &this->slots_[index].entry->value; } diff --git a/src/tint/utils/hashmap_test.cc b/src/tint/utils/hashmap_test.cc index 77421cf76c..34a93e6c30 100644 --- a/src/tint/utils/hashmap_test.cc +++ b/src/tint/utils/hashmap_test.cc @@ -90,6 +90,71 @@ TEST(Hashmap, Generation) { EXPECT_EQ(map.Generation(), 5u); } +TEST(Hashmap, Index) { + Hashmap map; + auto zero = map.Find(0); + EXPECT_FALSE(zero); + + map.Add(3, "three"); + auto three = map.Find(3); + map.Add(2, "two"); + auto two = map.Find(2); + map.Add(4, "four"); + auto four = map.Find(4); + map.Add(8, "eight"); + auto eight = map.Find(8); + + EXPECT_FALSE(zero); + ASSERT_TRUE(three); + ASSERT_TRUE(two); + ASSERT_TRUE(four); + ASSERT_TRUE(eight); + + EXPECT_EQ(*three, "three"); + EXPECT_EQ(*two, "two"); + EXPECT_EQ(*four, "four"); + EXPECT_EQ(*eight, "eight"); + + map.Add(0, "zero"); // Note: Find called before Add() is okay! + + map.Add(5, "five"); + auto five = map.Find(5); + map.Add(6, "six"); + auto six = map.Find(6); + map.Add(1, "one"); + auto one = map.Find(1); + map.Add(7, "seven"); + auto seven = map.Find(7); + + ASSERT_TRUE(zero); + ASSERT_TRUE(three); + ASSERT_TRUE(two); + ASSERT_TRUE(four); + ASSERT_TRUE(eight); + ASSERT_TRUE(five); + ASSERT_TRUE(six); + ASSERT_TRUE(one); + ASSERT_TRUE(seven); + + EXPECT_EQ(*zero, "zero"); + EXPECT_EQ(*three, "three"); + EXPECT_EQ(*two, "two"); + EXPECT_EQ(*four, "four"); + EXPECT_EQ(*eight, "eight"); + EXPECT_EQ(*five, "five"); + EXPECT_EQ(*six, "six"); + EXPECT_EQ(*one, "one"); + EXPECT_EQ(*seven, "seven"); + + map.Remove(2); + map.Remove(8); + map.Remove(1); + + EXPECT_FALSE(two); + EXPECT_FALSE(eight); + EXPECT_FALSE(one); +} + TEST(Hashmap, Iterator) { using Map = Hashmap; using Entry = typename Map::Entry;