[ir] Add optional CreateFunction parameters.

This CL adds the pipeline stage and workgroup_size as optional
parameters when creating a function in the IR.

Bug: tint:1718

Change-Id: Iae65dcb9557a644a17ec67fc5269d0c2db3f8aba
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/133001
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
dan sinclair 2023-05-16 14:43:45 +00:00 committed by Dawn LUCI CQ
parent 69b5900c88
commit f59547fb7f
5 changed files with 34 additions and 20 deletions

View File

@ -47,10 +47,13 @@ FunctionTerminator* Builder::CreateFunctionTerminator() {
return ir.flow_nodes.Create<FunctionTerminator>(); return ir.flow_nodes.Create<FunctionTerminator>();
} }
Function* Builder::CreateFunction(Symbol name, type::Type* return_type) { Function* Builder::CreateFunction(Symbol name,
type::Type* return_type,
Function::PipelineStage stage,
std::optional<std::array<uint32_t, 3>> wg_size) {
TINT_ASSERT(IR, return_type); TINT_ASSERT(IR, return_type);
auto* ir_func = ir.flow_nodes.Create<Function>(name, return_type); auto* ir_func = ir.flow_nodes.Create<Function>(name, return_type, stage, wg_size);
ir_func->start_target = CreateBlock(); ir_func->start_target = CreateBlock();
ir_func->end_target = CreateFunctionTerminator(); ir_func->end_target = CreateFunctionTerminator();

View File

@ -67,8 +67,13 @@ class Builder {
/// Creates a function flow node /// Creates a function flow node
/// @param name the function name /// @param name the function name
/// @param return_type the function return type /// @param return_type the function return type
/// @param stage the function stage
/// @param wg_size the workgroup_size
/// @returns the flow node /// @returns the flow node
Function* CreateFunction(Symbol name, type::Type* return_type); Function* CreateFunction(Symbol name,
type::Type* return_type,
Function::PipelineStage stage = Function::PipelineStage::kUndefined,
std::optional<std::array<uint32_t, 3>> wg_size = {});
/// Creates an if flow node /// Creates an if flow node
/// @param condition the if condition /// @param condition the if condition

View File

@ -18,7 +18,11 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Function);
namespace tint::ir { namespace tint::ir {
Function::Function(Symbol n, type::Type* rt) : Base(), name(n), return_type(rt) {} Function::Function(Symbol n,
type::Type* rt,
PipelineStage stage,
std::optional<std::array<uint32_t, 3>> wg_size)
: Base(), name(n), pipeline_stage(stage), workgroup_size(wg_size), return_type(rt) {}
Function::~Function() = default; Function::~Function() = default;

View File

@ -64,7 +64,12 @@ class Function : public utils::Castable<Function, FlowNode> {
/// Constructor /// Constructor
/// @param n the function name /// @param n the function name
/// @param rt the function return type /// @param rt the function return type
Function(Symbol n, type::Type* rt); /// @param stage the function stage
/// @param wg_size the workgroup_size
Function(Symbol n,
type::Type* rt,
PipelineStage stage = PipelineStage::kUndefined,
std::optional<std::array<uint32_t, 3>> wg_size = {});
~Function() override; ~Function() override;
/// The function name /// The function name

View File

@ -46,9 +46,8 @@ TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
} }
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) { TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>()); auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>(),
func->pipeline_stage = ir::Function::PipelineStage::kCompute; ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
func->workgroup_size = {32, 4, 1};
b.Branch(func->start_target, func->end_target); b.Branch(func->start_target, func->end_target);
generator_.EmitFunction(func); generator_.EmitFunction(func);
@ -65,8 +64,8 @@ OpFunctionEnd
} }
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) { TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>()); auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>(),
func->pipeline_stage = ir::Function::PipelineStage::kFragment; ir::Function::PipelineStage::kFragment);
b.Branch(func->start_target, func->end_target); b.Branch(func->start_target, func->end_target);
generator_.EmitFunction(func); generator_.EmitFunction(func);
@ -83,8 +82,8 @@ OpFunctionEnd
} }
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) { TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>()); auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>(),
func->pipeline_stage = ir::Function::PipelineStage::kVertex; ir::Function::PipelineStage::kVertex);
b.Branch(func->start_target, func->end_target); b.Branch(func->start_target, func->end_target);
generator_.EmitFunction(func); generator_.EmitFunction(func);
@ -100,18 +99,16 @@ OpFunctionEnd
} }
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) { TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
auto* f1 = b.CreateFunction(mod.symbols.Register("main1"), mod.types.Get<type::Void>()); auto* f1 = b.CreateFunction(mod.symbols.Register("main1"), mod.types.Get<type::Void>(),
f1->pipeline_stage = ir::Function::PipelineStage::kCompute; ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
f1->workgroup_size = {32, 4, 1};
b.Branch(f1->start_target, f1->end_target); b.Branch(f1->start_target, f1->end_target);
auto* f2 = b.CreateFunction(mod.symbols.Register("main2"), mod.types.Get<type::Void>()); auto* f2 = b.CreateFunction(mod.symbols.Register("main2"), mod.types.Get<type::Void>(),
f2->pipeline_stage = ir::Function::PipelineStage::kCompute; ir::Function::PipelineStage::kCompute, {{8, 2, 16}});
f2->workgroup_size = {8, 2, 16};
b.Branch(f2->start_target, f2->end_target); b.Branch(f2->start_target, f2->end_target);
auto* f3 = b.CreateFunction(mod.symbols.Register("main3"), mod.types.Get<type::Void>()); auto* f3 = b.CreateFunction(mod.symbols.Register("main3"), mod.types.Get<type::Void>(),
f3->pipeline_stage = ir::Function::PipelineStage::kFragment; ir::Function::PipelineStage::kFragment);
b.Branch(f3->start_target, f3->end_target); b.Branch(f3->start_target, f3->end_target);
generator_.EmitFunction(f1); generator_.EmitFunction(f1);