diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9f81fda4c8..06e7392f24 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -713,6 +713,7 @@ if(${TINT_BUILD_TESTS}) utils/io/command_test.cc utils/io/tmpfile_test.cc utils/math_test.cc + utils/reverse_test.cc utils/scoped_assignment_test.cc utils/unique_vector_test.cc writer/append_vector_test.cc diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 20acfa86fe..f7c958b96d 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -72,6 +72,7 @@ #include "src/utils/defer.h" #include "src/utils/get_or_create.h" #include "src/utils/math.h" +#include "src/utils/reverse.h" #include "src/utils/scoped_assignment.h" namespace tint { @@ -2244,60 +2245,94 @@ bool Resolver::ForLoopStatement(ast::ForLoopStatement* stmt) { }); } -bool Resolver::Expressions(const ast::ExpressionList& list) { - for (auto* expr : list) { - Mark(expr); - if (!Expression(expr)) { +bool Resolver::TraverseExpressions(ast::Expression* root, + std::vector& out) { + std::vector to_visit; + to_visit.emplace_back(root); + + auto add = [&](ast::Expression* e) { + Mark(e); + to_visit.emplace_back(e); + }; + + while (!to_visit.empty()) { + auto* expr = to_visit.back(); + to_visit.pop_back(); + + out.emplace_back(expr); + + if (auto* array = expr->As()) { + add(array->array()); + add(array->idx_expr()); + } else if (auto* bin_op = expr->As()) { + add(bin_op->lhs()); + add(bin_op->rhs()); + } else if (auto* bitcast = expr->As()) { + add(bitcast->expr()); + } else if (auto* call = expr->As()) { + for (auto* arg : call->params()) { + add(arg); + } + } else if (auto* type_ctor = expr->As()) { + for (auto* value : type_ctor->values()) { + add(value); + } + } else if (auto* member = expr->As()) { + add(member->structure()); + } else if (auto* unary = expr->As()) { + add(unary->expr()); + } else if (expr->IsAnyOf()) { + // Leaf expression + } else { + TINT_ICE(Resolver, diagnostics_) + << "unhandled expression type: " << expr->TypeInfo().name; return false; } } + return true; } -bool Resolver::Expression(ast::Expression* expr) { - if (TypeOf(expr)) { - return true; // Already resolved - } - - bool ok = false; - if (auto* array = expr->As()) { - ok = ArrayAccessor(array); - } else if (auto* bin_op = expr->As()) { - ok = Binary(bin_op); - } else if (auto* bitcast = expr->As()) { - ok = Bitcast(bitcast); - } else if (auto* call = expr->As()) { - ok = Call(call); - } else if (auto* ctor = expr->As()) { - ok = Constructor(ctor); - } else if (auto* ident = expr->As()) { - ok = Identifier(ident); - } else if (auto* member = expr->As()) { - ok = MemberAccessor(member); - } else if (auto* unary = expr->As()) { - ok = UnaryOp(unary); - } else { - AddError("unknown expression for type determination", expr->source()); - } - - if (!ok) { +bool Resolver::Expression(ast::Expression* root) { + std::vector sorted; + if (!TraverseExpressions(root, sorted)) { return false; } + for (auto* expr : utils::Reverse(sorted)) { + bool ok = false; + if (auto* array = expr->As()) { + ok = ArrayAccessor(array); + } else if (auto* bin_op = expr->As()) { + ok = Binary(bin_op); + } else if (auto* bitcast = expr->As()) { + ok = Bitcast(bitcast); + } else if (auto* call = expr->As()) { + ok = Call(call); + } else if (auto* ctor = expr->As()) { + ok = Constructor(ctor); + } else if (auto* ident = expr->As()) { + ok = Identifier(ident); + } else if (auto* member = expr->As()) { + ok = MemberAccessor(member); + } else if (auto* unary = expr->As()) { + ok = UnaryOp(unary); + } else { + TINT_ICE(Resolver, diagnostics_) + << "unhandled expression type: " << expr->TypeInfo().name; + return false; + } + if (!ok) { + return false; + } + } + return true; } bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) { - Mark(expr->array()); - if (!Expression(expr->array())) { - return false; - } auto* idx = expr->idx_expr(); - Mark(idx); - if (!Expression(idx)) { - return false; - } - auto* res = TypeOf(expr->array()); auto* parent_type = res->UnwrapRef(); const sem::Type* ret = nullptr; @@ -2345,10 +2380,6 @@ bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) { } bool Resolver::Bitcast(ast::BitcastExpression* expr) { - Mark(expr->expr()); - if (!Expression(expr->expr())) { - return false; - } auto* ty = Type(expr->type()); if (!ty) { return false; @@ -2362,10 +2393,6 @@ bool Resolver::Bitcast(ast::BitcastExpression* expr) { } bool Resolver::Call(ast::CallExpression* call) { - if (!Expressions(call->params())) { - return false; - } - Mark(call->func()); auto* ident = call->func(); auto name = builder_->Symbols().NameFor(ident->symbol()); @@ -2641,13 +2668,6 @@ bool Resolver::ValidateFunctionCall(const ast::CallExpression* call, bool Resolver::Constructor(ast::ConstructorExpression* expr) { if (auto* type_ctor = expr->As()) { - for (auto* value : type_ctor->values()) { - Mark(value); - if (!Expression(value)) { - return false; - } - } - auto* type = Type(type_ctor->type()); if (!type) { return false; @@ -2994,11 +3014,6 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) { } bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) { - Mark(expr->structure()); - if (!Expression(expr->structure())) { - return false; - } - auto* structure = TypeOf(expr->structure()); auto* storage_type = structure->UnwrapRef(); @@ -3118,13 +3133,6 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) { } bool Resolver::Binary(ast::BinaryExpression* expr) { - Mark(expr->lhs()); - Mark(expr->rhs()); - - if (!Expression(expr->lhs()) || !Expression(expr->rhs())) { - return false; - } - using Bool = sem::Bool; using F32 = sem::F32; using I32 = sem::I32; @@ -3330,13 +3338,6 @@ bool Resolver::Binary(ast::BinaryExpression* expr) { } bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) { - Mark(unary->expr()); - - // Resolve the inner expression - if (!Expression(unary->expr())) { - return false; - } - auto* expr_type = TypeOf(unary->expr()); if (!expr_type) { return false; diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 9f02791d5d..ccbb97b6cc 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -244,7 +244,6 @@ class Resolver { bool Constructor(ast::ConstructorExpression*); bool ElseStatement(ast::ElseStatement*); bool Expression(ast::Expression*); - bool Expressions(const ast::ExpressionList&); bool ForLoopStatement(ast::ForLoopStatement*); bool Function(ast::Function*); bool FunctionCall(const ast::CallExpression* call); @@ -262,6 +261,15 @@ class Resolver { bool UnaryOp(ast::UnaryOpExpression*); bool VariableDeclStatement(const ast::VariableDeclStatement*); + /// Performs a depth-first traversal of the expression nodes from `root`, + /// collecting all the visited expressions in pre-ordering (root first). + /// @param root the root expression node + /// @param out the ordered list of visited expression nodes, starting with the + /// root node, and ending with leaf nodes + /// @return true on success, false on error + bool TraverseExpressions(ast::Expression* root, + std::vector& out); + // AST and Type validation methods // Each return true on success, false on failure. bool ValidateArray(const sem::Array* arr, const Source& source); diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index 07396a79b7..bad2a8e92c 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -15,6 +15,7 @@ #include "src/resolver/resolver.h" #include "gmock/gmock.h" +#include "gtest/gtest-spi.h" #include "src/ast/assignment_statement.h" #include "src/ast/bitcast_expression.h" #include "src/ast/break_statement.h" @@ -56,10 +57,9 @@ class FakeStmt : public ast::Statement { } }; -class FakeExpr : public ast::Expression { +class FakeExpr : public Castable { public: - FakeExpr(ProgramID program_id, Source source) - : ast::Expression(program_id, source) {} + FakeExpr(ProgramID program_id, Source source) : Base(program_id, source) {} FakeExpr* Clone(CloneContext*) const override { return nullptr; } void to_str(const sem::Info&, std::ostream&, size_t) const override {} }; @@ -158,14 +158,15 @@ TEST_F(ResolverValidationTest, Stmt_Else_NonBool) { "12:34 error: else statement condition must be bool, got f32"); } -TEST_F(ResolverValidationTest, Expr_Error_Unknown) { - auto* e = create(Source{Source::Location{2, 30}}); - WrapInFunction(e); - - EXPECT_FALSE(r()->Resolve()); - - EXPECT_EQ(r()->error(), - "2:30 error: unknown expression for type determination"); +TEST_F(ResolverValidationTest, Expr_ErrUnknownExprType) { + EXPECT_FATAL_FAILURE( + { + ProgramBuilder b; + b.WrapInFunction(b.create()); + Resolver(&b).Resolve(); + }, + "internal compiler error: unhandled expression type: " + "tint::resolver::FakeExpr"); } TEST_F(ResolverValidationTest, Expr_DontCall_Function) { @@ -925,3 +926,5 @@ TEST_F(ResolverTest, Expr_Constructor_Cast_Pointer) { } // namespace } // namespace resolver } // namespace tint + +TINT_INSTANTIATE_TYPEINFO(tint::resolver::FakeExpr); diff --git a/src/utils/reverse.h b/src/utils/reverse.h new file mode 100644 index 0000000000..0848dc585e --- /dev/null +++ b/src/utils/reverse.h @@ -0,0 +1,64 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_UTILS_REVERSE_H_ +#define SRC_UTILS_REVERSE_H_ + +#include + +namespace tint { +namespace utils { + +namespace detail { +/// Used by utils::Reverse to hold the underlying iterable. +/// begin(ReverseIterable) and end(ReverseIterable) are automatically +/// called for range-for loops, via argument-dependent lookup. +/// See https://en.cppreference.com/w/cpp/language/range-for +template +struct ReverseIterable { + /// The wrapped iterable object. + T& iterable; +}; + +template +auto begin(ReverseIterable r_it) { + return std::rbegin(r_it.iterable); +} + +template +auto end(ReverseIterable r_it) { + return std::rend(r_it.iterable); +} +} // namespace detail + +/// Reverse returns an iterable wrapper that when used in range-for loops, +/// performs a reverse iteration over the object `iterable`. +/// Example: +/// ``` +/// /* Equivalent to: +/// * for (auto it = vec.rbegin(); i != vec.rend(); ++it) { +/// * auto v = *it; +/// */ +/// for (auto v : utils::Reverse(vec)) { +/// } +/// ``` +template +detail::ReverseIterable Reverse(T&& iterable) { + return {iterable}; +} + +} // namespace utils +} // namespace tint + +#endif // SRC_UTILS_REVERSE_H_ diff --git a/src/utils/reverse_test.cc b/src/utils/reverse_test.cc new file mode 100644 index 0000000000..25be0eb40a --- /dev/null +++ b/src/utils/reverse_test.cc @@ -0,0 +1,36 @@ +// Copyright 2021 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/utils/reverse.h" + +#include + +#include "gmock/gmock.h" + +namespace tint { +namespace utils { +namespace { + +TEST(ReverseTest, Vector) { + std::vector vec{1, 3, 5, 7, 9}; + std::vector rev; + for (auto v : Reverse(vec)) { + rev.emplace_back(v); + } + ASSERT_THAT(rev, testing::ElementsAre(9, 7, 5, 3, 1)); +} + +} // namespace +} // namespace utils +} // namespace tint diff --git a/test/BUILD.gn b/test/BUILD.gn index c7fffad98b..ad410e75d6 100644 --- a/test/BUILD.gn +++ b/test/BUILD.gn @@ -320,6 +320,7 @@ tint_unittests_source_set("tint_unittests_core_src") { "../src/utils/io/command_test.cc", "../src/utils/io/tmpfile_test.cc", "../src/utils/math_test.cc", + "../src/utils/reverse_test.cc", "../src/utils/scoped_assignment_test.cc", "../src/utils/unique_vector_test.cc", "../src/writer/append_vector_test.cc",