From 38f1e9c75c5c9f73b02bb82ec1e2e9611a554f9a Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Fri, 18 Feb 2022 22:06:33 +0000 Subject: [PATCH] resolver: Optimize type dispatch with Switch() Bug: tint:1383 Change-Id: Ia02c7ddd3e46d36134f5430e4f22df04993b2158 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/81104 Reviewed-by: Antonio Maiorano Commit-Queue: Ben Clayton Kokoro: Ben Clayton --- src/resolver/dependency_graph.cc | 341 +++++++++++++++---------------- 1 file changed, 161 insertions(+), 180 deletions(-) diff --git a/src/resolver/dependency_graph.cc b/src/resolver/dependency_graph.cc index cbb574de8f..bf82f2d0d7 100644 --- a/src/resolver/dependency_graph.cc +++ b/src/resolver/dependency_graph.cc @@ -139,34 +139,31 @@ class DependencyScanner { /// dependencies of each global. void Scan(Global* global) { TINT_SCOPED_ASSIGNMENT(current_global_, global); - - if (auto* str = global->node->As()) { - Declare(str->name, str); - for (auto* member : str->members) { - TraverseType(member->type); - } - return; - } - if (auto* alias = global->node->As()) { - Declare(alias->name, alias); - TraverseType(alias->type); - return; - } - if (auto* func = global->node->As()) { - Declare(func->symbol, func); - TraverseAttributes(func->attributes); - TraverseFunction(func); - return; - } - if (auto* var = global->node->As()) { - Declare(var->symbol, var); - TraverseType(var->type); - if (var->constructor) { - TraverseExpression(var->constructor); - } - return; - } - UnhandledNode(diagnostics_, global->node); + Switch( + global->node, + [&](const ast::Struct* str) { + Declare(str->name, str); + for (auto* member : str->members) { + TraverseType(member->type); + } + }, + [&](const ast::Alias* alias) { + Declare(alias->name, alias); + TraverseType(alias->type); + }, + [&](const ast::Function* func) { + Declare(func->symbol, func); + TraverseAttributes(func->attributes); + TraverseFunction(func); + }, + [&](const ast::Variable* var) { + Declare(var->symbol, var); + TraverseType(var->type); + if (var->constructor) { + TraverseExpression(var->constructor); + } + }, + [&](Default) { UnhandledNode(diagnostics_, global->node); }); } private: @@ -208,78 +205,72 @@ class DependencyScanner { /// Traverses the statement, performing symbol resolution and determining /// global dependencies. void TraverseStatement(const ast::Statement* stmt) { - if (stmt == nullptr) { + if (!stmt) { return; } - if (auto* b = stmt->As()) { - TraverseExpression(b->lhs); - TraverseExpression(b->rhs); - return; - } - if (auto* b = stmt->As()) { - scope_stack_.Push(); - TINT_DEFER(scope_stack_.Pop()); - TraverseStatements(b->statements); - return; - } - if (auto* r = stmt->As()) { - TraverseExpression(r->expr); - return; - } - if (auto* l = stmt->As()) { - scope_stack_.Push(); - TINT_DEFER(scope_stack_.Pop()); - TraverseStatement(l->initializer); - TraverseExpression(l->condition); - TraverseStatement(l->continuing); - TraverseStatement(l->body); - return; - } - if (auto* l = stmt->As()) { - scope_stack_.Push(); - TINT_DEFER(scope_stack_.Pop()); - TraverseStatements(l->body->statements); - TraverseStatement(l->continuing); - return; - } - if (auto* i = stmt->As()) { - TraverseExpression(i->condition); - TraverseStatement(i->body); - for (auto* e : i->else_statements) { - TraverseExpression(e->condition); - TraverseStatement(e->body); - } - return; - } - if (auto* r = stmt->As()) { - TraverseExpression(r->value); - return; - } - if (auto* s = stmt->As()) { - TraverseExpression(s->condition); - for (auto* c : s->body) { - for (auto* sel : c->selectors) { - TraverseExpression(sel); - } - TraverseStatement(c->body); - } - return; - } - if (auto* v = stmt->As()) { - if (auto* shadows = scope_stack_.Get(v->variable->symbol)) { - graph_.shadows.emplace(v->variable, shadows); - } - TraverseType(v->variable->type); - TraverseExpression(v->variable->constructor); - Declare(v->variable->symbol, v->variable); - return; - } - if (stmt->IsAnyOf()) { - return; - } - - UnhandledNode(diagnostics_, stmt); + Switch( + stmt, // + [&](const ast::AssignmentStatement* a) { + TraverseExpression(a->lhs); + TraverseExpression(a->rhs); + }, + [&](const ast::BlockStatement* b) { + scope_stack_.Push(); + TINT_DEFER(scope_stack_.Pop()); + TraverseStatements(b->statements); + }, + [&](const ast::CallStatement* r) { // + TraverseExpression(r->expr); + }, + [&](const ast::ForLoopStatement* l) { + scope_stack_.Push(); + TINT_DEFER(scope_stack_.Pop()); + TraverseStatement(l->initializer); + TraverseExpression(l->condition); + TraverseStatement(l->continuing); + TraverseStatement(l->body); + }, + [&](const ast::LoopStatement* l) { + scope_stack_.Push(); + TINT_DEFER(scope_stack_.Pop()); + TraverseStatements(l->body->statements); + TraverseStatement(l->continuing); + }, + [&](const ast::IfStatement* i) { + TraverseExpression(i->condition); + TraverseStatement(i->body); + for (auto* e : i->else_statements) { + TraverseExpression(e->condition); + TraverseStatement(e->body); + } + }, + [&](const ast::ReturnStatement* r) { // + TraverseExpression(r->value); + }, + [&](const ast::SwitchStatement* s) { + TraverseExpression(s->condition); + for (auto* c : s->body) { + for (auto* sel : c->selectors) { + TraverseExpression(sel); + } + TraverseStatement(c->body); + } + }, + [&](const ast::VariableDeclStatement* v) { + if (auto* shadows = scope_stack_.Get(v->variable->symbol)) { + graph_.shadows.emplace(v->variable, shadows); + } + TraverseType(v->variable->type); + TraverseExpression(v->variable->constructor); + Declare(v->variable->symbol, v->variable); + }, + [&](Default) { + if (!stmt->IsAnyOf()) { + UnhandledNode(diagnostics_, stmt); + } + }); } /// Adds the symbol definition to the current scope, raising an error if two @@ -302,21 +293,23 @@ class DependencyScanner { } ast::TraverseExpressions( root, diagnostics_, [&](const ast::Expression* expr) { - if (auto* ident = expr->As()) { - AddDependency(ident, ident->symbol, "identifier", "references"); - } - if (auto* call = expr->As()) { - if (call->target.name) { - AddDependency(call->target.name, call->target.name->symbol, - "function", "calls"); - } - if (call->target.type) { - TraverseType(call->target.type); - } - } - if (auto* cast = expr->As()) { - TraverseType(cast->type); - } + Switch( + expr, + [&](const ast::IdentifierExpression* ident) { + AddDependency(ident, ident->symbol, "identifier", "references"); + }, + [&](const ast::CallExpression* call) { + if (call->target.name) { + AddDependency(call->target.name, call->target.name->symbol, + "function", "calls"); + } + if (call->target.type) { + TraverseType(call->target.type); + } + }, + [&](const ast::BitcastExpression* cast) { + TraverseType(cast->type); + }); return ast::TraverseAction::Descend; }); } @@ -324,50 +317,44 @@ class DependencyScanner { /// Traverses the type node, performing symbol resolution and determining /// global dependencies. void TraverseType(const ast::Type* ty) { - if (ty == nullptr) { + if (!ty) { return; } - if (auto* arr = ty->As()) { - TraverseType(arr->type); - TraverseExpression(arr->count); - return; - } - if (auto* atomic = ty->As()) { - TraverseType(atomic->type); - return; - } - if (auto* mat = ty->As()) { - TraverseType(mat->type); - return; - } - if (auto* ptr = ty->As()) { - TraverseType(ptr->type); - return; - } - if (auto* tn = ty->As()) { - AddDependency(tn, tn->name, "type", "references"); - return; - } - if (auto* vec = ty->As()) { - TraverseType(vec->type); - return; - } - if (auto* tex = ty->As()) { - TraverseType(tex->type); - return; - } - if (auto* tex = ty->As()) { - TraverseType(tex->type); - return; - } - if (ty->IsAnyOf()) { - return; - } - - UnhandledNode(diagnostics_, ty); + Switch( + ty, // + [&](const ast::Array* arr) { + TraverseType(arr->type); // + TraverseExpression(arr->count); + }, + [&](const ast::Atomic* atomic) { // + TraverseType(atomic->type); + }, + [&](const ast::Matrix* mat) { // + TraverseType(mat->type); + }, + [&](const ast::Pointer* ptr) { // + TraverseType(ptr->type); + }, + [&](const ast::TypeName* tn) { // + AddDependency(tn, tn->name, "type", "references"); + }, + [&](const ast::Vector* vec) { // + TraverseType(vec->type); + }, + [&](const ast::SampledTexture* tex) { // + TraverseType(tex->type); + }, + [&](const ast::MultisampledTexture* tex) { // + TraverseType(tex->type); + }, + [&](Default) { + if (!ty->IsAnyOf()) { + UnhandledNode(diagnostics_, ty); + } + }); } /// Traverses the attribute list, performing symbol resolution and @@ -490,17 +477,15 @@ struct DependencyAnalysis { /// @note will raise an ICE if the node is not a type, function or variable /// declaration Symbol SymbolOf(const ast::Node* node) const { - if (auto* td = node->As()) { - return td->name; - } - if (auto* func = node->As()) { - return func->symbol; - } - if (auto* var = node->As()) { - return var->symbol; - } - UnhandledNode(diagnostics_, node); - return {}; + return Switch( + node, // + [&](const ast::TypeDecl* td) { return td->name; }, + [&](const ast::Function* func) { return func->symbol; }, + [&](const ast::Variable* var) { return var->symbol; }, + [&](Default) { + UnhandledNode(diagnostics_, node); + return Symbol{}; + }); } /// @param node the ast::Node of the global declaration @@ -516,20 +501,16 @@ struct DependencyAnalysis { /// @note will raise an ICE if the node is not a type, function or variable /// declaration std::string KindOf(const ast::Node* node) { - if (node->Is()) { - return "struct"; - } - if (node->Is()) { - return "alias"; - } - if (node->Is()) { - return "function"; - } - if (auto* var = node->As()) { - return var->is_const ? "let" : "var"; - } - UnhandledNode(diagnostics_, node); - return {}; + return Switch( + node, // + [&](const ast::Struct*) { return "struct"; }, + [&](const ast::Alias*) { return "alias"; }, + [&](const ast::Function*) { return "function"; }, + [&](const ast::Variable* var) { return var->is_const ? "let" : "var"; }, + [&](Default) { + UnhandledNode(diagnostics_, node); + return ""; + }); } /// Traverses `module`, collecting all the global declarations and populating