tint: limit expression depth to avoid stack overflow in backends
Bug: chromium:1324533 Change-Id: I2a334eaee59b2235830057b78c92b919ff0ea940 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/90302 Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
8ee1e11be7
commit
8ba6e1d6ec
|
@ -54,40 +54,59 @@ enum class TraverseOrder {
|
||||||
/// @param root the root expression node
|
/// @param root the root expression node
|
||||||
/// @param diags the diagnostics used for error messages
|
/// @param diags the diagnostics used for error messages
|
||||||
/// @param callback the callback function. Must be of the signature:
|
/// @param callback the callback function. Must be of the signature:
|
||||||
/// `TraverseAction(const T*)` where T is an ast::Expression type.
|
/// `TraverseAction(const T* expr)` or `TraverseAction(const T* expr, size_t depth)` where T
|
||||||
|
/// is an ast::Expression type.
|
||||||
/// @return true on success, false on error
|
/// @return true on success, false on error
|
||||||
template <TraverseOrder ORDER = TraverseOrder::LeftToRight, typename CALLBACK>
|
template <TraverseOrder ORDER = TraverseOrder::LeftToRight, typename CALLBACK>
|
||||||
bool TraverseExpressions(const ast::Expression* root, diag::List& diags, CALLBACK&& callback) {
|
bool TraverseExpressions(const ast::Expression* root, diag::List& diags, CALLBACK&& callback) {
|
||||||
using EXPR_TYPE = std::remove_pointer_t<traits::ParameterType<CALLBACK, 0>>;
|
using EXPR_TYPE = std::remove_pointer_t<traits::ParameterType<CALLBACK, 0>>;
|
||||||
std::vector<const ast::Expression*> to_visit{root};
|
constexpr static bool kHasDepthArg = traits::SignatureOfT<CALLBACK>::parameter_count == 2;
|
||||||
|
|
||||||
auto push_pair = [&](const ast::Expression* left, const ast::Expression* right) {
|
struct Pending {
|
||||||
|
const ast::Expression* expr;
|
||||||
|
size_t depth;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<Pending> to_visit{{root, 0}};
|
||||||
|
|
||||||
|
auto push_single = [&](const ast::Expression* expr, size_t depth) {
|
||||||
|
to_visit.push_back({expr, depth});
|
||||||
|
};
|
||||||
|
auto push_pair = [&](const ast::Expression* left, const ast::Expression* right, size_t depth) {
|
||||||
if (ORDER == TraverseOrder::LeftToRight) {
|
if (ORDER == TraverseOrder::LeftToRight) {
|
||||||
to_visit.push_back(right);
|
to_visit.push_back({right, depth});
|
||||||
to_visit.push_back(left);
|
to_visit.push_back({left, depth});
|
||||||
} else {
|
} else {
|
||||||
to_visit.push_back(left);
|
to_visit.push_back({left, depth});
|
||||||
to_visit.push_back(right);
|
to_visit.push_back({right, depth});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
auto push_list = [&](const std::vector<const ast::Expression*>& exprs) {
|
auto push_list = [&](const std::vector<const ast::Expression*>& exprs, size_t depth) {
|
||||||
if (ORDER == TraverseOrder::LeftToRight) {
|
if (ORDER == TraverseOrder::LeftToRight) {
|
||||||
for (auto* expr : utils::Reverse(exprs)) {
|
for (auto* expr : utils::Reverse(exprs)) {
|
||||||
to_visit.push_back(expr);
|
to_visit.push_back({expr, depth});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (auto* expr : exprs) {
|
for (auto* expr : exprs) {
|
||||||
to_visit.push_back(expr);
|
to_visit.push_back({expr, depth});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
while (!to_visit.empty()) {
|
while (!to_visit.empty()) {
|
||||||
auto* expr = to_visit.back();
|
auto& p = to_visit.back();
|
||||||
to_visit.pop_back();
|
to_visit.pop_back();
|
||||||
|
const ast::Expression* expr = p.expr;
|
||||||
|
|
||||||
if (auto* filtered = expr->As<EXPR_TYPE>()) {
|
if (auto* filtered = expr->template As<EXPR_TYPE>()) {
|
||||||
switch (callback(filtered)) {
|
TraverseAction result;
|
||||||
|
if constexpr (kHasDepthArg) {
|
||||||
|
result = callback(filtered, p.depth);
|
||||||
|
} else {
|
||||||
|
result = callback(filtered);
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (result) {
|
||||||
case TraverseAction::Stop:
|
case TraverseAction::Stop:
|
||||||
return true;
|
return true;
|
||||||
case TraverseAction::Skip:
|
case TraverseAction::Skip:
|
||||||
|
@ -100,32 +119,31 @@ bool TraverseExpressions(const ast::Expression* root, diag::List& diags, CALLBAC
|
||||||
bool ok = Switch(
|
bool ok = Switch(
|
||||||
expr,
|
expr,
|
||||||
[&](const IndexAccessorExpression* idx) {
|
[&](const IndexAccessorExpression* idx) {
|
||||||
push_pair(idx->object, idx->index);
|
push_pair(idx->object, idx->index, p.depth + 1);
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
[&](const BinaryExpression* bin_op) {
|
[&](const BinaryExpression* bin_op) {
|
||||||
push_pair(bin_op->lhs, bin_op->rhs);
|
push_pair(bin_op->lhs, bin_op->rhs, p.depth + 1);
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
[&](const BitcastExpression* bitcast) {
|
[&](const BitcastExpression* bitcast) {
|
||||||
to_visit.push_back(bitcast->expr);
|
push_single(bitcast->expr, p.depth + 1);
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
[&](const CallExpression* call) {
|
[&](const CallExpression* call) {
|
||||||
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
|
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
|
||||||
// the function name in the traversal. to_visit.push_back(call->func);
|
// the function name in the traversal. push_single(call->func);
|
||||||
push_list(call->args);
|
push_list(call->args, p.depth + 1);
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
[&](const MemberAccessorExpression* member) {
|
[&](const MemberAccessorExpression* member) {
|
||||||
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
|
// TODO(crbug.com/tint/1257): Resolver breaks if we actually include
|
||||||
// the member name in the traversal. push_pair(member->structure,
|
// the member name in the traversal. push_pair(member->member, p.depth + 1);
|
||||||
// member->member);
|
push_single(member->structure, p.depth + 1);
|
||||||
to_visit.push_back(member->structure);
|
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
[&](const UnaryOpExpression* unary) {
|
[&](const UnaryOpExpression* unary) {
|
||||||
to_visit.push_back(unary->expr);
|
push_single(unary->expr, p.depth + 1);
|
||||||
return true;
|
return true;
|
||||||
},
|
},
|
||||||
[&](Default) {
|
[&](Default) {
|
||||||
|
|
|
@ -73,6 +73,23 @@ TEST_F(TraverseExpressionsTest, DescendBinaryExpression) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TraverseExpressionsTest, Depth) {
|
||||||
|
std::vector<const ast::Expression*> e = {Expr(1_i), Expr(1_i), Expr(1_i), Expr(1_i)};
|
||||||
|
std::vector<const ast::Expression*> i = {Add(e[0], e[1]), Sub(e[2], e[3])};
|
||||||
|
auto* root = Mul(i[0], i[1]);
|
||||||
|
|
||||||
|
size_t j = 0;
|
||||||
|
size_t depths[] = {0, 1, 2, 2, 1, 2, 2};
|
||||||
|
{
|
||||||
|
TraverseExpressions<TraverseOrder::LeftToRight>( //
|
||||||
|
root, Diagnostics(), [&](const ast::Expression* expr, size_t depth) {
|
||||||
|
(void)expr;
|
||||||
|
EXPECT_THAT(depth, depths[j++]);
|
||||||
|
return ast::TraverseAction::Descend;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TraverseExpressionsTest, DescendBitcastExpression) {
|
TEST_F(TraverseExpressionsTest, DescendBitcastExpression) {
|
||||||
auto* e = Expr(1_i);
|
auto* e = Expr(1_i);
|
||||||
auto* b0 = Bitcast<i32>(e);
|
auto* b0 = Bitcast<i32>(e);
|
||||||
|
|
|
@ -1025,11 +1025,19 @@ sem::ForLoopStatement* Resolver::ForLoopStatement(const ast::ForLoopStatement* s
|
||||||
|
|
||||||
sem::Expression* Resolver::Expression(const ast::Expression* root) {
|
sem::Expression* Resolver::Expression(const ast::Expression* root) {
|
||||||
std::vector<const ast::Expression*> sorted;
|
std::vector<const ast::Expression*> sorted;
|
||||||
bool mark_failed = false;
|
constexpr size_t kMaxExpressionDepth = 512U;
|
||||||
|
bool failed = false;
|
||||||
if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
|
if (!ast::TraverseExpressions<ast::TraverseOrder::RightToLeft>(
|
||||||
root, diagnostics_, [&](const ast::Expression* expr) {
|
root, diagnostics_, [&](const ast::Expression* expr, size_t depth) {
|
||||||
|
if (depth > kMaxExpressionDepth) {
|
||||||
|
AddError(
|
||||||
|
"reached max expression depth of " + std::to_string(kMaxExpressionDepth),
|
||||||
|
expr->source);
|
||||||
|
failed = true;
|
||||||
|
return ast::TraverseAction::Stop;
|
||||||
|
}
|
||||||
if (!Mark(expr)) {
|
if (!Mark(expr)) {
|
||||||
mark_failed = true;
|
failed = true;
|
||||||
return ast::TraverseAction::Stop;
|
return ast::TraverseAction::Stop;
|
||||||
}
|
}
|
||||||
sorted.emplace_back(expr);
|
sorted.emplace_back(expr);
|
||||||
|
@ -1038,7 +1046,7 @@ sem::Expression* Resolver::Expression(const ast::Expression* root) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (mark_failed) {
|
if (failed) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2098,5 +2098,33 @@ TEST_F(ResolverTest, ModuleDependencyOrderedDeclarations) {
|
||||||
ElementsAre(f0, v0, a0, s0, f1, v1, a1, s1, f2, v2, a2, s2));
|
ElementsAre(f0, v0, a0, s0, f1, v1, a1, s1, f2, v2, a2, s2));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr size_t kMaxExpressionDepth = 512U;
|
||||||
|
|
||||||
|
TEST_F(ResolverTest, MaxExpressionDepth_Pass) {
|
||||||
|
auto* b = Var("b", ty.i32());
|
||||||
|
const ast::Expression* chain = nullptr;
|
||||||
|
for (size_t i = 0; i < kMaxExpressionDepth; ++i) {
|
||||||
|
chain = Add(chain ? chain : Expr("b"), Expr("b"));
|
||||||
|
}
|
||||||
|
auto* a = Let("a", nullptr, chain);
|
||||||
|
WrapInFunction(b, a);
|
||||||
|
|
||||||
|
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverTest, MaxExpressionDepth_Fail) {
|
||||||
|
auto* b = Var("b", ty.i32());
|
||||||
|
const ast::Expression* chain = nullptr;
|
||||||
|
for (size_t i = 0; i < kMaxExpressionDepth + 1; ++i) {
|
||||||
|
chain = Add(chain ? chain : Expr("b"), Expr("b"));
|
||||||
|
}
|
||||||
|
auto* a = Let("a", nullptr, chain);
|
||||||
|
WrapInFunction(b, a);
|
||||||
|
|
||||||
|
EXPECT_FALSE(r()->Resolve());
|
||||||
|
EXPECT_THAT(r()->error(), HasSubstr("error: reached max expression depth of " +
|
||||||
|
std::to_string(kMaxExpressionDepth)));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tint::resolver
|
} // namespace tint::resolver
|
||||||
|
|
Loading…
Reference in New Issue