diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 4ed643d050..bf66a22868 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -1354,6 +1354,13 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, return false; } + if (ep->stage() == ast::PipelineStage::kCompute) { + // TODO(dsinclair): When we have a way to set the thread group size this + // should be updated. + out << "[numthreads(1, 1, 1)]" << std::endl; + make_indent(out); + } + auto outdata = ep_name_to_out_data_.find(current_ep_name_); bool has_outdata = outdata != ep_name_to_out_data_.end(); if (has_outdata) { diff --git a/src/writer/hlsl/generator_impl_entry_point_test.cc b/src/writer/hlsl/generator_impl_entry_point_test.cc index 2177806344..189c2abe3a 100644 --- a/src/writer/hlsl/generator_impl_entry_point_test.cc +++ b/src/writer/hlsl/generator_impl_entry_point_test.cc @@ -19,6 +19,7 @@ #include "src/ast/location_decoration.h" #include "src/ast/member_accessor_expression.h" #include "src/ast/module.h" +#include "src/ast/return_statement.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/vector_type.h" diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index 51b6830dda..d79394b365 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -1139,10 +1139,10 @@ void ep_2() { TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_WithName) { ast::type::VoidType void_type; - auto func = std::make_unique("comp_main", ast::VariableList{}, + auto func = std::make_unique("frag_main", ast::VariableList{}, &void_type); - auto ep = std::make_unique(ast::PipelineStage::kCompute, - "my_main", "comp_main"); + auto ep = std::make_unique(ast::PipelineStage::kFragment, + "my_main", "frag_main"); mod()->AddFunction(std::move(func)); mod()->AddEntryPoint(std::move(ep)); @@ -1158,10 +1158,10 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_WithNameCollision) { ast::type::VoidType void_type; - auto func = std::make_unique("comp_main", ast::VariableList{}, + auto func = std::make_unique("frag_main", ast::VariableList{}, &void_type); - auto ep = std::make_unique(ast::PipelineStage::kCompute, - "GeometryShader", "comp_main"); + auto ep = std::make_unique(ast::PipelineStage::kFragment, + "GeometryShader", "frag_main"); mod()->AddFunction(std::move(func)); mod()->AddEntryPoint(std::move(ep)); @@ -1173,6 +1173,33 @@ TEST_F(HlslGeneratorImplTest_Function, )"); } +TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_Compute) { + ast::type::VoidType void_type; + + ast::VariableList params; + auto func = std::make_unique("comp_main", std::move(params), + &void_type); + + auto body = std::make_unique(); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod()->AddFunction(std::move(func)); + + auto ep = std::make_unique(ast::PipelineStage::kCompute, + "main", "comp_main"); + mod()->AddEntryPoint(std::move(ep)); + + ASSERT_TRUE(td().Determine()) << td().error(); + ASSERT_TRUE(gen().Generate(out())) << gen().error(); + EXPECT_EQ(result(), R"([numthreads(1, 1, 1)] +void main() { + return; +} + +)"); +} + TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { ast::type::F32Type f32; ast::type::ArrayType ary(&f32, 5);