diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc index b0b8437f5d..5f38b877ea 100644 --- a/src/resolver/function_validation_test.cc +++ b/src/resolver/function_validation_test.cc @@ -26,8 +26,8 @@ class ResolverFunctionValidationTest : public resolver::TestHelper, public testing::Test {}; TEST_F(ResolverFunctionValidationTest, FunctionNamesMustBeUnique_fail) { - // fn func -> i32 { return 2; } - // fn func -> i32 { return 2; } + // fn func() -> i32 { return 2; } + // fn func() -> i32 { return 2; } Func(Source{{56, 78}}, "func", ast::VariableList{}, ty.i32(), ast::StatementList{ Return(2), @@ -189,7 +189,7 @@ TEST_F(ResolverFunctionValidationTest, UnreachableCode_return) { } TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatement_Fail) { - // fn func -> int { var a:i32 = 2; } + // fn func() -> int { var a:i32 = 2; } auto* var = Var("a", ty.i32(), ast::StorageClass::kNone, Expr(2)); @@ -216,7 +216,7 @@ TEST_F(ResolverFunctionValidationTest, TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatementEmptyBody_Fail) { - // fn func -> int {} + // fn func() -> int {} Func(Source{Source::Location{12, 34}}, "func", ast::VariableList{}, ty.i32(), ast::StatementList{}, ast::DecorationList{}); @@ -269,7 +269,7 @@ TEST_F(ResolverFunctionValidationTest, TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeMissing_fail) { - // fn func -> f32 { return; } + // fn func() -> f32 { return; } Func("func", ast::VariableList{}, ty.f32(), ast::StatementList{ Return(Source{Source::Location{12, 34}}, nullptr), @@ -284,7 +284,7 @@ TEST_F(ResolverFunctionValidationTest, TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeF32_pass) { - // fn func -> f32 { return 2.0; } + // fn func() -> f32 { return 2.0; } Func("func", ast::VariableList{}, ty.f32(), ast::StatementList{ Return(Source{Source::Location{12, 34}}, Expr(2.f)), @@ -296,7 +296,7 @@ TEST_F(ResolverFunctionValidationTest, TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) { - // fn func -> f32 { return 2; } + // fn func() -> f32 { return 2; } Func("func", ast::VariableList{}, ty.f32(), ast::StatementList{ Return(Source{Source::Location{12, 34}}, Expr(2)), @@ -312,7 +312,7 @@ TEST_F(ResolverFunctionValidationTest, TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeF32Alias_pass) { // type myf32 = f32; - // fn func -> myf32 { return 2.0; } + // fn func() -> myf32 { return 2.0; } auto* myf32 = Alias("myf32", ty.f32()); Func("func", ast::VariableList{}, ty.Of(myf32), ast::StatementList{ @@ -326,7 +326,7 @@ TEST_F(ResolverFunctionValidationTest, TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeF32Alias_fail) { // type myf32 = f32; - // fn func -> myf32 { return 2; } + // fn func() -> myf32 { return 2; } auto* myf32 = Alias("myf32", ty.f32()); Func("func", ast::VariableList{}, ty.Of(myf32), ast::StatementList{ @@ -340,6 +340,24 @@ TEST_F(ResolverFunctionValidationTest, "type, returned 'u32', expected 'myf32'"); } +TEST_F(ResolverFunctionValidationTest, CannotCallEntryPoint) { + // [[stage(compute), workgroup_size(1)]] fn entrypoint() {} + // fn func() { return entrypoint(); } + Func("entrypoint", ast::VariableList{}, ty.void_(), {}, + {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1)}); + + Func("func", ast::VariableList{}, ty.void_(), + { + create(Call(Source{{12, 34}}, "entrypoint")), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ( + r()->error(), + + R"(12:34 error: entry point functions cannot be the target of a function call)"); +} + TEST_F(ResolverFunctionValidationTest, PipelineStage_MustBeUnique_Fail) { // [[stage(fragment)]] // [[stage(vertex)]] diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index b61cf818d4..1c9e86705b 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -2553,22 +2553,30 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) { } bool Resolver::ValidateFunctionCall(const ast::CallExpression* call, - const FunctionInfo* callee_func) { + const FunctionInfo* target) { auto* ident = call->func(); auto name = builder_->Symbols().NameFor(ident->symbol()); - if (call->params().size() != callee_func->parameters.size()) { - bool more = call->params().size() > callee_func->parameters.size(); + if (target->declaration->IsEntryPoint()) { + // https://www.w3.org/TR/WGSL/#function-restriction + // An entry point must never be the target of a function call. + AddError("entry point functions cannot be the target of a function call", + call->source()); + return false; + } + + if (call->params().size() != target->parameters.size()) { + bool more = call->params().size() > target->parameters.size(); AddError("too " + (more ? std::string("many") : std::string("few")) + " arguments in call to '" + name + "', expected " + - std::to_string(callee_func->parameters.size()) + ", got " + + std::to_string(target->parameters.size()) + ", got " + std::to_string(call->params().size()), call->source()); return false; } for (size_t i = 0; i < call->params().size(); ++i) { - const VariableInfo* param = callee_func->parameters[i]; + const VariableInfo* param = target->parameters[i]; const ast::Expression* arg_expr = call->params()[i]; auto* arg_type = TypeOf(arg_expr)->UnwrapRef(); diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index dbfe50679a..9f02791d5d 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -281,7 +281,7 @@ class Resolver { bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info); bool ValidateFunction(const ast::Function* func, const FunctionInfo* info); bool ValidateFunctionCall(const ast::CallExpression* call, - const FunctionInfo* info); + const FunctionInfo* target); bool ValidateGlobalVariable(const VariableInfo* var); bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco, const sem::Type* storage_type);