From afc53fa942b86e219620ec67ae4f9a268c0918e8 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 22 Feb 2023 17:15:53 +0000 Subject: [PATCH] tint/resolver: Bring back enum suggestions The dependency graph no longer errors if a symbol cannot be resolved, instead the ResolvedIdentifier now has an unresolved variant. This is required as the second resolve phase only has the full context of the identifier usage, to provide the hints. Also: Split Slice out of the utils/vector.h, so it can be used as a lightweight view over static data. Fixed: tint:1842 Change-Id: I31fa7697790be24c35b7e4fab5ca903c8a7afbba Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/121020 Commit-Queue: Ben Clayton Kokoro: Kokoro Commit-Queue: Ben Clayton Reviewed-by: Dan Sinclair --- src/tint/BUILD.gn | 3 + src/tint/CMakeLists.txt | 3 + src/tint/resolver/dependency_graph.cc | 160 +++++--------- src/tint/resolver/dependency_graph.h | 20 +- src/tint/resolver/dependency_graph_test.cc | 96 ++++---- src/tint/resolver/function_validation_test.cc | 15 +- src/tint/resolver/resolver.cc | 45 +++- src/tint/resolver/resolver.h | 12 + src/tint/resolver/type_validation_test.cc | 2 +- .../resolver/unresolved_identifier_test.cc | 108 +++++++++ src/tint/resolver/validation_test.cc | 8 +- src/tint/utils/slice.h | 205 ++++++++++++++++++ src/tint/utils/slice_test.cc | 131 +++++++++++ src/tint/utils/string.cc | 32 +++ src/tint/utils/string.h | 39 +--- src/tint/utils/string_test.cc | 6 +- src/tint/utils/vector.h | 161 ++------------ src/tint/utils/vector_test.cc | 25 --- 18 files changed, 682 insertions(+), 389 deletions(-) create mode 100644 src/tint/resolver/unresolved_identifier_test.cc create mode 100644 src/tint/utils/slice.h create mode 100644 src/tint/utils/slice_test.cc diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 50b468cc28..227893caf4 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -218,6 +218,7 @@ libtint_source_set("libtint_base_src") { "utils/map.h", "utils/math.h", "utils/scoped_assignment.h", + "utils/slice.h", "utils/string.cc", "utils/string.h", "utils/unique_allocator.h", @@ -1402,6 +1403,7 @@ if (tint_build_unittests) { "resolver/type_initializer_validation_test.cc", "resolver/type_validation_test.cc", "resolver/uniformity_test.cc", + "resolver/unresolved_identifier_test.cc", "resolver/validation_test.cc", "resolver/validator_is_storeable_test.cc", "resolver/variable_test.cc", @@ -1553,6 +1555,7 @@ if (tint_build_unittests) { "utils/result_test.cc", "utils/reverse_test.cc", "utils/scoped_assignment_test.cc", + "utils/slice_test.cc", "utils/string_test.cc", "utils/transform_test.cc", "utils/unique_allocator_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 91c0851ede..b8a803e6f7 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -527,6 +527,7 @@ list(APPEND TINT_LIB_SRCS utils/map.h utils/math.h utils/scoped_assignment.h + utils/slice.h utils/string.cc utils/string.h utils/unique_allocator.h @@ -927,6 +928,7 @@ if(TINT_BUILD_TESTS) resolver/struct_address_space_use_test.cc resolver/type_initializer_validation_test.cc resolver/type_validation_test.cc + resolver/unresolved_identifier_test.cc resolver/validation_test.cc resolver/validator_is_storeable_test.cc resolver/variable_test.cc @@ -981,6 +983,7 @@ if(TINT_BUILD_TESTS) utils/result_test.cc utils/reverse_test.cc utils/scoped_assignment_test.cc + utils/slice_test.cc utils/string_test.cc utils/transform_test.cc utils/unique_allocator_test.cc diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc index e0f0656664..0b8f728934 100644 --- a/src/tint/resolver/dependency_graph.cc +++ b/src/tint/resolver/dependency_graph.cc @@ -79,8 +79,6 @@ struct Global; struct DependencyInfo { /// The source of the symbol that forms the dependency Source source; - /// A string describing how the dependency is referenced. e.g. 'calls' - const char* action = nullptr; }; /// DependencyEdge describes the two Globals used to define a dependency @@ -174,12 +172,12 @@ class DependencyScanner { Declare(str->name->symbol, str); for (auto* member : str->members) { TraverseAttributes(member->attributes); - TraverseTypeExpression(member->type); + TraverseExpression(member->type); } }, [&](const ast::Alias* alias) { Declare(alias->name->symbol, alias); - TraverseTypeExpression(alias->type); + TraverseExpression(alias->type); }, [&](const ast::Function* func) { Declare(func->name->symbol, func); @@ -195,9 +193,7 @@ class DependencyScanner { [&](const ast::Enable*) { // Enable directives do not affect the dependency graph. }, - [&](const ast::ConstAssert* assertion) { - TraverseValueExpression(assertion->condition); - }, + [&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); }, [&](Default) { UnhandledNode(diagnostics_, global->node); }); } @@ -205,12 +201,12 @@ class DependencyScanner { /// Traverses the variable, performing symbol resolution. void TraverseVariable(const ast::Variable* v) { if (auto* var = v->As()) { - TraverseAddressSpaceExpression(var->declared_address_space); - TraverseAccessExpression(var->declared_access); + TraverseExpression(var->declared_address_space); + TraverseExpression(var->declared_access); } - TraverseTypeExpression(v->type); + TraverseExpression(v->type); TraverseAttributes(v->attributes); - TraverseValueExpression(v->initializer); + TraverseExpression(v->initializer); } /// Traverses the function, performing symbol resolution and determining global dependencies. @@ -222,10 +218,10 @@ class DependencyScanner { // with the same identifier as its type. for (auto* param : func->params) { TraverseAttributes(param->attributes); - TraverseTypeExpression(param->type); + TraverseExpression(param->type); } // Resolve the return type - TraverseTypeExpression(func->return_type); + TraverseExpression(func->return_type); // Push the scope stack for the parameters and function body. scope_stack_.Push(); @@ -259,29 +255,29 @@ class DependencyScanner { Switch( stmt, // [&](const ast::AssignmentStatement* a) { - TraverseValueExpression(a->lhs); - TraverseValueExpression(a->rhs); + TraverseExpression(a->lhs); + TraverseExpression(a->rhs); }, [&](const ast::BlockStatement* b) { scope_stack_.Push(); TINT_DEFER(scope_stack_.Pop()); TraverseStatements(b->statements); }, - [&](const ast::BreakIfStatement* b) { TraverseValueExpression(b->condition); }, - [&](const ast::CallStatement* r) { TraverseValueExpression(r->expr); }, + [&](const ast::BreakIfStatement* b) { TraverseExpression(b->condition); }, + [&](const ast::CallStatement* r) { TraverseExpression(r->expr); }, [&](const ast::CompoundAssignmentStatement* a) { - TraverseValueExpression(a->lhs); - TraverseValueExpression(a->rhs); + TraverseExpression(a->lhs); + TraverseExpression(a->rhs); }, [&](const ast::ForLoopStatement* l) { scope_stack_.Push(); TINT_DEFER(scope_stack_.Pop()); TraverseStatement(l->initializer); - TraverseValueExpression(l->condition); + TraverseExpression(l->condition); TraverseStatement(l->continuing); TraverseStatement(l->body); }, - [&](const ast::IncrementDecrementStatement* i) { TraverseValueExpression(i->lhs); }, + [&](const ast::IncrementDecrementStatement* i) { TraverseExpression(i->lhs); }, [&](const ast::LoopStatement* l) { scope_stack_.Push(); TINT_DEFER(scope_stack_.Pop()); @@ -289,18 +285,18 @@ class DependencyScanner { TraverseStatement(l->continuing); }, [&](const ast::IfStatement* i) { - TraverseValueExpression(i->condition); + TraverseExpression(i->condition); TraverseStatement(i->body); if (i->else_statement) { TraverseStatement(i->else_statement); } }, - [&](const ast::ReturnStatement* r) { TraverseValueExpression(r->value); }, + [&](const ast::ReturnStatement* r) { TraverseExpression(r->value); }, [&](const ast::SwitchStatement* s) { - TraverseValueExpression(s->condition); + TraverseExpression(s->condition); for (auto* c : s->body) { for (auto* sel : c->selectors) { - TraverseValueExpression(sel->expr); + TraverseExpression(sel->expr); } TraverseStatement(c->body); } @@ -315,12 +311,10 @@ class DependencyScanner { [&](const ast::WhileStatement* w) { scope_stack_.Push(); TINT_DEFER(scope_stack_.Pop()); - TraverseValueExpression(w->condition); + TraverseExpression(w->condition); TraverseStatement(w->body); }, - [&](const ast::ConstAssert* assertion) { - TraverseValueExpression(assertion->condition); - }, + [&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); }, [&](Default) { if (TINT_UNLIKELY((!stmt->IsAnyOf()))) { @@ -340,70 +334,28 @@ class DependencyScanner { } } - /// Traverses the expression @p root_expr for the intended use as a value, performing symbol - /// resolution and determining global dependencies. - void TraverseValueExpression(const ast::Expression* root) { - TraverseExpression(root, "identifier", "references"); - } - - /// Traverses the expression @p root_expr for the intended use as a type, performing symbol - /// resolution and determining global dependencies. - void TraverseTypeExpression(const ast::Expression* root) { - TraverseExpression(root, "type", "references"); - } - - /// Traverses the expression @p root_expr for the intended use as an address space, performing - /// symbol resolution and determining global dependencies. - void TraverseAddressSpaceExpression(const ast::Expression* root) { - TraverseExpression(root, "address space", "references"); - } - - /// Traverses the expression @p root_expr for the intended use as an access, performing symbol - /// resolution and determining global dependencies. - void TraverseAccessExpression(const ast::Expression* root) { - TraverseExpression(root, "access", "references"); - } - - /// Traverses the expression @p root_expr for the intended use as a call target, performing - /// symbol resolution and determining global dependencies. - void TraverseCallableExpression(const ast::Expression* root) { - TraverseExpression(root, "function", "calls"); - } - /// Traverses the expression @p root_expr, performing symbol resolution and determining global /// dependencies. - void TraverseExpression(const ast::Expression* root_expr, - const char* root_use, - const char* root_action) { + void TraverseExpression(const ast::Expression* root_expr) { if (!root_expr) { return; } - struct Pending { - const ast::Expression* expr; - const char* use; - const char* action; - }; - utils::Vector pending{{root_expr, root_use, root_action}}; + utils::Vector pending{root_expr}; while (!pending.IsEmpty()) { - auto next = pending.Pop(); - ast::TraverseExpressions(next.expr, diagnostics_, [&](const ast::Expression* expr) { + ast::TraverseExpressions(pending.Pop(), diagnostics_, [&](const ast::Expression* expr) { Switch( expr, [&](const ast::IdentifierExpression* e) { - AddDependency(e->identifier, e->identifier->symbol, next.use, next.action); + AddDependency(e->identifier, e->identifier->symbol); if (auto* tmpl_ident = e->identifier->As()) { for (auto* arg : tmpl_ident->arguments) { - pending.Push({arg, "identifier", "references"}); + pending.Push(arg); } } }, - [&](const ast::CallExpression* call) { - TraverseCallableExpression(call->target); - }, - [&](const ast::BitcastExpression* cast) { - TraverseTypeExpression(cast->type); - }); + [&](const ast::CallExpression* call) { TraverseExpression(call->target); }, + [&](const ast::BitcastExpression* cast) { TraverseExpression(cast->type); }); return ast::TraverseAction::Descend; }); } @@ -423,42 +375,42 @@ class DependencyScanner { bool handled = Switch( attr, [&](const ast::BindingAttribute* binding) { - TraverseValueExpression(binding->expr); + TraverseExpression(binding->expr); return true; }, [&](const ast::BuiltinAttribute* builtin) { - TraverseExpression(builtin->builtin, "builtin", "references"); + TraverseExpression(builtin->builtin); return true; }, [&](const ast::GroupAttribute* group) { - TraverseValueExpression(group->expr); + TraverseExpression(group->expr); return true; }, [&](const ast::IdAttribute* id) { - TraverseValueExpression(id->expr); + TraverseExpression(id->expr); return true; }, [&](const ast::InterpolateAttribute* interpolate) { - TraverseExpression(interpolate->type, "interpolation type", "references"); - TraverseExpression(interpolate->sampling, "interpolation sampling", "references"); + TraverseExpression(interpolate->type); + TraverseExpression(interpolate->sampling); return true; }, [&](const ast::LocationAttribute* loc) { - TraverseValueExpression(loc->expr); + TraverseExpression(loc->expr); return true; }, [&](const ast::StructMemberAlignAttribute* align) { - TraverseValueExpression(align->expr); + TraverseExpression(align->expr); return true; }, [&](const ast::StructMemberSizeAttribute* size) { - TraverseValueExpression(size->expr); + TraverseExpression(size->expr); return true; }, [&](const ast::WorkgroupAttribute* wg) { - TraverseValueExpression(wg->x); - TraverseValueExpression(wg->y); - TraverseValueExpression(wg->z); + TraverseExpression(wg->x); + TraverseExpression(wg->y); + TraverseExpression(wg->z); return true; }); if (handled) { @@ -476,10 +428,7 @@ class DependencyScanner { } /// Adds the dependency from @p from to @p to, erroring if @p to cannot be resolved. - void AddDependency(const ast::Identifier* from, - Symbol to, - const char* use, - const char* action) { + void AddDependency(const ast::Identifier* from, Symbol to) { auto* resolved = scope_stack_.Get(to); if (!resolved) { auto s = symbols_.NameFor(to); @@ -521,13 +470,14 @@ class DependencyScanner { return; } - UnknownSymbol(to, from->source, use); + // Unresolved. + graph_.resolved_identifiers.Add(from, UnresolvedIdentifier{s}); return; } if (auto global = globals_.Find(to); global && (*global)->node == resolved) { if (dependency_edges_.Add(DependencyEdge{current_global_, *global}, - DependencyInfo{from->source, action})) { + DependencyInfo{from->source})) { current_global_->deps.Push(*global); } } @@ -535,12 +485,6 @@ class DependencyScanner { graph_.resolved_identifiers.Add(from, ResolvedIdentifier(resolved)); } - /// Appends an error to the diagnostics that the given symbol cannot be resolved. - void UnknownSymbol(Symbol name, Source source, const char* use) { - AddError(diagnostics_, "unknown " + std::string(use) + ": '" + symbols_.NameFor(name) + "'", - source); - } - using VariableMap = utils::Hashmap; const SymbolTable& symbols_; const GlobalMap& globals_; @@ -787,7 +731,7 @@ struct DependencyAnalysis { auto* to = (i + 1 < stack.Length()) ? stack[i + 1] : stack[loop_start]; auto info = DepInfoFor(from, to); AddNote(diagnostics_, - KindOf(from->node) + " '" + NameOf(from->node) + "' " + info.action + " " + + KindOf(from->node) + " '" + NameOf(from->node) + "' references " + KindOf(to->node) + " '" + NameOf(to->node) + "' here", info.source); } @@ -831,8 +775,7 @@ struct DependencyAnalysis { /// Global map, keyed by name. Populated by GatherGlobals(). GlobalMap globals_; - /// Map of DependencyEdge to DependencyInfo. Populated by - /// DetermineDependencies(). + /// Map of DependencyEdge to DependencyInfo. Populated by DetermineDependencies(). DependencyEdges dependency_edges_; /// Globals in declaration order. Populated by GatherGlobals(). @@ -857,9 +800,6 @@ bool DependencyGraph::Build(const ast::Module& module, } std::string ResolvedIdentifier::String(const SymbolTable& symbols, diag::List& diagnostics) const { - if (!Resolved()) { - return ""; - } if (auto* node = Node()) { return Switch( node, @@ -911,6 +851,10 @@ std::string ResolvedIdentifier::String(const SymbolTable& symbols, diag::List& d if (auto fmt = TexelFormat(); fmt != builtin::TexelFormat::kUndefined) { return "texel format '" + utils::ToString(fmt) + "'"; } + if (auto* unresolved = Unresolved()) { + return "unresolved identifier '" + unresolved->name + "'"; + } + TINT_UNREACHABLE(Resolver, diagnostics) << "unhandled ResolvedIdentifier"; return ""; } diff --git a/src/tint/resolver/dependency_graph.h b/src/tint/resolver/dependency_graph.h index da25383f16..429a47afae 100644 --- a/src/tint/resolver/dependency_graph.h +++ b/src/tint/resolver/dependency_graph.h @@ -32,8 +32,15 @@ namespace tint::resolver { +/// UnresolvedIdentifier is the variant value used by ResolvedIdentifier +struct UnresolvedIdentifier { + /// Name of the unresolved identifier + std::string name; +}; + /// ResolvedIdentifier holds the resolution of an ast::Identifier. /// Can hold one of: +/// - UnresolvedIdentifier /// - const ast::TypeDecl* (as const ast::Node*) /// - const ast::Variable* (as const ast::Node*) /// - const ast::Function* (as const ast::Node*) @@ -47,15 +54,18 @@ namespace tint::resolver { /// - builtin::TexelFormat class ResolvedIdentifier { public: - ResolvedIdentifier() = default; - /// Constructor /// @param value the resolved identifier value template ResolvedIdentifier(T value) : value_(value) {} // NOLINT(runtime/explicit) - /// @return true if the ResolvedIdentifier holds a value (successfully resolved) - bool Resolved() const { return !std::holds_alternative(value_); } + /// @return the UnresolvedIdentifier if the identifier was not resolved + const UnresolvedIdentifier* Unresolved() const { + if (auto n = std::get_if(&value_)) { + return n; + } + return nullptr; + } /// @return the node pointer if the ResolvedIdentifier holds an AST node, otherwise nullptr const ast::Node* Node() const { @@ -160,7 +170,7 @@ class ResolvedIdentifier { std::string String(const SymbolTable& symbols, diag::List& diagnostics) const; private: - std::variant"; } -/// @returns the the diagnostic message name used for the given use -std::string DiagString(SymbolUseKind kind) { - switch (kind) { - case SymbolUseKind::GlobalVarType: - case SymbolUseKind::GlobalConstType: - case SymbolUseKind::AliasType: - case SymbolUseKind::StructMemberType: - case SymbolUseKind::ParameterType: - case SymbolUseKind::LocalVarType: - case SymbolUseKind::LocalLetType: - case SymbolUseKind::NestedLocalVarType: - case SymbolUseKind::NestedLocalLetType: - return "type"; - case SymbolUseKind::GlobalVarArrayElemType: - case SymbolUseKind::GlobalVarVectorElemType: - case SymbolUseKind::GlobalVarMatrixElemType: - case SymbolUseKind::GlobalVarSampledTexElemType: - case SymbolUseKind::GlobalVarMultisampledTexElemType: - case SymbolUseKind::GlobalConstArrayElemType: - case SymbolUseKind::GlobalConstVectorElemType: - case SymbolUseKind::GlobalConstMatrixElemType: - case SymbolUseKind::LocalVarArrayElemType: - case SymbolUseKind::LocalVarVectorElemType: - case SymbolUseKind::LocalVarMatrixElemType: - case SymbolUseKind::GlobalVarValue: - case SymbolUseKind::GlobalVarArraySizeValue: - case SymbolUseKind::GlobalConstValue: - case SymbolUseKind::GlobalConstArraySizeValue: - case SymbolUseKind::LocalVarValue: - case SymbolUseKind::LocalVarArraySizeValue: - case SymbolUseKind::LocalLetValue: - case SymbolUseKind::NestedLocalVarValue: - case SymbolUseKind::NestedLocalLetValue: - case SymbolUseKind::WorkgroupSizeValue: - return "identifier"; - case SymbolUseKind::CallFunction: - return "function"; - } - return ""; -} - /// @returns the declaration scope depth for the symbol declaration kind. /// Globals are at depth 0, parameters and locals are at depth 1, /// nested locals are at depth 2. @@ -783,10 +742,16 @@ TEST_P(ResolverDependencyGraphUndeclaredSymbolTest, Test) { // Build a use of a non-existent symbol SymbolTestHelper helper(this); - helper.Add(use_kind, symbol, Source{{56, 78}}); + auto* ident = helper.Add(use_kind, symbol, Source{{56, 78}}); helper.Build(); - Build("56:78 error: unknown " + DiagString(use_kind) + ": 'SYMBOL'"); + auto graph = Build(); + + auto resolved_identifier = graph.resolved_identifiers.Find(ident); + ASSERT_NE(resolved_identifier, nullptr); + auto* unresolved = resolved_identifier->Unresolved(); + ASSERT_NE(unresolved, nullptr); + EXPECT_EQ(unresolved->name, "SYMBOL"); } INSTANTIATE_TEST_SUITE_P(Types, @@ -826,14 +791,28 @@ TEST_F(ResolverDependencyGraphDeclSelfUse, GlobalConst) { TEST_F(ResolverDependencyGraphDeclSelfUse, LocalVar) { const Symbol symbol = Sym("SYMBOL"); - WrapInFunction(Decl(Var(symbol, ty.i32(), Mul(Expr(Source{{12, 34}}, symbol), 123_i)))); - Build("12:34 error: unknown identifier: 'SYMBOL'"); + auto* ident = Ident(Source{{12, 34}}, symbol); + WrapInFunction(Decl(Var(symbol, ty.i32(), Mul(Expr(ident), 123_i)))); + auto graph = Build(); + + auto resolved_identifier = graph.resolved_identifiers.Find(ident); + ASSERT_TRUE(resolved_identifier); + auto* unresolved = resolved_identifier->Unresolved(); + ASSERT_NE(unresolved, nullptr); + EXPECT_EQ(unresolved->name, "SYMBOL"); } TEST_F(ResolverDependencyGraphDeclSelfUse, LocalLet) { const Symbol symbol = Sym("SYMBOL"); - WrapInFunction(Decl(Let(symbol, ty.i32(), Mul(Expr(Source{{12, 34}}, symbol), 123_i)))); - Build("12:34 error: unknown identifier: 'SYMBOL'"); + auto* ident = Ident(Source{{12, 34}}, symbol); + WrapInFunction(Decl(Let(symbol, ty.i32(), Mul(Expr(ident), 123_i)))); + auto graph = Build(); + + auto resolved_identifier = graph.resolved_identifiers.Find(ident); + ASSERT_TRUE(resolved_identifier); + auto* unresolved = resolved_identifier->Unresolved(); + ASSERT_NE(unresolved, nullptr); + EXPECT_EQ(unresolved->name, "SYMBOL"); } } // namespace undeclared_tests @@ -852,7 +831,7 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, DirectCall) { utils::Vector{CallStmt(Call(Ident(Source{{56, 78}}, "main")))}); Build(R"(12:34 error: cyclic dependency found: 'main' -> 'main' -56:78 note: function 'main' calls function 'main' here)"); +56:78 note: function 'main' references function 'main' here)"); } TEST_F(ResolverDependencyGraphCyclicRefTest, IndirectCall) { @@ -876,9 +855,9 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, IndirectCall) { utils::Vector{CallStmt(Call(Ident(Source{{5, 10}}, "c")))}); Build(R"(5:1 error: cyclic dependency found: 'b' -> 'c' -> 'd' -> 'b' -5:10 note: function 'b' calls function 'c' here -4:10 note: function 'c' calls function 'd' here -3:10 note: function 'd' calls function 'b' here)"); +5:10 note: function 'b' references function 'c' here +4:10 note: function 'c' references function 'd' here +3:10 note: function 'd' references function 'b' here)"); } TEST_F(ResolverDependencyGraphCyclicRefTest, Alias_Direct) { @@ -1160,17 +1139,22 @@ TEST_P(ResolverDependencyGraphResolveToUserDeclTest, Test) { // If the declaration is visible to the use, then we expect the analysis to // succeed. - bool expect_pass = ScopeDepth(decl_kind) <= ScopeDepth(use_kind); - auto graph = Build(expect_pass ? "" : "56:78 error: unknown identifier: 'SYMBOL'"); + bool expect_resolved = ScopeDepth(decl_kind) <= ScopeDepth(use_kind); + auto graph = Build(); - if (expect_pass) { + auto resolved_identifier = graph.resolved_identifiers.Find(use); + ASSERT_TRUE(resolved_identifier); + + if (expect_resolved) { // Check that the use resolves to the declaration - auto resolved_identifier = graph.resolved_identifiers.Find(use); - ASSERT_TRUE(resolved_identifier); auto* resolved_node = resolved_identifier->Node(); EXPECT_EQ(resolved_node, decl) << "resolved: " << (resolved_node ? resolved_node->TypeInfo().name : "") << "\n" << "decl: " << decl->TypeInfo().name; + } else { + auto* unresolved = resolved_identifier->Unresolved(); + ASSERT_NE(unresolved, nullptr); + EXPECT_EQ(unresolved->name, "SYMBOL"); } } diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc index b7af242ec3..b18ff2a74f 100644 --- a/src/tint/resolver/function_validation_test.cc +++ b/src/tint/resolver/function_validation_test.cc @@ -1070,11 +1070,14 @@ TEST_P(ResolverFunctionParameterValidationTest, AddressSpaceNoExtension) { ss << param.address_space; EXPECT_FALSE(r()->Resolve()); if (param.expectation == Expectation::kInvalid) { - EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: '" + ss.str() + "'"); + std::string err = R"(12:34 error: unresolved address space '${addr_space}' +12:34 note: Possible values: 'function', 'private', 'push_constant', 'storage', 'uniform', 'workgroup')"; + err = utils::ReplaceAll(err, "${addr_space}", utils::ToString(param.address_space)); + EXPECT_EQ(r()->error(), err); } else { EXPECT_EQ(r()->error(), - "12:34 error: function parameter of pointer type cannot be in '" + ss.str() + - "' address space"); + "12:34 error: function parameter of pointer type cannot be in '" + + utils::ToString(param.address_space) + "' address space"); } } } @@ -1091,8 +1094,10 @@ TEST_P(ResolverFunctionParameterValidationTest, AddressSpaceWithExtension) { } else { EXPECT_FALSE(r()->Resolve()); if (param.expectation == Expectation::kInvalid) { - EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: '" + - utils::ToString(param.address_space) + "'"); + std::string err = R"(12:34 error: unresolved address space '${addr_space}' +12:34 note: Possible values: 'function', 'private', 'push_constant', 'storage', 'uniform', 'workgroup')"; + err = utils::ReplaceAll(err, "${addr_space}", utils::ToString(param.address_space)); + EXPECT_EQ(r()->error(), err); } else { EXPECT_EQ(r()->error(), "12:34 error: function parameter of pointer type cannot be in '" + diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index ca5c46f9cc..27f1881743 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -1418,7 +1418,8 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) { for (auto* expr : utils::Reverse(sorted)) { auto* sem_expr = Switch( - expr, [&](const ast::IndexAccessorExpression* array) { return IndexAccessor(array); }, + expr, // + [&](const ast::IndexAccessorExpression* array) { return IndexAccessor(array); }, [&](const ast::BinaryExpression* bin_op) { return Binary(bin_op); }, [&](const ast::BitcastExpression* bitcast) { return Bitcast(bitcast); }, [&](const ast::CallExpression* call) { return Call(call); }, @@ -1488,10 +1489,12 @@ sem::ValueExpression* Resolver::ValueExpression(const ast::Expression* expr) { } sem::TypeExpression* Resolver::TypeExpression(const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "type"}; return sem_.AsTypeExpression(Expression(expr)); } sem::FunctionExpression* Resolver::FunctionExpression(const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "call target"}; return sem_.AsFunctionExpression(Expression(expr)); } @@ -1505,31 +1508,38 @@ type::Type* Resolver::Type(const ast::Expression* ast) { sem::BuiltinEnumExpression* Resolver::AddressSpaceExpression( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "address space", builtin::kAddressSpaceStrings}; return sem_.AsAddressSpace(Expression(expr)); } sem::BuiltinEnumExpression* Resolver::BuiltinValueExpression( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "builtin value", builtin::kBuiltinValueStrings}; return sem_.AsBuiltinValue(Expression(expr)); } sem::BuiltinEnumExpression* Resolver::TexelFormatExpression( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "texel format", builtin::kTexelFormatStrings}; return sem_.AsTexelFormat(Expression(expr)); } sem::BuiltinEnumExpression* Resolver::AccessExpression( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "access", builtin::kAccessStrings}; return sem_.AsAccess(Expression(expr)); } sem::BuiltinEnumExpression* Resolver::InterpolationSampling( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "interpolation sampling", + builtin::kInterpolationSamplingStrings}; return sem_.AsInterpolationSampling(Expression(expr)); } sem::BuiltinEnumExpression* Resolver::InterpolationType( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "interpolation type", builtin::kInterpolationTypeStrings}; return sem_.AsInterpolationType(Expression(expr)); } @@ -2196,6 +2206,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { return ty_init_or_conv(ty); } + if (auto* unresolved = resolved->Unresolved()) { + AddError("unresolved call target '" + unresolved->name + "'", expr->source); + return nullptr; + } + ErrorMismatchedResolvedIdentifier(ident->source, *resolved, "call target"); return nullptr; }(); @@ -2541,11 +2556,11 @@ type::Type* Resolver::BuiltinType(builtin::Builtin builtin_ty, const ast::Identi return nullptr; } - auto* format = sem_.AsTexelFormat(Expression(tmpl_ident->arguments[0])); + auto* format = TexelFormatExpression(tmpl_ident->arguments[0]); if (TINT_UNLIKELY(!format)) { return nullptr; } - auto* access = sem_.AsAccess(Expression(tmpl_ident->arguments[1])); + auto* access = AccessExpression(tmpl_ident->arguments[1]); if (TINT_UNLIKELY(!access)) { return nullptr; } @@ -3030,6 +3045,30 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) { expr, current_statement_, fmt); } + if (auto* unresolved = resolved->Unresolved()) { + if (identifier_resolve_hint_.expression == expr) { + AddError("unresolved " + std::string(identifier_resolve_hint_.usage) + " '" + + unresolved->name + "'", + expr->source); + if (!identifier_resolve_hint_.suggestions.IsEmpty()) { + // Filter out suggestions that have a leading underscore. + utils::Vector filtered; + for (auto* str : identifier_resolve_hint_.suggestions) { + if (str[0] != '_') { + filtered.Push(str); + } + } + std::ostringstream msg; + utils::SuggestAlternatives(unresolved->name, + filtered.Slice().Reinterpret(), msg); + AddNote(msg.str(), expr->source); + } + } else { + AddError("unresolved identifier '" + unresolved->name + "'", expr->source); + } + return nullptr; + } + TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled resolved identifier: " << resolved->String(builder_->Symbols(), diagnostics_); return nullptr; diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index bbde9e952c..c29a416aca 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -530,6 +530,17 @@ class Resolver { std::unordered_set parameter_reads; }; + /// A hint for the usage of an identifier expression. + /// Used to provide more informative error diagnostics on resolution failure. + struct IdentifierResolveHint { + /// The expression this hint applies to + const ast::Expression* expression = nullptr; + /// The usage of the identifier. + const char* usage = "identifier"; + /// Suggested strings if the identifier failed to resolve + utils::Slice suggestions = utils::Empty; + }; + ProgramBuilder* const builder_; diag::List& diagnostics_; ConstEval const_eval_; @@ -555,6 +566,7 @@ class Resolver { utils::Hashmap logical_binary_lhs_to_parent_; utils::Hashset skip_const_eval_; + IdentifierResolveHint identifier_resolve_hint_; }; } // namespace tint::resolver diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc index b9f2edbd84..fe9622b6a1 100644 --- a/src/tint/resolver/type_validation_test.cc +++ b/src/tint/resolver/type_validation_test.cc @@ -1083,7 +1083,7 @@ TEST_P(StorageTextureDimensionTest, All) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: unknown type: '" + std::string(params.name) + "'"); + EXPECT_EQ(r()->error(), "12:34 error: unresolved type '" + std::string(params.name) + "'"); } } INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest, diff --git a/src/tint/resolver/unresolved_identifier_test.cc b/src/tint/resolver/unresolved_identifier_test.cc new file mode 100644 index 0000000000..e52b858b27 --- /dev/null +++ b/src/tint/resolver/unresolved_identifier_test.cc @@ -0,0 +1,108 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" + +#include "src/tint/resolver/resolver_test_helper.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::resolver { +namespace { + +using ResolverUnresolvedIdentifierSuggestions = ResolverTest; + +TEST_F(ResolverUnresolvedIdentifierSuggestions, AddressSpace) { + AST().AddGlobalVariable(create( + Ident("v"), // name + ty.i32(), // type + Expr(Source{{12, 34}}, "privte"), // declared_address_space + nullptr, // declared_access + nullptr, // initializer + utils::Empty // attributes + )); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved address space 'privte' +12:34 note: Did you mean 'private'? +Possible values: 'function', 'private', 'push_constant', 'storage', 'uniform', 'workgroup')"); +} + +TEST_F(ResolverUnresolvedIdentifierSuggestions, BuiltinValue) { + Func("f", + utils::Vector{ + Param("p", ty.i32(), utils::Vector{Builtin(Expr(Source{{12, 34}}, "positon"))})}, + ty.void_(), utils::Empty); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved builtin value 'positon' +12:34 note: Did you mean 'position'? +Possible values: 'frag_depth', 'front_facing', 'global_invocation_id', 'instance_index', 'local_invocation_id', 'local_invocation_index', 'num_workgroups', 'position', 'sample_index', 'sample_mask', 'vertex_index', 'workgroup_id')"); +} + +TEST_F(ResolverUnresolvedIdentifierSuggestions, TexelFormat) { + GlobalVar("v", ty("texture_storage_1d", Expr(Source{{12, 34}}, "rba8unorm"), "read")); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved texel format 'rba8unorm' +12:34 note: Did you mean 'rgba8unorm'? +Possible values: 'bgra8unorm', 'r32float', 'r32sint', 'r32uint', 'rg32float', 'rg32sint', 'rg32uint', 'rgba16float', 'rgba16sint', 'rgba16uint', 'rgba32float', 'rgba32sint', 'rgba32uint', 'rgba8sint', 'rgba8snorm', 'rgba8uint', 'rgba8unorm')"); +} + +TEST_F(ResolverUnresolvedIdentifierSuggestions, AccessMode) { + AST().AddGlobalVariable(create(Ident("v"), // name + ty.i32(), // type + Expr("private"), // declared_address_space + Expr(Source{{12, 34}}, "reed"), // declared_access + nullptr, // initializer + utils::Empty // attributes + )); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved access 'reed' +12:34 note: Did you mean 'read'? +Possible values: 'read', 'read_write', 'write')"); +} + +TEST_F(ResolverUnresolvedIdentifierSuggestions, InterpolationSampling) { + Structure("s", utils::Vector{ + Member("m", ty.vec4(), + utils::Vector{ + Interpolate(builtin::InterpolationType::kLinear, + Expr(Source{{12, 34}}, "centre")), + }), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved interpolation sampling 'centre' +12:34 note: Did you mean 'center'? +Possible values: 'center', 'centroid', 'sample')"); +} + +TEST_F(ResolverUnresolvedIdentifierSuggestions, InterpolationType) { + Structure("s", utils::Vector{ + Member("m", ty.vec4(), + utils::Vector{ + Interpolate(Expr(Source{{12, 34}}, "liner")), + }), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved interpolation type 'liner' +12:34 note: Did you mean 'linear'? +Possible values: 'flat', 'linear', 'perspective')"); +} + +} // namespace +} // namespace tint::resolver diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc index 1e806d3f1b..6485a6a331 100644 --- a/src/tint/resolver/validation_test.cc +++ b/src/tint/resolver/validation_test.cc @@ -168,7 +168,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariable_Fail) { WrapInFunction(assign); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'b')"); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'b')"); } TEST_F(ResolverValidationTest, UsingUndefinedVariableInBlockStatement_Fail) { @@ -183,7 +183,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableInBlockStatement_Fail) { WrapInFunction(body); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'b')"); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'b')"); } TEST_F(ResolverValidationTest, UsingUndefinedVariableGlobalVariable_Pass) { @@ -223,7 +223,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableInnerScope_Fail) { WrapInFunction(outer_body); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'a')"); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'a')"); } TEST_F(ResolverValidationTest, UsingUndefinedVariableOuterScope_Pass) { @@ -263,7 +263,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableDifferentScope_Fail) { WrapInFunction(outer_body); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'a')"); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'a')"); } TEST_F(ResolverValidationTest, AddressSpace_FunctionVariableWorkgroupClass) { diff --git a/src/tint/utils/slice.h b/src/tint/utils/slice.h new file mode 100644 index 0000000000..719a53e5ff --- /dev/null +++ b/src/tint/utils/slice.h @@ -0,0 +1,205 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_UTILS_SLICE_H_ +#define SRC_TINT_UTILS_SLICE_H_ + +#include +#include + +#include "src/tint/castable.h" +#include "src/tint/traits.h" + +namespace tint::utils { + +/// A type used to indicate an empty array. +struct EmptyType {}; + +/// An instance of the EmptyType. +static constexpr EmptyType Empty; + +/// Mode enumerator for ReinterpretSlice +enum class ReinterpretMode { + /// Only upcasts of pointers are permitted + kSafe, + /// Potentially unsafe downcasts of pointers are also permitted + kUnsafe, +}; + +namespace detail { + +template +static constexpr bool ConstRemoved = std::is_const_v && !std::is_const_v; + +/// Private implementation of tint::utils::CanReinterpretSlice. +/// Specialized for the case of TO equal to FROM, which is the common case, and avoids inspection of +/// the base classes, which can be troublesome if the slice is of an incomplete type. +template +struct CanReinterpretSlice { + private: + using TO_EL = std::remove_pointer_t>; + using FROM_EL = std::remove_pointer_t>; + + public: + /// @see utils::CanReinterpretSlice + static constexpr bool value = + // const can only be applied, not removed + !ConstRemoved && + + // Both TO and FROM are the same type (ignoring const) + (std::is_same_v, std::remove_const_t> || + + // Both TO and FROM are pointers... + ((std::is_pointer_v && std::is_pointer_v)&& + + // const can only be applied to element type, not removed + !ConstRemoved && + + // Either: + // * Both the pointer elements are of the same type (ignoring const) + // * Both the pointer elements are both Castable, and MODE is kUnsafe, or FROM is of, + // or + // derives from TO + (std::is_same_v, std::remove_const_t> || + (IsCastable && + (MODE == ReinterpretMode::kUnsafe || traits::IsTypeOrDerived))))); +}; + +/// Specialization of 'CanReinterpretSlice' for when TO and FROM are equal types. +template +struct CanReinterpretSlice { + /// Always `true` as TO and FROM are the same type. + static constexpr bool value = true; +}; + +} // namespace detail + +/// Evaluates whether a `Slice` and be reinterpreted as a `Slice`. +/// Slices can be reinterpreted if: +/// * TO has the same or more 'constness' than FROM. +/// * And either: +/// * `FROM` and `TO` are pointers to the same type +/// * `FROM` and `TO` are pointers to CastableBase (or derived), and the pointee type of `TO` is of +/// the same type as, or is an ancestor of the pointee type of `FROM`. +template +static constexpr bool CanReinterpretSlice = detail::CanReinterpretSlice::value; + +/// A slice represents a contigious array of elements of type T. +template +struct Slice { + /// Type of `T`. + using value_type = T; + + /// The pointer to the first element in the slice + T* data = nullptr; + + /// The total number of elements in the slice + size_t len = 0; + + /// The total capacity of the backing store for the slice + size_t cap = 0; + + /// Constructor + Slice() = default; + + /// Constructor + Slice(EmptyType) {} // NOLINT + + /// Constructor + /// @param d pointer to the first element in the slice + /// @param l total number of elements in the slice + /// @param c total capacity of the backing store for the slice + Slice(T* d, size_t l, size_t c) : data(d), len(l), cap(c) {} + + /// Constructor + /// @param elements c-array of elements + template + Slice(T (&elements)[N]) // NOLINT + : data(elements), len(N), cap(N) {} + + /// Reinterprets this slice as `const Slice&` + /// @returns the reinterpreted slice + /// @see CanReinterpretSlice + template + const Slice& Reinterpret() const { + static_assert(CanReinterpretSlice); + return *Bitcast*>(this); + } + + /// Reinterprets this slice as `Slice&` + /// @returns the reinterpreted slice + /// @see CanReinterpretSlice + template + Slice& Reinterpret() { + static_assert(CanReinterpretSlice); + return *Bitcast*>(this); + } + + /// @return true if the slice length is zero + bool IsEmpty() const { return len == 0; } + + /// Index operator + /// @param i the element index. Must be less than `len`. + /// @returns a reference to the i'th element. + T& operator[](size_t i) { return data[i]; } + + /// Index operator + /// @param i the element index. Must be less than `len`. + /// @returns a reference to the i'th element. + const T& operator[](size_t i) const { return data[i]; } + + /// @returns a reference to the first element in the vector + T& Front() { return data[0]; } + + /// @returns a reference to the first element in the vector + const T& Front() const { return data[0]; } + + /// @returns a reference to the last element in the vector + T& Back() { return data[len - 1]; } + + /// @returns a reference to the last element in the vector + const T& Back() const { return data[len - 1]; } + + /// @returns a pointer to the first element in the vector + T* begin() { return data; } + + /// @returns a pointer to the first element in the vector + const T* begin() const { return data; } + + /// @returns a pointer to one past the last element in the vector + T* end() { return data + len; } + + /// @returns a pointer to one past the last element in the vector + const T* end() const { return data + len; } + + /// @returns a reverse iterator starting with the last element in the vector + auto rbegin() { return std::reverse_iterator(end()); } + + /// @returns a reverse iterator starting with the last element in the vector + auto rbegin() const { return std::reverse_iterator(end()); } + + /// @returns the end for a reverse iterator + auto rend() { return std::reverse_iterator(begin()); } + + /// @returns the end for a reverse iterator + auto rend() const { return std::reverse_iterator(begin()); } +}; + +/// Deduction guide for Slice from c-array +template +Slice(T (&elements)[N]) -> Slice; + +} // namespace tint::utils + +#endif // SRC_TINT_UTILS_SLICE_H_ diff --git a/src/tint/utils/slice_test.cc b/src/tint/utils/slice_test.cc new file mode 100644 index 0000000000..6a4493cb9c --- /dev/null +++ b/src/tint/utils/slice_test.cc @@ -0,0 +1,131 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/utils/slice.h" + +#include "gmock/gmock.h" + +namespace tint::utils { +namespace { + +class C0 : public Castable {}; +class C1 : public Castable {}; +class C2a : public Castable {}; +class C2b : public Castable {}; + +//////////////////////////////////////////////////////////////////////////////// +// Static asserts +//////////////////////////////////////////////////////////////////////////////// +// Non-pointer +static_assert(CanReinterpretSlice, "same type"); +static_assert(CanReinterpretSlice, "apply const"); +static_assert(!CanReinterpretSlice, "remove const"); + +// Non-castable pointers +static_assert(CanReinterpretSlice, "apply ptr const"); +static_assert(!CanReinterpretSlice, "remove ptr const"); +static_assert(CanReinterpretSlice, "apply el const"); +static_assert(!CanReinterpretSlice, "remove el const"); + +// Castable +static_assert(CanReinterpretSlice, "apply const"); +static_assert(!CanReinterpretSlice, "remove const"); +static_assert(CanReinterpretSlice, "up cast"); +static_assert(CanReinterpretSlice, "up cast"); +static_assert(CanReinterpretSlice, "up cast, apply const"); +static_assert(!CanReinterpretSlice, + "up cast, remove const"); +static_assert(!CanReinterpretSlice, "down cast"); +static_assert(!CanReinterpretSlice, "down cast"); +static_assert(!CanReinterpretSlice, + "down cast, apply const"); +static_assert(!CanReinterpretSlice, + "down cast, remove const"); +static_assert(!CanReinterpretSlice, + "down cast, apply const"); +static_assert(!CanReinterpretSlice, + "down cast, remove const"); +static_assert(!CanReinterpretSlice, "sideways cast"); +static_assert(!CanReinterpretSlice, + "sideways cast"); +static_assert(!CanReinterpretSlice, + "sideways cast, apply const"); +static_assert(!CanReinterpretSlice, + "sideways cast, remove const"); + +TEST(TintSliceTest, Ctor) { + Slice slice; + EXPECT_EQ(slice.data, nullptr); + EXPECT_EQ(slice.len, 0u); + EXPECT_EQ(slice.cap, 0u); + EXPECT_TRUE(slice.IsEmpty()); +} + +TEST(TintSliceTest, CtorEmpty) { + Slice slice{Empty}; + EXPECT_EQ(slice.data, nullptr); + EXPECT_EQ(slice.len, 0u); + EXPECT_EQ(slice.cap, 0u); + EXPECT_TRUE(slice.IsEmpty()); +} + +TEST(TintSliceTest, CtorCArray) { + int elements[] = {1, 2, 3}; + + auto slice = Slice{elements}; + EXPECT_EQ(slice.data, elements); + EXPECT_EQ(slice.len, 3u); + EXPECT_EQ(slice.cap, 3u); + EXPECT_FALSE(slice.IsEmpty()); +} + +TEST(TintSliceTest, Index) { + int elements[] = {1, 2, 3}; + + auto slice = Slice{elements}; + EXPECT_EQ(slice[0], 1); + EXPECT_EQ(slice[1], 2); + EXPECT_EQ(slice[2], 3); +} + +TEST(TintSliceTest, Front) { + int elements[] = {1, 2, 3}; + auto slice = Slice{elements}; + EXPECT_EQ(slice.Front(), 1); +} + +TEST(TintSliceTest, Back) { + int elements[] = {1, 2, 3}; + auto slice = Slice{elements}; + EXPECT_EQ(slice.Back(), 3); +} + +TEST(TintSliceTest, BeginEnd) { + int elements[] = {1, 2, 3}; + auto slice = Slice{elements}; + EXPECT_THAT(slice, testing::ElementsAre(1, 2, 3)); +} + +TEST(TintSliceTest, ReverseBeginEnd) { + int elements[] = {1, 2, 3}; + auto slice = Slice{elements}; + size_t i = 0; + for (auto it = slice.rbegin(); it != slice.rend(); it++) { + EXPECT_EQ(*it, elements[2 - i]); + i++; + } +} + +} // namespace +} // namespace tint::utils diff --git a/src/tint/utils/string.cc b/src/tint/utils/string.cc index 354ad96062..bccd52c23a 100644 --- a/src/tint/utils/string.cc +++ b/src/tint/utils/string.cc @@ -48,4 +48,36 @@ size_t Distance(std::string_view str_a, std::string_view str_b) { return at(len_a, len_b); } +void SuggestAlternatives(std::string_view got, + Slice strings, + std::ostringstream& ss) { + // If the string typed was within kSuggestionDistance of one of the possible enum values, + // suggest that. Don't bother with suggestions if the string was extremely long. + constexpr size_t kSuggestionDistance = 5; + constexpr size_t kSuggestionMaxLength = 64; + if (!got.empty() && got.size() < kSuggestionMaxLength) { + size_t candidate_dist = kSuggestionDistance; + const char* candidate = nullptr; + for (auto* str : strings) { + auto dist = utils::Distance(str, got); + if (dist < candidate_dist) { + candidate = str; + candidate_dist = dist; + } + } + if (candidate) { + ss << "Did you mean '" << candidate << "'?\n"; + } + } + + // List all the possible enumerator values + ss << "Possible values: "; + for (auto* str : strings) { + if (str != strings[0]) { + ss << ", "; + } + ss << "'" << str << "'"; + } +} + } // namespace tint::utils diff --git a/src/tint/utils/string.h b/src/tint/utils/string.h index 5c05ebfe04..08e0be3138 100644 --- a/src/tint/utils/string.h +++ b/src/tint/utils/string.h @@ -19,6 +19,8 @@ #include #include +#include "src/tint/utils/slice.h" + namespace tint::utils { /// @param str the string to apply replacements to @@ -66,42 +68,13 @@ inline size_t HasPrefix(std::string_view str, std::string_view prefix) { /// @returns the Levenshtein distance between @p a and @p b size_t Distance(std::string_view a, std::string_view b); -/// Suggest alternatives for an unrecognized string from a list of expected values. +/// Suggest alternatives for an unrecognized string from a list of possible values. /// @param got the unrecognized string -/// @param strings the list of expected values +/// @param strings the list of possible values /// @param ss the stream to write the suggest and list of possible values to -template void SuggestAlternatives(std::string_view got, - const char* const (&strings)[N], - std::ostringstream& ss) { - // If the string typed was within kSuggestionDistance of one of the possible enum values, - // suggest that. Don't bother with suggestions if the string was extremely long. - constexpr size_t kSuggestionDistance = 5; - constexpr size_t kSuggestionMaxLength = 64; - if (!got.empty() && got.size() < kSuggestionMaxLength) { - size_t candidate_dist = kSuggestionDistance; - const char* candidate = nullptr; - for (auto* str : strings) { - auto dist = utils::Distance(str, got); - if (dist < candidate_dist) { - candidate = str; - candidate_dist = dist; - } - } - if (candidate) { - ss << "Did you mean '" << candidate << "'?\n"; - } - } - - // List all the possible enumerator values - ss << "Possible values: "; - for (auto* str : strings) { - if (str != strings[0]) { - ss << ", "; - } - ss << "'" << str << "'"; - } -} + Slice strings, + std::ostringstream& ss); } // namespace tint::utils diff --git a/src/tint/utils/string_test.cc b/src/tint/utils/string_test.cc index 6b17dfb409..b9e1ebb842 100644 --- a/src/tint/utils/string_test.cc +++ b/src/tint/utils/string_test.cc @@ -60,14 +60,16 @@ TEST(StringTest, Distance) { TEST(StringTest, SuggestAlternatives) { { + const char* alternatives[] = {"hello world", "Hello World"}; std::ostringstream ss; - SuggestAlternatives("hello wordl", {"hello world", "Hello World"}, ss); + SuggestAlternatives("hello wordl", alternatives, ss); EXPECT_EQ(ss.str(), R"(Did you mean 'hello world'? Possible values: 'hello world', 'Hello World')"); } { + const char* alternatives[] = {"foobar", "something else"}; std::ostringstream ss; - SuggestAlternatives("hello world", {"foobar", "something else"}, ss); + SuggestAlternatives("hello world", alternatives, ss); EXPECT_EQ(ss.str(), R"(Possible values: 'foobar', 'something else')"); } } diff --git a/src/tint/utils/vector.h b/src/tint/utils/vector.h index cca2409f7b..e39854aaf6 100644 --- a/src/tint/utils/vector.h +++ b/src/tint/utils/vector.h @@ -19,14 +19,14 @@ #include #include #include +#include #include #include #include -#include "src/tint/castable.h" -#include "src/tint/traits.h" #include "src/tint/utils/bitcast.h" #include "src/tint/utils/compiler_macros.h" +#include "src/tint/utils/slice.h" #include "src/tint/utils/string.h" namespace tint::utils { @@ -41,137 +41,6 @@ class VectorRef; namespace tint::utils { -/// A type used to indicate an empty array. -struct EmptyType {}; - -/// An instance of the EmptyType. -static constexpr EmptyType Empty; - -/// A slice represents a contigious array of elements of type T. -template -struct Slice { - /// The pointer to the first element in the slice - T* data = nullptr; - - /// The total number of elements in the slice - size_t len = 0; - - /// The total capacity of the backing store for the slice - size_t cap = 0; - - /// Index operator - /// @param i the element index. Must be less than `len`. - /// @returns a reference to the i'th element. - T& operator[](size_t i) { return data[i]; } - - /// Index operator - /// @param i the element index. Must be less than `len`. - /// @returns a reference to the i'th element. - const T& operator[](size_t i) const { return data[i]; } - - /// @returns a reference to the first element in the vector - T& Front() { return data[0]; } - - /// @returns a reference to the first element in the vector - const T& Front() const { return data[0]; } - - /// @returns a reference to the last element in the vector - T& Back() { return data[len - 1]; } - - /// @returns a reference to the last element in the vector - const T& Back() const { return data[len - 1]; } - - /// @returns a pointer to the first element in the vector - T* begin() { return data; } - - /// @returns a pointer to the first element in the vector - const T* begin() const { return data; } - - /// @returns a pointer to one past the last element in the vector - T* end() { return data + len; } - - /// @returns a pointer to one past the last element in the vector - const T* end() const { return data + len; } - - /// @returns a reverse iterator starting with the last element in the vector - auto rbegin() { return std::reverse_iterator(end()); } - - /// @returns a reverse iterator starting with the last element in the vector - auto rbegin() const { return std::reverse_iterator(end()); } - - /// @returns the end for a reverse iterator - auto rend() { return std::reverse_iterator(begin()); } - - /// @returns the end for a reverse iterator - auto rend() const { return std::reverse_iterator(begin()); } -}; - -/// Mode enumerator for ReinterpretSlice -enum class ReinterpretMode { - /// Only upcasts of pointers are permitted - kSafe, - /// Potentially unsafe downcasts of pointers are also permitted - kUnsafe, -}; - -namespace detail { - -/// Private implementation of tint::utils::CanReinterpretSlice. -/// Specialized for the case of TO equal to FROM, which is the common case, and avoids inspection of -/// the base classes, which can be troublesome if the slice is of an incomplete type. -template -struct CanReinterpretSlice { - /// True if a slice of FROM can be reinterpreted as a slice of TO - static constexpr bool value = - // Both TO and FROM are pointers - (std::is_pointer_v && std::is_pointer_v)&& // - // const can only be applied, not removed - (std::is_const_v> || - !std::is_const_v>)&& // - // TO and FROM are both Castable - IsCastable, std::remove_pointer_t> && // - // MODE is kUnsafe, or FROM is of, or derives from TO - (MODE == ReinterpretMode::kUnsafe || - traits::IsTypeOrDerived, std::remove_pointer_t>); -}; - -/// Specialization of 'CanReinterpretSlice' for when TO and FROM are equal types. -template -struct CanReinterpretSlice { - /// Always `true` as TO and FROM are the same type. - static constexpr bool value = true; -}; - -} // namespace detail - -/// Evaluates whether a `vector` and be reinterpreted as a `vector`. -/// Vectors can be reinterpreted if both `FROM` and `TO` are pointers to a type that derives from -/// CastableBase, and the pointee type of `TO` is of the same type as, or is an ancestor of the -/// pointee type of `FROM`. Vectors of non-`const` Castable pointers can be converted to a vector of -/// `const` Castable pointers. -template -static constexpr bool CanReinterpretSlice = detail::CanReinterpretSlice::value; - -/// Reinterprets `const Slice*` as `const Slice*` -/// @param slice a pointer to the slice to reinterpret -/// @returns the reinterpreted slice -/// @see CanReinterpretSlice -template -const Slice* ReinterpretSlice(const Slice* slice) { - static_assert(CanReinterpretSlice); - return Bitcast*>(slice); -} - -/// Reinterprets `Slice*` as `Slice*` -/// @param slice a pointer to the slice to reinterpret -/// @returns the reinterpreted slice -/// @see CanReinterpretSlice -template -Slice* ReinterpretSlice(Slice* slice) { - static_assert(CanReinterpretSlice); - return Bitcast*>(slice); -} - /// Vector is a small-object-optimized, dynamically-sized vector of contigious elements of type T. /// /// Vector will fit `N` elements internally before spilling to heap allocations. If `N` is greater @@ -244,7 +113,7 @@ class Vector { ReinterpretMode MODE, typename = std::enable_if_t>> Vector(const Vector& other) { // NOLINT(runtime/explicit) - Copy(*ReinterpretSlice(&other.impl_.slice)); + Copy(other.impl_.slice.template Reinterpret); } /// Move constructor with covariance / const conversion @@ -518,6 +387,9 @@ class Vector { return !(*this == other); } + /// @returns the internal slice of the vector + utils::Slice Slice() { return impl_.slice; } + private: /// Friend class (differing specializations of this class) template @@ -531,9 +403,6 @@ class Vector { template friend class VectorRef; - /// The slice type used by this vector - using Slice = utils::Slice; - template void AppendVariadic(Ts&&... args) { ((new (&impl_.slice.data[impl_.slice.len++]) T(std::forward(args))), ...); @@ -555,7 +424,7 @@ class Vector { /// Copies all the elements from `other` to this vector, replacing the content of this vector. /// @param other the - void Copy(const Slice& other) { + void Copy(const utils::Slice& other) { if (impl_.slice.cap < other.len) { ClearAndFree(); impl_.Allocate(other.len); @@ -592,7 +461,7 @@ class Vector { /// The internal structure for the vector with a small array. struct ImplWithSmallArray { TStorage small_arr[N]; - Slice slice = {small_arr[0].Get(), 0, N}; + utils::Slice slice = {small_arr[0].Get(), 0, N}; /// Allocates a new vector of `T` either from #small_arr, or from the heap, then assigns the /// pointer it to #slice.data, and updates #slice.cap. @@ -620,7 +489,7 @@ class Vector { /// The internal structure for the vector without a small array. struct ImplWithoutSmallArray { - Slice slice = {nullptr, 0, 0}; + utils::Slice slice = Empty; /// Allocates a new vector of `T` and assigns it to #slice.data, and updates #slice.cap. void Allocate(size_t new_cap) { @@ -759,15 +628,14 @@ class VectorRef { template >> VectorRef(const VectorRef& other) // NOLINT(runtime/explicit) - : slice_(*ReinterpretSlice(&other.slice_)) {} + : slice_(other.slice_.template Reinterpret()) {} /// Move constructor with covariance / const conversion /// @param other the vector reference template >> VectorRef(VectorRef&& other) // NOLINT(runtime/explicit) - : slice_(*ReinterpretSlice(&other.slice_)), - can_move_(other.can_move_) {} + : slice_(other.slice_.template Reinterpret()), can_move_(other.can_move_) {} /// Constructor from a Vector with covariance / const conversion /// @param vector the vector to create a reference of @@ -776,7 +644,7 @@ class VectorRef { size_t N, typename = std::enable_if_t>> VectorRef(Vector& vector) // NOLINT(runtime/explicit) - : slice_(*ReinterpretSlice(&vector.impl_.slice)) {} + : slice_(vector.impl_.slice.template Reinterpret()) {} /// Constructor from a moved Vector with covariance / const conversion /// @param vector the vector to create a reference of @@ -785,8 +653,7 @@ class VectorRef { size_t N, typename = std::enable_if_t>> VectorRef(Vector&& vector) // NOLINT(runtime/explicit) - : slice_(*ReinterpretSlice(&vector.impl_.slice)), - can_move_(vector.impl_.CanMove()) {} + : slice_(vector.impl_.slice.template Reinterpret()), can_move_(vector.impl_.CanMove()) {} /// Index operator /// @param i the element index. Must be less than `len`. @@ -805,7 +672,7 @@ class VectorRef { /// this is a safe operation. template VectorRef ReinterpretCast() const { - return {*ReinterpretSlice(&slice_)}; + return {slice_.template Reinterpret()}; } /// @returns true if the vector is empty. diff --git a/src/tint/utils/vector_test.cc b/src/tint/utils/vector_test.cc index ed9a97a48c..6245476053 100644 --- a/src/tint/utils/vector_test.cc +++ b/src/tint/utils/vector_test.cc @@ -79,31 +79,6 @@ static_assert(std::is_same_v, const C1*>); static_assert(std::is_same_v, const C1*>); static_assert(std::is_same_v, const C1*>); -static_assert(CanReinterpretSlice, "apply const"); -static_assert(!CanReinterpretSlice, "remove const"); -static_assert(CanReinterpretSlice, "up cast"); -static_assert(CanReinterpretSlice, "up cast"); -static_assert(CanReinterpretSlice, "up cast, apply const"); -static_assert(!CanReinterpretSlice, - "up cast, remove const"); -static_assert(!CanReinterpretSlice, "down cast"); -static_assert(!CanReinterpretSlice, "down cast"); -static_assert(!CanReinterpretSlice, - "down cast, apply const"); -static_assert(!CanReinterpretSlice, - "down cast, remove const"); -static_assert(!CanReinterpretSlice, - "down cast, apply const"); -static_assert(!CanReinterpretSlice, - "down cast, remove const"); -static_assert(!CanReinterpretSlice, "sideways cast"); -static_assert(!CanReinterpretSlice, - "sideways cast"); -static_assert(!CanReinterpretSlice, - "sideways cast, apply const"); -static_assert(!CanReinterpretSlice, - "sideways cast, remove const"); - //////////////////////////////////////////////////////////////////////////////// // TintVectorTest ////////////////////////////////////////////////////////////////////////////////