Resolver: Traverse expressions without recursion
This CL changes the way that the resolver traverses expressions to avoid stack overflows for deeply nested expressions. Instead of having the expression resolver methods call back into Expression(), add a TraverseExpressions() method that collects all the expression nodes with a simple DFS. This currently only changes the way that Expressions are traversed. We may need to do the same for statements. Bug: chromium:1246375 Change-Id: Ie81905da1b790b6dd1df9f1ac42e06593d397c21 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63700 Auto-Submit: Ben Clayton <bclayton@google.com> Reviewed-by: David Neto <dneto@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
be514a1efb
commit
b7bcbf0d20
|
@ -713,6 +713,7 @@ if(${TINT_BUILD_TESTS})
|
|||
utils/io/command_test.cc
|
||||
utils/io/tmpfile_test.cc
|
||||
utils/math_test.cc
|
||||
utils/reverse_test.cc
|
||||
utils/scoped_assignment_test.cc
|
||||
utils/unique_vector_test.cc
|
||||
writer/append_vector_test.cc
|
||||
|
|
|
@ -72,6 +72,7 @@
|
|||
#include "src/utils/defer.h"
|
||||
#include "src/utils/get_or_create.h"
|
||||
#include "src/utils/math.h"
|
||||
#include "src/utils/reverse.h"
|
||||
#include "src/utils/scoped_assignment.h"
|
||||
|
||||
namespace tint {
|
||||
|
@ -2244,60 +2245,94 @@ bool Resolver::ForLoopStatement(ast::ForLoopStatement* stmt) {
|
|||
});
|
||||
}
|
||||
|
||||
bool Resolver::Expressions(const ast::ExpressionList& list) {
|
||||
for (auto* expr : list) {
|
||||
Mark(expr);
|
||||
if (!Expression(expr)) {
|
||||
bool Resolver::TraverseExpressions(ast::Expression* root,
|
||||
std::vector<ast::Expression*>& out) {
|
||||
std::vector<ast::Expression*> to_visit;
|
||||
to_visit.emplace_back(root);
|
||||
|
||||
auto add = [&](ast::Expression* e) {
|
||||
Mark(e);
|
||||
to_visit.emplace_back(e);
|
||||
};
|
||||
|
||||
while (!to_visit.empty()) {
|
||||
auto* expr = to_visit.back();
|
||||
to_visit.pop_back();
|
||||
|
||||
out.emplace_back(expr);
|
||||
|
||||
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
|
||||
add(array->array());
|
||||
add(array->idx_expr());
|
||||
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
|
||||
add(bin_op->lhs());
|
||||
add(bin_op->rhs());
|
||||
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
|
||||
add(bitcast->expr());
|
||||
} else if (auto* call = expr->As<ast::CallExpression>()) {
|
||||
for (auto* arg : call->params()) {
|
||||
add(arg);
|
||||
}
|
||||
} else if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
|
||||
for (auto* value : type_ctor->values()) {
|
||||
add(value);
|
||||
}
|
||||
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
|
||||
add(member->structure());
|
||||
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
|
||||
add(unary->expr());
|
||||
} else if (expr->IsAnyOf<ast::ScalarConstructorExpression,
|
||||
ast::IdentifierExpression>()) {
|
||||
// Leaf expression
|
||||
} else {
|
||||
TINT_ICE(Resolver, diagnostics_)
|
||||
<< "unhandled expression type: " << expr->TypeInfo().name;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Resolver::Expression(ast::Expression* expr) {
|
||||
if (TypeOf(expr)) {
|
||||
return true; // Already resolved
|
||||
}
|
||||
|
||||
bool ok = false;
|
||||
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
|
||||
ok = ArrayAccessor(array);
|
||||
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
|
||||
ok = Binary(bin_op);
|
||||
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
|
||||
ok = Bitcast(bitcast);
|
||||
} else if (auto* call = expr->As<ast::CallExpression>()) {
|
||||
ok = Call(call);
|
||||
} else if (auto* ctor = expr->As<ast::ConstructorExpression>()) {
|
||||
ok = Constructor(ctor);
|
||||
} else if (auto* ident = expr->As<ast::IdentifierExpression>()) {
|
||||
ok = Identifier(ident);
|
||||
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
|
||||
ok = MemberAccessor(member);
|
||||
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
|
||||
ok = UnaryOp(unary);
|
||||
} else {
|
||||
AddError("unknown expression for type determination", expr->source());
|
||||
}
|
||||
|
||||
if (!ok) {
|
||||
bool Resolver::Expression(ast::Expression* root) {
|
||||
std::vector<ast::Expression*> sorted;
|
||||
if (!TraverseExpressions(root, sorted)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto* expr : utils::Reverse(sorted)) {
|
||||
bool ok = false;
|
||||
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
|
||||
ok = ArrayAccessor(array);
|
||||
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
|
||||
ok = Binary(bin_op);
|
||||
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
|
||||
ok = Bitcast(bitcast);
|
||||
} else if (auto* call = expr->As<ast::CallExpression>()) {
|
||||
ok = Call(call);
|
||||
} else if (auto* ctor = expr->As<ast::ConstructorExpression>()) {
|
||||
ok = Constructor(ctor);
|
||||
} else if (auto* ident = expr->As<ast::IdentifierExpression>()) {
|
||||
ok = Identifier(ident);
|
||||
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
|
||||
ok = MemberAccessor(member);
|
||||
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
|
||||
ok = UnaryOp(unary);
|
||||
} else {
|
||||
TINT_ICE(Resolver, diagnostics_)
|
||||
<< "unhandled expression type: " << expr->TypeInfo().name;
|
||||
return false;
|
||||
}
|
||||
if (!ok) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
|
||||
Mark(expr->array());
|
||||
if (!Expression(expr->array())) {
|
||||
return false;
|
||||
}
|
||||
auto* idx = expr->idx_expr();
|
||||
Mark(idx);
|
||||
if (!Expression(idx)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* res = TypeOf(expr->array());
|
||||
auto* parent_type = res->UnwrapRef();
|
||||
const sem::Type* ret = nullptr;
|
||||
|
@ -2345,10 +2380,6 @@ bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
|
|||
}
|
||||
|
||||
bool Resolver::Bitcast(ast::BitcastExpression* expr) {
|
||||
Mark(expr->expr());
|
||||
if (!Expression(expr->expr())) {
|
||||
return false;
|
||||
}
|
||||
auto* ty = Type(expr->type());
|
||||
if (!ty) {
|
||||
return false;
|
||||
|
@ -2362,10 +2393,6 @@ bool Resolver::Bitcast(ast::BitcastExpression* expr) {
|
|||
}
|
||||
|
||||
bool Resolver::Call(ast::CallExpression* call) {
|
||||
if (!Expressions(call->params())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Mark(call->func());
|
||||
auto* ident = call->func();
|
||||
auto name = builder_->Symbols().NameFor(ident->symbol());
|
||||
|
@ -2641,13 +2668,6 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
|
|||
|
||||
bool Resolver::Constructor(ast::ConstructorExpression* expr) {
|
||||
if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
|
||||
for (auto* value : type_ctor->values()) {
|
||||
Mark(value);
|
||||
if (!Expression(value)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto* type = Type(type_ctor->type());
|
||||
if (!type) {
|
||||
return false;
|
||||
|
@ -2994,11 +3014,6 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) {
|
|||
}
|
||||
|
||||
bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
|
||||
Mark(expr->structure());
|
||||
if (!Expression(expr->structure())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* structure = TypeOf(expr->structure());
|
||||
auto* storage_type = structure->UnwrapRef();
|
||||
|
||||
|
@ -3118,13 +3133,6 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
|
|||
}
|
||||
|
||||
bool Resolver::Binary(ast::BinaryExpression* expr) {
|
||||
Mark(expr->lhs());
|
||||
Mark(expr->rhs());
|
||||
|
||||
if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
using Bool = sem::Bool;
|
||||
using F32 = sem::F32;
|
||||
using I32 = sem::I32;
|
||||
|
@ -3330,13 +3338,6 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
|
|||
}
|
||||
|
||||
bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) {
|
||||
Mark(unary->expr());
|
||||
|
||||
// Resolve the inner expression
|
||||
if (!Expression(unary->expr())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* expr_type = TypeOf(unary->expr());
|
||||
if (!expr_type) {
|
||||
return false;
|
||||
|
|
|
@ -244,7 +244,6 @@ class Resolver {
|
|||
bool Constructor(ast::ConstructorExpression*);
|
||||
bool ElseStatement(ast::ElseStatement*);
|
||||
bool Expression(ast::Expression*);
|
||||
bool Expressions(const ast::ExpressionList&);
|
||||
bool ForLoopStatement(ast::ForLoopStatement*);
|
||||
bool Function(ast::Function*);
|
||||
bool FunctionCall(const ast::CallExpression* call);
|
||||
|
@ -262,6 +261,15 @@ class Resolver {
|
|||
bool UnaryOp(ast::UnaryOpExpression*);
|
||||
bool VariableDeclStatement(const ast::VariableDeclStatement*);
|
||||
|
||||
/// Performs a depth-first traversal of the expression nodes from `root`,
|
||||
/// collecting all the visited expressions in pre-ordering (root first).
|
||||
/// @param root the root expression node
|
||||
/// @param out the ordered list of visited expression nodes, starting with the
|
||||
/// root node, and ending with leaf nodes
|
||||
/// @return true on success, false on error
|
||||
bool TraverseExpressions(ast::Expression* root,
|
||||
std::vector<ast::Expression*>& out);
|
||||
|
||||
// AST and Type validation methods
|
||||
// Each return true on success, false on failure.
|
||||
bool ValidateArray(const sem::Array* arr, const Source& source);
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "src/resolver/resolver.h"
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest-spi.h"
|
||||
#include "src/ast/assignment_statement.h"
|
||||
#include "src/ast/bitcast_expression.h"
|
||||
#include "src/ast/break_statement.h"
|
||||
|
@ -56,10 +57,9 @@ class FakeStmt : public ast::Statement {
|
|||
}
|
||||
};
|
||||
|
||||
class FakeExpr : public ast::Expression {
|
||||
class FakeExpr : public Castable<FakeExpr, ast::Expression> {
|
||||
public:
|
||||
FakeExpr(ProgramID program_id, Source source)
|
||||
: ast::Expression(program_id, source) {}
|
||||
FakeExpr(ProgramID program_id, Source source) : Base(program_id, source) {}
|
||||
FakeExpr* Clone(CloneContext*) const override { return nullptr; }
|
||||
void to_str(const sem::Info&, std::ostream&, size_t) const override {}
|
||||
};
|
||||
|
@ -158,14 +158,15 @@ TEST_F(ResolverValidationTest, Stmt_Else_NonBool) {
|
|||
"12:34 error: else statement condition must be bool, got f32");
|
||||
}
|
||||
|
||||
TEST_F(ResolverValidationTest, Expr_Error_Unknown) {
|
||||
auto* e = create<FakeExpr>(Source{Source::Location{2, 30}});
|
||||
WrapInFunction(e);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
||||
EXPECT_EQ(r()->error(),
|
||||
"2:30 error: unknown expression for type determination");
|
||||
TEST_F(ResolverValidationTest, Expr_ErrUnknownExprType) {
|
||||
EXPECT_FATAL_FAILURE(
|
||||
{
|
||||
ProgramBuilder b;
|
||||
b.WrapInFunction(b.create<FakeExpr>());
|
||||
Resolver(&b).Resolve();
|
||||
},
|
||||
"internal compiler error: unhandled expression type: "
|
||||
"tint::resolver::FakeExpr");
|
||||
}
|
||||
|
||||
TEST_F(ResolverValidationTest, Expr_DontCall_Function) {
|
||||
|
@ -925,3 +926,5 @@ TEST_F(ResolverTest, Expr_Constructor_Cast_Pointer) {
|
|||
} // namespace
|
||||
} // namespace resolver
|
||||
} // namespace tint
|
||||
|
||||
TINT_INSTANTIATE_TYPEINFO(tint::resolver::FakeExpr);
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_UTILS_REVERSE_H_
|
||||
#define SRC_UTILS_REVERSE_H_
|
||||
|
||||
#include <iterator>
|
||||
|
||||
namespace tint {
|
||||
namespace utils {
|
||||
|
||||
namespace detail {
|
||||
/// Used by utils::Reverse to hold the underlying iterable.
|
||||
/// begin(ReverseIterable<T>) and end(ReverseIterable<T>) are automatically
|
||||
/// called for range-for loops, via argument-dependent lookup.
|
||||
/// See https://en.cppreference.com/w/cpp/language/range-for
|
||||
template <typename T>
|
||||
struct ReverseIterable {
|
||||
/// The wrapped iterable object.
|
||||
T& iterable;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
auto begin(ReverseIterable<T> r_it) {
|
||||
return std::rbegin(r_it.iterable);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
auto end(ReverseIterable<T> r_it) {
|
||||
return std::rend(r_it.iterable);
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
/// Reverse returns an iterable wrapper that when used in range-for loops,
|
||||
/// performs a reverse iteration over the object `iterable`.
|
||||
/// Example:
|
||||
/// ```
|
||||
/// /* Equivalent to:
|
||||
/// * for (auto it = vec.rbegin(); i != vec.rend(); ++it) {
|
||||
/// * auto v = *it;
|
||||
/// */
|
||||
/// for (auto v : utils::Reverse(vec)) {
|
||||
/// }
|
||||
/// ```
|
||||
template <typename T>
|
||||
detail::ReverseIterable<T> Reverse(T&& iterable) {
|
||||
return {iterable};
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace tint
|
||||
|
||||
#endif // SRC_UTILS_REVERSE_H_
|
|
@ -0,0 +1,36 @@
|
|||
// Copyright 2021 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "src/utils/reverse.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
|
||||
namespace tint {
|
||||
namespace utils {
|
||||
namespace {
|
||||
|
||||
TEST(ReverseTest, Vector) {
|
||||
std::vector<int> vec{1, 3, 5, 7, 9};
|
||||
std::vector<int> rev;
|
||||
for (auto v : Reverse(vec)) {
|
||||
rev.emplace_back(v);
|
||||
}
|
||||
ASSERT_THAT(rev, testing::ElementsAre(9, 7, 5, 3, 1));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace utils
|
||||
} // namespace tint
|
|
@ -320,6 +320,7 @@ tint_unittests_source_set("tint_unittests_core_src") {
|
|||
"../src/utils/io/command_test.cc",
|
||||
"../src/utils/io/tmpfile_test.cc",
|
||||
"../src/utils/math_test.cc",
|
||||
"../src/utils/reverse_test.cc",
|
||||
"../src/utils/scoped_assignment_test.cc",
|
||||
"../src/utils/unique_vector_test.cc",
|
||||
"../src/writer/append_vector_test.cc",
|
||||
|
|
Loading…
Reference in New Issue