From cf0e9301b24d47800a5522035ca57fec19f37eac Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 8 Feb 2023 15:18:43 +0000 Subject: [PATCH] tint: Improve the output of DependencyGraph Add ResolvedIdentifier to hold the resolved AST node, sem::BuiltinType or type::Builtin. Reduces duplicate builtin symbol lookups in Resolver. Bug: tint:1810 Change-Id: Idde2b5f6fa22804b5019adc14c717bebd8342475 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/119041 Kokoro: Kokoro Commit-Queue: Ben Clayton Reviewed-by: Dan Sinclair --- src/tint/program_builder.h | 6 +- src/tint/resolver/builtin_validation_test.cc | 4 +- src/tint/resolver/dependency_graph.cc | 55 ++- src/tint/resolver/dependency_graph.h | 78 +++- src/tint/resolver/dependency_graph_test.cc | 306 ++++++++++++--- src/tint/resolver/resolver.cc | 390 ++++++++++--------- src/tint/resolver/resolver.h | 7 +- src/tint/resolver/sem_helper.cc | 3 +- src/tint/resolver/sem_helper.h | 16 +- src/tint/resolver/validator.cc | 8 +- src/tint/sem/builtin_type.h | 118 ++++++ src/tint/sem/builtin_type.h.tmpl | 7 + 12 files changed, 704 insertions(+), 294 deletions(-) diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index 396a06029d..90507833e9 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -879,8 +879,7 @@ class ProgramBuilder { /// @returns the type name template > const ast::TypeName* operator()(NAME&& name, ARGS&&... args) const { - return builder->create( - builder->Ident(std::forward(name), std::forward(args)...)); + return (*this)(builder->source_, std::forward(name), std::forward(args)...); } /// Creates a type name @@ -891,7 +890,8 @@ class ProgramBuilder { template const ast::TypeName* operator()(const Source& source, NAME&& name, ARGS&&... args) const { return builder->create( - source, builder->Ident(std::forward(name), std::forward(args)...)); + source, + builder->Ident(source, std::forward(name), std::forward(args)...)); } /// Creates an alias type diff --git a/src/tint/resolver/builtin_validation_test.cc b/src/tint/resolver/builtin_validation_test.cc index 4f5c94ff40..51741d0d3b 100644 --- a/src/tint/resolver/builtin_validation_test.cc +++ b/src/tint/resolver/builtin_validation_test.cc @@ -210,7 +210,7 @@ TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsAliasUsedAsVariable) { WrapInFunction(Decl(Var("v", Expr(Source{{56, 78}}, "mix")))); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(56:78 error: missing '(' for builtin call)"); + EXPECT_EQ(r()->error(), R"(56:78 error: missing '(' for type initializer or cast)"); } TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsAliasUsedAsType) { @@ -242,7 +242,7 @@ TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsStructUsedAsVariable) { WrapInFunction(Decl(Var("v", Expr(Source{{12, 34}}, "mix")))); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: missing '(' for builtin call)"); + EXPECT_EQ(r()->error(), R"(12:34 error: missing '(' for type initializer or cast)"); } TEST_F(ResolverBuiltinValidationTest, BuiltinRedeclaredAsStructUsedAsType) { diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc index 8e99d8ff01..7326f818dc 100644 --- a/src/tint/resolver/dependency_graph.cc +++ b/src/tint/resolver/dependency_graph.cc @@ -354,7 +354,8 @@ class DependencyScanner { Switch( expr, [&](const ast::IdentifierExpression* ident) { - AddDependency(ident, ident->identifier->symbol, "identifier", "references"); + AddDependency(ident->identifier, ident->identifier->symbol, "identifier", + "references"); }, [&](const ast::CallExpression* call) { if (call->target.name) { @@ -392,7 +393,7 @@ class DependencyScanner { TraverseType(ptr->type); }, [&](const ast::TypeName* tn) { // - AddDependency(tn, tn->name->symbol, "type", "references"); + AddDependency(tn->name, tn->name->symbol, "type", "references"); }, [&](const ast::Vector* vec) { // TraverseType(vec->type); @@ -468,15 +469,25 @@ class DependencyScanner { UnhandledNode(diagnostics_, attr); } - /// Adds the dependency from `from` to `to`, erroring if `to` cannot be - /// resolved. - void AddDependency(const ast::Node* from, Symbol to, const char* use, const char* action) { + /// Adds the dependency from @p from to @p to, erroring if @p to cannot be resolved. + void AddDependency(const ast::Identifier* from, + Symbol to, + const char* use, + const char* action) { auto* resolved = scope_stack_.Get(to); if (!resolved) { - if (!IsBuiltin(to)) { - UnknownSymbol(to, from->source, use); + auto s = symbols_.NameFor(to); + if (auto builtin_fn = sem::ParseBuiltinType(s); builtin_fn != sem::BuiltinType::kNone) { + graph_.resolved_identifiers.Add(from, ResolvedIdentifier(builtin_fn)); return; } + if (auto builtin_ty = type::ParseBuiltin(s); builtin_ty != type::Builtin::kUndefined) { + graph_.resolved_identifiers.Add(from, ResolvedIdentifier(builtin_ty)); + return; + } + + UnknownSymbol(to, from->source, use); + return; } if (auto global = globals_.Find(to); global && (*global)->node == resolved) { @@ -486,17 +497,7 @@ class DependencyScanner { } } - graph_.resolved_symbols.Add(from, resolved); - } - - /// @returns true if `name` is the name of a builtin function, or builtin type alias - bool IsBuiltin(Symbol name) const { - auto s = symbols_.NameFor(name); - if (sem::ParseBuiltinType(s) != sem::BuiltinType::kNone || - type::ParseBuiltin(s) != type::Builtin::kUndefined) { - return true; - } - return false; + graph_.resolved_identifiers.Add(from, ResolvedIdentifier(resolved)); } /// Appends an error to the diagnostics that the given symbol cannot be @@ -529,7 +530,7 @@ struct DependencyAnalysis { /// @returns true if analysis found no errors, otherwise false. bool Run(const ast::Module& module) { // Reserve container memory - graph_.resolved_symbols.Reserve(module.GlobalDeclarations().Length()); + graph_.resolved_identifiers.Reserve(module.GlobalDeclarations().Length()); sorted_.Reserve(module.GlobalDeclarations().Length()); // Collect all the named globals from the AST module @@ -821,4 +822,20 @@ bool DependencyGraph::Build(const ast::Module& module, return da.Run(module); } +std::ostream& operator<<(std::ostream& out, const ResolvedIdentifier& ri) { + if (!ri) { + return out << ""; + } + if (auto* node = ri.Node()) { + return out << "'" << node->TypeInfo().name << "' at " << node->source; + } + if (auto builtin_fn = ri.BuiltinFunction(); builtin_fn != sem::BuiltinType::kNone) { + return out << "builtin function '" << builtin_fn << "'"; + } + if (auto builtin_ty = ri.BuiltinType(); builtin_ty != type::Builtin::kUndefined) { + return out << "builtin function '" << builtin_ty << "'"; + } + return out << ""; +} + } // namespace tint::resolver diff --git a/src/tint/resolver/dependency_graph.h b/src/tint/resolver/dependency_graph.h index bc849a0e64..5bbc7a466a 100644 --- a/src/tint/resolver/dependency_graph.h +++ b/src/tint/resolver/dependency_graph.h @@ -19,10 +19,83 @@ #include "src/tint/ast/module.h" #include "src/tint/diagnostic/diagnostic.h" +#include "src/tint/sem/builtin_type.h" +#include "src/tint/type/builtin.h" #include "src/tint/utils/hashmap.h" namespace tint::resolver { +/// ResolvedIdentifier holds the resolution of an ast::Identifier. +/// Can hold one of: +/// - const ast::TypeDecl* (as const ast::Node*) +/// - const ast::Variable* (as const ast::Node*) +/// - const ast::Function* (as const ast::Node*) +/// - sem::BuiltinType +/// - type::Builtin +class ResolvedIdentifier { + public: + ResolvedIdentifier() = default; + + /// Constructor + /// @param value the resolved identifier value + template + ResolvedIdentifier(T value) : value_(value) {} // NOLINT(runtime/explicit) + + /// @return true if the ResolvedIdentifier holds a value (successfully resolved) + operator bool() const { return !std::holds_alternative(value_); } + + /// @return the node pointer if the ResolvedIdentifier holds an AST node, otherwise nullptr + const ast::Node* Node() const { + if (auto n = std::get_if(&value_)) { + return *n; + } + return nullptr; + } + + /// @return the builtin function if the ResolvedIdentifier holds sem::BuiltinType, otherwise + /// sem::BuiltinType::kNone + sem::BuiltinType BuiltinFunction() const { + if (auto n = std::get_if(&value_)) { + return *n; + } + return sem::BuiltinType::kNone; + } + + /// @return the builtin type if the ResolvedIdentifier holds type::Builtin, otherwise + /// type::Builtin::kUndefined + type::Builtin BuiltinType() const { + if (auto n = std::get_if(&value_)) { + return *n; + } + return type::Builtin::kUndefined; + } + + /// @param value the value to compare the ResolvedIdentifier to + /// @return true if the ResolvedIdentifier is equal to @p value + template + bool operator==(const T& value) const { + if (auto n = std::get_if(&value_)) { + return *n == value; + } + return false; + } + + /// @param other the other value to compare to this + /// @return true if this ResolvedIdentifier and @p other are not equal + template + bool operator!=(const T& other) const { + return !(*this == other); + } + + private: + std::variant value_; +}; + +/// @param out the std::ostream to write to +/// @param ri the ResolvedIdentifier +/// @returns `out` so calls can be chained +std::ostream& operator<<(std::ostream& out, const ResolvedIdentifier& ri); + /// DependencyGraph holds information about module-scope declaration dependency /// analysis and symbol resolutions. struct DependencyGraph { @@ -48,9 +121,8 @@ struct DependencyGraph { /// All globals in dependency-sorted order. utils::Vector ordered_globals; - /// Map of ast::IdentifierExpression or ast::TypeName to a type, function, or - /// variable that declares the symbol. - utils::Hashmap resolved_symbols; + /// Map of ast::Identifier to a ResolvedIdentifier + utils::Hashmap resolved_identifiers; /// Map of ast::Variable to a type, function, or variable that is shadowed by /// the variable key. A declaration (X) shadows another (Y) if X and Y use diff --git a/src/tint/resolver/dependency_graph_test.cc b/src/tint/resolver/dependency_graph_test.cc index a584915130..e6c460ec0c 100644 --- a/src/tint/resolver/dependency_graph_test.cc +++ b/src/tint/resolver/dependency_graph_test.cc @@ -103,8 +103,7 @@ static constexpr SymbolDeclKind kFuncDeclKinds[] = { SymbolDeclKind::Function, }; -/// SymbolUseKind is used by parameterized tests to enumerate the different -/// kinds of symbol uses. +/// SymbolUseKind is used by parameterized tests to enumerate the different kinds of symbol uses. enum class SymbolUseKind { GlobalVarType, GlobalVarArrayElemType, @@ -400,19 +399,19 @@ struct SymbolTestHelper { /// Destructor. ~SymbolTestHelper(); - /// Declares a symbol with the given kind - /// @param kind the kind of symbol declaration + /// Declares an identifier with the given kind + /// @param kind the kind of identifier declaration /// @param symbol the symbol to use for the declaration /// @param source the source of the declaration - /// @returns the declaration node + /// @returns the identifier node const ast::Node* Add(SymbolDeclKind kind, Symbol symbol, Source source); - /// Declares a use of a symbol with the given kind + /// Declares a use of an identifier with the given kind /// @param kind the kind of symbol use /// @param symbol the declaration symbol to use /// @param source the source of the use /// @returns the use node - const ast::Node* Add(SymbolUseKind kind, Symbol symbol, Source source); + const ast::Identifier* Add(SymbolUseKind kind, Symbol symbol, Source source); /// Builds a function, if any parameter or local declarations have been added void Build(); @@ -422,7 +421,7 @@ SymbolTestHelper::SymbolTestHelper(ProgramBuilder* b) : builder(b) {} SymbolTestHelper::~SymbolTestHelper() {} -const ast::Node* SymbolTestHelper::Add(SymbolDeclKind kind, Symbol symbol, Source source) { +const ast::Node* SymbolTestHelper::Add(SymbolDeclKind kind, Symbol symbol, Source source = {}) { auto& b = *builder; switch (kind) { case SymbolDeclKind::GlobalVar: @@ -464,88 +463,90 @@ const ast::Node* SymbolTestHelper::Add(SymbolDeclKind kind, Symbol symbol, Sourc return nullptr; } -const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind, Symbol symbol, Source source) { +const ast::Identifier* SymbolTestHelper::Add(SymbolUseKind kind, + Symbol symbol, + Source source = {}) { auto& b = *builder; switch (kind) { case SymbolUseKind::GlobalVarType: { auto* node = b.ty(source, symbol); b.GlobalVar(b.Sym(), node, type::AddressSpace::kPrivate); - return node; + return node->name; } case SymbolUseKind::GlobalVarArrayElemType: { auto* node = b.ty(source, symbol); b.GlobalVar(b.Sym(), b.ty.array(node, 4_i), type::AddressSpace::kPrivate); - return node; + return node->name; } case SymbolUseKind::GlobalVarArraySizeValue: { auto* node = b.Expr(source, symbol); b.GlobalVar(b.Sym(), b.ty.array(b.ty.i32(), node), type::AddressSpace::kPrivate); - return node; + return node->identifier; } case SymbolUseKind::GlobalVarVectorElemType: { auto* node = b.ty(source, symbol); b.GlobalVar(b.Sym(), b.ty.vec3(node), type::AddressSpace::kPrivate); - return node; + return node->name; } case SymbolUseKind::GlobalVarMatrixElemType: { auto* node = b.ty(source, symbol); b.GlobalVar(b.Sym(), b.ty.mat3x4(node), type::AddressSpace::kPrivate); - return node; + return node->name; } case SymbolUseKind::GlobalVarSampledTexElemType: { auto* node = b.ty(source, symbol); b.GlobalVar(b.Sym(), b.ty.sampled_texture(type::TextureDimension::k2d, node)); - return node; + return node->name; } case SymbolUseKind::GlobalVarMultisampledTexElemType: { auto* node = b.ty(source, symbol); b.GlobalVar(b.Sym(), b.ty.multisampled_texture(type::TextureDimension::k2d, node)); - return node; + return node->name; } case SymbolUseKind::GlobalVarValue: { auto* node = b.Expr(source, symbol); b.GlobalVar(b.Sym(), b.ty.i32(), type::AddressSpace::kPrivate, node); - return node; + return node->identifier; } case SymbolUseKind::GlobalConstType: { auto* node = b.ty(source, symbol); b.GlobalConst(b.Sym(), node, b.Expr(1_i)); - return node; + return node->name; } case SymbolUseKind::GlobalConstArrayElemType: { auto* node = b.ty(source, symbol); b.GlobalConst(b.Sym(), b.ty.array(node, 4_i), b.Expr(1_i)); - return node; + return node->name; } case SymbolUseKind::GlobalConstArraySizeValue: { auto* node = b.Expr(source, symbol); b.GlobalConst(b.Sym(), b.ty.array(b.ty.i32(), node), b.Expr(1_i)); - return node; + return node->identifier; } case SymbolUseKind::GlobalConstVectorElemType: { auto* node = b.ty(source, symbol); b.GlobalConst(b.Sym(), b.ty.vec3(node), b.Expr(1_i)); - return node; + return node->name; } case SymbolUseKind::GlobalConstMatrixElemType: { auto* node = b.ty(source, symbol); b.GlobalConst(b.Sym(), b.ty.mat3x4(node), b.Expr(1_i)); - return node; + return node->name; } case SymbolUseKind::GlobalConstValue: { auto* node = b.Expr(source, symbol); b.GlobalConst(b.Sym(), b.ty.i32(), node); - return node; + return node->identifier; } case SymbolUseKind::AliasType: { auto* node = b.ty(source, symbol); b.Alias(b.Sym(), node); - return node; + return node->name; } case SymbolUseKind::StructMemberType: { auto* node = b.ty(source, symbol); b.Structure(b.Sym(), utils::Vector{b.Member("m", node)}); - return node; + return node->name; } case SymbolUseKind::CallFunction: { auto* node = b.Ident(source, symbol); @@ -555,72 +556,72 @@ const ast::Node* SymbolTestHelper::Add(SymbolUseKind kind, Symbol symbol, Source case SymbolUseKind::ParameterType: { auto* node = b.ty(source, symbol); parameters.Push(b.Param(b.Sym(), node)); - return node; + return node->name; } case SymbolUseKind::LocalVarType: { auto* node = b.ty(source, symbol); statements.Push(b.Decl(b.Var(b.Sym(), node))); - return node; + return node->name; } case SymbolUseKind::LocalVarArrayElemType: { auto* node = b.ty(source, symbol); statements.Push(b.Decl(b.Var(b.Sym(), b.ty.array(node, 4_u), b.Expr(1_i)))); - return node; + return node->name; } case SymbolUseKind::LocalVarArraySizeValue: { auto* node = b.Expr(source, symbol); statements.Push(b.Decl(b.Var(b.Sym(), b.ty.array(b.ty.i32(), node), b.Expr(1_i)))); - return node; + return node->identifier; } case SymbolUseKind::LocalVarVectorElemType: { auto* node = b.ty(source, symbol); statements.Push(b.Decl(b.Var(b.Sym(), b.ty.vec3(node)))); - return node; + return node->name; } case SymbolUseKind::LocalVarMatrixElemType: { auto* node = b.ty(source, symbol); statements.Push(b.Decl(b.Var(b.Sym(), b.ty.mat3x4(node)))); - return node; + return node->name; } case SymbolUseKind::LocalVarValue: { auto* node = b.Expr(source, symbol); statements.Push(b.Decl(b.Var(b.Sym(), b.ty.i32(), node))); - return node; + return node->identifier; } case SymbolUseKind::LocalLetType: { auto* node = b.ty(source, symbol); statements.Push(b.Decl(b.Let(b.Sym(), node, b.Expr(1_i)))); - return node; + return node->name; } case SymbolUseKind::LocalLetValue: { auto* node = b.Expr(source, symbol); statements.Push(b.Decl(b.Let(b.Sym(), b.ty.i32(), node))); - return node; + return node->identifier; } case SymbolUseKind::NestedLocalVarType: { auto* node = b.ty(source, symbol); nested_statements.Push(b.Decl(b.Var(b.Sym(), node))); - return node; + return node->name; } case SymbolUseKind::NestedLocalVarValue: { auto* node = b.Expr(source, symbol); nested_statements.Push(b.Decl(b.Var(b.Sym(), b.ty.i32(), node))); - return node; + return node->identifier; } case SymbolUseKind::NestedLocalLetType: { auto* node = b.ty(source, symbol); nested_statements.Push(b.Decl(b.Let(b.Sym(), node, b.Expr(1_i)))); - return node; + return node->name; } case SymbolUseKind::NestedLocalLetValue: { auto* node = b.Expr(source, symbol); nested_statements.Push(b.Decl(b.Let(b.Sym(), b.ty.i32(), node))); - return node; + return node->identifier; } case SymbolUseKind::WorkgroupSizeValue: { auto* node = b.Expr(source, symbol); func_attrs.Push(b.WorkgroupSize(1_i, node, 2_i)); - return node; + return node->identifier; } } return nullptr; @@ -641,7 +642,7 @@ void SymbolTestHelper::Build() { } //////////////////////////////////////////////////////////////////////////////// -// Used-before-declarated tests +// Used-before-declared tests //////////////////////////////////////////////////////////////////////////////// namespace used_before_decl_tests { @@ -1103,14 +1104,14 @@ TEST_F(ResolverDependencyGraphOrderedGlobalsTest, DirectiveFirst) { } // namespace ordered_globals //////////////////////////////////////////////////////////////////////////////// -// Resolved symbols tests +// Resolve to user-declaration tests //////////////////////////////////////////////////////////////////////////////// -namespace resolved_symbols { +namespace resolve_to_user_decl { -using ResolverDependencyGraphResolvedSymbolTest = +using ResolverDependencyGraphResolveToUserDeclTest = ResolverDependencyGraphTestWithParam>; -TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) { +TEST_P(ResolverDependencyGraphResolveToUserDeclTest, Test) { const Symbol symbol = Sym("SYMBOL"); const auto decl_kind = std::get<0>(GetParam()); const auto use_kind = std::get<1>(GetParam()); @@ -1128,41 +1129,209 @@ TEST_P(ResolverDependencyGraphResolvedSymbolTest, Test) { if (expect_pass) { // Check that the use resolves to the declaration - auto resolved_symbol = graph.resolved_symbols.Find(use); - ASSERT_TRUE(resolved_symbol); - EXPECT_EQ(*resolved_symbol, decl) - << "resolved: " << (*resolved_symbol ? (*resolved_symbol)->TypeInfo().name : "") - << "\n" + auto resolved_identifier = graph.resolved_identifiers.Find(use); + ASSERT_TRUE(resolved_identifier); + auto* resolved_node = resolved_identifier->Node(); + EXPECT_EQ(resolved_node, decl) + << "resolved: " << (resolved_node ? resolved_node->TypeInfo().name : "") << "\n" << "decl: " << decl->TypeInfo().name; } } INSTANTIATE_TEST_SUITE_P(Types, - ResolverDependencyGraphResolvedSymbolTest, + ResolverDependencyGraphResolveToUserDeclTest, testing::Combine(testing::ValuesIn(kTypeDeclKinds), testing::ValuesIn(kTypeUseKinds))); INSTANTIATE_TEST_SUITE_P(Values, - ResolverDependencyGraphResolvedSymbolTest, + ResolverDependencyGraphResolveToUserDeclTest, testing::Combine(testing::ValuesIn(kValueDeclKinds), testing::ValuesIn(kValueUseKinds))); INSTANTIATE_TEST_SUITE_P(Functions, - ResolverDependencyGraphResolvedSymbolTest, + ResolverDependencyGraphResolveToUserDeclTest, testing::Combine(testing::ValuesIn(kFuncDeclKinds), testing::ValuesIn(kFuncUseKinds))); -} // namespace resolved_symbols +} // namespace resolve_to_user_decl + +//////////////////////////////////////////////////////////////////////////////// +// Resolve to builtin func tests +//////////////////////////////////////////////////////////////////////////////// +namespace resolve_to_builtin_func { + +using ResolverDependencyGraphResolveToBuiltinFunc = + ResolverDependencyGraphTestWithParam>; + +TEST_P(ResolverDependencyGraphResolveToBuiltinFunc, Resolve) { + const auto use = std::get<0>(GetParam()); + const auto builtin = std::get<1>(GetParam()); + const auto symbol = Symbols().New(utils::ToString(builtin)); + + SymbolTestHelper helper(this); + auto* ident = helper.Add(use, symbol); + helper.Build(); + + auto resolved = Build().resolved_identifiers.Get(ident); + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->BuiltinFunction(), builtin) << *resolved; +} + +TEST_P(ResolverDependencyGraphResolveToBuiltinFunc, ShadowedByGlobalVar) { + const auto use = std::get<0>(GetParam()); + const auto builtin = std::get<1>(GetParam()); + const auto symbol = Symbols().New(utils::ToString(builtin)); + + SymbolTestHelper helper(this); + auto* decl = helper.Add(SymbolDeclKind::GlobalVar, symbol); + auto* ident = helper.Add(use, symbol); + helper.Build(); + + auto resolved = Build().resolved_identifiers.Get(ident); + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->Node(), decl) << *resolved; +} + +TEST_P(ResolverDependencyGraphResolveToBuiltinFunc, ShadowedByStruct) { + const auto use = std::get<0>(GetParam()); + const auto builtin = std::get<1>(GetParam()); + const auto symbol = Symbols().New(utils::ToString(builtin)); + + SymbolTestHelper helper(this); + auto* decl = helper.Add(SymbolDeclKind::Struct, symbol); + auto* ident = helper.Add(use, symbol); + helper.Build(); + + auto resolved = Build().resolved_identifiers.Get(ident); + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->Node(), decl) << *resolved; +} + +TEST_P(ResolverDependencyGraphResolveToBuiltinFunc, ShadowedByFunc) { + const auto use = std::get<0>(GetParam()); + const auto builtin = std::get<1>(GetParam()); + const auto symbol = Symbols().New(utils::ToString(builtin)); + + SymbolTestHelper helper(this); + auto* decl = helper.Add(SymbolDeclKind::Function, symbol); + auto* ident = helper.Add(use, symbol); + helper.Build(); + + auto resolved = Build().resolved_identifiers.Get(ident); + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->Node(), decl) << *resolved; +} + +INSTANTIATE_TEST_SUITE_P(Types, + ResolverDependencyGraphResolveToBuiltinFunc, + testing::Combine(testing::ValuesIn(kTypeUseKinds), + testing::ValuesIn(sem::kBuiltinTypes))); + +INSTANTIATE_TEST_SUITE_P(Values, + ResolverDependencyGraphResolveToBuiltinFunc, + testing::Combine(testing::ValuesIn(kValueUseKinds), + testing::ValuesIn(sem::kBuiltinTypes))); + +INSTANTIATE_TEST_SUITE_P(Functions, + ResolverDependencyGraphResolveToBuiltinFunc, + testing::Combine(testing::ValuesIn(kFuncUseKinds), + testing::ValuesIn(sem::kBuiltinTypes))); + +} // namespace resolve_to_builtin_func + +//////////////////////////////////////////////////////////////////////////////// +// Resolve to builtin type tests +//////////////////////////////////////////////////////////////////////////////// +namespace resolve_to_builtin_type { + +using ResolverDependencyGraphResolveToBuiltinType = + ResolverDependencyGraphTestWithParam>; + +TEST_P(ResolverDependencyGraphResolveToBuiltinType, Resolve) { + const auto use = std::get<0>(GetParam()); + const auto name = std::get<1>(GetParam()); + const auto symbol = Symbols().New(name); + + SymbolTestHelper helper(this); + auto* ident = helper.Add(use, symbol); + helper.Build(); + + auto resolved = Build().resolved_identifiers.Get(ident); + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->BuiltinType(), type::ParseBuiltin(name)) << *resolved; +} + +TEST_P(ResolverDependencyGraphResolveToBuiltinType, ShadowedByGlobalVar) { + const auto use = std::get<0>(GetParam()); + const auto name = std::get<1>(GetParam()); + const auto symbol = Symbols().New(name); + + SymbolTestHelper helper(this); + auto* decl = helper.Add(SymbolDeclKind::GlobalVar, symbol); + auto* ident = helper.Add(use, symbol); + helper.Build(); + + auto resolved = Build().resolved_identifiers.Get(ident); + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->Node(), decl) << *resolved; +} + +TEST_P(ResolverDependencyGraphResolveToBuiltinType, ShadowedByStruct) { + const auto use = std::get<0>(GetParam()); + const auto name = std::get<1>(GetParam()); + const auto symbol = Symbols().New(name); + + SymbolTestHelper helper(this); + auto* decl = helper.Add(SymbolDeclKind::Struct, symbol); + auto* ident = helper.Add(use, symbol); + helper.Build(); + + auto resolved = Build().resolved_identifiers.Get(ident); + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->Node(), decl) << *resolved; +} + +TEST_P(ResolverDependencyGraphResolveToBuiltinType, ShadowedByFunc) { + const auto use = std::get<0>(GetParam()); + const auto name = std::get<1>(GetParam()); + const auto symbol = Symbols().New(name); + + SymbolTestHelper helper(this); + auto* decl = helper.Add(SymbolDeclKind::Function, symbol); + auto* ident = helper.Add(use, symbol); + helper.Build(); + + auto resolved = Build().resolved_identifiers.Get(ident); + ASSERT_TRUE(resolved); + EXPECT_EQ(resolved->Node(), decl) << *resolved; +} + +INSTANTIATE_TEST_SUITE_P(Types, + ResolverDependencyGraphResolveToBuiltinType, + testing::Combine(testing::ValuesIn(kTypeUseKinds), + testing::ValuesIn(type::kBuiltinStrings))); + +INSTANTIATE_TEST_SUITE_P(Values, + ResolverDependencyGraphResolveToBuiltinType, + testing::Combine(testing::ValuesIn(kValueUseKinds), + testing::ValuesIn(type::kBuiltinStrings))); + +INSTANTIATE_TEST_SUITE_P(Functions, + ResolverDependencyGraphResolveToBuiltinType, + testing::Combine(testing::ValuesIn(kFuncUseKinds), + testing::ValuesIn(type::kBuiltinStrings))); + +} // namespace resolve_to_builtin_type //////////////////////////////////////////////////////////////////////////////// // Shadowing tests //////////////////////////////////////////////////////////////////////////////// namespace shadowing { -using ResolverDependencyShadowTest = +using ResolverDependencyGraphShadowTest = ResolverDependencyGraphTestWithParam>; -TEST_P(ResolverDependencyShadowTest, Test) { +TEST_P(ResolverDependencyGraphShadowTest, Test) { const Symbol symbol = Sym("SYMBOL"); const auto outer_kind = std::get<0>(GetParam()); const auto inner_kind = std::get<1>(GetParam()); @@ -1185,12 +1354,12 @@ TEST_P(ResolverDependencyShadowTest, Test) { } INSTANTIATE_TEST_SUITE_P(LocalShadowGlobal, - ResolverDependencyShadowTest, + ResolverDependencyGraphShadowTest, testing::Combine(testing::ValuesIn(kGlobalDeclKinds), testing::ValuesIn(kLocalDeclKinds))); INSTANTIATE_TEST_SUITE_P(NestedLocalShadowLocal, - ResolverDependencyShadowTest, + ResolverDependencyGraphShadowTest, testing::Combine(testing::Values(SymbolDeclKind::Parameter, SymbolDeclKind::LocalVar, SymbolDeclKind::LocalLet), @@ -1204,6 +1373,18 @@ INSTANTIATE_TEST_SUITE_P(NestedLocalShadowLocal, //////////////////////////////////////////////////////////////////////////////// namespace ast_traversal { +static const ast::Identifier* IdentifierOf(const ast::IdentifierExpression* expr) { + return expr->identifier; +} + +static const ast::Identifier* IdentifierOf(const ast::TypeName* ty) { + return ty->name; +} + +static const ast::Identifier* IdentifierOf(const ast::Identifier* ident) { + return ident; +} + using ResolverDependencyGraphTraversalTest = ResolverDependencyGraphTest; TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) { @@ -1217,7 +1398,7 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) { struct SymbolUse { const ast::Node* decl = nullptr; - const ast::Node* use = nullptr; + const ast::Identifier* use = nullptr; std::string where; }; @@ -1225,7 +1406,8 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) { auto add_use = [&](const ast::Node* decl, auto* use, int line, const char* kind) { symbol_uses.Push( - SymbolUse{decl, use, std::string(__FILE__) + ":" + std::to_string(line) + ": " + kind}); + SymbolUse{decl, IdentifierOf(use), + std::string(__FILE__) + ":" + std::to_string(line) + ": " + kind}); return use; }; #define V add_use(value_decl, Expr(value_sym), __LINE__, "V()") @@ -1310,9 +1492,9 @@ TEST_F(ResolverDependencyGraphTraversalTest, SymbolsReached) { auto graph = Build(); for (auto use : symbol_uses) { - auto resolved_symbol = graph.resolved_symbols.Find(use.use); - ASSERT_TRUE(resolved_symbol) << use.where; - EXPECT_EQ(*resolved_symbol, use.decl) << use.where; + auto resolved_identifier = graph.resolved_identifiers.Find(use.use); + ASSERT_TRUE(resolved_identifier) << use.where; + EXPECT_EQ(*resolved_identifier, use.decl) << use.where; } } diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index df9eace548..947499bb46 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -105,7 +105,7 @@ Resolver::Resolver(ProgramBuilder* builder) diagnostics_(builder->Diagnostics()), const_eval_(*builder), intrinsic_table_(IntrinsicTable::Create(*builder)), - sem_(builder, dependencies_), + sem_(builder), validator_(builder, sem_, enabled_extensions_, @@ -332,41 +332,45 @@ type::Type* Resolver::Type(const ast::Type* ty) { TINT_UNREACHABLE(Resolver, builder_->Diagnostics()) << "TODO(crbug.com/tint/1810)"; } - auto* resolved = sem_.ResolvedSymbol(t); - if (resolved == nullptr) { - if (IsBuiltin(t->name->symbol)) { - auto name = builder_->Symbols().NameFor(t->name->symbol); - AddError("cannot use builtin '" + name + "' as type", ty->source); - return nullptr; - } - return BuiltinType(t->name->symbol, t->source); + auto resolved = dependencies_.resolved_identifiers.Get(t->name); + if (!resolved) { + TINT_ICE(Resolver, builder_->Diagnostics()) << "identifier was not resolved"; + return nullptr; } - return Switch( - resolved, // - [&](type::Type* type) { return type; }, - [&](sem::Variable* var) { - auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); - AddError("cannot use variable '" + name + "' as type", ty->source); - AddNote("'" + name + "' declared here", var->Declaration()->source); - return nullptr; - }, - [&](sem::Function* func) { - auto name = builder_->Symbols().NameFor(func->Declaration()->symbol); - AddError("cannot use function '" + name + "' as type", ty->source); - AddNote("'" + name + "' declared here", func->Declaration()->source); - return nullptr; - }, - [&](Default) -> type::Type* { - TINT_UNREACHABLE(Resolver, diagnostics_) - << "Unhandled resolved type '" - << (resolved ? resolved->TypeInfo().name : "") - << "' resolved from ast::Type '" << ty->TypeInfo().name << "'"; - return nullptr; - }); - }, - [&](Default) { + + if (auto* ast_node = resolved->Node()) { + auto* resolved_node = sem_.Get(ast_node); + return Switch( + resolved_node, // + [&](type::Type* type) { return type; }, + [&](sem::Variable* variable) { + auto name = builder_->Symbols().NameFor(variable->Declaration()->symbol); + AddError("cannot use variable '" + name + "' as type", ty->source); + AddNote("'" + name + "' declared here", variable->Declaration()->source); + return nullptr; + }, + [&](sem::Function* func) { + auto name = builder_->Symbols().NameFor(func->Declaration()->symbol); + AddError("cannot use function '" + name + "' as type", ty->source); + AddNote("'" + name + "' declared here", func->Declaration()->source); + return nullptr; + }); + } + + if (auto builtin_ty = resolved->BuiltinType(); + builtin_ty != type::Builtin::kUndefined) { + return BuiltinType(builtin_ty, t->name); + } + + if (auto builtin_fn = resolved->BuiltinFunction(); + builtin_fn != sem::BuiltinType::kNone) { + auto name = builder_->Symbols().NameFor(t->name->symbol); + AddError("cannot use builtin '" + name + "' as type", ty->source); + return nullptr; + } + TINT_UNREACHABLE(Resolver, diagnostics_) - << "Unhandled type: '" << ty->TypeInfo().name << "'"; + << "unhandled resolved identifier: " << *resolved; return nullptr; }); @@ -947,7 +951,8 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) { return nullptr; } - // Track the pipeline-overridable constants that are transitively referenced by this variable. + // Track the pipeline-overridable constants that are transitively referenced by this + // variable. for (auto* var : transitively_referenced_overrides) { builder_->Sem().AddTransitivelyReferencedOverride(sem, var); } @@ -1176,8 +1181,8 @@ bool Resolver::WorkgroupSize(const ast::Function* func) { "abstract-integer, i32 or u32"; for (size_t i = 0; i < 3; i++) { - // Each argument to this attribute can either be a literal, an identifier for a module-scope - // constants, a const-expression, or nullptr if not specified. + // Each argument to this attribute can either be a literal, an identifier for a + // module-scope constants, a const-expression, or nullptr if not specified. auto* value = values[i]; if (!value) { break; @@ -1604,8 +1609,8 @@ sem::ValueExpression* Resolver::Expression(const ast::Expression* root) { return sem_expr; } - // If we just processed the lhs of a constexpr logical binary expression, mark the rhs for - // short-circuiting. + // If we just processed the lhs of a constexpr logical binary expression, mark the rhs + // for short-circuiting. if (sem_expr->ConstantValue()) { if (auto binary = logical_binary_lhs_to_parent_.Find(expr)) { const bool lhs_is_true = sem_expr->ConstantValue()->ValueAs(); @@ -1678,7 +1683,8 @@ bool Resolver::AliasAnalysis(const sem::Call* call) { auto& target_info = alias_analysis_infos_[target]; auto& caller_info = alias_analysis_infos_[current_function_]; - // Track the set of root identifiers that are read and written by arguments passed in this call. + // Track the set of root identifiers that are read and written by arguments passed in this + // call. std::unordered_map arg_reads; std::unordered_map arg_writes; for (size_t i = 0; i < args.Length(); i++) { @@ -1715,8 +1721,8 @@ bool Resolver::AliasAnalysis(const sem::Call* call) { }, [&](const sem::Parameter* param) { caller_info.parameter_writes.insert(param); }); } else if (target_info.parameter_reads.count(target->Parameters()[i])) { - // Arguments that are read from can alias with arguments or module-scope variables that - // are written to. + // Arguments that are read from can alias with arguments or module-scope variables + // that are written to. if (arg_writes.count(root)) { return make_error(arg, {arg_writes.at(root), Alias::Argument, "write"}); } @@ -2101,8 +2107,8 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { // Constant evaluation failed. // Can happen for expressions that will fail validation (later). // Use the kRuntime EvaluationStage, as kConstant will trigger an assertion in - // the sem::ValueExpression initializer, which checks that kConstant is paired with - // a constant value. + // the sem::ValueExpression initializer, which checks that kConstant is paired + // with a constant value. stage = sem::EvaluationStage::kRuntime; } } @@ -2310,35 +2316,44 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { // conversion. auto* ident = expr->target.name; Mark(ident); - if (auto* resolved = sem_.ResolvedSymbol(ident)) { - // A type initializer or conversions. - // Note: Unlike the code path where we're resolving the call target from an - // ast::Type, all types must already have the element type explicitly specified, - // so there's no need to infer element types. - return ty_init_or_conv(resolved); + + auto resolved = dependencies_.resolved_identifiers.Get(ident); + if (!resolved) { + TINT_ICE(Resolver, builder_->Diagnostics()) << "identifier was not resolved"; + return nullptr; } - auto* resolved = sem_.ResolvedSymbol(ident); - call = Switch( - resolved, // - [&](sem::Function* func) { return FunctionCall(expr, func, args, arg_behaviors); }, - [&](sem::Variable* var) { - auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); - AddError("cannot call variable '" + name + "'", ident->source); - AddNote("'" + name + "' declared here", var->Declaration()->source); - return nullptr; - }, - [&](Default) -> sem::Call* { - auto name = builder_->Symbols().NameFor(ident->symbol); - if (auto builtin_type = sem::ParseBuiltinType(name); - builtin_type != sem::BuiltinType::kNone) { - return BuiltinCall(expr, builtin_type, args); - } - if (auto* alias = BuiltinType(ident->symbol, ident->source)) { - return ty_init_or_conv(alias); - } - return nullptr; - }); + if (auto* ast_node = resolved->Node()) { + auto* resolved_node = sem_.Get(ast_node); + return Switch( + resolved_node, // + [&](const type::Type* ty) { + // A type initializer or conversions. + // Note: Unlike the code path where we're resolving the call target from an + // ast::Type, all types must already have the element type explicitly + // specified, so there's no need to infer element types. + return ty_init_or_conv(ty); + }, + [&](sem::Function* func) { return FunctionCall(expr, func, args, arg_behaviors); }, + [&](sem::Variable* var) { + auto name = builder_->Symbols().NameFor(var->Declaration()->symbol); + AddError("cannot call variable '" + name + "'", ident->source); + AddNote("'" + name + "' declared here", var->Declaration()->source); + return nullptr; + }); + } + + if (auto builtin_fn = resolved->BuiltinFunction(); builtin_fn != sem::BuiltinType::kNone) { + return BuiltinCall(expr, builtin_fn, args); + } + + if (auto builtin_ty = resolved->BuiltinType(); builtin_ty != type::Builtin::kUndefined) { + auto* ty = BuiltinType(builtin_ty, expr->target.name); + return ty ? ty_init_or_conv(ty) : nullptr; + } + + TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled resolved identifier: " << *resolved; + return nullptr; } if (!call) { @@ -2438,13 +2453,12 @@ sem::Call* Resolver::BuiltinCall(const ast::CallExpression* expr, return call; } -type::Type* Resolver::BuiltinType(Symbol sym, const Source& source) const { - auto name = builder_->Symbols().NameFor(sym); +type::Type* Resolver::BuiltinType(type::Builtin builtin_ty, const ast::Identifier* ident) const { auto& b = *builder_; auto vec_f32 = [&](uint32_t n) { return b.create(b.create(), n); }; auto vec_f16 = [&](uint32_t n) { return b.create(b.create(), n); }; - switch (type::ParseBuiltin(name)) { + switch (builtin_ty) { case type::Builtin::kMat2X2F: return b.create(vec_f32(2u), 2u); case type::Builtin::kMat2X3F: @@ -2464,32 +2478,41 @@ type::Type* Resolver::BuiltinType(Symbol sym, const Source& source) const { case type::Builtin::kMat4X4F: return b.create(vec_f32(4u), 4u); case type::Builtin::kMat2X2H: - return validator_.CheckF16Enabled(source) ? b.create(vec_f16(2u), 2u) - : nullptr; + return validator_.CheckF16Enabled(ident->source) + ? b.create(vec_f16(2u), 2u) + : nullptr; case type::Builtin::kMat2X3H: - return validator_.CheckF16Enabled(source) ? b.create(vec_f16(3u), 2u) - : nullptr; + return validator_.CheckF16Enabled(ident->source) + ? b.create(vec_f16(3u), 2u) + : nullptr; case type::Builtin::kMat2X4H: - return validator_.CheckF16Enabled(source) ? b.create(vec_f16(4u), 2u) - : nullptr; + return validator_.CheckF16Enabled(ident->source) + ? b.create(vec_f16(4u), 2u) + : nullptr; case type::Builtin::kMat3X2H: - return validator_.CheckF16Enabled(source) ? b.create(vec_f16(2u), 3u) - : nullptr; + return validator_.CheckF16Enabled(ident->source) + ? b.create(vec_f16(2u), 3u) + : nullptr; case type::Builtin::kMat3X3H: - return validator_.CheckF16Enabled(source) ? b.create(vec_f16(3u), 3u) - : nullptr; + return validator_.CheckF16Enabled(ident->source) + ? b.create(vec_f16(3u), 3u) + : nullptr; case type::Builtin::kMat3X4H: - return validator_.CheckF16Enabled(source) ? b.create(vec_f16(4u), 3u) - : nullptr; + return validator_.CheckF16Enabled(ident->source) + ? b.create(vec_f16(4u), 3u) + : nullptr; case type::Builtin::kMat4X2H: - return validator_.CheckF16Enabled(source) ? b.create(vec_f16(2u), 4u) - : nullptr; + return validator_.CheckF16Enabled(ident->source) + ? b.create(vec_f16(2u), 4u) + : nullptr; case type::Builtin::kMat4X3H: - return validator_.CheckF16Enabled(source) ? b.create(vec_f16(3u), 4u) - : nullptr; + return validator_.CheckF16Enabled(ident->source) + ? b.create(vec_f16(3u), 4u) + : nullptr; case type::Builtin::kMat4X4H: - return validator_.CheckF16Enabled(source) ? b.create(vec_f16(4u), 4u) - : nullptr; + return validator_.CheckF16Enabled(ident->source) + ? b.create(vec_f16(4u), 4u) + : nullptr; case type::Builtin::kVec2F: return vec_f32(2u); case type::Builtin::kVec3F: @@ -2497,11 +2520,11 @@ type::Type* Resolver::BuiltinType(Symbol sym, const Source& source) const { case type::Builtin::kVec4F: return vec_f32(4u); case type::Builtin::kVec2H: - return validator_.CheckF16Enabled(source) ? vec_f16(2u) : nullptr; + return validator_.CheckF16Enabled(ident->source) ? vec_f16(2u) : nullptr; case type::Builtin::kVec3H: - return validator_.CheckF16Enabled(source) ? vec_f16(3u) : nullptr; + return validator_.CheckF16Enabled(ident->source) ? vec_f16(3u) : nullptr; case type::Builtin::kVec4H: - return validator_.CheckF16Enabled(source) ? vec_f16(4u) : nullptr; + return validator_.CheckF16Enabled(ident->source) ? vec_f16(4u) : nullptr; case type::Builtin::kVec2I: return b.create(b.create(), 2u); case type::Builtin::kVec3I: @@ -2518,7 +2541,8 @@ type::Type* Resolver::BuiltinType(Symbol sym, const Source& source) const { break; } - TINT_ICE(Resolver, diagnostics_) << source << " unhandled builtin type '" << name << "'"; + auto name = builder_->Symbols().NameFor(ident->symbol); + TINT_ICE(Resolver, diagnostics_) << ident->source << " unhandled builtin type '" << name << "'"; return nullptr; } @@ -2685,97 +2709,110 @@ sem::ValueExpression* Resolver::Literal(const ast::LiteralExpression* literal) { sem::ValueExpression* Resolver::Identifier(const ast::IdentifierExpression* expr) { Mark(expr->identifier); - auto symbol = expr->identifier->symbol; - auto* sem_resolved = sem_.ResolvedSymbol(expr); - if (auto* variable = As(sem_resolved)) { - auto* user = builder_->create(expr, current_statement_, variable); - if (current_statement_) { - // If identifier is part of a loop continuing block, make sure it - // doesn't refer to a variable that is bypassed by a continue statement - // in the loop's body block. - if (auto* continuing_block = - current_statement_->FindFirstParent()) { - auto* loop_block = continuing_block->FindFirstParent(); - if (loop_block->FirstContinue()) { - // If our identifier is in loop_block->decls, make sure its index is - // less than first_continue - if (auto decl = loop_block->Decls().Find(symbol)) { - if (decl->order >= loop_block->NumDeclsAtFirstContinue()) { - AddError("continue statement bypasses declaration of '" + - builder_->Symbols().NameFor(symbol) + "'", - loop_block->FirstContinue()->source); - AddNote("identifier '" + builder_->Symbols().NameFor(symbol) + - "' declared here", - decl->variable->Declaration()->source); - AddNote("identifier '" + builder_->Symbols().NameFor(symbol) + - "' referenced in continuing block here", - expr->source); - return nullptr; + auto resolved = dependencies_.resolved_identifiers.Get(expr->identifier); + if (!resolved) { + TINT_ICE(Resolver, builder_->Diagnostics()) << "identifier was not resolved"; + return nullptr; + } + + if (auto* ast_node = resolved->Node()) { + auto* resolved_node = sem_.Get(ast_node); + return Switch( + resolved_node, // + [&](sem::Variable* variable) -> sem::VariableUser* { + auto symbol = expr->identifier->symbol; + auto* user = + builder_->create(expr, current_statement_, variable); + + if (current_statement_) { + // If identifier is part of a loop continuing block, make sure it + // doesn't refer to a variable that is bypassed by a continue statement + // in the loop's body block. + if (auto* continuing_block = + current_statement_ + ->FindFirstParent()) { + auto* loop_block = + continuing_block->FindFirstParent(); + if (loop_block->FirstContinue()) { + // If our identifier is in loop_block->decls, make sure its index is + // less than first_continue + if (auto decl = loop_block->Decls().Find(symbol)) { + if (decl->order >= loop_block->NumDeclsAtFirstContinue()) { + AddError("continue statement bypasses declaration of '" + + builder_->Symbols().NameFor(symbol) + "'", + loop_block->FirstContinue()->source); + AddNote("identifier '" + builder_->Symbols().NameFor(symbol) + + "' declared here", + decl->variable->Declaration()->source); + AddNote("identifier '" + builder_->Symbols().NameFor(symbol) + + "' referenced in continuing block here", + expr->source); + return nullptr; + } + } } } } - } - } - auto* global = variable->As(); - if (current_function_) { - if (global) { - current_function_->AddDirectlyReferencedGlobal(global); - auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global); - if (refs) { - for (auto* var : *refs) { - current_function_->AddTransitivelyReferencedGlobal(var); + auto* global = variable->As(); + if (current_function_) { + if (global) { + current_function_->AddDirectlyReferencedGlobal(global); + auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global); + if (refs) { + for (auto* var : *refs) { + current_function_->AddTransitivelyReferencedGlobal(var); + } + } } - } - } - } else if (variable->Declaration()->Is()) { - if (resolved_overrides_) { - // Track the reference to this pipeline-overridable constant and any other - // pipeline-overridable constants that it references. - resolved_overrides_->Add(global); - auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global); - if (refs) { - for (auto* var : *refs) { - resolved_overrides_->Add(var); + } else if (variable->Declaration()->Is()) { + if (resolved_overrides_) { + // Track the reference to this pipeline-overridable constant and any other + // pipeline-overridable constants that it references. + resolved_overrides_->Add(global); + auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global); + if (refs) { + for (auto* var : *refs) { + resolved_overrides_->Add(var); + } + } } + } else if (variable->Declaration()->Is()) { + // Use of a module-scope 'var' outside of a function. + // Note: The spec is currently vague around the rules here. See + // https://github.com/gpuweb/gpuweb/issues/3081. Remove this comment when + // resolved. + std::string desc = "var '" + builder_->Symbols().NameFor(symbol) + "' "; + AddError(desc + "cannot be referenced at module-scope", expr->source); + AddNote(desc + "declared here", variable->Declaration()->source); + return nullptr; } - } - } else if (variable->Declaration()->Is()) { - // Use of a module-scope 'var' outside of a function. - // Note: The spec is currently vague around the rules here. See - // https://github.com/gpuweb/gpuweb/issues/3081. Remove this comment when resolved. - std::string desc = "var '" + builder_->Symbols().NameFor(symbol) + "' "; - AddError(desc + "cannot be referenced at module-scope", expr->source); - AddNote(desc + "declared here", variable->Declaration()->source); - return nullptr; - } - variable->AddUser(user); - return user; + variable->AddUser(user); + return user; + }, + [&](const type::Type*) { + AddError("missing '(' for type initializer or cast", expr->source.End()); + return nullptr; + }, + [&](const sem::Function*) { + AddError("missing '(' for function call", expr->source.End()); + return nullptr; + }); } - if (Is(sem_resolved)) { - AddError("missing '(' for function call", expr->source.End()); - return nullptr; - } - - if (IsBuiltin(symbol)) { - AddError("missing '(' for builtin call", expr->source.End()); - return nullptr; - } - - if (sem_.ResolvedSymbol(expr) || - type::ParseBuiltin(builder_->Symbols().NameFor(symbol)) != type::Builtin::kUndefined) { + if (resolved->BuiltinType() != type::Builtin::kUndefined) { AddError("missing '(' for type initializer or cast", expr->source.End()); return nullptr; } - // The dependency graph should have errored on this unresolved identifier before reaching here. - TINT_ICE(Resolver, diagnostics_) - << expr->source << " unresolved identifier:\n" - << "resolved: " << (sem_resolved ? sem_resolved->TypeInfo().name : "") << "\n" - << "name: " << builder_->Symbols().NameFor(symbol); + if (resolved->BuiltinFunction() != sem::BuiltinType::kNone) { + AddError("missing '(' for builtin call", expr->source.End()); + return nullptr; + } + + TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled resolved identifier: " << *resolved; return nullptr; } @@ -2977,8 +3014,8 @@ sem::ValueExpression* Resolver::Binary(const ast::BinaryExpression* expr) { return nullptr; } } else { - // The arguments have constant values, but the operator cannot be const-evaluated. This - // can only be evaluated at runtime. + // The arguments have constant values, but the operator cannot be const-evaluated. + // This can only be evaluated at runtime. stage = sem::EvaluationStage::kRuntime; } } @@ -3966,9 +4003,4 @@ void Resolver::AddNote(const std::string& msg, const Source& source) const { diagnostics_.add_note(diag::System::Resolver, msg, source); } -bool Resolver::IsBuiltin(Symbol symbol) const { - std::string name = builder_->Symbols().NameFor(symbol); - return sem::ParseBuiltinType(name) != sem::BuiltinType::kNone; -} - } // namespace tint::resolver diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 5adabcdd01..ac6ac5d14f 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -428,12 +428,9 @@ class Resolver { /// Adds the given note message to the diagnostics void AddNote(const std::string& msg, const Source& source) const; - /// @returns true if the symbol is the name of a builtin function. - bool IsBuiltin(Symbol) const; - - /// @returns the builtin type for the symbol @p symbol at @p source + /// @returns the type::Type for the builtin type @p builtin_ty with the identifier @p ident /// @note: Will raise an ICE if @p symbol is not a builtin type. - type::Type* BuiltinType(Symbol symbol, const Source& source) const; + type::Type* BuiltinType(type::Builtin builtin_ty, const ast::Identifier* ident) const; // ArrayInitializerSig represents a unique array initializer signature. // It is a tuple of the array type, number of arguments provided and earliest evaluation stage. diff --git a/src/tint/resolver/sem_helper.cc b/src/tint/resolver/sem_helper.cc index fbabb5e9e0..836fb97ea9 100644 --- a/src/tint/resolver/sem_helper.cc +++ b/src/tint/resolver/sem_helper.cc @@ -18,8 +18,7 @@ namespace tint::resolver { -SemHelper::SemHelper(ProgramBuilder* builder, DependencyGraph& dependencies) - : builder_(builder), dependencies_(dependencies) {} +SemHelper::SemHelper(ProgramBuilder* builder) : builder_(builder) {} SemHelper::~SemHelper() = default; diff --git a/src/tint/resolver/sem_helper.h b/src/tint/resolver/sem_helper.h index c8382db00e..2a7241bf12 100644 --- a/src/tint/resolver/sem_helper.h +++ b/src/tint/resolver/sem_helper.h @@ -29,8 +29,7 @@ class SemHelper { public: /// Constructor /// @param builder the program builder - /// @param dependencies the program dependency graph - explicit SemHelper(ProgramBuilder* builder, DependencyGraph& dependencies); + explicit SemHelper(ProgramBuilder* builder); ~SemHelper(); /// Get is a helper for obtaining the semantic node for the given AST node. @@ -76,18 +75,6 @@ class SemHelper { } } - /// @returns the resolved symbol (function, type or variable) for the given ast::Identifier or - /// ast::TypeName cast to the given semantic type. - /// @param node the node to retrieve - template - sem::Info::GetResultType* ResolvedSymbol(const ast::Node* node) const { - if (auto resolved = dependencies_.resolved_symbols.Find(node)) { - auto* sem = builder_->Sem().Get(*resolved); - return const_cast*>(sem); - } - return nullptr; - } - /// @returns the resolved type of the ast::Expression `expr` /// @param expr the expression type::Type* TypeOf(const ast::Expression* expr) const; @@ -104,7 +91,6 @@ class SemHelper { private: ProgramBuilder* builder_; - DependencyGraph& dependencies_; }; } // namespace tint::resolver diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index d6a29502c0..02b8f39b1c 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -2344,8 +2344,8 @@ bool Validator::Assignment(const ast::Statement* a, const type::Type* rhs_ty) co // https://gpuweb.github.io/gpuweb/wgsl/#assignment-statement auto const* lhs_ty = sem_.TypeOf(lhs); - if (auto* variable = sem_.ResolvedSymbol(lhs)) { - auto* v = variable->Declaration(); + if (auto* var_user = sem_.Get(lhs)) { + auto* v = var_user->Variable()->Declaration(); const char* err = Switch( v, // [&](const ast::Parameter*) { return "cannot assign to function parameter"; }, @@ -2392,8 +2392,8 @@ bool Validator::IncrementDecrementStatement(const ast::IncrementDecrementStateme // https://gpuweb.github.io/gpuweb/wgsl/#increment-decrement - if (auto* variable = sem_.ResolvedSymbol(lhs)) { - auto* v = variable->Declaration(); + if (auto* var_user = sem_.Get(lhs)) { + auto* v = var_user->Variable()->Declaration(); const char* err = Switch( v, // [&](const ast::Parameter*) { return "cannot modify function parameter"; }, diff --git a/src/tint/sem/builtin_type.h b/src/tint/sem/builtin_type.h index aa1b36ed2c..1328d99c1c 100644 --- a/src/tint/sem/builtin_type.h +++ b/src/tint/sem/builtin_type.h @@ -161,6 +161,124 @@ const char* str(BuiltinType i); /// matches the name in the WGSL spec. std::ostream& operator<<(std::ostream& out, BuiltinType i); +/// All builtin types +constexpr BuiltinType kBuiltinTypes[] = { + BuiltinType::kAbs, + BuiltinType::kAcos, + BuiltinType::kAcosh, + BuiltinType::kAll, + BuiltinType::kAny, + BuiltinType::kArrayLength, + BuiltinType::kAsin, + BuiltinType::kAsinh, + BuiltinType::kAtan, + BuiltinType::kAtan2, + BuiltinType::kAtanh, + BuiltinType::kCeil, + BuiltinType::kClamp, + BuiltinType::kCos, + BuiltinType::kCosh, + BuiltinType::kCountLeadingZeros, + BuiltinType::kCountOneBits, + BuiltinType::kCountTrailingZeros, + BuiltinType::kCross, + BuiltinType::kDegrees, + BuiltinType::kDeterminant, + BuiltinType::kDistance, + BuiltinType::kDot, + BuiltinType::kDot4I8Packed, + BuiltinType::kDot4U8Packed, + BuiltinType::kDpdx, + BuiltinType::kDpdxCoarse, + BuiltinType::kDpdxFine, + BuiltinType::kDpdy, + BuiltinType::kDpdyCoarse, + BuiltinType::kDpdyFine, + BuiltinType::kExp, + BuiltinType::kExp2, + BuiltinType::kExtractBits, + BuiltinType::kFaceForward, + BuiltinType::kFirstLeadingBit, + BuiltinType::kFirstTrailingBit, + BuiltinType::kFloor, + BuiltinType::kFma, + BuiltinType::kFract, + BuiltinType::kFrexp, + BuiltinType::kFwidth, + BuiltinType::kFwidthCoarse, + BuiltinType::kFwidthFine, + BuiltinType::kInsertBits, + BuiltinType::kInverseSqrt, + BuiltinType::kLdexp, + BuiltinType::kLength, + BuiltinType::kLog, + BuiltinType::kLog2, + BuiltinType::kMax, + BuiltinType::kMin, + BuiltinType::kMix, + BuiltinType::kModf, + BuiltinType::kNormalize, + BuiltinType::kPack2X16Float, + BuiltinType::kPack2X16Snorm, + BuiltinType::kPack2X16Unorm, + BuiltinType::kPack4X8Snorm, + BuiltinType::kPack4X8Unorm, + BuiltinType::kPow, + BuiltinType::kQuantizeToF16, + BuiltinType::kRadians, + BuiltinType::kReflect, + BuiltinType::kRefract, + BuiltinType::kReverseBits, + BuiltinType::kRound, + BuiltinType::kSaturate, + BuiltinType::kSelect, + BuiltinType::kSign, + BuiltinType::kSin, + BuiltinType::kSinh, + BuiltinType::kSmoothstep, + BuiltinType::kSqrt, + BuiltinType::kStep, + BuiltinType::kStorageBarrier, + BuiltinType::kTan, + BuiltinType::kTanh, + BuiltinType::kTranspose, + BuiltinType::kTrunc, + BuiltinType::kUnpack2X16Float, + BuiltinType::kUnpack2X16Snorm, + BuiltinType::kUnpack2X16Unorm, + BuiltinType::kUnpack4X8Snorm, + BuiltinType::kUnpack4X8Unorm, + BuiltinType::kWorkgroupBarrier, + BuiltinType::kWorkgroupUniformLoad, + BuiltinType::kTextureDimensions, + BuiltinType::kTextureGather, + BuiltinType::kTextureGatherCompare, + BuiltinType::kTextureNumLayers, + BuiltinType::kTextureNumLevels, + BuiltinType::kTextureNumSamples, + BuiltinType::kTextureSample, + BuiltinType::kTextureSampleBias, + BuiltinType::kTextureSampleCompare, + BuiltinType::kTextureSampleCompareLevel, + BuiltinType::kTextureSampleGrad, + BuiltinType::kTextureSampleLevel, + BuiltinType::kTextureSampleBaseClampToEdge, + BuiltinType::kTextureStore, + BuiltinType::kTextureLoad, + BuiltinType::kAtomicLoad, + BuiltinType::kAtomicStore, + BuiltinType::kAtomicAdd, + BuiltinType::kAtomicSub, + BuiltinType::kAtomicMax, + BuiltinType::kAtomicMin, + BuiltinType::kAtomicAnd, + BuiltinType::kAtomicOr, + BuiltinType::kAtomicXor, + BuiltinType::kAtomicExchange, + BuiltinType::kAtomicCompareExchangeWeak, + BuiltinType::kTintMaterialize, +}; + } // namespace tint::sem #endif // SRC_TINT_SEM_BUILTIN_TYPE_H_ diff --git a/src/tint/sem/builtin_type.h.tmpl b/src/tint/sem/builtin_type.h.tmpl index ef7d436116..7cd715bbf7 100644 --- a/src/tint/sem/builtin_type.h.tmpl +++ b/src/tint/sem/builtin_type.h.tmpl @@ -41,6 +41,13 @@ const char* str(BuiltinType i); /// matches the name in the WGSL spec. std::ostream& operator<<(std::ostream& out, BuiltinType i); +/// All builtin types +constexpr BuiltinType kBuiltinTypes[] = { +{{- range Sem.Builtins }} + BuiltinType::k{{PascalCase .Name}}, +{{- end }} +}; + } // namespace tint::sem #endif // SRC_TINT_SEM_BUILTIN_TYPE_H_