resolver: Validate that entry points are not called

Fixed: chromium:1245112
Change-Id: Ibe7e686e7688761fd681bb6b1d331adad84f8d61
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/63161
Auto-Submit: Ben Clayton <bclayton@google.com>
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2021-08-31 16:14:46 +00:00 committed by Tint LUCI CQ
parent 293d313bbc
commit b09723e58b
3 changed files with 41 additions and 15 deletions

View File

@ -26,8 +26,8 @@ class ResolverFunctionValidationTest : public resolver::TestHelper,
public testing::Test {};
TEST_F(ResolverFunctionValidationTest, FunctionNamesMustBeUnique_fail) {
// fn func -> i32 { return 2; }
// fn func -> i32 { return 2; }
// fn func() -> i32 { return 2; }
// fn func() -> i32 { return 2; }
Func(Source{{56, 78}}, "func", ast::VariableList{}, ty.i32(),
ast::StatementList{
Return(2),
@ -189,7 +189,7 @@ TEST_F(ResolverFunctionValidationTest, UnreachableCode_return) {
}
TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatement_Fail) {
// fn func -> int { var a:i32 = 2; }
// fn func() -> int { var a:i32 = 2; }
auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2));
@ -216,7 +216,7 @@ TEST_F(ResolverFunctionValidationTest,
TEST_F(ResolverFunctionValidationTest,
FunctionEndWithoutReturnStatementEmptyBody_Fail) {
// fn func -> int {}
// fn func() -> int {}
Func(Source{Source::Location{12, 34}}, "func", ast::VariableList{}, ty.i32(),
ast::StatementList{}, ast::DecorationList{});
@ -269,7 +269,7 @@ TEST_F(ResolverFunctionValidationTest,
TEST_F(ResolverFunctionValidationTest,
FunctionTypeMustMatchReturnStatementTypeMissing_fail) {
// fn func -> f32 { return; }
// fn func() -> f32 { return; }
Func("func", ast::VariableList{}, ty.f32(),
ast::StatementList{
Return(Source{Source::Location{12, 34}}, nullptr),
@ -284,7 +284,7 @@ TEST_F(ResolverFunctionValidationTest,
TEST_F(ResolverFunctionValidationTest,
FunctionTypeMustMatchReturnStatementTypeF32_pass) {
// fn func -> f32 { return 2.0; }
// fn func() -> f32 { return 2.0; }
Func("func", ast::VariableList{}, ty.f32(),
ast::StatementList{
Return(Source{Source::Location{12, 34}}, Expr(2.f)),
@ -296,7 +296,7 @@ TEST_F(ResolverFunctionValidationTest,
TEST_F(ResolverFunctionValidationTest,
FunctionTypeMustMatchReturnStatementTypeF32_fail) {
// fn func -> f32 { return 2; }
// fn func() -> f32 { return 2; }
Func("func", ast::VariableList{}, ty.f32(),
ast::StatementList{
Return(Source{Source::Location{12, 34}}, Expr(2)),
@ -312,7 +312,7 @@ TEST_F(ResolverFunctionValidationTest,
TEST_F(ResolverFunctionValidationTest,
FunctionTypeMustMatchReturnStatementTypeF32Alias_pass) {
// type myf32 = f32;
// fn func -> myf32 { return 2.0; }
// fn func() -> myf32 { return 2.0; }
auto* myf32 = Alias("myf32", ty.f32());
Func("func", ast::VariableList{}, ty.Of(myf32),
ast::StatementList{
@ -326,7 +326,7 @@ TEST_F(ResolverFunctionValidationTest,
TEST_F(ResolverFunctionValidationTest,
FunctionTypeMustMatchReturnStatementTypeF32Alias_fail) {
// type myf32 = f32;
// fn func -> myf32 { return 2; }
// fn func() -> myf32 { return 2; }
auto* myf32 = Alias("myf32", ty.f32());
Func("func", ast::VariableList{}, ty.Of(myf32),
ast::StatementList{
@ -340,6 +340,24 @@ TEST_F(ResolverFunctionValidationTest,
"type, returned 'u32', expected 'myf32'");
}
TEST_F(ResolverFunctionValidationTest, CannotCallEntryPoint) {
// [[stage(compute), workgroup_size(1)]] fn entrypoint() {}
// fn func() { return entrypoint(); }
Func("entrypoint", ast::VariableList{}, ty.void_(), {},
{Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)});
Func("func", ast::VariableList{}, ty.void_(),
{
create<ast::CallStatement>(Call(Source{{12, 34}}, "entrypoint")),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
R"(12:34 error: entry point functions cannot be the target of a function call)");
}
TEST_F(ResolverFunctionValidationTest, PipelineStage_MustBeUnique_Fail) {
// [[stage(fragment)]]
// [[stage(vertex)]]

View File

@ -2553,22 +2553,30 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) {
}
bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
const FunctionInfo* callee_func) {
const FunctionInfo* target) {
auto* ident = call->func();
auto name = builder_->Symbols().NameFor(ident->symbol());
if (call->params().size() != callee_func->parameters.size()) {
bool more = call->params().size() > callee_func->parameters.size();
if (target->declaration->IsEntryPoint()) {
// https://www.w3.org/TR/WGSL/#function-restriction
// An entry point must never be the target of a function call.
AddError("entry point functions cannot be the target of a function call",
call->source());
return false;
}
if (call->params().size() != target->parameters.size()) {
bool more = call->params().size() > target->parameters.size();
AddError("too " + (more ? std::string("many") : std::string("few")) +
" arguments in call to '" + name + "', expected " +
std::to_string(callee_func->parameters.size()) + ", got " +
std::to_string(target->parameters.size()) + ", got " +
std::to_string(call->params().size()),
call->source());
return false;
}
for (size_t i = 0; i < call->params().size(); ++i) {
const VariableInfo* param = callee_func->parameters[i];
const VariableInfo* param = target->parameters[i];
const ast::Expression* arg_expr = call->params()[i];
auto* arg_type = TypeOf(arg_expr)->UnwrapRef();

View File

@ -281,7 +281,7 @@ class Resolver {
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunctionCall(const ast::CallExpression* call,
const FunctionInfo* info);
const FunctionInfo* target);
bool ValidateGlobalVariable(const VariableInfo* var);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type);