From b7bcbf0d2087c73f9be28c9ad83c0b259ddecbe0 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 8 Sep 2021 15:18:36 +0000 Subject: [PATCH] Resolver: Traverse expressions without recursion This CL changes the way that the resolver traverses expressions to avoid stack overflows for deeply nested expressions. Instead of having the expression resolver methods call back into Expression(), add a TraverseExpressions() method that collects all the expression nodes with a simple DFS. This currently only changes the way that Expressions are traversed. We may need to do the same for statements. Bug: chromium:1246375 Change-Id: Ie81905da1b790b6dd1df9f1ac42e06593d397c21 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63700 Auto-Submit: Ben Clayton Reviewed-by: David Neto Commit-Queue: Ben Clayton Kokoro: Kokoro --- src/CMakeLists.txt | 1 + src/resolver/resolver.cc | 149 ++++++++++++++++---------------- src/resolver/resolver.h | 10 ++- src/resolver/validation_test.cc | 25 +++--- src/utils/reverse.h | 64 ++++++++++++++ src/utils/reverse_test.cc | 36 ++++++++ test/BUILD.gn | 1 + 7 files changed, 200 insertions(+), 86 deletions(-) create mode 100644 src/utils/reverse.h create mode 100644 src/utils/reverse_test.cc diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9f81fda4c8..06e7392f24 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -713,6 +713,7 @@ if(${TINT_BUILD_TESTS}) utils/io/command_test.cc utils/io/tmpfile_test.cc utils/math_test.cc + utils/reverse_test.cc utils/scoped_assignment_test.cc utils/unique_vector_test.cc writer/append_vector_test.cc diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 20acfa86fe..f7c958b96d 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -72,6 +72,7 @@ #include "src/utils/defer.h" #include "src/utils/get_or_create.h" #include "src/utils/math.h" +#include "src/utils/reverse.h" #include "src/utils/scoped_assignment.h" namespace tint { @@ -2244,60 +2245,94 @@ bool Resolver::ForLoopStatement(ast::ForLoopStatement* stmt) { }); } -bool Resolver::Expressions(const ast::ExpressionList& list) { - for (auto* expr : list) { - Mark(expr); - if (!Expression(expr)) { +bool Resolver::TraverseExpressions(ast::Expression* root, + std::vector& out) { + std::vector to_visit; + to_visit.emplace_back(root); + + auto add = [&](ast::Expression* e) { + Mark(e); + to_visit.emplace_back(e); + }; + + while (!to_visit.empty()) { + auto* expr = to_visit.back(); + to_visit.pop_back(); + + out.emplace_back(expr); + + if (auto* array = expr->As()) { + add(array->array()); + add(array->idx_expr()); + } else if (auto* bin_op = expr->As()) { + add(bin_op->lhs()); + add(bin_op->rhs()); + } else if (auto* bitcast = expr->As()) { + add(bitcast->expr()); + } else if (auto* call = expr->As()) { + for (auto* arg : call->params()) { + add(arg); + } + } else if (auto* type_ctor = expr->As()) { + for (auto* value : type_ctor->values()) { + add(value); + } + } else if (auto* member = expr->As()) { + add(member->structure()); + } else if (auto* unary = expr->As()) { + add(unary->expr()); + } else if (expr->IsAnyOf()) { + // Leaf expression + } else { + TINT_ICE(Resolver, diagnostics_) + << "unhandled expression type: " << expr->TypeInfo().name; return false; } } + return true; } -bool Resolver::Expression(ast::Expression* expr) { - if (TypeOf(expr)) { - return true; // Already resolved - } - - bool ok = false; - if (auto* array = expr->As()) { - ok = ArrayAccessor(array); - } else if (auto* bin_op = expr->As()) { - ok = Binary(bin_op); - } else if (auto* bitcast = expr->As()) { - ok = Bitcast(bitcast); - } else if (auto* call = expr->As()) { - ok = Call(call); - } else if (auto* ctor = expr->As()) { - ok = Constructor(ctor); - } else if (auto* ident = expr->As()) { - ok = Identifier(ident); - } else if (auto* member = expr->As()) { - ok = MemberAccessor(member); - } else if (auto* unary = expr->As()) { - ok = UnaryOp(unary); - } else { - AddError("unknown expression for type determination", expr->source()); - } - - if (!ok) { +bool Resolver::Expression(ast::Expression* root) { + std::vector sorted; + if (!TraverseExpressions(root, sorted)) { return false; } + for (auto* expr : utils::Reverse(sorted)) { + bool ok = false; + if (auto* array = expr->As()) { + ok = ArrayAccessor(array); + } else if (auto* bin_op = expr->As()) { + ok = Binary(bin_op); + } else if (auto* bitcast = expr->As()) { + ok = Bitcast(bitcast); + } else if (auto* call = expr->As()) { + ok = Call(call); + } else if (auto* ctor = expr->As()) { + ok = Constructor(ctor); + } else if (auto* ident = expr->As()) { + ok = Identifier(ident); + } else if (auto* member = expr->As()) { + ok = MemberAccessor(member); + } else if (auto* unary = expr->As()) { + ok = UnaryOp(unary); + } else { + TINT_ICE(Resolver, diagnostics_) + << "unhandled expression type: " << expr->TypeInfo().name; + return false; + } + if (!ok) { + return false; + } + } + return true; } bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) { - Mark(expr->array()); - if (!Expression(expr->array())) { - return false; - } auto* idx = expr->idx_expr(); - Mark(idx); - if (!Expression(idx)) { - return false; - } - auto* res = TypeOf(expr->array()); auto* parent_type = res->UnwrapRef(); const sem::Type* ret = nullptr; @@ -2345,10 +2380,6 @@ bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) { } bool Resolver::Bitcast(ast::BitcastExpression* expr) { - Mark(expr->expr()); - if (!Expression(expr->expr())) { - return false; - } auto* ty = Type(expr->type()); if (!ty) { return false; @@ -2362,10 +2393,6 @@ bool Resolver::Bitcast(ast::BitcastExpression* expr) { } bool Resolver::Call(ast::CallExpression* call) { - if (!Expressions(call->params())) { - return false; - } - Mark(call->func()); auto* ident = call->func(); auto name = builder_->Symbols().NameFor(ident->symbol()); @@ -2641,13 +2668,6 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call, bool Resolver::Constructor(ast::ConstructorExpression* expr) { if (auto* type_ctor = expr->As()) { - for (auto* value : type_ctor->values()) { - Mark(value); - if (!Expression(value)) { - return false; - } - } - auto* type = Type(type_ctor->type()); if (!type) { return false; @@ -2994,11 +3014,6 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) { } bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) { - Mark(expr->structure()); - if (!Expression(expr->structure())) { - return false; - } - auto* structure = TypeOf(expr->structure()); auto* storage_type = structure->UnwrapRef(); @@ -3118,13 +3133,6 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) { } bool Resolver::Binary(ast::BinaryExpression* expr) { - Mark(expr->lhs()); - Mark(expr->rhs()); - - if (!Expression(expr->lhs()) || !Expression(expr->rhs())) { - return false; - } - using Bool = sem::Bool; using F32 = sem::F32; using I32 = sem::I32; @@ -3330,13 +3338,6 @@ bool Resolver::Binary(ast::BinaryExpression* expr) { } bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) { - Mark(unary->expr()); - - // Resolve the inner expression - if (!Expression(unary->expr())) { - return false; - } - auto* expr_type = TypeOf(unary->expr()); if (!expr_type) { return false; diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 9f02791d5d..ccbb97b6cc 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -244,7 +244,6 @@ class Resolver { bool Constructor(ast::ConstructorExpression*); bool ElseStatement(ast::ElseStatement*); bool Expression(ast::Expression*); - bool Expressions(const ast::ExpressionList&); bool ForLoopStatement(ast::ForLoopStatement*); bool Function(ast::Function*); bool FunctionCall(const ast::CallExpression* call); @@ -262,6 +261,15 @@ class Resolver { bool UnaryOp(ast::UnaryOpExpression*); bool VariableDeclStatement(const ast::VariableDeclStatement*); + /// Performs a depth-first traversal of the expression nodes from `root`, + /// collecting all the visited expressions in pre-ordering (root first). + /// @param root the root expression node + /// @param out the ordered list of visited expression nodes, starting with the + /// root node, and ending with leaf nodes + /// @return true on success, false on error + bool TraverseExpressions(ast::Expression* root, + std::vector& out); + // AST and Type validation methods // Each return true on success, false on failure. bool ValidateArray(const sem::Array* arr, const Source& source); diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index 07396a79b7..bad2a8e92c 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -15,6 +15,7 @@ #include "src/resolver/resolver.h" #include "gmock/gmock.h" +#include "gtest/gtest-spi.h" #include "src/ast/assignment_statement.h" #include "src/ast/bitcast_expression.h" #include "src/ast/break_statement.h" @@ -56,10 +57,9 @@ class FakeStmt : public ast::Statement { } }; -class FakeExpr : public ast::Expression { +class FakeExpr : public Castable { public: - FakeExpr(ProgramID program_id, Source source) - : ast::Expression(program_id, source) {} + FakeExpr(ProgramID program_id, Source source) : Base(program_id, source) {} FakeExpr* Clone(CloneContext*) const override { return nullptr; } void to_str(const sem::Info&, std::ostream&, size_t) const override {} }; @@ -158,14 +158,15 @@ TEST_F(ResolverValidationTest, Stmt_Else_NonBool) { "12:34 error: else statement condition must be bool, got f32"); } -TEST_F(ResolverValidationTest, Expr_Error_Unknown) { - auto* e = create(Source{Source::Location{2, 30}}); - WrapInFunction(e); - - EXPECT_FALSE(r()->Resolve()); - - EXPECT_EQ(r()->error(), - "2:30 error: unknown expression for type determination"); +TEST_F(ResolverValidationTest, Expr_ErrUnknownExprType) { + EXPECT_FATAL_FAILURE( + { + ProgramBuilder b; + b.WrapInFunction(b.create()); + Resolver(&b).Resolve(); + }, + "internal compiler error: unhandled expression type: " + "tint::resolver::FakeExpr"); } TEST_F(ResolverValidationTest, Expr_DontCall_Function) { @@ -925,3 +926,5 @@ TEST_F(ResolverTest, Expr_Constructor_Cast_Pointer) { } // namespace } // namespace resolver } // namespace tint + +TINT_INSTANTIATE_TYPEINFO(tint::resolver::FakeExpr); diff --git a/src/utils/reverse.h b/src/utils/reverse.h new file mode 100644 index 0000000000..0848dc585e --- /dev/null +++ b/src/utils/reverse.h @@ -0,0 +1,64 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_UTILS_REVERSE_H_ +#define SRC_UTILS_REVERSE_H_ + +#include + +namespace tint { +namespace utils { + +namespace detail { +/// Used by utils::Reverse to hold the underlying iterable. +/// begin(ReverseIterable) and end(ReverseIterable) are automatically +/// called for range-for loops, via argument-dependent lookup. +/// See https://en.cppreference.com/w/cpp/language/range-for +template +struct ReverseIterable { + /// The wrapped iterable object. + T& iterable; +}; + +template +auto begin(ReverseIterable r_it) { + return std::rbegin(r_it.iterable); +} + +template +auto end(ReverseIterable r_it) { + return std::rend(r_it.iterable); +} +} // namespace detail + +/// Reverse returns an iterable wrapper that when used in range-for loops, +/// performs a reverse iteration over the object `iterable`. +/// Example: +/// ``` +/// /* Equivalent to: +/// * for (auto it = vec.rbegin(); i != vec.rend(); ++it) { +/// * auto v = *it; +/// */ +/// for (auto v : utils::Reverse(vec)) { +/// } +/// ``` +template +detail::ReverseIterable Reverse(T&& iterable) { + return {iterable}; +} + +} // namespace utils +} // namespace tint + +#endif // SRC_UTILS_REVERSE_H_ diff --git a/src/utils/reverse_test.cc b/src/utils/reverse_test.cc new file mode 100644 index 0000000000..25be0eb40a --- /dev/null +++ b/src/utils/reverse_test.cc @@ -0,0 +1,36 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/utils/reverse.h" + +#include + +#include "gmock/gmock.h" + +namespace tint { +namespace utils { +namespace { + +TEST(ReverseTest, Vector) { + std::vector vec{1, 3, 5, 7, 9}; + std::vector rev; + for (auto v : Reverse(vec)) { + rev.emplace_back(v); + } + ASSERT_THAT(rev, testing::ElementsAre(9, 7, 5, 3, 1)); +} + +} // namespace +} // namespace utils +} // namespace tint diff --git a/test/BUILD.gn b/test/BUILD.gn index c7fffad98b..ad410e75d6 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -320,6 +320,7 @@ tint_unittests_source_set("tint_unittests_core_src") { "../src/utils/io/command_test.cc", "../src/utils/io/tmpfile_test.cc", "../src/utils/math_test.cc", + "../src/utils/reverse_test.cc", "../src/utils/scoped_assignment_test.cc", "../src/utils/unique_vector_test.cc", "../src/writer/append_vector_test.cc",