ast: Add TraverseExpressions()

An ast::Expression traversal helper extracted from Resolver.

Change-Id: I88754cbc86cc12cbf8348fb36a3f038904017f3d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/67202
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton
2021-10-21 20:38:54 +00:00
committed by Tint LUCI CQ
parent 72789de9f5
commit f164a4a723
7 changed files with 414 additions and 71 deletions

View File

@@ -45,6 +45,7 @@
#include "src/ast/storage_texture.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/switch_statement.h"
#include "src/ast/traverse_expressions.h"
#include "src/ast/type_name.h"
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
@@ -469,7 +470,6 @@ Resolver::VariableInfo* Resolver::Variable(const ast::Variable* var,
// Does the variable have a constructor?
if (auto* ctor = var->constructor) {
Mark(var->constructor);
if (!Expression(var->constructor)) {
return nullptr;
}
@@ -1886,7 +1886,6 @@ bool Resolver::Function(const ast::Function* func) {
continue;
}
Mark(expr);
if (!Expression(expr)) {
return false;
}
@@ -2061,7 +2060,6 @@ bool Resolver::Statement(const ast::Statement* stmt) {
return true;
}
if (auto* c = stmt->As<ast::CallStatement>()) {
Mark(c->expr);
if (!Expression(c->expr)) {
return false;
}
@@ -2138,7 +2136,6 @@ bool Resolver::IfStatement(const ast::IfStatement* stmt) {
builder_->create<sem::IfStatement>(stmt, current_compound_statement_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
Mark(stmt->condition);
if (!Expression(stmt->condition)) {
return false;
}
@@ -2175,7 +2172,6 @@ bool Resolver::ElseStatement(const ast::ElseStatement* stmt) {
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
if (auto* cond = stmt->condition) {
Mark(cond);
if (!Expression(cond)) {
return false;
}
@@ -2250,7 +2246,6 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
}
if (auto* condition = stmt->condition) {
Mark(condition);
if (!Expression(condition)) {
return false;
}
@@ -2279,58 +2274,14 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
});
}
bool Resolver::TraverseExpressions(const ast::Expression* root,
std::vector<const ast::Expression*>& out) {
std::vector<const ast::Expression*> to_visit;
to_visit.emplace_back(root);
auto add = [&](const ast::Expression* e) {
Mark(e);
to_visit.emplace_back(e);
};
while (!to_visit.empty()) {
auto* expr = to_visit.back();
to_visit.pop_back();
out.emplace_back(expr);
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
add(array->array);
add(array->index);
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
add(bin_op->lhs);
add(bin_op->rhs);
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
add(bitcast->expr);
} else if (auto* call = expr->As<ast::CallExpression>()) {
for (auto* arg : call->args) {
add(arg);
}
} else if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
for (auto* value : type_ctor->values) {
add(value);
}
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
add(member->structure);
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
add(unary->expr);
} else if (expr->IsAnyOf<ast::ScalarConstructorExpression,
ast::IdentifierExpression>()) {
// Leaf expression
} else {
TINT_ICE(Resolver, diagnostics_)
<< "unhandled expression type: " << expr->TypeInfo().name;
return false;
}
}
return true;
}
bool Resolver::Expression(const ast::Expression* root) {
std::vector<const ast::Expression*> sorted;
if (!TraverseExpressions(root, sorted)) {
if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
root, diagnostics_, [&](const ast::Expression* expr) {
Mark(expr);
sorted.emplace_back(expr);
return ast::TraverseAction::Descend;
})) {
return false;
}
@@ -3874,7 +3825,6 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
// sem::Array uses a size of 0 for a runtime-sized array.
uint32_t count = 0;
if (auto* count_expr = arr->count) {
Mark(count_expr);
if (!Expression(count_expr)) {
return nullptr;
}
@@ -4340,7 +4290,6 @@ bool Resolver::Return(const ast::ReturnStatement* ret) {
current_function_->return_statements.push_back(ret);
if (auto* value = ret->value) {
Mark(value);
if (!Expression(value)) {
return false;
}
@@ -4424,7 +4373,6 @@ bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) {
builder_->create<sem::SwitchStatement>(stmt, current_compound_statement_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
Mark(stmt->condition);
if (!Expression(stmt->condition)) {
return false;
}
@@ -4442,9 +4390,6 @@ bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) {
}
bool Resolver::Assignment(const ast::AssignmentStatement* a) {
Mark(a->lhs);
Mark(a->rhs);
if (!Expression(a->lhs) || !Expression(a->rhs)) {
return false;
}

View File

@@ -262,15 +262,6 @@ class Resolver {
bool UnaryOp(const ast::UnaryOpExpression*);
bool VariableDeclStatement(const ast::VariableDeclStatement*);
/// Performs a depth-first traversal of the expression nodes from `root`,
/// collecting all the visited expressions in pre-ordering (root first).
/// @param root the root expression node
/// @param out the ordered list of visited expression nodes, starting with the
/// root node, and ending with leaf nodes
/// @return true on success, false on error
bool TraverseExpressions(const ast::Expression* root,
std::vector<const ast::Expression*>& out);
// AST and Type validation methods
// Each return true on success, false on failure.
bool ValidateArray(const sem::Array* arr, const Source& source);