D3D12: Support [[num_workgroups]] for DispatchIndirect
This patch supports [[num_workgroups]] on D3D12 for DispatchIndirect by appending the values of [[num_workgroups]] at the end of the scratch buffer for indirect dispatch validation and setting them as the root constants in the command signature. With this patch, for every DispatchIndirect call: - On D3D12: 1. Validation enabled, [[num_workgroups]] is used The dispatch indirect buffer needs to be validated, duplicated and be written into a scratch buffer (size: 6 * uint32_t). 2. Validation enabled, [[num_workgroups]] isn't used The dispatch indirect buffer needs to be validated and be written into a scratch buffer (size: 3 * uint32_t). 3. Validation disabled, [[num_workgroups]] is used The dispatch indirect buffer needs to be duplicated and be written into a scratch buffer (size: 6 * uint32_t). 4. Validation disabled, [[num_workgroups]] isn't used Neither transformations or scratch buffers are needed for the dispatch call. - On the other backends: 1. Validation enabled, The dispatch indirect buffer needs to be validated and be written into a scratch buffer (size: 3 * uint32_t). 2. Validation disabled, Neither transformations or scratch buffers are needed for the dispatch call. BUG=dawn:839 TEST=dawn_end2end_tests Change-Id: I4105f1b2e3c12f6df6e487ed535a627fbb342344 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/68843 Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
parent
8ce15b3ce9
commit
829d165d7c
|
@ -152,9 +152,11 @@ namespace dawn_native {
|
|||
mUsage |= kInternalStorageBuffer;
|
||||
}
|
||||
|
||||
// We also add internal storage usage for Indirect buffers if validation is enabled, since
|
||||
// validation involves binding them as storage buffers for use in a compute pass.
|
||||
if ((mUsage & wgpu::BufferUsage::Indirect) && device->IsValidationEnabled()) {
|
||||
// We also add internal storage usage for Indirect buffers for some transformations before
|
||||
// DispatchIndirect calls on the backend (e.g. validations, support of [[num_workgroups]] on
|
||||
// D3D12), since these transformations involve binding them as storage buffers for use in a
|
||||
// compute pass.
|
||||
if (mUsage & wgpu::BufferUsage::Indirect) {
|
||||
mUsage |= kInternalStorageBuffer;
|
||||
}
|
||||
|
||||
|
|
|
@ -42,11 +42,14 @@ namespace dawn_native {
|
|||
|
||||
// TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this
|
||||
// shader in various failure modes.
|
||||
// Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable.
|
||||
Ref<ShaderModuleBase> shaderModule;
|
||||
DAWN_TRY_ASSIGN(shaderModule, utils::CreateShaderModule(device, R"(
|
||||
[[block]] struct UniformParams {
|
||||
maxComputeWorkgroupsPerDimension: u32;
|
||||
clientOffsetInU32: u32;
|
||||
enableValidation: u32;
|
||||
duplicateNumWorkgroups: u32;
|
||||
};
|
||||
|
||||
[[block]] struct IndirectParams {
|
||||
|
@ -54,7 +57,7 @@ namespace dawn_native {
|
|||
};
|
||||
|
||||
[[block]] struct ValidatedParams {
|
||||
data: array<u32, 3>;
|
||||
data: array<u32>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<uniform> uniformParams: UniformParams;
|
||||
|
@ -65,10 +68,15 @@ namespace dawn_native {
|
|||
fn main() {
|
||||
for (var i = 0u; i < 3u; i = i + 1u) {
|
||||
var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i];
|
||||
if (numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) {
|
||||
if (uniformParams.enableValidation > 0u &&
|
||||
numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) {
|
||||
numWorkgroups = 0u;
|
||||
}
|
||||
validatedParams.data[i] = numWorkgroups;
|
||||
|
||||
if (uniformParams.duplicateNumWorkgroups > 0u) {
|
||||
validatedParams.data[i + 3u] = numWorkgroups;
|
||||
}
|
||||
}
|
||||
}
|
||||
)"));
|
||||
|
@ -191,9 +199,21 @@ namespace dawn_native {
|
|||
}
|
||||
|
||||
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>>
|
||||
ComputePassEncoder::ValidateIndirectDispatch(BufferBase* indirectBuffer,
|
||||
ComputePassEncoder::TransformIndirectDispatchBuffer(Ref<BufferBase> indirectBuffer,
|
||||
uint64_t indirectOffset) {
|
||||
DeviceBase* device = GetDevice();
|
||||
|
||||
const bool shouldDuplicateNumWorkgroups =
|
||||
device->ShouldDuplicateNumWorkgroupsForDispatchIndirect(
|
||||
mCommandBufferState.GetComputePipeline());
|
||||
if (!IsValidationEnabled() && !shouldDuplicateNumWorkgroups) {
|
||||
return std::make_pair(indirectBuffer, indirectOffset);
|
||||
}
|
||||
|
||||
// Save the previous command buffer state so it can be restored after the
|
||||
// validation inserts additional commands.
|
||||
CommandBufferStateTracker previousState = mCommandBufferState;
|
||||
|
||||
auto* const store = device->GetInternalPipelineStore();
|
||||
|
||||
Ref<ComputePipelineBase> validationPipeline;
|
||||
|
@ -215,9 +235,14 @@ namespace dawn_native {
|
|||
const uint64_t clientIndirectBindingSize =
|
||||
kDispatchIndirectSize + clientOffsetFromAlignedBoundary;
|
||||
|
||||
// Neither 'enableValidation' nor 'duplicateNumWorkgroups' can be declared as 'bool' as
|
||||
// currently in WGSL type 'bool' cannot be used in storage class 'uniform' as 'it is
|
||||
// non-host-shareable'.
|
||||
struct UniformParams {
|
||||
uint32_t maxComputeWorkgroupsPerDimension;
|
||||
uint32_t clientOffsetInU32;
|
||||
uint32_t enableValidation;
|
||||
uint32_t duplicateNumWorkgroups;
|
||||
};
|
||||
|
||||
// Create a uniform buffer to hold parameters for the shader.
|
||||
|
@ -227,6 +252,8 @@ namespace dawn_native {
|
|||
params.maxComputeWorkgroupsPerDimension =
|
||||
device->GetLimits().v1.maxComputeWorkgroupsPerDimension;
|
||||
params.clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t);
|
||||
params.enableValidation = static_cast<uint32_t>(IsValidationEnabled());
|
||||
params.duplicateNumWorkgroups = static_cast<uint32_t>(shouldDuplicateNumWorkgroups);
|
||||
|
||||
DAWN_TRY_ASSIGN(uniformBuffer, utils::CreateBufferFromData(
|
||||
device, wgpu::BufferUsage::Uniform, {params}));
|
||||
|
@ -234,18 +261,20 @@ namespace dawn_native {
|
|||
|
||||
// Reserve space in the scratch buffer to hold the validated indirect params.
|
||||
ScratchBuffer& scratchBuffer = store->scratchIndirectStorage;
|
||||
DAWN_TRY(scratchBuffer.EnsureCapacity(kDispatchIndirectSize));
|
||||
const uint64_t scratchBufferSize =
|
||||
shouldDuplicateNumWorkgroups ? 2 * kDispatchIndirectSize : kDispatchIndirectSize;
|
||||
DAWN_TRY(scratchBuffer.EnsureCapacity(scratchBufferSize));
|
||||
Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer();
|
||||
|
||||
Ref<BindGroupBase> validationBindGroup;
|
||||
DAWN_TRY_ASSIGN(
|
||||
validationBindGroup,
|
||||
utils::MakeBindGroup(
|
||||
device, layout,
|
||||
ASSERT(indirectBuffer->GetUsage() & kInternalStorageBuffer);
|
||||
DAWN_TRY_ASSIGN(validationBindGroup,
|
||||
utils::MakeBindGroup(device, layout,
|
||||
{
|
||||
{0, uniformBuffer},
|
||||
{1, indirectBuffer, clientIndirectBindingOffset, clientIndirectBindingSize},
|
||||
{2, validatedIndirectBuffer, 0, kDispatchIndirectSize},
|
||||
{1, indirectBuffer, clientIndirectBindingOffset,
|
||||
clientIndirectBindingSize},
|
||||
{2, validatedIndirectBuffer, 0, scratchBufferSize},
|
||||
}));
|
||||
|
||||
// Issue commands to validate the indirect buffer.
|
||||
|
@ -253,6 +282,9 @@ namespace dawn_native {
|
|||
APISetBindGroup(0, validationBindGroup.Get());
|
||||
APIDispatch(1);
|
||||
|
||||
// Restore the state.
|
||||
RestoreCommandBufferState(std::move(previousState));
|
||||
|
||||
// Return the new indirect buffer and indirect buffer offset.
|
||||
return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0));
|
||||
}
|
||||
|
@ -287,27 +319,28 @@ namespace dawn_native {
|
|||
// the backend.
|
||||
|
||||
Ref<BufferBase> indirectBufferRef = indirectBuffer;
|
||||
if (IsValidationEnabled()) {
|
||||
// Save the previous command buffer state so it can be restored after the
|
||||
// validation inserts additional commands.
|
||||
CommandBufferStateTracker previousState = mCommandBufferState;
|
||||
|
||||
// Validate each indirect dispatch with a single dispatch to copy the indirect
|
||||
// Get applied indirect buffer with necessary changes on the original indirect
|
||||
// buffer. For example,
|
||||
// - Validate each indirect dispatch with a single dispatch to copy the indirect
|
||||
// buffer params into a scratch buffer if they're valid, and otherwise zero them
|
||||
// out. We could consider moving the validation earlier in the pass after the
|
||||
// out.
|
||||
// - Duplicate all the indirect dispatch parameters to support [[num_workgroups]] on
|
||||
// D3D12.
|
||||
// - Directly return the original indirect dispatch buffer if we don't need any
|
||||
// transformations on it.
|
||||
// We could consider moving the validation earlier in the pass after the last
|
||||
// last point the indirect buffer was used with writable usage, as well as batch
|
||||
// validation for multiple dispatches into one, but inserting commands at
|
||||
// arbitrary points in the past is not possible right now.
|
||||
DAWN_TRY_ASSIGN(
|
||||
std::tie(indirectBufferRef, indirectOffset),
|
||||
ValidateIndirectDispatch(indirectBufferRef.Get(), indirectOffset));
|
||||
|
||||
// Restore the state.
|
||||
RestoreCommandBufferState(std::move(previousState));
|
||||
DAWN_TRY_ASSIGN(std::tie(indirectBufferRef, indirectOffset),
|
||||
TransformIndirectDispatchBuffer(indirectBufferRef, indirectOffset));
|
||||
|
||||
// If we have created a new scratch dispatch indirect buffer in
|
||||
// TransformIndirectDispatchBuffer(), we need to track it in mUsageTracker.
|
||||
if (indirectBufferRef.Get() != indirectBuffer) {
|
||||
// |indirectBufferRef| was replaced with a scratch buffer. Add it to the
|
||||
// synchronization scope.
|
||||
ASSERT(indirectBufferRef.Get() != indirectBuffer);
|
||||
scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect);
|
||||
mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get());
|
||||
}
|
||||
|
|
|
@ -64,8 +64,8 @@ namespace dawn_native {
|
|||
private:
|
||||
void DestroyImpl() override;
|
||||
|
||||
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ValidateIndirectDispatch(
|
||||
BufferBase* indirectBuffer,
|
||||
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> TransformIndirectDispatchBuffer(
|
||||
Ref<BufferBase> indirectBuffer,
|
||||
uint64_t indirectOffset);
|
||||
|
||||
void RestoreCommandBufferState(CommandBufferStateTracker state);
|
||||
|
|
|
@ -1670,4 +1670,9 @@ namespace dawn_native {
|
|||
void DeviceBase::SetLabelImpl() {
|
||||
}
|
||||
|
||||
bool DeviceBase::ShouldDuplicateNumWorkgroupsForDispatchIndirect(
|
||||
ComputePipelineBase* computePipeline) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace dawn_native
|
||||
|
|
|
@ -336,11 +336,16 @@ namespace dawn_native {
|
|||
|
||||
MaybeError Tick();
|
||||
|
||||
// TODO(crbug.com/dawn/839): Organize the below backend-specific parameters into the struct
|
||||
// BackendMetadata that we can query from the device.
|
||||
virtual uint32_t GetOptimalBytesPerRowAlignment() const = 0;
|
||||
virtual uint64_t GetOptimalBufferToTextureCopyOffsetAlignment() const = 0;
|
||||
|
||||
virtual float GetTimestampPeriodInNS() const = 0;
|
||||
|
||||
virtual bool ShouldDuplicateNumWorkgroupsForDispatchIndirect(
|
||||
ComputePipelineBase* computePipeline) const;
|
||||
|
||||
const CombinedLimits& GetLimits() const;
|
||||
|
||||
AsyncTaskManager* GetAsyncTaskManager() const;
|
||||
|
|
|
@ -1063,21 +1063,14 @@ namespace dawn_native { namespace d3d12 {
|
|||
case Command::DispatchIndirect: {
|
||||
DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
|
||||
|
||||
// TODO(dawn:839): support [[num_workgroups]] for DispatchIndirect calls
|
||||
DAWN_INVALID_IF(lastPipeline->UsesNumWorkgroups(),
|
||||
"Using %s with [[num_workgroups]] in a DispatchIndirect call "
|
||||
"is not implemented.",
|
||||
lastPipeline);
|
||||
|
||||
Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
|
||||
|
||||
TransitionAndClearForSyncScope(commandContext,
|
||||
resourceUsages.dispatchUsages[currentDispatch]);
|
||||
DAWN_TRY(bindingTracker->Apply(commandContext));
|
||||
|
||||
ComPtr<ID3D12CommandSignature> signature =
|
||||
ToBackend(GetDevice())->GetDispatchIndirectSignature();
|
||||
commandList->ExecuteIndirect(signature.Get(), 1, buffer->GetD3D12Resource(),
|
||||
lastPipeline->GetDispatchIndirectCommandSignature();
|
||||
commandList->ExecuteIndirect(
|
||||
signature.Get(), 1, ToBackend(dispatch->indirectBuffer)->GetD3D12Resource(),
|
||||
dispatch->indirectOffset, nullptr, 0);
|
||||
currentDispatch++;
|
||||
break;
|
||||
|
|
|
@ -90,4 +90,11 @@ namespace dawn_native { namespace d3d12 {
|
|||
return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups;
|
||||
}
|
||||
|
||||
ComPtr<ID3D12CommandSignature> ComputePipeline::GetDispatchIndirectCommandSignature() {
|
||||
if (UsesNumWorkgroups()) {
|
||||
return ToBackend(GetLayout())->GetDispatchIndirectCommandSignatureWithNumWorkgroups();
|
||||
}
|
||||
return ToBackend(GetDevice())->GetDispatchIndirectSignature();
|
||||
}
|
||||
|
||||
}} // namespace dawn_native::d3d12
|
||||
|
|
|
@ -42,6 +42,8 @@ namespace dawn_native { namespace d3d12 {
|
|||
|
||||
bool UsesNumWorkgroups() const;
|
||||
|
||||
ComPtr<ID3D12CommandSignature> GetDispatchIndirectCommandSignature();
|
||||
|
||||
private:
|
||||
~ComputePipeline() override;
|
||||
|
||||
|
|
|
@ -675,4 +675,9 @@ namespace dawn_native { namespace d3d12 {
|
|||
return mTimestampPeriod;
|
||||
}
|
||||
|
||||
bool Device::ShouldDuplicateNumWorkgroupsForDispatchIndirect(
|
||||
ComputePipelineBase* computePipeline) const {
|
||||
return ToBackend(computePipeline)->UsesNumWorkgroups();
|
||||
}
|
||||
|
||||
}} // namespace dawn_native::d3d12
|
||||
|
|
|
@ -139,6 +139,9 @@ namespace dawn_native { namespace d3d12 {
|
|||
|
||||
float GetTimestampPeriodInNS() const override;
|
||||
|
||||
bool ShouldDuplicateNumWorkgroupsForDispatchIndirect(
|
||||
ComputePipelineBase* computePipeline) const override;
|
||||
|
||||
private:
|
||||
using DeviceBase::DeviceBase;
|
||||
|
||||
|
|
|
@ -184,7 +184,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
numWorkgroupsConstants.Constants.Num32BitValues = 3;
|
||||
numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace();
|
||||
numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister();
|
||||
mNumWorkgroupsParamterIndex = rootParameters.size();
|
||||
mNumWorkgroupsParameterIndex = rootParameters.size();
|
||||
// NOTE: We should consider moving this entry to earlier in the root signature since
|
||||
// dispatch sizes would need to be updated often
|
||||
rootParameters.emplace_back(numWorkgroupsConstants);
|
||||
|
@ -265,6 +265,38 @@ namespace dawn_native { namespace d3d12 {
|
|||
}
|
||||
|
||||
uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const {
|
||||
return mNumWorkgroupsParamterIndex;
|
||||
return mNumWorkgroupsParameterIndex;
|
||||
}
|
||||
|
||||
ID3D12CommandSignature* PipelineLayout::GetDispatchIndirectCommandSignatureWithNumWorkgroups() {
|
||||
// mDispatchIndirectCommandSignatureWithNumWorkgroups won't be created until it is needed.
|
||||
if (mDispatchIndirectCommandSignatureWithNumWorkgroups.Get() != nullptr) {
|
||||
return mDispatchIndirectCommandSignatureWithNumWorkgroups.Get();
|
||||
}
|
||||
|
||||
D3D12_INDIRECT_ARGUMENT_DESC argumentDescs[2] = {};
|
||||
argumentDescs[0].Type = D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT;
|
||||
argumentDescs[0].Constant.RootParameterIndex = GetNumWorkgroupsParameterIndex();
|
||||
argumentDescs[0].Constant.Num32BitValuesToSet = 3;
|
||||
argumentDescs[0].Constant.DestOffsetIn32BitValues = 0;
|
||||
|
||||
// A command signature must contain exactly 1 Draw / Dispatch / DispatchMesh / DispatchRays
|
||||
// command. That command must come last.
|
||||
argumentDescs[1].Type = D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH;
|
||||
|
||||
D3D12_COMMAND_SIGNATURE_DESC programDesc = {};
|
||||
programDesc.ByteStride = 6 * sizeof(uint32_t);
|
||||
programDesc.NumArgumentDescs = 2;
|
||||
programDesc.pArgumentDescs = argumentDescs;
|
||||
|
||||
// The root signature must be specified if and only if the command signature changes one of
|
||||
// the root arguments.
|
||||
ToBackend(GetDevice())
|
||||
->GetD3D12Device()
|
||||
->CreateCommandSignature(
|
||||
&programDesc, GetRootSignature(),
|
||||
IID_PPV_ARGS(&mDispatchIndirectCommandSignatureWithNumWorkgroups));
|
||||
return mDispatchIndirectCommandSignatureWithNumWorkgroups.Get();
|
||||
}
|
||||
|
||||
}} // namespace dawn_native::d3d12
|
||||
|
|
|
@ -55,6 +55,8 @@ namespace dawn_native { namespace d3d12 {
|
|||
|
||||
ID3D12RootSignature* GetRootSignature() const;
|
||||
|
||||
ID3D12CommandSignature* GetDispatchIndirectCommandSignatureWithNumWorkgroups();
|
||||
|
||||
private:
|
||||
~PipelineLayout() override = default;
|
||||
using PipelineLayoutBase::PipelineLayoutBase;
|
||||
|
@ -66,8 +68,9 @@ namespace dawn_native { namespace d3d12 {
|
|||
kMaxBindGroups>
|
||||
mDynamicRootParameterIndices;
|
||||
uint32_t mFirstIndexOffsetParameterIndex;
|
||||
uint32_t mNumWorkgroupsParamterIndex;
|
||||
uint32_t mNumWorkgroupsParameterIndex;
|
||||
ComPtr<ID3D12RootSignature> mRootSignature;
|
||||
ComPtr<ID3D12CommandSignature> mDispatchIndirectCommandSignatureWithNumWorkgroups;
|
||||
};
|
||||
|
||||
}} // namespace dawn_native::d3d12
|
||||
|
|
|
@ -27,7 +27,7 @@ class ComputeDispatchTests : public DawnTest {
|
|||
|
||||
// Write workgroup number into the output buffer if we saw the biggest dispatch
|
||||
// To make sure the dispatch was not called, write maximum u32 value for 0 dispatches
|
||||
wgpu::ShaderModule moduleForDispatch = utils::CreateShaderModule(device, R"(
|
||||
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
|
||||
[[block]] struct OutputBuf {
|
||||
workGroups : vec3<u32>;
|
||||
};
|
||||
|
@ -47,9 +47,13 @@ class ComputeDispatchTests : public DawnTest {
|
|||
}
|
||||
})");
|
||||
|
||||
// TODO(dawn:839): use moduleForDispatch for indirect dispatch tests when D3D12 supports
|
||||
// [[num_workgroups]] for indirect dispatch.
|
||||
wgpu::ShaderModule moduleForDispatchIndirect = utils::CreateShaderModule(device, R"(
|
||||
wgpu::ComputePipelineDescriptor csDesc;
|
||||
csDesc.compute.module = module;
|
||||
csDesc.compute.entryPoint = "main";
|
||||
pipeline = device.CreateComputePipeline(&csDesc);
|
||||
|
||||
// Test the use of the compute pipelines without using [[num_workgroups]]
|
||||
wgpu::ShaderModule moduleWithoutNumWorkgroups = utils::CreateShaderModule(device, R"(
|
||||
[[block]] struct InputBuf {
|
||||
expectedDispatch : vec3<u32>;
|
||||
};
|
||||
|
@ -73,14 +77,8 @@ class ComputeDispatchTests : public DawnTest {
|
|||
output.workGroups = dispatch;
|
||||
}
|
||||
})");
|
||||
|
||||
wgpu::ComputePipelineDescriptor csDesc;
|
||||
csDesc.compute.module = moduleForDispatch;
|
||||
csDesc.compute.entryPoint = "main";
|
||||
pipelineForDispatch = device.CreateComputePipeline(&csDesc);
|
||||
|
||||
csDesc.compute.module = moduleForDispatchIndirect;
|
||||
pipelineForDispatchIndirect = device.CreateComputePipeline(&csDesc);
|
||||
csDesc.compute.module = moduleWithoutNumWorkgroups;
|
||||
pipelineWithoutNumWorkgroups = device.CreateComputePipeline(&csDesc);
|
||||
}
|
||||
|
||||
void DirectTest(uint32_t x, uint32_t y, uint32_t z) {
|
||||
|
@ -91,8 +89,7 @@ class ComputeDispatchTests : public DawnTest {
|
|||
kSentinelData);
|
||||
|
||||
// Set up bind group and issue dispatch
|
||||
wgpu::BindGroup bindGroup =
|
||||
utils::MakeBindGroup(device, pipelineForDispatch.GetBindGroupLayout(0),
|
||||
wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
|
||||
{
|
||||
{0, dst, 0, 3 * sizeof(uint32_t)},
|
||||
});
|
||||
|
@ -101,7 +98,7 @@ class ComputeDispatchTests : public DawnTest {
|
|||
{
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
pass.SetPipeline(pipelineForDispatch);
|
||||
pass.SetPipeline(pipeline);
|
||||
pass.SetBindGroup(0, bindGroup);
|
||||
pass.Dispatch(x, y, z);
|
||||
pass.EndPass();
|
||||
|
@ -118,7 +115,9 @@ class ComputeDispatchTests : public DawnTest {
|
|||
EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3);
|
||||
}
|
||||
|
||||
void IndirectTest(std::vector<uint32_t> indirectBufferData, uint64_t indirectOffset) {
|
||||
void IndirectTest(std::vector<uint32_t> indirectBufferData,
|
||||
uint64_t indirectOffset,
|
||||
bool useNumWorkgroups = true) {
|
||||
// Set up dst storage buffer to contain dispatch x, y, z
|
||||
wgpu::Buffer dst = utils::CreateBufferFromData<uint32_t>(
|
||||
device,
|
||||
|
@ -131,23 +130,34 @@ class ComputeDispatchTests : public DawnTest {
|
|||
|
||||
uint32_t indirectStart = indirectOffset / sizeof(uint32_t);
|
||||
|
||||
// Set up bind group and issue dispatch
|
||||
wgpu::BindGroup bindGroup;
|
||||
wgpu::ComputePipeline computePipelineForTest;
|
||||
|
||||
if (useNumWorkgroups) {
|
||||
computePipelineForTest = pipeline;
|
||||
bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
|
||||
{
|
||||
{0, dst, 0, 3 * sizeof(uint32_t)},
|
||||
});
|
||||
} else {
|
||||
computePipelineForTest = pipelineWithoutNumWorkgroups;
|
||||
wgpu::Buffer expectedBuffer =
|
||||
utils::CreateBufferFromData(device, &indirectBufferData[indirectStart],
|
||||
3 * sizeof(uint32_t), wgpu::BufferUsage::Uniform);
|
||||
|
||||
// Set up bind group and issue dispatch
|
||||
wgpu::BindGroup bindGroup =
|
||||
utils::MakeBindGroup(device, pipelineForDispatchIndirect.GetBindGroupLayout(0),
|
||||
bindGroup =
|
||||
utils::MakeBindGroup(device, pipelineWithoutNumWorkgroups.GetBindGroupLayout(0),
|
||||
{
|
||||
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
|
||||
{1, dst, 0, 3 * sizeof(uint32_t)},
|
||||
});
|
||||
}
|
||||
|
||||
wgpu::CommandBuffer commands;
|
||||
{
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
pass.SetPipeline(pipelineForDispatchIndirect);
|
||||
pass.SetPipeline(computePipelineForTest);
|
||||
pass.SetBindGroup(0, bindGroup);
|
||||
pass.DispatchIndirect(indirectBuffer, indirectOffset);
|
||||
pass.EndPass();
|
||||
|
@ -178,8 +188,8 @@ class ComputeDispatchTests : public DawnTest {
|
|||
}
|
||||
|
||||
private:
|
||||
wgpu::ComputePipeline pipelineForDispatch;
|
||||
wgpu::ComputePipeline pipelineForDispatchIndirect;
|
||||
wgpu::ComputePipeline pipeline;
|
||||
wgpu::ComputePipeline pipelineWithoutNumWorkgroups;
|
||||
};
|
||||
|
||||
// Test basic direct
|
||||
|
@ -207,6 +217,11 @@ TEST_P(ComputeDispatchTests, IndirectBasic) {
|
|||
IndirectTest({2, 3, 4}, 0);
|
||||
}
|
||||
|
||||
// Test basic indirect without using [[num_workgroups]]
|
||||
TEST_P(ComputeDispatchTests, IndirectBasicWithoutNumWorkgroups) {
|
||||
IndirectTest({2, 3, 4}, 0, false);
|
||||
}
|
||||
|
||||
// Test no-op indirect
|
||||
TEST_P(ComputeDispatchTests, IndirectNoop) {
|
||||
// All dimensions are 0s
|
||||
|
@ -227,6 +242,11 @@ TEST_P(ComputeDispatchTests, IndirectOffset) {
|
|||
IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
// Test indirect with buffer offset without using [[num_workgroups]]
|
||||
TEST_P(ComputeDispatchTests, IndirectOffsetWithoutNumWorkgroups) {
|
||||
IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t), false);
|
||||
}
|
||||
|
||||
// Test indirect dispatches at max limit.
|
||||
TEST_P(ComputeDispatchTests, MaxWorkgroups) {
|
||||
// TODO(crbug.com/dawn/1165): Fails with WARP
|
||||
|
@ -244,6 +264,9 @@ TEST_P(ComputeDispatchTests, MaxWorkgroups) {
|
|||
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
|
||||
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
|
||||
|
||||
// TODO(crbug.com/dawn/839): Investigate why this test fails with WARP.
|
||||
DAWN_SUPPRESS_TEST_IF(IsWARP());
|
||||
|
||||
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
// All dimensions are above the max
|
||||
|
@ -266,6 +289,9 @@ TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
|
|||
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) {
|
||||
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
|
||||
|
||||
// TODO(crbug.com/dawn/839): Investigate why this test fails with WARP.
|
||||
DAWN_SUPPRESS_TEST_IF(IsWARP());
|
||||
|
||||
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t));
|
||||
|
|
Loading…
Reference in New Issue