tint/utils: Make Hashmap::Find() safer to use

Don't return a raw pointer to the map entry's value, instead return a new Reference which re-looks up the entry if the map is mutated.

Change-Id: I031749785faeac98e2a129a776493cb0371a5cb9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110540
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton
2022-11-23 21:04:25 +00:00
committed by Dawn LUCI CQ
parent 597ad53029
commit 7c6e229a18
19 changed files with 187 additions and 113 deletions

View File

@@ -471,7 +471,7 @@ class DependencyScanner {
}
}
if (auto* global = globals_.Find(to); global && (*global)->node == resolved) {
if (auto global = globals_.Find(to); global && (*global)->node == resolved) {
if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
DependencyInfo{from->source, action})) {
current_global_->deps.Push(*global);

View File

@@ -1128,8 +1128,8 @@ TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) {
if (expect_pass) {
// Check that the use resolves to the declaration
auto* resolved_symbol = graph.resolved_symbols.Find(use);
ASSERT_NE(resolved_symbol, nullptr);
auto resolved_symbol = graph.resolved_symbols.Find(use);
ASSERT_TRUE(resolved_symbol);
EXPECT_EQ(*resolved_symbol, decl)
<< "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "<null>")
<< "\n"
@@ -1179,8 +1179,8 @@ TEST_P(ResolverDependencyShadowTest, Test) {
helper.Build();
auto shadows = Build().shadows;
auto* shadow = shadows.Find(inner_var);
ASSERT_NE(shadow, nullptr);
auto shadow = shadows.Find(inner_var);
ASSERT_TRUE(shadow);
EXPECT_EQ(*shadow, outer);
}
@@ -1310,8 +1310,8 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) {
auto graph = Build();
for (auto use : symbol_uses) {
auto* resolved_symbol = graph.resolved_symbols.Find(use.use);
ASSERT_NE(resolved_symbol, nullptr) << use.where;
auto resolved_symbol = graph.resolved_symbols.Find(use.use);
ASSERT_TRUE(resolved_symbol) << use.where;
EXPECT_EQ(*resolved_symbol, use.decl) << use.where;
}
}

View File

@@ -2481,7 +2481,7 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
if (loop_block->FirstContinue()) {
// If our identifier is in loop_block->decls, make sure its index is
// less than first_continue
if (auto* decl = loop_block->Decls().Find(symbol)) {
if (auto decl = loop_block->Decls().Find(symbol)) {
if (decl->order >= loop_block->NumDeclsAtFirstContinue()) {
AddError("continue statement bypasses declaration of '" +
builder_->Symbols().NameFor(symbol) + "'",

View File

@@ -54,7 +54,7 @@ class SemHelper {
/// @param node the node to retrieve
template <typename SEM = sem::Node>
SEM* ResolvedSymbol(const ast::Node* node) const {
auto* resolved = dependencies_.resolved_symbols.Find(node);
auto resolved = dependencies_.resolved_symbols.Find(node);
return resolved ? const_cast<SEM*>(builder_->Sem().Get<SEM>(*resolved)) : nullptr;
}

View File

@@ -600,7 +600,7 @@ class UniformityGraph {
}
// Add an edge from the variable's loop input node to its value at this point.
auto** in_node = info.var_in_nodes.Find(var);
auto in_node = info.var_in_nodes.Find(var);
TINT_ASSERT(Resolver, in_node != nullptr);
auto* out_node = current_function_->variables.Get(var);
if (out_node != *in_node) {
@@ -1334,7 +1334,7 @@ class UniformityGraph {
[&](const sem::Function* func) {
// We must have already analyzed the user-defined function since we process
// functions in dependency order.
auto* info = functions_.Find(func->Declaration());
auto info = functions_.Find(func->Declaration());
TINT_ASSERT(Resolver, info != nullptr);
callsite_tag = info->callsite_tag;
function_tag = info->function_tag;
@@ -1466,7 +1466,7 @@ class UniformityGraph {
} else if (auto* user = target->As<sem::Function>()) {
// This is a call to a user-defined function, so inspect the functions called by that
// function and look for one whose node has an edge from the RequiredToBeUniform node.
auto* target_info = functions_.Find(user->Declaration());
auto target_info = functions_.Find(user->Declaration());
for (auto* call_node : target_info->required_to_be_uniform->edges) {
if (call_node->type == Node::kRegular) {
auto* child_call = call_node->ast->As<ast::CallExpression>();
@@ -1636,7 +1636,7 @@ class UniformityGraph {
// If this is a call to a user-defined function, add a note to show the reason that the
// parameter is required to be uniform.
if (auto* user = target->As<sem::Function>()) {
auto* next_function = functions_.Find(user->Declaration());
auto next_function = functions_.Find(user->Declaration());
Node* next_cause = next_function->parameters[cause->arg_index].init_value;
MakeError(*next_function, next_cause, true);
}

View File

@@ -719,7 +719,7 @@ bool Validator::AtomicVariable(
return false;
}
} else if (type->IsAnyOf<sem::Struct, sem::Array>()) {
if (auto* found = atomic_composite_info.Find(type)) {
if (auto found = atomic_composite_info.Find(type)) {
if (address_space != ast::AddressSpace::kStorage &&
address_space != ast::AddressSpace::kWorkgroup) {
AddError("atomic variables must have <storage> or <workgroup> address space",
@@ -798,7 +798,7 @@ bool Validator::Override(
for (auto* attr : decl->attributes) {
if (attr->Is<ast::IdAttribute>()) {
auto id = v->OverrideId();
if (auto* var = override_ids.Find(id); var && *var != v) {
if (auto var = override_ids.Find(id); var && *var != v) {
AddError("@id values must be unique", attr->source);
AddNote(
"a override with an ID of " + std::to_string(id.value) +