D3D12: Support [[num_workgroups]] for Dispatch

This patch implements [[num_workgroups]] on the API side for
Dispatch() calls by setting num_workgroups.xyz as root constants.

This patch also adds a temporary validation that on D3D12 backend
using a compute pipeline with [[num_workgroups]] in a
DispatchIndirect call is not supported.

BUG=dawn:839
TEST=dawn_end2end_tests

Change-Id: Iaee2ffd162e9420e4e80944fbb222f10a4600c6a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/66580
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
Jiawei Shao 2021-10-20 00:58:48 +00:00 committed by Dawn LUCI CQ
parent a5c0c8f6be
commit 1349ca182e
9 changed files with 124 additions and 21 deletions

View File

@ -677,6 +677,8 @@ namespace dawn_native {
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
}
if (metadata->stage == SingleShaderStage::Vertex) {

View File

@ -204,6 +204,8 @@ namespace dawn_native {
// Store overridableConstants from tint program
std::unordered_map<std::string, OverridableConstant> overridableConstants;
bool usesNumWorkgroups = false;
};
class ShaderModuleBase : public ApiObjectBase, public CachedObject {

View File

@ -254,6 +254,18 @@ namespace dawn_native { namespace d3d12 {
return {};
}
void RecordNumWorkgroupsForDispatch(ID3D12GraphicsCommandList* commandList,
ComputePipeline* pipeline,
DispatchCmd* dispatch) {
if (!pipeline->UsesNumWorkgroups()) {
return;
}
PipelineLayout* layout = ToBackend(pipeline->GetLayout());
commandList->SetComputeRoot32BitConstants(layout->GetNumWorkgroupsParameterIndex(), 3,
dispatch, 0);
}
// Records the necessary barriers for a synchronization scope using the resource usage
// data pre-computed in the frontend. Also performs lazy initialization if required.
// Returns whether any UAV are used in the synchronization scope.
@ -1030,6 +1042,7 @@ namespace dawn_native { namespace d3d12 {
ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
Command type;
ComputePipeline* lastPipeline = nullptr;
while (mCommands.NextCommandId(&type)) {
switch (type) {
case Command::Dispatch: {
@ -1045,6 +1058,7 @@ namespace dawn_native { namespace d3d12 {
resourceUsages.dispatchUsages[currentDispatch]);
DAWN_TRY(bindingTracker->Apply(commandContext));
RecordNumWorkgroupsForDispatch(commandList, lastPipeline, dispatch);
commandList->Dispatch(dispatch->x, dispatch->y, dispatch->z);
currentDispatch++;
break;
@ -1052,6 +1066,14 @@ namespace dawn_native { namespace d3d12 {
case Command::DispatchIndirect: {
DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
// TODO(dawn:839): support [[num_workgroups]] for DispatchIndirect calls
if (lastPipeline->UsesNumWorkgroups()) {
return DAWN_VALIDATION_ERROR(
"Using a compute pipeline with [[num_workgroups]] in a "
"DispatchIndirect call is not implemented");
}
Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
TransitionAndClearForSyncScope(commandContext,
@ -1078,6 +1100,7 @@ namespace dawn_native { namespace d3d12 {
commandList->SetPipelineState(pipeline->GetPipelineState());
bindingTracker->OnSetPipeline(pipeline);
lastPipeline = pipeline;
break;
}

View File

@ -84,4 +84,8 @@ namespace dawn_native { namespace d3d12 {
CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
}
bool ComputePipeline::UsesNumWorkgroups() const {
return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups;
}
}} // namespace dawn_native::d3d12

View File

@ -40,6 +40,8 @@ namespace dawn_native { namespace d3d12 {
// Dawn API
void SetLabelImpl() override;
bool UsesNumWorkgroups() const;
private:
~ComputePipeline() override;
using ComputePipelineBase::ComputePipelineBase;

View File

@ -174,6 +174,21 @@ namespace dawn_native { namespace d3d12 {
// would need to be updated often
rootParameters.emplace_back(indexOffsetConstants);
// Always allocate 3 constants for num_workgroups_x, num_workgroups_y and num_workgroups_z
// for Dispatch calls
// NOTE: We should consider delaying root signature creation until we know how many values
// we need
D3D12_ROOT_PARAMETER numWorkgroupsConstants{};
numWorkgroupsConstants.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
numWorkgroupsConstants.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS;
numWorkgroupsConstants.Constants.Num32BitValues = 3;
numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace();
numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister();
mNumWorkgroupsParamterIndex = rootParameters.size();
// NOTE: We should consider moving this entry to earlier in the root signature since
// dispatch sizes would need to be updated often
rootParameters.emplace_back(numWorkgroupsConstants);
D3D12_ROOT_SIGNATURE_DESC rootSignatureDescriptor;
rootSignatureDescriptor.NumParameters = rootParameters.size();
rootSignatureDescriptor.pParameters = rootParameters.data();
@ -230,7 +245,7 @@ namespace dawn_native { namespace d3d12 {
}
uint32_t PipelineLayout::GetFirstIndexOffsetRegisterSpace() const {
return kReservedRegisterSpace;
return kFirstIndexOffsetRegisterSpace;
}
uint32_t PipelineLayout::GetFirstIndexOffsetShaderRegister() const {
@ -240,4 +255,16 @@ namespace dawn_native { namespace d3d12 {
uint32_t PipelineLayout::GetFirstIndexOffsetParameterIndex() const {
return mFirstIndexOffsetParameterIndex;
}
uint32_t PipelineLayout::GetNumWorkgroupsRegisterSpace() const {
return kNumWorkgroupsRegisterSpace;
}
uint32_t PipelineLayout::GetNumWorkgroupsShaderRegister() const {
return kNumWorkgroupsBaseRegister;
}
uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const {
return mNumWorkgroupsParamterIndex;
}
}} // namespace dawn_native::d3d12

View File

@ -26,6 +26,9 @@ namespace dawn_native { namespace d3d12 {
// We reserve a register space that a user cannot use.
static constexpr uint32_t kReservedRegisterSpace = kMaxBindGroups + 1;
static constexpr uint32_t kFirstOffsetInfoBaseRegister = 0;
static constexpr uint32_t kFirstIndexOffsetRegisterSpace = kReservedRegisterSpace;
static constexpr uint32_t kNumWorkgroupsRegisterSpace = kReservedRegisterSpace + 1;
static constexpr uint32_t kNumWorkgroupsBaseRegister = 0;
class Device;
@ -46,6 +49,10 @@ namespace dawn_native { namespace d3d12 {
uint32_t GetFirstIndexOffsetShaderRegister() const;
uint32_t GetFirstIndexOffsetParameterIndex() const;
uint32_t GetNumWorkgroupsRegisterSpace() const;
uint32_t GetNumWorkgroupsShaderRegister() const;
uint32_t GetNumWorkgroupsParameterIndex() const;
ID3D12RootSignature* GetRootSignature() const;
private:
@ -59,6 +66,7 @@ namespace dawn_native { namespace d3d12 {
kMaxBindGroups>
mDynamicRootParameterIndices;
uint32_t mFirstIndexOffsetParameterIndex;
uint32_t mNumWorkgroupsParamterIndex;
ComPtr<ID3D12RootSignature> mRootSignature;
};

View File

@ -106,6 +106,9 @@ namespace dawn_native { namespace d3d12 {
tint::transform::BindingRemapper::BindingPoints bindingPoints;
tint::transform::BindingRemapper::AccessControls accessControls;
bool isRobustnessEnabled;
bool usesNumWorkgroups;
uint32_t numWorkgroupsRegisterSpace;
uint32_t numWorkgroupsShaderRegister;
// FXC/DXC common inputs
bool disableWorkgroupInit;
@ -125,7 +128,7 @@ namespace dawn_native { namespace d3d12 {
uint32_t compileFlags,
const Device* device,
const tint::Program* program,
const BindingInfoArray& moduleBindingInfo) {
const EntryPointMetadata& entryPoint) {
Compiler compiler;
uint64_t dxcVersion = 0;
if (device->IsToggleEnabled(Toggle::UseDXC)) {
@ -145,6 +148,7 @@ namespace dawn_native { namespace d3d12 {
// Tint AST to make the "bindings" decoration match the offset chosen by
// d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
// assigned to each interface variable.
const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
const auto& groupBindingInfo = moduleBindingInfo[group];
@ -189,6 +193,9 @@ namespace dawn_native { namespace d3d12 {
request.isRobustnessEnabled = device->IsRobustnessEnabled();
request.disableWorkgroupInit =
device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
request.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
request.deviceInfo = &device->GetDeviceInfo();
@ -234,6 +241,10 @@ namespace dawn_native { namespace d3d12 {
stream << " accessControls=";
Serialize(stream, accessControls);
stream << " useNumWorkgroups=" << usesNumWorkgroups;
stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;
stream << " shaderModel=" << deviceInfo->shaderModel;
stream << " disableWorkgroupInit=" << disableWorkgroupInit;
stream << " isRobustnessEnabled=" << isRobustnessEnabled;
@ -423,6 +434,10 @@ namespace dawn_native { namespace d3d12 {
tint::writer::hlsl::Options options;
options.disable_workgroup_init = request.disableWorkgroupInit;
if (request.usesNumWorkgroups) {
options.root_constant_binding_point.group = request.numWorkgroupsRegisterSpace;
options.root_constant_binding_point.binding = request.numWorkgroupsShaderRegister;
}
auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
if (!result.success) {
errorStream << "Generator: " << result.error << std::endl;
@ -547,9 +562,9 @@ namespace dawn_native { namespace d3d12 {
}
ShaderCompilationRequest request;
DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(
entryPointName, stage, layout, compileFlags, device, program,
GetEntryPoint(entryPointName).bindings));
DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(entryPointName, stage, layout,
compileFlags, device, program,
GetEntryPoint(entryPointName)));
PersistentCacheKey shaderCacheKey;
DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());

View File

@ -26,9 +26,30 @@ class ComputeDispatchTests : public DawnTest {
DawnTest::SetUp();
// Write workgroup number into the output buffer if we saw the biggest dispatch
// This is a workaround since D3D12 doesn't have gl_NumWorkGroups
// To make sure the dispatch was not called, write maximum u32 value for 0 dispatches
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
wgpu::ShaderModule moduleForDispatch = utils::CreateShaderModule(device, R"(
[[block]] struct OutputBuf {
workGroups : vec3<u32>;
};
[[group(0), binding(0)]] var<storage, read_write> output : OutputBuf;
[[stage(compute), workgroup_size(1, 1, 1)]]
fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>,
[[builtin(num_workgroups)]] dispatch : vec3<u32>) {
if (dispatch.x == 0u || dispatch.y == 0u || dispatch.z == 0u) {
output.workGroups = vec3<u32>(0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu);
return;
}
if (all(GlobalInvocationID == dispatch - vec3<u32>(1u, 1u, 1u))) {
output.workGroups = dispatch;
}
})");
// TODO(dawn:839): use moduleForDispatch for indirect dispatch tests when D3D12 supports
// [[num_workgroups]] for indirect dispatch.
wgpu::ShaderModule moduleForDispatchIndirect = utils::CreateShaderModule(device, R"(
[[block]] struct InputBuf {
expectedDispatch : vec3<u32>;
};
@ -54,9 +75,12 @@ class ComputeDispatchTests : public DawnTest {
})");
wgpu::ComputePipelineDescriptor csDesc;
csDesc.compute.module = module;
csDesc.compute.module = moduleForDispatch;
csDesc.compute.entryPoint = "main";
pipeline = device.CreateComputePipeline(&csDesc);
pipelineForDispatch = device.CreateComputePipeline(&csDesc);
csDesc.compute.module = moduleForDispatchIndirect;
pipelineForDispatchIndirect = device.CreateComputePipeline(&csDesc);
}
void DirectTest(uint32_t x, uint32_t y, uint32_t z) {
@ -66,23 +90,18 @@ class ComputeDispatchTests : public DawnTest {
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
kSentinelData);
std::initializer_list<uint32_t> expectedBufferData{x, y, z};
wgpu::Buffer expectedBuffer = utils::CreateBufferFromData<uint32_t>(
device, wgpu::BufferUsage::Uniform, expectedBufferData);
// Set up bind group and issue dispatch
wgpu::BindGroup bindGroup =
utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
utils::MakeBindGroup(device, pipelineForDispatch.GetBindGroupLayout(0),
{
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
{1, dst, 0, 3 * sizeof(uint32_t)},
{0, dst, 0, 3 * sizeof(uint32_t)},
});
wgpu::CommandBuffer commands;
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline);
pass.SetPipeline(pipelineForDispatch);
pass.SetBindGroup(0, bindGroup);
pass.Dispatch(x, y, z);
pass.EndPass();
@ -93,7 +112,7 @@ class ComputeDispatchTests : public DawnTest {
queue.Submit(1, &commands);
std::vector<uint32_t> expected =
x == 0 || y == 0 || z == 0 ? kSentinelData : expectedBufferData;
x == 0 || y == 0 || z == 0 ? kSentinelData : std::initializer_list<uint32_t>{x, y, z};
// Verify the dispatch got called if all group counts are not zero
EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3);
@ -118,7 +137,7 @@ class ComputeDispatchTests : public DawnTest {
// Set up bind group and issue dispatch
wgpu::BindGroup bindGroup =
utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
utils::MakeBindGroup(device, pipelineForDispatchIndirect.GetBindGroupLayout(0),
{
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
{1, dst, 0, 3 * sizeof(uint32_t)},
@ -128,7 +147,7 @@ class ComputeDispatchTests : public DawnTest {
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline);
pass.SetPipeline(pipelineForDispatchIndirect);
pass.SetBindGroup(0, bindGroup);
pass.DispatchIndirect(indirectBuffer, indirectOffset);
pass.EndPass();
@ -153,7 +172,8 @@ class ComputeDispatchTests : public DawnTest {
}
private:
wgpu::ComputePipeline pipeline;
wgpu::ComputePipeline pipelineForDispatch;
wgpu::ComputePipeline pipelineForDispatchIndirect;
};
// Test basic direct