tint: limit expression depth to avoid stack overflow in backends

Bug: chromium:1324533
Change-Id: I2a334eaee59b2235830057b78c92b919ff0ea940
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90302
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Antonio Maiorano 2022-05-17 15:01:42 +00:00 committed by Dawn LUCI CQ
parent 8ee1e11be7
commit 8ba6e1d6ec
4 changed files with 97 additions and 26 deletions

View File

@ -54,40 +54,59 @@ enum class TraverseOrder {
/// @param root the root expression node
/// @param diags the diagnostics used for error messages
/// @param callback the callback function. Must be of the signature:
/// `TraverseAction(const T*)` where T is an ast::Expression type.
/// `TraverseAction(const T* expr)` or `TraverseAction(const T* expr, size_t depth)` where T
/// is an ast::Expression type.
/// @return true on success, false on error
template <TraverseOrder ORDER = TraverseOrder::LeftToRight, typename CALLBACK>
bool TraverseExpressions(const ast::Expression* root, diag::List& diags, CALLBACK&& callback) {
using EXPR_TYPE = std::remove_pointer_t<traits::ParameterType<CALLBACK, 0>>;
std::vector<const ast::Expression*> to_visit{root};
constexpr static bool kHasDepthArg = traits::SignatureOfT<CALLBACK>::parameter_count == 2;
auto push_pair = [&](const ast::Expression* left, const ast::Expression* right) {
struct Pending {
const ast::Expression* expr;
size_t depth;
};
std::vector<Pending> to_visit{{root, 0}};
auto push_single = [&](const ast::Expression* expr, size_t depth) {
to_visit.push_back({expr, depth});
};
auto push_pair = [&](const ast::Expression* left, const ast::Expression* right, size_t depth) {
if (ORDER == TraverseOrder::LeftToRight) {
to_visit.push_back(right);
to_visit.push_back(left);
to_visit.push_back({right, depth});
to_visit.push_back({left, depth});
} else {
to_visit.push_back(left);
to_visit.push_back(right);
to_visit.push_back({left, depth});
to_visit.push_back({right, depth});
}
};
auto push_list = [&](const std::vector<const ast::Expression*>& exprs) {
auto push_list = [&](const std::vector<const ast::Expression*>& exprs, size_t depth) {
if (ORDER == TraverseOrder::LeftToRight) {
for (auto* expr : utils::Reverse(exprs)) {
to_visit.push_back(expr);
to_visit.push_back({expr, depth});
}
} else {
for (auto* expr : exprs) {
to_visit.push_back(expr);
to_visit.push_back({expr, depth});
}
}
};
while (!to_visit.empty()) {
auto* expr = to_visit.back();
auto& p = to_visit.back();
to_visit.pop_back();
const ast::Expression* expr = p.expr;
if (auto* filtered = expr->As<EXPR_TYPE>()) {
switch (callback(filtered)) {
if (auto* filtered = expr->template As<EXPR_TYPE>()) {
TraverseAction result;
if constexpr (kHasDepthArg) {
result = callback(filtered, p.depth);
} else {
result = callback(filtered);
}
switch (result) {
case TraverseAction::Stop:
return true;
case TraverseAction::Skip:
@ -100,32 +119,31 @@ bool TraverseExpressions(const ast::Expression* root, diag::List& diags, CALLBAC
bool ok = Switch(
expr,
[&](const IndexAccessorExpression* idx) {
push_pair(idx->object, idx->index);
push_pair(idx->object, idx->index, p.depth + 1);
return true;
},
[&](const BinaryExpression* bin_op) {
push_pair(bin_op->lhs, bin_op->rhs);
push_pair(bin_op->lhs, bin_op->rhs, p.depth + 1);
return true;
},
[&](const BitcastExpression* bitcast) {
to_visit.push_back(bitcast->expr);
push_single(bitcast->expr, p.depth + 1);
return true;
},
[&](const CallExpression* call) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
// the function name in the traversal. to_visit.push_back(call->func);
push_list(call->args);
// the function name in the traversal. push_single(call->func);
push_list(call->args, p.depth + 1);
return true;
},
[&](const MemberAccessorExpression* member) {
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
// the member name in the traversal. push_pair(member->structure,
// member->member);
to_visit.push_back(member->structure);
// the member name in the traversal. push_pair(member->member, p.depth + 1);
push_single(member->structure, p.depth + 1);
return true;
},
[&](const UnaryOpExpression* unary) {
to_visit.push_back(unary->expr);
push_single(unary->expr, p.depth + 1);
return true;
},
[&](Default) {

View File

@ -73,6 +73,23 @@ TEST_F(TraverseExpressionsTest, DescendBinaryExpression) {
}
}
TEST_F(TraverseExpressionsTest, Depth) {
std::vector<const ast::Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
std::vector<const ast::Expression*> i = {Add(e[0], e[1]), Sub(e[2], e[3])};
auto* root = Mul(i[0], i[1]);
size_t j = 0;
size_t depths[] = {0, 1, 2, 2, 1, 2, 2};
{
TraverseExpressions<TraverseOrder::LeftToRight>( //
root, Diagnostics(), [&](const ast::Expression* expr, size_t depth) {
(void)expr;
EXPECT_THAT(depth, depths[j++]);
return ast::TraverseAction::Descend;
});
}
}
TEST_F(TraverseExpressionsTest, DescendBitcastExpression) {
auto* e = Expr(1_i);
auto* b0 = Bitcast<i32>(e);

View File

@ -1025,11 +1025,19 @@ sem::ForLoopStatement* Resolver::ForLoopStatement(const ast::ForLoopStatement* s
sem::Expression* Resolver::Expression(const ast::Expression* root) {
std::vector<const ast::Expression*> sorted;
bool mark_failed = false;
constexpr size_t kMaxExpressionDepth = 512U;
bool failed = false;
if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
root, diagnostics_, [&](const ast::Expression* expr) {
root, diagnostics_, [&](const ast::Expression* expr, size_t depth) {
if (depth > kMaxExpressionDepth) {
AddError(
"reached max expression depth of " + std::to_string(kMaxExpressionDepth),
expr->source);
failed = true;
return ast::TraverseAction::Stop;
}
if (!Mark(expr)) {
mark_failed = true;
failed = true;
return ast::TraverseAction::Stop;
}
sorted.emplace_back(expr);
@ -1038,7 +1046,7 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) {
return nullptr;
}
if (mark_failed) {
if (failed) {
return nullptr;
}

View File

@ -2098,5 +2098,33 @@ TEST_F(ResolverTest, ModuleDependencyOrderedDeclarations) {
ElementsAre(f0, v0, a0, s0, f1, v1, a1, s1, f2, v2, a2, s2));
}
constexpr size_t kMaxExpressionDepth = 512U;
TEST_F(ResolverTest, MaxExpressionDepth_Pass) {
auto* b = Var("b", ty.i32());
const ast::Expression* chain = nullptr;
for (size_t i = 0; i < kMaxExpressionDepth; ++i) {
chain = Add(chain ? chain : Expr("b"), Expr("b"));
}
auto* a = Let("a", nullptr, chain);
WrapInFunction(b, a);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverTest, MaxExpressionDepth_Fail) {
auto* b = Var("b", ty.i32());
const ast::Expression* chain = nullptr;
for (size_t i = 0; i < kMaxExpressionDepth + 1; ++i) {
chain = Add(chain ? chain : Expr("b"), Expr("b"));
}
auto* a = Let("a", nullptr, chain);
WrapInFunction(b, a);
EXPECT_FALSE(r()->Resolve());
EXPECT_THAT(r()->error(), HasSubstr("error: reached max expression depth of " +
std::to_string(kMaxExpressionDepth)));
}
} // namespace
} // namespace tint::resolver