writer/spirv: Support overridable workgroup sizes

Generate an OpSpecConstantComposite instruction decorated with the
WorkgroupSize builtin. Only support a single stage with an overridable
workgroup size.

Bug: tint:713
Change-Id: I139123c0af8326fcbd796cb2fc9d223882206e19
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/51263
Commit-Queue: James Price <jrprice@google.com>
Auto-Submit: James Price <jrprice@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
James Price 2021-05-19 14:10:28 +00:00 committed by Commit Bot service account
parent 70f80bb13d
commit 0a1927b4c3
3 changed files with 131 additions and 11 deletions

View File

@ -444,19 +444,63 @@ bool Builder::GenerateExecutionModes(ast::Function* func, uint32_t id) {
{Operand::Int(id), Operand::Int(SpvExecutionModeOriginUpperLeft)});
} else if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
auto& wgsize = func_sem->workgroup_size();
// Check if the workgroup_size uses pipeline-overridable constants.
if (wgsize[0].overridable_const || wgsize[1].overridable_const ||
wgsize[2].overridable_const) {
// TODO(crbug.com/tint/713): Handle overridable constants.
TINT_UNIMPLEMENTED(builder_.Diagnostics())
<< "pipeline-overridable workgroup sizes are not implemented";
if (has_overridable_workgroup_size_) {
// Only one stage can have a pipeline-overridable workgroup size.
// TODO(crbug.com/tint/810): Use LocalSizeId to handle this scenario.
TINT_ICE(builder_.Diagnostics())
<< "multiple stages using pipeline-overridable workgroup sizes";
}
has_overridable_workgroup_size_ = true;
sem::U32 u32;
sem::Vector vec3_u32(&u32, 3);
uint32_t vec3_u32_type_id = GenerateTypeIfNeeded(&vec3_u32);
if (vec3_u32_type_id == 0) {
return 0;
}
OperandList wgsize_ops;
auto wgsize_result = result_op();
wgsize_ops.push_back(Operand::Int(vec3_u32_type_id));
wgsize_ops.push_back(wgsize_result);
// Generate OpConstant instructions for each dimension.
for (int i = 0; i < 3; i++) {
auto constant = ScalarConstant::U32(wgsize[i].value);
if (wgsize[i].overridable_const) {
// Make the constant specializable.
auto* sem_const = builder_.Sem().Get(wgsize[i].overridable_const);
if (!sem_const->IsPipelineConstant()) {
TINT_ICE(builder_.Diagnostics())
<< "expected a pipeline-overridable constant";
}
constant.is_spec_op = true;
constant.constant_id = sem_const->ConstantId();
}
auto result = GenerateConstantIfNeeded(constant);
wgsize_ops.push_back(Operand::Int(result));
}
// Generate the WorkgroupSize builtin.
push_type(spv::Op::OpSpecConstantComposite, wgsize_ops);
push_annot(spv::Op::OpDecorate,
{wgsize_result, Operand::Int(SpvDecorationBuiltIn),
Operand::Int(SpvBuiltInWorkgroupSize)});
} else {
// Not overridable, so just use OpExecutionMode LocalSize.
uint32_t x = wgsize[0].value;
uint32_t y = wgsize[1].value;
uint32_t z = wgsize[2].value;
push_execution_mode(
spv::Op::OpExecutionMode,
{Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize),
Operand::Int(x), Operand::Int(y), Operand::Int(z)});
}
uint32_t x = wgsize[0].value;
uint32_t y = wgsize[1].value;
uint32_t z = wgsize[2].value;
push_execution_mode(
spv::Op::OpExecutionMode,
{Operand::Int(id), Operand::Int(SpvExecutionModeLocalSize),
Operand::Int(x), Operand::Int(y), Operand::Int(z)});
}
for (auto builtin : func_sem->ReferencedBuiltinVariables()) {

View File

@ -533,6 +533,7 @@ class Builder {
std::vector<uint32_t> merge_stack_;
std::vector<uint32_t> continue_stack_;
std::unordered_set<uint32_t> capability_set_;
bool has_overridable_workgroup_size_ = false;
};
} // namespace spirv

View File

@ -197,7 +197,7 @@ TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_Default) {
)");
}
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize) {
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_Literals) {
auto* func = Func("main", {}, ty.void_(), ast::StatementList{},
ast::DecorationList{
WorkgroupSize(2, 4, 6),
@ -212,6 +212,81 @@ TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize) {
)");
}
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_Const) {
GlobalConst("width", ty.i32(), Construct(ty.i32(), 2));
GlobalConst("height", ty.i32(), Construct(ty.i32(), 3));
GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4));
auto* func = Func("main", {}, ty.void_(), ast::StatementList{},
ast::DecorationList{
WorkgroupSize("width", "height", "depth"),
Stage(ast::PipelineStage::kCompute),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateExecutionModes(func, 3)) << b.error();
EXPECT_EQ(DumpInstructions(b.execution_modes()),
R"(OpExecutionMode %3 LocalSize 2 3 4
)");
}
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_OverridableConst) {
GlobalConst("width", ty.i32(), Construct(ty.i32(), 2), {Override(7u)});
GlobalConst("height", ty.i32(), Construct(ty.i32(), 3), {Override(8u)});
GlobalConst("depth", ty.i32(), Construct(ty.i32(), 4), {Override(9u)});
auto* func = Func("main", {}, ty.void_(), ast::StatementList{},
ast::DecorationList{
WorkgroupSize("width", "height", "depth"),
Stage(ast::PipelineStage::kCompute),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateExecutionModes(func, 3)) << b.error();
EXPECT_EQ(DumpInstructions(b.execution_modes()), "");
EXPECT_EQ(DumpInstructions(b.types()),
R"(%2 = OpTypeInt 32 0
%1 = OpTypeVector %2 3
%4 = OpSpecConstant %2 2
%5 = OpSpecConstant %2 3
%6 = OpSpecConstant %2 4
%3 = OpSpecConstantComposite %1 %4 %5 %6
)");
EXPECT_EQ(DumpInstructions(b.annots()),
R"(OpDecorate %4 SpecId 7
OpDecorate %5 SpecId 8
OpDecorate %6 SpecId 9
OpDecorate %3 BuiltIn WorkgroupSize
)");
}
TEST_F(BuilderTest, Decoration_ExecutionMode_WorkgroupSize_LiteralAndConst) {
GlobalConst("height", ty.i32(), Construct(ty.i32(), 2), {Override(7u)});
GlobalConst("depth", ty.i32(), Construct(ty.i32(), 3));
auto* func = Func("main", {}, ty.void_(), ast::StatementList{},
ast::DecorationList{
WorkgroupSize(4, "height", "depth"),
Stage(ast::PipelineStage::kCompute),
});
spirv::Builder& b = Build();
ASSERT_TRUE(b.GenerateExecutionModes(func, 3)) << b.error();
EXPECT_EQ(DumpInstructions(b.execution_modes()), "");
EXPECT_EQ(DumpInstructions(b.types()),
R"(%2 = OpTypeInt 32 0
%1 = OpTypeVector %2 3
%4 = OpConstant %2 4
%5 = OpSpecConstant %2 2
%6 = OpConstant %2 3
%3 = OpSpecConstantComposite %1 %4 %5 %6
)");
EXPECT_EQ(DumpInstructions(b.annots()),
R"(OpDecorate %5 SpecId 7
OpDecorate %3 BuiltIn WorkgroupSize
)");
}
TEST_F(BuilderTest, Decoration_ExecutionMode_MultipleFragment) {
auto* func1 = Func("main1", {}, ty.void_(), ast::StatementList{},
ast::DecorationList{