mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-17 00:47:13 +00:00
Resolver: Traverse expressions without recursion
This CL changes the way that the resolver traverses expressions to avoid stack overflows for deeply nested expressions. Instead of having the expression resolver methods call back into Expression(), add a TraverseExpressions() method that collects all the expression nodes with a simple DFS. This currently only changes the way that Expressions are traversed. We may need to do the same for statements. Bug: chromium:1246375 Change-Id: Ie81905da1b790b6dd1df9f1ac42e06593d397c21 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63700 Auto-Submit: Ben Clayton <bclayton@google.com> Reviewed-by: David Neto <dneto@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
be514a1efb
commit
b7bcbf0d20
@@ -72,6 +72,7 @@
|
||||
#include "src/utils/defer.h"
|
||||
#include "src/utils/get_or_create.h"
|
||||
#include "src/utils/math.h"
|
||||
#include "src/utils/reverse.h"
|
||||
#include "src/utils/scoped_assignment.h"
|
||||
|
||||
namespace tint {
|
||||
@@ -2244,60 +2245,94 @@ bool Resolver::ForLoopStatement(ast::ForLoopStatement* stmt) {
|
||||
});
|
||||
}
|
||||
|
||||
bool Resolver::Expressions(const ast::ExpressionList& list) {
|
||||
for (auto* expr : list) {
|
||||
Mark(expr);
|
||||
if (!Expression(expr)) {
|
||||
bool Resolver::TraverseExpressions(ast::Expression* root,
|
||||
std::vector<ast::Expression*>& out) {
|
||||
std::vector<ast::Expression*> to_visit;
|
||||
to_visit.emplace_back(root);
|
||||
|
||||
auto add = [&](ast::Expression* e) {
|
||||
Mark(e);
|
||||
to_visit.emplace_back(e);
|
||||
};
|
||||
|
||||
while (!to_visit.empty()) {
|
||||
auto* expr = to_visit.back();
|
||||
to_visit.pop_back();
|
||||
|
||||
out.emplace_back(expr);
|
||||
|
||||
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
|
||||
add(array->array());
|
||||
add(array->idx_expr());
|
||||
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
|
||||
add(bin_op->lhs());
|
||||
add(bin_op->rhs());
|
||||
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
|
||||
add(bitcast->expr());
|
||||
} else if (auto* call = expr->As<ast::CallExpression>()) {
|
||||
for (auto* arg : call->params()) {
|
||||
add(arg);
|
||||
}
|
||||
} else if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
|
||||
for (auto* value : type_ctor->values()) {
|
||||
add(value);
|
||||
}
|
||||
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
|
||||
add(member->structure());
|
||||
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
|
||||
add(unary->expr());
|
||||
} else if (expr->IsAnyOf<ast::ScalarConstructorExpression,
|
||||
ast::IdentifierExpression>()) {
|
||||
// Leaf expression
|
||||
} else {
|
||||
TINT_ICE(Resolver, diagnostics_)
|
||||
<< "unhandled expression type: " << expr->TypeInfo().name;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Resolver::Expression(ast::Expression* expr) {
|
||||
if (TypeOf(expr)) {
|
||||
return true; // Already resolved
|
||||
}
|
||||
|
||||
bool ok = false;
|
||||
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
|
||||
ok = ArrayAccessor(array);
|
||||
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
|
||||
ok = Binary(bin_op);
|
||||
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
|
||||
ok = Bitcast(bitcast);
|
||||
} else if (auto* call = expr->As<ast::CallExpression>()) {
|
||||
ok = Call(call);
|
||||
} else if (auto* ctor = expr->As<ast::ConstructorExpression>()) {
|
||||
ok = Constructor(ctor);
|
||||
} else if (auto* ident = expr->As<ast::IdentifierExpression>()) {
|
||||
ok = Identifier(ident);
|
||||
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
|
||||
ok = MemberAccessor(member);
|
||||
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
|
||||
ok = UnaryOp(unary);
|
||||
} else {
|
||||
AddError("unknown expression for type determination", expr->source());
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
bool Resolver::Expression(ast::Expression* root) {
|
||||
std::vector<ast::Expression*> sorted;
|
||||
if (!TraverseExpressions(root, sorted)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto* expr : utils::Reverse(sorted)) {
|
||||
bool ok = false;
|
||||
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
|
||||
ok = ArrayAccessor(array);
|
||||
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
|
||||
ok = Binary(bin_op);
|
||||
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
|
||||
ok = Bitcast(bitcast);
|
||||
} else if (auto* call = expr->As<ast::CallExpression>()) {
|
||||
ok = Call(call);
|
||||
} else if (auto* ctor = expr->As<ast::ConstructorExpression>()) {
|
||||
ok = Constructor(ctor);
|
||||
} else if (auto* ident = expr->As<ast::IdentifierExpression>()) {
|
||||
ok = Identifier(ident);
|
||||
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
|
||||
ok = MemberAccessor(member);
|
||||
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
|
||||
ok = UnaryOp(unary);
|
||||
} else {
|
||||
TINT_ICE(Resolver, diagnostics_)
|
||||
<< "unhandled expression type: " << expr->TypeInfo().name;
|
||||
return false;
|
||||
}
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
|
||||
Mark(expr->array());
|
||||
if (!Expression(expr->array())) {
|
||||
return false;
|
||||
}
|
||||
auto* idx = expr->idx_expr();
|
||||
Mark(idx);
|
||||
if (!Expression(idx)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* res = TypeOf(expr->array());
|
||||
auto* parent_type = res->UnwrapRef();
|
||||
const sem::Type* ret = nullptr;
|
||||
@@ -2345,10 +2380,6 @@ bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
|
||||
}
|
||||
|
||||
bool Resolver::Bitcast(ast::BitcastExpression* expr) {
|
||||
Mark(expr->expr());
|
||||
if (!Expression(expr->expr())) {
|
||||
return false;
|
||||
}
|
||||
auto* ty = Type(expr->type());
|
||||
if (!ty) {
|
||||
return false;
|
||||
@@ -2362,10 +2393,6 @@ bool Resolver::Bitcast(ast::BitcastExpression* expr) {
|
||||
}
|
||||
|
||||
bool Resolver::Call(ast::CallExpression* call) {
|
||||
if (!Expressions(call->params())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Mark(call->func());
|
||||
auto* ident = call->func();
|
||||
auto name = builder_->Symbols().NameFor(ident->symbol());
|
||||
@@ -2641,13 +2668,6 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
|
||||
|
||||
bool Resolver::Constructor(ast::ConstructorExpression* expr) {
|
||||
if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
|
||||
for (auto* value : type_ctor->values()) {
|
||||
Mark(value);
|
||||
if (!Expression(value)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto* type = Type(type_ctor->type());
|
||||
if (!type) {
|
||||
return false;
|
||||
@@ -2994,11 +3014,6 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) {
|
||||
}
|
||||
|
||||
bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
|
||||
Mark(expr->structure());
|
||||
if (!Expression(expr->structure())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* structure = TypeOf(expr->structure());
|
||||
auto* storage_type = structure->UnwrapRef();
|
||||
|
||||
@@ -3118,13 +3133,6 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
|
||||
}
|
||||
|
||||
bool Resolver::Binary(ast::BinaryExpression* expr) {
|
||||
Mark(expr->lhs());
|
||||
Mark(expr->rhs());
|
||||
|
||||
if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
using Bool = sem::Bool;
|
||||
using F32 = sem::F32;
|
||||
using I32 = sem::I32;
|
||||
@@ -3330,13 +3338,6 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
|
||||
}
|
||||
|
||||
bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) {
|
||||
Mark(unary->expr());
|
||||
|
||||
// Resolve the inner expression
|
||||
if (!Expression(unary->expr())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* expr_type = TypeOf(unary->expr());
|
||||
if (!expr_type) {
|
||||
return false;
|
||||
|
||||
@@ -244,7 +244,6 @@ class Resolver {
|
||||
bool Constructor(ast::ConstructorExpression*);
|
||||
bool ElseStatement(ast::ElseStatement*);
|
||||
bool Expression(ast::Expression*);
|
||||
bool Expressions(const ast::ExpressionList&);
|
||||
bool ForLoopStatement(ast::ForLoopStatement*);
|
||||
bool Function(ast::Function*);
|
||||
bool FunctionCall(const ast::CallExpression* call);
|
||||
@@ -262,6 +261,15 @@ class Resolver {
|
||||
bool UnaryOp(ast::UnaryOpExpression*);
|
||||
bool VariableDeclStatement(const ast::VariableDeclStatement*);
|
||||
|
||||
/// Performs a depth-first traversal of the expression nodes from `root`,
|
||||
/// collecting all the visited expressions in pre-ordering (root first).
|
||||
/// @param root the root expression node
|
||||
/// @param out the ordered list of visited expression nodes, starting with the
|
||||
/// root node, and ending with leaf nodes
|
||||
/// @return true on success, false on error
|
||||
bool TraverseExpressions(ast::Expression* root,
|
||||
std::vector<ast::Expression*>& out);
|
||||
|
||||
// AST and Type validation methods
|
||||
// Each return true on success, false on failure.
|
||||
bool ValidateArray(const sem::Array* arr, const Source& source);
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "src/resolver/resolver.h"
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest-spi.h"
|
||||
#include "src/ast/assignment_statement.h"
|
||||
#include "src/ast/bitcast_expression.h"
|
||||
#include "src/ast/break_statement.h"
|
||||
@@ -56,10 +57,9 @@ class FakeStmt : public ast::Statement {
|
||||
}
|
||||
};
|
||||
|
||||
class FakeExpr : public ast::Expression {
|
||||
class FakeExpr : public Castable<FakeExpr, ast::Expression> {
|
||||
public:
|
||||
FakeExpr(ProgramID program_id, Source source)
|
||||
: ast::Expression(program_id, source) {}
|
||||
FakeExpr(ProgramID program_id, Source source) : Base(program_id, source) {}
|
||||
FakeExpr* Clone(CloneContext*) const override { return nullptr; }
|
||||
void to_str(const sem::Info&, std::ostream&, size_t) const override {}
|
||||
};
|
||||
@@ -158,14 +158,15 @@ TEST_F(ResolverValidationTest, Stmt_Else_NonBool) {
|
||||
"12:34 error: else statement condition must be bool, got f32");
|
||||
}
|
||||
|
||||
TEST_F(ResolverValidationTest, Expr_Error_Unknown) {
|
||||
auto* e = create<FakeExpr>(Source{Source::Location{2, 30}});
|
||||
WrapInFunction(e);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
||||
EXPECT_EQ(r()->error(),
|
||||
"2:30 error: unknown expression for type determination");
|
||||
TEST_F(ResolverValidationTest, Expr_ErrUnknownExprType) {
|
||||
EXPECT_FATAL_FAILURE(
|
||||
{
|
||||
ProgramBuilder b;
|
||||
b.WrapInFunction(b.create<FakeExpr>());
|
||||
Resolver(&b).Resolve();
|
||||
},
|
||||
"internal compiler error: unhandled expression type: "
|
||||
"tint::resolver::FakeExpr");
|
||||
}
|
||||
|
||||
TEST_F(ResolverValidationTest, Expr_DontCall_Function) {
|
||||
@@ -925,3 +926,5 @@ TEST_F(ResolverTest, Expr_Constructor_Cast_Pointer) {
|
||||
} // namespace
|
||||
} // namespace resolver
|
||||
} // namespace tint
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::resolver::FakeExpr);
|
||||
|
||||
Reference in New Issue
Block a user