From 0a1927b4c3e4dcbc3045996c3e37f6e73ef19c9b Mon Sep 17 00:00:00 2001 From: James Price Date: Wed, 19 May 2021 14:10:28 +0000 Subject: [PATCH] writer/spirv: Support overridable workgroup sizes Generate an OpSpecConstantComposite instruction decorated with the WorkgroupSize builtin. Only support a single stage with an overridable workgroup size. Bug: tint:713 Change-Id: I139123c0af8326fcbd796cb2fc9d223882206e19 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51263 Commit-Queue: James Price Auto-Submit: James Price Reviewed-by: Ben Clayton --- src/writer/spirv/builder.cc | 64 ++++++++++++--- src/writer/spirv/builder.h | 1 + .../spirv/builder_function_decoration_test.cc | 77 ++++++++++++++++++- 3 files changed, 131 insertions(+), 11 deletions(-) diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 60c159c52f..c59b4e9c04 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -444,19 +444,63 @@ bool Builder::GenerateExecutionModes(ast::Function* func, uint32_t id) { {Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)}); } else if (func->pipeline_stage() == ast::PipelineStage::kCompute) { auto& wgsize = func_sem->workgroup_size(); + + // Check if the workgroup_size uses pipeline-overridable constants. if (wgsize[0].overridable_const || wgsize[1].overridable_const || wgsize[2].overridable_const) { - // TODO(crbug.com/tint/713): Handle overridable constants. - TINT_UNIMPLEMENTED(builder_.Diagnostics()) - << "pipeline-overridable workgroup sizes are not implemented"; + if (has_overridable_workgroup_size_) { + // Only one stage can have a pipeline-overridable workgroup size. + // TODO(crbug.com/tint/810): Use LocalSizeId to handle this scenario. + TINT_ICE(builder_.Diagnostics()) + << "multiple stages using pipeline-overridable workgroup sizes"; + } + has_overridable_workgroup_size_ = true; + + sem::U32 u32; + sem::Vector vec3_u32(&u32, 3); + uint32_t vec3_u32_type_id = GenerateTypeIfNeeded(&vec3_u32); + if (vec3_u32_type_id == 0) { + return 0; + } + + OperandList wgsize_ops; + auto wgsize_result = result_op(); + wgsize_ops.push_back(Operand::Int(vec3_u32_type_id)); + wgsize_ops.push_back(wgsize_result); + + // Generate OpConstant instructions for each dimension. + for (int i = 0; i < 3; i++) { + auto constant = ScalarConstant::U32(wgsize[i].value); + if (wgsize[i].overridable_const) { + // Make the constant specializable. + auto* sem_const = builder_.Sem().Get(wgsize[i].overridable_const); + if (!sem_const->IsPipelineConstant()) { + TINT_ICE(builder_.Diagnostics()) + << "expected a pipeline-overridable constant"; + } + constant.is_spec_op = true; + constant.constant_id = sem_const->ConstantId(); + } + + auto result = GenerateConstantIfNeeded(constant); + wgsize_ops.push_back(Operand::Int(result)); + } + + // Generate the WorkgroupSize builtin. + push_type(spv::Op::OpSpecConstantComposite, wgsize_ops); + push_annot(spv::Op::OpDecorate, + {wgsize_result, Operand::Int(SpvDecorationBuiltIn), + Operand::Int(SpvBuiltInWorkgroupSize)}); + } else { + // Not overridable, so just use OpExecutionMode LocalSize. + uint32_t x = wgsize[0].value; + uint32_t y = wgsize[1].value; + uint32_t z = wgsize[2].value; + push_execution_mode( + spv::Op::OpExecutionMode, + {Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize), + Operand::Int(x), Operand::Int(y), Operand::Int(z)}); } - uint32_t x = wgsize[0].value; - uint32_t y = wgsize[1].value; - uint32_t z = wgsize[2].value; - push_execution_mode( - spv::Op::OpExecutionMode, - {Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize), - Operand::Int(x), Operand::Int(y), Operand::Int(z)}); } for (auto builtin : func_sem->ReferencedBuiltinVariables()) { diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 98675ff37c..0774e7abef 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -533,6 +533,7 @@ class Builder { std::vector merge_stack_; std::vector continue_stack_; std::unordered_set capability_set_; + bool has_overridable_workgroup_size_ = false; }; } // namespace spirv diff --git a/src/writer/spirv/builder_function_decoration_test.cc b/src/writer/spirv/builder_function_decoration_test.cc index 91bc0050fb..c2f8b5ede7 100644 --- a/src/writer/spirv/builder_function_decoration_test.cc +++ b/src/writer/spirv/builder_function_decoration_test.cc @@ -197,7 +197,7 @@ TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_Default) { )"); } -TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize) { +TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_Literals) { auto* func = Func("main", {}, ty.void_(), ast::StatementList{}, ast::DecorationList{ WorkgroupSize(2, 4, 6), @@ -212,6 +212,81 @@ TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize) { )"); } +TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_Const) { + GlobalConst("width", ty.i32(), Construct(ty.i32(), 2)); + GlobalConst("height", ty.i32(), Construct(ty.i32(), 3)); + GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4)); + auto* func = Func("main", {}, ty.void_(), ast::StatementList{}, + ast::DecorationList{ + WorkgroupSize("width", "height", "depth"), + Stage(ast::PipelineStage::kCompute), + }); + + spirv::Builder& b = Build(); + + ASSERT_TRUE(b.GenerateExecutionModes(func, 3)) << b.error(); + EXPECT_EQ(DumpInstructions(b.execution_modes()), + R"(OpExecutionMode %3 LocalSize 2 3 4 +)"); +} + +TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_OverridableConst) { + GlobalConst("width", ty.i32(), Construct(ty.i32(), 2), {Override(7u)}); + GlobalConst("height", ty.i32(), Construct(ty.i32(), 3), {Override(8u)}); + GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4), {Override(9u)}); + auto* func = Func("main", {}, ty.void_(), ast::StatementList{}, + ast::DecorationList{ + WorkgroupSize("width", "height", "depth"), + Stage(ast::PipelineStage::kCompute), + }); + + spirv::Builder& b = Build(); + + ASSERT_TRUE(b.GenerateExecutionModes(func, 3)) << b.error(); + EXPECT_EQ(DumpInstructions(b.execution_modes()), ""); + EXPECT_EQ(DumpInstructions(b.types()), + R"(%2 = OpTypeInt 32 0 +%1 = OpTypeVector %2 3 +%4 = OpSpecConstant %2 2 +%5 = OpSpecConstant %2 3 +%6 = OpSpecConstant %2 4 +%3 = OpSpecConstantComposite %1 %4 %5 %6 +)"); + EXPECT_EQ(DumpInstructions(b.annots()), + R"(OpDecorate %4 SpecId 7 +OpDecorate %5 SpecId 8 +OpDecorate %6 SpecId 9 +OpDecorate %3 BuiltIn WorkgroupSize +)"); +} + +TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_LiteralAndConst) { + GlobalConst("height", ty.i32(), Construct(ty.i32(), 2), {Override(7u)}); + GlobalConst("depth", ty.i32(), Construct(ty.i32(), 3)); + auto* func = Func("main", {}, ty.void_(), ast::StatementList{}, + ast::DecorationList{ + WorkgroupSize(4, "height", "depth"), + Stage(ast::PipelineStage::kCompute), + }); + + spirv::Builder& b = Build(); + + ASSERT_TRUE(b.GenerateExecutionModes(func, 3)) << b.error(); + EXPECT_EQ(DumpInstructions(b.execution_modes()), ""); + EXPECT_EQ(DumpInstructions(b.types()), + R"(%2 = OpTypeInt 32 0 +%1 = OpTypeVector %2 3 +%4 = OpConstant %2 4 +%5 = OpSpecConstant %2 2 +%6 = OpConstant %2 3 +%3 = OpSpecConstantComposite %1 %4 %5 %6 +)"); + EXPECT_EQ(DumpInstructions(b.annots()), + R"(OpDecorate %5 SpecId 7 +OpDecorate %3 BuiltIn WorkgroupSize +)"); +} + TEST_F(BuilderTest, Decoration_ExecutionMode_MultipleFragment) { auto* func1 = Func("main1", {}, ty.void_(), ast::StatementList{}, ast::DecorationList{