diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 08afc35e10..8b9e63c6e2 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -44,6 +44,7 @@ #include "src/ast/scalar_constructor_expression.h" #include "src/ast/set_decoration.h" #include "src/ast/sint_literal.h" +#include "src/ast/stage_decoration.h" #include "src/ast/statement.h" #include "src/ast/struct.h" #include "src/ast/struct_member.h" @@ -434,6 +435,9 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { out_ << "workgroup_size(" << std::to_string(x) << ", " << std::to_string(y) << ", " << std::to_string(z) << ")"; } + if (deco->IsStage()) { + out_ << "stage(" << deco->AsStage()->value() << ")"; + } out_ << "]]" << std::endl; } diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index 25bac24ed0..95ad8448af 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc @@ -15,7 +15,9 @@ #include "gtest/gtest.h" #include "src/ast/discard_statement.h" #include "src/ast/function.h" +#include "src/ast/pipeline_stage.h" #include "src/ast/return_statement.h" +#include "src/ast/stage_decoration.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/void_type.h" @@ -78,7 +80,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithParams) { )"); } -TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecorations) { +TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) { auto body = std::make_unique(); body->append(std::make_unique()); body->append(std::make_unique()); @@ -100,6 +102,54 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecorations) { )"); } +TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Stage) { + auto body = std::make_unique(); + body->append(std::make_unique()); + body->append(std::make_unique()); + + ast::type::VoidType void_type; + ast::Function func("my_func", {}, &void_type); + func.add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + func.set_body(std::move(body)); + + GeneratorImpl g; + g.increment_indent(); + + ASSERT_TRUE(g.EmitFunction(&func)); + EXPECT_EQ(g.result(), R"( [[stage(fragment)]] + fn my_func() -> void { + discard; + return; + } +)"); +} + +TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Multiple) { + auto body = std::make_unique(); + body->append(std::make_unique()); + body->append(std::make_unique()); + + ast::type::VoidType void_type; + ast::Function func("my_func", {}, &void_type); + func.add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + func.add_decoration(std::make_unique(2u, 4u, 6u)); + func.set_body(std::move(body)); + + GeneratorImpl g; + g.increment_indent(); + + ASSERT_TRUE(g.EmitFunction(&func)); + EXPECT_EQ(g.result(), R"( [[stage(fragment)]] + [[workgroup_size(2, 4, 6)]] + fn my_func() -> void { + discard; + return; + } +)"); +} + } // namespace } // namespace wgsl } // namespace writer