resolver: Optimize type dispatch with Switch()

Bug: tint:1383
Change-Id: Ia02c7ddd3e46d36134f5430e4f22df04993b2158
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/81104
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Kokoro: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
Ben Clayton 2022-02-18 22:06:33 +00:00 committed by Tint LUCI CQ
parent 473b6087ac
commit 38f1e9c75c
1 changed files with 161 additions and 180 deletions

View File

@ -139,34 +139,31 @@ class DependencyScanner {
/// dependencies of each global. /// dependencies of each global.
void Scan(Global* global) { void Scan(Global* global) {
TINT_SCOPED_ASSIGNMENT(current_global_, global); TINT_SCOPED_ASSIGNMENT(current_global_, global);
Switch(
if (auto* str = global->node->As<ast::Struct>()) { global->node,
[&](const ast::Struct* str) {
Declare(str->name, str); Declare(str->name, str);
for (auto* member : str->members) { for (auto* member : str->members) {
TraverseType(member->type); TraverseType(member->type);
} }
return; },
} [&](const ast::Alias* alias) {
if (auto* alias = global->node->As<ast::Alias>()) {
Declare(alias->name, alias); Declare(alias->name, alias);
TraverseType(alias->type); TraverseType(alias->type);
return; },
} [&](const ast::Function* func) {
if (auto* func = global->node->As<ast::Function>()) {
Declare(func->symbol, func); Declare(func->symbol, func);
TraverseAttributes(func->attributes); TraverseAttributes(func->attributes);
TraverseFunction(func); TraverseFunction(func);
return; },
} [&](const ast::Variable* var) {
if (auto* var = global->node->As<ast::Variable>()) {
Declare(var->symbol, var); Declare(var->symbol, var);
TraverseType(var->type); TraverseType(var->type);
if (var->constructor) { if (var->constructor) {
TraverseExpression(var->constructor); TraverseExpression(var->constructor);
} }
return; },
} [&](Default) { UnhandledNode(diagnostics_, global->node); });
UnhandledNode(diagnostics_, global->node);
} }
private: private:
@ -208,54 +205,49 @@ class DependencyScanner {
/// Traverses the statement, performing symbol resolution and determining /// Traverses the statement, performing symbol resolution and determining
/// global dependencies. /// global dependencies.
void TraverseStatement(const ast::Statement* stmt) { void TraverseStatement(const ast::Statement* stmt) {
if (stmt == nullptr) { if (!stmt) {
return; return;
} }
if (auto* b = stmt->As<ast::AssignmentStatement>()) { Switch(
TraverseExpression(b->lhs); stmt, //
TraverseExpression(b->rhs); [&](const ast::AssignmentStatement* a) {
return; TraverseExpression(a->lhs);
} TraverseExpression(a->rhs);
if (auto* b = stmt->As<ast::BlockStatement>()) { },
[&](const ast::BlockStatement* b) {
scope_stack_.Push(); scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop()); TINT_DEFER(scope_stack_.Pop());
TraverseStatements(b->statements); TraverseStatements(b->statements);
return; },
} [&](const ast::CallStatement* r) { //
if (auto* r = stmt->As<ast::CallStatement>()) {
TraverseExpression(r->expr); TraverseExpression(r->expr);
return; },
} [&](const ast::ForLoopStatement* l) {
if (auto* l = stmt->As<ast::ForLoopStatement>()) {
scope_stack_.Push(); scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop()); TINT_DEFER(scope_stack_.Pop());
TraverseStatement(l->initializer); TraverseStatement(l->initializer);
TraverseExpression(l->condition); TraverseExpression(l->condition);
TraverseStatement(l->continuing); TraverseStatement(l->continuing);
TraverseStatement(l->body); TraverseStatement(l->body);
return; },
} [&](const ast::LoopStatement* l) {
if (auto* l = stmt->As<ast::LoopStatement>()) {
scope_stack_.Push(); scope_stack_.Push();
TINT_DEFER(scope_stack_.Pop()); TINT_DEFER(scope_stack_.Pop());
TraverseStatements(l->body->statements); TraverseStatements(l->body->statements);
TraverseStatement(l->continuing); TraverseStatement(l->continuing);
return; },
} [&](const ast::IfStatement* i) {
if (auto* i = stmt->As<ast::IfStatement>()) {
TraverseExpression(i->condition); TraverseExpression(i->condition);
TraverseStatement(i->body); TraverseStatement(i->body);
for (auto* e : i->else_statements) { for (auto* e : i->else_statements) {
TraverseExpression(e->condition); TraverseExpression(e->condition);
TraverseStatement(e->body); TraverseStatement(e->body);
} }
return; },
} [&](const ast::ReturnStatement* r) { //
if (auto* r = stmt->As<ast::ReturnStatement>()) {
TraverseExpression(r->value); TraverseExpression(r->value);
return; },
} [&](const ast::SwitchStatement* s) {
if (auto* s = stmt->As<ast::SwitchStatement>()) {
TraverseExpression(s->condition); TraverseExpression(s->condition);
for (auto* c : s->body) { for (auto* c : s->body) {
for (auto* sel : c->selectors) { for (auto* sel : c->selectors) {
@ -263,24 +255,23 @@ class DependencyScanner {
} }
TraverseStatement(c->body); TraverseStatement(c->body);
} }
return; },
} [&](const ast::VariableDeclStatement* v) {
if (auto* v = stmt->As<ast::VariableDeclStatement>()) {
if (auto* shadows = scope_stack_.Get(v->variable->symbol)) { if (auto* shadows = scope_stack_.Get(v->variable->symbol)) {
graph_.shadows.emplace(v->variable, shadows); graph_.shadows.emplace(v->variable, shadows);
} }
TraverseType(v->variable->type); TraverseType(v->variable->type);
TraverseExpression(v->variable->constructor); TraverseExpression(v->variable->constructor);
Declare(v->variable->symbol, v->variable); Declare(v->variable->symbol, v->variable);
return; },
} [&](Default) {
if (stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement, if (!stmt->IsAnyOf<ast::BreakStatement, ast::ContinueStatement,
ast::DiscardStatement, ast::FallthroughStatement>()) { ast::DiscardStatement,
return; ast::FallthroughStatement>()) {
}
UnhandledNode(diagnostics_, stmt); UnhandledNode(diagnostics_, stmt);
} }
});
}
/// Adds the symbol definition to the current scope, raising an error if two /// Adds the symbol definition to the current scope, raising an error if two
/// symbols collide within the same scope. /// symbols collide within the same scope.
@ -302,10 +293,12 @@ class DependencyScanner {
} }
ast::TraverseExpressions( ast::TraverseExpressions(
root, diagnostics_, [&](const ast::Expression* expr) { root, diagnostics_, [&](const ast::Expression* expr) {
if (auto* ident = expr->As<ast::IdentifierExpression>()) { Switch(
expr,
[&](const ast::IdentifierExpression* ident) {
AddDependency(ident, ident->symbol, "identifier", "references"); AddDependency(ident, ident->symbol, "identifier", "references");
} },
if (auto* call = expr->As<ast::CallExpression>()) { [&](const ast::CallExpression* call) {
if (call->target.name) { if (call->target.name) {
AddDependency(call->target.name, call->target.name->symbol, AddDependency(call->target.name, call->target.name->symbol,
"function", "calls"); "function", "calls");
@ -313,10 +306,10 @@ class DependencyScanner {
if (call->target.type) { if (call->target.type) {
TraverseType(call->target.type); TraverseType(call->target.type);
} }
} },
if (auto* cast = expr->As<ast::BitcastExpression>()) { [&](const ast::BitcastExpression* cast) {
TraverseType(cast->type); TraverseType(cast->type);
} });
return ast::TraverseAction::Descend; return ast::TraverseAction::Descend;
}); });
} }
@ -324,51 +317,45 @@ class DependencyScanner {
/// Traverses the type node, performing symbol resolution and determining /// Traverses the type node, performing symbol resolution and determining
/// global dependencies. /// global dependencies.
void TraverseType(const ast::Type* ty) { void TraverseType(const ast::Type* ty) {
if (ty == nullptr) { if (!ty) {
return; return;
} }
if (auto* arr = ty->As<ast::Array>()) { Switch(
TraverseType(arr->type); ty, //
[&](const ast::Array* arr) {
TraverseType(arr->type); //
TraverseExpression(arr->count); TraverseExpression(arr->count);
return; },
} [&](const ast::Atomic* atomic) { //
if (auto* atomic = ty->As<ast::Atomic>()) {
TraverseType(atomic->type); TraverseType(atomic->type);
return; },
} [&](const ast::Matrix* mat) { //
if (auto* mat = ty->As<ast::Matrix>()) {
TraverseType(mat->type); TraverseType(mat->type);
return; },
} [&](const ast::Pointer* ptr) { //
if (auto* ptr = ty->As<ast::Pointer>()) {
TraverseType(ptr->type); TraverseType(ptr->type);
return; },
} [&](const ast::TypeName* tn) { //
if (auto* tn = ty->As<ast::TypeName>()) {
AddDependency(tn, tn->name, "type", "references"); AddDependency(tn, tn->name, "type", "references");
return; },
} [&](const ast::Vector* vec) { //
if (auto* vec = ty->As<ast::Vector>()) {
TraverseType(vec->type); TraverseType(vec->type);
return; },
} [&](const ast::SampledTexture* tex) { //
if (auto* tex = ty->As<ast::SampledTexture>()) {
TraverseType(tex->type); TraverseType(tex->type);
return; },
} [&](const ast::MultisampledTexture* tex) { //
if (auto* tex = ty->As<ast::MultisampledTexture>()) {
TraverseType(tex->type); TraverseType(tex->type);
return; },
} [&](Default) {
if (ty->IsAnyOf<ast::Void, ast::Bool, ast::I32, ast::U32, ast::F32, if (!ty->IsAnyOf<ast::Void, ast::Bool, ast::I32, ast::U32, ast::F32,
ast::DepthTexture, ast::DepthMultisampledTexture, ast::DepthTexture, ast::DepthMultisampledTexture,
ast::StorageTexture, ast::ExternalTexture, ast::StorageTexture, ast::ExternalTexture,
ast::Sampler>()) { ast::Sampler>()) {
return;
}
UnhandledNode(diagnostics_, ty); UnhandledNode(diagnostics_, ty);
} }
});
}
/// Traverses the attribute list, performing symbol resolution and /// Traverses the attribute list, performing symbol resolution and
/// determining global dependencies. /// determining global dependencies.
@ -490,17 +477,15 @@ struct DependencyAnalysis {
/// @note will raise an ICE if the node is not a type, function or variable /// @note will raise an ICE if the node is not a type, function or variable
/// declaration /// declaration
Symbol SymbolOf(const ast::Node* node) const { Symbol SymbolOf(const ast::Node* node) const {
if (auto* td = node->As<ast::TypeDecl>()) { return Switch(
return td->name; node, //
} [&](const ast::TypeDecl* td) { return td->name; },
if (auto* func = node->As<ast::Function>()) { [&](const ast::Function* func) { return func->symbol; },
return func->symbol; [&](const ast::Variable* var) { return var->symbol; },
} [&](Default) {
if (auto* var = node->As<ast::Variable>()) {
return var->symbol;
}
UnhandledNode(diagnostics_, node); UnhandledNode(diagnostics_, node);
return {}; return Symbol{};
});
} }
/// @param node the ast::Node of the global declaration /// @param node the ast::Node of the global declaration
@ -516,20 +501,16 @@ struct DependencyAnalysis {
/// @note will raise an ICE if the node is not a type, function or variable /// @note will raise an ICE if the node is not a type, function or variable
/// declaration /// declaration
std::string KindOf(const ast::Node* node) { std::string KindOf(const ast::Node* node) {
if (node->Is<ast::Struct>()) { return Switch(
return "struct"; node, //
} [&](const ast::Struct*) { return "struct"; },
if (node->Is<ast::Alias>()) { [&](const ast::Alias*) { return "alias"; },
return "alias"; [&](const ast::Function*) { return "function"; },
} [&](const ast::Variable* var) { return var->is_const ? "let" : "var"; },
if (node->Is<ast::Function>()) { [&](Default) {
return "function";
}
if (auto* var = node->As<ast::Variable>()) {
return var->is_const ? "let" : "var";
}
UnhandledNode(diagnostics_, node); UnhandledNode(diagnostics_, node);
return {}; return "<error>";
});
} }
/// Traverses `module`, collecting all the global declarations and populating /// Traverses `module`, collecting all the global declarations and populating