D3D12: Support [[num_workgroups]] for DispatchIndirect

This patch supports [[num_workgroups]] on D3D12 for DispatchIndirect
by appending the values of [[num_workgroups]] at the end of the
scratch buffer for indirect dispatch validation and setting them as
the root constants in the command signature.

With this patch, for every DispatchIndirect call:
- On D3D12:
1. Validation enabled, [[num_workgroups]] is used
The dispatch indirect buffer needs to be validated, duplicated and
be written into a scratch buffer (size: 6 * uint32_t).

2. Validation enabled, [[num_workgroups]] isn't used
The dispatch indirect buffer needs to be validated and be written
into a scratch buffer (size: 3 * uint32_t).

3. Validation disabled, [[num_workgroups]] is used
The dispatch indirect buffer needs to be duplicated and be written
into a scratch buffer (size: 6 * uint32_t).

4. Validation disabled, [[num_workgroups]] isn't used
Neither transformations or scratch buffers are needed for the dispatch
call.

- On the other backends:
1. Validation enabled,
The dispatch indirect buffer needs to be validated and be written
into a scratch buffer (size: 3 * uint32_t).

2. Validation disabled,
Neither transformations or scratch buffers are needed for the dispatch
call.

BUG=dawn:839
TEST=dawn_end2end_tests

Change-Id: I4105f1b2e3c12f6df6e487ed535a627fbb342344
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/68843
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
Jiawei Shao 2021-11-17 05:00:44 +00:00 committed by Dawn LUCI CQ
parent 8ce15b3ce9
commit 829d165d7c
13 changed files with 198 additions and 82 deletions

View File

@ -152,9 +152,11 @@ namespace dawn_native {
mUsage |= kInternalStorageBuffer; mUsage |= kInternalStorageBuffer;
} }
// We also add internal storage usage for Indirect buffers if validation is enabled, since // We also add internal storage usage for Indirect buffers for some transformations before
// validation involves binding them as storage buffers for use in a compute pass. // DispatchIndirect calls on the backend (e.g. validations, support of [[num_workgroups]] on
if ((mUsage & wgpu::BufferUsage::Indirect) && device->IsValidationEnabled()) { // D3D12), since these transformations involve binding them as storage buffers for use in a
// compute pass.
if (mUsage & wgpu::BufferUsage::Indirect) {
mUsage |= kInternalStorageBuffer; mUsage |= kInternalStorageBuffer;
} }

View File

@ -42,11 +42,14 @@ namespace dawn_native {
// TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this
// shader in various failure modes. // shader in various failure modes.
// Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable.
Ref<ShaderModuleBase> shaderModule; Ref<ShaderModuleBase> shaderModule;
DAWN_TRY_ASSIGN(shaderModule, utils::CreateShaderModule(device, R"( DAWN_TRY_ASSIGN(shaderModule, utils::CreateShaderModule(device, R"(
[[block]] struct UniformParams { [[block]] struct UniformParams {
maxComputeWorkgroupsPerDimension: u32; maxComputeWorkgroupsPerDimension: u32;
clientOffsetInU32: u32; clientOffsetInU32: u32;
enableValidation: u32;
duplicateNumWorkgroups: u32;
}; };
[[block]] struct IndirectParams { [[block]] struct IndirectParams {
@ -54,7 +57,7 @@ namespace dawn_native {
}; };
[[block]] struct ValidatedParams { [[block]] struct ValidatedParams {
data: array<u32, 3>; data: array<u32>;
}; };
[[group(0), binding(0)]] var<uniform> uniformParams: UniformParams; [[group(0), binding(0)]] var<uniform> uniformParams: UniformParams;
@ -65,10 +68,15 @@ namespace dawn_native {
fn main() { fn main() {
for (var i = 0u; i < 3u; i = i + 1u) { for (var i = 0u; i < 3u; i = i + 1u) {
var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i]; var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i];
if (numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) { if (uniformParams.enableValidation > 0u &&
numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) {
numWorkgroups = 0u; numWorkgroups = 0u;
} }
validatedParams.data[i] = numWorkgroups; validatedParams.data[i] = numWorkgroups;
if (uniformParams.duplicateNumWorkgroups > 0u) {
validatedParams.data[i + 3u] = numWorkgroups;
}
} }
} }
)")); )"));
@ -191,9 +199,21 @@ namespace dawn_native {
} }
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ResultOrError<std::pair<Ref<BufferBase>, uint64_t>>
ComputePassEncoder::ValidateIndirectDispatch(BufferBase* indirectBuffer, ComputePassEncoder::TransformIndirectDispatchBuffer(Ref<BufferBase> indirectBuffer,
uint64_t indirectOffset) { uint64_t indirectOffset) {
DeviceBase* device = GetDevice(); DeviceBase* device = GetDevice();
const bool shouldDuplicateNumWorkgroups =
device->ShouldDuplicateNumWorkgroupsForDispatchIndirect(
mCommandBufferState.GetComputePipeline());
if (!IsValidationEnabled() && !shouldDuplicateNumWorkgroups) {
return std::make_pair(indirectBuffer, indirectOffset);
}
// Save the previous command buffer state so it can be restored after the
// validation inserts additional commands.
CommandBufferStateTracker previousState = mCommandBufferState;
auto* const store = device->GetInternalPipelineStore(); auto* const store = device->GetInternalPipelineStore();
Ref<ComputePipelineBase> validationPipeline; Ref<ComputePipelineBase> validationPipeline;
@ -215,9 +235,14 @@ namespace dawn_native {
const uint64_t clientIndirectBindingSize = const uint64_t clientIndirectBindingSize =
kDispatchIndirectSize + clientOffsetFromAlignedBoundary; kDispatchIndirectSize + clientOffsetFromAlignedBoundary;
// Neither 'enableValidation' nor 'duplicateNumWorkgroups' can be declared as 'bool' as
// currently in WGSL type 'bool' cannot be used in storage class 'uniform' as 'it is
// non-host-shareable'.
struct UniformParams { struct UniformParams {
uint32_t maxComputeWorkgroupsPerDimension; uint32_t maxComputeWorkgroupsPerDimension;
uint32_t clientOffsetInU32; uint32_t clientOffsetInU32;
uint32_t enableValidation;
uint32_t duplicateNumWorkgroups;
}; };
// Create a uniform buffer to hold parameters for the shader. // Create a uniform buffer to hold parameters for the shader.
@ -227,6 +252,8 @@ namespace dawn_native {
params.maxComputeWorkgroupsPerDimension = params.maxComputeWorkgroupsPerDimension =
device->GetLimits().v1.maxComputeWorkgroupsPerDimension; device->GetLimits().v1.maxComputeWorkgroupsPerDimension;
params.clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t); params.clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t);
params.enableValidation = static_cast<uint32_t>(IsValidationEnabled());
params.duplicateNumWorkgroups = static_cast<uint32_t>(shouldDuplicateNumWorkgroups);
DAWN_TRY_ASSIGN(uniformBuffer, utils::CreateBufferFromData( DAWN_TRY_ASSIGN(uniformBuffer, utils::CreateBufferFromData(
device, wgpu::BufferUsage::Uniform, {params})); device, wgpu::BufferUsage::Uniform, {params}));
@ -234,25 +261,30 @@ namespace dawn_native {
// Reserve space in the scratch buffer to hold the validated indirect params. // Reserve space in the scratch buffer to hold the validated indirect params.
ScratchBuffer& scratchBuffer = store->scratchIndirectStorage; ScratchBuffer& scratchBuffer = store->scratchIndirectStorage;
DAWN_TRY(scratchBuffer.EnsureCapacity(kDispatchIndirectSize)); const uint64_t scratchBufferSize =
shouldDuplicateNumWorkgroups ? 2 * kDispatchIndirectSize : kDispatchIndirectSize;
DAWN_TRY(scratchBuffer.EnsureCapacity(scratchBufferSize));
Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer(); Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer();
Ref<BindGroupBase> validationBindGroup; Ref<BindGroupBase> validationBindGroup;
DAWN_TRY_ASSIGN( ASSERT(indirectBuffer->GetUsage() & kInternalStorageBuffer);
validationBindGroup, DAWN_TRY_ASSIGN(validationBindGroup,
utils::MakeBindGroup( utils::MakeBindGroup(device, layout,
device, layout, {
{ {0, uniformBuffer},
{0, uniformBuffer}, {1, indirectBuffer, clientIndirectBindingOffset,
{1, indirectBuffer, clientIndirectBindingOffset, clientIndirectBindingSize}, clientIndirectBindingSize},
{2, validatedIndirectBuffer, 0, kDispatchIndirectSize}, {2, validatedIndirectBuffer, 0, scratchBufferSize},
})); }));
// Issue commands to validate the indirect buffer. // Issue commands to validate the indirect buffer.
APISetPipeline(validationPipeline.Get()); APISetPipeline(validationPipeline.Get());
APISetBindGroup(0, validationBindGroup.Get()); APISetBindGroup(0, validationBindGroup.Get());
APIDispatch(1); APIDispatch(1);
// Restore the state.
RestoreCommandBufferState(std::move(previousState));
// Return the new indirect buffer and indirect buffer offset. // Return the new indirect buffer and indirect buffer offset.
return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0)); return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0));
} }
@ -287,27 +319,28 @@ namespace dawn_native {
// the backend. // the backend.
Ref<BufferBase> indirectBufferRef = indirectBuffer; Ref<BufferBase> indirectBufferRef = indirectBuffer;
if (IsValidationEnabled()) {
// Save the previous command buffer state so it can be restored after the
// validation inserts additional commands.
CommandBufferStateTracker previousState = mCommandBufferState;
// Validate each indirect dispatch with a single dispatch to copy the indirect // Get applied indirect buffer with necessary changes on the original indirect
// buffer params into a scratch buffer if they're valid, and otherwise zero them // buffer. For example,
// out. We could consider moving the validation earlier in the pass after the // - Validate each indirect dispatch with a single dispatch to copy the indirect
// last point the indirect buffer was used with writable usage, as well as batch // buffer params into a scratch buffer if they're valid, and otherwise zero them
// validation for multiple dispatches into one, but inserting commands at // out.
// arbitrary points in the past is not possible right now. // - Duplicate all the indirect dispatch parameters to support [[num_workgroups]] on
DAWN_TRY_ASSIGN( // D3D12.
std::tie(indirectBufferRef, indirectOffset), // - Directly return the original indirect dispatch buffer if we don't need any
ValidateIndirectDispatch(indirectBufferRef.Get(), indirectOffset)); // transformations on it.
// We could consider moving the validation earlier in the pass after the last
// Restore the state. // last point the indirect buffer was used with writable usage, as well as batch
RestoreCommandBufferState(std::move(previousState)); // validation for multiple dispatches into one, but inserting commands at
// arbitrary points in the past is not possible right now.
DAWN_TRY_ASSIGN(std::tie(indirectBufferRef, indirectOffset),
TransformIndirectDispatchBuffer(indirectBufferRef, indirectOffset));
// If we have created a new scratch dispatch indirect buffer in
// TransformIndirectDispatchBuffer(), we need to track it in mUsageTracker.
if (indirectBufferRef.Get() != indirectBuffer) {
// |indirectBufferRef| was replaced with a scratch buffer. Add it to the // |indirectBufferRef| was replaced with a scratch buffer. Add it to the
// synchronization scope. // synchronization scope.
ASSERT(indirectBufferRef.Get() != indirectBuffer);
scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect); scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect);
mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get()); mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get());
} }

View File

@ -64,8 +64,8 @@ namespace dawn_native {
private: private:
void DestroyImpl() override; void DestroyImpl() override;
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ValidateIndirectDispatch( ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> TransformIndirectDispatchBuffer(
BufferBase* indirectBuffer, Ref<BufferBase> indirectBuffer,
uint64_t indirectOffset); uint64_t indirectOffset);
void RestoreCommandBufferState(CommandBufferStateTracker state); void RestoreCommandBufferState(CommandBufferStateTracker state);

View File

@ -1670,4 +1670,9 @@ namespace dawn_native {
void DeviceBase::SetLabelImpl() { void DeviceBase::SetLabelImpl() {
} }
bool DeviceBase::ShouldDuplicateNumWorkgroupsForDispatchIndirect(
ComputePipelineBase* computePipeline) const {
return false;
}
} // namespace dawn_native } // namespace dawn_native

View File

@ -336,11 +336,16 @@ namespace dawn_native {
MaybeError Tick(); MaybeError Tick();
// TODO(crbug.com/dawn/839): Organize the below backend-specific parameters into the struct
// BackendMetadata that we can query from the device.
virtual uint32_t GetOptimalBytesPerRowAlignment() const = 0; virtual uint32_t GetOptimalBytesPerRowAlignment() const = 0;
virtual uint64_t GetOptimalBufferToTextureCopyOffsetAlignment() const = 0; virtual uint64_t GetOptimalBufferToTextureCopyOffsetAlignment() const = 0;
virtual float GetTimestampPeriodInNS() const = 0; virtual float GetTimestampPeriodInNS() const = 0;
virtual bool ShouldDuplicateNumWorkgroupsForDispatchIndirect(
ComputePipelineBase* computePipeline) const;
const CombinedLimits& GetLimits() const; const CombinedLimits& GetLimits() const;
AsyncTaskManager* GetAsyncTaskManager() const; AsyncTaskManager* GetAsyncTaskManager() const;

View File

@ -1063,22 +1063,15 @@ namespace dawn_native { namespace d3d12 {
case Command::DispatchIndirect: { case Command::DispatchIndirect: {
DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>(); DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
// TODO(dawn:839): support [[num_workgroups]] for DispatchIndirect calls
DAWN_INVALID_IF(lastPipeline->UsesNumWorkgroups(),
"Using %s with [[num_workgroups]] in a DispatchIndirect call "
"is not implemented.",
lastPipeline);
Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
TransitionAndClearForSyncScope(commandContext, TransitionAndClearForSyncScope(commandContext,
resourceUsages.dispatchUsages[currentDispatch]); resourceUsages.dispatchUsages[currentDispatch]);
DAWN_TRY(bindingTracker->Apply(commandContext)); DAWN_TRY(bindingTracker->Apply(commandContext));
ComPtr<ID3D12CommandSignature> signature = ComPtr<ID3D12CommandSignature> signature =
ToBackend(GetDevice())->GetDispatchIndirectSignature(); lastPipeline->GetDispatchIndirectCommandSignature();
commandList->ExecuteIndirect(signature.Get(), 1, buffer->GetD3D12Resource(), commandList->ExecuteIndirect(
dispatch->indirectOffset, nullptr, 0); signature.Get(), 1, ToBackend(dispatch->indirectBuffer)->GetD3D12Resource(),
dispatch->indirectOffset, nullptr, 0);
currentDispatch++; currentDispatch++;
break; break;
} }

View File

@ -90,4 +90,11 @@ namespace dawn_native { namespace d3d12 {
return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups; return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups;
} }
ComPtr<ID3D12CommandSignature> ComputePipeline::GetDispatchIndirectCommandSignature() {
if (UsesNumWorkgroups()) {
return ToBackend(GetLayout())->GetDispatchIndirectCommandSignatureWithNumWorkgroups();
}
return ToBackend(GetDevice())->GetDispatchIndirectSignature();
}
}} // namespace dawn_native::d3d12 }} // namespace dawn_native::d3d12

View File

@ -42,6 +42,8 @@ namespace dawn_native { namespace d3d12 {
bool UsesNumWorkgroups() const; bool UsesNumWorkgroups() const;
ComPtr<ID3D12CommandSignature> GetDispatchIndirectCommandSignature();
private: private:
~ComputePipeline() override; ~ComputePipeline() override;

View File

@ -675,4 +675,9 @@ namespace dawn_native { namespace d3d12 {
return mTimestampPeriod; return mTimestampPeriod;
} }
bool Device::ShouldDuplicateNumWorkgroupsForDispatchIndirect(
ComputePipelineBase* computePipeline) const {
return ToBackend(computePipeline)->UsesNumWorkgroups();
}
}} // namespace dawn_native::d3d12 }} // namespace dawn_native::d3d12

View File

@ -139,6 +139,9 @@ namespace dawn_native { namespace d3d12 {
float GetTimestampPeriodInNS() const override; float GetTimestampPeriodInNS() const override;
bool ShouldDuplicateNumWorkgroupsForDispatchIndirect(
ComputePipelineBase* computePipeline) const override;
private: private:
using DeviceBase::DeviceBase; using DeviceBase::DeviceBase;

View File

@ -184,7 +184,7 @@ namespace dawn_native { namespace d3d12 {
numWorkgroupsConstants.Constants.Num32BitValues = 3; numWorkgroupsConstants.Constants.Num32BitValues = 3;
numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace(); numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace();
numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister(); numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister();
mNumWorkgroupsParamterIndex = rootParameters.size(); mNumWorkgroupsParameterIndex = rootParameters.size();
// NOTE: We should consider moving this entry to earlier in the root signature since // NOTE: We should consider moving this entry to earlier in the root signature since
// dispatch sizes would need to be updated often // dispatch sizes would need to be updated often
rootParameters.emplace_back(numWorkgroupsConstants); rootParameters.emplace_back(numWorkgroupsConstants);
@ -265,6 +265,38 @@ namespace dawn_native { namespace d3d12 {
} }
uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const { uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const {
return mNumWorkgroupsParamterIndex; return mNumWorkgroupsParameterIndex;
} }
ID3D12CommandSignature* PipelineLayout::GetDispatchIndirectCommandSignatureWithNumWorkgroups() {
// mDispatchIndirectCommandSignatureWithNumWorkgroups won't be created until it is needed.
if (mDispatchIndirectCommandSignatureWithNumWorkgroups.Get() != nullptr) {
return mDispatchIndirectCommandSignatureWithNumWorkgroups.Get();
}
D3D12_INDIRECT_ARGUMENT_DESC argumentDescs[2] = {};
argumentDescs[0].Type = D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT;
argumentDescs[0].Constant.RootParameterIndex = GetNumWorkgroupsParameterIndex();
argumentDescs[0].Constant.Num32BitValuesToSet = 3;
argumentDescs[0].Constant.DestOffsetIn32BitValues = 0;
// A command signature must contain exactly 1 Draw / Dispatch / DispatchMesh / DispatchRays
// command. That command must come last.
argumentDescs[1].Type = D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH;
D3D12_COMMAND_SIGNATURE_DESC programDesc = {};
programDesc.ByteStride = 6 * sizeof(uint32_t);
programDesc.NumArgumentDescs = 2;
programDesc.pArgumentDescs = argumentDescs;
// The root signature must be specified if and only if the command signature changes one of
// the root arguments.
ToBackend(GetDevice())
->GetD3D12Device()
->CreateCommandSignature(
&programDesc, GetRootSignature(),
IID_PPV_ARGS(&mDispatchIndirectCommandSignatureWithNumWorkgroups));
return mDispatchIndirectCommandSignatureWithNumWorkgroups.Get();
}
}} // namespace dawn_native::d3d12 }} // namespace dawn_native::d3d12

View File

@ -55,6 +55,8 @@ namespace dawn_native { namespace d3d12 {
ID3D12RootSignature* GetRootSignature() const; ID3D12RootSignature* GetRootSignature() const;
ID3D12CommandSignature* GetDispatchIndirectCommandSignatureWithNumWorkgroups();
private: private:
~PipelineLayout() override = default; ~PipelineLayout() override = default;
using PipelineLayoutBase::PipelineLayoutBase; using PipelineLayoutBase::PipelineLayoutBase;
@ -66,8 +68,9 @@ namespace dawn_native { namespace d3d12 {
kMaxBindGroups> kMaxBindGroups>
mDynamicRootParameterIndices; mDynamicRootParameterIndices;
uint32_t mFirstIndexOffsetParameterIndex; uint32_t mFirstIndexOffsetParameterIndex;
uint32_t mNumWorkgroupsParamterIndex; uint32_t mNumWorkgroupsParameterIndex;
ComPtr<ID3D12RootSignature> mRootSignature; ComPtr<ID3D12RootSignature> mRootSignature;
ComPtr<ID3D12CommandSignature> mDispatchIndirectCommandSignatureWithNumWorkgroups;
}; };
}} // namespace dawn_native::d3d12 }} // namespace dawn_native::d3d12

View File

@ -27,7 +27,7 @@ class ComputeDispatchTests : public DawnTest {
// Write workgroup number into the output buffer if we saw the biggest dispatch // Write workgroup number into the output buffer if we saw the biggest dispatch
// To make sure the dispatch was not called, write maximum u32 value for 0 dispatches // To make sure the dispatch was not called, write maximum u32 value for 0 dispatches
wgpu::ShaderModule moduleForDispatch = utils::CreateShaderModule(device, R"( wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
[[block]] struct OutputBuf { [[block]] struct OutputBuf {
workGroups : vec3<u32>; workGroups : vec3<u32>;
}; };
@ -47,9 +47,13 @@ class ComputeDispatchTests : public DawnTest {
} }
})"); })");
// TODO(dawn:839): use moduleForDispatch for indirect dispatch tests when D3D12 supports wgpu::ComputePipelineDescriptor csDesc;
// [[num_workgroups]] for indirect dispatch. csDesc.compute.module = module;
wgpu::ShaderModule moduleForDispatchIndirect = utils::CreateShaderModule(device, R"( csDesc.compute.entryPoint = "main";
pipeline = device.CreateComputePipeline(&csDesc);
// Test the use of the compute pipelines without using [[num_workgroups]]
wgpu::ShaderModule moduleWithoutNumWorkgroups = utils::CreateShaderModule(device, R"(
[[block]] struct InputBuf { [[block]] struct InputBuf {
expectedDispatch : vec3<u32>; expectedDispatch : vec3<u32>;
}; };
@ -73,14 +77,8 @@ class ComputeDispatchTests : public DawnTest {
output.workGroups = dispatch; output.workGroups = dispatch;
} }
})"); })");
csDesc.compute.module = moduleWithoutNumWorkgroups;
wgpu::ComputePipelineDescriptor csDesc; pipelineWithoutNumWorkgroups = device.CreateComputePipeline(&csDesc);
csDesc.compute.module = moduleForDispatch;
csDesc.compute.entryPoint = "main";
pipelineForDispatch = device.CreateComputePipeline(&csDesc);
csDesc.compute.module = moduleForDispatchIndirect;
pipelineForDispatchIndirect = device.CreateComputePipeline(&csDesc);
} }
void DirectTest(uint32_t x, uint32_t y, uint32_t z) { void DirectTest(uint32_t x, uint32_t y, uint32_t z) {
@ -91,17 +89,16 @@ class ComputeDispatchTests : public DawnTest {
kSentinelData); kSentinelData);
// Set up bind group and issue dispatch // Set up bind group and issue dispatch
wgpu::BindGroup bindGroup = wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
utils::MakeBindGroup(device, pipelineForDispatch.GetBindGroupLayout(0), {
{ {0, dst, 0, 3 * sizeof(uint32_t)},
{0, dst, 0, 3 * sizeof(uint32_t)}, });
});
wgpu::CommandBuffer commands; wgpu::CommandBuffer commands;
{ {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipelineForDispatch); pass.SetPipeline(pipeline);
pass.SetBindGroup(0, bindGroup); pass.SetBindGroup(0, bindGroup);
pass.Dispatch(x, y, z); pass.Dispatch(x, y, z);
pass.EndPass(); pass.EndPass();
@ -118,7 +115,9 @@ class ComputeDispatchTests : public DawnTest {
EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3); EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3);
} }
void IndirectTest(std::vector<uint32_t> indirectBufferData, uint64_t indirectOffset) { void IndirectTest(std::vector<uint32_t> indirectBufferData,
uint64_t indirectOffset,
bool useNumWorkgroups = true) {
// Set up dst storage buffer to contain dispatch x, y, z // Set up dst storage buffer to contain dispatch x, y, z
wgpu::Buffer dst = utils::CreateBufferFromData<uint32_t>( wgpu::Buffer dst = utils::CreateBufferFromData<uint32_t>(
device, device,
@ -131,23 +130,34 @@ class ComputeDispatchTests : public DawnTest {
uint32_t indirectStart = indirectOffset / sizeof(uint32_t); uint32_t indirectStart = indirectOffset / sizeof(uint32_t);
wgpu::Buffer expectedBuffer =
utils::CreateBufferFromData(device, &indirectBufferData[indirectStart],
3 * sizeof(uint32_t), wgpu::BufferUsage::Uniform);
// Set up bind group and issue dispatch // Set up bind group and issue dispatch
wgpu::BindGroup bindGroup = wgpu::BindGroup bindGroup;
utils::MakeBindGroup(device, pipelineForDispatchIndirect.GetBindGroupLayout(0), wgpu::ComputePipeline computePipelineForTest;
{
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)}, if (useNumWorkgroups) {
{1, dst, 0, 3 * sizeof(uint32_t)}, computePipelineForTest = pipeline;
}); bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
{
{0, dst, 0, 3 * sizeof(uint32_t)},
});
} else {
computePipelineForTest = pipelineWithoutNumWorkgroups;
wgpu::Buffer expectedBuffer =
utils::CreateBufferFromData(device, &indirectBufferData[indirectStart],
3 * sizeof(uint32_t), wgpu::BufferUsage::Uniform);
bindGroup =
utils::MakeBindGroup(device, pipelineWithoutNumWorkgroups.GetBindGroupLayout(0),
{
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
{1, dst, 0, 3 * sizeof(uint32_t)},
});
}
wgpu::CommandBuffer commands; wgpu::CommandBuffer commands;
{ {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipelineForDispatchIndirect); pass.SetPipeline(computePipelineForTest);
pass.SetBindGroup(0, bindGroup); pass.SetBindGroup(0, bindGroup);
pass.DispatchIndirect(indirectBuffer, indirectOffset); pass.DispatchIndirect(indirectBuffer, indirectOffset);
pass.EndPass(); pass.EndPass();
@ -178,8 +188,8 @@ class ComputeDispatchTests : public DawnTest {
} }
private: private:
wgpu::ComputePipeline pipelineForDispatch; wgpu::ComputePipeline pipeline;
wgpu::ComputePipeline pipelineForDispatchIndirect; wgpu::ComputePipeline pipelineWithoutNumWorkgroups;
}; };
// Test basic direct // Test basic direct
@ -207,6 +217,11 @@ TEST_P(ComputeDispatchTests, IndirectBasic) {
IndirectTest({2, 3, 4}, 0); IndirectTest({2, 3, 4}, 0);
} }
// Test basic indirect without using [[num_workgroups]]
TEST_P(ComputeDispatchTests, IndirectBasicWithoutNumWorkgroups) {
IndirectTest({2, 3, 4}, 0, false);
}
// Test no-op indirect // Test no-op indirect
TEST_P(ComputeDispatchTests, IndirectNoop) { TEST_P(ComputeDispatchTests, IndirectNoop) {
// All dimensions are 0s // All dimensions are 0s
@ -227,6 +242,11 @@ TEST_P(ComputeDispatchTests, IndirectOffset) {
IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t)); IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t));
} }
// Test indirect with buffer offset without using [[num_workgroups]]
TEST_P(ComputeDispatchTests, IndirectOffsetWithoutNumWorkgroups) {
IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t), false);
}
// Test indirect dispatches at max limit. // Test indirect dispatches at max limit.
TEST_P(ComputeDispatchTests, MaxWorkgroups) { TEST_P(ComputeDispatchTests, MaxWorkgroups) {
// TODO(crbug.com/dawn/1165): Fails with WARP // TODO(crbug.com/dawn/1165): Fails with WARP
@ -244,6 +264,9 @@ TEST_P(ComputeDispatchTests, MaxWorkgroups) {
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) { TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation")); DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
// TODO(crbug.com/dawn/839): Investigate why this test fails with WARP.
DAWN_SUPPRESS_TEST_IF(IsWARP());
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension; uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
// All dimensions are above the max // All dimensions are above the max
@ -266,6 +289,9 @@ TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) { TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) {
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation")); DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
// TODO(crbug.com/dawn/839): Investigate why this test fails with WARP.
DAWN_SUPPRESS_TEST_IF(IsWARP());
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension; uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t)); IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t));