D3D12: Support [[num_workgroups]] for DispatchIndirect
This patch supports [[num_workgroups]] on D3D12 for DispatchIndirect by appending the values of [[num_workgroups]] at the end of the scratch buffer for indirect dispatch validation and setting them as the root constants in the command signature. With this patch, for every DispatchIndirect call: - On D3D12: 1. Validation enabled, [[num_workgroups]] is used The dispatch indirect buffer needs to be validated, duplicated and be written into a scratch buffer (size: 6 * uint32_t). 2. Validation enabled, [[num_workgroups]] isn't used The dispatch indirect buffer needs to be validated and be written into a scratch buffer (size: 3 * uint32_t). 3. Validation disabled, [[num_workgroups]] is used The dispatch indirect buffer needs to be duplicated and be written into a scratch buffer (size: 6 * uint32_t). 4. Validation disabled, [[num_workgroups]] isn't used Neither transformations or scratch buffers are needed for the dispatch call. - On the other backends: 1. Validation enabled, The dispatch indirect buffer needs to be validated and be written into a scratch buffer (size: 3 * uint32_t). 2. Validation disabled, Neither transformations or scratch buffers are needed for the dispatch call. BUG=dawn:839 TEST=dawn_end2end_tests Change-Id: I4105f1b2e3c12f6df6e487ed535a627fbb342344 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/68843 Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
parent
8ce15b3ce9
commit
829d165d7c
|
@ -152,9 +152,11 @@ namespace dawn_native {
|
||||||
mUsage |= kInternalStorageBuffer;
|
mUsage |= kInternalStorageBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
// We also add internal storage usage for Indirect buffers if validation is enabled, since
|
// We also add internal storage usage for Indirect buffers for some transformations before
|
||||||
// validation involves binding them as storage buffers for use in a compute pass.
|
// DispatchIndirect calls on the backend (e.g. validations, support of [[num_workgroups]] on
|
||||||
if ((mUsage & wgpu::BufferUsage::Indirect) && device->IsValidationEnabled()) {
|
// D3D12), since these transformations involve binding them as storage buffers for use in a
|
||||||
|
// compute pass.
|
||||||
|
if (mUsage & wgpu::BufferUsage::Indirect) {
|
||||||
mUsage |= kInternalStorageBuffer;
|
mUsage |= kInternalStorageBuffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -42,11 +42,14 @@ namespace dawn_native {
|
||||||
|
|
||||||
// TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this
|
// TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this
|
||||||
// shader in various failure modes.
|
// shader in various failure modes.
|
||||||
|
// Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable.
|
||||||
Ref<ShaderModuleBase> shaderModule;
|
Ref<ShaderModuleBase> shaderModule;
|
||||||
DAWN_TRY_ASSIGN(shaderModule, utils::CreateShaderModule(device, R"(
|
DAWN_TRY_ASSIGN(shaderModule, utils::CreateShaderModule(device, R"(
|
||||||
[[block]] struct UniformParams {
|
[[block]] struct UniformParams {
|
||||||
maxComputeWorkgroupsPerDimension: u32;
|
maxComputeWorkgroupsPerDimension: u32;
|
||||||
clientOffsetInU32: u32;
|
clientOffsetInU32: u32;
|
||||||
|
enableValidation: u32;
|
||||||
|
duplicateNumWorkgroups: u32;
|
||||||
};
|
};
|
||||||
|
|
||||||
[[block]] struct IndirectParams {
|
[[block]] struct IndirectParams {
|
||||||
|
@ -54,7 +57,7 @@ namespace dawn_native {
|
||||||
};
|
};
|
||||||
|
|
||||||
[[block]] struct ValidatedParams {
|
[[block]] struct ValidatedParams {
|
||||||
data: array<u32, 3>;
|
data: array<u32>;
|
||||||
};
|
};
|
||||||
|
|
||||||
[[group(0), binding(0)]] var<uniform> uniformParams: UniformParams;
|
[[group(0), binding(0)]] var<uniform> uniformParams: UniformParams;
|
||||||
|
@ -65,10 +68,15 @@ namespace dawn_native {
|
||||||
fn main() {
|
fn main() {
|
||||||
for (var i = 0u; i < 3u; i = i + 1u) {
|
for (var i = 0u; i < 3u; i = i + 1u) {
|
||||||
var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i];
|
var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i];
|
||||||
if (numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) {
|
if (uniformParams.enableValidation > 0u &&
|
||||||
|
numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) {
|
||||||
numWorkgroups = 0u;
|
numWorkgroups = 0u;
|
||||||
}
|
}
|
||||||
validatedParams.data[i] = numWorkgroups;
|
validatedParams.data[i] = numWorkgroups;
|
||||||
|
|
||||||
|
if (uniformParams.duplicateNumWorkgroups > 0u) {
|
||||||
|
validatedParams.data[i + 3u] = numWorkgroups;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
)"));
|
)"));
|
||||||
|
@ -191,9 +199,21 @@ namespace dawn_native {
|
||||||
}
|
}
|
||||||
|
|
||||||
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>>
|
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>>
|
||||||
ComputePassEncoder::ValidateIndirectDispatch(BufferBase* indirectBuffer,
|
ComputePassEncoder::TransformIndirectDispatchBuffer(Ref<BufferBase> indirectBuffer,
|
||||||
uint64_t indirectOffset) {
|
uint64_t indirectOffset) {
|
||||||
DeviceBase* device = GetDevice();
|
DeviceBase* device = GetDevice();
|
||||||
|
|
||||||
|
const bool shouldDuplicateNumWorkgroups =
|
||||||
|
device->ShouldDuplicateNumWorkgroupsForDispatchIndirect(
|
||||||
|
mCommandBufferState.GetComputePipeline());
|
||||||
|
if (!IsValidationEnabled() && !shouldDuplicateNumWorkgroups) {
|
||||||
|
return std::make_pair(indirectBuffer, indirectOffset);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save the previous command buffer state so it can be restored after the
|
||||||
|
// validation inserts additional commands.
|
||||||
|
CommandBufferStateTracker previousState = mCommandBufferState;
|
||||||
|
|
||||||
auto* const store = device->GetInternalPipelineStore();
|
auto* const store = device->GetInternalPipelineStore();
|
||||||
|
|
||||||
Ref<ComputePipelineBase> validationPipeline;
|
Ref<ComputePipelineBase> validationPipeline;
|
||||||
|
@ -215,9 +235,14 @@ namespace dawn_native {
|
||||||
const uint64_t clientIndirectBindingSize =
|
const uint64_t clientIndirectBindingSize =
|
||||||
kDispatchIndirectSize + clientOffsetFromAlignedBoundary;
|
kDispatchIndirectSize + clientOffsetFromAlignedBoundary;
|
||||||
|
|
||||||
|
// Neither 'enableValidation' nor 'duplicateNumWorkgroups' can be declared as 'bool' as
|
||||||
|
// currently in WGSL type 'bool' cannot be used in storage class 'uniform' as 'it is
|
||||||
|
// non-host-shareable'.
|
||||||
struct UniformParams {
|
struct UniformParams {
|
||||||
uint32_t maxComputeWorkgroupsPerDimension;
|
uint32_t maxComputeWorkgroupsPerDimension;
|
||||||
uint32_t clientOffsetInU32;
|
uint32_t clientOffsetInU32;
|
||||||
|
uint32_t enableValidation;
|
||||||
|
uint32_t duplicateNumWorkgroups;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Create a uniform buffer to hold parameters for the shader.
|
// Create a uniform buffer to hold parameters for the shader.
|
||||||
|
@ -227,6 +252,8 @@ namespace dawn_native {
|
||||||
params.maxComputeWorkgroupsPerDimension =
|
params.maxComputeWorkgroupsPerDimension =
|
||||||
device->GetLimits().v1.maxComputeWorkgroupsPerDimension;
|
device->GetLimits().v1.maxComputeWorkgroupsPerDimension;
|
||||||
params.clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t);
|
params.clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t);
|
||||||
|
params.enableValidation = static_cast<uint32_t>(IsValidationEnabled());
|
||||||
|
params.duplicateNumWorkgroups = static_cast<uint32_t>(shouldDuplicateNumWorkgroups);
|
||||||
|
|
||||||
DAWN_TRY_ASSIGN(uniformBuffer, utils::CreateBufferFromData(
|
DAWN_TRY_ASSIGN(uniformBuffer, utils::CreateBufferFromData(
|
||||||
device, wgpu::BufferUsage::Uniform, {params}));
|
device, wgpu::BufferUsage::Uniform, {params}));
|
||||||
|
@ -234,25 +261,30 @@ namespace dawn_native {
|
||||||
|
|
||||||
// Reserve space in the scratch buffer to hold the validated indirect params.
|
// Reserve space in the scratch buffer to hold the validated indirect params.
|
||||||
ScratchBuffer& scratchBuffer = store->scratchIndirectStorage;
|
ScratchBuffer& scratchBuffer = store->scratchIndirectStorage;
|
||||||
DAWN_TRY(scratchBuffer.EnsureCapacity(kDispatchIndirectSize));
|
const uint64_t scratchBufferSize =
|
||||||
|
shouldDuplicateNumWorkgroups ? 2 * kDispatchIndirectSize : kDispatchIndirectSize;
|
||||||
|
DAWN_TRY(scratchBuffer.EnsureCapacity(scratchBufferSize));
|
||||||
Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer();
|
Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer();
|
||||||
|
|
||||||
Ref<BindGroupBase> validationBindGroup;
|
Ref<BindGroupBase> validationBindGroup;
|
||||||
DAWN_TRY_ASSIGN(
|
ASSERT(indirectBuffer->GetUsage() & kInternalStorageBuffer);
|
||||||
validationBindGroup,
|
DAWN_TRY_ASSIGN(validationBindGroup,
|
||||||
utils::MakeBindGroup(
|
utils::MakeBindGroup(device, layout,
|
||||||
device, layout,
|
{
|
||||||
{
|
{0, uniformBuffer},
|
||||||
{0, uniformBuffer},
|
{1, indirectBuffer, clientIndirectBindingOffset,
|
||||||
{1, indirectBuffer, clientIndirectBindingOffset, clientIndirectBindingSize},
|
clientIndirectBindingSize},
|
||||||
{2, validatedIndirectBuffer, 0, kDispatchIndirectSize},
|
{2, validatedIndirectBuffer, 0, scratchBufferSize},
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Issue commands to validate the indirect buffer.
|
// Issue commands to validate the indirect buffer.
|
||||||
APISetPipeline(validationPipeline.Get());
|
APISetPipeline(validationPipeline.Get());
|
||||||
APISetBindGroup(0, validationBindGroup.Get());
|
APISetBindGroup(0, validationBindGroup.Get());
|
||||||
APIDispatch(1);
|
APIDispatch(1);
|
||||||
|
|
||||||
|
// Restore the state.
|
||||||
|
RestoreCommandBufferState(std::move(previousState));
|
||||||
|
|
||||||
// Return the new indirect buffer and indirect buffer offset.
|
// Return the new indirect buffer and indirect buffer offset.
|
||||||
return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0));
|
return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0));
|
||||||
}
|
}
|
||||||
|
@ -287,27 +319,28 @@ namespace dawn_native {
|
||||||
// the backend.
|
// the backend.
|
||||||
|
|
||||||
Ref<BufferBase> indirectBufferRef = indirectBuffer;
|
Ref<BufferBase> indirectBufferRef = indirectBuffer;
|
||||||
if (IsValidationEnabled()) {
|
|
||||||
// Save the previous command buffer state so it can be restored after the
|
|
||||||
// validation inserts additional commands.
|
|
||||||
CommandBufferStateTracker previousState = mCommandBufferState;
|
|
||||||
|
|
||||||
// Validate each indirect dispatch with a single dispatch to copy the indirect
|
// Get applied indirect buffer with necessary changes on the original indirect
|
||||||
// buffer params into a scratch buffer if they're valid, and otherwise zero them
|
// buffer. For example,
|
||||||
// out. We could consider moving the validation earlier in the pass after the
|
// - Validate each indirect dispatch with a single dispatch to copy the indirect
|
||||||
// last point the indirect buffer was used with writable usage, as well as batch
|
// buffer params into a scratch buffer if they're valid, and otherwise zero them
|
||||||
// validation for multiple dispatches into one, but inserting commands at
|
// out.
|
||||||
// arbitrary points in the past is not possible right now.
|
// - Duplicate all the indirect dispatch parameters to support [[num_workgroups]] on
|
||||||
DAWN_TRY_ASSIGN(
|
// D3D12.
|
||||||
std::tie(indirectBufferRef, indirectOffset),
|
// - Directly return the original indirect dispatch buffer if we don't need any
|
||||||
ValidateIndirectDispatch(indirectBufferRef.Get(), indirectOffset));
|
// transformations on it.
|
||||||
|
// We could consider moving the validation earlier in the pass after the last
|
||||||
// Restore the state.
|
// last point the indirect buffer was used with writable usage, as well as batch
|
||||||
RestoreCommandBufferState(std::move(previousState));
|
// validation for multiple dispatches into one, but inserting commands at
|
||||||
|
// arbitrary points in the past is not possible right now.
|
||||||
|
DAWN_TRY_ASSIGN(std::tie(indirectBufferRef, indirectOffset),
|
||||||
|
TransformIndirectDispatchBuffer(indirectBufferRef, indirectOffset));
|
||||||
|
|
||||||
|
// If we have created a new scratch dispatch indirect buffer in
|
||||||
|
// TransformIndirectDispatchBuffer(), we need to track it in mUsageTracker.
|
||||||
|
if (indirectBufferRef.Get() != indirectBuffer) {
|
||||||
// |indirectBufferRef| was replaced with a scratch buffer. Add it to the
|
// |indirectBufferRef| was replaced with a scratch buffer. Add it to the
|
||||||
// synchronization scope.
|
// synchronization scope.
|
||||||
ASSERT(indirectBufferRef.Get() != indirectBuffer);
|
|
||||||
scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect);
|
scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect);
|
||||||
mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get());
|
mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get());
|
||||||
}
|
}
|
||||||
|
|
|
@ -64,8 +64,8 @@ namespace dawn_native {
|
||||||
private:
|
private:
|
||||||
void DestroyImpl() override;
|
void DestroyImpl() override;
|
||||||
|
|
||||||
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ValidateIndirectDispatch(
|
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> TransformIndirectDispatchBuffer(
|
||||||
BufferBase* indirectBuffer,
|
Ref<BufferBase> indirectBuffer,
|
||||||
uint64_t indirectOffset);
|
uint64_t indirectOffset);
|
||||||
|
|
||||||
void RestoreCommandBufferState(CommandBufferStateTracker state);
|
void RestoreCommandBufferState(CommandBufferStateTracker state);
|
||||||
|
|
|
@ -1670,4 +1670,9 @@ namespace dawn_native {
|
||||||
void DeviceBase::SetLabelImpl() {
|
void DeviceBase::SetLabelImpl() {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool DeviceBase::ShouldDuplicateNumWorkgroupsForDispatchIndirect(
|
||||||
|
ComputePipelineBase* computePipeline) const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dawn_native
|
} // namespace dawn_native
|
||||||
|
|
|
@ -336,11 +336,16 @@ namespace dawn_native {
|
||||||
|
|
||||||
MaybeError Tick();
|
MaybeError Tick();
|
||||||
|
|
||||||
|
// TODO(crbug.com/dawn/839): Organize the below backend-specific parameters into the struct
|
||||||
|
// BackendMetadata that we can query from the device.
|
||||||
virtual uint32_t GetOptimalBytesPerRowAlignment() const = 0;
|
virtual uint32_t GetOptimalBytesPerRowAlignment() const = 0;
|
||||||
virtual uint64_t GetOptimalBufferToTextureCopyOffsetAlignment() const = 0;
|
virtual uint64_t GetOptimalBufferToTextureCopyOffsetAlignment() const = 0;
|
||||||
|
|
||||||
virtual float GetTimestampPeriodInNS() const = 0;
|
virtual float GetTimestampPeriodInNS() const = 0;
|
||||||
|
|
||||||
|
virtual bool ShouldDuplicateNumWorkgroupsForDispatchIndirect(
|
||||||
|
ComputePipelineBase* computePipeline) const;
|
||||||
|
|
||||||
const CombinedLimits& GetLimits() const;
|
const CombinedLimits& GetLimits() const;
|
||||||
|
|
||||||
AsyncTaskManager* GetAsyncTaskManager() const;
|
AsyncTaskManager* GetAsyncTaskManager() const;
|
||||||
|
|
|
@ -1063,22 +1063,15 @@ namespace dawn_native { namespace d3d12 {
|
||||||
case Command::DispatchIndirect: {
|
case Command::DispatchIndirect: {
|
||||||
DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
|
DispatchIndirectCmd* dispatch = mCommands.NextCommand<DispatchIndirectCmd>();
|
||||||
|
|
||||||
// TODO(dawn:839): support [[num_workgroups]] for DispatchIndirect calls
|
|
||||||
DAWN_INVALID_IF(lastPipeline->UsesNumWorkgroups(),
|
|
||||||
"Using %s with [[num_workgroups]] in a DispatchIndirect call "
|
|
||||||
"is not implemented.",
|
|
||||||
lastPipeline);
|
|
||||||
|
|
||||||
Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get());
|
|
||||||
|
|
||||||
TransitionAndClearForSyncScope(commandContext,
|
TransitionAndClearForSyncScope(commandContext,
|
||||||
resourceUsages.dispatchUsages[currentDispatch]);
|
resourceUsages.dispatchUsages[currentDispatch]);
|
||||||
DAWN_TRY(bindingTracker->Apply(commandContext));
|
DAWN_TRY(bindingTracker->Apply(commandContext));
|
||||||
|
|
||||||
ComPtr<ID3D12CommandSignature> signature =
|
ComPtr<ID3D12CommandSignature> signature =
|
||||||
ToBackend(GetDevice())->GetDispatchIndirectSignature();
|
lastPipeline->GetDispatchIndirectCommandSignature();
|
||||||
commandList->ExecuteIndirect(signature.Get(), 1, buffer->GetD3D12Resource(),
|
commandList->ExecuteIndirect(
|
||||||
dispatch->indirectOffset, nullptr, 0);
|
signature.Get(), 1, ToBackend(dispatch->indirectBuffer)->GetD3D12Resource(),
|
||||||
|
dispatch->indirectOffset, nullptr, 0);
|
||||||
currentDispatch++;
|
currentDispatch++;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
|
@ -90,4 +90,11 @@ namespace dawn_native { namespace d3d12 {
|
||||||
return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups;
|
return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ComPtr<ID3D12CommandSignature> ComputePipeline::GetDispatchIndirectCommandSignature() {
|
||||||
|
if (UsesNumWorkgroups()) {
|
||||||
|
return ToBackend(GetLayout())->GetDispatchIndirectCommandSignatureWithNumWorkgroups();
|
||||||
|
}
|
||||||
|
return ToBackend(GetDevice())->GetDispatchIndirectSignature();
|
||||||
|
}
|
||||||
|
|
||||||
}} // namespace dawn_native::d3d12
|
}} // namespace dawn_native::d3d12
|
||||||
|
|
|
@ -42,6 +42,8 @@ namespace dawn_native { namespace d3d12 {
|
||||||
|
|
||||||
bool UsesNumWorkgroups() const;
|
bool UsesNumWorkgroups() const;
|
||||||
|
|
||||||
|
ComPtr<ID3D12CommandSignature> GetDispatchIndirectCommandSignature();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
~ComputePipeline() override;
|
~ComputePipeline() override;
|
||||||
|
|
||||||
|
|
|
@ -675,4 +675,9 @@ namespace dawn_native { namespace d3d12 {
|
||||||
return mTimestampPeriod;
|
return mTimestampPeriod;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Device::ShouldDuplicateNumWorkgroupsForDispatchIndirect(
|
||||||
|
ComputePipelineBase* computePipeline) const {
|
||||||
|
return ToBackend(computePipeline)->UsesNumWorkgroups();
|
||||||
|
}
|
||||||
|
|
||||||
}} // namespace dawn_native::d3d12
|
}} // namespace dawn_native::d3d12
|
||||||
|
|
|
@ -139,6 +139,9 @@ namespace dawn_native { namespace d3d12 {
|
||||||
|
|
||||||
float GetTimestampPeriodInNS() const override;
|
float GetTimestampPeriodInNS() const override;
|
||||||
|
|
||||||
|
bool ShouldDuplicateNumWorkgroupsForDispatchIndirect(
|
||||||
|
ComputePipelineBase* computePipeline) const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
using DeviceBase::DeviceBase;
|
using DeviceBase::DeviceBase;
|
||||||
|
|
||||||
|
|
|
@ -184,7 +184,7 @@ namespace dawn_native { namespace d3d12 {
|
||||||
numWorkgroupsConstants.Constants.Num32BitValues = 3;
|
numWorkgroupsConstants.Constants.Num32BitValues = 3;
|
||||||
numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace();
|
numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace();
|
||||||
numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister();
|
numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister();
|
||||||
mNumWorkgroupsParamterIndex = rootParameters.size();
|
mNumWorkgroupsParameterIndex = rootParameters.size();
|
||||||
// NOTE: We should consider moving this entry to earlier in the root signature since
|
// NOTE: We should consider moving this entry to earlier in the root signature since
|
||||||
// dispatch sizes would need to be updated often
|
// dispatch sizes would need to be updated often
|
||||||
rootParameters.emplace_back(numWorkgroupsConstants);
|
rootParameters.emplace_back(numWorkgroupsConstants);
|
||||||
|
@ -265,6 +265,38 @@ namespace dawn_native { namespace d3d12 {
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const {
|
uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const {
|
||||||
return mNumWorkgroupsParamterIndex;
|
return mNumWorkgroupsParameterIndex;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ID3D12CommandSignature* PipelineLayout::GetDispatchIndirectCommandSignatureWithNumWorkgroups() {
|
||||||
|
// mDispatchIndirectCommandSignatureWithNumWorkgroups won't be created until it is needed.
|
||||||
|
if (mDispatchIndirectCommandSignatureWithNumWorkgroups.Get() != nullptr) {
|
||||||
|
return mDispatchIndirectCommandSignatureWithNumWorkgroups.Get();
|
||||||
|
}
|
||||||
|
|
||||||
|
D3D12_INDIRECT_ARGUMENT_DESC argumentDescs[2] = {};
|
||||||
|
argumentDescs[0].Type = D3D12_INDIRECT_ARGUMENT_TYPE_CONSTANT;
|
||||||
|
argumentDescs[0].Constant.RootParameterIndex = GetNumWorkgroupsParameterIndex();
|
||||||
|
argumentDescs[0].Constant.Num32BitValuesToSet = 3;
|
||||||
|
argumentDescs[0].Constant.DestOffsetIn32BitValues = 0;
|
||||||
|
|
||||||
|
// A command signature must contain exactly 1 Draw / Dispatch / DispatchMesh / DispatchRays
|
||||||
|
// command. That command must come last.
|
||||||
|
argumentDescs[1].Type = D3D12_INDIRECT_ARGUMENT_TYPE_DISPATCH;
|
||||||
|
|
||||||
|
D3D12_COMMAND_SIGNATURE_DESC programDesc = {};
|
||||||
|
programDesc.ByteStride = 6 * sizeof(uint32_t);
|
||||||
|
programDesc.NumArgumentDescs = 2;
|
||||||
|
programDesc.pArgumentDescs = argumentDescs;
|
||||||
|
|
||||||
|
// The root signature must be specified if and only if the command signature changes one of
|
||||||
|
// the root arguments.
|
||||||
|
ToBackend(GetDevice())
|
||||||
|
->GetD3D12Device()
|
||||||
|
->CreateCommandSignature(
|
||||||
|
&programDesc, GetRootSignature(),
|
||||||
|
IID_PPV_ARGS(&mDispatchIndirectCommandSignatureWithNumWorkgroups));
|
||||||
|
return mDispatchIndirectCommandSignatureWithNumWorkgroups.Get();
|
||||||
|
}
|
||||||
|
|
||||||
}} // namespace dawn_native::d3d12
|
}} // namespace dawn_native::d3d12
|
||||||
|
|
|
@ -55,6 +55,8 @@ namespace dawn_native { namespace d3d12 {
|
||||||
|
|
||||||
ID3D12RootSignature* GetRootSignature() const;
|
ID3D12RootSignature* GetRootSignature() const;
|
||||||
|
|
||||||
|
ID3D12CommandSignature* GetDispatchIndirectCommandSignatureWithNumWorkgroups();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
~PipelineLayout() override = default;
|
~PipelineLayout() override = default;
|
||||||
using PipelineLayoutBase::PipelineLayoutBase;
|
using PipelineLayoutBase::PipelineLayoutBase;
|
||||||
|
@ -66,8 +68,9 @@ namespace dawn_native { namespace d3d12 {
|
||||||
kMaxBindGroups>
|
kMaxBindGroups>
|
||||||
mDynamicRootParameterIndices;
|
mDynamicRootParameterIndices;
|
||||||
uint32_t mFirstIndexOffsetParameterIndex;
|
uint32_t mFirstIndexOffsetParameterIndex;
|
||||||
uint32_t mNumWorkgroupsParamterIndex;
|
uint32_t mNumWorkgroupsParameterIndex;
|
||||||
ComPtr<ID3D12RootSignature> mRootSignature;
|
ComPtr<ID3D12RootSignature> mRootSignature;
|
||||||
|
ComPtr<ID3D12CommandSignature> mDispatchIndirectCommandSignatureWithNumWorkgroups;
|
||||||
};
|
};
|
||||||
|
|
||||||
}} // namespace dawn_native::d3d12
|
}} // namespace dawn_native::d3d12
|
||||||
|
|
|
@ -27,7 +27,7 @@ class ComputeDispatchTests : public DawnTest {
|
||||||
|
|
||||||
// Write workgroup number into the output buffer if we saw the biggest dispatch
|
// Write workgroup number into the output buffer if we saw the biggest dispatch
|
||||||
// To make sure the dispatch was not called, write maximum u32 value for 0 dispatches
|
// To make sure the dispatch was not called, write maximum u32 value for 0 dispatches
|
||||||
wgpu::ShaderModule moduleForDispatch = utils::CreateShaderModule(device, R"(
|
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
|
||||||
[[block]] struct OutputBuf {
|
[[block]] struct OutputBuf {
|
||||||
workGroups : vec3<u32>;
|
workGroups : vec3<u32>;
|
||||||
};
|
};
|
||||||
|
@ -47,9 +47,13 @@ class ComputeDispatchTests : public DawnTest {
|
||||||
}
|
}
|
||||||
})");
|
})");
|
||||||
|
|
||||||
// TODO(dawn:839): use moduleForDispatch for indirect dispatch tests when D3D12 supports
|
wgpu::ComputePipelineDescriptor csDesc;
|
||||||
// [[num_workgroups]] for indirect dispatch.
|
csDesc.compute.module = module;
|
||||||
wgpu::ShaderModule moduleForDispatchIndirect = utils::CreateShaderModule(device, R"(
|
csDesc.compute.entryPoint = "main";
|
||||||
|
pipeline = device.CreateComputePipeline(&csDesc);
|
||||||
|
|
||||||
|
// Test the use of the compute pipelines without using [[num_workgroups]]
|
||||||
|
wgpu::ShaderModule moduleWithoutNumWorkgroups = utils::CreateShaderModule(device, R"(
|
||||||
[[block]] struct InputBuf {
|
[[block]] struct InputBuf {
|
||||||
expectedDispatch : vec3<u32>;
|
expectedDispatch : vec3<u32>;
|
||||||
};
|
};
|
||||||
|
@ -73,14 +77,8 @@ class ComputeDispatchTests : public DawnTest {
|
||||||
output.workGroups = dispatch;
|
output.workGroups = dispatch;
|
||||||
}
|
}
|
||||||
})");
|
})");
|
||||||
|
csDesc.compute.module = moduleWithoutNumWorkgroups;
|
||||||
wgpu::ComputePipelineDescriptor csDesc;
|
pipelineWithoutNumWorkgroups = device.CreateComputePipeline(&csDesc);
|
||||||
csDesc.compute.module = moduleForDispatch;
|
|
||||||
csDesc.compute.entryPoint = "main";
|
|
||||||
pipelineForDispatch = device.CreateComputePipeline(&csDesc);
|
|
||||||
|
|
||||||
csDesc.compute.module = moduleForDispatchIndirect;
|
|
||||||
pipelineForDispatchIndirect = device.CreateComputePipeline(&csDesc);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void DirectTest(uint32_t x, uint32_t y, uint32_t z) {
|
void DirectTest(uint32_t x, uint32_t y, uint32_t z) {
|
||||||
|
@ -91,17 +89,16 @@ class ComputeDispatchTests : public DawnTest {
|
||||||
kSentinelData);
|
kSentinelData);
|
||||||
|
|
||||||
// Set up bind group and issue dispatch
|
// Set up bind group and issue dispatch
|
||||||
wgpu::BindGroup bindGroup =
|
wgpu::BindGroup bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
|
||||||
utils::MakeBindGroup(device, pipelineForDispatch.GetBindGroupLayout(0),
|
{
|
||||||
{
|
{0, dst, 0, 3 * sizeof(uint32_t)},
|
||||||
{0, dst, 0, 3 * sizeof(uint32_t)},
|
});
|
||||||
});
|
|
||||||
|
|
||||||
wgpu::CommandBuffer commands;
|
wgpu::CommandBuffer commands;
|
||||||
{
|
{
|
||||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||||
pass.SetPipeline(pipelineForDispatch);
|
pass.SetPipeline(pipeline);
|
||||||
pass.SetBindGroup(0, bindGroup);
|
pass.SetBindGroup(0, bindGroup);
|
||||||
pass.Dispatch(x, y, z);
|
pass.Dispatch(x, y, z);
|
||||||
pass.EndPass();
|
pass.EndPass();
|
||||||
|
@ -118,7 +115,9 @@ class ComputeDispatchTests : public DawnTest {
|
||||||
EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3);
|
EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
void IndirectTest(std::vector<uint32_t> indirectBufferData, uint64_t indirectOffset) {
|
void IndirectTest(std::vector<uint32_t> indirectBufferData,
|
||||||
|
uint64_t indirectOffset,
|
||||||
|
bool useNumWorkgroups = true) {
|
||||||
// Set up dst storage buffer to contain dispatch x, y, z
|
// Set up dst storage buffer to contain dispatch x, y, z
|
||||||
wgpu::Buffer dst = utils::CreateBufferFromData<uint32_t>(
|
wgpu::Buffer dst = utils::CreateBufferFromData<uint32_t>(
|
||||||
device,
|
device,
|
||||||
|
@ -131,23 +130,34 @@ class ComputeDispatchTests : public DawnTest {
|
||||||
|
|
||||||
uint32_t indirectStart = indirectOffset / sizeof(uint32_t);
|
uint32_t indirectStart = indirectOffset / sizeof(uint32_t);
|
||||||
|
|
||||||
wgpu::Buffer expectedBuffer =
|
|
||||||
utils::CreateBufferFromData(device, &indirectBufferData[indirectStart],
|
|
||||||
3 * sizeof(uint32_t), wgpu::BufferUsage::Uniform);
|
|
||||||
|
|
||||||
// Set up bind group and issue dispatch
|
// Set up bind group and issue dispatch
|
||||||
wgpu::BindGroup bindGroup =
|
wgpu::BindGroup bindGroup;
|
||||||
utils::MakeBindGroup(device, pipelineForDispatchIndirect.GetBindGroupLayout(0),
|
wgpu::ComputePipeline computePipelineForTest;
|
||||||
{
|
|
||||||
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
|
if (useNumWorkgroups) {
|
||||||
{1, dst, 0, 3 * sizeof(uint32_t)},
|
computePipelineForTest = pipeline;
|
||||||
});
|
bindGroup = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0),
|
||||||
|
{
|
||||||
|
{0, dst, 0, 3 * sizeof(uint32_t)},
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
computePipelineForTest = pipelineWithoutNumWorkgroups;
|
||||||
|
wgpu::Buffer expectedBuffer =
|
||||||
|
utils::CreateBufferFromData(device, &indirectBufferData[indirectStart],
|
||||||
|
3 * sizeof(uint32_t), wgpu::BufferUsage::Uniform);
|
||||||
|
bindGroup =
|
||||||
|
utils::MakeBindGroup(device, pipelineWithoutNumWorkgroups.GetBindGroupLayout(0),
|
||||||
|
{
|
||||||
|
{0, expectedBuffer, 0, 3 * sizeof(uint32_t)},
|
||||||
|
{1, dst, 0, 3 * sizeof(uint32_t)},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
wgpu::CommandBuffer commands;
|
wgpu::CommandBuffer commands;
|
||||||
{
|
{
|
||||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||||
pass.SetPipeline(pipelineForDispatchIndirect);
|
pass.SetPipeline(computePipelineForTest);
|
||||||
pass.SetBindGroup(0, bindGroup);
|
pass.SetBindGroup(0, bindGroup);
|
||||||
pass.DispatchIndirect(indirectBuffer, indirectOffset);
|
pass.DispatchIndirect(indirectBuffer, indirectOffset);
|
||||||
pass.EndPass();
|
pass.EndPass();
|
||||||
|
@ -178,8 +188,8 @@ class ComputeDispatchTests : public DawnTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
wgpu::ComputePipeline pipelineForDispatch;
|
wgpu::ComputePipeline pipeline;
|
||||||
wgpu::ComputePipeline pipelineForDispatchIndirect;
|
wgpu::ComputePipeline pipelineWithoutNumWorkgroups;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Test basic direct
|
// Test basic direct
|
||||||
|
@ -207,6 +217,11 @@ TEST_P(ComputeDispatchTests, IndirectBasic) {
|
||||||
IndirectTest({2, 3, 4}, 0);
|
IndirectTest({2, 3, 4}, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test basic indirect without using [[num_workgroups]]
|
||||||
|
TEST_P(ComputeDispatchTests, IndirectBasicWithoutNumWorkgroups) {
|
||||||
|
IndirectTest({2, 3, 4}, 0, false);
|
||||||
|
}
|
||||||
|
|
||||||
// Test no-op indirect
|
// Test no-op indirect
|
||||||
TEST_P(ComputeDispatchTests, IndirectNoop) {
|
TEST_P(ComputeDispatchTests, IndirectNoop) {
|
||||||
// All dimensions are 0s
|
// All dimensions are 0s
|
||||||
|
@ -227,6 +242,11 @@ TEST_P(ComputeDispatchTests, IndirectOffset) {
|
||||||
IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t));
|
IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test indirect with buffer offset without using [[num_workgroups]]
|
||||||
|
TEST_P(ComputeDispatchTests, IndirectOffsetWithoutNumWorkgroups) {
|
||||||
|
IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t), false);
|
||||||
|
}
|
||||||
|
|
||||||
// Test indirect dispatches at max limit.
|
// Test indirect dispatches at max limit.
|
||||||
TEST_P(ComputeDispatchTests, MaxWorkgroups) {
|
TEST_P(ComputeDispatchTests, MaxWorkgroups) {
|
||||||
// TODO(crbug.com/dawn/1165): Fails with WARP
|
// TODO(crbug.com/dawn/1165): Fails with WARP
|
||||||
|
@ -244,6 +264,9 @@ TEST_P(ComputeDispatchTests, MaxWorkgroups) {
|
||||||
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
|
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
|
||||||
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
|
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
|
||||||
|
|
||||||
|
// TODO(crbug.com/dawn/839): Investigate why this test fails with WARP.
|
||||||
|
DAWN_SUPPRESS_TEST_IF(IsWARP());
|
||||||
|
|
||||||
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
|
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
|
||||||
|
|
||||||
// All dimensions are above the max
|
// All dimensions are above the max
|
||||||
|
@ -266,6 +289,9 @@ TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
|
||||||
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) {
|
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) {
|
||||||
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
|
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
|
||||||
|
|
||||||
|
// TODO(crbug.com/dawn/839): Investigate why this test fails with WARP.
|
||||||
|
DAWN_SUPPRESS_TEST_IF(IsWARP());
|
||||||
|
|
||||||
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
|
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
|
||||||
|
|
||||||
IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t));
|
IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t));
|
||||||
|
|
Loading…
Reference in New Issue