Resolver: Traverse expressions without recursion

This CL changes the way that the resolver traverses expressions to avoid stack overflows for deeply nested expressions.

Instead of having the expression resolver methods call back into
Expression(), add a TraverseExpressions() method that collects all the
expression nodes with a simple DFS.

This currently only changes the way that Expressions are traversed. We
may need to do the same for statements.

Bug: chromium:1246375
Change-Id: Ie81905da1b790b6dd1df9f1ac42e06593d397c21
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63700
Auto-Submit: Ben Clayton <bclayton@google.com>
Reviewed-by: David Neto <dneto@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2021-09-08 15:18:36 +00:00 committed by Tint LUCI CQ
parent be514a1efb
commit b7bcbf0d20
7 changed files with 200 additions and 86 deletions

View File

@ -713,6 +713,7 @@ if(${TINT_BUILD_TESTS})
utils/io/command_test.cc
utils/io/tmpfile_test.cc
utils/math_test.cc
utils/reverse_test.cc
utils/scoped_assignment_test.cc
utils/unique_vector_test.cc
writer/append_vector_test.cc

View File

@ -72,6 +72,7 @@
#include "src/utils/defer.h"
#include "src/utils/get_or_create.h"
#include "src/utils/math.h"
#include "src/utils/reverse.h"
#include "src/utils/scoped_assignment.h"
namespace tint {
@ -2244,21 +2245,62 @@ bool Resolver::ForLoopStatement(ast::ForLoopStatement* stmt) {
});
}
bool Resolver::Expressions(const ast::ExpressionList& list) {
for (auto* expr : list) {
Mark(expr);
if (!Expression(expr)) {
bool Resolver::TraverseExpressions(ast::Expression* root,
std::vector<ast::Expression*>& out) {
std::vector<ast::Expression*> to_visit;
to_visit.emplace_back(root);
auto add = [&](ast::Expression* e) {
Mark(e);
to_visit.emplace_back(e);
};
while (!to_visit.empty()) {
auto* expr = to_visit.back();
to_visit.pop_back();
out.emplace_back(expr);
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
add(array->array());
add(array->idx_expr());
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
add(bin_op->lhs());
add(bin_op->rhs());
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
add(bitcast->expr());
} else if (auto* call = expr->As<ast::CallExpression>()) {
for (auto* arg : call->params()) {
add(arg);
}
} else if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
for (auto* value : type_ctor->values()) {
add(value);
}
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
add(member->structure());
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
add(unary->expr());
} else if (expr->IsAnyOf<ast::ScalarConstructorExpression,
ast::IdentifierExpression>()) {
// Leaf expression
} else {
TINT_ICE(Resolver, diagnostics_)
<< "unhandled expression type: " << expr->TypeInfo().name;
return false;
}
}
return true;
}
bool Resolver::Expression(ast::Expression* expr) {
if (TypeOf(expr)) {
return true; // Already resolved
bool Resolver::Expression(ast::Expression* root) {
std::vector<ast::Expression*> sorted;
if (!TraverseExpressions(root, sorted)) {
return false;
}
for (auto* expr : utils::Reverse(sorted)) {
bool ok = false;
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
ok = ArrayAccessor(array);
@ -2277,27 +2319,20 @@ bool Resolver::Expression(ast::Expression* expr) {
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
ok = UnaryOp(unary);
} else {
AddError("unknown expression for type determination", expr->source());
TINT_ICE(Resolver, diagnostics_)
<< "unhandled expression type: " << expr->TypeInfo().name;
return false;
}
if (!ok) {
return false;
}
}
return true;
}
bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
Mark(expr->array());
if (!Expression(expr->array())) {
return false;
}
auto* idx = expr->idx_expr();
Mark(idx);
if (!Expression(idx)) {
return false;
}
auto* res = TypeOf(expr->array());
auto* parent_type = res->UnwrapRef();
const sem::Type* ret = nullptr;
@ -2345,10 +2380,6 @@ bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) {
}
bool Resolver::Bitcast(ast::BitcastExpression* expr) {
Mark(expr->expr());
if (!Expression(expr->expr())) {
return false;
}
auto* ty = Type(expr->type());
if (!ty) {
return false;
@ -2362,10 +2393,6 @@ bool Resolver::Bitcast(ast::BitcastExpression* expr) {
}
bool Resolver::Call(ast::CallExpression* call) {
if (!Expressions(call->params())) {
return false;
}
Mark(call->func());
auto* ident = call->func();
auto name = builder_->Symbols().NameFor(ident->symbol());
@ -2641,13 +2668,6 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
bool Resolver::Constructor(ast::ConstructorExpression* expr) {
if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
for (auto* value : type_ctor->values()) {
Mark(value);
if (!Expression(value)) {
return false;
}
}
auto* type = Type(type_ctor->type());
if (!type) {
return false;
@ -2994,11 +3014,6 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) {
}
bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
Mark(expr->structure());
if (!Expression(expr->structure())) {
return false;
}
auto* structure = TypeOf(expr->structure());
auto* storage_type = structure->UnwrapRef();
@ -3118,13 +3133,6 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
}
bool Resolver::Binary(ast::BinaryExpression* expr) {
Mark(expr->lhs());
Mark(expr->rhs());
if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
return false;
}
using Bool = sem::Bool;
using F32 = sem::F32;
using I32 = sem::I32;
@ -3330,13 +3338,6 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
}
bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) {
Mark(unary->expr());
// Resolve the inner expression
if (!Expression(unary->expr())) {
return false;
}
auto* expr_type = TypeOf(unary->expr());
if (!expr_type) {
return false;

View File

@ -244,7 +244,6 @@ class Resolver {
bool Constructor(ast::ConstructorExpression*);
bool ElseStatement(ast::ElseStatement*);
bool Expression(ast::Expression*);
bool Expressions(const ast::ExpressionList&);
bool ForLoopStatement(ast::ForLoopStatement*);
bool Function(ast::Function*);
bool FunctionCall(const ast::CallExpression* call);
@ -262,6 +261,15 @@ class Resolver {
bool UnaryOp(ast::UnaryOpExpression*);
bool VariableDeclStatement(const ast::VariableDeclStatement*);
/// Performs a depth-first traversal of the expression nodes from `root`,
/// collecting all the visited expressions in pre-ordering (root first).
/// @param root the root expression node
/// @param out the ordered list of visited expression nodes, starting with the
/// root node, and ending with leaf nodes
/// @return true on success, false on error
bool TraverseExpressions(ast::Expression* root,
std::vector<ast::Expression*>& out);
// AST and Type validation methods
// Each return true on success, false on failure.
bool ValidateArray(const sem::Array* arr, const Source& source);

View File

@ -15,6 +15,7 @@
#include "src/resolver/resolver.h"
#include "gmock/gmock.h"
#include "gtest/gtest-spi.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/break_statement.h"
@ -56,10 +57,9 @@ class FakeStmt : public ast::Statement {
}
};
class FakeExpr : public ast::Expression {
class FakeExpr : public Castable<FakeExpr, ast::Expression> {
public:
FakeExpr(ProgramID program_id, Source source)
: ast::Expression(program_id, source) {}
FakeExpr(ProgramID program_id, Source source) : Base(program_id, source) {}
FakeExpr* Clone(CloneContext*) const override { return nullptr; }
void to_str(const sem::Info&, std::ostream&, size_t) const override {}
};
@ -158,14 +158,15 @@ TEST_F(ResolverValidationTest, Stmt_Else_NonBool) {
"12:34 error: else statement condition must be bool, got f32");
}
TEST_F(ResolverValidationTest, Expr_Error_Unknown) {
auto* e = create<FakeExpr>(Source{Source::Location{2, 30}});
WrapInFunction(e);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"2:30 error: unknown expression for type determination");
TEST_F(ResolverValidationTest, Expr_ErrUnknownExprType) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder b;
b.WrapInFunction(b.create<FakeExpr>());
Resolver(&b).Resolve();
},
"internal compiler error: unhandled expression type: "
"tint::resolver::FakeExpr");
}
TEST_F(ResolverValidationTest, Expr_DontCall_Function) {
@ -925,3 +926,5 @@ TEST_F(ResolverTest, Expr_Constructor_Cast_Pointer) {
} // namespace
} // namespace resolver
} // namespace tint
TINT_INSTANTIATE_TYPEINFO(tint::resolver::FakeExpr);

64
src/utils/reverse.h Normal file
View File

@ -0,0 +1,64 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_UTILS_REVERSE_H_
#define SRC_UTILS_REVERSE_H_
#include <iterator>
namespace tint {
namespace utils {
namespace detail {
/// Used by utils::Reverse to hold the underlying iterable.
/// begin(ReverseIterable<T>) and end(ReverseIterable<T>) are automatically
/// called for range-for loops, via argument-dependent lookup.
/// See https://en.cppreference.com/w/cpp/language/range-for
template <typename T>
struct ReverseIterable {
/// The wrapped iterable object.
T& iterable;
};
template <typename T>
auto begin(ReverseIterable<T> r_it) {
return std::rbegin(r_it.iterable);
}
template <typename T>
auto end(ReverseIterable<T> r_it) {
return std::rend(r_it.iterable);
}
} // namespace detail
/// Reverse returns an iterable wrapper that when used in range-for loops,
/// performs a reverse iteration over the object `iterable`.
/// Example:
/// ```
/// /* Equivalent to:
/// * for (auto it = vec.rbegin(); i != vec.rend(); ++it) {
/// * auto v = *it;
/// */
/// for (auto v : utils::Reverse(vec)) {
/// }
/// ```
template <typename T>
detail::ReverseIterable<T> Reverse(T&& iterable) {
return {iterable};
}
} // namespace utils
} // namespace tint
#endif // SRC_UTILS_REVERSE_H_

36
src/utils/reverse_test.cc Normal file
View File

@ -0,0 +1,36 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/utils/reverse.h"
#include <vector>
#include "gmock/gmock.h"
namespace tint {
namespace utils {
namespace {
TEST(ReverseTest, Vector) {
std::vector<int> vec{1, 3, 5, 7, 9};
std::vector<int> rev;
for (auto v : Reverse(vec)) {
rev.emplace_back(v);
}
ASSERT_THAT(rev, testing::ElementsAre(9, 7, 5, 3, 1));
}
} // namespace
} // namespace utils
} // namespace tint

View File

@ -320,6 +320,7 @@ tint_unittests_source_set("tint_unittests_core_src") {
"../src/utils/io/command_test.cc",
"../src/utils/io/tmpfile_test.cc",
"../src/utils/math_test.cc",
"../src/utils/reverse_test.cc",
"../src/utils/scoped_assignment_test.cc",
"../src/utils/unique_vector_test.cc",
"../src/writer/append_vector_test.cc",