[ir] Add optional CreateFunction parameters.

This CL adds the pipeline stage and workgroup_size as optional
parameters when creating a function in the IR.

Bug: tint:1718

Change-Id: Iae65dcb9557a644a17ec67fc5269d0c2db3f8aba
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/133001
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Dan Sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
dan sinclair 2023-05-16 14:43:45 +00:00 committed by Dawn LUCI CQ
parent 69b5900c88
commit f59547fb7f
5 changed files with 34 additions and 20 deletions

View File

@ -47,10 +47,13 @@ FunctionTerminator* Builder::CreateFunctionTerminator() {
return ir.flow_nodes.Create<FunctionTerminator>();
}
Function* Builder::CreateFunction(Symbol name, type::Type* return_type) {
Function* Builder::CreateFunction(Symbol name,
type::Type* return_type,
Function::PipelineStage stage,
std::optional<std::array<uint32_t, 3>> wg_size) {
TINT_ASSERT(IR, return_type);
auto* ir_func = ir.flow_nodes.Create<Function>(name, return_type);
auto* ir_func = ir.flow_nodes.Create<Function>(name, return_type, stage, wg_size);
ir_func->start_target = CreateBlock();
ir_func->end_target = CreateFunctionTerminator();

View File

@ -67,8 +67,13 @@ class Builder {
/// Creates a function flow node
/// @param name the function name
/// @param return_type the function return type
/// @param stage the function stage
/// @param wg_size the workgroup_size
/// @returns the flow node
Function* CreateFunction(Symbol name, type::Type* return_type);
Function* CreateFunction(Symbol name,
type::Type* return_type,
Function::PipelineStage stage = Function::PipelineStage::kUndefined,
std::optional<std::array<uint32_t, 3>> wg_size = {});
/// Creates an if flow node
/// @param condition the if condition

View File

@ -18,7 +18,11 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Function);
namespace tint::ir {
Function::Function(Symbol n, type::Type* rt) : Base(), name(n), return_type(rt) {}
Function::Function(Symbol n,
type::Type* rt,
PipelineStage stage,
std::optional<std::array<uint32_t, 3>> wg_size)
: Base(), name(n), pipeline_stage(stage), workgroup_size(wg_size), return_type(rt) {}
Function::~Function() = default;

View File

@ -64,7 +64,12 @@ class Function : public utils::Castable<Function, FlowNode> {
/// Constructor
/// @param n the function name
/// @param rt the function return type
Function(Symbol n, type::Type* rt);
/// @param stage the function stage
/// @param wg_size the workgroup_size
Function(Symbol n,
type::Type* rt,
PipelineStage stage = PipelineStage::kUndefined,
std::optional<std::array<uint32_t, 3>> wg_size = {});
~Function() override;
/// The function name

View File

@ -46,9 +46,8 @@ TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
func->pipeline_stage = ir::Function::PipelineStage::kCompute;
func->workgroup_size = {32, 4, 1};
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>(),
ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
b.Branch(func->start_target, func->end_target);
generator_.EmitFunction(func);
@ -65,8 +64,8 @@ OpFunctionEnd
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
func->pipeline_stage = ir::Function::PipelineStage::kFragment;
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>(),
ir::Function::PipelineStage::kFragment);
b.Branch(func->start_target, func->end_target);
generator_.EmitFunction(func);
@ -83,8 +82,8 @@ OpFunctionEnd
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
func->pipeline_stage = ir::Function::PipelineStage::kVertex;
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>(),
ir::Function::PipelineStage::kVertex);
b.Branch(func->start_target, func->end_target);
generator_.EmitFunction(func);
@ -100,18 +99,16 @@ OpFunctionEnd
}
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
auto* f1 = b.CreateFunction(mod.symbols.Register("main1"), mod.types.Get<type::Void>());
f1->pipeline_stage = ir::Function::PipelineStage::kCompute;
f1->workgroup_size = {32, 4, 1};
auto* f1 = b.CreateFunction(mod.symbols.Register("main1"), mod.types.Get<type::Void>(),
ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
b.Branch(f1->start_target, f1->end_target);
auto* f2 = b.CreateFunction(mod.symbols.Register("main2"), mod.types.Get<type::Void>());
f2->pipeline_stage = ir::Function::PipelineStage::kCompute;
f2->workgroup_size = {8, 2, 16};
auto* f2 = b.CreateFunction(mod.symbols.Register("main2"), mod.types.Get<type::Void>(),
ir::Function::PipelineStage::kCompute, {{8, 2, 16}});
b.Branch(f2->start_target, f2->end_target);
auto* f3 = b.CreateFunction(mod.symbols.Register("main3"), mod.types.Get<type::Void>());
f3->pipeline_stage = ir::Function::PipelineStage::kFragment;
auto* f3 = b.CreateFunction(mod.symbols.Register("main3"), mod.types.Get<type::Void>(),
ir::Function::PipelineStage::kFragment);
b.Branch(f3->start_target, f3->end_target);
generator_.EmitFunction(f1);