diff --git a/src/resolver/call_validation_test.cc b/src/resolver/call_validation_test.cc index e4a479b9ca..71cca02df0 100644 --- a/src/resolver/call_validation_test.cc +++ b/src/resolver/call_validation_test.cc @@ -265,6 +265,24 @@ TEST_F(ResolverCallValidationTest, CallVariable) { note: 'v' declared here)"); } +TEST_F(ResolverCallValidationTest, CallVariableShadowsFunction) { + // fn x() {} + // fn f() { + // var x : i32; + // x(); + // } + Func("x", {}, ty.void_(), {}); + Func("f", {}, ty.void_(), + { + Decl(Var(Source{{56, 78}}, "x", ty.i32())), + CallStmt(Call(Source{{12, 34}}, "x")), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(error: cannot call variable 'x' +56:78 note: 'x' declared here)"); +} + } // namespace } // namespace resolver } // namespace tint diff --git a/src/resolver/dependency_graph.cc b/src/resolver/dependency_graph.cc index d3ac5f4b6e..cbb574de8f 100644 --- a/src/resolver/dependency_graph.cc +++ b/src/resolver/dependency_graph.cc @@ -173,6 +173,16 @@ class DependencyScanner { /// Traverses the function, performing symbol resolution and determining /// global dependencies. void TraverseFunction(const ast::Function* func) { + // Perform symbol resolution on all the parameter types before registering + // the parameters themselves. This allows the case of declaring a parameter + // with the same identifier as its type. + for (auto* param : func->params) { + TraverseType(param->type); + } + // Resolve the return type + TraverseType(func->return_type); + + // Push the scope stack for the parameters and function body. scope_stack_.Push(); TINT_DEFER(scope_stack_.Pop()); @@ -181,12 +191,10 @@ class DependencyScanner { graph_.shadows.emplace(param, shadows); } Declare(param->symbol, param); - TraverseType(param->type); } if (func->body) { TraverseStatements(func->body->statements); } - TraverseType(func->return_type); } /// Traverses the statements, performing symbol resolution and determining @@ -295,38 +303,15 @@ class DependencyScanner { ast::TraverseExpressions( root, diagnostics_, [&](const ast::Expression* expr) { if (auto* ident = expr->As()) { - auto* node = scope_stack_.Get(ident->symbol); - if (node == nullptr) { - if (!IsBuiltin(ident->symbol)) { - UnknownSymbol(ident->symbol, ident->source, "identifier"); - } - return ast::TraverseAction::Descend; - } - auto global_it = globals_.find(ident->symbol); - if (global_it != globals_.end() && - node == global_it->second->node) { - AddGlobalDependency(ident, ident->symbol, "identifier", - "references"); - } else { - graph_.resolved_symbols.emplace(ident, node); - } + AddDependency(ident, ident->symbol, "identifier", "references"); } if (auto* call = expr->As()) { if (call->target.name) { - if (!IsBuiltin(call->target.name->symbol)) { - AddGlobalDependency(call->target.name, - call->target.name->symbol, "function", - "calls"); - graph_.resolved_symbols.emplace( - call, - utils::Lookup(graph_.resolved_symbols, call->target.name)); - } + AddDependency(call->target.name, call->target.name->symbol, + "function", "calls"); } if (call->target.type) { TraverseType(call->target.type); - graph_.resolved_symbols.emplace( - call, - utils::Lookup(graph_.resolved_symbols, call->target.type)); } } if (auto* cast = expr->As()) { @@ -360,7 +345,7 @@ class DependencyScanner { return; } if (auto* tn = ty->As()) { - AddGlobalDependency(tn, tn->name, "type", "references"); + AddDependency(tn, tn->name, "type", "references"); return; } if (auto* vec = ty->As()) { @@ -416,24 +401,31 @@ class DependencyScanner { UnhandledNode(diagnostics_, attr); } - /// Adds the dependency to the currently processed global - void AddGlobalDependency(const ast::Node* from, - Symbol to, - const char* use, - const char* action) { - auto global_it = globals_.find(to); - if (global_it != globals_.end()) { - auto* global = global_it->second; + /// Adds the dependency from `from` to `to`, erroring if `to` cannot be + /// resolved. + void AddDependency(const ast::Node* from, + Symbol to, + const char* use, + const char* action) { + auto* resolved = scope_stack_.Get(to); + if (!resolved) { + if (!IsBuiltin(to)) { + UnknownSymbol(to, from->source, use); + return; + } + } + + if (auto* global = utils::Lookup(globals_, to); + global && global->node == resolved) { if (dependency_edges_ .emplace(DependencyEdge{current_global_, global}, DependencyInfo{from->source, action}) .second) { current_global_->deps.emplace_back(global); } - graph_.resolved_symbols.emplace(from, global->node); - } else { - UnknownSymbol(to, from->source, use); } + + graph_.resolved_symbols.emplace(from, resolved); } /// @returns true if `name` is the name of a builtin function diff --git a/src/resolver/dependency_graph_test.cc b/src/resolver/dependency_graph_test.cc index 2329d86f04..102d99666e 100644 --- a/src/resolver/dependency_graph_test.cc +++ b/src/resolver/dependency_graph_test.cc @@ -993,9 +993,9 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, Mixed_RecursiveDependencies) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - R"(3:1 error: cyclic dependency found: 'S' -> 'A' -> 'S' -3:10 note: struct 'S' references alias 'A' here + R"(2:1 error: cyclic dependency found: 'A' -> 'S' -> 'A' 2:10 note: alias 'A' references struct 'S' here +3:10 note: struct 'S' references alias 'A' here 4:1 error: cyclic dependency found: 'Z' -> 'L' -> 'Z' 4:10 note: var 'Z' references let 'L' here 5:10 note: let 'L' references var 'Z' here)"); diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 98697559ee..f2f516abcb 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -268,8 +268,9 @@ sem::Type* Resolver::Type(const ast::Type* ty) { return builder_->create(); }, [&](Default) -> sem::Type* { + auto* resolved = ResolvedSymbol(ty); return Switch( - ResolvedSymbol(ty), // + resolved, // [&](sem::Type* type) { return type; }, [&](sem::Variable* var) { auto name = @@ -291,7 +292,10 @@ sem::Type* Resolver::Type(const ast::Type* ty) { }, [&](Default) { TINT_UNREACHABLE(Resolver, diagnostics_) - << "Unhandled ast::Type: " << ty->TypeInfo().name; + << "Unhandled resolved type '" + << (resolved ? resolved->TypeInfo().name : "") + << "' resolved from ast::Type '" << ty->TypeInfo().name + << "'"; return nullptr; }); }); diff --git a/src/utils/map.h b/src/utils/map.h index db1b153e24..a84c9ba637 100644 --- a/src/utils/map.h +++ b/src/utils/map.h @@ -29,9 +29,9 @@ namespace utils { /// @return the map item value, or `if_missing` if the map does not contain the /// given key template -V Lookup(std::unordered_map& map, +V Lookup(const std::unordered_map& map, const KV& key, - const KV& if_missing = {}) { + const V& if_missing = {}) { auto it = map.find(key); return it != map.end() ? it->second : if_missing; }