diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 50b468cc28..227893caf4 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -218,6 +218,7 @@ libtint_source_set("libtint_base_src") { "utils/map.h", "utils/math.h", "utils/scoped_assignment.h", + "utils/slice.h", "utils/string.cc", "utils/string.h", "utils/unique_allocator.h", @@ -1402,6 +1403,7 @@ if (tint_build_unittests) { "resolver/type_initializer_validation_test.cc", "resolver/type_validation_test.cc", "resolver/uniformity_test.cc", + "resolver/unresolved_identifier_test.cc", "resolver/validation_test.cc", "resolver/validator_is_storeable_test.cc", "resolver/variable_test.cc", @@ -1553,6 +1555,7 @@ if (tint_build_unittests) { "utils/result_test.cc", "utils/reverse_test.cc", "utils/scoped_assignment_test.cc", + "utils/slice_test.cc", "utils/string_test.cc", "utils/transform_test.cc", "utils/unique_allocator_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 91c0851ede..b8a803e6f7 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -527,6 +527,7 @@ list(APPEND TINT_LIB_SRCS utils/map.h utils/math.h utils/scoped_assignment.h + utils/slice.h utils/string.cc utils/string.h utils/unique_allocator.h @@ -927,6 +928,7 @@ if(TINT_BUILD_TESTS) resolver/struct_address_space_use_test.cc resolver/type_initializer_validation_test.cc resolver/type_validation_test.cc + resolver/unresolved_identifier_test.cc resolver/validation_test.cc resolver/validator_is_storeable_test.cc resolver/variable_test.cc @@ -981,6 +983,7 @@ if(TINT_BUILD_TESTS) utils/result_test.cc utils/reverse_test.cc utils/scoped_assignment_test.cc + utils/slice_test.cc utils/string_test.cc utils/transform_test.cc utils/unique_allocator_test.cc diff --git a/src/tint/resolver/dependency_graph.cc b/src/tint/resolver/dependency_graph.cc index e0f0656664..0b8f728934 100644 --- a/src/tint/resolver/dependency_graph.cc +++ b/src/tint/resolver/dependency_graph.cc @@ -79,8 +79,6 @@ struct Global; struct DependencyInfo { /// The source of the symbol that forms the dependency Source source; - /// A string describing how the dependency is referenced. e.g. 'calls' - const char* action = nullptr; }; /// DependencyEdge describes the two Globals used to define a dependency @@ -174,12 +172,12 @@ class DependencyScanner { Declare(str->name->symbol, str); for (auto* member : str->members) { TraverseAttributes(member->attributes); - TraverseTypeExpression(member->type); + TraverseExpression(member->type); } }, [&](const ast::Alias* alias) { Declare(alias->name->symbol, alias); - TraverseTypeExpression(alias->type); + TraverseExpression(alias->type); }, [&](const ast::Function* func) { Declare(func->name->symbol, func); @@ -195,9 +193,7 @@ class DependencyScanner { [&](const ast::Enable*) { // Enable directives do not affect the dependency graph. }, - [&](const ast::ConstAssert* assertion) { - TraverseValueExpression(assertion->condition); - }, + [&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); }, [&](Default) { UnhandledNode(diagnostics_, global->node); }); } @@ -205,12 +201,12 @@ class DependencyScanner { /// Traverses the variable, performing symbol resolution. void TraverseVariable(const ast::Variable* v) { if (auto* var = v->As()) { - TraverseAddressSpaceExpression(var->declared_address_space); - TraverseAccessExpression(var->declared_access); + TraverseExpression(var->declared_address_space); + TraverseExpression(var->declared_access); } - TraverseTypeExpression(v->type); + TraverseExpression(v->type); TraverseAttributes(v->attributes); - TraverseValueExpression(v->initializer); + TraverseExpression(v->initializer); } /// Traverses the function, performing symbol resolution and determining global dependencies. @@ -222,10 +218,10 @@ class DependencyScanner { // with the same identifier as its type. for (auto* param : func->params) { TraverseAttributes(param->attributes); - TraverseTypeExpression(param->type); + TraverseExpression(param->type); } // Resolve the return type - TraverseTypeExpression(func->return_type); + TraverseExpression(func->return_type); // Push the scope stack for the parameters and function body. scope_stack_.Push(); @@ -259,29 +255,29 @@ class DependencyScanner { Switch( stmt, // [&](const ast::AssignmentStatement* a) { - TraverseValueExpression(a->lhs); - TraverseValueExpression(a->rhs); + TraverseExpression(a->lhs); + TraverseExpression(a->rhs); }, [&](const ast::BlockStatement* b) { scope_stack_.Push(); TINT_DEFER(scope_stack_.Pop()); TraverseStatements(b->statements); }, - [&](const ast::BreakIfStatement* b) { TraverseValueExpression(b->condition); }, - [&](const ast::CallStatement* r) { TraverseValueExpression(r->expr); }, + [&](const ast::BreakIfStatement* b) { TraverseExpression(b->condition); }, + [&](const ast::CallStatement* r) { TraverseExpression(r->expr); }, [&](const ast::CompoundAssignmentStatement* a) { - TraverseValueExpression(a->lhs); - TraverseValueExpression(a->rhs); + TraverseExpression(a->lhs); + TraverseExpression(a->rhs); }, [&](const ast::ForLoopStatement* l) { scope_stack_.Push(); TINT_DEFER(scope_stack_.Pop()); TraverseStatement(l->initializer); - TraverseValueExpression(l->condition); + TraverseExpression(l->condition); TraverseStatement(l->continuing); TraverseStatement(l->body); }, - [&](const ast::IncrementDecrementStatement* i) { TraverseValueExpression(i->lhs); }, + [&](const ast::IncrementDecrementStatement* i) { TraverseExpression(i->lhs); }, [&](const ast::LoopStatement* l) { scope_stack_.Push(); TINT_DEFER(scope_stack_.Pop()); @@ -289,18 +285,18 @@ class DependencyScanner { TraverseStatement(l->continuing); }, [&](const ast::IfStatement* i) { - TraverseValueExpression(i->condition); + TraverseExpression(i->condition); TraverseStatement(i->body); if (i->else_statement) { TraverseStatement(i->else_statement); } }, - [&](const ast::ReturnStatement* r) { TraverseValueExpression(r->value); }, + [&](const ast::ReturnStatement* r) { TraverseExpression(r->value); }, [&](const ast::SwitchStatement* s) { - TraverseValueExpression(s->condition); + TraverseExpression(s->condition); for (auto* c : s->body) { for (auto* sel : c->selectors) { - TraverseValueExpression(sel->expr); + TraverseExpression(sel->expr); } TraverseStatement(c->body); } @@ -315,12 +311,10 @@ class DependencyScanner { [&](const ast::WhileStatement* w) { scope_stack_.Push(); TINT_DEFER(scope_stack_.Pop()); - TraverseValueExpression(w->condition); + TraverseExpression(w->condition); TraverseStatement(w->body); }, - [&](const ast::ConstAssert* assertion) { - TraverseValueExpression(assertion->condition); - }, + [&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); }, [&](Default) { if (TINT_UNLIKELY((!stmt->IsAnyOf()))) { @@ -340,70 +334,28 @@ class DependencyScanner { } } - /// Traverses the expression @p root_expr for the intended use as a value, performing symbol - /// resolution and determining global dependencies. - void TraverseValueExpression(const ast::Expression* root) { - TraverseExpression(root, "identifier", "references"); - } - - /// Traverses the expression @p root_expr for the intended use as a type, performing symbol - /// resolution and determining global dependencies. - void TraverseTypeExpression(const ast::Expression* root) { - TraverseExpression(root, "type", "references"); - } - - /// Traverses the expression @p root_expr for the intended use as an address space, performing - /// symbol resolution and determining global dependencies. - void TraverseAddressSpaceExpression(const ast::Expression* root) { - TraverseExpression(root, "address space", "references"); - } - - /// Traverses the expression @p root_expr for the intended use as an access, performing symbol - /// resolution and determining global dependencies. - void TraverseAccessExpression(const ast::Expression* root) { - TraverseExpression(root, "access", "references"); - } - - /// Traverses the expression @p root_expr for the intended use as a call target, performing - /// symbol resolution and determining global dependencies. - void TraverseCallableExpression(const ast::Expression* root) { - TraverseExpression(root, "function", "calls"); - } - /// Traverses the expression @p root_expr, performing symbol resolution and determining global /// dependencies. - void TraverseExpression(const ast::Expression* root_expr, - const char* root_use, - const char* root_action) { + void TraverseExpression(const ast::Expression* root_expr) { if (!root_expr) { return; } - struct Pending { - const ast::Expression* expr; - const char* use; - const char* action; - }; - utils::Vector pending{{root_expr, root_use, root_action}}; + utils::Vector pending{root_expr}; while (!pending.IsEmpty()) { - auto next = pending.Pop(); - ast::TraverseExpressions(next.expr, diagnostics_, [&](const ast::Expression* expr) { + ast::TraverseExpressions(pending.Pop(), diagnostics_, [&](const ast::Expression* expr) { Switch( expr, [&](const ast::IdentifierExpression* e) { - AddDependency(e->identifier, e->identifier->symbol, next.use, next.action); + AddDependency(e->identifier, e->identifier->symbol); if (auto* tmpl_ident = e->identifier->As()) { for (auto* arg : tmpl_ident->arguments) { - pending.Push({arg, "identifier", "references"}); + pending.Push(arg); } } }, - [&](const ast::CallExpression* call) { - TraverseCallableExpression(call->target); - }, - [&](const ast::BitcastExpression* cast) { - TraverseTypeExpression(cast->type); - }); + [&](const ast::CallExpression* call) { TraverseExpression(call->target); }, + [&](const ast::BitcastExpression* cast) { TraverseExpression(cast->type); }); return ast::TraverseAction::Descend; }); } @@ -423,42 +375,42 @@ class DependencyScanner { bool handled = Switch( attr, [&](const ast::BindingAttribute* binding) { - TraverseValueExpression(binding->expr); + TraverseExpression(binding->expr); return true; }, [&](const ast::BuiltinAttribute* builtin) { - TraverseExpression(builtin->builtin, "builtin", "references"); + TraverseExpression(builtin->builtin); return true; }, [&](const ast::GroupAttribute* group) { - TraverseValueExpression(group->expr); + TraverseExpression(group->expr); return true; }, [&](const ast::IdAttribute* id) { - TraverseValueExpression(id->expr); + TraverseExpression(id->expr); return true; }, [&](const ast::InterpolateAttribute* interpolate) { - TraverseExpression(interpolate->type, "interpolation type", "references"); - TraverseExpression(interpolate->sampling, "interpolation sampling", "references"); + TraverseExpression(interpolate->type); + TraverseExpression(interpolate->sampling); return true; }, [&](const ast::LocationAttribute* loc) { - TraverseValueExpression(loc->expr); + TraverseExpression(loc->expr); return true; }, [&](const ast::StructMemberAlignAttribute* align) { - TraverseValueExpression(align->expr); + TraverseExpression(align->expr); return true; }, [&](const ast::StructMemberSizeAttribute* size) { - TraverseValueExpression(size->expr); + TraverseExpression(size->expr); return true; }, [&](const ast::WorkgroupAttribute* wg) { - TraverseValueExpression(wg->x); - TraverseValueExpression(wg->y); - TraverseValueExpression(wg->z); + TraverseExpression(wg->x); + TraverseExpression(wg->y); + TraverseExpression(wg->z); return true; }); if (handled) { @@ -476,10 +428,7 @@ class DependencyScanner { } /// Adds the dependency from @p from to @p to, erroring if @p to cannot be resolved. - void AddDependency(const ast::Identifier* from, - Symbol to, - const char* use, - const char* action) { + void AddDependency(const ast::Identifier* from, Symbol to) { auto* resolved = scope_stack_.Get(to); if (!resolved) { auto s = symbols_.NameFor(to); @@ -521,13 +470,14 @@ class DependencyScanner { return; } - UnknownSymbol(to, from->source, use); + // Unresolved. + graph_.resolved_identifiers.Add(from, UnresolvedIdentifier{s}); return; } if (auto global = globals_.Find(to); global && (*global)->node == resolved) { if (dependency_edges_.Add(DependencyEdge{current_global_, *global}, - DependencyInfo{from->source, action})) { + DependencyInfo{from->source})) { current_global_->deps.Push(*global); } } @@ -535,12 +485,6 @@ class DependencyScanner { graph_.resolved_identifiers.Add(from, ResolvedIdentifier(resolved)); } - /// Appends an error to the diagnostics that the given symbol cannot be resolved. - void UnknownSymbol(Symbol name, Source source, const char* use) { - AddError(diagnostics_, "unknown " + std::string(use) + ": '" + symbols_.NameFor(name) + "'", - source); - } - using VariableMap = utils::Hashmap; const SymbolTable& symbols_; const GlobalMap& globals_; @@ -787,7 +731,7 @@ struct DependencyAnalysis { auto* to = (i + 1 < stack.Length()) ? stack[i + 1] : stack[loop_start]; auto info = DepInfoFor(from, to); AddNote(diagnostics_, - KindOf(from->node) + " '" + NameOf(from->node) + "' " + info.action + " " + + KindOf(from->node) + " '" + NameOf(from->node) + "' references " + KindOf(to->node) + " '" + NameOf(to->node) + "' here", info.source); } @@ -831,8 +775,7 @@ struct DependencyAnalysis { /// Global map, keyed by name. Populated by GatherGlobals(). GlobalMap globals_; - /// Map of DependencyEdge to DependencyInfo. Populated by - /// DetermineDependencies(). + /// Map of DependencyEdge to DependencyInfo. Populated by DetermineDependencies(). DependencyEdges dependency_edges_; /// Globals in declaration order. Populated by GatherGlobals(). @@ -857,9 +800,6 @@ bool DependencyGraph::Build(const ast::Module& module, } std::string ResolvedIdentifier::String(const SymbolTable& symbols, diag::List& diagnostics) const { - if (!Resolved()) { - return ""; - } if (auto* node = Node()) { return Switch( node, @@ -911,6 +851,10 @@ std::string ResolvedIdentifier::String(const SymbolTable& symbols, diag::List& d if (auto fmt = TexelFormat(); fmt != builtin::TexelFormat::kUndefined) { return "texel format '" + utils::ToString(fmt) + "'"; } + if (auto* unresolved = Unresolved()) { + return "unresolved identifier '" + unresolved->name + "'"; + } + TINT_UNREACHABLE(Resolver, diagnostics) << "unhandled ResolvedIdentifier"; return ""; } diff --git a/src/tint/resolver/dependency_graph.h b/src/tint/resolver/dependency_graph.h index da25383f16..429a47afae 100644 --- a/src/tint/resolver/dependency_graph.h +++ b/src/tint/resolver/dependency_graph.h @@ -32,8 +32,15 @@ namespace tint::resolver { +/// UnresolvedIdentifier is the variant value used by ResolvedIdentifier +struct UnresolvedIdentifier { + /// Name of the unresolved identifier + std::string name; +}; + /// ResolvedIdentifier holds the resolution of an ast::Identifier. /// Can hold one of: +/// - UnresolvedIdentifier /// - const ast::TypeDecl* (as const ast::Node*) /// - const ast::Variable* (as const ast::Node*) /// - const ast::Function* (as const ast::Node*) @@ -47,15 +54,18 @@ namespace tint::resolver { /// - builtin::TexelFormat class ResolvedIdentifier { public: - ResolvedIdentifier() = default; - /// Constructor /// @param value the resolved identifier value template ResolvedIdentifier(T value) : value_(value) {} // NOLINT(runtime/explicit) - /// @return true if the ResolvedIdentifier holds a value (successfully resolved) - bool Resolved() const { return !std::holds_alternative(value_); } + /// @return the UnresolvedIdentifier if the identifier was not resolved + const UnresolvedIdentifier* Unresolved() const { + if (auto n = std::get_if(&value_)) { + return n; + } + return nullptr; + } /// @return the node pointer if the ResolvedIdentifier holds an AST node, otherwise nullptr const ast::Node* Node() const { @@ -160,7 +170,7 @@ class ResolvedIdentifier { std::string String(const SymbolTable& symbols, diag::List& diagnostics) const; private: - std::variant"; } -/// @returns the the diagnostic message name used for the given use -std::string DiagString(SymbolUseKind kind) { - switch (kind) { - case SymbolUseKind::GlobalVarType: - case SymbolUseKind::GlobalConstType: - case SymbolUseKind::AliasType: - case SymbolUseKind::StructMemberType: - case SymbolUseKind::ParameterType: - case SymbolUseKind::LocalVarType: - case SymbolUseKind::LocalLetType: - case SymbolUseKind::NestedLocalVarType: - case SymbolUseKind::NestedLocalLetType: - return "type"; - case SymbolUseKind::GlobalVarArrayElemType: - case SymbolUseKind::GlobalVarVectorElemType: - case SymbolUseKind::GlobalVarMatrixElemType: - case SymbolUseKind::GlobalVarSampledTexElemType: - case SymbolUseKind::GlobalVarMultisampledTexElemType: - case SymbolUseKind::GlobalConstArrayElemType: - case SymbolUseKind::GlobalConstVectorElemType: - case SymbolUseKind::GlobalConstMatrixElemType: - case SymbolUseKind::LocalVarArrayElemType: - case SymbolUseKind::LocalVarVectorElemType: - case SymbolUseKind::LocalVarMatrixElemType: - case SymbolUseKind::GlobalVarValue: - case SymbolUseKind::GlobalVarArraySizeValue: - case SymbolUseKind::GlobalConstValue: - case SymbolUseKind::GlobalConstArraySizeValue: - case SymbolUseKind::LocalVarValue: - case SymbolUseKind::LocalVarArraySizeValue: - case SymbolUseKind::LocalLetValue: - case SymbolUseKind::NestedLocalVarValue: - case SymbolUseKind::NestedLocalLetValue: - case SymbolUseKind::WorkgroupSizeValue: - return "identifier"; - case SymbolUseKind::CallFunction: - return "function"; - } - return ""; -} - /// @returns the declaration scope depth for the symbol declaration kind. /// Globals are at depth 0, parameters and locals are at depth 1, /// nested locals are at depth 2. @@ -783,10 +742,16 @@ TEST_P(ResolverDependencyGraphUndeclaredSymbolTest, Test) { // Build a use of a non-existent symbol SymbolTestHelper helper(this); - helper.Add(use_kind, symbol, Source{{56, 78}}); + auto* ident = helper.Add(use_kind, symbol, Source{{56, 78}}); helper.Build(); - Build("56:78 error: unknown " + DiagString(use_kind) + ": 'SYMBOL'"); + auto graph = Build(); + + auto resolved_identifier = graph.resolved_identifiers.Find(ident); + ASSERT_NE(resolved_identifier, nullptr); + auto* unresolved = resolved_identifier->Unresolved(); + ASSERT_NE(unresolved, nullptr); + EXPECT_EQ(unresolved->name, "SYMBOL"); } INSTANTIATE_TEST_SUITE_P(Types, @@ -826,14 +791,28 @@ TEST_F(ResolverDependencyGraphDeclSelfUse, GlobalConst) { TEST_F(ResolverDependencyGraphDeclSelfUse, LocalVar) { const Symbol symbol = Sym("SYMBOL"); - WrapInFunction(Decl(Var(symbol, ty.i32(), Mul(Expr(Source{{12, 34}}, symbol), 123_i)))); - Build("12:34 error: unknown identifier: 'SYMBOL'"); + auto* ident = Ident(Source{{12, 34}}, symbol); + WrapInFunction(Decl(Var(symbol, ty.i32(), Mul(Expr(ident), 123_i)))); + auto graph = Build(); + + auto resolved_identifier = graph.resolved_identifiers.Find(ident); + ASSERT_TRUE(resolved_identifier); + auto* unresolved = resolved_identifier->Unresolved(); + ASSERT_NE(unresolved, nullptr); + EXPECT_EQ(unresolved->name, "SYMBOL"); } TEST_F(ResolverDependencyGraphDeclSelfUse, LocalLet) { const Symbol symbol = Sym("SYMBOL"); - WrapInFunction(Decl(Let(symbol, ty.i32(), Mul(Expr(Source{{12, 34}}, symbol), 123_i)))); - Build("12:34 error: unknown identifier: 'SYMBOL'"); + auto* ident = Ident(Source{{12, 34}}, symbol); + WrapInFunction(Decl(Let(symbol, ty.i32(), Mul(Expr(ident), 123_i)))); + auto graph = Build(); + + auto resolved_identifier = graph.resolved_identifiers.Find(ident); + ASSERT_TRUE(resolved_identifier); + auto* unresolved = resolved_identifier->Unresolved(); + ASSERT_NE(unresolved, nullptr); + EXPECT_EQ(unresolved->name, "SYMBOL"); } } // namespace undeclared_tests @@ -852,7 +831,7 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, DirectCall) { utils::Vector{CallStmt(Call(Ident(Source{{56, 78}}, "main")))}); Build(R"(12:34 error: cyclic dependency found: 'main' -> 'main' -56:78 note: function 'main' calls function 'main' here)"); +56:78 note: function 'main' references function 'main' here)"); } TEST_F(ResolverDependencyGraphCyclicRefTest, IndirectCall) { @@ -876,9 +855,9 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, IndirectCall) { utils::Vector{CallStmt(Call(Ident(Source{{5, 10}}, "c")))}); Build(R"(5:1 error: cyclic dependency found: 'b' -> 'c' -> 'd' -> 'b' -5:10 note: function 'b' calls function 'c' here -4:10 note: function 'c' calls function 'd' here -3:10 note: function 'd' calls function 'b' here)"); +5:10 note: function 'b' references function 'c' here +4:10 note: function 'c' references function 'd' here +3:10 note: function 'd' references function 'b' here)"); } TEST_F(ResolverDependencyGraphCyclicRefTest, Alias_Direct) { @@ -1160,17 +1139,22 @@ TEST_P(ResolverDependencyGraphResolveToUserDeclTest, Test) { // If the declaration is visible to the use, then we expect the analysis to // succeed. - bool expect_pass = ScopeDepth(decl_kind) <= ScopeDepth(use_kind); - auto graph = Build(expect_pass ? "" : "56:78 error: unknown identifier: 'SYMBOL'"); + bool expect_resolved = ScopeDepth(decl_kind) <= ScopeDepth(use_kind); + auto graph = Build(); - if (expect_pass) { + auto resolved_identifier = graph.resolved_identifiers.Find(use); + ASSERT_TRUE(resolved_identifier); + + if (expect_resolved) { // Check that the use resolves to the declaration - auto resolved_identifier = graph.resolved_identifiers.Find(use); - ASSERT_TRUE(resolved_identifier); auto* resolved_node = resolved_identifier->Node(); EXPECT_EQ(resolved_node, decl) << "resolved: " << (resolved_node ? resolved_node->TypeInfo().name : "") << "\n" << "decl: " << decl->TypeInfo().name; + } else { + auto* unresolved = resolved_identifier->Unresolved(); + ASSERT_NE(unresolved, nullptr); + EXPECT_EQ(unresolved->name, "SYMBOL"); } } diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc index b7af242ec3..b18ff2a74f 100644 --- a/src/tint/resolver/function_validation_test.cc +++ b/src/tint/resolver/function_validation_test.cc @@ -1070,11 +1070,14 @@ TEST_P(ResolverFunctionParameterValidationTest, AddressSpaceNoExtension) { ss << param.address_space; EXPECT_FALSE(r()->Resolve()); if (param.expectation == Expectation::kInvalid) { - EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: '" + ss.str() + "'"); + std::string err = R"(12:34 error: unresolved address space '${addr_space}' +12:34 note: Possible values: 'function', 'private', 'push_constant', 'storage', 'uniform', 'workgroup')"; + err = utils::ReplaceAll(err, "${addr_space}", utils::ToString(param.address_space)); + EXPECT_EQ(r()->error(), err); } else { EXPECT_EQ(r()->error(), - "12:34 error: function parameter of pointer type cannot be in '" + ss.str() + - "' address space"); + "12:34 error: function parameter of pointer type cannot be in '" + + utils::ToString(param.address_space) + "' address space"); } } } @@ -1091,8 +1094,10 @@ TEST_P(ResolverFunctionParameterValidationTest, AddressSpaceWithExtension) { } else { EXPECT_FALSE(r()->Resolve()); if (param.expectation == Expectation::kInvalid) { - EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: '" + - utils::ToString(param.address_space) + "'"); + std::string err = R"(12:34 error: unresolved address space '${addr_space}' +12:34 note: Possible values: 'function', 'private', 'push_constant', 'storage', 'uniform', 'workgroup')"; + err = utils::ReplaceAll(err, "${addr_space}", utils::ToString(param.address_space)); + EXPECT_EQ(r()->error(), err); } else { EXPECT_EQ(r()->error(), "12:34 error: function parameter of pointer type cannot be in '" + diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index ca5c46f9cc..27f1881743 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -1418,7 +1418,8 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) { for (auto* expr : utils::Reverse(sorted)) { auto* sem_expr = Switch( - expr, [&](const ast::IndexAccessorExpression* array) { return IndexAccessor(array); }, + expr, // + [&](const ast::IndexAccessorExpression* array) { return IndexAccessor(array); }, [&](const ast::BinaryExpression* bin_op) { return Binary(bin_op); }, [&](const ast::BitcastExpression* bitcast) { return Bitcast(bitcast); }, [&](const ast::CallExpression* call) { return Call(call); }, @@ -1488,10 +1489,12 @@ sem::ValueExpression* Resolver::ValueExpression(const ast::Expression* expr) { } sem::TypeExpression* Resolver::TypeExpression(const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "type"}; return sem_.AsTypeExpression(Expression(expr)); } sem::FunctionExpression* Resolver::FunctionExpression(const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "call target"}; return sem_.AsFunctionExpression(Expression(expr)); } @@ -1505,31 +1508,38 @@ type::Type* Resolver::Type(const ast::Expression* ast) { sem::BuiltinEnumExpression* Resolver::AddressSpaceExpression( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "address space", builtin::kAddressSpaceStrings}; return sem_.AsAddressSpace(Expression(expr)); } sem::BuiltinEnumExpression* Resolver::BuiltinValueExpression( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "builtin value", builtin::kBuiltinValueStrings}; return sem_.AsBuiltinValue(Expression(expr)); } sem::BuiltinEnumExpression* Resolver::TexelFormatExpression( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "texel format", builtin::kTexelFormatStrings}; return sem_.AsTexelFormat(Expression(expr)); } sem::BuiltinEnumExpression* Resolver::AccessExpression( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "access", builtin::kAccessStrings}; return sem_.AsAccess(Expression(expr)); } sem::BuiltinEnumExpression* Resolver::InterpolationSampling( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "interpolation sampling", + builtin::kInterpolationSamplingStrings}; return sem_.AsInterpolationSampling(Expression(expr)); } sem::BuiltinEnumExpression* Resolver::InterpolationType( const ast::Expression* expr) { + identifier_resolve_hint_ = {expr, "interpolation type", builtin::kInterpolationTypeStrings}; return sem_.AsInterpolationType(Expression(expr)); } @@ -2196,6 +2206,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) { return ty_init_or_conv(ty); } + if (auto* unresolved = resolved->Unresolved()) { + AddError("unresolved call target '" + unresolved->name + "'", expr->source); + return nullptr; + } + ErrorMismatchedResolvedIdentifier(ident->source, *resolved, "call target"); return nullptr; }(); @@ -2541,11 +2556,11 @@ type::Type* Resolver::BuiltinType(builtin::Builtin builtin_ty, const ast::Identi return nullptr; } - auto* format = sem_.AsTexelFormat(Expression(tmpl_ident->arguments[0])); + auto* format = TexelFormatExpression(tmpl_ident->arguments[0]); if (TINT_UNLIKELY(!format)) { return nullptr; } - auto* access = sem_.AsAccess(Expression(tmpl_ident->arguments[1])); + auto* access = AccessExpression(tmpl_ident->arguments[1]); if (TINT_UNLIKELY(!access)) { return nullptr; } @@ -3030,6 +3045,30 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) { expr, current_statement_, fmt); } + if (auto* unresolved = resolved->Unresolved()) { + if (identifier_resolve_hint_.expression == expr) { + AddError("unresolved " + std::string(identifier_resolve_hint_.usage) + " '" + + unresolved->name + "'", + expr->source); + if (!identifier_resolve_hint_.suggestions.IsEmpty()) { + // Filter out suggestions that have a leading underscore. + utils::Vector filtered; + for (auto* str : identifier_resolve_hint_.suggestions) { + if (str[0] != '_') { + filtered.Push(str); + } + } + std::ostringstream msg; + utils::SuggestAlternatives(unresolved->name, + filtered.Slice().Reinterpret(), msg); + AddNote(msg.str(), expr->source); + } + } else { + AddError("unresolved identifier '" + unresolved->name + "'", expr->source); + } + return nullptr; + } + TINT_UNREACHABLE(Resolver, diagnostics_) << "unhandled resolved identifier: " << resolved->String(builder_->Symbols(), diagnostics_); return nullptr; diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index bbde9e952c..c29a416aca 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -530,6 +530,17 @@ class Resolver { std::unordered_set parameter_reads; }; + /// A hint for the usage of an identifier expression. + /// Used to provide more informative error diagnostics on resolution failure. + struct IdentifierResolveHint { + /// The expression this hint applies to + const ast::Expression* expression = nullptr; + /// The usage of the identifier. + const char* usage = "identifier"; + /// Suggested strings if the identifier failed to resolve + utils::Slice suggestions = utils::Empty; + }; + ProgramBuilder* const builder_; diag::List& diagnostics_; ConstEval const_eval_; @@ -555,6 +566,7 @@ class Resolver { utils::Hashmap logical_binary_lhs_to_parent_; utils::Hashset skip_const_eval_; + IdentifierResolveHint identifier_resolve_hint_; }; } // namespace tint::resolver diff --git a/src/tint/resolver/type_validation_test.cc b/src/tint/resolver/type_validation_test.cc index b9f2edbd84..fe9622b6a1 100644 --- a/src/tint/resolver/type_validation_test.cc +++ b/src/tint/resolver/type_validation_test.cc @@ -1083,7 +1083,7 @@ TEST_P(StorageTextureDimensionTest, All) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } else { EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: unknown type: '" + std::string(params.name) + "'"); + EXPECT_EQ(r()->error(), "12:34 error: unresolved type '" + std::string(params.name) + "'"); } } INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest, diff --git a/src/tint/resolver/unresolved_identifier_test.cc b/src/tint/resolver/unresolved_identifier_test.cc new file mode 100644 index 0000000000..e52b858b27 --- /dev/null +++ b/src/tint/resolver/unresolved_identifier_test.cc @@ -0,0 +1,108 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gmock/gmock.h" + +#include "src/tint/resolver/resolver_test_helper.h" + +using namespace tint::number_suffixes; // NOLINT + +namespace tint::resolver { +namespace { + +using ResolverUnresolvedIdentifierSuggestions = ResolverTest; + +TEST_F(ResolverUnresolvedIdentifierSuggestions, AddressSpace) { + AST().AddGlobalVariable(create( + Ident("v"), // name + ty.i32(), // type + Expr(Source{{12, 34}}, "privte"), // declared_address_space + nullptr, // declared_access + nullptr, // initializer + utils::Empty // attributes + )); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved address space 'privte' +12:34 note: Did you mean 'private'? +Possible values: 'function', 'private', 'push_constant', 'storage', 'uniform', 'workgroup')"); +} + +TEST_F(ResolverUnresolvedIdentifierSuggestions, BuiltinValue) { + Func("f", + utils::Vector{ + Param("p", ty.i32(), utils::Vector{Builtin(Expr(Source{{12, 34}}, "positon"))})}, + ty.void_(), utils::Empty); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved builtin value 'positon' +12:34 note: Did you mean 'position'? +Possible values: 'frag_depth', 'front_facing', 'global_invocation_id', 'instance_index', 'local_invocation_id', 'local_invocation_index', 'num_workgroups', 'position', 'sample_index', 'sample_mask', 'vertex_index', 'workgroup_id')"); +} + +TEST_F(ResolverUnresolvedIdentifierSuggestions, TexelFormat) { + GlobalVar("v", ty("texture_storage_1d", Expr(Source{{12, 34}}, "rba8unorm"), "read")); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved texel format 'rba8unorm' +12:34 note: Did you mean 'rgba8unorm'? +Possible values: 'bgra8unorm', 'r32float', 'r32sint', 'r32uint', 'rg32float', 'rg32sint', 'rg32uint', 'rgba16float', 'rgba16sint', 'rgba16uint', 'rgba32float', 'rgba32sint', 'rgba32uint', 'rgba8sint', 'rgba8snorm', 'rgba8uint', 'rgba8unorm')"); +} + +TEST_F(ResolverUnresolvedIdentifierSuggestions, AccessMode) { + AST().AddGlobalVariable(create(Ident("v"), // name + ty.i32(), // type + Expr("private"), // declared_address_space + Expr(Source{{12, 34}}, "reed"), // declared_access + nullptr, // initializer + utils::Empty // attributes + )); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved access 'reed' +12:34 note: Did you mean 'read'? +Possible values: 'read', 'read_write', 'write')"); +} + +TEST_F(ResolverUnresolvedIdentifierSuggestions, InterpolationSampling) { + Structure("s", utils::Vector{ + Member("m", ty.vec4(), + utils::Vector{ + Interpolate(builtin::InterpolationType::kLinear, + Expr(Source{{12, 34}}, "centre")), + }), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved interpolation sampling 'centre' +12:34 note: Did you mean 'center'? +Possible values: 'center', 'centroid', 'sample')"); +} + +TEST_F(ResolverUnresolvedIdentifierSuggestions, InterpolationType) { + Structure("s", utils::Vector{ + Member("m", ty.vec4(), + utils::Vector{ + Interpolate(Expr(Source{{12, 34}}, "liner")), + }), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved interpolation type 'liner' +12:34 note: Did you mean 'linear'? +Possible values: 'flat', 'linear', 'perspective')"); +} + +} // namespace +} // namespace tint::resolver diff --git a/src/tint/resolver/validation_test.cc b/src/tint/resolver/validation_test.cc index 1e806d3f1b..6485a6a331 100644 --- a/src/tint/resolver/validation_test.cc +++ b/src/tint/resolver/validation_test.cc @@ -168,7 +168,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariable_Fail) { WrapInFunction(assign); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'b')"); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'b')"); } TEST_F(ResolverValidationTest, UsingUndefinedVariableInBlockStatement_Fail) { @@ -183,7 +183,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableInBlockStatement_Fail) { WrapInFunction(body); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'b')"); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'b')"); } TEST_F(ResolverValidationTest, UsingUndefinedVariableGlobalVariable_Pass) { @@ -223,7 +223,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableInnerScope_Fail) { WrapInFunction(outer_body); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'a')"); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'a')"); } TEST_F(ResolverValidationTest, UsingUndefinedVariableOuterScope_Pass) { @@ -263,7 +263,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableDifferentScope_Fail) { WrapInFunction(outer_body); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'a')"); + EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'a')"); } TEST_F(ResolverValidationTest, AddressSpace_FunctionVariableWorkgroupClass) { diff --git a/src/tint/utils/slice.h b/src/tint/utils/slice.h new file mode 100644 index 0000000000..719a53e5ff --- /dev/null +++ b/src/tint/utils/slice.h @@ -0,0 +1,205 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_UTILS_SLICE_H_ +#define SRC_TINT_UTILS_SLICE_H_ + +#include +#include + +#include "src/tint/castable.h" +#include "src/tint/traits.h" + +namespace tint::utils { + +/// A type used to indicate an empty array. +struct EmptyType {}; + +/// An instance of the EmptyType. +static constexpr EmptyType Empty; + +/// Mode enumerator for ReinterpretSlice +enum class ReinterpretMode { + /// Only upcasts of pointers are permitted + kSafe, + /// Potentially unsafe downcasts of pointers are also permitted + kUnsafe, +}; + +namespace detail { + +template +static constexpr bool ConstRemoved = std::is_const_v && !std::is_const_v; + +/// Private implementation of tint::utils::CanReinterpretSlice. +/// Specialized for the case of TO equal to FROM, which is the common case, and avoids inspection of +/// the base classes, which can be troublesome if the slice is of an incomplete type. +template +struct CanReinterpretSlice { + private: + using TO_EL = std::remove_pointer_t>; + using FROM_EL = std::remove_pointer_t>; + + public: + /// @see utils::CanReinterpretSlice + static constexpr bool value = + // const can only be applied, not removed + !ConstRemoved && + + // Both TO and FROM are the same type (ignoring const) + (std::is_same_v, std::remove_const_t> || + + // Both TO and FROM are pointers... + ((std::is_pointer_v && std::is_pointer_v)&& + + // const can only be applied to element type, not removed + !ConstRemoved && + + // Either: + // * Both the pointer elements are of the same type (ignoring const) + // * Both the pointer elements are both Castable, and MODE is kUnsafe, or FROM is of, + // or + // derives from TO + (std::is_same_v, std::remove_const_t> || + (IsCastable && + (MODE == ReinterpretMode::kUnsafe || traits::IsTypeOrDerived))))); +}; + +/// Specialization of 'CanReinterpretSlice' for when TO and FROM are equal types. +template +struct CanReinterpretSlice { + /// Always `true` as TO and FROM are the same type. + static constexpr bool value = true; +}; + +} // namespace detail + +/// Evaluates whether a `Slice` and be reinterpreted as a `Slice`. +/// Slices can be reinterpreted if: +/// * TO has the same or more 'constness' than FROM. +/// * And either: +/// * `FROM` and `TO` are pointers to the same type +/// * `FROM` and `TO` are pointers to CastableBase (or derived), and the pointee type of `TO` is of +/// the same type as, or is an ancestor of the pointee type of `FROM`. +template +static constexpr bool CanReinterpretSlice = detail::CanReinterpretSlice::value; + +/// A slice represents a contigious array of elements of type T. +template +struct Slice { + /// Type of `T`. + using value_type = T; + + /// The pointer to the first element in the slice + T* data = nullptr; + + /// The total number of elements in the slice + size_t len = 0; + + /// The total capacity of the backing store for the slice + size_t cap = 0; + + /// Constructor + Slice() = default; + + /// Constructor + Slice(EmptyType) {} // NOLINT + + /// Constructor + /// @param d pointer to the first element in the slice + /// @param l total number of elements in the slice + /// @param c total capacity of the backing store for the slice + Slice(T* d, size_t l, size_t c) : data(d), len(l), cap(c) {} + + /// Constructor + /// @param elements c-array of elements + template + Slice(T (&elements)[N]) // NOLINT + : data(elements), len(N), cap(N) {} + + /// Reinterprets this slice as `const Slice&` + /// @returns the reinterpreted slice + /// @see CanReinterpretSlice + template + const Slice& Reinterpret() const { + static_assert(CanReinterpretSlice); + return *Bitcast*>(this); + } + + /// Reinterprets this slice as `Slice&` + /// @returns the reinterpreted slice + /// @see CanReinterpretSlice + template + Slice& Reinterpret() { + static_assert(CanReinterpretSlice); + return *Bitcast*>(this); + } + + /// @return true if the slice length is zero + bool IsEmpty() const { return len == 0; } + + /// Index operator + /// @param i the element index. Must be less than `len`. + /// @returns a reference to the i'th element. + T& operator[](size_t i) { return data[i]; } + + /// Index operator + /// @param i the element index. Must be less than `len`. + /// @returns a reference to the i'th element. + const T& operator[](size_t i) const { return data[i]; } + + /// @returns a reference to the first element in the vector + T& Front() { return data[0]; } + + /// @returns a reference to the first element in the vector + const T& Front() const { return data[0]; } + + /// @returns a reference to the last element in the vector + T& Back() { return data[len - 1]; } + + /// @returns a reference to the last element in the vector + const T& Back() const { return data[len - 1]; } + + /// @returns a pointer to the first element in the vector + T* begin() { return data; } + + /// @returns a pointer to the first element in the vector + const T* begin() const { return data; } + + /// @returns a pointer to one past the last element in the vector + T* end() { return data + len; } + + /// @returns a pointer to one past the last element in the vector + const T* end() const { return data + len; } + + /// @returns a reverse iterator starting with the last element in the vector + auto rbegin() { return std::reverse_iterator(end()); } + + /// @returns a reverse iterator starting with the last element in the vector + auto rbegin() const { return std::reverse_iterator(end()); } + + /// @returns the end for a reverse iterator + auto rend() { return std::reverse_iterator(begin()); } + + /// @returns the end for a reverse iterator + auto rend() const { return std::reverse_iterator(begin()); } +}; + +/// Deduction guide for Slice from c-array +template +Slice(T (&elements)[N]) -> Slice; + +} // namespace tint::utils + +#endif // SRC_TINT_UTILS_SLICE_H_ diff --git a/src/tint/utils/slice_test.cc b/src/tint/utils/slice_test.cc new file mode 100644 index 0000000000..6a4493cb9c --- /dev/null +++ b/src/tint/utils/slice_test.cc @@ -0,0 +1,131 @@ +// Copyright 2023 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/utils/slice.h" + +#include "gmock/gmock.h" + +namespace tint::utils { +namespace { + +class C0 : public Castable {}; +class C1 : public Castable {}; +class C2a : public Castable {}; +class C2b : public Castable {}; + +//////////////////////////////////////////////////////////////////////////////// +// Static asserts +//////////////////////////////////////////////////////////////////////////////// +// Non-pointer +static_assert(CanReinterpretSlice, "same type"); +static_assert(CanReinterpretSlice, "apply const"); +static_assert(!CanReinterpretSlice, "remove const"); + +// Non-castable pointers +static_assert(CanReinterpretSlice, "apply ptr const"); +static_assert(!CanReinterpretSlice, "remove ptr const"); +static_assert(CanReinterpretSlice, "apply el const"); +static_assert(!CanReinterpretSlice, "remove el const"); + +// Castable +static_assert(CanReinterpretSlice, "apply const"); +static_assert(!CanReinterpretSlice, "remove const"); +static_assert(CanReinterpretSlice, "up cast"); +static_assert(CanReinterpretSlice, "up cast"); +static_assert(CanReinterpretSlice, "up cast, apply const"); +static_assert(!CanReinterpretSlice, + "up cast, remove const"); +static_assert(!CanReinterpretSlice, "down cast"); +static_assert(!CanReinterpretSlice, "down cast"); +static_assert(!CanReinterpretSlice, + "down cast, apply const"); +static_assert(!CanReinterpretSlice, + "down cast, remove const"); +static_assert(!CanReinterpretSlice, + "down cast, apply const"); +static_assert(!CanReinterpretSlice, + "down cast, remove const"); +static_assert(!CanReinterpretSlice, "sideways cast"); +static_assert(!CanReinterpretSlice, + "sideways cast"); +static_assert(!CanReinterpretSlice, + "sideways cast, apply const"); +static_assert(!CanReinterpretSlice, + "sideways cast, remove const"); + +TEST(TintSliceTest, Ctor) { + Slice slice; + EXPECT_EQ(slice.data, nullptr); + EXPECT_EQ(slice.len, 0u); + EXPECT_EQ(slice.cap, 0u); + EXPECT_TRUE(slice.IsEmpty()); +} + +TEST(TintSliceTest, CtorEmpty) { + Slice slice{Empty}; + EXPECT_EQ(slice.data, nullptr); + EXPECT_EQ(slice.len, 0u); + EXPECT_EQ(slice.cap, 0u); + EXPECT_TRUE(slice.IsEmpty()); +} + +TEST(TintSliceTest, CtorCArray) { + int elements[] = {1, 2, 3}; + + auto slice = Slice{elements}; + EXPECT_EQ(slice.data, elements); + EXPECT_EQ(slice.len, 3u); + EXPECT_EQ(slice.cap, 3u); + EXPECT_FALSE(slice.IsEmpty()); +} + +TEST(TintSliceTest, Index) { + int elements[] = {1, 2, 3}; + + auto slice = Slice{elements}; + EXPECT_EQ(slice[0], 1); + EXPECT_EQ(slice[1], 2); + EXPECT_EQ(slice[2], 3); +} + +TEST(TintSliceTest, Front) { + int elements[] = {1, 2, 3}; + auto slice = Slice{elements}; + EXPECT_EQ(slice.Front(), 1); +} + +TEST(TintSliceTest, Back) { + int elements[] = {1, 2, 3}; + auto slice = Slice{elements}; + EXPECT_EQ(slice.Back(), 3); +} + +TEST(TintSliceTest, BeginEnd) { + int elements[] = {1, 2, 3}; + auto slice = Slice{elements}; + EXPECT_THAT(slice, testing::ElementsAre(1, 2, 3)); +} + +TEST(TintSliceTest, ReverseBeginEnd) { + int elements[] = {1, 2, 3}; + auto slice = Slice{elements}; + size_t i = 0; + for (auto it = slice.rbegin(); it != slice.rend(); it++) { + EXPECT_EQ(*it, elements[2 - i]); + i++; + } +} + +} // namespace +} // namespace tint::utils diff --git a/src/tint/utils/string.cc b/src/tint/utils/string.cc index 354ad96062..bccd52c23a 100644 --- a/src/tint/utils/string.cc +++ b/src/tint/utils/string.cc @@ -48,4 +48,36 @@ size_t Distance(std::string_view str_a, std::string_view str_b) { return at(len_a, len_b); } +void SuggestAlternatives(std::string_view got, + Slice strings, + std::ostringstream& ss) { + // If the string typed was within kSuggestionDistance of one of the possible enum values, + // suggest that. Don't bother with suggestions if the string was extremely long. + constexpr size_t kSuggestionDistance = 5; + constexpr size_t kSuggestionMaxLength = 64; + if (!got.empty() && got.size() < kSuggestionMaxLength) { + size_t candidate_dist = kSuggestionDistance; + const char* candidate = nullptr; + for (auto* str : strings) { + auto dist = utils::Distance(str, got); + if (dist < candidate_dist) { + candidate = str; + candidate_dist = dist; + } + } + if (candidate) { + ss << "Did you mean '" << candidate << "'?\n"; + } + } + + // List all the possible enumerator values + ss << "Possible values: "; + for (auto* str : strings) { + if (str != strings[0]) { + ss << ", "; + } + ss << "'" << str << "'"; + } +} + } // namespace tint::utils diff --git a/src/tint/utils/string.h b/src/tint/utils/string.h index 5c05ebfe04..08e0be3138 100644 --- a/src/tint/utils/string.h +++ b/src/tint/utils/string.h @@ -19,6 +19,8 @@ #include #include +#include "src/tint/utils/slice.h" + namespace tint::utils { /// @param str the string to apply replacements to @@ -66,42 +68,13 @@ inline size_t HasPrefix(std::string_view str, std::string_view prefix) { /// @returns the Levenshtein distance between @p a and @p b size_t Distance(std::string_view a, std::string_view b); -/// Suggest alternatives for an unrecognized string from a list of expected values. +/// Suggest alternatives for an unrecognized string from a list of possible values. /// @param got the unrecognized string -/// @param strings the list of expected values +/// @param strings the list of possible values /// @param ss the stream to write the suggest and list of possible values to -template void SuggestAlternatives(std::string_view got, - const char* const (&strings)[N], - std::ostringstream& ss) { - // If the string typed was within kSuggestionDistance of one of the possible enum values, - // suggest that. Don't bother with suggestions if the string was extremely long. - constexpr size_t kSuggestionDistance = 5; - constexpr size_t kSuggestionMaxLength = 64; - if (!got.empty() && got.size() < kSuggestionMaxLength) { - size_t candidate_dist = kSuggestionDistance; - const char* candidate = nullptr; - for (auto* str : strings) { - auto dist = utils::Distance(str, got); - if (dist < candidate_dist) { - candidate = str; - candidate_dist = dist; - } - } - if (candidate) { - ss << "Did you mean '" << candidate << "'?\n"; - } - } - - // List all the possible enumerator values - ss << "Possible values: "; - for (auto* str : strings) { - if (str != strings[0]) { - ss << ", "; - } - ss << "'" << str << "'"; - } -} + Slice strings, + std::ostringstream& ss); } // namespace tint::utils diff --git a/src/tint/utils/string_test.cc b/src/tint/utils/string_test.cc index 6b17dfb409..b9e1ebb842 100644 --- a/src/tint/utils/string_test.cc +++ b/src/tint/utils/string_test.cc @@ -60,14 +60,16 @@ TEST(StringTest, Distance) { TEST(StringTest, SuggestAlternatives) { { + const char* alternatives[] = {"hello world", "Hello World"}; std::ostringstream ss; - SuggestAlternatives("hello wordl", {"hello world", "Hello World"}, ss); + SuggestAlternatives("hello wordl", alternatives, ss); EXPECT_EQ(ss.str(), R"(Did you mean 'hello world'? Possible values: 'hello world', 'Hello World')"); } { + const char* alternatives[] = {"foobar", "something else"}; std::ostringstream ss; - SuggestAlternatives("hello world", {"foobar", "something else"}, ss); + SuggestAlternatives("hello world", alternatives, ss); EXPECT_EQ(ss.str(), R"(Possible values: 'foobar', 'something else')"); } } diff --git a/src/tint/utils/vector.h b/src/tint/utils/vector.h index cca2409f7b..e39854aaf6 100644 --- a/src/tint/utils/vector.h +++ b/src/tint/utils/vector.h @@ -19,14 +19,14 @@ #include #include #include +#include #include #include #include -#include "src/tint/castable.h" -#include "src/tint/traits.h" #include "src/tint/utils/bitcast.h" #include "src/tint/utils/compiler_macros.h" +#include "src/tint/utils/slice.h" #include "src/tint/utils/string.h" namespace tint::utils { @@ -41,137 +41,6 @@ class VectorRef; namespace tint::utils { -/// A type used to indicate an empty array. -struct EmptyType {}; - -/// An instance of the EmptyType. -static constexpr EmptyType Empty; - -/// A slice represents a contigious array of elements of type T. -template -struct Slice { - /// The pointer to the first element in the slice - T* data = nullptr; - - /// The total number of elements in the slice - size_t len = 0; - - /// The total capacity of the backing store for the slice - size_t cap = 0; - - /// Index operator - /// @param i the element index. Must be less than `len`. - /// @returns a reference to the i'th element. - T& operator[](size_t i) { return data[i]; } - - /// Index operator - /// @param i the element index. Must be less than `len`. - /// @returns a reference to the i'th element. - const T& operator[](size_t i) const { return data[i]; } - - /// @returns a reference to the first element in the vector - T& Front() { return data[0]; } - - /// @returns a reference to the first element in the vector - const T& Front() const { return data[0]; } - - /// @returns a reference to the last element in the vector - T& Back() { return data[len - 1]; } - - /// @returns a reference to the last element in the vector - const T& Back() const { return data[len - 1]; } - - /// @returns a pointer to the first element in the vector - T* begin() { return data; } - - /// @returns a pointer to the first element in the vector - const T* begin() const { return data; } - - /// @returns a pointer to one past the last element in the vector - T* end() { return data + len; } - - /// @returns a pointer to one past the last element in the vector - const T* end() const { return data + len; } - - /// @returns a reverse iterator starting with the last element in the vector - auto rbegin() { return std::reverse_iterator(end()); } - - /// @returns a reverse iterator starting with the last element in the vector - auto rbegin() const { return std::reverse_iterator(end()); } - - /// @returns the end for a reverse iterator - auto rend() { return std::reverse_iterator(begin()); } - - /// @returns the end for a reverse iterator - auto rend() const { return std::reverse_iterator(begin()); } -}; - -/// Mode enumerator for ReinterpretSlice -enum class ReinterpretMode { - /// Only upcasts of pointers are permitted - kSafe, - /// Potentially unsafe downcasts of pointers are also permitted - kUnsafe, -}; - -namespace detail { - -/// Private implementation of tint::utils::CanReinterpretSlice. -/// Specialized for the case of TO equal to FROM, which is the common case, and avoids inspection of -/// the base classes, which can be troublesome if the slice is of an incomplete type. -template -struct CanReinterpretSlice { - /// True if a slice of FROM can be reinterpreted as a slice of TO - static constexpr bool value = - // Both TO and FROM are pointers - (std::is_pointer_v && std::is_pointer_v)&& // - // const can only be applied, not removed - (std::is_const_v> || - !std::is_const_v>)&& // - // TO and FROM are both Castable - IsCastable, std::remove_pointer_t> && // - // MODE is kUnsafe, or FROM is of, or derives from TO - (MODE == ReinterpretMode::kUnsafe || - traits::IsTypeOrDerived, std::remove_pointer_t>); -}; - -/// Specialization of 'CanReinterpretSlice' for when TO and FROM are equal types. -template -struct CanReinterpretSlice { - /// Always `true` as TO and FROM are the same type. - static constexpr bool value = true; -}; - -} // namespace detail - -/// Evaluates whether a `vector` and be reinterpreted as a `vector`. -/// Vectors can be reinterpreted if both `FROM` and `TO` are pointers to a type that derives from -/// CastableBase, and the pointee type of `TO` is of the same type as, or is an ancestor of the -/// pointee type of `FROM`. Vectors of non-`const` Castable pointers can be converted to a vector of -/// `const` Castable pointers. -template -static constexpr bool CanReinterpretSlice = detail::CanReinterpretSlice::value; - -/// Reinterprets `const Slice*` as `const Slice*` -/// @param slice a pointer to the slice to reinterpret -/// @returns the reinterpreted slice -/// @see CanReinterpretSlice -template -const Slice* ReinterpretSlice(const Slice* slice) { - static_assert(CanReinterpretSlice); - return Bitcast*>(slice); -} - -/// Reinterprets `Slice*` as `Slice*` -/// @param slice a pointer to the slice to reinterpret -/// @returns the reinterpreted slice -/// @see CanReinterpretSlice -template -Slice* ReinterpretSlice(Slice* slice) { - static_assert(CanReinterpretSlice); - return Bitcast*>(slice); -} - /// Vector is a small-object-optimized, dynamically-sized vector of contigious elements of type T. /// /// Vector will fit `N` elements internally before spilling to heap allocations. If `N` is greater @@ -244,7 +113,7 @@ class Vector { ReinterpretMode MODE, typename = std::enable_if_t>> Vector(const Vector& other) { // NOLINT(runtime/explicit) - Copy(*ReinterpretSlice(&other.impl_.slice)); + Copy(other.impl_.slice.template Reinterpret); } /// Move constructor with covariance / const conversion @@ -518,6 +387,9 @@ class Vector { return !(*this == other); } + /// @returns the internal slice of the vector + utils::Slice Slice() { return impl_.slice; } + private: /// Friend class (differing specializations of this class) template @@ -531,9 +403,6 @@ class Vector { template friend class VectorRef; - /// The slice type used by this vector - using Slice = utils::Slice; - template void AppendVariadic(Ts&&... args) { ((new (&impl_.slice.data[impl_.slice.len++]) T(std::forward(args))), ...); @@ -555,7 +424,7 @@ class Vector { /// Copies all the elements from `other` to this vector, replacing the content of this vector. /// @param other the - void Copy(const Slice& other) { + void Copy(const utils::Slice& other) { if (impl_.slice.cap < other.len) { ClearAndFree(); impl_.Allocate(other.len); @@ -592,7 +461,7 @@ class Vector { /// The internal structure for the vector with a small array. struct ImplWithSmallArray { TStorage small_arr[N]; - Slice slice = {small_arr[0].Get(), 0, N}; + utils::Slice slice = {small_arr[0].Get(), 0, N}; /// Allocates a new vector of `T` either from #small_arr, or from the heap, then assigns the /// pointer it to #slice.data, and updates #slice.cap. @@ -620,7 +489,7 @@ class Vector { /// The internal structure for the vector without a small array. struct ImplWithoutSmallArray { - Slice slice = {nullptr, 0, 0}; + utils::Slice slice = Empty; /// Allocates a new vector of `T` and assigns it to #slice.data, and updates #slice.cap. void Allocate(size_t new_cap) { @@ -759,15 +628,14 @@ class VectorRef { template >> VectorRef(const VectorRef& other) // NOLINT(runtime/explicit) - : slice_(*ReinterpretSlice(&other.slice_)) {} + : slice_(other.slice_.template Reinterpret()) {} /// Move constructor with covariance / const conversion /// @param other the vector reference template >> VectorRef(VectorRef&& other) // NOLINT(runtime/explicit) - : slice_(*ReinterpretSlice(&other.slice_)), - can_move_(other.can_move_) {} + : slice_(other.slice_.template Reinterpret()), can_move_(other.can_move_) {} /// Constructor from a Vector with covariance / const conversion /// @param vector the vector to create a reference of @@ -776,7 +644,7 @@ class VectorRef { size_t N, typename = std::enable_if_t>> VectorRef(Vector& vector) // NOLINT(runtime/explicit) - : slice_(*ReinterpretSlice(&vector.impl_.slice)) {} + : slice_(vector.impl_.slice.template Reinterpret()) {} /// Constructor from a moved Vector with covariance / const conversion /// @param vector the vector to create a reference of @@ -785,8 +653,7 @@ class VectorRef { size_t N, typename = std::enable_if_t>> VectorRef(Vector&& vector) // NOLINT(runtime/explicit) - : slice_(*ReinterpretSlice(&vector.impl_.slice)), - can_move_(vector.impl_.CanMove()) {} + : slice_(vector.impl_.slice.template Reinterpret()), can_move_(vector.impl_.CanMove()) {} /// Index operator /// @param i the element index. Must be less than `len`. @@ -805,7 +672,7 @@ class VectorRef { /// this is a safe operation. template VectorRef ReinterpretCast() const { - return {*ReinterpretSlice(&slice_)}; + return {slice_.template Reinterpret()}; } /// @returns true if the vector is empty. diff --git a/src/tint/utils/vector_test.cc b/src/tint/utils/vector_test.cc index ed9a97a48c..6245476053 100644 --- a/src/tint/utils/vector_test.cc +++ b/src/tint/utils/vector_test.cc @@ -79,31 +79,6 @@ static_assert(std::is_same_v, const C1*>); static_assert(std::is_same_v, const C1*>); static_assert(std::is_same_v, const C1*>); -static_assert(CanReinterpretSlice, "apply const"); -static_assert(!CanReinterpretSlice, "remove const"); -static_assert(CanReinterpretSlice, "up cast"); -static_assert(CanReinterpretSlice, "up cast"); -static_assert(CanReinterpretSlice, "up cast, apply const"); -static_assert(!CanReinterpretSlice, - "up cast, remove const"); -static_assert(!CanReinterpretSlice, "down cast"); -static_assert(!CanReinterpretSlice, "down cast"); -static_assert(!CanReinterpretSlice, - "down cast, apply const"); -static_assert(!CanReinterpretSlice, - "down cast, remove const"); -static_assert(!CanReinterpretSlice, - "down cast, apply const"); -static_assert(!CanReinterpretSlice, - "down cast, remove const"); -static_assert(!CanReinterpretSlice, "sideways cast"); -static_assert(!CanReinterpretSlice, - "sideways cast"); -static_assert(!CanReinterpretSlice, - "sideways cast, apply const"); -static_assert(!CanReinterpretSlice, - "sideways cast, remove const"); - //////////////////////////////////////////////////////////////////////////////// // TintVectorTest ////////////////////////////////////////////////////////////////////////////////