From 829d165d7c59eb8a6633a3322dc7c3414baacaa0 Mon Sep 17 00:00:00 2001 From: Jiawei Shao Date: Wed, 17 Nov 2021 05:00:44 +0000 Subject: [PATCH] D3D12: Support [[num_workgroups]] for DispatchIndirect This patch supports [[num_workgroups]] on D3D12 for DispatchIndirect by appending the values of [[num_workgroups]] at the end of the scratch buffer for indirect dispatch validation and setting them as the root constants in the command signature. With this patch, for every DispatchIndirect call: - On D3D12: 1. Validation enabled, [[num_workgroups]] is used The dispatch indirect buffer needs to be validated, duplicated and be written into a scratch buffer (size: 6 * uint32_t). 2. Validation enabled, [[num_workgroups]] isn't used The dispatch indirect buffer needs to be validated and be written into a scratch buffer (size: 3 * uint32_t). 3. Validation disabled, [[num_workgroups]] is used The dispatch indirect buffer needs to be duplicated and be written into a scratch buffer (size: 6 * uint32_t). 4. Validation disabled, [[num_workgroups]] isn't used Neither transformations or scratch buffers are needed for the dispatch call. - On the other backends: 1. Validation enabled, The dispatch indirect buffer needs to be validated and be written into a scratch buffer (size: 3 * uint32_t). 2. Validation disabled, Neither transformations or scratch buffers are needed for the dispatch call. BUG=dawn:839 TEST=dawn_end2end_tests Change-Id: I4105f1b2e3c12f6df6e487ed535a627fbb342344 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/68843 Reviewed-by: Austin Eng Commit-Queue: Jiawei Shao --- src/dawn_native/Buffer.cpp | 8 +- src/dawn_native/ComputePassEncoder.cpp | 95 +++++++++++++------ src/dawn_native/ComputePassEncoder.h | 4 +- src/dawn_native/Device.cpp | 5 + src/dawn_native/Device.h | 5 + src/dawn_native/d3d12/CommandBufferD3D12.cpp | 15 +-- .../d3d12/ComputePipelineD3D12.cpp | 7 ++ src/dawn_native/d3d12/ComputePipelineD3D12.h | 2 + src/dawn_native/d3d12/DeviceD3D12.cpp | 5 + src/dawn_native/d3d12/DeviceD3D12.h | 3 + src/dawn_native/d3d12/PipelineLayoutD3D12.cpp | 36 ++++++- src/dawn_native/d3d12/PipelineLayoutD3D12.h | 5 +- src/tests/end2end/ComputeDispatchTests.cpp | 90 +++++++++++------- 13 files changed, 198 insertions(+), 82 deletions(-) diff --git a/src/dawn_native/Buffer.cpp b/src/dawn_native/Buffer.cpp index 10673207a6..4565df4c9f 100644 --- a/src/dawn_native/Buffer.cpp +++ b/src/dawn_native/Buffer.cpp @@ -152,9 +152,11 @@ namespace dawn_native { mUsage |= kInternalStorageBuffer; } - // We also add internal storage usage for Indirect buffers if validation is enabled, since - // validation involves binding them as storage buffers for use in a compute pass. - if ((mUsage & wgpu::BufferUsage::Indirect) && device->IsValidationEnabled()) { + // We also add internal storage usage for Indirect buffers for some transformations before + // DispatchIndirect calls on the backend (e.g. validations, support of [[num_workgroups]] on + // D3D12), since these transformations involve binding them as storage buffers for use in a + // compute pass. + if (mUsage & wgpu::BufferUsage::Indirect) { mUsage |= kInternalStorageBuffer; } diff --git a/src/dawn_native/ComputePassEncoder.cpp b/src/dawn_native/ComputePassEncoder.cpp index 281979ff8f..02079017bf 100644 --- a/src/dawn_native/ComputePassEncoder.cpp +++ b/src/dawn_native/ComputePassEncoder.cpp @@ -42,11 +42,14 @@ namespace dawn_native { // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this // shader in various failure modes. + // Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable. Ref shaderModule; DAWN_TRY_ASSIGN(shaderModule, utils::CreateShaderModule(device, R"( [[block]] struct UniformParams { maxComputeWorkgroupsPerDimension: u32; clientOffsetInU32: u32; + enableValidation: u32; + duplicateNumWorkgroups: u32; }; [[block]] struct IndirectParams { @@ -54,7 +57,7 @@ namespace dawn_native { }; [[block]] struct ValidatedParams { - data: array; + data: array; }; [[group(0), binding(0)]] var uniformParams: UniformParams; @@ -65,10 +68,15 @@ namespace dawn_native { fn main() { for (var i = 0u; i < 3u; i = i + 1u) { var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i]; - if (numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) { + if (uniformParams.enableValidation > 0u && + numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) { numWorkgroups = 0u; } validatedParams.data[i] = numWorkgroups; + + if (uniformParams.duplicateNumWorkgroups > 0u) { + validatedParams.data[i + 3u] = numWorkgroups; + } } } )")); @@ -191,9 +199,21 @@ namespace dawn_native { } ResultOrError, uint64_t>> - ComputePassEncoder::ValidateIndirectDispatch(BufferBase* indirectBuffer, - uint64_t indirectOffset) { + ComputePassEncoder::TransformIndirectDispatchBuffer(Ref indirectBuffer, + uint64_t indirectOffset) { DeviceBase* device = GetDevice(); + + const bool shouldDuplicateNumWorkgroups = + device->ShouldDuplicateNumWorkgroupsForDispatchIndirect( + mCommandBufferState.GetComputePipeline()); + if (!IsValidationEnabled() && !shouldDuplicateNumWorkgroups) { + return std::make_pair(indirectBuffer, indirectOffset); + } + + // Save the previous command buffer state so it can be restored after the + // validation inserts additional commands. + CommandBufferStateTracker previousState = mCommandBufferState; + auto* const store = device->GetInternalPipelineStore(); Ref validationPipeline; @@ -215,9 +235,14 @@ namespace dawn_native { const uint64_t clientIndirectBindingSize = kDispatchIndirectSize + clientOffsetFromAlignedBoundary; + // Neither 'enableValidation' nor 'duplicateNumWorkgroups' can be declared as 'bool' as + // currently in WGSL type 'bool' cannot be used in storage class 'uniform' as 'it is + // non-host-shareable'. struct UniformParams { uint32_t maxComputeWorkgroupsPerDimension; uint32_t clientOffsetInU32; + uint32_t enableValidation; + uint32_t duplicateNumWorkgroups; }; // Create a uniform buffer to hold parameters for the shader. @@ -227,6 +252,8 @@ namespace dawn_native { params.maxComputeWorkgroupsPerDimension = device->GetLimits().v1.maxComputeWorkgroupsPerDimension; params.clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t); + params.enableValidation = static_cast(IsValidationEnabled()); + params.duplicateNumWorkgroups = static_cast(shouldDuplicateNumWorkgroups); DAWN_TRY_ASSIGN(uniformBuffer, utils::CreateBufferFromData( device, wgpu::BufferUsage::Uniform, {params})); @@ -234,25 +261,30 @@ namespace dawn_native { // Reserve space in the scratch buffer to hold the validated indirect params. ScratchBuffer& scratchBuffer = store->scratchIndirectStorage; - DAWN_TRY(scratchBuffer.EnsureCapacity(kDispatchIndirectSize)); + const uint64_t scratchBufferSize = + shouldDuplicateNumWorkgroups ? 2 * kDispatchIndirectSize : kDispatchIndirectSize; + DAWN_TRY(scratchBuffer.EnsureCapacity(scratchBufferSize)); Ref validatedIndirectBuffer = scratchBuffer.GetBuffer(); Ref validationBindGroup; - DAWN_TRY_ASSIGN( - validationBindGroup, - utils::MakeBindGroup( - device, layout, - { - {0, uniformBuffer}, - {1, indirectBuffer, clientIndirectBindingOffset, clientIndirectBindingSize}, - {2, validatedIndirectBuffer, 0, kDispatchIndirectSize}, - })); + ASSERT(indirectBuffer->GetUsage() & kInternalStorageBuffer); + DAWN_TRY_ASSIGN(validationBindGroup, + utils::MakeBindGroup(device, layout, + { + {0, uniformBuffer}, + {1, indirectBuffer, clientIndirectBindingOffset, + clientIndirectBindingSize}, + {2, validatedIndirectBuffer, 0, scratchBufferSize}, + })); // Issue commands to validate the indirect buffer. APISetPipeline(validationPipeline.Get()); APISetBindGroup(0, validationBindGroup.Get()); APIDispatch(1); + // Restore the state. + RestoreCommandBufferState(std::move(previousState)); + // Return the new indirect buffer and indirect buffer offset. return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0)); } @@ -287,27 +319,28 @@ namespace dawn_native { // the backend. Ref indirectBufferRef = indirectBuffer; - if (IsValidationEnabled()) { - // Save the previous command buffer state so it can be restored after the - // validation inserts additional commands. - CommandBufferStateTracker previousState = mCommandBufferState; - // Validate each indirect dispatch with a single dispatch to copy the indirect - // buffer params into a scratch buffer if they're valid, and otherwise zero them - // out. We could consider moving the validation earlier in the pass after the - // last point the indirect buffer was used with writable usage, as well as batch - // validation for multiple dispatches into one, but inserting commands at - // arbitrary points in the past is not possible right now. - DAWN_TRY_ASSIGN( - std::tie(indirectBufferRef, indirectOffset), - ValidateIndirectDispatch(indirectBufferRef.Get(), indirectOffset)); - - // Restore the state. - RestoreCommandBufferState(std::move(previousState)); + // Get applied indirect buffer with necessary changes on the original indirect + // buffer. For example, + // - Validate each indirect dispatch with a single dispatch to copy the indirect + // buffer params into a scratch buffer if they're valid, and otherwise zero them + // out. + // - Duplicate all the indirect dispatch parameters to support [[num_workgroups]] on + // D3D12. + // - Directly return the original indirect dispatch buffer if we don't need any + // transformations on it. + // We could consider moving the validation earlier in the pass after the last + // last point the indirect buffer was used with writable usage, as well as batch + // validation for multiple dispatches into one, but inserting commands at + // arbitrary points in the past is not possible right now. + DAWN_TRY_ASSIGN(std::tie(indirectBufferRef, indirectOffset), + TransformIndirectDispatchBuffer(indirectBufferRef, indirectOffset)); + // If we have created a new scratch dispatch indirect buffer in + // TransformIndirectDispatchBuffer(), we need to track it in mUsageTracker. + if (indirectBufferRef.Get() != indirectBuffer) { // |indirectBufferRef| was replaced with a scratch buffer. Add it to the // synchronization scope. - ASSERT(indirectBufferRef.Get() != indirectBuffer); scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect); mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get()); } diff --git a/src/dawn_native/ComputePassEncoder.h b/src/dawn_native/ComputePassEncoder.h index 96841d868b..cc8132027d 100644 --- a/src/dawn_native/ComputePassEncoder.h +++ b/src/dawn_native/ComputePassEncoder.h @@ -64,8 +64,8 @@ namespace dawn_native { private: void DestroyImpl() override; - ResultOrError, uint64_t>> ValidateIndirectDispatch( - BufferBase* indirectBuffer, + ResultOrError, uint64_t>> TransformIndirectDispatchBuffer( + Ref indirectBuffer, uint64_t indirectOffset); void RestoreCommandBufferState(CommandBufferStateTracker state); diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp index 3003e4541a..a65285389e 100644 --- a/src/dawn_native/Device.cpp +++ b/src/dawn_native/Device.cpp @@ -1670,4 +1670,9 @@ namespace dawn_native { void DeviceBase::SetLabelImpl() { } + bool DeviceBase::ShouldDuplicateNumWorkgroupsForDispatchIndirect( + ComputePipelineBase* computePipeline) const { + return false; + } + } // namespace dawn_native diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h index eb2fb046a7..fd56a87061 100644 --- a/src/dawn_native/Device.h +++ b/src/dawn_native/Device.h @@ -336,11 +336,16 @@ namespace dawn_native { MaybeError Tick(); + // TODO(crbug.com/dawn/839): Organize the below backend-specific parameters into the struct + // BackendMetadata that we can query from the device. virtual uint32_t GetOptimalBytesPerRowAlignment() const = 0; virtual uint64_t GetOptimalBufferToTextureCopyOffsetAlignment() const = 0; virtual float GetTimestampPeriodInNS() const = 0; + virtual bool ShouldDuplicateNumWorkgroupsForDispatchIndirect( + ComputePipelineBase* computePipeline) const; + const CombinedLimits& GetLimits() const; AsyncTaskManager* GetAsyncTaskManager() const; diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp index f50c8a49d7..e35964abb3 100644 --- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp +++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp @@ -1063,22 +1063,15 @@ namespace dawn_native { namespace d3d12 { case Command::DispatchIndirect: { DispatchIndirectCmd* dispatch = mCommands.NextCommand(); - // TODO(dawn:839): support [[num_workgroups]] for DispatchIndirect calls - DAWN_INVALID_IF(lastPipeline->UsesNumWorkgroups(), - "Using %s with [[num_workgroups]] in a DispatchIndirect call " - "is not implemented.", - lastPipeline); - - Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get()); - TransitionAndClearForSyncScope(commandContext, resourceUsages.dispatchUsages[currentDispatch]); DAWN_TRY(bindingTracker->Apply(commandContext)); ComPtr signature = - ToBackend(GetDevice())->GetDispatchIndirectSignature(); - commandList->ExecuteIndirect(signature.Get(), 1, buffer->GetD3D12Resource(), - dispatch->indirectOffset, nullptr, 0); + lastPipeline->GetDispatchIndirectCommandSignature(); + commandList->ExecuteIndirect( + signature.Get(), 1, ToBackend(dispatch->indirectBuffer)->GetD3D12Resource(), + dispatch->indirectOffset, nullptr, 0); currentDispatch++; break; } diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp index f73f1da634..744cb68b93 100644 --- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp +++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp @@ -90,4 +90,11 @@ namespace dawn_native { namespace d3d12 { return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups; } + ComPtr ComputePipeline::GetDispatchIndirectCommandSignature() { + if (UsesNumWorkgroups()) { + return ToBackend(GetLayout())->GetDispatchIndirectCommandSignatureWithNumWorkgroups(); + } + return ToBackend(GetDevice())->GetDispatchIndirectSignature(); + } + }} // namespace dawn_native::d3d12 diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.h b/src/dawn_native/d3d12/ComputePipelineD3D12.h index fdba24f051..ddf7476ed1 100644 --- a/src/dawn_native/d3d12/ComputePipelineD3D12.h +++ b/src/dawn_native/d3d12/ComputePipelineD3D12.h @@ -42,6 +42,8 @@ namespace dawn_native { namespace d3d12 { bool UsesNumWorkgroups() const; + ComPtr GetDispatchIndirectCommandSignature(); + private: ~ComputePipeline() override; diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp index 0c2d34ce33..027ead1fd1 100644 --- a/src/dawn_native/d3d12/DeviceD3D12.cpp +++ b/src/dawn_native/d3d12/DeviceD3D12.cpp @@ -675,4 +675,9 @@ namespace dawn_native { namespace d3d12 { return mTimestampPeriod; } + bool Device::ShouldDuplicateNumWorkgroupsForDispatchIndirect( + ComputePipelineBase* computePipeline) const { + return ToBackend(computePipeline)->UsesNumWorkgroups(); + } + }} // namespace dawn_native::d3d12 diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h index 0e78cf991c..a1c48c1afa 100644 --- a/src/dawn_native/d3d12/DeviceD3D12.h +++ b/src/dawn_native/d3d12/DeviceD3D12.h @@ -139,6 +139,9 @@ namespace dawn_native { namespace d3d12 { float GetTimestampPeriodInNS() const override; + bool ShouldDuplicateNumWorkgroupsForDispatchIndirect( + ComputePipelineBase* computePipeline) const override; + private: using DeviceBase::DeviceBase; diff --git a/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp b/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp index 1a512fa60e..2953cab36b 100644 --- a/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp +++ b/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp @@ -184,7 +184,7 @@ namespace dawn_native { namespace d3d12 { numWorkgroupsConstants.Constants.Num32BitValues = 3; numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace(); numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister(); - mNumWorkgroupsParamterIndex = rootParameters.size(); + mNumWorkgroupsParameterIndex = rootParameters.size(); // NOTE: We should consider moving this entry to earlier in the root signature since // dispatch sizes would need to be updated often rootParameters.emplace_back(numWorkgroupsConstants); @@ -265,6 +265,38 @@ namespace dawn_native { namespace d3d12 { } uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const { - return mNumWorkgroupsParamterIndex; + return mNumWorkgroupsParameterIndex; } + + ID3D12CommandSignature* PipelineLayout::GetDispatchIndirectCommandSignatureWithNumWorkgroups() { + // mDispatchIndirectCommandSignatureWithNumWorkgroups won't be created until it is needed. + if (mDispatchIndirectCommandSignatureWithNumWorkgroups.Get() != nullptr) { + return mDispatchIndirectCommandSignatureWithNumWorkgroups.Get(); + } + + D3D12_INDIRECT_ARGUMENT_DESC argumentDescs[2] = {}; + argumentDescs[0].Type = D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT; + argumentDescs[0].Constant.RootParameterIndex = GetNumWorkgroupsParameterIndex(); + argumentDescs[0].Constant.Num32BitValuesToSet = 3; + argumentDescs[0].Constant.DestOffsetIn32BitValues = 0; + + // A command signature must contain exactly 1 Draw / Dispatch / DispatchMesh / DispatchRays + // command. That command must come last. + argumentDescs[1].Type = D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH; + + D3D12_COMMAND_SIGNATURE_DESC programDesc = {}; + programDesc.ByteStride = 6 * sizeof(uint32_t); + programDesc.NumArgumentDescs = 2; + programDesc.pArgumentDescs = argumentDescs; + + // The root signature must be specified if and only if the command signature changes one of + // the root arguments. + ToBackend(GetDevice()) + ->GetD3D12Device() + ->CreateCommandSignature( + &programDesc, GetRootSignature(), + IID_PPV_ARGS(&mDispatchIndirectCommandSignatureWithNumWorkgroups)); + return mDispatchIndirectCommandSignatureWithNumWorkgroups.Get(); + } + }} // namespace dawn_native::d3d12 diff --git a/src/dawn_native/d3d12/PipelineLayoutD3D12.h b/src/dawn_native/d3d12/PipelineLayoutD3D12.h index cf52f066e7..7a539b3c04 100644 --- a/src/dawn_native/d3d12/PipelineLayoutD3D12.h +++ b/src/dawn_native/d3d12/PipelineLayoutD3D12.h @@ -55,6 +55,8 @@ namespace dawn_native { namespace d3d12 { ID3D12RootSignature* GetRootSignature() const; + ID3D12CommandSignature* GetDispatchIndirectCommandSignatureWithNumWorkgroups(); + private: ~PipelineLayout() override = default; using PipelineLayoutBase::PipelineLayoutBase; @@ -66,8 +68,9 @@ namespace dawn_native { namespace d3d12 { kMaxBindGroups> mDynamicRootParameterIndices; uint32_t mFirstIndexOffsetParameterIndex; - uint32_t mNumWorkgroupsParamterIndex; + uint32_t mNumWorkgroupsParameterIndex; ComPtr mRootSignature; + ComPtr mDispatchIndirectCommandSignatureWithNumWorkgroups; }; }} // namespace dawn_native::d3d12 diff --git a/src/tests/end2end/ComputeDispatchTests.cpp b/src/tests/end2end/ComputeDispatchTests.cpp index cbc2d8642c..5589d64c9a 100644 --- a/src/tests/end2end/ComputeDispatchTests.cpp +++ b/src/tests/end2end/ComputeDispatchTests.cpp @@ -27,7 +27,7 @@ class ComputeDispatchTests : public DawnTest { // Write workgroup number into the output buffer if we saw the biggest dispatch // To make sure the dispatch was not called, write maximum u32 value for 0 dispatches - wgpu::ShaderModule moduleForDispatch = utils::CreateShaderModule(device, R"( + wgpu::ShaderModule module = utils::CreateShaderModule(device, R"( [[block]] struct OutputBuf { workGroups : vec3; }; @@ -47,9 +47,13 @@ class ComputeDispatchTests : public DawnTest { } })"); - // TODO(dawn:839): use moduleForDispatch for indirect dispatch tests when D3D12 supports - // [[num_workgroups]] for indirect dispatch. - wgpu::ShaderModule moduleForDispatchIndirect = utils::CreateShaderModule(device, R"( + wgpu::ComputePipelineDescriptor csDesc; + csDesc.compute.module = module; + csDesc.compute.entryPoint = "main"; + pipeline = device.CreateComputePipeline(&csDesc); + + // Test the use of the compute pipelines without using [[num_workgroups]] + wgpu::ShaderModule moduleWithoutNumWorkgroups = utils::CreateShaderModule(device, R"( [[block]] struct InputBuf { expectedDispatch : vec3; }; @@ -73,14 +77,8 @@ class ComputeDispatchTests : public DawnTest { output.workGroups = dispatch; } })"); - - wgpu::ComputePipelineDescriptor csDesc; - csDesc.compute.module = moduleForDispatch; - csDesc.compute.entryPoint = "main"; - pipelineForDispatch = device.CreateComputePipeline(&csDesc); - - csDesc.compute.module = moduleForDispatchIndirect; - pipelineForDispatchIndirect = device.CreateComputePipeline(&csDesc); + csDesc.compute.module = moduleWithoutNumWorkgroups; + pipelineWithoutNumWorkgroups = device.CreateComputePipeline(&csDesc); } void DirectTest(uint32_t x, uint32_t y, uint32_t z) { @@ -91,17 +89,16 @@ class ComputeDispatchTests : public DawnTest { kSentinelData); // Set up bind group and issue dispatch - wgpu::BindGroup bindGroup = - utils::MakeBindGroup(device, pipelineForDispatch.GetBindGroupLayout(0), - { - {0, dst, 0, 3 * sizeof(uint32_t)}, - }); + wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), + { + {0, dst, 0, 3 * sizeof(uint32_t)}, + }); wgpu::CommandBuffer commands; { wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(pipelineForDispatch); + pass.SetPipeline(pipeline); pass.SetBindGroup(0, bindGroup); pass.Dispatch(x, y, z); pass.EndPass(); @@ -118,7 +115,9 @@ class ComputeDispatchTests : public DawnTest { EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3); } - void IndirectTest(std::vector indirectBufferData, uint64_t indirectOffset) { + void IndirectTest(std::vector indirectBufferData, + uint64_t indirectOffset, + bool useNumWorkgroups = true) { // Set up dst storage buffer to contain dispatch x, y, z wgpu::Buffer dst = utils::CreateBufferFromData( device, @@ -131,23 +130,34 @@ class ComputeDispatchTests : public DawnTest { uint32_t indirectStart = indirectOffset / sizeof(uint32_t); - wgpu::Buffer expectedBuffer = - utils::CreateBufferFromData(device, &indirectBufferData[indirectStart], - 3 * sizeof(uint32_t), wgpu::BufferUsage::Uniform); - // Set up bind group and issue dispatch - wgpu::BindGroup bindGroup = - utils::MakeBindGroup(device, pipelineForDispatchIndirect.GetBindGroupLayout(0), - { - {0, expectedBuffer, 0, 3 * sizeof(uint32_t)}, - {1, dst, 0, 3 * sizeof(uint32_t)}, - }); + wgpu::BindGroup bindGroup; + wgpu::ComputePipeline computePipelineForTest; + + if (useNumWorkgroups) { + computePipelineForTest = pipeline; + bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), + { + {0, dst, 0, 3 * sizeof(uint32_t)}, + }); + } else { + computePipelineForTest = pipelineWithoutNumWorkgroups; + wgpu::Buffer expectedBuffer = + utils::CreateBufferFromData(device, &indirectBufferData[indirectStart], + 3 * sizeof(uint32_t), wgpu::BufferUsage::Uniform); + bindGroup = + utils::MakeBindGroup(device, pipelineWithoutNumWorkgroups.GetBindGroupLayout(0), + { + {0, expectedBuffer, 0, 3 * sizeof(uint32_t)}, + {1, dst, 0, 3 * sizeof(uint32_t)}, + }); + } wgpu::CommandBuffer commands; { wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(pipelineForDispatchIndirect); + pass.SetPipeline(computePipelineForTest); pass.SetBindGroup(0, bindGroup); pass.DispatchIndirect(indirectBuffer, indirectOffset); pass.EndPass(); @@ -178,8 +188,8 @@ class ComputeDispatchTests : public DawnTest { } private: - wgpu::ComputePipeline pipelineForDispatch; - wgpu::ComputePipeline pipelineForDispatchIndirect; + wgpu::ComputePipeline pipeline; + wgpu::ComputePipeline pipelineWithoutNumWorkgroups; }; // Test basic direct @@ -207,6 +217,11 @@ TEST_P(ComputeDispatchTests, IndirectBasic) { IndirectTest({2, 3, 4}, 0); } +// Test basic indirect without using [[num_workgroups]] +TEST_P(ComputeDispatchTests, IndirectBasicWithoutNumWorkgroups) { + IndirectTest({2, 3, 4}, 0, false); +} + // Test no-op indirect TEST_P(ComputeDispatchTests, IndirectNoop) { // All dimensions are 0s @@ -227,6 +242,11 @@ TEST_P(ComputeDispatchTests, IndirectOffset) { IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t)); } +// Test indirect with buffer offset without using [[num_workgroups]] +TEST_P(ComputeDispatchTests, IndirectOffsetWithoutNumWorkgroups) { + IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t), false); +} + // Test indirect dispatches at max limit. TEST_P(ComputeDispatchTests, MaxWorkgroups) { // TODO(crbug.com/dawn/1165): Fails with WARP @@ -244,6 +264,9 @@ TEST_P(ComputeDispatchTests, MaxWorkgroups) { TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) { DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation")); + // TODO(crbug.com/dawn/839): Investigate why this test fails with WARP. + DAWN_SUPPRESS_TEST_IF(IsWARP()); + uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension; // All dimensions are above the max @@ -266,6 +289,9 @@ TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) { TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) { DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation")); + // TODO(crbug.com/dawn/839): Investigate why this test fails with WARP. + DAWN_SUPPRESS_TEST_IF(IsWARP()); + uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension; IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t));