diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc index ddf7b71789..8164445f0a 100644 --- a/src/resolver/decoration_validation_test.cc +++ b/src/resolver/decoration_validation_test.cc @@ -126,6 +126,105 @@ static ast::DecorationList createDecorations(const Source& source, return {}; } +using FunctionParameterDecorationTest = TestWithParams; +TEST_P(FunctionParameterDecorationTest, IsValid) { + auto& params = GetParam(); + + Func("main", + ast::VariableList{Param("a", ty.vec4(), + createDecorations({}, *this, params.kind))}, + ty.void_(), {}); + + if (params.should_pass) { + EXPECT_TRUE(r()->Resolve()) << r()->error(); + } else { + EXPECT_FALSE(r()->Resolve()) << r()->error(); + EXPECT_EQ(r()->error(), + "error: decoration is not valid for function parameters"); + } +} +INSTANTIATE_TEST_SUITE_P( + ResolverDecorationValidationTest, + FunctionParameterDecorationTest, + testing::Values(TestParams{DecorationKind::kAlign, false}, + TestParams{DecorationKind::kBinding, false}, + TestParams{DecorationKind::kBuiltin, false}, + TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kLocation, false}, + TestParams{DecorationKind::kOverride, false}, + TestParams{DecorationKind::kOffset, false}, + TestParams{DecorationKind::kSize, false}, + TestParams{DecorationKind::kStage, false}, + TestParams{DecorationKind::kStride, false}, + TestParams{DecorationKind::kStructBlock, false}, + TestParams{DecorationKind::kWorkgroup, false}, + TestParams{DecorationKind::kBindingAndGroup, false})); + +using EntryPointParameterDecorationTest = TestWithParams; +TEST_P(EntryPointParameterDecorationTest, IsValid) { + auto& params = GetParam(); + + Func("main", + ast::VariableList{Param("a", ty.vec4(), + createDecorations({}, *this, params.kind))}, + ty.void_(), {}, + ast::DecorationList{Stage(ast::PipelineStage::kFragment)}); + + if (params.should_pass) { + EXPECT_TRUE(r()->Resolve()) << r()->error(); + } else { + EXPECT_FALSE(r()->Resolve()) << r()->error(); + EXPECT_EQ(r()->error(), + "error: decoration is not valid for function parameters"); + } +} +INSTANTIATE_TEST_SUITE_P( + ResolverDecorationValidationTest, + EntryPointParameterDecorationTest, + testing::Values(TestParams{DecorationKind::kAlign, false}, + TestParams{DecorationKind::kBinding, false}, + TestParams{DecorationKind::kBuiltin, true}, + TestParams{DecorationKind::kGroup, false}, + TestParams{DecorationKind::kLocation, true}, + TestParams{DecorationKind::kOverride, false}, + TestParams{DecorationKind::kOffset, false}, + TestParams{DecorationKind::kSize, false}, + TestParams{DecorationKind::kStage, false}, + TestParams{DecorationKind::kStride, false}, + TestParams{DecorationKind::kStructBlock, false}, + TestParams{DecorationKind::kWorkgroup, false}, + TestParams{DecorationKind::kBindingAndGroup, false})); + +TEST_F(EntryPointParameterDecorationTest, DuplicateDecoration) { + Func("main", ast::VariableList{}, ty.f32(), ast::StatementList{Return(1.f)}, + {Stage(ast::PipelineStage::kFragment)}, + { + Location(Source{{12, 34}}, 2), + Location(Source{{56, 78}}, 3), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + R"(56:78 error: duplicate location decoration +12:34 note: first decoration declared here)"); +} + +TEST_F(EntryPointParameterDecorationTest, DuplicateInternalDecoration) { + auto* s = + Param("s", ty.sampler(ast::SamplerKind::kSampler), + ast::DecorationList{ + create(0), + create(0), + ASTNodes().Create( + ID(), ast::DisabledValidation::kBindingPointCollision), + ASTNodes().Create( + ID(), ast::DisabledValidation::kEntryPointParameter), + }); + Func("f", {s}, ty.void_(), {}, {Stage(ast::PipelineStage::kFragment)}); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + using FunctionReturnTypeDecorationTest = TestWithParams; TEST_P(FunctionReturnTypeDecorationTest, IsValid) { auto& params = GetParam(); @@ -385,22 +484,6 @@ TEST_F(VariableDecorationTest, DuplicateDecoration) { 12:34 note: first decoration declared here)"); } -TEST_F(VariableDecorationTest, DuplicateInternalDecoration) { - auto* s = - Param("s", ty.sampler(ast::SamplerKind::kSampler), - ast::DecorationList{ - create(0), - create(0), - ASTNodes().Create( - ID(), ast::DisabledValidation::kBindingPointCollision), - ASTNodes().Create( - ID(), ast::DisabledValidation::kEntryPointParameter), - }); - Func("f", {s}, ty.void_(), {}); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); -} - using ConstantDecorationTest = TestWithParams; TEST_P(ConstantDecorationTest, IsValid) { auto& params = GetParam(); diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index ce2e608577..405f15cf20 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -927,15 +927,30 @@ bool Resolver::ValidateVariable(const VariableInfo* info) { return true; } -bool Resolver::ValidateParameter(const VariableInfo* info) { +bool Resolver::ValidateParameter(const ast::Function* func, + const VariableInfo* info) { if (!ValidateVariable(info)) { return false; } for (auto* deco : info->declaration->decorations()) { + if (!func->IsEntryPoint()) { + AddError("decoration is not valid for function parameters", + deco->source()); + return false; + } + if (auto* builtin = deco->As()) { if (!ValidateBuiltinDecoration(builtin, info->type)) { return false; } + } else if (!deco->IsAnyOf() && + !IsValidationDisabled( + info->declaration->decorations(), + ast::DisabledValidation::kEntryPointParameter)) { + AddError("decoration is not valid for function parameters", + deco->source()); + return false; } } return true; @@ -1041,7 +1056,7 @@ bool Resolver::ValidateFunction(const ast::Function* func, } for (auto* param : func->params()) { - if (!ValidateParameter(variable_to_info_.at(param))) { + if (!ValidateParameter(func, variable_to_info_.at(param))) { return false; } } diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 531912fe82..336b44c128 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -279,7 +279,7 @@ class Resolver { bool ValidateMatrix(const sem::Matrix* matirx_type, const Source& source); bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor, const sem::Matrix* matrix_type); - bool ValidateParameter(const VariableInfo* info); + bool ValidateParameter(const ast::Function* func, const VariableInfo* info); bool ValidateReturn(const ast::ReturnStatement* ret); bool ValidateStatements(const ast::StatementList& stmts); bool ValidateStorageTexture(const ast::StorageTexture* t); diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc index ebac0b9feb..eaf9ce9b82 100644 --- a/src/transform/canonicalize_entry_point_io.cc +++ b/src/transform/canonicalize_entry_point_io.cc @@ -151,6 +151,13 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap& data) { << "nested pipeline IO struct"; } + ast::DecorationList new_decorations = RemoveDecorations( + &ctx, member->Declaration()->decorations(), + [](const ast::Decoration* deco) { + return !deco->IsAnyOf(); + }); + if (cfg->builtin_style == BuiltinStyle::kParameter && ast::HasDecoration( member->Declaration()->decorations())) { @@ -158,19 +165,12 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap& data) { // parameters, then move it to the parameter list. auto* member_ty = CreateASTTypeFor(&ctx, member->Type()); auto new_param_name = ctx.dst->Sym(); - new_parameters.push_back(ctx.dst->Param( - new_param_name, member_ty, - ctx.Clone(member->Declaration()->decorations()))); + new_parameters.push_back( + ctx.dst->Param(new_param_name, member_ty, new_decorations)); init_values.push_back(ctx.dst->Expr(new_param_name)); continue; } - ast::DecorationList new_decorations = RemoveDecorations( - &ctx, member->Declaration()->decorations(), - [](const ast::Decoration* deco) { - return !deco->IsAnyOf(); - }); auto member_name = ctx.Clone(member->Declaration()->symbol()); auto* member_type = ctx.Clone(member->Declaration()->type()); new_struct_members.push_back(