tint: limit expression depth to avoid stack overflow in backends
Bug: chromium:1324533 Change-Id: I2a334eaee59b2235830057b78c92b919ff0ea940 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90302 Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
8ee1e11be7
commit
8ba6e1d6ec
|
@ -54,40 +54,59 @@ enum class TraverseOrder {
|
|||
/// @param root the root expression node
|
||||
/// @param diags the diagnostics used for error messages
|
||||
/// @param callback the callback function. Must be of the signature:
|
||||
/// `TraverseAction(const T*)` where T is an ast::Expression type.
|
||||
/// `TraverseAction(const T* expr)` or `TraverseAction(const T* expr, size_t depth)` where T
|
||||
/// is an ast::Expression type.
|
||||
/// @return true on success, false on error
|
||||
template <TraverseOrder ORDER = TraverseOrder::LeftToRight, typename CALLBACK>
|
||||
bool TraverseExpressions(const ast::Expression* root, diag::List& diags, CALLBACK&& callback) {
|
||||
using EXPR_TYPE = std::remove_pointer_t<traits::ParameterType<CALLBACK, 0>>;
|
||||
std::vector<const ast::Expression*> to_visit{root};
|
||||
constexpr static bool kHasDepthArg = traits::SignatureOfT<CALLBACK>::parameter_count == 2;
|
||||
|
||||
auto push_pair = [&](const ast::Expression* left, const ast::Expression* right) {
|
||||
struct Pending {
|
||||
const ast::Expression* expr;
|
||||
size_t depth;
|
||||
};
|
||||
|
||||
std::vector<Pending> to_visit{{root, 0}};
|
||||
|
||||
auto push_single = [&](const ast::Expression* expr, size_t depth) {
|
||||
to_visit.push_back({expr, depth});
|
||||
};
|
||||
auto push_pair = [&](const ast::Expression* left, const ast::Expression* right, size_t depth) {
|
||||
if (ORDER == TraverseOrder::LeftToRight) {
|
||||
to_visit.push_back(right);
|
||||
to_visit.push_back(left);
|
||||
to_visit.push_back({right, depth});
|
||||
to_visit.push_back({left, depth});
|
||||
} else {
|
||||
to_visit.push_back(left);
|
||||
to_visit.push_back(right);
|
||||
to_visit.push_back({left, depth});
|
||||
to_visit.push_back({right, depth});
|
||||
}
|
||||
};
|
||||
auto push_list = [&](const std::vector<const ast::Expression*>& exprs) {
|
||||
auto push_list = [&](const std::vector<const ast::Expression*>& exprs, size_t depth) {
|
||||
if (ORDER == TraverseOrder::LeftToRight) {
|
||||
for (auto* expr : utils::Reverse(exprs)) {
|
||||
to_visit.push_back(expr);
|
||||
to_visit.push_back({expr, depth});
|
||||
}
|
||||
} else {
|
||||
for (auto* expr : exprs) {
|
||||
to_visit.push_back(expr);
|
||||
to_visit.push_back({expr, depth});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
while (!to_visit.empty()) {
|
||||
auto* expr = to_visit.back();
|
||||
auto& p = to_visit.back();
|
||||
to_visit.pop_back();
|
||||
const ast::Expression* expr = p.expr;
|
||||
|
||||
if (auto* filtered = expr->As<EXPR_TYPE>()) {
|
||||
switch (callback(filtered)) {
|
||||
if (auto* filtered = expr->template As<EXPR_TYPE>()) {
|
||||
TraverseAction result;
|
||||
if constexpr (kHasDepthArg) {
|
||||
result = callback(filtered, p.depth);
|
||||
} else {
|
||||
result = callback(filtered);
|
||||
}
|
||||
|
||||
switch (result) {
|
||||
case TraverseAction::Stop:
|
||||
return true;
|
||||
case TraverseAction::Skip:
|
||||
|
@ -100,32 +119,31 @@ bool TraverseExpressions(const ast::Expression* root, diag::List& diags, CALLBAC
|
|||
bool ok = Switch(
|
||||
expr,
|
||||
[&](const IndexAccessorExpression* idx) {
|
||||
push_pair(idx->object, idx->index);
|
||||
push_pair(idx->object, idx->index, p.depth + 1);
|
||||
return true;
|
||||
},
|
||||
[&](const BinaryExpression* bin_op) {
|
||||
push_pair(bin_op->lhs, bin_op->rhs);
|
||||
push_pair(bin_op->lhs, bin_op->rhs, p.depth + 1);
|
||||
return true;
|
||||
},
|
||||
[&](const BitcastExpression* bitcast) {
|
||||
to_visit.push_back(bitcast->expr);
|
||||
push_single(bitcast->expr, p.depth + 1);
|
||||
return true;
|
||||
},
|
||||
[&](const CallExpression* call) {
|
||||
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
|
||||
// the function name in the traversal. to_visit.push_back(call->func);
|
||||
push_list(call->args);
|
||||
// the function name in the traversal. push_single(call->func);
|
||||
push_list(call->args, p.depth + 1);
|
||||
return true;
|
||||
},
|
||||
[&](const MemberAccessorExpression* member) {
|
||||
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
|
||||
// the member name in the traversal. push_pair(member->structure,
|
||||
// member->member);
|
||||
to_visit.push_back(member->structure);
|
||||
// the member name in the traversal. push_pair(member->member, p.depth + 1);
|
||||
push_single(member->structure, p.depth + 1);
|
||||
return true;
|
||||
},
|
||||
[&](const UnaryOpExpression* unary) {
|
||||
to_visit.push_back(unary->expr);
|
||||
push_single(unary->expr, p.depth + 1);
|
||||
return true;
|
||||
},
|
||||
[&](Default) {
|
||||
|
|
|
@ -73,6 +73,23 @@ TEST_F(TraverseExpressionsTest, DescendBinaryExpression) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(TraverseExpressionsTest, Depth) {
|
||||
std::vector<const ast::Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
|
||||
std::vector<const ast::Expression*> i = {Add(e[0], e[1]), Sub(e[2], e[3])};
|
||||
auto* root = Mul(i[0], i[1]);
|
||||
|
||||
size_t j = 0;
|
||||
size_t depths[] = {0, 1, 2, 2, 1, 2, 2};
|
||||
{
|
||||
TraverseExpressions<TraverseOrder::LeftToRight>( //
|
||||
root, Diagnostics(), [&](const ast::Expression* expr, size_t depth) {
|
||||
(void)expr;
|
||||
EXPECT_THAT(depth, depths[j++]);
|
||||
return ast::TraverseAction::Descend;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TraverseExpressionsTest, DescendBitcastExpression) {
|
||||
auto* e = Expr(1_i);
|
||||
auto* b0 = Bitcast<i32>(e);
|
||||
|
|
|
@ -1025,11 +1025,19 @@ sem::ForLoopStatement* Resolver::ForLoopStatement(const ast::ForLoopStatement* s
|
|||
|
||||
sem::Expression* Resolver::Expression(const ast::Expression* root) {
|
||||
std::vector<const ast::Expression*> sorted;
|
||||
bool mark_failed = false;
|
||||
constexpr size_t kMaxExpressionDepth = 512U;
|
||||
bool failed = false;
|
||||
if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
|
||||
root, diagnostics_, [&](const ast::Expression* expr) {
|
||||
root, diagnostics_, [&](const ast::Expression* expr, size_t depth) {
|
||||
if (depth > kMaxExpressionDepth) {
|
||||
AddError(
|
||||
"reached max expression depth of " + std::to_string(kMaxExpressionDepth),
|
||||
expr->source);
|
||||
failed = true;
|
||||
return ast::TraverseAction::Stop;
|
||||
}
|
||||
if (!Mark(expr)) {
|
||||
mark_failed = true;
|
||||
failed = true;
|
||||
return ast::TraverseAction::Stop;
|
||||
}
|
||||
sorted.emplace_back(expr);
|
||||
|
@ -1038,7 +1046,7 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (mark_failed) {
|
||||
if (failed) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -2098,5 +2098,33 @@ TEST_F(ResolverTest, ModuleDependencyOrderedDeclarations) {
|
|||
ElementsAre(f0, v0, a0, s0, f1, v1, a1, s1, f2, v2, a2, s2));
|
||||
}
|
||||
|
||||
constexpr size_t kMaxExpressionDepth = 512U;
|
||||
|
||||
TEST_F(ResolverTest, MaxExpressionDepth_Pass) {
|
||||
auto* b = Var("b", ty.i32());
|
||||
const ast::Expression* chain = nullptr;
|
||||
for (size_t i = 0; i < kMaxExpressionDepth; ++i) {
|
||||
chain = Add(chain ? chain : Expr("b"), Expr("b"));
|
||||
}
|
||||
auto* a = Let("a", nullptr, chain);
|
||||
WrapInFunction(b, a);
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, MaxExpressionDepth_Fail) {
|
||||
auto* b = Var("b", ty.i32());
|
||||
const ast::Expression* chain = nullptr;
|
||||
for (size_t i = 0; i < kMaxExpressionDepth + 1; ++i) {
|
||||
chain = Add(chain ? chain : Expr("b"), Expr("b"));
|
||||
}
|
||||
auto* a = Let("a", nullptr, chain);
|
||||
WrapInFunction(b, a);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_THAT(r()->error(), HasSubstr("error: reached max expression depth of " +
|
||||
std::to_string(kMaxExpressionDepth)));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::resolver
|
||||
|
|
Loading…
Reference in New Issue