ast: Add TraverseExpressions()

An ast::Expression traversal helper extracted from Resolver.

Change-Id: I88754cbc86cc12cbf8348fb36a3f038904017f3d
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/67202
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2021-10-21 20:38:54 +00:00 committed by Tint LUCI CQ
parent 72789de9f5
commit f164a4a723
7 changed files with 414 additions and 71 deletions

View File

@ -317,6 +317,7 @@ libtint_source_set("libtint_core_all_src") {
"ast/texture.cc",
"ast/texture.h",
"ast/type.h",
"ast/traverse_expressions.h",
"ast/type_constructor_expression.cc",
"ast/type_constructor_expression.h",
"ast/type_decl.cc",

View File

@ -180,6 +180,7 @@ set(TINT_LIB_SRCS
ast/switch_statement.h
ast/texture.cc
ast/texture.h
ast/traverse_expressions.h
ast/type_constructor_expression.cc
ast/type_constructor_expression.h
ast/type_name.cc
@ -639,6 +640,7 @@ if(${TINT_BUILD_TESTS})
ast/switch_statement_test.cc
ast/test_helper.h
ast/texture_test.cc
ast/traverse_expressions_test.cc
ast/type_constructor_expression_test.cc
ast/u32_test.cc
ast/uint_literal_test.cc

View File

@ -0,0 +1,141 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_AST_TRAVERSE_EXPRESSIONS_H_
#define SRC_AST_TRAVERSE_EXPRESSIONS_H_
#include <vector>
#include "src/ast/array_accessor_expression.h"
#include "src/ast/binary_expression.h"
#include "src/ast/bitcast_expression.h"
#include "src/ast/call_expression.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/phony_expression.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_op_expression.h"
#include "src/utils/reverse.h"
namespace tint {
namespace ast {
/// The action to perform after calling the TraverseExpressions() callback
/// function.
enum class TraverseAction {
/// Stop traversal immediately.
Stop,
/// Descend into this expression.
Descend,
/// Do not descend into this expression.
Skip,
};
/// The order TraverseExpressions() will traverse expressions
enum class TraverseOrder {
/// Expressions will be traversed from left to right
LeftToRight,
/// Expressions will be traversed from right to left
RightToLeft,
};
/// TraverseExpressions performs a depth-first traversal of the expression nodes
/// from `root`, calling `callback` for each of the visited expressions that
/// match the predicate parameter type, in pre-ordering (root first).
/// @param root the root expression node
/// @param diags the diagnostics used for error messages
/// @param callback the callback function. Must be of the signature:
/// `TraverseAction(const T*)` where T is an ast::Expression type.
/// @return true on success, false on error
template <TraverseOrder ORDER = TraverseOrder::LeftToRight, typename CALLBACK>
bool TraverseExpressions(const ast::Expression* root,
diag::List& diags,
CALLBACK&& callback) {
using EXPR_TYPE = std::remove_pointer_t<traits::ParamTypeT<CALLBACK, 0>>;
std::vector<const ast::Expression*> to_visit{root};
auto push_pair = [&](const ast::Expression* left,
const ast::Expression* right) {
if (ORDER == TraverseOrder::LeftToRight) {
to_visit.push_back(right);
to_visit.push_back(left);
} else {
to_visit.push_back(left);
to_visit.push_back(right);
}
};
auto push_list = [&](const std::vector<const ast::Expression*>& exprs) {
if (ORDER == TraverseOrder::LeftToRight) {
for (auto* expr : utils::Reverse(exprs)) {
to_visit.push_back(expr);
}
} else {
for (auto* expr : exprs) {
to_visit.push_back(expr);
}
}
};
while (!to_visit.empty()) {
auto* expr = to_visit.back();
to_visit.pop_back();
if (auto* filtered = expr->As<EXPR_TYPE>()) {
switch (callback(filtered)) {
case TraverseAction::Stop:
return true;
case TraverseAction::Skip:
continue;
case TraverseAction::Descend:
break;
}
}
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
push_pair(array->array, array->index);
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
push_pair(bin_op->lhs, bin_op->rhs);
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
to_visit.push_back(bitcast->expr);
} else if (auto* call = expr->As<ast::CallExpression>()) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
// function name in the traversal.
// to_visit.push_back(call->func);
push_list(call->args);
} else if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
push_list(type_ctor->values);
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include the
// member name in the traversal.
// push_pair(member->structure, member->member);
to_visit.push_back(member->structure);
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
to_visit.push_back(unary->expr);
} else if (expr->IsAnyOf<ast::ScalarConstructorExpression,
ast::IdentifierExpression,
ast::PhonyExpression>()) {
// Leaf expression
} else {
TINT_ICE(AST, diags) << "unhandled expression type: "
<< expr->TypeInfo().name;
return false;
}
}
return true;
}
} // namespace ast
} // namespace tint
#endif // SRC_AST_TRAVERSE_EXPRESSIONS_H_

View File

@ -0,0 +1,262 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/ast/traverse_expressions.h"
#include "gmock/gmock.h"
#include "src/ast/test_helper.h"
namespace tint {
namespace ast {
namespace {
using ::testing::ElementsAre;
using TraverseExpressionsTest = TestHelper;
TEST_F(TraverseExpressionsTest, DescendArrayAccessorExpression) {
std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
std::vector<const ast::Expression*> i = {IndexAccessor(e[0], e[1]),
IndexAccessor(e[2], e[3])};
auto* root = IndexAccessor(i[0], i[1]);
{
std::vector<const ast::Expression*> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
l2r.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, i[0], e[0], e[1], i[1], e[2], e[3]));
}
{
std::vector<const ast::Expression*> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(
root, Diagnostics(), [&](const ast::Expression* expr) {
r2l.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, i[1], e[3], e[2], i[0], e[1], e[0]));
}
}
TEST_F(TraverseExpressionsTest, DescendBinaryExpression) {
std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
std::vector<const ast::Expression*> i = {Add(e[0], e[1]), Sub(e[2], e[3])};
auto* root = Mul(i[0], i[1]);
{
std::vector<const ast::Expression*> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
l2r.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, i[0], e[0], e[1], i[1], e[2], e[3]));
}
{
std::vector<const ast::Expression*> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(
root, Diagnostics(), [&](const ast::Expression* expr) {
r2l.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, i[1], e[3], e[2], i[0], e[1], e[0]));
}
}
TEST_F(TraverseExpressionsTest, DescendBitcastExpression) {
auto* e = Expr(1);
auto* b0 = Bitcast<i32>(e);
auto* b1 = Bitcast<i32>(b0);
auto* b2 = Bitcast<i32>(b1);
auto* root = Bitcast<i32>(b2);
{
std::vector<const ast::Expression*> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
l2r.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, b2, b1, b0, e));
}
{
std::vector<const ast::Expression*> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(
root, Diagnostics(), [&](const ast::Expression* expr) {
r2l.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, b2, b1, b0, e));
}
}
TEST_F(TraverseExpressionsTest, DescendCallExpression) {
std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
std::vector<const ast::Expression*> c = {Call("a", e[0], e[1]),
Call("b", e[2], e[3])};
auto* root = Call("c", c[0], c[1]);
{
std::vector<const ast::Expression*> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
l2r.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, c[0], e[0], e[1], c[1], e[2], e[3]));
}
{
std::vector<const ast::Expression*> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(
root, Diagnostics(), [&](const ast::Expression* expr) {
r2l.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, c[1], e[3], e[2], c[0], e[1], e[0]));
}
}
TEST_F(TraverseExpressionsTest, DescendTypeConstructorExpression) {
std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
std::vector<const ast::Expression*> c = {vec2<i32>(e[0], e[1]),
vec2<i32>(e[2], e[3])};
auto* root = vec2<i32>(c[0], c[1]);
{
std::vector<const ast::Expression*> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
l2r.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, c[0], e[0], e[1], c[1], e[2], e[3]));
}
{
std::vector<const ast::Expression*> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(
root, Diagnostics(), [&](const ast::Expression* expr) {
r2l.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, c[1], e[3], e[2], c[0], e[1], e[0]));
}
}
// TODO(crbug.com/tint/1257): Test ignores member accessor 'member' field.
// Replace with the test below when fixed.
TEST_F(TraverseExpressionsTest, DescendMemberIndexExpression) {
auto* e = Expr(1);
auto* m = MemberAccessor(e, Expr("a"));
auto* root = MemberAccessor(m, Expr("b"));
{
std::vector<const ast::Expression*> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
l2r.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, m, e));
}
{
std::vector<const ast::Expression*> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(
root, Diagnostics(), [&](const ast::Expression* expr) {
r2l.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, m, e));
}
}
// TODO(crbug.com/tint/1257): The correct test for DescendMemberIndexExpression.
TEST_F(TraverseExpressionsTest, DISABLED_DescendMemberIndexExpression) {
auto* e = Expr(1);
std::vector<const ast::IdentifierExpression*> i = {Expr("a"), Expr("b")};
auto* m = MemberAccessor(e, i[0]);
auto* root = MemberAccessor(m, i[1]);
{
std::vector<const ast::Expression*> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
l2r.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, m, e, i[0], i[1]));
}
{
std::vector<const ast::Expression*> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(
root, Diagnostics(), [&](const ast::Expression* expr) {
r2l.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, i[1], m, i[0], e));
}
}
TEST_F(TraverseExpressionsTest, DescendUnaryExpression) {
auto* e = Expr(1);
auto* u0 = AddressOf(e);
auto* u1 = Deref(u0);
auto* u2 = AddressOf(u1);
auto* root = Deref(u2);
{
std::vector<const ast::Expression*> l2r;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
l2r.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(l2r, ElementsAre(root, u2, u1, u0, e));
}
{
std::vector<const ast::Expression*> r2l;
TraverseExpressions<TraverseOrder::RightToLeft>(
root, Diagnostics(), [&](const ast::Expression* expr) {
r2l.push_back(expr);
return ast::TraverseAction::Descend;
});
EXPECT_THAT(r2l, ElementsAre(root, u2, u1, u0, e));
}
}
TEST_F(TraverseExpressionsTest, Skip) {
std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
std::vector<const ast::Expression*> i = {IndexAccessor(e[0], e[1]),
IndexAccessor(e[2], e[3])};
auto* root = IndexAccessor(i[0], i[1]);
std::vector<const ast::Expression*> order;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
order.push_back(expr);
return expr == i[0] ? ast::TraverseAction::Skip
: ast::TraverseAction::Descend;
});
EXPECT_THAT(order, ElementsAre(root, i[0], i[1], e[2], e[3]));
}
TEST_F(TraverseExpressionsTest, Stop) {
std::vector<const ast::Expression*> e = {Expr(1), Expr(1), Expr(1), Expr(1)};
std::vector<const ast::Expression*> i = {IndexAccessor(e[0], e[1]),
IndexAccessor(e[2], e[3])};
auto* root = IndexAccessor(i[0], i[1]);
std::vector<const ast::Expression*> order;
TraverseExpressions<TraverseOrder::LeftToRight>(
root, Diagnostics(), [&](const ast::Expression* expr) {
order.push_back(expr);
return expr == i[0] ? ast::TraverseAction::Stop
: ast::TraverseAction::Descend;
});
EXPECT_THAT(order, ElementsAre(root, i[0]));
}
} // namespace
} // namespace ast
} // namespace tint

View File

@ -45,6 +45,7 @@
#include "src/ast/storage_texture.h"
#include "src/ast/struct_block_decoration.h"
#include "src/ast/switch_statement.h"
#include "src/ast/traverse_expressions.h"
#include "src/ast/type_name.h"
#include "src/ast/unary_op_expression.h"
#include "src/ast/variable_decl_statement.h"
@ -469,7 +470,6 @@ Resolver::VariableInfo* Resolver::Variable(const ast::Variable* var,
// Does the variable have a constructor?
if (auto* ctor = var->constructor) {
Mark(var->constructor);
if (!Expression(var->constructor)) {
return nullptr;
}
@ -1886,7 +1886,6 @@ bool Resolver::Function(const ast::Function* func) {
continue;
}
Mark(expr);
if (!Expression(expr)) {
return false;
}
@ -2061,7 +2060,6 @@ bool Resolver::Statement(const ast::Statement* stmt) {
return true;
}
if (auto* c = stmt->As<ast::CallStatement>()) {
Mark(c->expr);
if (!Expression(c->expr)) {
return false;
}
@ -2138,7 +2136,6 @@ bool Resolver::IfStatement(const ast::IfStatement* stmt) {
builder_->create<sem::IfStatement>(stmt, current_compound_statement_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
Mark(stmt->condition);
if (!Expression(stmt->condition)) {
return false;
}
@ -2175,7 +2172,6 @@ bool Resolver::ElseStatement(const ast::ElseStatement* stmt) {
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
if (auto* cond = stmt->condition) {
Mark(cond);
if (!Expression(cond)) {
return false;
}
@ -2250,7 +2246,6 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
}
if (auto* condition = stmt->condition) {
Mark(condition);
if (!Expression(condition)) {
return false;
}
@ -2279,58 +2274,14 @@ bool Resolver::ForLoopStatement(const ast::ForLoopStatement* stmt) {
});
}
bool Resolver::TraverseExpressions(const ast::Expression* root,
std::vector<const ast::Expression*>& out) {
std::vector<const ast::Expression*> to_visit;
to_visit.emplace_back(root);
auto add = [&](const ast::Expression* e) {
Mark(e);
to_visit.emplace_back(e);
};
while (!to_visit.empty()) {
auto* expr = to_visit.back();
to_visit.pop_back();
out.emplace_back(expr);
if (auto* array = expr->As<ast::ArrayAccessorExpression>()) {
add(array->array);
add(array->index);
} else if (auto* bin_op = expr->As<ast::BinaryExpression>()) {
add(bin_op->lhs);
add(bin_op->rhs);
} else if (auto* bitcast = expr->As<ast::BitcastExpression>()) {
add(bitcast->expr);
} else if (auto* call = expr->As<ast::CallExpression>()) {
for (auto* arg : call->args) {
add(arg);
}
} else if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
for (auto* value : type_ctor->values) {
add(value);
}
} else if (auto* member = expr->As<ast::MemberAccessorExpression>()) {
add(member->structure);
} else if (auto* unary = expr->As<ast::UnaryOpExpression>()) {
add(unary->expr);
} else if (expr->IsAnyOf<ast::ScalarConstructorExpression,
ast::IdentifierExpression>()) {
// Leaf expression
} else {
TINT_ICE(Resolver, diagnostics_)
<< "unhandled expression type: " << expr->TypeInfo().name;
return false;
}
}
return true;
}
bool Resolver::Expression(const ast::Expression* root) {
std::vector<const ast::Expression*> sorted;
if (!TraverseExpressions(root, sorted)) {
if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
root, diagnostics_, [&](const ast::Expression* expr) {
Mark(expr);
sorted.emplace_back(expr);
return ast::TraverseAction::Descend;
})) {
return false;
}
@ -3874,7 +3825,6 @@ sem::Array* Resolver::Array(const ast::Array* arr) {
// sem::Array uses a size of 0 for a runtime-sized array.
uint32_t count = 0;
if (auto* count_expr = arr->count) {
Mark(count_expr);
if (!Expression(count_expr)) {
return nullptr;
}
@ -4340,7 +4290,6 @@ bool Resolver::Return(const ast::ReturnStatement* ret) {
current_function_->return_statements.push_back(ret);
if (auto* value = ret->value) {
Mark(value);
if (!Expression(value)) {
return false;
}
@ -4424,7 +4373,6 @@ bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) {
builder_->create<sem::SwitchStatement>(stmt, current_compound_statement_);
builder_->Sem().Add(stmt, sem);
return Scope(sem, [&] {
Mark(stmt->condition);
if (!Expression(stmt->condition)) {
return false;
}
@ -4442,9 +4390,6 @@ bool Resolver::SwitchStatement(const ast::SwitchStatement* stmt) {
}
bool Resolver::Assignment(const ast::AssignmentStatement* a) {
Mark(a->lhs);
Mark(a->rhs);
if (!Expression(a->lhs) || !Expression(a->rhs)) {
return false;
}

View File

@ -262,15 +262,6 @@ class Resolver {
bool UnaryOp(const ast::UnaryOpExpression*);
bool VariableDeclStatement(const ast::VariableDeclStatement*);
/// Performs a depth-first traversal of the expression nodes from `root`,
/// collecting all the visited expressions in pre-ordering (root first).
/// @param root the root expression node
/// @param out the ordered list of visited expression nodes, starting with the
/// root node, and ending with leaf nodes
/// @return true on success, false on error
bool TraverseExpressions(const ast::Expression* root,
std::vector<const ast::Expression*>& out);
// AST and Type validation methods
// Each return true on success, false on failure.
bool ValidateArray(const sem::Array* arr, const Source& source);

View File

@ -207,6 +207,7 @@ tint_unittests_source_set("tint_unittests_core_src") {
"../src/ast/switch_statement_test.cc",
"../src/ast/test_helper.h",
"../src/ast/texture_test.cc",
"../src/ast/traverse_expressions_test.cc",
"../src/ast/type_constructor_expression_test.cc",
"../src/ast/u32_test.cc",
"../src/ast/uint_literal_test.cc",