diff --git a/src/resolver/decoration_validation_test.cc b/src/resolver/decoration_validation_test.cc index 85a26ea838..aac478453e 100644 --- a/src/resolver/decoration_validation_test.cc +++ b/src/resolver/decoration_validation_test.cc @@ -332,7 +332,6 @@ TEST_P(FunctionDecorationTest, IsValid) { ast::DecorationList decos = createDecorations(Source{{12, 34}}, *this, params.kind); - decos.emplace_back(Stage(ast::PipelineStage::kCompute)); Func("foo", ast::VariableList{}, ty.void_(), ast::StatementList{}, decos); if (params.should_pass) { @@ -355,10 +354,10 @@ INSTANTIATE_TEST_SUITE_P( TestParams{DecorationKind::kOverride, false}, TestParams{DecorationKind::kOffset, false}, TestParams{DecorationKind::kSize, false}, - // Skip kStage as we always apply it in this test + // Skip kStage as we do not apply it in this test TestParams{DecorationKind::kStride, false}, TestParams{DecorationKind::kStructBlock, false}, - TestParams{DecorationKind::kWorkgroup, true}, + // Skip kWorkgroup as this is a different error TestParams{DecorationKind::kBindingAndGroup, false})); } // namespace @@ -658,5 +657,46 @@ TEST_F(ResourceDecorationTest, BindingPointOnNonResource) { } // namespace } // namespace ResourceTests +namespace WorkgroupDecorationTests { +namespace { + +using WorkgroupDecoration = ResolverTest; + +TEST_F(WorkgroupDecoration, NotAnEntryPoint) { + Func("main", {}, ty.void_(), {}, + {create(Source{{12, 34}}, 1u)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: the workgroup_size attribute is only valid for " + "compute stages"); +} + +TEST_F(WorkgroupDecoration, NotAComputeShader) { + Func("main", {}, ty.void_(), {}, + {Stage(ast::PipelineStage::kFragment), + create(Source{{12, 34}}, 1u)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: the workgroup_size attribute is only valid for " + "compute stages"); +} + +TEST_F(WorkgroupDecoration, MultipleAttributes) { + Func(Source{{12, 34}}, "main", {}, ty.void_(), {}, + {Stage(ast::PipelineStage::kCompute), + create(1u), + create(2u)}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: only one workgroup_size attribute permitted per " + "entry point"); +} + +} // namespace +} // namespace WorkgroupDecorationTests + } // namespace resolver } // namespace tint diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index d5166d9176..145bdc0381 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -823,6 +823,38 @@ bool Resolver::ValidateFunction(const ast::Function* func, return false; } + auto stage_deco_count = 0; + auto workgroup_deco_count = 0; + for (auto* deco : func->decorations()) { + if (deco->Is()) { + stage_deco_count++; + } else if (deco->Is()) { + workgroup_deco_count++; + if (func->pipeline_stage() != ast::PipelineStage::kCompute) { + diagnostics_.add_error( + "the workgroup_size attribute is only valid for compute stages", + deco->source()); + return false; + } + } else if (!deco->Is()) { + diagnostics_.add_error("decoration is not valid for functions", + deco->source()); + return false; + } + } + if (stage_deco_count > 1) { + diagnostics_.add_error( + "v-0020", "only one stage decoration permitted per entry point", + func->source()); + return false; + } + if (workgroup_deco_count > 1) { + diagnostics_.add_error( + "only one workgroup_size attribute permitted per entry point", + func->source()); + return false; + } + for (auto* param : func->params()) { if (!ValidateParameter(variable_to_info_.at(param))) { return false; @@ -867,23 +899,6 @@ bool Resolver::ValidateFunction(const ast::Function* func, bool Resolver::ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info) { - auto stage_deco_count = 0; - for (auto* deco : func->decorations()) { - if (deco->Is()) { - stage_deco_count++; - } else if (!deco->Is()) { - diagnostics_.add_error("decoration is not valid for functions", - deco->source()); - return false; - } - } - if (stage_deco_count > 1) { - diagnostics_.add_error( - "v-0020", "only one stage decoration permitted per entry point", - func->source()); - return false; - } - // Use a lambda to validate the entry point decorations for a type. // Persistent state is used to track which builtins and locations have already // been seen, in order to catch conflicts. diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index 16b1c0dd5b..fc69ccf57a 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc @@ -74,6 +74,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) { Return(), }, ast::DecorationList{ + Stage(ast::PipelineStage::kCompute), create(2u, 4u, 6u), }); @@ -82,7 +83,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) { gen.increment_indent(); ASSERT_TRUE(gen.EmitFunction(func)); - EXPECT_EQ(gen.result(), R"( [[workgroup_size(2, 4, 6)]] + EXPECT_EQ(gen.result(), R"( [[stage(compute), workgroup_size(2, 4, 6)]] fn my_func() { discard; return; @@ -113,30 +114,6 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Stage) { )"); } -TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Multiple) { - auto* func = Func("my_func", ast::VariableList{}, ty.void_(), - ast::StatementList{ - create(), - Return(), - }, - ast::DecorationList{ - Stage(ast::PipelineStage::kFragment), - create(2u, 4u, 6u), - }); - - GeneratorImpl& gen = Build(); - - gen.increment_indent(); - - ASSERT_TRUE(gen.EmitFunction(func)); - EXPECT_EQ(gen.result(), R"( [[stage(fragment), workgroup_size(2, 4, 6)]] - fn my_func() { - discard; - return; - } -)"); -} - TEST_F(WgslGeneratorImplTest, Emit_Function_EntryPoint_Parameters) { auto vec4 = ty.vec4(); auto* coord = Param("coord", vec4, {Builtin(ast::Builtin::kPosition)});