D3D12: Support [[num_workgroups]] for Dispatch
This patch implements [[num_workgroups]] on the API side for Dispatch() calls by setting num_workgroups.xyz as root constants. This patch also adds a temporary validation that on D3D12 backend using a compute pipeline with [[num_workgroups]] in a DispatchIndirect call is not supported. BUG=dawn:839 TEST=dawn_end2end_tests Change-Id: Iaee2ffd162e9420e4e80944fbb222f10a4600c6a Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/66580 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
parent
a5c0c8f6be
commit
1349ca182e
|
@ -677,6 +677,8 @@ namespace dawn_native {
|
|||
metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x;
|
||||
metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y;
|
||||
metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z;
|
||||
|
||||
metadata->usesNumWorkgroups = entryPoint.num_workgroups_used;
|
||||
}
|
||||
|
||||
if (metadata->stage == SingleShaderStage::Vertex) {
|
||||
|
|
|
@ -204,6 +204,8 @@ namespace dawn_native {
|
|||
|
||||
// Store overridableConstants from tint program
|
||||
std::unordered_map<std::string, OverridableConstant> overridableConstants;
|
||||
|
||||
bool usesNumWorkgroups = false;
|
||||
};
|
||||
|
||||
class ShaderModuleBase : public ApiObjectBase, public CachedObject {
|
||||
|
|
|
@ -254,6 +254,18 @@ namespace dawn_native { namespace d3d12 {
|
|||
return {};
|
||||
}
|
||||
|
||||
void RecordNumWorkgroupsForDispatch(ID3D12GraphicsCommandList* commandList,
|
||||
ComputePipeline* pipeline,
|
||||
DispatchCmd* dispatch) {
|
||||
if (!pipeline->UsesNumWorkgroups()) {
|
||||
return;
|
||||
}
|
||||
|
||||
PipelineLayout* layout = ToBackend(pipeline->GetLayout());
|
||||
commandList->SetComputeRoot32BitConstants(layout->GetNumWorkgroupsParameterIndex(), 3,
|
||||
dispatch, 0);
|
||||
}
|
||||
|
||||
// Records the necessary barriers for a synchronization scope using the resource usage
|
||||
// data pre-computed in the frontend. Also performs lazy initialization if required.
|
||||
// Returns whether any UAV are used in the synchronization scope.
|
||||
|
@ -1030,6 +1042,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
|
||||
|
||||
Command type;
|
||||
ComputePipeline* lastPipeline = nullptr;
|
||||
while (mCommands.NextCommandId(&type)) {
|
||||
switch (type) {
|
||||
case Command::Dispatch: {
|
||||
|
@ -1045,6 +1058,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
resourceUsages.dispatchUsages[currentDispatch]);
|
||||
DAWN_TRY(bindingTracker->Apply(commandContext));
|
||||
|
||||
RecordNumWorkgroupsForDispatch(commandList, lastPipeline, dispatch);
|
||||
commandList->Dispatch(dispatch->x, dispatch->y, dispatch->z);
|
||||
currentDispatch++;
|
||||
break;
|
||||
|
@ -1052,6 +1066,14 @@ namespace dawn_native { namespace d3d12 {
|
|||
|
||||
case Command::DispatchIndirect: {
|
||||
DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
|
||||
|
||||
// TODO(dawn:839): support [[num_workgroups]] for DispatchIndirect calls
|
||||
if (lastPipeline->UsesNumWorkgroups()) {
|
||||
return DAWN_VALIDATION_ERROR(
|
||||
"Using a compute pipeline with [[num_workgroups]] in a "
|
||||
"DispatchIndirect call is not implemented");
|
||||
}
|
||||
|
||||
Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
|
||||
|
||||
TransitionAndClearForSyncScope(commandContext,
|
||||
|
@ -1078,6 +1100,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
commandList->SetPipelineState(pipeline->GetPipelineState());
|
||||
|
||||
bindingTracker->OnSetPipeline(pipeline);
|
||||
lastPipeline = pipeline;
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
|
@ -84,4 +84,8 @@ namespace dawn_native { namespace d3d12 {
|
|||
CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
|
||||
}
|
||||
|
||||
bool ComputePipeline::UsesNumWorkgroups() const {
|
||||
return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups;
|
||||
}
|
||||
|
||||
}} // namespace dawn_native::d3d12
|
||||
|
|
|
@ -40,6 +40,8 @@ namespace dawn_native { namespace d3d12 {
|
|||
// Dawn API
|
||||
void SetLabelImpl() override;
|
||||
|
||||
bool UsesNumWorkgroups() const;
|
||||
|
||||
private:
|
||||
~ComputePipeline() override;
|
||||
using ComputePipelineBase::ComputePipelineBase;
|
||||
|
|
|
@ -174,6 +174,21 @@ namespace dawn_native { namespace d3d12 {
|
|||
// would need to be updated often
|
||||
rootParameters.emplace_back(indexOffsetConstants);
|
||||
|
||||
// Always allocate 3 constants for num_workgroups_x, num_workgroups_y and num_workgroups_z
|
||||
// for Dispatch calls
|
||||
// NOTE: We should consider delaying root signature creation until we know how many values
|
||||
// we need
|
||||
D3D12_ROOT_PARAMETER numWorkgroupsConstants{};
|
||||
numWorkgroupsConstants.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL;
|
||||
numWorkgroupsConstants.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS;
|
||||
numWorkgroupsConstants.Constants.Num32BitValues = 3;
|
||||
numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace();
|
||||
numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister();
|
||||
mNumWorkgroupsParamterIndex = rootParameters.size();
|
||||
// NOTE: We should consider moving this entry to earlier in the root signature since
|
||||
// dispatch sizes would need to be updated often
|
||||
rootParameters.emplace_back(numWorkgroupsConstants);
|
||||
|
||||
D3D12_ROOT_SIGNATURE_DESC rootSignatureDescriptor;
|
||||
rootSignatureDescriptor.NumParameters = rootParameters.size();
|
||||
rootSignatureDescriptor.pParameters = rootParameters.data();
|
||||
|
@ -230,7 +245,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
}
|
||||
|
||||
uint32_t PipelineLayout::GetFirstIndexOffsetRegisterSpace() const {
|
||||
return kReservedRegisterSpace;
|
||||
return kFirstIndexOffsetRegisterSpace;
|
||||
}
|
||||
|
||||
uint32_t PipelineLayout::GetFirstIndexOffsetShaderRegister() const {
|
||||
|
@ -240,4 +255,16 @@ namespace dawn_native { namespace d3d12 {
|
|||
uint32_t PipelineLayout::GetFirstIndexOffsetParameterIndex() const {
|
||||
return mFirstIndexOffsetParameterIndex;
|
||||
}
|
||||
|
||||
uint32_t PipelineLayout::GetNumWorkgroupsRegisterSpace() const {
|
||||
return kNumWorkgroupsRegisterSpace;
|
||||
}
|
||||
|
||||
uint32_t PipelineLayout::GetNumWorkgroupsShaderRegister() const {
|
||||
return kNumWorkgroupsBaseRegister;
|
||||
}
|
||||
|
||||
uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const {
|
||||
return mNumWorkgroupsParamterIndex;
|
||||
}
|
||||
}} // namespace dawn_native::d3d12
|
||||
|
|
|
@ -26,6 +26,9 @@ namespace dawn_native { namespace d3d12 {
|
|||
// We reserve a register space that a user cannot use.
|
||||
static constexpr uint32_t kReservedRegisterSpace = kMaxBindGroups + 1;
|
||||
static constexpr uint32_t kFirstOffsetInfoBaseRegister = 0;
|
||||
static constexpr uint32_t kFirstIndexOffsetRegisterSpace = kReservedRegisterSpace;
|
||||
static constexpr uint32_t kNumWorkgroupsRegisterSpace = kReservedRegisterSpace + 1;
|
||||
static constexpr uint32_t kNumWorkgroupsBaseRegister = 0;
|
||||
|
||||
class Device;
|
||||
|
||||
|
@ -46,6 +49,10 @@ namespace dawn_native { namespace d3d12 {
|
|||
uint32_t GetFirstIndexOffsetShaderRegister() const;
|
||||
uint32_t GetFirstIndexOffsetParameterIndex() const;
|
||||
|
||||
uint32_t GetNumWorkgroupsRegisterSpace() const;
|
||||
uint32_t GetNumWorkgroupsShaderRegister() const;
|
||||
uint32_t GetNumWorkgroupsParameterIndex() const;
|
||||
|
||||
ID3D12RootSignature* GetRootSignature() const;
|
||||
|
||||
private:
|
||||
|
@ -59,6 +66,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
kMaxBindGroups>
|
||||
mDynamicRootParameterIndices;
|
||||
uint32_t mFirstIndexOffsetParameterIndex;
|
||||
uint32_t mNumWorkgroupsParamterIndex;
|
||||
ComPtr<ID3D12RootSignature> mRootSignature;
|
||||
};
|
||||
|
||||
|
|
|
@ -106,6 +106,9 @@ namespace dawn_native { namespace d3d12 {
|
|||
tint::transform::BindingRemapper::BindingPoints bindingPoints;
|
||||
tint::transform::BindingRemapper::AccessControls accessControls;
|
||||
bool isRobustnessEnabled;
|
||||
bool usesNumWorkgroups;
|
||||
uint32_t numWorkgroupsRegisterSpace;
|
||||
uint32_t numWorkgroupsShaderRegister;
|
||||
|
||||
// FXC/DXC common inputs
|
||||
bool disableWorkgroupInit;
|
||||
|
@ -125,7 +128,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
uint32_t compileFlags,
|
||||
const Device* device,
|
||||
const tint::Program* program,
|
||||
const BindingInfoArray& moduleBindingInfo) {
|
||||
const EntryPointMetadata& entryPoint) {
|
||||
Compiler compiler;
|
||||
uint64_t dxcVersion = 0;
|
||||
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
||||
|
@ -145,6 +148,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
// Tint AST to make the "bindings" decoration match the offset chosen by
|
||||
// d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers
|
||||
// assigned to each interface variable.
|
||||
const BindingInfoArray& moduleBindingInfo = entryPoint.bindings;
|
||||
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
|
||||
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
|
||||
const auto& groupBindingInfo = moduleBindingInfo[group];
|
||||
|
@ -189,6 +193,9 @@ namespace dawn_native { namespace d3d12 {
|
|||
request.isRobustnessEnabled = device->IsRobustnessEnabled();
|
||||
request.disableWorkgroupInit =
|
||||
device->IsToggleEnabled(Toggle::DisableWorkgroupInit);
|
||||
request.usesNumWorkgroups = entryPoint.usesNumWorkgroups;
|
||||
request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister();
|
||||
request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace();
|
||||
request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0;
|
||||
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
|
||||
request.deviceInfo = &device->GetDeviceInfo();
|
||||
|
@ -234,6 +241,10 @@ namespace dawn_native { namespace d3d12 {
|
|||
stream << " accessControls=";
|
||||
Serialize(stream, accessControls);
|
||||
|
||||
stream << " useNumWorkgroups=" << usesNumWorkgroups;
|
||||
stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace;
|
||||
stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister;
|
||||
|
||||
stream << " shaderModel=" << deviceInfo->shaderModel;
|
||||
stream << " disableWorkgroupInit=" << disableWorkgroupInit;
|
||||
stream << " isRobustnessEnabled=" << isRobustnessEnabled;
|
||||
|
@ -423,6 +434,10 @@ namespace dawn_native { namespace d3d12 {
|
|||
|
||||
tint::writer::hlsl::Options options;
|
||||
options.disable_workgroup_init = request.disableWorkgroupInit;
|
||||
if (request.usesNumWorkgroups) {
|
||||
options.root_constant_binding_point.group = request.numWorkgroupsRegisterSpace;
|
||||
options.root_constant_binding_point.binding = request.numWorkgroupsShaderRegister;
|
||||
}
|
||||
auto result = tint::writer::hlsl::Generate(&transformedProgram, options);
|
||||
if (!result.success) {
|
||||
errorStream << "Generator: " << result.error << std::endl;
|
||||
|
@ -547,9 +562,9 @@ namespace dawn_native { namespace d3d12 {
|
|||
}
|
||||
|
||||
ShaderCompilationRequest request;
|
||||
DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(
|
||||
entryPointName, stage, layout, compileFlags, device, program,
|
||||
GetEntryPoint(entryPointName).bindings));
|
||||
DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(entryPointName, stage, layout,
|
||||
compileFlags, device, program,
|
||||
GetEntryPoint(entryPointName)));
|
||||
|
||||
PersistentCacheKey shaderCacheKey;
|
||||
DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey());
|
||||
|
|
|
@ -26,9 +26,30 @@ class ComputeDispatchTests : public DawnTest {
|
|||
DawnTest::SetUp();
|
||||
|
||||
// Write workgroup number into the output buffer if we saw the biggest dispatch
|
||||
// This is a workaround since D3D12 doesn't have gl_NumWorkGroups
|
||||
// To make sure the dispatch was not called, write maximum u32 value for 0 dispatches
|
||||
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
|
||||
wgpu::ShaderModule moduleForDispatch = utils::CreateShaderModule(device, R"(
|
||||
[[block]] struct OutputBuf {
|
||||
workGroups : vec3<u32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<storage, read_write> output : OutputBuf;
|
||||
|
||||
[[stage(compute), workgroup_size(1, 1, 1)]]
|
||||
fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>,
|
||||
[[builtin(num_workgroups)]] dispatch : vec3<u32>) {
|
||||
if (dispatch.x == 0u || dispatch.y == 0u || dispatch.z == 0u) {
|
||||
output.workGroups = vec3<u32>(0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu);
|
||||
return;
|
||||
}
|
||||
|
||||
if (all(GlobalInvocationID == dispatch - vec3<u32>(1u, 1u, 1u))) {
|
||||
output.workGroups = dispatch;
|
||||
}
|
||||
})");
|
||||
|
||||
// TODO(dawn:839): use moduleForDispatch for indirect dispatch tests when D3D12 supports
|
||||
// [[num_workgroups]] for indirect dispatch.
|
||||
wgpu::ShaderModule moduleForDispatchIndirect = utils::CreateShaderModule(device, R"(
|
||||
[[block]] struct InputBuf {
|
||||
expectedDispatch : vec3<u32>;
|
||||
};
|
||||
|
@ -54,9 +75,12 @@ class ComputeDispatchTests : public DawnTest {
|
|||
})");
|
||||
|
||||
wgpu::ComputePipelineDescriptor csDesc;
|
||||
csDesc.compute.module = module;
|
||||
csDesc.compute.module = moduleForDispatch;
|
||||
csDesc.compute.entryPoint = "main";
|
||||
pipeline = device.CreateComputePipeline(&csDesc);
|
||||
pipelineForDispatch = device.CreateComputePipeline(&csDesc);
|
||||
|
||||
csDesc.compute.module = moduleForDispatchIndirect;
|
||||
pipelineForDispatchIndirect = device.CreateComputePipeline(&csDesc);
|
||||
}
|
||||
|
||||
void DirectTest(uint32_t x, uint32_t y, uint32_t z) {
|
||||
|
@ -66,23 +90,18 @@ class ComputeDispatchTests : public DawnTest {
|
|||
wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst,
|
||||
kSentinelData);
|
||||
|
||||
std::initializer_list<uint32_t> expectedBufferData{x, y, z};
|
||||
wgpu::Buffer expectedBuffer = utils::CreateBufferFromData<uint32_t>(
|
||||
device, wgpu::BufferUsage::Uniform, expectedBufferData);
|
||||
|
||||
// Set up bind group and issue dispatch
|
||||
wgpu::BindGroup bindGroup =
|
||||
utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
|
||||
utils::MakeBindGroup(device, pipelineForDispatch.GetBindGroupLayout(0),
|
||||
{
|
||||
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
|
||||
{1, dst, 0, 3 * sizeof(uint32_t)},
|
||||
{0, dst, 0, 3 * sizeof(uint32_t)},
|
||||
});
|
||||
|
||||
wgpu::CommandBuffer commands;
|
||||
{
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
pass.SetPipeline(pipeline);
|
||||
pass.SetPipeline(pipelineForDispatch);
|
||||
pass.SetBindGroup(0, bindGroup);
|
||||
pass.Dispatch(x, y, z);
|
||||
pass.EndPass();
|
||||
|
@ -93,7 +112,7 @@ class ComputeDispatchTests : public DawnTest {
|
|||
queue.Submit(1, &commands);
|
||||
|
||||
std::vector<uint32_t> expected =
|
||||
x == 0 || y == 0 || z == 0 ? kSentinelData : expectedBufferData;
|
||||
x == 0 || y == 0 || z == 0 ? kSentinelData : std::initializer_list<uint32_t>{x, y, z};
|
||||
|
||||
// Verify the dispatch got called if all group counts are not zero
|
||||
EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3);
|
||||
|
@ -118,7 +137,7 @@ class ComputeDispatchTests : public DawnTest {
|
|||
|
||||
// Set up bind group and issue dispatch
|
||||
wgpu::BindGroup bindGroup =
|
||||
utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
|
||||
utils::MakeBindGroup(device, pipelineForDispatchIndirect.GetBindGroupLayout(0),
|
||||
{
|
||||
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
|
||||
{1, dst, 0, 3 * sizeof(uint32_t)},
|
||||
|
@ -128,7 +147,7 @@ class ComputeDispatchTests : public DawnTest {
|
|||
{
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
pass.SetPipeline(pipeline);
|
||||
pass.SetPipeline(pipelineForDispatchIndirect);
|
||||
pass.SetBindGroup(0, bindGroup);
|
||||
pass.DispatchIndirect(indirectBuffer, indirectOffset);
|
||||
pass.EndPass();
|
||||
|
@ -153,7 +172,8 @@ class ComputeDispatchTests : public DawnTest {
|
|||
}
|
||||
|
||||
private:
|
||||
wgpu::ComputePipeline pipeline;
|
||||
wgpu::ComputePipeline pipelineForDispatch;
|
||||
wgpu::ComputePipeline pipelineForDispatchIndirect;
|
||||
};
|
||||
|
||||
// Test basic direct
|
||||
|
|
Loading…
Reference in New Issue