[ir] Add optional CreateFunction parameters.
This CL adds the pipeline stage and workgroup_size as optional parameters when creating a function in the IR. Bug: tint:1718 Change-Id: Iae65dcb9557a644a17ec67fc5269d0c2db3f8aba Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/133001 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Dan Sinclair <dsinclair@chromium.org> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
69b5900c88
commit
f59547fb7f
|
@ -47,10 +47,13 @@ FunctionTerminator* Builder::CreateFunctionTerminator() {
|
||||||
return ir.flow_nodes.Create<FunctionTerminator>();
|
return ir.flow_nodes.Create<FunctionTerminator>();
|
||||||
}
|
}
|
||||||
|
|
||||||
Function* Builder::CreateFunction(Symbol name, type::Type* return_type) {
|
Function* Builder::CreateFunction(Symbol name,
|
||||||
|
type::Type* return_type,
|
||||||
|
Function::PipelineStage stage,
|
||||||
|
std::optional<std::array<uint32_t, 3>> wg_size) {
|
||||||
TINT_ASSERT(IR, return_type);
|
TINT_ASSERT(IR, return_type);
|
||||||
|
|
||||||
auto* ir_func = ir.flow_nodes.Create<Function>(name, return_type);
|
auto* ir_func = ir.flow_nodes.Create<Function>(name, return_type, stage, wg_size);
|
||||||
ir_func->start_target = CreateBlock();
|
ir_func->start_target = CreateBlock();
|
||||||
ir_func->end_target = CreateFunctionTerminator();
|
ir_func->end_target = CreateFunctionTerminator();
|
||||||
|
|
||||||
|
|
|
@ -67,8 +67,13 @@ class Builder {
|
||||||
/// Creates a function flow node
|
/// Creates a function flow node
|
||||||
/// @param name the function name
|
/// @param name the function name
|
||||||
/// @param return_type the function return type
|
/// @param return_type the function return type
|
||||||
|
/// @param stage the function stage
|
||||||
|
/// @param wg_size the workgroup_size
|
||||||
/// @returns the flow node
|
/// @returns the flow node
|
||||||
Function* CreateFunction(Symbol name, type::Type* return_type);
|
Function* CreateFunction(Symbol name,
|
||||||
|
type::Type* return_type,
|
||||||
|
Function::PipelineStage stage = Function::PipelineStage::kUndefined,
|
||||||
|
std::optional<std::array<uint32_t, 3>> wg_size = {});
|
||||||
|
|
||||||
/// Creates an if flow node
|
/// Creates an if flow node
|
||||||
/// @param condition the if condition
|
/// @param condition the if condition
|
||||||
|
|
|
@ -18,7 +18,11 @@ TINT_INSTANTIATE_TYPEINFO(tint::ir::Function);
|
||||||
|
|
||||||
namespace tint::ir {
|
namespace tint::ir {
|
||||||
|
|
||||||
Function::Function(Symbol n, type::Type* rt) : Base(), name(n), return_type(rt) {}
|
Function::Function(Symbol n,
|
||||||
|
type::Type* rt,
|
||||||
|
PipelineStage stage,
|
||||||
|
std::optional<std::array<uint32_t, 3>> wg_size)
|
||||||
|
: Base(), name(n), pipeline_stage(stage), workgroup_size(wg_size), return_type(rt) {}
|
||||||
|
|
||||||
Function::~Function() = default;
|
Function::~Function() = default;
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,12 @@ class Function : public utils::Castable<Function, FlowNode> {
|
||||||
/// Constructor
|
/// Constructor
|
||||||
/// @param n the function name
|
/// @param n the function name
|
||||||
/// @param rt the function return type
|
/// @param rt the function return type
|
||||||
Function(Symbol n, type::Type* rt);
|
/// @param stage the function stage
|
||||||
|
/// @param wg_size the workgroup_size
|
||||||
|
Function(Symbol n,
|
||||||
|
type::Type* rt,
|
||||||
|
PipelineStage stage = PipelineStage::kUndefined,
|
||||||
|
std::optional<std::array<uint32_t, 3>> wg_size = {});
|
||||||
~Function() override;
|
~Function() override;
|
||||||
|
|
||||||
/// The function name
|
/// The function name
|
||||||
|
|
|
@ -46,9 +46,8 @@ TEST_F(SpvGeneratorImplTest, Function_DeduplicateType) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
|
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Compute) {
|
||||||
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
|
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>(),
|
||||||
func->pipeline_stage = ir::Function::PipelineStage::kCompute;
|
ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
|
||||||
func->workgroup_size = {32, 4, 1};
|
|
||||||
b.Branch(func->start_target, func->end_target);
|
b.Branch(func->start_target, func->end_target);
|
||||||
|
|
||||||
generator_.EmitFunction(func);
|
generator_.EmitFunction(func);
|
||||||
|
@ -65,8 +64,8 @@ OpFunctionEnd
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
|
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Fragment) {
|
||||||
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
|
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>(),
|
||||||
func->pipeline_stage = ir::Function::PipelineStage::kFragment;
|
ir::Function::PipelineStage::kFragment);
|
||||||
b.Branch(func->start_target, func->end_target);
|
b.Branch(func->start_target, func->end_target);
|
||||||
|
|
||||||
generator_.EmitFunction(func);
|
generator_.EmitFunction(func);
|
||||||
|
@ -83,8 +82,8 @@ OpFunctionEnd
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
|
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Vertex) {
|
||||||
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>());
|
auto* func = b.CreateFunction(mod.symbols.Register("main"), mod.types.Get<type::Void>(),
|
||||||
func->pipeline_stage = ir::Function::PipelineStage::kVertex;
|
ir::Function::PipelineStage::kVertex);
|
||||||
b.Branch(func->start_target, func->end_target);
|
b.Branch(func->start_target, func->end_target);
|
||||||
|
|
||||||
generator_.EmitFunction(func);
|
generator_.EmitFunction(func);
|
||||||
|
@ -100,18 +99,16 @@ OpFunctionEnd
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
|
TEST_F(SpvGeneratorImplTest, Function_EntryPoint_Multiple) {
|
||||||
auto* f1 = b.CreateFunction(mod.symbols.Register("main1"), mod.types.Get<type::Void>());
|
auto* f1 = b.CreateFunction(mod.symbols.Register("main1"), mod.types.Get<type::Void>(),
|
||||||
f1->pipeline_stage = ir::Function::PipelineStage::kCompute;
|
ir::Function::PipelineStage::kCompute, {{32, 4, 1}});
|
||||||
f1->workgroup_size = {32, 4, 1};
|
|
||||||
b.Branch(f1->start_target, f1->end_target);
|
b.Branch(f1->start_target, f1->end_target);
|
||||||
|
|
||||||
auto* f2 = b.CreateFunction(mod.symbols.Register("main2"), mod.types.Get<type::Void>());
|
auto* f2 = b.CreateFunction(mod.symbols.Register("main2"), mod.types.Get<type::Void>(),
|
||||||
f2->pipeline_stage = ir::Function::PipelineStage::kCompute;
|
ir::Function::PipelineStage::kCompute, {{8, 2, 16}});
|
||||||
f2->workgroup_size = {8, 2, 16};
|
|
||||||
b.Branch(f2->start_target, f2->end_target);
|
b.Branch(f2->start_target, f2->end_target);
|
||||||
|
|
||||||
auto* f3 = b.CreateFunction(mod.symbols.Register("main3"), mod.types.Get<type::Void>());
|
auto* f3 = b.CreateFunction(mod.symbols.Register("main3"), mod.types.Get<type::Void>(),
|
||||||
f3->pipeline_stage = ir::Function::PipelineStage::kFragment;
|
ir::Function::PipelineStage::kFragment);
|
||||||
b.Branch(f3->start_target, f3->end_target);
|
b.Branch(f3->start_target, f3->end_target);
|
||||||
|
|
||||||
generator_.EmitFunction(f1);
|
generator_.EmitFunction(f1);
|
||||||
|
|
Loading…
Reference in New Issue