resolver: Add dependency-graph analysis

Performs a module-scope (global) declaration dependency analysis, so
that out-of-order global declarations can be re-ordered into dependency
order for consumption by the resolver.

The WGSL working group are currently debating whether out-of-order
declarations should be included in WebGPU V1, so this implementation
currently errors if module-scope declarations are declared out-of-order,
and the resolver does not currently use this sorted global list.

The analysis does however provide significantly better error diagnostics
when cyclic dependencies are formed, and when globals are declared
out-of-order.

The DependencyGraph also correctly now detects symbol collisions between
functions and types (tint:1308).

With this change, validation is duplicated between the DependencyGraph
and the Resolver. The now-unreachable validation will be removed from
the Resolver with a followup change.

Fixed: tint:1308
Bug: tint:1266
Change-Id: I809c23a069a86cf429f5ec8ef3ad9a98246766ab
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/69381
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Ben Clayton 2021-11-22 11:44:57 +00:00 committed by Tint LUCI CQ
parent c87dc10ce3
commit 4183051b54
20 changed files with 2007 additions and 253 deletions

View File

@ -371,6 +371,8 @@ libtint_source_set("libtint_core_all_src") {
"program_id.h",
"reader/reader.cc",
"reader/reader.h",
"resolver/dependency_graph.cc",
"resolver/dependency_graph.h",
"resolver/resolver.cc",
"resolver/resolver.h",
"resolver/resolver_constants.cc",

View File

@ -238,6 +238,8 @@ set(TINT_LIB_SRCS
program.h
reader/reader.cc
reader/reader.h
resolver/dependency_graph.cc
resolver/dependency_graph.h
resolver/resolver.cc
resolver/resolver_constants.cc
resolver/resolver_validation.cc
@ -673,6 +675,7 @@ if(${TINT_BUILD_TESTS})
resolver/compound_statement_test.cc
resolver/control_block_validation_test.cc
resolver/decoration_validation_test.cc
resolver/dependency_graph_test.cc
resolver/entry_point_validation_test.cc
resolver/function_validation_test.cc
resolver/host_shareable_validation_test.cc

View File

@ -112,9 +112,9 @@ type declaration_order_check_1 = f32;
fn declaration_order_check_2() {}
type declaration_order_check_2 = f32;
type declaration_order_check_3 = f32;
let declaration_order_check_3 : i32 = 1;
let declaration_order_check_4 : i32 = 1;
)");

View File

@ -35,8 +35,10 @@
#include "src/ast/depth_multisampled_texture.h"
#include "src/ast/depth_texture.h"
#include "src/ast/disable_validation_decoration.h"
#include "src/ast/discard_statement.h"
#include "src/ast/external_texture.h"
#include "src/ast/f32.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/float_literal_expression.h"
#include "src/ast/for_loop_statement.h"
#include "src/ast/i32.h"
@ -1892,6 +1894,19 @@ class ProgramBuilder {
return create<ast::ReturnStatement>(Expr(std::forward<EXPR>(val)));
}
/// Creates an ast::DiscardStatement
/// @param source the source information
/// @returns the discard statement pointer
const ast::DiscardStatement* Discard(const Source& source) {
return create<ast::DiscardStatement>(source);
}
/// Creates an ast::DiscardStatement
/// @returns the discard statement pointer
const ast::DiscardStatement* Discard() {
return create<ast::DiscardStatement>();
}
/// Creates a ast::Alias registering it with the AST().TypeDecls().
/// @param source the source information
/// @param name the alias name
@ -2205,6 +2220,19 @@ class ProgramBuilder {
return Case(ast::CaseSelectorList{}, body);
}
/// Creates an ast::FallthroughStatement
/// @param source the source information
/// @returns the fallthrough statement pointer
const ast::FallthroughStatement* Fallthrough(const Source& source) {
return create<ast::FallthroughStatement>(source);
}
/// Creates an ast::FallthroughStatement
/// @returns the fallthrough statement pointer
const ast::FallthroughStatement* Fallthrough() {
return create<ast::FallthroughStatement>();
}
/// Creates an ast::BuiltinDecoration
/// @param source the source information
/// @param builtin the builtin value

View File

@ -24,54 +24,6 @@ namespace {
using ResolverCallValidationTest = ResolverTest;
TEST_F(ResolverCallValidationTest, Recursive_Invalid) {
// fn main() {main(); }
SetSource(Source::Location{12, 34});
auto* call_expr = Call("main");
ast::VariableList params0;
Func("main", params0, ty.void_(),
ast::StatementList{
CallStmt(call_expr),
},
ast::DecorationList{
Stage(ast::PipelineStage::kVertex),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: recursion is not permitted. 'main' attempted to call "
"itself.");
}
TEST_F(ResolverCallValidationTest, Undeclared_Invalid) {
// fn main() {func(); return; }
// fn func() { return; }
SetSource(Source::Location{12, 34});
auto* call_expr = Call("func");
ast::VariableList params0;
Func("main", params0, ty.f32(),
ast::StatementList{
CallStmt(call_expr),
Return(),
},
ast::DecorationList{});
Func("func", params0, ty.f32(),
ast::StatementList{
Return(),
},
ast::DecorationList{});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: unable to find called function: func");
}
TEST_F(ResolverCallValidationTest, TooFewArgs) {
Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(),
{Return()});
@ -115,11 +67,10 @@ TEST_F(ResolverCallValidationTest, UnusedRetval) {
Func("func", {}, ty.f32(), {Return(Expr(1.0f))}, {});
Func("main", {}, ty.void_(),
ast::StatementList{
{
CallStmt(Source{{12, 34}}, Call("func")),
Return(),
},
{});
});
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
@ -133,7 +84,7 @@ TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) {
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
Func("foo", {param}, ty.void_(), {});
Func("main", {}, ty.void_(),
ast::StatementList{
{
Decl(Var("z", ty.i32(), Expr(1))),
CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Expr("z")))),
});
@ -150,7 +101,7 @@ TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) {
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
Func("foo", {param}, ty.void_(), {});
Func("main", {}, ty.void_(),
ast::StatementList{
{
Decl(Const("z", ty.i32(), Expr(1))),
CallStmt(Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
});
@ -170,7 +121,7 @@ TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
Func("foo", {param}, ty.void_(), {});
Func("main", {}, ty.void_(),
ast::StatementList{
{
Decl(Var("v", ty.Of(S))),
CallStmt(Call(
"foo", AddressOf(Source{{12, 34}}, MemberAccessor("v", "m")))),
@ -193,7 +144,7 @@ TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
Func("foo", {param}, ty.void_(), {});
Func("main", {}, ty.void_(),
ast::StatementList{
{
Decl(Const("v", ty.Of(S), Construct(ty.Of(S)))),
CallStmt(Call("foo", AddressOf(Expr(Source{{12, 34}},
MemberAccessor("v", "m"))))),

View File

@ -0,0 +1,700 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/resolver/dependency_graph.h"
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "src/ast/continue_statement.h"
#include "src/ast/discard_statement.h"
#include "src/ast/fallthrough_statement.h"
#include "src/ast/traverse_expressions.h"
#include "src/scope_stack.h"
#include "src/sem/intrinsic.h"
#include "src/utils/defer.h"
#include "src/utils/scoped_assignment.h"
#include "src/utils/unique_vector.h"
#define TINT_DUMP_DEPENDENCY_GRAPH 0
namespace tint {
namespace resolver {
namespace {
// Forward declaration
struct Global;
/// Dependency describes how one global depends on another global
struct DependencyInfo {
/// The source of the symbol that forms the dependency
Source source;
/// A string describing how the dependency is referenced. e.g. 'calls'
const char* action = nullptr;
};
/// DependencyEdge describes the two Globals used to define a dependency
/// relationship.
struct DependencyEdge {
/// The Global that depends on #to
const Global* from;
/// The Global that is depended on by #from
const Global* to;
};
/// DependencyEdgeCmp implements the contracts of std::equal_to<DependencyEdge>
/// and std::hash<DependencyEdge>.
struct DependencyEdgeCmp {
/// Equality operator
bool operator()(const DependencyEdge& lhs, const DependencyEdge& rhs) const {
return lhs.from == rhs.from && lhs.to == rhs.to;
}
/// Hashing operator
inline std::size_t operator()(const DependencyEdge& d) const {
return utils::Hash(d.from, d.to);
}
};
/// A map of DependencyEdge to DependencyInfo
using DependencyEdges = std::unordered_map<DependencyEdge,
DependencyInfo,
DependencyEdgeCmp,
DependencyEdgeCmp>;
/// Global describes a module-scope variable, type or function.
struct Global {
explicit Global(const ast::Node* n) : node(n) {}
/// The declaration ast::Node
const ast::Node* node;
/// A list of dependencies that this global depends on
std::vector<Global*> deps;
};
/// A map of global name to Global
using GlobalMap = std::unordered_map<Symbol, Global*>;
/// Raises an ICE that a global ast::Node declaration type was not handled by
/// this system.
void UnhandledDecl(diag::List& diagnostics, const ast::Node* node) {
TINT_UNREACHABLE(Resolver, diagnostics)
<< "unhandled global declaration: " << node->TypeInfo().name;
}
/// Raises an error diagnostic with the given message and source.
void AddError(diag::List& diagnostics,
const std::string& msg,
const Source& source) {
diagnostics.add_error(diag::System::Resolver, msg, source);
}
/// Raises a note diagnostic with the given message and source.
void AddNote(diag::List& diagnostics,
const std::string& msg,
const Source& source) {
diagnostics.add_note(diag::System::Resolver, msg, source);
}
/// DependencyScanner is used to traverse a module to build the list of
/// global-to-global dependencies.
class DependencyScanner {
public:
/// Constructor
/// @param syms the program symbol table
/// @param globals_by_name map of global symbol to Global pointer
/// @param diagnostics diagnostic messages, appended with any errors found
/// @param graph the dependency graph to populate with resolved symbols
/// @param edges the map of globals-to-global dependency edges, which will
/// be populated by calls to Scan()
DependencyScanner(const SymbolTable& syms,
const GlobalMap& globals_by_name,
diag::List& diagnostics,
DependencyGraph& graph,
DependencyEdges& edges)
: symbols_(syms),
globals_(globals_by_name),
diagnostics_(diagnostics),
graph_(graph),
dependency_edges_(edges) {
// Register all the globals at global-scope
for (auto it : globals_by_name) {
scope_stack_.Set(it.first, it.second->node);
}
}
/// Walks the global declarations, resolving symbols, and determining the
/// dependencies of each global.
void Scan(Global* global) {
TINT_SCOPED_ASSIGNMENT(current_global_, global);
if (auto* str = global->node->As<ast::Struct>()) {
Declare(str->name, str);
for (auto* member : str->members) {
ResolveTypeDependency(member->type);
}
return;
}
if (auto* alias = global->node->As<ast::Alias>()) {
Declare(alias->name, alias);
ResolveTypeDependency(alias->type);
return;
}
if (auto* func = global->node->As<ast::Function>()) {
Declare(func->symbol, func);
TraverseFunction(func);
return;
}
if (auto* var = global->node->As<ast::Variable>()) {
Declare(var->symbol, var);
ResolveTypeDependency(var->type);
if (var->constructor) {
TraverseExpression(var->constructor);
}
return;
}
UnhandledDecl(diagnostics_, global->node);
}
private:
/// Traverses the function determining global dependencies.
void TraverseFunction(const ast::Function* func) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
for (auto* param : func->params) {
Declare(param->symbol, param);
ResolveTypeDependency(param->type);
}
if (func->body) {
TraverseStatements(func->body->statements);
}
ResolveTypeDependency(func->return_type);
}
/// Traverses the statements determining global dependencies.
void TraverseStatements(const ast::StatementList& stmts) {
for (auto* s : stmts) {
TraverseStatement(s);
}
}
/// Traverses the statement determining global dependencies.
void TraverseStatement(const ast::Statement* stmt) {
if (stmt == nullptr) {
return;
}
if (auto* b = stmt->As<ast::AssignmentStatement>()) {
TraverseExpression(b->lhs);
TraverseExpression(b->rhs);
return;
}
if (auto* b = stmt->As<ast::BlockStatement>()) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseStatements(b->statements);
return;
}
if (auto* r = stmt->As<ast::CallStatement>()) {
TraverseExpression(r->expr);
return;
}
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseStatement(l->initializer);
TraverseExpression(l->condition);
TraverseStatement(l->continuing);
TraverseStatement(l->body);
return;
}
if (auto* l = stmt->As<ast::LoopStatement>()) {
scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop());
TraverseStatements(l->body->statements);
TraverseStatement(l->continuing);
return;
}
if (auto* i = stmt->As<ast::IfStatement>()) {
TraverseExpression(i->condition);
TraverseStatement(i->body);
for (auto* e : i->else_statements) {
TraverseExpression(e->condition);
TraverseStatement(e->body);
}
return;
}
if (auto* r = stmt->As<ast::ReturnStatement>()) {
TraverseExpression(r->value);
return;
}
if (auto* s = stmt->As<ast::SwitchStatement>()) {
TraverseExpression(s->condition);
for (auto* c : s->body) {
for (auto* sel : c->selectors) {
TraverseExpression(sel);
}
TraverseStatement(c->body);
}
return;
}
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
Declare(v->variable->symbol, v->variable);
ResolveTypeDependency(v->variable->type);
TraverseExpression(v->variable->constructor);
return;
}
if (stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement, ast::FallthroughStatement>()) {
return;
}
AddError(diagnostics_,
"unknown statement type: " + std::string(stmt->TypeInfo().name),
stmt->source);
}
/// Adds the symbol definition to the current scope, raising an error if two
/// symbols collide within the same scope.
void Declare(Symbol symbol, const ast::Node* node) {
auto* old = scope_stack_.Set(symbol, node);
if (old != nullptr && node != old) {
auto name = symbols_.NameFor(symbol);
AddError(diagnostics_, "redeclaration of '" + name + "'", node->source);
AddNote(diagnostics_, "'" + name + "' previously declared here",
old->source);
}
}
/// Traverses the expression determining global dependencies.
void TraverseExpression(const ast::Expression* root) {
if (!root) {
return;
}
ast::TraverseExpressions(
root, diagnostics_, [&](const ast::Expression* expr) {
if (auto* ident = expr->As<ast::IdentifierExpression>()) {
auto* node = scope_stack_.Get(ident->symbol);
if (node == nullptr) {
if (!IsIntrinsic(ident->symbol)) {
UnknownSymbol(ident->symbol, ident->source, "identifier");
}
return ast::TraverseAction::Descend;
}
auto global_it = globals_.find(ident->symbol);
if (global_it != globals_.end() &&
node == global_it->second->node) {
ResolveGlobalDependency(ident, ident->symbol, "identifier",
"references");
} else {
graph_.resolved_symbols.emplace(ident, node);
}
}
if (auto* call = expr->As<ast::CallExpression>()) {
if (call->target.name) {
if (!IsIntrinsic(call->target.name->symbol)) {
ResolveGlobalDependency(call->target.name,
call->target.name->symbol, "function",
"calls");
}
}
if (call->target.type) {
ResolveTypeDependency(call->target.type);
}
}
return ast::TraverseAction::Descend;
});
}
/// Adds the type dependency to the currently processed global
void ResolveTypeDependency(const ast::Type* ty) {
if (ty == nullptr) {
return;
}
if (auto* tn = ty->As<ast::TypeName>()) {
ResolveGlobalDependency(tn, tn->name, "type", "references");
}
}
/// Adds the dependency to the currently processed global
void ResolveGlobalDependency(const ast::Node* from,
Symbol to,
const char* use,
const char* action) {
auto global_it = globals_.find(to);
if (global_it != globals_.end()) {
auto* global = global_it->second;
if (dependency_edges_
.emplace(DependencyEdge{current_global_, global},
DependencyInfo{from->source, action})
.second) {
current_global_->deps.emplace_back(global);
}
graph_.resolved_symbols.emplace(from, global->node);
} else {
UnknownSymbol(to, from->source, use);
}
}
/// @returns true if `name` is the name of an intrinsic function
bool IsIntrinsic(Symbol name) const {
return sem::ParseIntrinsicType(symbols_.NameFor(name)) !=
sem::IntrinsicType::kNone;
}
/// Appends an error to the diagnostics that the given symbol cannot be
/// resolved.
void UnknownSymbol(Symbol name, Source source, const char* use) {
AddError(
diagnostics_,
"unknown " + std::string(use) + ": '" + symbols_.NameFor(name) + "'",
source);
}
using VariableMap = std::unordered_map<Symbol, const ast::Variable*>;
const SymbolTable& symbols_;
const GlobalMap& globals_;
diag::List& diagnostics_;
DependencyGraph& graph_;
DependencyEdges& dependency_edges_;
ScopeStack<const ast::Node*> scope_stack_;
Global* current_global_ = nullptr;
};
/// The global dependency analysis system
struct DependencyAnalysis {
public:
/// Constructor
DependencyAnalysis(const SymbolTable& symbols,
diag::List& diagnostics,
DependencyGraph& graph)
: symbols_(symbols), diagnostics_(diagnostics), graph_(graph) {}
/// Performs global dependency analysis on the module, emitting any errors to
/// #diagnostics.
/// @returns true if analysis found no errors, otherwise false.
bool Run(const ast::Module& module, bool allow_out_of_order_decls) {
// Collect all the named globals from the AST module
GatherGlobals(module);
// Traverse the named globals to build the dependency graph
DetermineDependencies();
// Sort the globals into dependency order
SortGlobals();
#if TINT_DUMP_DEPENDENCY_GRAPH
DumpDependencyGraph();
#endif
if (!allow_out_of_order_decls) {
// Prevent out-of-order declarations.
ErrorOnOutOfOrderDeclarations();
}
graph_.ordered_globals = std::move(sorted_);
return !diagnostics_.contains_errors();
}
private:
/// @param node the ast::Node of the global declaration
/// @returns the symbol of the global declaration node
/// @note will raise an ICE if the node is not a type, function or variable
/// declaration
Symbol SymbolOf(const ast::Node* node) const {
if (auto* td = node->As<ast::TypeDecl>()) {
return td->name;
}
if (auto* func = node->As<ast::Function>()) {
return func->symbol;
}
if (auto* var = node->As<ast::Variable>()) {
return var->symbol;
}
UnhandledDecl(diagnostics_, node);
return {};
}
/// @param node the ast::Node of the global declaration
/// @returns the name of the global declaration node
/// @note will raise an ICE if the node is not a type, function or variable
/// declaration
std::string NameOf(const ast::Node* node) const {
return symbols_.NameFor(SymbolOf(node));
}
/// @param node the ast::Node of the global declaration
/// @returns a string representation of the global declaration kind
/// @note will raise an ICE if the node is not a type, function or variable
/// declaration
std::string KindOf(const ast::Node* node) {
if (node->Is<ast::Struct>()) {
return "struct";
}
if (node->Is<ast::Alias>()) {
return "alias";
}
if (node->Is<ast::Function>()) {
return "function";
}
if (auto* var = node->As<ast::Variable>()) {
return var->is_const ? "let" : "var";
}
UnhandledDecl(diagnostics_, node);
return {};
}
/// Traverses `module`, collecting all the global declarations and populating
/// the #globals and #declaration_order fields.
void GatherGlobals(const ast::Module& module) {
for (auto* node : module.GlobalDeclarations()) {
auto* global = allocator_.Create(node);
globals_.emplace(SymbolOf(node), global);
declaration_order_.emplace_back(global);
}
}
/// Walks the global declarations, determining the dependencies of each global
/// and adding these to each global's Global::deps field.
void DetermineDependencies() {
DependencyScanner scanner(symbols_, globals_, diagnostics_, graph_,
dependency_edges_);
for (auto* global : declaration_order_) {
scanner.Scan(global);
}
}
/// Performs a depth-first traversal of `root`'s dependencies, calling `enter`
/// as the function decends into each dependency and `exit` when bubbling back
/// up towards the root.
/// @param enter is a function with the signature: `bool(Global*)`. The
/// `enter` function returns true if TraverseDependencies() should traverse
/// the dependency, otherwise it will be skipped.
/// @param exit is a function with the signature: `void(Global*)`. The `exit`
/// function is only called if the corresponding `enter` call returned true.
template <typename ENTER, typename EXIT>
void TraverseDependencies(const Global* root, ENTER&& enter, EXIT&& exit) {
// Entry is a single entry in the traversal stack. Entry points to a
// dep_idx'th dependency of Entry::global.
struct Entry {
const Global* global; // The parent global
size_t dep_idx; // The dependency index in `global->deps`
};
if (!enter(root)) {
return;
}
std::vector<Entry> stack{Entry{root, 0}};
while (true) {
auto& entry = stack.back();
// Have we exhausted the dependencies of entry.global?
if (entry.dep_idx < entry.global->deps.size()) {
// No, there's more dependencies to traverse.
auto& dep = entry.global->deps[entry.dep_idx];
// Does the caller want to enter this dependency?
if (enter(dep)) { // Yes.
stack.push_back(Entry{dep, 0}); // Enter the dependency.
} else {
entry.dep_idx++; // No. Skip this node.
}
} else {
// Yes. Time to back up.
// Exit this global, pop the stack, and if there's another parent node,
// increment its dependency index, and loop again.
exit(entry.global);
stack.pop_back();
if (stack.empty()) {
return; // All done.
}
stack.back().dep_idx++;
}
}
}
/// SortGlobals sorts the globals into dependency order, erroring if cyclic
/// dependencies are found. The sorted dependencies are assigned to #sorted.
void SortGlobals() {
if (diagnostics_.contains_errors()) {
return; // This code assumes there are no undeclared identifiers.
}
std::unordered_set<const Global*> visited;
for (auto* global : declaration_order_) {
utils::UniqueVector<const Global*> stack;
TraverseDependencies(
global,
[&](const Global* g) { // Enter
if (!stack.add(g)) {
CyclicDependencyFound(g, stack);
return false;
}
if (sorted_.contains(g->node)) {
return false; // Visited this global already.
}
return true;
},
[&](const Global* g) { // Exit
sorted_.add(g->node);
stack.pop_back();
});
sorted_.add(global->node);
}
}
/// DepInfoFor() looks up the global dependency information for the dependency
/// of global `from` depending on `to`.
/// @note will raise an ICE if the edge is not found.
DependencyInfo DepInfoFor(const Global* from, const Global* to) const {
auto it = dependency_edges_.find(DependencyEdge{from, to});
if (it != dependency_edges_.end()) {
return it->second;
}
TINT_ICE(Resolver, diagnostics_)
<< "failed to find dependency info for edge: '" << NameOf(from->node)
<< "' -> '" << NameOf(to->node) << "'";
return {};
}
// TODO(crbug.com/tint/1266): Errors if there are any uses of globals before
// their declaration. Out-of-order declarations was added to the WGSL
// specification with https://github.com/gpuweb/gpuweb/pull/2244, but Mozilla
// have objections to this change so this feature is currently disabled via
// this function.
void ErrorOnOutOfOrderDeclarations() {
if (diagnostics_.contains_errors()) {
// Might have already errored about cyclic dependencies. No need to report
// out-of-order errors as well.
return;
}
std::unordered_set<const Global*> seen;
for (auto* global : declaration_order_) {
for (auto* dep : global->deps) {
if (!seen.count(dep)) {
auto info = DepInfoFor(global, dep);
auto name = NameOf(dep->node);
AddError(diagnostics_,
KindOf(dep->node) + " '" + name +
"' used before it has been declared",
info.source);
AddNote(diagnostics_,
KindOf(dep->node) + " '" + name + "' declared here",
dep->node->source);
}
}
seen.emplace(global);
}
}
/// CyclicDependencyFound() emits an error diagnostic for a cyclic dependency.
/// @param root is the global that starts the cyclic dependency, which must be
/// found in `stack`.
/// @param stack is the global dependency stack that contains a loop.
void CyclicDependencyFound(const Global* root,
const std::vector<const Global*>& stack) {
std::stringstream msg;
msg << "cyclic dependency found: ";
constexpr size_t kLoopNotStarted = ~0u;
size_t loop_start = kLoopNotStarted;
for (size_t i = 0; i < stack.size(); i++) {
auto* e = stack[i];
if (loop_start == kLoopNotStarted && e == root) {
loop_start = i;
}
if (loop_start != kLoopNotStarted) {
msg << "'" << NameOf(e->node) << "' -> ";
}
}
msg << "'" << NameOf(root->node) << "'";
AddError(diagnostics_, msg.str(), root->node->source);
for (size_t i = loop_start; i < stack.size(); i++) {
auto* from = stack[i];
auto* to = (i + 1 < stack.size()) ? stack[i + 1] : stack[loop_start];
auto info = DepInfoFor(from, to);
AddNote(diagnostics_,
KindOf(from->node) + " '" + NameOf(from->node) + "' " +
info.action + " " + KindOf(to->node) + " '" +
NameOf(to->node) + "' here",
info.source);
}
}
#if TINT_DUMP_DEPENDENCY_GRAPH
void DumpDependencyGraph() {
printf("=========================\n");
printf("------ declaration ------ \n");
for (auto* global : declaration_order_) {
printf("%s\n", NameOf(global->node).c_str());
}
printf("------ dependencies ------ \n");
for (auto* node : sorted) {
auto symbol = SymbolOf(node);
auto* global = globals.at(symbol);
printf("%s depends on:\n", symbols.NameFor(symbol).c_str());
for (auto& dep : global->deps) {
printf(" %s\n", NameOf(dep.global->node).c_str());
}
}
printf("=========================\n");
}
#endif // TINT_DUMP_DEPENDENCY_GRAPH
/// Program symbols
const SymbolTable& symbols_;
/// Program diagnostics
diag::List& diagnostics_;
/// The resulting dependency graph
DependencyGraph& graph_;
/// Allocator of Globals
BlockAllocator<Global> allocator_;
/// Global map, keyed by name. Populated by GatherGlobals().
GlobalMap globals_;
/// Map of DependencyEdge to DependencyInfo. Populated by
/// DetermineDependencies().
DependencyEdges dependency_edges_;
/// Globals in declaration order. Populated by GatherGlobals().
std::vector<Global*> declaration_order_;
/// Globals in sorted dependency order. Populated by SortGlobals().
utils::UniqueVector<const ast::Node*> sorted_;
};
} // namespace
DependencyGraph::DependencyGraph() = default;
DependencyGraph::DependencyGraph(DependencyGraph&&) = default;
DependencyGraph::~DependencyGraph() = default;
bool DependencyGraph::Build(const ast::Module& module,
const SymbolTable& symbols,
diag::List& diagnostics,
DependencyGraph& output,
bool allow_out_of_order_decls) {
DependencyAnalysis da{symbols, diagnostics, output};
return da.Run(module, allow_out_of_order_decls);
}
} // namespace resolver
} // namespace tint

View File

@ -0,0 +1,63 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_RESOLVER_DEPENDENCY_GRAPH_H_
#define SRC_RESOLVER_DEPENDENCY_GRAPH_H_
#include <unordered_map>
#include <vector>
#include "src/ast/module.h"
#include "src/diagnostic/diagnostic.h"
namespace tint {
namespace resolver {
/// DependencyGraph holds information about module-scope declaration dependency
/// analysis and symbol resolutions.
struct DependencyGraph {
/// Constructor
DependencyGraph();
/// Move-constructor
DependencyGraph(DependencyGraph&&);
/// Destructor
~DependencyGraph();
/// Build() performs symbol resolution and dependency analysis on `module`,
/// populating `output` with the resulting dependency graph.
/// @param module the AST module to analyse
/// @param symbols the symbol table
/// @param diagnostics the diagnostic list to populate with errors / warnings
/// @param output the resulting DependencyGraph
/// @param allow_out_of_order_decls if true, then out-of-order declarations
/// are not considered an error
/// @returns true on success, false on error
static bool Build(const ast::Module& module,
const SymbolTable& symbols,
diag::List& diagnostics,
DependencyGraph& output,
bool allow_out_of_order_decls);
/// All globals in dependency-sorted order.
std::vector<const ast::Node*> ordered_globals;
/// Map of ast::IdentifierExpression or ast::TypeName to a type, function, or
/// variable that declares the symbol.
std::unordered_map<const ast::Node*, const ast::Node*> resolved_symbols;
};
} // namespace resolver
} // namespace tint
#endif // SRC_RESOLVER_DEPENDENCY_GRAPH_H_

File diff suppressed because it is too large Load Diff

View File

@ -26,43 +26,6 @@ namespace {
class ResolverFunctionValidationTest : public resolver::TestHelper,
public testing::Test {};
TEST_F(ResolverFunctionValidationTest, FunctionNamesMustBeUnique_fail) {
// fn func() -> i32 { return 2; }
// fn func() -> i32 { return 2; }
Func(Source{{56, 78}}, "func", ast::VariableList{}, ty.i32(),
ast::StatementList{
Return(2),
},
ast::DecorationList{});
Func(Source{{12, 34}}, "func", ast::VariableList{}, ty.i32(),
ast::StatementList{
Return(2),
},
ast::DecorationList{});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: redefinition of 'func'
56:78 note: previous definition is here)");
}
TEST_F(ResolverFunctionValidationTest, ParameterNamesMustBeUnique_fail) {
// fn func(common_name : f32, x : i32, common_name : u32) { }
Func("func",
{
Param(Source{{56, 78}}, "common_name", ty.f32()),
Param("x", ty.i32()),
Param(Source{{12, 34}}, "common_name", ty.u32()),
},
ty.void_(), {});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
R"(12:34 error: redefinition of parameter 'common_name'
56:78 note: previous definition is here)");
}
TEST_F(ResolverFunctionValidationTest, ParameterNamesMustBeUnique_pass) {
// fn func_a(common_name : f32) { }
// fn func_b(common_name : f32) { }
@ -97,41 +60,6 @@ TEST_F(ResolverFunctionValidationTest,
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverFunctionValidationTest,
FunctionNameSameAsGlobalVariableName_Fail) {
// var foo:f32 = 3.14;
// fn foo() -> void {}
auto* global_var = Var(Source{{56, 78}}, "foo", ty.f32(),
ast::StorageClass::kPrivate, Expr(3.14f));
AST().AddGlobalVariable(global_var);
Func(Source{{12, 34}}, "foo", ast::VariableList{}, ty.void_(),
ast::StatementList{}, ast::DecorationList{});
EXPECT_FALSE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(),
"12:34 error: redefinition of 'foo'\n56:78 note: previous "
"definition is here");
}
TEST_F(ResolverFunctionValidationTest,
GlobalVariableNameSameAFunctionName_Fail) {
// fn foo() -> void {}
// var<private> foo:f32 = 3.14;
Func(Source{{12, 34}}, "foo", ast::VariableList{}, ty.void_(),
ast::StatementList{}, ast::DecorationList{});
auto* global_var = Var(Source{{56, 78}}, "foo", ty.f32(),
ast::StorageClass::kPrivate, Expr(3.14f));
AST().AddGlobalVariable(global_var);
EXPECT_FALSE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(),
"56:78 error: redefinition of 'foo'\n12:34 note: previous "
"definition is here");
}
TEST_F(ResolverFunctionValidationTest, FunctionUsingSameVariableName_Pass) {
// fn func() -> i32 {
// var func:i32 = 0;

View File

@ -95,6 +95,12 @@ bool Resolver::Resolve() {
return false;
}
if (!DependencyGraph::Build(builder_->AST(), builder_->Symbols(),
builder_->Diagnostics(), dependencies_,
/* allow_out_of_order_decls*/ false)) {
return false;
}
bool result = ResolveInternal();
if (!result && !diagnostics_.contains_errors()) {
@ -910,8 +916,7 @@ bool Resolver::Statement(const ast::Statement* stmt) {
return VariableDeclStatement(v);
}
AddError("unknown statement type for type determination: " +
std::string(stmt->TypeInfo().name),
AddError("unknown statement type: " + std::string(stmt->TypeInfo().name),
stmt->source);
return false;
}
@ -1075,15 +1080,23 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
sem::Expression* Resolver::Expression(const ast::Expression* root) {
std::vector<const ast::Expression*> sorted;
bool mark_failed = false;
if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
root, diagnostics_, [&](const ast::Expression* expr) {
Mark(expr);
if (!Mark(expr)) {
mark_failed = true;
return ast::TraverseAction::Stop;
}
sorted.emplace_back(expr);
return ast::TraverseAction::Descend;
})) {
return nullptr;
}
if (mark_failed) {
return nullptr;
}
for (auto* expr : utils::Reverse(sorted)) {
sem::Expression* sem_expr = nullptr;
if (auto* array = expr->As<ast::IndexAccessorExpression>()) {
@ -1524,7 +1537,7 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) {
return nullptr;
}
AddError("identifier must be declared before use: " + name, expr->source);
AddError("unknown identifier: '" + name + "'", expr->source);
return nullptr;
}
@ -2249,18 +2262,20 @@ std::string Resolver::VectorPretty(uint32_t size,
return vec_type.FriendlyName(builder_->Symbols());
}
void Resolver::Mark(const ast::Node* node) {
bool Resolver::Mark(const ast::Node* node) {
if (node == nullptr) {
TINT_ICE(Resolver, diagnostics_) << "Resolver::Mark() called with nullptr";
return false;
}
if (marked_.emplace(node).second) {
return;
return true;
}
TINT_ICE(Resolver, diagnostics_)
<< "AST node '" << node->TypeInfo().name
<< "' was encountered twice in the same AST of a Program\n"
<< "At: " << node->source << "\n"
<< "Pointer: " << node;
return false;
}
void Resolver::AddError(const std::string& msg, const Source& source) const {

View File

@ -25,6 +25,7 @@
#include "src/intrinsic_table.h"
#include "src/program_builder.h"
#include "src/resolver/dependency_graph.h"
#include "src/scope_stack.h"
#include "src/sem/binding_point.h"
#include "src/sem/block_statement.h"
@ -378,7 +379,8 @@ class Resolver {
/// the given node has not already been seen. Diamonds in the AST are
/// illegal.
/// @param node the AST node.
void Mark(const ast::Node* node);
/// @returns true on success, false on error
bool Mark(const ast::Node* node);
/// Adds the given error message to the diagnostics
void AddError(const std::string& msg, const Source& source) const;
@ -453,6 +455,7 @@ class Resolver {
ProgramBuilder* const builder_;
diag::List& diagnostics_;
std::unique_ptr<IntrinsicTable> const intrinsic_table_;
DependencyGraph dependencies_;
ScopeStack<sem::Variable*> variable_stack_;
std::unordered_map<Symbol, sem::Function*> symbol_to_function_;
std::vector<sem::Function*> entry_points_;

View File

@ -307,17 +307,6 @@ TEST_F(ResolverTest, Stmt_VariableDecl_Alias) {
EXPECT_TRUE(TypeOf(init)->Is<sem::I32>());
}
TEST_F(ResolverTest, Stmt_VariableDecl_AliasRedeclared) {
Alias(Source{{12, 34}}, "MyInt", ty.i32());
Alias(Source{{56, 78}}, "MyInt", ty.i32());
WrapInFunction();
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"56:78 error: type with the name 'MyInt' was already declared\n"
"12:34 note: first declared here");
}
TEST_F(ResolverTest, Stmt_VariableDecl_ModuleScope) {
auto* init = Expr(2);
Global("my_var", ty.i32(), ast::StorageClass::kPrivate, init);
@ -1979,9 +1968,9 @@ TEST_F(ResolverTest, ASTNodesAreReached) {
TEST_F(ResolverTest, ASTNodeNotReached) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder builder;
builder.Expr("1");
Resolver(&builder).Resolve();
ProgramBuilder b;
b.Expr("expr");
Resolver(&b).Resolve();
},
"internal compiler error: AST node 'tint::ast::IdentifierExpression' was "
"not reached by the resolver");
@ -1990,15 +1979,14 @@ TEST_F(ResolverTest, ASTNodeNotReached) {
TEST_F(ResolverTest, ASTNodeReachedTwice) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder builder;
auto* expr = builder.Expr("1");
auto* usesExprTwice = builder.Add(expr, expr);
builder.Global("g", builder.ty.i32(), ast::StorageClass::kPrivate,
usesExprTwice);
Resolver(&builder).Resolve();
ProgramBuilder b;
auto* expr = b.Expr(1);
b.Global("a", b.ty.i32(), ast::StorageClass::kPrivate, expr);
b.Global("b", b.ty.i32(), ast::StorageClass::kPrivate, expr);
Resolver(&b).Resolve();
},
"internal compiler error: AST node 'tint::ast::IdentifierExpression' was "
"encountered twice in the same AST of a Program");
"internal compiler error: AST node 'tint::ast::SintLiteralExpression' "
"was encountered twice in the same AST of a Program");
}
TEST_F(ResolverTest, UnaryOp_Not) {

View File

@ -111,8 +111,7 @@ TEST_F(ResolverValidationTest, Error_WithEmptySource) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"error: unknown statement type for type determination: "
"tint::resolver::FakeStmt");
"error: unknown statement type: tint::resolver::FakeStmt");
}
TEST_F(ResolverValidationTest, Stmt_Error_Unknown) {
@ -122,8 +121,7 @@ TEST_F(ResolverValidationTest, Stmt_Error_Unknown) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"2:30 error: unknown statement type for type determination: "
"tint::resolver::FakeStmt");
"2:30 error: unknown statement type: tint::resolver::FakeStmt");
}
TEST_F(ResolverValidationTest, Stmt_If_NonBool) {
@ -203,8 +201,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariable_Fail) {
WrapInFunction(assign);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: identifier must be declared before use: b");
EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: 'b'");
}
TEST_F(ResolverValidationTest, UsingUndefinedVariableInBlockStatement_Fail) {
@ -219,30 +216,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableInBlockStatement_Fail) {
WrapInFunction(body);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: identifier must be declared before use: b");
}
TEST_F(ResolverValidationTest, UsingUndefinedVariableGlobalVariableAfter_Fail) {
// fn my_func() {
// global_var = 3.14f;
// }
// var global_var: f32 = 2.1;
auto* lhs = Expr(Source{{12, 34}}, "global_var");
auto* rhs = Expr(3.14f);
Func("my_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
Assign(lhs, rhs),
},
ast::DecorationList{Stage(ast::PipelineStage::kVertex)});
Global("global_var", ty.f32(), ast::StorageClass::kPrivate, Expr(2.1f));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: identifier must be declared before use: global_var");
EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: 'b'");
}
TEST_F(ResolverValidationTest, UsingUndefinedVariableGlobalVariable_Pass) {
@ -255,7 +229,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableGlobalVariable_Pass) {
Global("global_var", ty.f32(), ast::StorageClass::kPrivate, Expr(2.1f));
Func("my_func", ast::VariableList{}, ty.void_(),
ast::StatementList{
{
Assign(Expr(Source{Source::Location{12, 34}}, "global_var"), 3.14f),
Return(),
});
@ -284,8 +258,7 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableInnerScope_Fail) {
WrapInFunction(outer_body);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: identifier must be declared before use: a");
EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: 'a'");
}
TEST_F(ResolverValidationTest, UsingUndefinedVariableOuterScope_Pass) {
@ -327,16 +300,14 @@ TEST_F(ResolverValidationTest, UsingUndefinedVariableDifferentScope_Fail) {
WrapInFunction(outer_body);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: identifier must be declared before use: a");
EXPECT_EQ(r()->error(), "12:34 error: unknown identifier: 'a'");
}
TEST_F(ResolverValidationTest, StorageClass_FunctionVariableWorkgroupClass) {
auto* var = Var("var", ty.i32(), ast::StorageClass::kWorkgroup);
auto* stmt = Decl(var);
Func("func", ast::VariableList{}, ty.void_(), ast::StatementList{stmt},
ast::DecorationList{});
Func("func", ast::VariableList{}, ty.void_(), {stmt}, ast::DecorationList{});
EXPECT_FALSE(r()->Resolve());
@ -348,8 +319,7 @@ TEST_F(ResolverValidationTest, StorageClass_FunctionVariableI32) {
auto* var = Var("s", ty.i32(), ast::StorageClass::kPrivate);
auto* stmt = Decl(var);
Func("func", ast::VariableList{}, ty.void_(), ast::StatementList{stmt},
ast::DecorationList{});
Func("func", ast::VariableList{}, ty.void_(), {stmt}, ast::DecorationList{});
EXPECT_FALSE(r()->Resolve());

View File

@ -150,19 +150,6 @@ TEST_F(ResolverVarLetValidationTest, LetOfPtrConstructedWithRef) {
R"(12:34 error: cannot initialize let of type 'ptr<function, f32, read_write>' with value of type 'f32')");
}
TEST_F(ResolverVarLetValidationTest, LocalVarRedeclared) {
// var v : f32;
// var v : i32;
auto* v1 = Var("v", ty.f32(), ast::StorageClass::kNone);
auto* v2 = Var(Source{{12, 34}}, "v", ty.i32(), ast::StorageClass::kNone);
WrapInFunction(v1, v2);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: redefinition of 'v'\nnote: previous definition is here");
}
TEST_F(ResolverVarLetValidationTest, LocalLetRedeclared) {
// let l : f32 = 1.;
// let l : i32 = 0;
@ -173,31 +160,7 @@ TEST_F(ResolverVarLetValidationTest, LocalLetRedeclared) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: redefinition of 'l'\nnote: previous definition is here");
}
TEST_F(ResolverVarLetValidationTest, GlobalVarRedeclared) {
// var v : f32;
// var v : i32;
Global("v", ty.f32(), ast::StorageClass::kPrivate);
Global(Source{{12, 34}}, "v", ty.i32(), ast::StorageClass::kPrivate);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: redefinition of 'v'\nnote: previous definition is here");
}
TEST_F(ResolverVarLetValidationTest, GlobalLetRedeclared) {
// let l : f32 = 0.1;
// let l : i32 = 0;
GlobalConst("l", ty.f32(), Expr(0.1f));
GlobalConst(Source{{12, 34}}, "l", ty.i32(), Expr(0));
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: redefinition of 'l'\nnote: previous definition is here");
"12:34 error: redeclaration of 'l'\nnote: 'l' previously declared here");
}
TEST_F(ResolverVarLetValidationTest, GlobalVarRedeclaredAsLocal) {

View File

@ -15,6 +15,7 @@
#define SRC_SCOPE_STACK_H_
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/symbol.h"
@ -45,14 +46,19 @@ class ScopeStack {
}
}
/// Assigns the value into the top most scope of the stack
/// @param symbol the symbol of the variable
/// Assigns the value into the top most scope of the stack.
/// @param symbol the symbol of the value
/// @param val the value
void Set(const Symbol& symbol, T val) { stack_.back()[symbol] = val; }
/// @returns the old value if there was an existing symbol at the top of the
/// stack, otherwise the zero initializer for type T.
T Set(const Symbol& symbol, T val) {
std::swap(val, stack_.back()[symbol]);
return val;
}
/// Retrieves a value from the stack
/// @param symbol the symbol to look for
/// @returns the variable, or the zero initializer if the value was not found
/// @returns the value, or the zero initializer if the value was not found
T Get(const Symbol& symbol) const {
for (auto iter = stack_.rbegin(); iter != stack_.rend(); ++iter) {
auto& map = *iter;

View File

@ -49,5 +49,23 @@ TEST_F(ScopeStackTest, Get_MissingSymbol) {
EXPECT_EQ(s.Get(sym), 0u);
}
TEST_F(ScopeStackTest, Set) {
ScopeStack<uint32_t> s;
Symbol a(1, ID());
Symbol b(2, ID());
EXPECT_EQ(s.Set(a, 5u), 0u);
EXPECT_EQ(s.Get(a), 5u);
EXPECT_EQ(s.Set(b, 10u), 0u);
EXPECT_EQ(s.Get(b), 10u);
EXPECT_EQ(s.Set(a, 20u), 5u);
EXPECT_EQ(s.Get(a), 20u);
EXPECT_EQ(s.Set(b, 25u), 10u);
EXPECT_EQ(s.Get(b), 25u);
}
} // namespace
} // namespace tint

View File

@ -16,13 +16,13 @@
/// If set to 1 then the transform::Manager will dump the WGSL of the program
/// before and after each transform. Helpful for debugging bad output.
#define PRINT_PROGRAM_FOR_EACH_TRANSFORM 0
#define TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM 0
#if PRINT_PROGRAM_FOR_EACH_TRANSFORM
#define IF_PRINT_PROGRAM(x) x
#else // PRINT_PROGRAM_FOR_EACH_TRANSFORM
#define IF_PRINT_PROGRAM(x)
#endif // PRINT_PROGRAM_FOR_EACH_TRANSFORM
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
#define TINT_IF_PRINT_PROGRAM(x) x
#else // TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
#define TINT_IF_PRINT_PROGRAM(x)
#endif // TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
TINT_INSTANTIATE_TYPEINFO(tint::transform::Manager);
@ -33,7 +33,7 @@ Manager::Manager() = default;
Manager::~Manager() = default;
Output Manager::Run(const Program* program, const DataMap& data) {
#if PRINT_PROGRAM_FOR_EACH_TRANSFORM
#if TINT_PRINT_PROGRAM_FOR_EACH_TRANSFORM
auto print_program = [&](const char* msg, const Transform* transform) {
auto wgsl = Program::printer(program);
std::cout << "---------------------------------------------------------"
@ -52,19 +52,20 @@ Output Manager::Run(const Program* program, const DataMap& data) {
Output out;
if (!transforms_.empty()) {
for (const auto& transform : transforms_) {
IF_PRINT_PROGRAM(print_program("Input to", transform.get()));
TINT_IF_PRINT_PROGRAM(print_program("Input to", transform.get()));
auto res = transform->Run(program, data);
out.program = std::move(res.program);
out.data.Add(std::move(res.data));
program = &out.program;
if (!program->IsValid()) {
IF_PRINT_PROGRAM(print_program("Invalid output of", transform.get()));
TINT_IF_PRINT_PROGRAM(
print_program("Invalid output of", transform.get()));
return out;
}
if (transform == transforms_.back()) {
IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
TINT_IF_PRINT_PROGRAM(print_program("Output of", transform.get()));
}
}
} else {

View File

@ -15,7 +15,9 @@
#ifndef SRC_UTILS_UNIQUE_VECTOR_H_
#define SRC_UTILS_UNIQUE_VECTOR_H_
#include <functional>
#include <unordered_set>
#include <utility>
#include <vector>
namespace tint {
@ -23,10 +25,14 @@ namespace utils {
/// UniqueVector is an ordered container that only contains unique items.
/// Attempting to add a duplicate is a no-op.
template <typename T, typename HASH = std::hash<T>>
template <typename T,
typename HASH = std::hash<T>,
typename EQUAL = std::equal_to<T>>
struct UniqueVector {
/// The iterator returned by begin() and end()
using ConstIterator = typename std::vector<T>::const_iterator;
/// The iterator returned by rbegin() and rend()
using ConstReverseIterator = typename std::vector<T>::const_reverse_iterator;
/// Constructor
UniqueVector() = default;
@ -43,11 +49,14 @@ struct UniqueVector {
/// add appends the item to the end of the vector, if the vector does not
/// already contain the given item.
/// @param item the item to append to the end of the vector
void add(const T& item) {
/// @returns true if the item was added, otherwise false.
bool add(const T& item) {
if (set.count(item) == 0) {
vector.emplace_back(item);
set.emplace(item);
return true;
}
return false;
}
/// @returns true if the vector contains `item`
@ -71,12 +80,27 @@ struct UniqueVector {
/// @returns an iterator to the end of the vector
ConstIterator end() const { return vector.end(); }
/// @returns an iterator to the beginning of the reversed vector
ConstReverseIterator rbegin() const { return vector.rbegin(); }
/// @returns an iterator to the end of the reversed vector
ConstReverseIterator rend() const { return vector.rend(); }
/// @returns a const reference to the internal vector
operator const std::vector<T>&() const { return vector; }
operator const std::vector<T> &() const { return vector; }
/// Removes the last element from the vector
/// @returns the popped element
T pop_back() {
auto el = std::move(vector.back());
set.erase(el);
vector.pop_back();
return el;
}
private:
std::vector<T> vector;
std::unordered_set<T, HASH> set;
std::unordered_set<T, HASH, EQUAL> set;
};
} // namespace utils

View File

@ -13,6 +13,7 @@
// limitations under the License.
#include "src/utils/unique_vector.h"
#include "src/utils/reverse.h"
#include "gtest/gtest.h"
@ -46,6 +47,10 @@ TEST(UniqueVectorTest, AddUnique) {
EXPECT_EQ(n, i);
i++;
}
for (auto n : Reverse(unique_vec)) {
i--;
EXPECT_EQ(n, i);
}
EXPECT_EQ(unique_vec[0], 0);
EXPECT_EQ(unique_vec[1], 1);
EXPECT_EQ(unique_vec[2], 2);
@ -65,6 +70,10 @@ TEST(UniqueVectorTest, AddDuplicates) {
EXPECT_EQ(n, i);
i++;
}
for (auto n : Reverse(unique_vec)) {
i--;
EXPECT_EQ(n, i);
}
EXPECT_EQ(unique_vec[0], 0);
EXPECT_EQ(unique_vec[1], 1);
EXPECT_EQ(unique_vec[2], 2);
@ -86,6 +95,32 @@ TEST(UniqueVectorTest, AsVector) {
EXPECT_EQ(n, i);
i++;
}
for (auto n : Reverse(unique_vec)) {
i--;
EXPECT_EQ(n, i);
}
}
TEST(UniqueVectorTest, PopBack) {
UniqueVector<int> unique_vec;
unique_vec.add(0);
unique_vec.add(2);
unique_vec.add(1);
EXPECT_EQ(unique_vec.pop_back(), 1);
EXPECT_EQ(unique_vec.size(), 2u);
EXPECT_EQ(unique_vec[0], 0);
EXPECT_EQ(unique_vec[1], 2);
EXPECT_EQ(unique_vec.pop_back(), 2);
EXPECT_EQ(unique_vec.size(), 1u);
EXPECT_EQ(unique_vec[0], 0);
unique_vec.add(1);
EXPECT_EQ(unique_vec.size(), 2u);
EXPECT_EQ(unique_vec[0], 0);
EXPECT_EQ(unique_vec[1], 1);
}
} // namespace

View File

@ -243,6 +243,7 @@ tint_unittests_source_set("tint_unittests_resolver_src") {
"../src/resolver/compound_statement_test.cc",
"../src/resolver/control_block_validation_test.cc",
"../src/resolver/decoration_validation_test.cc",
"../src/resolver/dependency_graph_test.cc",
"../src/resolver/entry_point_validation_test.cc",
"../src/resolver/function_validation_test.cc",
"../src/resolver/host_shareable_validation_test.cc",