[hlsl-writer] Emit numthreads for compute shaders.

This CL adds the numthreads annotation when emitting compute shaders.

Bug: tint:7
Change-Id: Ie0f47adfca0a0684f701f280958163b3da0019b4
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/27480
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-08-26 19:24:26 +00:00 committed by Commit Bot service account
parent 42b0e2d5af
commit fea2636945
3 changed files with 41 additions and 6 deletions

View File

@ -1354,6 +1354,13 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
return false; return false;
} }
if (ep->stage() == ast::PipelineStage::kCompute) {
// TODO(dsinclair): When we have a way to set the thread group size this
// should be updated.
out << "[numthreads(1, 1, 1)]" << std::endl;
make_indent(out);
}
auto outdata = ep_name_to_out_data_.find(current_ep_name_); auto outdata = ep_name_to_out_data_.find(current_ep_name_);
bool has_outdata = outdata != ep_name_to_out_data_.end(); bool has_outdata = outdata != ep_name_to_out_data_.end();
if (has_outdata) { if (has_outdata) {

View File

@ -19,6 +19,7 @@
#include "src/ast/location_decoration.h" #include "src/ast/location_decoration.h"
#include "src/ast/member_accessor_expression.h" #include "src/ast/member_accessor_expression.h"
#include "src/ast/module.h" #include "src/ast/module.h"
#include "src/ast/return_statement.h"
#include "src/ast/type/f32_type.h" #include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h" #include "src/ast/type/i32_type.h"
#include "src/ast/type/vector_type.h" #include "src/ast/type/vector_type.h"

View File

@ -1139,10 +1139,10 @@ void ep_2() {
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_WithName) { TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_WithName) {
ast::type::VoidType void_type; ast::type::VoidType void_type;
auto func = std::make_unique<ast::Function>("comp_main", ast::VariableList{}, auto func = std::make_unique<ast::Function>("frag_main", ast::VariableList{},
&void_type); &void_type);
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute, auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
"my_main", "comp_main"); "my_main", "frag_main");
mod()->AddFunction(std::move(func)); mod()->AddFunction(std::move(func));
mod()->AddEntryPoint(std::move(ep)); mod()->AddEntryPoint(std::move(ep));
@ -1158,10 +1158,10 @@ TEST_F(HlslGeneratorImplTest_Function,
Emit_Function_EntryPoint_WithNameCollision) { Emit_Function_EntryPoint_WithNameCollision) {
ast::type::VoidType void_type; ast::type::VoidType void_type;
auto func = std::make_unique<ast::Function>("comp_main", ast::VariableList{}, auto func = std::make_unique<ast::Function>("frag_main", ast::VariableList{},
&void_type); &void_type);
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute, auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kFragment,
"GeometryShader", "comp_main"); "GeometryShader", "frag_main");
mod()->AddFunction(std::move(func)); mod()->AddFunction(std::move(func));
mod()->AddEntryPoint(std::move(ep)); mod()->AddEntryPoint(std::move(ep));
@ -1173,6 +1173,33 @@ TEST_F(HlslGeneratorImplTest_Function,
)"); )");
} }
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_EntryPoint_Compute) {
ast::type::VoidType void_type;
ast::VariableList params;
auto func = std::make_unique<ast::Function>("comp_main", std::move(params),
&void_type);
auto body = std::make_unique<ast::BlockStatement>();
body->append(std::make_unique<ast::ReturnStatement>());
func->set_body(std::move(body));
mod()->AddFunction(std::move(func));
auto ep = std::make_unique<ast::EntryPoint>(ast::PipelineStage::kCompute,
"main", "comp_main");
mod()->AddEntryPoint(std::move(ep));
ASSERT_TRUE(td().Determine()) << td().error();
ASSERT_TRUE(gen().Generate(out())) << gen().error();
EXPECT_EQ(result(), R"([numthreads(1, 1, 1)]
void main() {
return;
}
)");
}
TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
ast::type::F32Type f32; ast::type::F32Type f32;
ast::type::ArrayType ary(&f32, 5); ast::type::ArrayType ary(&f32, 5);