diff --git a/src/resolver/builtins_validation_test.cc b/src/resolver/builtins_validation_test.cc index 23660f7375..6c0f20a115 100644 --- a/src/resolver/builtins_validation_test.cc +++ b/src/resolver/builtins_validation_test.cc @@ -39,6 +39,18 @@ TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_Struct_Fail) { "12:34 error: store type of builtin(position) must be 'vec4'"); } +TEST_F(ResolverBuiltinsValidationTest, PositionNotF32_ReturnType_Fail) { + // [[stage(vertex)]] + // fn main() -> [[builtin(position)]] f32 { return 1.0; } + Func("main", {}, ty.f32(), {Return(1.0f)}, + {Stage(ast::PipelineStage::kVertex)}, + {Builtin(Source{{12, 34}}, ast::Builtin::kPosition)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: store type of builtin(position) must be 'vec4'"); +} + TEST_F(ResolverBuiltinsValidationTest, FragDepthNotF32_Struct_Fail) { // struct MyInputs { // [[builtin(kFragDepth)]] p: i32; @@ -77,6 +89,18 @@ TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_Struct_Fail) { "12:34 error: store type of builtin(sample_mask) must be 'u32'"); } +TEST_F(ResolverBuiltinsValidationTest, SampleMaskNotU32_ReturnType_Fail) { + // [[stage(fragment)]] + // fn main() -> [[builtin(sample_mask)]] i32 { return 1; } + Func("main", {}, ty.i32(), {Return(1)}, + {Stage(ast::PipelineStage::kFragment)}, + {Builtin(Source{{12, 34}}, ast::Builtin::kSampleMask)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: store type of builtin(sample_mask) must be 'u32'"); +} + TEST_F(ResolverBuiltinsValidationTest, SampleMaskIsNotU32_Fail) { // [[stage(fragment)]] // fn fs_main( diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc index 8164445f0a..8702d2d876 100644 --- a/src/resolver/decoration_validation_test.cc +++ b/src/resolver/decoration_validation_test.cc @@ -230,8 +230,7 @@ TEST_P(FunctionReturnTypeDecorationTest, IsValid) { auto& params = GetParam(); Func("main", ast::VariableList{}, ty.f32(), ast::StatementList{Return(1.f)}, - ast::DecorationList{Stage(ast::PipelineStage::kCompute)}, - createDecorations({}, *this, params.kind)); + {}, createDecorations({}, *this, params.kind)); if (params.should_pass) { EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -244,6 +243,41 @@ TEST_P(FunctionReturnTypeDecorationTest, IsValid) { INSTANTIATE_TEST_SUITE_P( ResolverDecorationValidationTest, FunctionReturnTypeDecorationTest, + testing::Values(TestParams{DecorationKind::kAlign, false}, + TestParams{DecorationKind::kBinding, false}, + TestParams{DecorationKind::kBuiltin, false}, + TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kLocation, false}, + TestParams{DecorationKind::kOverride, false}, + TestParams{DecorationKind::kOffset, false}, + TestParams{DecorationKind::kSize, false}, + TestParams{DecorationKind::kStage, false}, + TestParams{DecorationKind::kStride, false}, + TestParams{DecorationKind::kStructBlock, false}, + TestParams{DecorationKind::kWorkgroup, false}, + TestParams{DecorationKind::kBindingAndGroup, false})); + +using EntryPointReturnTypeDecorationTest = TestWithParams; +TEST_P(EntryPointReturnTypeDecorationTest, IsValid) { + auto& params = GetParam(); + + Func("main", ast::VariableList{}, ty.vec4(), + {Return(Construct(ty.vec4(), 1.f))}, + {Stage(ast::PipelineStage::kCompute)}, + createDecorations({}, *this, params.kind)); + + if (params.should_pass) { + EXPECT_TRUE(r()->Resolve()) << r()->error(); + } else { + EXPECT_FALSE(r()->Resolve()) << r()->error(); + EXPECT_EQ(r()->error(), + "error: decoration is not valid for entry point return types"); + } +} + +INSTANTIATE_TEST_SUITE_P( + ResolverDecorationValidationTest, + EntryPointReturnTypeDecorationTest, testing::Values(TestParams{DecorationKind::kAlign, false}, TestParams{DecorationKind::kBinding, false}, TestParams{DecorationKind::kBuiltin, true}, @@ -258,9 +292,9 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kWorkgroup, false}, TestParams{DecorationKind::kBindingAndGroup, false})); -TEST_F(FunctionReturnTypeDecorationTest, DuplicateDecoration) { +TEST_F(EntryPointReturnTypeDecorationTest, DuplicateDecoration) { Func("main", ast::VariableList{}, ty.f32(), ast::StatementList{Return(1.f)}, - ast::DecorationList{Stage(ast::PipelineStage::kCompute)}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}, ast::DecorationList{ Location(Source{{12, 34}}, 2), Location(Source{{56, 78}}, 3), diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 405f15cf20..0a083b09f5 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -1084,11 +1084,21 @@ bool Resolver::ValidateFunction(const ast::Function* func, } for (auto* deco : func->return_type_decorations()) { - if (!deco->IsAnyOf()) { + if (!func->IsEntryPoint()) { AddError("decoration is not valid for function return types", deco->source()); return false; } + + if (auto* builtin = deco->As()) { + if (!ValidateBuiltinDecoration(builtin, info->return_type)) { + return false; + } + } else if (!deco->Is()) { + AddError("decoration is not valid for entry point return types", + deco->source()); + return false; + } } }