tint/resolver: Bring back enum suggestions

The dependency graph no longer errors if a symbol cannot be resolved, instead the ResolvedIdentifier now has an unresolved variant.
This is required as the second resolve phase only has the full context of the identifier usage, to provide the hints.

Also: Split Slice out of the utils/vector.h, so it can be used as a lightweight view over static data.

Fixed: tint:1842
Change-Id: I31fa7697790be24c35b7e4fab5ca903c8a7afbba
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/121020
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
This commit is contained in:
Ben Clayton 2023-02-22 17:15:53 +00:00 committed by Dawn LUCI CQ
parent b549b3051e
commit afc53fa942
18 changed files with 682 additions and 389 deletions

View File

@ -218,6 +218,7 @@ libtint_source_set("libtint_base_src") {
"utils/map.h",
"utils/math.h",
"utils/scoped_assignment.h",
"utils/slice.h",
"utils/string.cc",
"utils/string.h",
"utils/unique_allocator.h",
@ -1402,6 +1403,7 @@ if (tint_build_unittests) {
"resolver/type_initializer_validation_test.cc",
"resolver/type_validation_test.cc",
"resolver/uniformity_test.cc",
"resolver/unresolved_identifier_test.cc",
"resolver/validation_test.cc",
"resolver/validator_is_storeable_test.cc",
"resolver/variable_test.cc",
@ -1553,6 +1555,7 @@ if (tint_build_unittests) {
"utils/result_test.cc",
"utils/reverse_test.cc",
"utils/scoped_assignment_test.cc",
"utils/slice_test.cc",
"utils/string_test.cc",
"utils/transform_test.cc",
"utils/unique_allocator_test.cc",

View File

@ -527,6 +527,7 @@ list(APPEND TINT_LIB_SRCS
utils/map.h
utils/math.h
utils/scoped_assignment.h
utils/slice.h
utils/string.cc
utils/string.h
utils/unique_allocator.h
@ -927,6 +928,7 @@ if(TINT_BUILD_TESTS)
resolver/struct_address_space_use_test.cc
resolver/type_initializer_validation_test.cc
resolver/type_validation_test.cc
resolver/unresolved_identifier_test.cc
resolver/validation_test.cc
resolver/validator_is_storeable_test.cc
resolver/variable_test.cc
@ -981,6 +983,7 @@ if(TINT_BUILD_TESTS)
utils/result_test.cc
utils/reverse_test.cc
utils/scoped_assignment_test.cc
utils/slice_test.cc
utils/string_test.cc
utils/transform_test.cc
utils/unique_allocator_test.cc

View File

@ -79,8 +79,6 @@ struct Global;
struct DependencyInfo {
/// The source of the symbol that forms the dependency
Source source;
/// A string describing how the dependency is referenced. e.g. 'calls'
const char* action = nullptr;
};
/// DependencyEdge describes the two Globals used to define a dependency
@ -174,12 +172,12 @@ class DependencyScanner {
Declare(str->name->symbol, str);
for (auto* member : str->members) {
TraverseAttributes(member->attributes);
TraverseTypeExpression(member->type);
TraverseExpression(member->type);
}
},
[&](const ast::Alias* alias) {
Declare(alias->name->symbol, alias);
TraverseTypeExpression(alias->type);
TraverseExpression(alias->type);
},
[&](const ast::Function* func) {
Declare(func->name->symbol, func);
@ -195,9 +193,7 @@ class DependencyScanner {
[&](const ast::Enable*) {
// Enable directives do not affect the dependency graph.
},
[&](const ast::ConstAssert* assertion) {
TraverseValueExpression(assertion->condition);
},
[&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); },
[&](Default) { UnhandledNode(diagnostics_, global->node); });
}
@ -205,12 +201,12 @@ class DependencyScanner {
/// Traverses the variable, performing symbol resolution.
void TraverseVariable(const ast::Variable* v) {
if (auto* var = v->As<ast::Var>()) {
TraverseAddressSpaceExpression(var->declared_address_space);
TraverseAccessExpression(var->declared_access);
TraverseExpression(var->declared_address_space);
TraverseExpression(var->declared_access);
}
TraverseTypeExpression(v->type);
TraverseExpression(v->type);
TraverseAttributes(v->attributes);
TraverseValueExpression(v->initializer);
TraverseExpression(v->initializer);
}
/// Traverses the function, performing symbol resolution and determining global dependencies.
@ -222,10 +218,10 @@ class DependencyScanner {
// with the same identifier as its type.
for (auto* param : func->params) {
TraverseAttributes(param->attributes);
TraverseTypeExpression(param->type);
TraverseExpression(param->type);
}
// Resolve the return type
TraverseTypeExpression(func->return_type);
TraverseExpression(func->return_type);
// Push the scope stack for the parameters and function body.
scope_stack_.Push();
@ -259,29 +255,29 @@ class DependencyScanner {
Switch(
stmt, //
[&](const ast::AssignmentStatement* a) {
TraverseValueExpression(a->lhs);
TraverseValueExpression(a->rhs);
TraverseExpression(a->lhs);
TraverseExpression(a->rhs);
},
[&](const ast::BlockStatement* b) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseStatements(b->statements);
},
[&](const ast::BreakIfStatement* b) { TraverseValueExpression(b->condition); },
[&](const ast::CallStatement* r) { TraverseValueExpression(r->expr); },
[&](const ast::BreakIfStatement* b) { TraverseExpression(b->condition); },
[&](const ast::CallStatement* r) { TraverseExpression(r->expr); },
[&](const ast::CompoundAssignmentStatement* a) {
TraverseValueExpression(a->lhs);
TraverseValueExpression(a->rhs);
TraverseExpression(a->lhs);
TraverseExpression(a->rhs);
},
[&](const ast::ForLoopStatement* l) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseStatement(l->initializer);
TraverseValueExpression(l->condition);
TraverseExpression(l->condition);
TraverseStatement(l->continuing);
TraverseStatement(l->body);
},
[&](const ast::IncrementDecrementStatement* i) { TraverseValueExpression(i->lhs); },
[&](const ast::IncrementDecrementStatement* i) { TraverseExpression(i->lhs); },
[&](const ast::LoopStatement* l) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
@ -289,18 +285,18 @@ class DependencyScanner {
TraverseStatement(l->continuing);
},
[&](const ast::IfStatement* i) {
TraverseValueExpression(i->condition);
TraverseExpression(i->condition);
TraverseStatement(i->body);
if (i->else_statement) {
TraverseStatement(i->else_statement);
}
},
[&](const ast::ReturnStatement* r) { TraverseValueExpression(r->value); },
[&](const ast::ReturnStatement* r) { TraverseExpression(r->value); },
[&](const ast::SwitchStatement* s) {
TraverseValueExpression(s->condition);
TraverseExpression(s->condition);
for (auto* c : s->body) {
for (auto* sel : c->selectors) {
TraverseValueExpression(sel->expr);
TraverseExpression(sel->expr);
}
TraverseStatement(c->body);
}
@ -315,12 +311,10 @@ class DependencyScanner {
[&](const ast::WhileStatement* w) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseValueExpression(w->condition);
TraverseExpression(w->condition);
TraverseStatement(w->body);
},
[&](const ast::ConstAssert* assertion) {
TraverseValueExpression(assertion->condition);
},
[&](const ast::ConstAssert* assertion) { TraverseExpression(assertion->condition); },
[&](Default) {
if (TINT_UNLIKELY((!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement>()))) {
@ -340,70 +334,28 @@ class DependencyScanner {
}
}
/// Traverses the expression @p root_expr for the intended use as a value, performing symbol
/// resolution and determining global dependencies.
void TraverseValueExpression(const ast::Expression* root) {
TraverseExpression(root, "identifier", "references");
}
/// Traverses the expression @p root_expr for the intended use as a type, performing symbol
/// resolution and determining global dependencies.
void TraverseTypeExpression(const ast::Expression* root) {
TraverseExpression(root, "type", "references");
}
/// Traverses the expression @p root_expr for the intended use as an address space, performing
/// symbol resolution and determining global dependencies.
void TraverseAddressSpaceExpression(const ast::Expression* root) {
TraverseExpression(root, "address space", "references");
}
/// Traverses the expression @p root_expr for the intended use as an access, performing symbol
/// resolution and determining global dependencies.
void TraverseAccessExpression(const ast::Expression* root) {
TraverseExpression(root, "access", "references");
}
/// Traverses the expression @p root_expr for the intended use as a call target, performing
/// symbol resolution and determining global dependencies.
void TraverseCallableExpression(const ast::Expression* root) {
TraverseExpression(root, "function", "calls");
}
/// Traverses the expression @p root_expr, performing symbol resolution and determining global
/// dependencies.
void TraverseExpression(const ast::Expression* root_expr,
const char* root_use,
const char* root_action) {
void TraverseExpression(const ast::Expression* root_expr) {
if (!root_expr) {
return;
}
struct Pending {
const ast::Expression* expr;
const char* use;
const char* action;
};
utils::Vector<Pending, 8> pending{{root_expr, root_use, root_action}};
utils::Vector<const ast::Expression*, 8> pending{root_expr};
while (!pending.IsEmpty()) {
auto next = pending.Pop();
ast::TraverseExpressions(next.expr, diagnostics_, [&](const ast::Expression* expr) {
ast::TraverseExpressions(pending.Pop(), diagnostics_, [&](const ast::Expression* expr) {
Switch(
expr,
[&](const ast::IdentifierExpression* e) {
AddDependency(e->identifier, e->identifier->symbol, next.use, next.action);
AddDependency(e->identifier, e->identifier->symbol);
if (auto* tmpl_ident = e->identifier->As<ast::TemplatedIdentifier>()) {
for (auto* arg : tmpl_ident->arguments) {
pending.Push({arg, "identifier", "references"});
pending.Push(arg);
}
}
},
[&](const ast::CallExpression* call) {
TraverseCallableExpression(call->target);
},
[&](const ast::BitcastExpression* cast) {
TraverseTypeExpression(cast->type);
});
[&](const ast::CallExpression* call) { TraverseExpression(call->target); },
[&](const ast::BitcastExpression* cast) { TraverseExpression(cast->type); });
return ast::TraverseAction::Descend;
});
}
@ -423,42 +375,42 @@ class DependencyScanner {
bool handled = Switch(
attr,
[&](const ast::BindingAttribute* binding) {
TraverseValueExpression(binding->expr);
TraverseExpression(binding->expr);
return true;
},
[&](const ast::BuiltinAttribute* builtin) {
TraverseExpression(builtin->builtin, "builtin", "references");
TraverseExpression(builtin->builtin);
return true;
},
[&](const ast::GroupAttribute* group) {
TraverseValueExpression(group->expr);
TraverseExpression(group->expr);
return true;
},
[&](const ast::IdAttribute* id) {
TraverseValueExpression(id->expr);
TraverseExpression(id->expr);
return true;
},
[&](const ast::InterpolateAttribute* interpolate) {
TraverseExpression(interpolate->type, "interpolation type", "references");
TraverseExpression(interpolate->sampling, "interpolation sampling", "references");
TraverseExpression(interpolate->type);
TraverseExpression(interpolate->sampling);
return true;
},
[&](const ast::LocationAttribute* loc) {
TraverseValueExpression(loc->expr);
TraverseExpression(loc->expr);
return true;
},
[&](const ast::StructMemberAlignAttribute* align) {
TraverseValueExpression(align->expr);
TraverseExpression(align->expr);
return true;
},
[&](const ast::StructMemberSizeAttribute* size) {
TraverseValueExpression(size->expr);
TraverseExpression(size->expr);
return true;
},
[&](const ast::WorkgroupAttribute* wg) {
TraverseValueExpression(wg->x);
TraverseValueExpression(wg->y);
TraverseValueExpression(wg->z);
TraverseExpression(wg->x);
TraverseExpression(wg->y);
TraverseExpression(wg->z);
return true;
});
if (handled) {
@ -476,10 +428,7 @@ class DependencyScanner {
}
/// Adds the dependency from @p from to @p to, erroring if @p to cannot be resolved.
void AddDependency(const ast::Identifier* from,
Symbol to,
const char* use,
const char* action) {
void AddDependency(const ast::Identifier* from, Symbol to) {
auto* resolved = scope_stack_.Get(to);
if (!resolved) {
auto s = symbols_.NameFor(to);
@ -521,13 +470,14 @@ class DependencyScanner {
return;
}
UnknownSymbol(to, from->source, use);
// Unresolved.
graph_.resolved_identifiers.Add(from, UnresolvedIdentifier{s});
return;
}
if (auto global = globals_.Find(to); global && (*global)->node == resolved) {
if (dependency_edges_.Add(DependencyEdge{current_global_, *global},
DependencyInfo{from->source, action})) {
DependencyInfo{from->source})) {
current_global_->deps.Push(*global);
}
}
@ -535,12 +485,6 @@ class DependencyScanner {
graph_.resolved_identifiers.Add(from, ResolvedIdentifier(resolved));
}
/// Appends an error to the diagnostics that the given symbol cannot be resolved.
void UnknownSymbol(Symbol name, Source source, const char* use) {
AddError(diagnostics_, "unknown " + std::string(use) + ": '" + symbols_.NameFor(name) + "'",
source);
}
using VariableMap = utils::Hashmap<Symbol, const ast::Variable*, 32>;
const SymbolTable& symbols_;
const GlobalMap& globals_;
@ -787,7 +731,7 @@ struct DependencyAnalysis {
auto* to = (i + 1 < stack.Length()) ? stack[i + 1] : stack[loop_start];
auto info = DepInfoFor(from, to);
AddNote(diagnostics_,
KindOf(from->node) + " '" + NameOf(from->node) + "' " + info.action + " " +
KindOf(from->node) + " '" + NameOf(from->node) + "' references " +
KindOf(to->node) + " '" + NameOf(to->node) + "' here",
info.source);
}
@ -831,8 +775,7 @@ struct DependencyAnalysis {
/// Global map, keyed by name. Populated by GatherGlobals().
GlobalMap globals_;
/// Map of DependencyEdge to DependencyInfo. Populated by
/// DetermineDependencies().
/// Map of DependencyEdge to DependencyInfo. Populated by DetermineDependencies().
DependencyEdges dependency_edges_;
/// Globals in declaration order. Populated by GatherGlobals().
@ -857,9 +800,6 @@ bool DependencyGraph::Build(const ast::Module& module,
}
std::string ResolvedIdentifier::String(const SymbolTable& symbols, diag::List& diagnostics) const {
if (!Resolved()) {
return "<unresolved symbol>";
}
if (auto* node = Node()) {
return Switch(
node,
@ -911,6 +851,10 @@ std::string ResolvedIdentifier::String(const SymbolTable& symbols, diag::List& d
if (auto fmt = TexelFormat(); fmt != builtin::TexelFormat::kUndefined) {
return "texel format '" + utils::ToString(fmt) + "'";
}
if (auto* unresolved = Unresolved()) {
return "unresolved identifier '" + unresolved->name + "'";
}
TINT_UNREACHABLE(Resolver, diagnostics) << "unhandled ResolvedIdentifier";
return "<unknown>";
}

View File

@ -32,8 +32,15 @@
namespace tint::resolver {
/// UnresolvedIdentifier is the variant value used by ResolvedIdentifier
struct UnresolvedIdentifier {
/// Name of the unresolved identifier
std::string name;
};
/// ResolvedIdentifier holds the resolution of an ast::Identifier.
/// Can hold one of:
/// - UnresolvedIdentifier
/// - const ast::TypeDecl* (as const ast::Node*)
/// - const ast::Variable* (as const ast::Node*)
/// - const ast::Function* (as const ast::Node*)
@ -47,15 +54,18 @@ namespace tint::resolver {
/// - builtin::TexelFormat
class ResolvedIdentifier {
public:
ResolvedIdentifier() = default;
/// Constructor
/// @param value the resolved identifier value
template <typename T>
ResolvedIdentifier(T value) : value_(value) {} // NOLINT(runtime/explicit)
/// @return true if the ResolvedIdentifier holds a value (successfully resolved)
bool Resolved() const { return !std::holds_alternative<std::monostate>(value_); }
/// @return the UnresolvedIdentifier if the identifier was not resolved
const UnresolvedIdentifier* Unresolved() const {
if (auto n = std::get_if<UnresolvedIdentifier>(&value_)) {
return n;
}
return nullptr;
}
/// @return the node pointer if the ResolvedIdentifier holds an AST node, otherwise nullptr
const ast::Node* Node() const {
@ -160,7 +170,7 @@ class ResolvedIdentifier {
std::string String(const SymbolTable& symbols, diag::List& diagnostics) const;
private:
std::variant<std::monostate,
std::variant<UnresolvedIdentifier,
const ast::Node*,
sem::BuiltinType,
builtin::Access,

View File

@ -309,47 +309,6 @@ std::ostream& operator<<(std::ostream& out, SymbolUseKind kind) {
return out << "<unknown>";
}
/// @returns the the diagnostic message name used for the given use
std::string DiagString(SymbolUseKind kind) {
switch (kind) {
case SymbolUseKind::GlobalVarType:
case SymbolUseKind::GlobalConstType:
case SymbolUseKind::AliasType:
case SymbolUseKind::StructMemberType:
case SymbolUseKind::ParameterType:
case SymbolUseKind::LocalVarType:
case SymbolUseKind::LocalLetType:
case SymbolUseKind::NestedLocalVarType:
case SymbolUseKind::NestedLocalLetType:
return "type";
case SymbolUseKind::GlobalVarArrayElemType:
case SymbolUseKind::GlobalVarVectorElemType:
case SymbolUseKind::GlobalVarMatrixElemType:
case SymbolUseKind::GlobalVarSampledTexElemType:
case SymbolUseKind::GlobalVarMultisampledTexElemType:
case SymbolUseKind::GlobalConstArrayElemType:
case SymbolUseKind::GlobalConstVectorElemType:
case SymbolUseKind::GlobalConstMatrixElemType:
case SymbolUseKind::LocalVarArrayElemType:
case SymbolUseKind::LocalVarVectorElemType:
case SymbolUseKind::LocalVarMatrixElemType:
case SymbolUseKind::GlobalVarValue:
case SymbolUseKind::GlobalVarArraySizeValue:
case SymbolUseKind::GlobalConstValue:
case SymbolUseKind::GlobalConstArraySizeValue:
case SymbolUseKind::LocalVarValue:
case SymbolUseKind::LocalVarArraySizeValue:
case SymbolUseKind::LocalLetValue:
case SymbolUseKind::NestedLocalVarValue:
case SymbolUseKind::NestedLocalLetValue:
case SymbolUseKind::WorkgroupSizeValue:
return "identifier";
case SymbolUseKind::CallFunction:
return "function";
}
return "<unknown>";
}
/// @returns the declaration scope depth for the symbol declaration kind.
/// Globals are at depth 0, parameters and locals are at depth 1,
/// nested locals are at depth 2.
@ -783,10 +742,16 @@ TEST_P(ResolverDependencyGraphUndeclaredSymbolTest, Test) {
// Build a use of a non-existent symbol
SymbolTestHelper helper(this);
helper.Add(use_kind, symbol, Source{{56, 78}});
auto* ident = helper.Add(use_kind, symbol, Source{{56, 78}});
helper.Build();
Build("56:78 error: unknown " + DiagString(use_kind) + ": 'SYMBOL'");
auto graph = Build();
auto resolved_identifier = graph.resolved_identifiers.Find(ident);
ASSERT_NE(resolved_identifier, nullptr);
auto* unresolved = resolved_identifier->Unresolved();
ASSERT_NE(unresolved, nullptr);
EXPECT_EQ(unresolved->name, "SYMBOL");
}
INSTANTIATE_TEST_SUITE_P(Types,
@ -826,14 +791,28 @@ TEST_F(ResolverDependencyGraphDeclSelfUse, GlobalConst) {
TEST_F(ResolverDependencyGraphDeclSelfUse, LocalVar) {
const Symbol symbol = Sym("SYMBOL");
WrapInFunction(Decl(Var(symbol, ty.i32(), Mul(Expr(Source{{12, 34}}, symbol), 123_i))));
Build("12:34 error: unknown identifier: 'SYMBOL'");
auto* ident = Ident(Source{{12, 34}}, symbol);
WrapInFunction(Decl(Var(symbol, ty.i32(), Mul(Expr(ident), 123_i))));
auto graph = Build();
auto resolved_identifier = graph.resolved_identifiers.Find(ident);
ASSERT_TRUE(resolved_identifier);
auto* unresolved = resolved_identifier->Unresolved();
ASSERT_NE(unresolved, nullptr);
EXPECT_EQ(unresolved->name, "SYMBOL");
}
TEST_F(ResolverDependencyGraphDeclSelfUse, LocalLet) {
const Symbol symbol = Sym("SYMBOL");
WrapInFunction(Decl(Let(symbol, ty.i32(), Mul(Expr(Source{{12, 34}}, symbol), 123_i))));
Build("12:34 error: unknown identifier: 'SYMBOL'");
auto* ident = Ident(Source{{12, 34}}, symbol);
WrapInFunction(Decl(Let(symbol, ty.i32(), Mul(Expr(ident), 123_i))));
auto graph = Build();
auto resolved_identifier = graph.resolved_identifiers.Find(ident);
ASSERT_TRUE(resolved_identifier);
auto* unresolved = resolved_identifier->Unresolved();
ASSERT_NE(unresolved, nullptr);
EXPECT_EQ(unresolved->name, "SYMBOL");
}
} // namespace undeclared_tests
@ -852,7 +831,7 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, DirectCall) {
utils::Vector{CallStmt(Call(Ident(Source{{56, 78}}, "main")))});
Build(R"(12:34 error: cyclic dependency found: 'main' -> 'main'
56:78 note: function 'main' calls function 'main' here)");
56:78 note: function 'main' references function 'main' here)");
}
TEST_F(ResolverDependencyGraphCyclicRefTest, IndirectCall) {
@ -876,9 +855,9 @@ TEST_F(ResolverDependencyGraphCyclicRefTest, IndirectCall) {
utils::Vector{CallStmt(Call(Ident(Source{{5, 10}}, "c")))});
Build(R"(5:1 error: cyclic dependency found: 'b' -> 'c' -> 'd' -> 'b'
5:10 note: function 'b' calls function 'c' here
4:10 note: function 'c' calls function 'd' here
3:10 note: function 'd' calls function 'b' here)");
5:10 note: function 'b' references function 'c' here
4:10 note: function 'c' references function 'd' here
3:10 note: function 'd' references function 'b' here)");
}
TEST_F(ResolverDependencyGraphCyclicRefTest, Alias_Direct) {
@ -1160,17 +1139,22 @@ TEST_P(ResolverDependencyGraphResolveToUserDeclTest, Test) {
// If the declaration is visible to the use, then we expect the analysis to
// succeed.
bool expect_pass = ScopeDepth(decl_kind) <= ScopeDepth(use_kind);
auto graph = Build(expect_pass ? "" : "56:78 error: unknown identifier: 'SYMBOL'");
bool expect_resolved = ScopeDepth(decl_kind) <= ScopeDepth(use_kind);
auto graph = Build();
if (expect_pass) {
auto resolved_identifier = graph.resolved_identifiers.Find(use);
ASSERT_TRUE(resolved_identifier);
if (expect_resolved) {
// Check that the use resolves to the declaration
auto resolved_identifier = graph.resolved_identifiers.Find(use);
ASSERT_TRUE(resolved_identifier);
auto* resolved_node = resolved_identifier->Node();
EXPECT_EQ(resolved_node, decl)
<< "resolved: " << (resolved_node ? resolved_node->TypeInfo().name : "<null>") << "\n"
<< "decl: " << decl->TypeInfo().name;
} else {
auto* unresolved = resolved_identifier->Unresolved();
ASSERT_NE(unresolved, nullptr);
EXPECT_EQ(unresolved->name, "SYMBOL");
}
}

View File

@ -1070,11 +1070,14 @@ TEST_P(ResolverFunctionParameterValidationTest, AddressSpaceNoExtension) {
ss << param.address_space;
EXPECT_FALSE(r()->Resolve());
if (param.expectation == Expectation::kInvalid) {
EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: '" + ss.str() + "'");
std::string err = R"(12:34 error: unresolved address space '${addr_space}'
12:34 note: Possible values: 'function', 'private', 'push_constant', 'storage', 'uniform', 'workgroup')";
err = utils::ReplaceAll(err, "${addr_space}", utils::ToString(param.address_space));
EXPECT_EQ(r()->error(), err);
} else {
EXPECT_EQ(r()->error(),
"12:34 error: function parameter of pointer type cannot be in '" + ss.str() +
"' address space");
"12:34 error: function parameter of pointer type cannot be in '" +
utils::ToString(param.address_space) + "' address space");
}
}
}
@ -1091,8 +1094,10 @@ TEST_P(ResolverFunctionParameterValidationTest, AddressSpaceWithExtension) {
} else {
EXPECT_FALSE(r()->Resolve());
if (param.expectation == Expectation::kInvalid) {
EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: '" +
utils::ToString(param.address_space) + "'");
std::string err = R"(12:34 error: unresolved address space '${addr_space}'
12:34 note: Possible values: 'function', 'private', 'push_constant', 'storage', 'uniform', 'workgroup')";
err = utils::ReplaceAll(err, "${addr_space}", utils::ToString(param.address_space));
EXPECT_EQ(r()->error(), err);
} else {
EXPECT_EQ(r()->error(),
"12:34 error: function parameter of pointer type cannot be in '" +

View File

@ -1418,7 +1418,8 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) {
for (auto* expr : utils::Reverse(sorted)) {
auto* sem_expr = Switch(
expr, [&](const ast::IndexAccessorExpression* array) { return IndexAccessor(array); },
expr, //
[&](const ast::IndexAccessorExpression* array) { return IndexAccessor(array); },
[&](const ast::BinaryExpression* bin_op) { return Binary(bin_op); },
[&](const ast::BitcastExpression* bitcast) { return Bitcast(bitcast); },
[&](const ast::CallExpression* call) { return Call(call); },
@ -1488,10 +1489,12 @@ sem::ValueExpression* Resolver::ValueExpression(const ast::Expression* expr) {
}
sem::TypeExpression* Resolver::TypeExpression(const ast::Expression* expr) {
identifier_resolve_hint_ = {expr, "type"};
return sem_.AsTypeExpression(Expression(expr));
}
sem::FunctionExpression* Resolver::FunctionExpression(const ast::Expression* expr) {
identifier_resolve_hint_ = {expr, "call target"};
return sem_.AsFunctionExpression(Expression(expr));
}
@ -1505,31 +1508,38 @@ type::Type* Resolver::Type(const ast::Expression* ast) {
sem::BuiltinEnumExpression<builtin::AddressSpace>* Resolver::AddressSpaceExpression(
const ast::Expression* expr) {
identifier_resolve_hint_ = {expr, "address space", builtin::kAddressSpaceStrings};
return sem_.AsAddressSpace(Expression(expr));
}
sem::BuiltinEnumExpression<builtin::BuiltinValue>* Resolver::BuiltinValueExpression(
const ast::Expression* expr) {
identifier_resolve_hint_ = {expr, "builtin value", builtin::kBuiltinValueStrings};
return sem_.AsBuiltinValue(Expression(expr));
}
sem::BuiltinEnumExpression<builtin::TexelFormat>* Resolver::TexelFormatExpression(
const ast::Expression* expr) {
identifier_resolve_hint_ = {expr, "texel format", builtin::kTexelFormatStrings};
return sem_.AsTexelFormat(Expression(expr));
}
sem::BuiltinEnumExpression<builtin::Access>* Resolver::AccessExpression(
const ast::Expression* expr) {
identifier_resolve_hint_ = {expr, "access", builtin::kAccessStrings};
return sem_.AsAccess(Expression(expr));
}
sem::BuiltinEnumExpression<builtin::InterpolationSampling>* Resolver::InterpolationSampling(
const ast::Expression* expr) {
identifier_resolve_hint_ = {expr, "interpolation sampling",
builtin::kInterpolationSamplingStrings};
return sem_.AsInterpolationSampling(Expression(expr));
}
sem::BuiltinEnumExpression<builtin::InterpolationType>* Resolver::InterpolationType(
const ast::Expression* expr) {
identifier_resolve_hint_ = {expr, "interpolation type", builtin::kInterpolationTypeStrings};
return sem_.AsInterpolationType(Expression(expr));
}
@ -2196,6 +2206,11 @@ sem::Call* Resolver::Call(const ast::CallExpression* expr) {
return ty_init_or_conv(ty);
}
if (auto* unresolved = resolved->Unresolved()) {
AddError("unresolved call target '" + unresolved->name + "'", expr->source);
return nullptr;
}
ErrorMismatchedResolvedIdentifier(ident->source, *resolved, "call target");
return nullptr;
}();
@ -2541,11 +2556,11 @@ type::Type* Resolver::BuiltinType(builtin::Builtin builtin_ty, const ast::Identi
return nullptr;
}
auto* format = sem_.AsTexelFormat(Expression(tmpl_ident->arguments[0]));
auto* format = TexelFormatExpression(tmpl_ident->arguments[0]);
if (TINT_UNLIKELY(!format)) {
return nullptr;
}
auto* access = sem_.AsAccess(Expression(tmpl_ident->arguments[1]));
auto* access = AccessExpression(tmpl_ident->arguments[1]);
if (TINT_UNLIKELY(!access)) {
return nullptr;
}
@ -3030,6 +3045,30 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
expr, current_statement_, fmt);
}
if (auto* unresolved = resolved->Unresolved()) {
if (identifier_resolve_hint_.expression == expr) {
AddError("unresolved " + std::string(identifier_resolve_hint_.usage) + " '" +
unresolved->name + "'",
expr->source);
if (!identifier_resolve_hint_.suggestions.IsEmpty()) {
// Filter out suggestions that have a leading underscore.
utils::Vector<const char*, 8> filtered;
for (auto* str : identifier_resolve_hint_.suggestions) {
if (str[0] != '_') {
filtered.Push(str);
}
}
std::ostringstream msg;
utils::SuggestAlternatives(unresolved->name,
filtered.Slice().Reinterpret<char const* const>(), msg);
AddNote(msg.str(), expr->source);
}
} else {
AddError("unresolved identifier '" + unresolved->name + "'", expr->source);
}
return nullptr;
}
TINT_UNREACHABLE(Resolver, diagnostics_)
<< "unhandled resolved identifier: " << resolved->String(builder_->Symbols(), diagnostics_);
return nullptr;

View File

@ -530,6 +530,17 @@ class Resolver {
std::unordered_set<const sem::Variable*> parameter_reads;
};
/// A hint for the usage of an identifier expression.
/// Used to provide more informative error diagnostics on resolution failure.
struct IdentifierResolveHint {
/// The expression this hint applies to
const ast::Expression* expression = nullptr;
/// The usage of the identifier.
const char* usage = "identifier";
/// Suggested strings if the identifier failed to resolve
utils::Slice<char const* const> suggestions = utils::Empty;
};
ProgramBuilder* const builder_;
diag::List& diagnostics_;
ConstEval const_eval_;
@ -555,6 +566,7 @@ class Resolver {
utils::Hashmap<const ast::Expression*, const ast::BinaryExpression*, 8>
logical_binary_lhs_to_parent_;
utils::Hashset<const ast::Expression*, 8> skip_const_eval_;
IdentifierResolveHint identifier_resolve_hint_;
};
} // namespace tint::resolver

View File

@ -1083,7 +1083,7 @@ TEST_P(StorageTextureDimensionTest, All) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
} else {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: unknown type: '" + std::string(params.name) + "'");
EXPECT_EQ(r()->error(), "12:34 error: unresolved type '" + std::string(params.name) + "'");
}
}
INSTANTIATE_TEST_SUITE_P(ResolverTypeValidationTest,

View File

@ -0,0 +1,108 @@
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gmock/gmock.h"
#include "src/tint/resolver/resolver_test_helper.h"
using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver {
namespace {
using ResolverUnresolvedIdentifierSuggestions = ResolverTest;
TEST_F(ResolverUnresolvedIdentifierSuggestions, AddressSpace) {
AST().AddGlobalVariable(create<ast::Var>(
Ident("v"), // name
ty.i32(), // type
Expr(Source{{12, 34}}, "privte"), // declared_address_space
nullptr, // declared_access
nullptr, // initializer
utils::Empty // attributes
));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved address space 'privte'
12:34 note: Did you mean 'private'?
Possible values: 'function', 'private', 'push_constant', 'storage', 'uniform', 'workgroup')");
}
TEST_F(ResolverUnresolvedIdentifierSuggestions, BuiltinValue) {
Func("f",
utils::Vector{
Param("p", ty.i32(), utils::Vector{Builtin(Expr(Source{{12, 34}}, "positon"))})},
ty.void_(), utils::Empty);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved builtin value 'positon'
12:34 note: Did you mean 'position'?
Possible values: 'frag_depth', 'front_facing', 'global_invocation_id', 'instance_index', 'local_invocation_id', 'local_invocation_index', 'num_workgroups', 'position', 'sample_index', 'sample_mask', 'vertex_index', 'workgroup_id')");
}
TEST_F(ResolverUnresolvedIdentifierSuggestions, TexelFormat) {
GlobalVar("v", ty("texture_storage_1d", Expr(Source{{12, 34}}, "rba8unorm"), "read"));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved texel format 'rba8unorm'
12:34 note: Did you mean 'rgba8unorm'?
Possible values: 'bgra8unorm', 'r32float', 'r32sint', 'r32uint', 'rg32float', 'rg32sint', 'rg32uint', 'rgba16float', 'rgba16sint', 'rgba16uint', 'rgba32float', 'rgba32sint', 'rgba32uint', 'rgba8sint', 'rgba8snorm', 'rgba8uint', 'rgba8unorm')");
}
TEST_F(ResolverUnresolvedIdentifierSuggestions, AccessMode) {
AST().AddGlobalVariable(create<ast::Var>(Ident("v"), // name
ty.i32(), // type
Expr("private"), // declared_address_space
Expr(Source{{12, 34}}, "reed"), // declared_access
nullptr, // initializer
utils::Empty // attributes
));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved access 'reed'
12:34 note: Did you mean 'read'?
Possible values: 'read', 'read_write', 'write')");
}
TEST_F(ResolverUnresolvedIdentifierSuggestions, InterpolationSampling) {
Structure("s", utils::Vector{
Member("m", ty.vec4<f32>(),
utils::Vector{
Interpolate(builtin::InterpolationType::kLinear,
Expr(Source{{12, 34}}, "centre")),
}),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved interpolation sampling 'centre'
12:34 note: Did you mean 'center'?
Possible values: 'center', 'centroid', 'sample')");
}
TEST_F(ResolverUnresolvedIdentifierSuggestions, InterpolationType) {
Structure("s", utils::Vector{
Member("m", ty.vec4<f32>(),
utils::Vector{
Interpolate(Expr(Source{{12, 34}}, "liner")),
}),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved interpolation type 'liner'
12:34 note: Did you mean 'linear'?
Possible values: 'flat', 'linear', 'perspective')");
}
} // namespace
} // namespace tint::resolver

View File

@ -168,7 +168,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariable_Fail) {
WrapInFunction(assign);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'b')");
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'b')");
}
TEST_F(ResolverValidationTest, UsingUndefinedVariableInBlockStatement_Fail) {
@ -183,7 +183,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableInBlockStatement_Fail) {
WrapInFunction(body);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'b')");
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'b')");
}
TEST_F(ResolverValidationTest, UsingUndefinedVariableGlobalVariable_Pass) {
@ -223,7 +223,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableInnerScope_Fail) {
WrapInFunction(outer_body);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'a')");
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'a')");
}
TEST_F(ResolverValidationTest, UsingUndefinedVariableOuterScope_Pass) {
@ -263,7 +263,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableDifferentScope_Fail) {
WrapInFunction(outer_body);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(12:34 error: unknown identifier: 'a')");
EXPECT_EQ(r()->error(), R"(12:34 error: unresolved identifier 'a')");
}
TEST_F(ResolverValidationTest, AddressSpace_FunctionVariableWorkgroupClass) {

205
src/tint/utils/slice.h Normal file
View File

@ -0,0 +1,205 @@
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_UTILS_SLICE_H_
#define SRC_TINT_UTILS_SLICE_H_
#include <cstdint>
#include <iterator>
#include "src/tint/castable.h"
#include "src/tint/traits.h"
namespace tint::utils {
/// A type used to indicate an empty array.
struct EmptyType {};
/// An instance of the EmptyType.
static constexpr EmptyType Empty;
/// Mode enumerator for ReinterpretSlice
enum class ReinterpretMode {
/// Only upcasts of pointers are permitted
kSafe,
/// Potentially unsafe downcasts of pointers are also permitted
kUnsafe,
};
namespace detail {
template <typename TO, typename FROM>
static constexpr bool ConstRemoved = std::is_const_v<FROM> && !std::is_const_v<TO>;
/// Private implementation of tint::utils::CanReinterpretSlice.
/// Specialized for the case of TO equal to FROM, which is the common case, and avoids inspection of
/// the base classes, which can be troublesome if the slice is of an incomplete type.
template <ReinterpretMode MODE, typename TO, typename FROM>
struct CanReinterpretSlice {
private:
using TO_EL = std::remove_pointer_t<std::decay_t<TO>>;
using FROM_EL = std::remove_pointer_t<std::decay_t<FROM>>;
public:
/// @see utils::CanReinterpretSlice
static constexpr bool value =
// const can only be applied, not removed
!ConstRemoved<TO, FROM> &&
// Both TO and FROM are the same type (ignoring const)
(std::is_same_v<std::remove_const_t<TO>, std::remove_const_t<FROM>> ||
// Both TO and FROM are pointers...
((std::is_pointer_v<TO> && std::is_pointer_v<FROM>)&&
// const can only be applied to element type, not removed
!ConstRemoved<TO_EL, FROM_EL> &&
// Either:
// * Both the pointer elements are of the same type (ignoring const)
// * Both the pointer elements are both Castable, and MODE is kUnsafe, or FROM is of,
// or
// derives from TO
(std::is_same_v<std::remove_const_t<FROM_EL>, std::remove_const_t<TO_EL>> ||
(IsCastable<FROM_EL, TO_EL> &&
(MODE == ReinterpretMode::kUnsafe || traits::IsTypeOrDerived<FROM_EL, TO_EL>)))));
};
/// Specialization of 'CanReinterpretSlice' for when TO and FROM are equal types.
template <typename T, ReinterpretMode MODE>
struct CanReinterpretSlice<MODE, T, T> {
/// Always `true` as TO and FROM are the same type.
static constexpr bool value = true;
};
} // namespace detail
/// Evaluates whether a `Slice<FROM>` and be reinterpreted as a `Slice<TO>`.
/// Slices can be reinterpreted if:
/// * TO has the same or more 'constness' than FROM.
/// * And either:
/// * `FROM` and `TO` are pointers to the same type
/// * `FROM` and `TO` are pointers to CastableBase (or derived), and the pointee type of `TO` is of
/// the same type as, or is an ancestor of the pointee type of `FROM`.
template <ReinterpretMode MODE, typename TO, typename FROM>
static constexpr bool CanReinterpretSlice = detail::CanReinterpretSlice<MODE, TO, FROM>::value;
/// A slice represents a contigious array of elements of type T.
template <typename T>
struct Slice {
/// Type of `T`.
using value_type = T;
/// The pointer to the first element in the slice
T* data = nullptr;
/// The total number of elements in the slice
size_t len = 0;
/// The total capacity of the backing store for the slice
size_t cap = 0;
/// Constructor
Slice() = default;
/// Constructor
Slice(EmptyType) {} // NOLINT
/// Constructor
/// @param d pointer to the first element in the slice
/// @param l total number of elements in the slice
/// @param c total capacity of the backing store for the slice
Slice(T* d, size_t l, size_t c) : data(d), len(l), cap(c) {}
/// Constructor
/// @param elements c-array of elements
template <size_t N>
Slice(T (&elements)[N]) // NOLINT
: data(elements), len(N), cap(N) {}
/// Reinterprets this slice as `const Slice<TO>&`
/// @returns the reinterpreted slice
/// @see CanReinterpretSlice
template <typename TO, ReinterpretMode MODE = ReinterpretMode::kSafe>
const Slice<TO>& Reinterpret() const {
static_assert(CanReinterpretSlice<MODE, TO, T>);
return *Bitcast<const Slice<TO>*>(this);
}
/// Reinterprets this slice as `Slice<TO>&`
/// @returns the reinterpreted slice
/// @see CanReinterpretSlice
template <typename TO, ReinterpretMode MODE = ReinterpretMode::kSafe>
Slice<TO>& Reinterpret() {
static_assert(CanReinterpretSlice<MODE, TO, T>);
return *Bitcast<Slice<TO>*>(this);
}
/// @return true if the slice length is zero
bool IsEmpty() const { return len == 0; }
/// Index operator
/// @param i the element index. Must be less than `len`.
/// @returns a reference to the i'th element.
T& operator[](size_t i) { return data[i]; }
/// Index operator
/// @param i the element index. Must be less than `len`.
/// @returns a reference to the i'th element.
const T& operator[](size_t i) const { return data[i]; }
/// @returns a reference to the first element in the vector
T& Front() { return data[0]; }
/// @returns a reference to the first element in the vector
const T& Front() const { return data[0]; }
/// @returns a reference to the last element in the vector
T& Back() { return data[len - 1]; }
/// @returns a reference to the last element in the vector
const T& Back() const { return data[len - 1]; }
/// @returns a pointer to the first element in the vector
T* begin() { return data; }
/// @returns a pointer to the first element in the vector
const T* begin() const { return data; }
/// @returns a pointer to one past the last element in the vector
T* end() { return data + len; }
/// @returns a pointer to one past the last element in the vector
const T* end() const { return data + len; }
/// @returns a reverse iterator starting with the last element in the vector
auto rbegin() { return std::reverse_iterator<T*>(end()); }
/// @returns a reverse iterator starting with the last element in the vector
auto rbegin() const { return std::reverse_iterator<const T*>(end()); }
/// @returns the end for a reverse iterator
auto rend() { return std::reverse_iterator<T*>(begin()); }
/// @returns the end for a reverse iterator
auto rend() const { return std::reverse_iterator<const T*>(begin()); }
};
/// Deduction guide for Slice from c-array
template <typename T, size_t N>
Slice(T (&elements)[N]) -> Slice<T>;
} // namespace tint::utils
#endif // SRC_TINT_UTILS_SLICE_H_

View File

@ -0,0 +1,131 @@
// Copyright 2023 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/utils/slice.h"
#include "gmock/gmock.h"
namespace tint::utils {
namespace {
class C0 : public Castable<C0> {};
class C1 : public Castable<C1, C0> {};
class C2a : public Castable<C2a, C1> {};
class C2b : public Castable<C2b, C1> {};
////////////////////////////////////////////////////////////////////////////////
// Static asserts
////////////////////////////////////////////////////////////////////////////////
// Non-pointer
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, int, int>, "same type");
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, int const, int>, "apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, int, int const>, "remove const");
// Non-castable pointers
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, int* const, int*>, "apply ptr const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, int*, int* const>, "remove ptr const");
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, int const*, int*>, "apply el const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, int*, int const*>, "remove el const");
// Castable
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, const C0*, C0*>, "apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C0*, const C0*>, "remove const");
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, C0*, C1*>, "up cast");
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, const C0*, const C1*>, "up cast");
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, const C0*, C1*>, "up cast, apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C0*, const C1*>,
"up cast, remove const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C1*, C0*>, "down cast");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C1*, const C0*>, "down cast");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C1*, C0*>,
"down cast, apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C1*, const C0*>,
"down cast, remove const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C1*, C0*>,
"down cast, apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C1*, const C0*>,
"down cast, remove const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C2a*, C2b*>, "sideways cast");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C2a*, const C2b*>,
"sideways cast");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C2a*, C2b*>,
"sideways cast, apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C2a*, const C2b*>,
"sideways cast, remove const");
TEST(TintSliceTest, Ctor) {
Slice<int> slice;
EXPECT_EQ(slice.data, nullptr);
EXPECT_EQ(slice.len, 0u);
EXPECT_EQ(slice.cap, 0u);
EXPECT_TRUE(slice.IsEmpty());
}
TEST(TintSliceTest, CtorEmpty) {
Slice<int> slice{Empty};
EXPECT_EQ(slice.data, nullptr);
EXPECT_EQ(slice.len, 0u);
EXPECT_EQ(slice.cap, 0u);
EXPECT_TRUE(slice.IsEmpty());
}
TEST(TintSliceTest, CtorCArray) {
int elements[] = {1, 2, 3};
auto slice = Slice{elements};
EXPECT_EQ(slice.data, elements);
EXPECT_EQ(slice.len, 3u);
EXPECT_EQ(slice.cap, 3u);
EXPECT_FALSE(slice.IsEmpty());
}
TEST(TintSliceTest, Index) {
int elements[] = {1, 2, 3};
auto slice = Slice{elements};
EXPECT_EQ(slice[0], 1);
EXPECT_EQ(slice[1], 2);
EXPECT_EQ(slice[2], 3);
}
TEST(TintSliceTest, Front) {
int elements[] = {1, 2, 3};
auto slice = Slice{elements};
EXPECT_EQ(slice.Front(), 1);
}
TEST(TintSliceTest, Back) {
int elements[] = {1, 2, 3};
auto slice = Slice{elements};
EXPECT_EQ(slice.Back(), 3);
}
TEST(TintSliceTest, BeginEnd) {
int elements[] = {1, 2, 3};
auto slice = Slice{elements};
EXPECT_THAT(slice, testing::ElementsAre(1, 2, 3));
}
TEST(TintSliceTest, ReverseBeginEnd) {
int elements[] = {1, 2, 3};
auto slice = Slice{elements};
size_t i = 0;
for (auto it = slice.rbegin(); it != slice.rend(); it++) {
EXPECT_EQ(*it, elements[2 - i]);
i++;
}
}
} // namespace
} // namespace tint::utils

View File

@ -48,4 +48,36 @@ size_t Distance(std::string_view str_a, std::string_view str_b) {
return at(len_a, len_b);
}
void SuggestAlternatives(std::string_view got,
Slice<char const* const> strings,
std::ostringstream& ss) {
// If the string typed was within kSuggestionDistance of one of the possible enum values,
// suggest that. Don't bother with suggestions if the string was extremely long.
constexpr size_t kSuggestionDistance = 5;
constexpr size_t kSuggestionMaxLength = 64;
if (!got.empty() && got.size() < kSuggestionMaxLength) {
size_t candidate_dist = kSuggestionDistance;
const char* candidate = nullptr;
for (auto* str : strings) {
auto dist = utils::Distance(str, got);
if (dist < candidate_dist) {
candidate = str;
candidate_dist = dist;
}
}
if (candidate) {
ss << "Did you mean '" << candidate << "'?\n";
}
}
// List all the possible enumerator values
ss << "Possible values: ";
for (auto* str : strings) {
if (str != strings[0]) {
ss << ", ";
}
ss << "'" << str << "'";
}
}
} // namespace tint::utils

View File

@ -19,6 +19,8 @@
#include <string>
#include <variant>
#include "src/tint/utils/slice.h"
namespace tint::utils {
/// @param str the string to apply replacements to
@ -66,42 +68,13 @@ inline size_t HasPrefix(std::string_view str, std::string_view prefix) {
/// @returns the Levenshtein distance between @p a and @p b
size_t Distance(std::string_view a, std::string_view b);
/// Suggest alternatives for an unrecognized string from a list of expected values.
/// Suggest alternatives for an unrecognized string from a list of possible values.
/// @param got the unrecognized string
/// @param strings the list of expected values
/// @param strings the list of possible values
/// @param ss the stream to write the suggest and list of possible values to
template <size_t N>
void SuggestAlternatives(std::string_view got,
const char* const (&strings)[N],
std::ostringstream& ss) {
// If the string typed was within kSuggestionDistance of one of the possible enum values,
// suggest that. Don't bother with suggestions if the string was extremely long.
constexpr size_t kSuggestionDistance = 5;
constexpr size_t kSuggestionMaxLength = 64;
if (!got.empty() && got.size() < kSuggestionMaxLength) {
size_t candidate_dist = kSuggestionDistance;
const char* candidate = nullptr;
for (auto* str : strings) {
auto dist = utils::Distance(str, got);
if (dist < candidate_dist) {
candidate = str;
candidate_dist = dist;
}
}
if (candidate) {
ss << "Did you mean '" << candidate << "'?\n";
}
}
// List all the possible enumerator values
ss << "Possible values: ";
for (auto* str : strings) {
if (str != strings[0]) {
ss << ", ";
}
ss << "'" << str << "'";
}
}
Slice<char const* const> strings,
std::ostringstream& ss);
} // namespace tint::utils

View File

@ -60,14 +60,16 @@ TEST(StringTest, Distance) {
TEST(StringTest, SuggestAlternatives) {
{
const char* alternatives[] = {"hello world", "Hello World"};
std::ostringstream ss;
SuggestAlternatives("hello wordl", {"hello world", "Hello World"}, ss);
SuggestAlternatives("hello wordl", alternatives, ss);
EXPECT_EQ(ss.str(), R"(Did you mean 'hello world'?
Possible values: 'hello world', 'Hello World')");
}
{
const char* alternatives[] = {"foobar", "something else"};
std::ostringstream ss;
SuggestAlternatives("hello world", {"foobar", "something else"}, ss);
SuggestAlternatives("hello world", alternatives, ss);
EXPECT_EQ(ss.str(), R"(Possible values: 'foobar', 'something else')");
}
}

View File

@ -19,14 +19,14 @@
#include <stdint.h>
#include <algorithm>
#include <iterator>
#include <new>
#include <ostream>
#include <utility>
#include <vector>
#include "src/tint/castable.h"
#include "src/tint/traits.h"
#include "src/tint/utils/bitcast.h"
#include "src/tint/utils/compiler_macros.h"
#include "src/tint/utils/slice.h"
#include "src/tint/utils/string.h"
namespace tint::utils {
@ -41,137 +41,6 @@ class VectorRef;
namespace tint::utils {
/// A type used to indicate an empty array.
struct EmptyType {};
/// An instance of the EmptyType.
static constexpr EmptyType Empty;
/// A slice represents a contigious array of elements of type T.
template <typename T>
struct Slice {
/// The pointer to the first element in the slice
T* data = nullptr;
/// The total number of elements in the slice
size_t len = 0;
/// The total capacity of the backing store for the slice
size_t cap = 0;
/// Index operator
/// @param i the element index. Must be less than `len`.
/// @returns a reference to the i'th element.
T& operator[](size_t i) { return data[i]; }
/// Index operator
/// @param i the element index. Must be less than `len`.
/// @returns a reference to the i'th element.
const T& operator[](size_t i) const { return data[i]; }
/// @returns a reference to the first element in the vector
T& Front() { return data[0]; }
/// @returns a reference to the first element in the vector
const T& Front() const { return data[0]; }
/// @returns a reference to the last element in the vector
T& Back() { return data[len - 1]; }
/// @returns a reference to the last element in the vector
const T& Back() const { return data[len - 1]; }
/// @returns a pointer to the first element in the vector
T* begin() { return data; }
/// @returns a pointer to the first element in the vector
const T* begin() const { return data; }
/// @returns a pointer to one past the last element in the vector
T* end() { return data + len; }
/// @returns a pointer to one past the last element in the vector
const T* end() const { return data + len; }
/// @returns a reverse iterator starting with the last element in the vector
auto rbegin() { return std::reverse_iterator<T*>(end()); }
/// @returns a reverse iterator starting with the last element in the vector
auto rbegin() const { return std::reverse_iterator<const T*>(end()); }
/// @returns the end for a reverse iterator
auto rend() { return std::reverse_iterator<T*>(begin()); }
/// @returns the end for a reverse iterator
auto rend() const { return std::reverse_iterator<const T*>(begin()); }
};
/// Mode enumerator for ReinterpretSlice
enum class ReinterpretMode {
/// Only upcasts of pointers are permitted
kSafe,
/// Potentially unsafe downcasts of pointers are also permitted
kUnsafe,
};
namespace detail {
/// Private implementation of tint::utils::CanReinterpretSlice.
/// Specialized for the case of TO equal to FROM, which is the common case, and avoids inspection of
/// the base classes, which can be troublesome if the slice is of an incomplete type.
template <ReinterpretMode MODE, typename TO, typename FROM>
struct CanReinterpretSlice {
/// True if a slice of FROM can be reinterpreted as a slice of TO
static constexpr bool value =
// Both TO and FROM are pointers
(std::is_pointer_v<TO> && std::is_pointer_v<FROM>)&& //
// const can only be applied, not removed
(std::is_const_v<std::remove_pointer_t<TO>> ||
!std::is_const_v<std::remove_pointer_t<FROM>>)&& //
// TO and FROM are both Castable
IsCastable<std::remove_pointer_t<FROM>, std::remove_pointer_t<TO>> && //
// MODE is kUnsafe, or FROM is of, or derives from TO
(MODE == ReinterpretMode::kUnsafe ||
traits::IsTypeOrDerived<std::remove_pointer_t<FROM>, std::remove_pointer_t<TO>>);
};
/// Specialization of 'CanReinterpretSlice' for when TO and FROM are equal types.
template <typename T, ReinterpretMode MODE>
struct CanReinterpretSlice<MODE, T, T> {
/// Always `true` as TO and FROM are the same type.
static constexpr bool value = true;
};
} // namespace detail
/// Evaluates whether a `vector<FROM>` and be reinterpreted as a `vector<TO>`.
/// Vectors can be reinterpreted if both `FROM` and `TO` are pointers to a type that derives from
/// CastableBase, and the pointee type of `TO` is of the same type as, or is an ancestor of the
/// pointee type of `FROM`. Vectors of non-`const` Castable pointers can be converted to a vector of
/// `const` Castable pointers.
template <ReinterpretMode MODE, typename TO, typename FROM>
static constexpr bool CanReinterpretSlice = detail::CanReinterpretSlice<MODE, TO, FROM>::value;
/// Reinterprets `const Slice<FROM>*` as `const Slice<TO>*`
/// @param slice a pointer to the slice to reinterpret
/// @returns the reinterpreted slice
/// @see CanReinterpretSlice
template <ReinterpretMode MODE, typename TO, typename FROM>
const Slice<TO>* ReinterpretSlice(const Slice<FROM>* slice) {
static_assert(CanReinterpretSlice<MODE, TO, FROM>);
return Bitcast<const Slice<TO>*>(slice);
}
/// Reinterprets `Slice<FROM>*` as `Slice<TO>*`
/// @param slice a pointer to the slice to reinterpret
/// @returns the reinterpreted slice
/// @see CanReinterpretSlice
template <ReinterpretMode MODE, typename TO, typename FROM>
Slice<TO>* ReinterpretSlice(Slice<FROM>* slice) {
static_assert(CanReinterpretSlice<MODE, TO, FROM>);
return Bitcast<Slice<TO>*>(slice);
}
/// Vector is a small-object-optimized, dynamically-sized vector of contigious elements of type T.
///
/// Vector will fit `N` elements internally before spilling to heap allocations. If `N` is greater
@ -244,7 +113,7 @@ class Vector {
ReinterpretMode MODE,
typename = std::enable_if_t<CanReinterpretSlice<MODE, T, U>>>
Vector(const Vector<U, N2>& other) { // NOLINT(runtime/explicit)
Copy(*ReinterpretSlice<MODE, T>(&other.impl_.slice));
Copy(other.impl_.slice.template Reinterpret<T, MODE>);
}
/// Move constructor with covariance / const conversion
@ -518,6 +387,9 @@ class Vector {
return !(*this == other);
}
/// @returns the internal slice of the vector
utils::Slice<T> Slice() { return impl_.slice; }
private:
/// Friend class (differing specializations of this class)
template <typename, size_t>
@ -531,9 +403,6 @@ class Vector {
template <typename>
friend class VectorRef;
/// The slice type used by this vector
using Slice = utils::Slice<T>;
template <typename... Ts>
void AppendVariadic(Ts&&... args) {
((new (&impl_.slice.data[impl_.slice.len++]) T(std::forward<Ts>(args))), ...);
@ -555,7 +424,7 @@ class Vector {
/// Copies all the elements from `other` to this vector, replacing the content of this vector.
/// @param other the
void Copy(const Slice& other) {
void Copy(const utils::Slice<T>& other) {
if (impl_.slice.cap < other.len) {
ClearAndFree();
impl_.Allocate(other.len);
@ -592,7 +461,7 @@ class Vector {
/// The internal structure for the vector with a small array.
struct ImplWithSmallArray {
TStorage small_arr[N];
Slice slice = {small_arr[0].Get(), 0, N};
utils::Slice<T> slice = {small_arr[0].Get(), 0, N};
/// Allocates a new vector of `T` either from #small_arr, or from the heap, then assigns the
/// pointer it to #slice.data, and updates #slice.cap.
@ -620,7 +489,7 @@ class Vector {
/// The internal structure for the vector without a small array.
struct ImplWithoutSmallArray {
Slice slice = {nullptr, 0, 0};
utils::Slice<T> slice = Empty;
/// Allocates a new vector of `T` and assigns it to #slice.data, and updates #slice.cap.
void Allocate(size_t new_cap) {
@ -759,15 +628,14 @@ class VectorRef {
template <typename U,
typename = std::enable_if_t<CanReinterpretSlice<ReinterpretMode::kSafe, T, U>>>
VectorRef(const VectorRef<U>& other) // NOLINT(runtime/explicit)
: slice_(*ReinterpretSlice<ReinterpretMode::kSafe, T>(&other.slice_)) {}
: slice_(other.slice_.template Reinterpret<T>()) {}
/// Move constructor with covariance / const conversion
/// @param other the vector reference
template <typename U,
typename = std::enable_if_t<CanReinterpretSlice<ReinterpretMode::kSafe, T, U>>>
VectorRef(VectorRef<U>&& other) // NOLINT(runtime/explicit)
: slice_(*ReinterpretSlice<ReinterpretMode::kSafe, T>(&other.slice_)),
can_move_(other.can_move_) {}
: slice_(other.slice_.template Reinterpret<T>()), can_move_(other.can_move_) {}
/// Constructor from a Vector with covariance / const conversion
/// @param vector the vector to create a reference of
@ -776,7 +644,7 @@ class VectorRef {
size_t N,
typename = std::enable_if_t<CanReinterpretSlice<ReinterpretMode::kSafe, T, U>>>
VectorRef(Vector<U, N>& vector) // NOLINT(runtime/explicit)
: slice_(*ReinterpretSlice<ReinterpretMode::kSafe, T>(&vector.impl_.slice)) {}
: slice_(vector.impl_.slice.template Reinterpret<T>()) {}
/// Constructor from a moved Vector with covariance / const conversion
/// @param vector the vector to create a reference of
@ -785,8 +653,7 @@ class VectorRef {
size_t N,
typename = std::enable_if_t<CanReinterpretSlice<ReinterpretMode::kSafe, T, U>>>
VectorRef(Vector<U, N>&& vector) // NOLINT(runtime/explicit)
: slice_(*ReinterpretSlice<ReinterpretMode::kSafe, T>(&vector.impl_.slice)),
can_move_(vector.impl_.CanMove()) {}
: slice_(vector.impl_.slice.template Reinterpret<T>()), can_move_(vector.impl_.CanMove()) {}
/// Index operator
/// @param i the element index. Must be less than `len`.
@ -805,7 +672,7 @@ class VectorRef {
/// this is a safe operation.
template <typename U>
VectorRef<U> ReinterpretCast() const {
return {*ReinterpretSlice<ReinterpretMode::kUnsafe, U>(&slice_)};
return {slice_.template Reinterpret<U, ReinterpretMode::kUnsafe>()};
}
/// @returns true if the vector is empty.

View File

@ -79,31 +79,6 @@ static_assert(std::is_same_v<VectorCommonType<const C2a*, C2b*>, const C1*>);
static_assert(std::is_same_v<VectorCommonType<C2a*, const C2b*>, const C1*>);
static_assert(std::is_same_v<VectorCommonType<const C2a*, const C2b*>, const C1*>);
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, const C0*, C0*>, "apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C0*, const C0*>, "remove const");
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, C0*, C1*>, "up cast");
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, const C0*, const C1*>, "up cast");
static_assert(CanReinterpretSlice<ReinterpretMode::kSafe, const C0*, C1*>, "up cast, apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C0*, const C1*>,
"up cast, remove const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C1*, C0*>, "down cast");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C1*, const C0*>, "down cast");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C1*, C0*>,
"down cast, apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C1*, const C0*>,
"down cast, remove const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C1*, C0*>,
"down cast, apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C1*, const C0*>,
"down cast, remove const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C2a*, C2b*>, "sideways cast");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C2a*, const C2b*>,
"sideways cast");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, const C2a*, C2b*>,
"sideways cast, apply const");
static_assert(!CanReinterpretSlice<ReinterpretMode::kSafe, C2a*, const C2b*>,
"sideways cast, remove const");
////////////////////////////////////////////////////////////////////////////////
// TintVectorTest
////////////////////////////////////////////////////////////////////////////////