Implement GPU-based validation for dispatchIndirect

Bug: dawn:1039
Change-Id: I1b77244d33b178c8e4d4b7d72dc038ccb9d65c48
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/67142
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Austin Eng 2021-10-29 18:52:33 +00:00 committed by Dawn LUCI CQ
parent deb4057d27
commit bcfa7b1253
13 changed files with 756 additions and 27 deletions

View File

@ -73,6 +73,10 @@ namespace dawn_native {
return mResourceUsages; return mResourceUsages;
} }
CommandIterator* CommandBufferBase::GetCommandIteratorForTesting() {
return &mCommands;
}
bool IsCompleteSubresourceCopiedTo(const TextureBase* texture, bool IsCompleteSubresourceCopiedTo(const TextureBase* texture,
const Extent3D copySize, const Extent3D copySize,
const uint32_t mipLevel) { const uint32_t mipLevel) {

View File

@ -43,6 +43,8 @@ namespace dawn_native {
const CommandBufferResourceUsage& GetResourceUsages() const; const CommandBufferResourceUsage& GetResourceUsages() const;
CommandIterator* GetCommandIteratorForTesting();
protected: protected:
~CommandBufferBase() override; ~CommandBufferBase() override;

View File

@ -17,8 +17,10 @@
#include "common/Assert.h" #include "common/Assert.h"
#include "common/BitSetIterator.h" #include "common/BitSetIterator.h"
#include "dawn_native/BindGroup.h" #include "dawn_native/BindGroup.h"
#include "dawn_native/ComputePassEncoder.h"
#include "dawn_native/ComputePipeline.h" #include "dawn_native/ComputePipeline.h"
#include "dawn_native/Forward.h" #include "dawn_native/Forward.h"
#include "dawn_native/ObjectType_autogen.h"
#include "dawn_native/PipelineLayout.h" #include "dawn_native/PipelineLayout.h"
#include "dawn_native/RenderPipeline.h" #include "dawn_native/RenderPipeline.h"
@ -83,13 +85,15 @@ namespace dawn_native {
MaybeError CommandBufferStateTracker::ValidateBufferInRangeForVertexBuffer( MaybeError CommandBufferStateTracker::ValidateBufferInRangeForVertexBuffer(
uint32_t vertexCount, uint32_t vertexCount,
uint32_t firstVertex) { uint32_t firstVertex) {
RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>& const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>&
vertexBufferSlotsUsedAsVertexBuffer = vertexBufferSlotsUsedAsVertexBuffer =
mLastRenderPipeline->GetVertexBufferSlotsUsedAsVertexBuffer(); lastRenderPipeline->GetVertexBufferSlotsUsedAsVertexBuffer();
for (auto usedSlotVertex : IterateBitSet(vertexBufferSlotsUsedAsVertexBuffer)) { for (auto usedSlotVertex : IterateBitSet(vertexBufferSlotsUsedAsVertexBuffer)) {
const VertexBufferInfo& vertexBuffer = const VertexBufferInfo& vertexBuffer =
mLastRenderPipeline->GetVertexBuffer(usedSlotVertex); lastRenderPipeline->GetVertexBuffer(usedSlotVertex);
uint64_t arrayStride = vertexBuffer.arrayStride; uint64_t arrayStride = vertexBuffer.arrayStride;
uint64_t bufferSize = mVertexBufferSizes[usedSlotVertex]; uint64_t bufferSize = mVertexBufferSizes[usedSlotVertex];
@ -120,13 +124,15 @@ namespace dawn_native {
MaybeError CommandBufferStateTracker::ValidateBufferInRangeForInstanceBuffer( MaybeError CommandBufferStateTracker::ValidateBufferInRangeForInstanceBuffer(
uint32_t instanceCount, uint32_t instanceCount,
uint32_t firstInstance) { uint32_t firstInstance) {
RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>& const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>&
vertexBufferSlotsUsedAsInstanceBuffer = vertexBufferSlotsUsedAsInstanceBuffer =
mLastRenderPipeline->GetVertexBufferSlotsUsedAsInstanceBuffer(); lastRenderPipeline->GetVertexBufferSlotsUsedAsInstanceBuffer();
for (auto usedSlotInstance : IterateBitSet(vertexBufferSlotsUsedAsInstanceBuffer)) { for (auto usedSlotInstance : IterateBitSet(vertexBufferSlotsUsedAsInstanceBuffer)) {
const VertexBufferInfo& vertexBuffer = const VertexBufferInfo& vertexBuffer =
mLastRenderPipeline->GetVertexBuffer(usedSlotInstance); lastRenderPipeline->GetVertexBuffer(usedSlotInstance);
uint64_t arrayStride = vertexBuffer.arrayStride; uint64_t arrayStride = vertexBuffer.arrayStride;
uint64_t bufferSize = mVertexBufferSizes[usedSlotInstance]; uint64_t bufferSize = mVertexBufferSizes[usedSlotInstance];
if (arrayStride == 0) { if (arrayStride == 0) {
@ -209,18 +215,19 @@ namespace dawn_native {
} }
if (aspects[VALIDATION_ASPECT_VERTEX_BUFFERS]) { if (aspects[VALIDATION_ASPECT_VERTEX_BUFFERS]) {
ASSERT(mLastRenderPipeline != nullptr); RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>& requiredVertexBuffers = const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>& requiredVertexBuffers =
mLastRenderPipeline->GetVertexBufferSlotsUsed(); lastRenderPipeline->GetVertexBufferSlotsUsed();
if (IsSubset(requiredVertexBuffers, mVertexBufferSlotsUsed)) { if (IsSubset(requiredVertexBuffers, mVertexBufferSlotsUsed)) {
mAspects.set(VALIDATION_ASPECT_VERTEX_BUFFERS); mAspects.set(VALIDATION_ASPECT_VERTEX_BUFFERS);
} }
} }
if (aspects[VALIDATION_ASPECT_INDEX_BUFFER] && mIndexBufferSet) { if (aspects[VALIDATION_ASPECT_INDEX_BUFFER] && mIndexBufferSet) {
if (!IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) || RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
mIndexFormat == mLastRenderPipeline->GetStripIndexFormat()) { if (!IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) ||
mIndexFormat == lastRenderPipeline->GetStripIndexFormat()) {
mAspects.set(VALIDATION_ASPECT_INDEX_BUFFER); mAspects.set(VALIDATION_ASPECT_INDEX_BUFFER);
} }
} }
@ -234,12 +241,13 @@ namespace dawn_native {
if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_INDEX_BUFFER])) { if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_INDEX_BUFFER])) {
DAWN_INVALID_IF(!mIndexBufferSet, "Index buffer was not set."); DAWN_INVALID_IF(!mIndexBufferSet, "Index buffer was not set.");
wgpu::IndexFormat pipelineIndexFormat = mLastRenderPipeline->GetStripIndexFormat(); RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
wgpu::IndexFormat pipelineIndexFormat = lastRenderPipeline->GetStripIndexFormat();
DAWN_INVALID_IF( DAWN_INVALID_IF(
IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) && IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) &&
mIndexFormat != pipelineIndexFormat, mIndexFormat != pipelineIndexFormat,
"Strip index format (%s) of %s does not match index buffer format (%s).", "Strip index format (%s) of %s does not match index buffer format (%s).",
pipelineIndexFormat, mLastRenderPipeline, mIndexFormat); pipelineIndexFormat, lastRenderPipeline, mIndexFormat);
// The chunk of code above should be similar to the one in |RecomputeLazyAspects|. // The chunk of code above should be similar to the one in |RecomputeLazyAspects|.
// It returns the first invalid state found. We shouldn't be able to reach this line // It returns the first invalid state found. We shouldn't be able to reach this line
@ -251,7 +259,7 @@ namespace dawn_native {
// TODO(dawn:563): Indicate which slots were not set. // TODO(dawn:563): Indicate which slots were not set.
DAWN_INVALID_IF(aspects[VALIDATION_ASPECT_VERTEX_BUFFERS], DAWN_INVALID_IF(aspects[VALIDATION_ASPECT_VERTEX_BUFFERS],
"Vertex buffer slots required by %s were not set.", mLastRenderPipeline); "Vertex buffer slots required by %s were not set.", GetRenderPipeline());
if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_BIND_GROUPS])) { if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_BIND_GROUPS])) {
for (BindGroupIndex i : IterateBitSet(mLastPipelineLayout->GetBindGroupLayoutsMask())) { for (BindGroupIndex i : IterateBitSet(mLastPipelineLayout->GetBindGroupLayoutsMask())) {
@ -290,12 +298,15 @@ namespace dawn_native {
} }
void CommandBufferStateTracker::SetRenderPipeline(RenderPipelineBase* pipeline) { void CommandBufferStateTracker::SetRenderPipeline(RenderPipelineBase* pipeline) {
mLastRenderPipeline = pipeline;
SetPipelineCommon(pipeline); SetPipelineCommon(pipeline);
} }
void CommandBufferStateTracker::SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup) { void CommandBufferStateTracker::SetBindGroup(BindGroupIndex index,
BindGroupBase* bindgroup,
uint32_t dynamicOffsetCount,
const uint32_t* dynamicOffsets) {
mBindgroups[index] = bindgroup; mBindgroups[index] = bindgroup;
mDynamicOffsets[index].assign(dynamicOffsets, dynamicOffsets + dynamicOffsetCount);
mAspects.reset(VALIDATION_ASPECT_BIND_GROUPS); mAspects.reset(VALIDATION_ASPECT_BIND_GROUPS);
} }
@ -311,8 +322,9 @@ namespace dawn_native {
} }
void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) { void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) {
mLastPipelineLayout = pipeline->GetLayout(); mLastPipeline = pipeline;
mMinBufferSizes = &pipeline->GetMinBufferSizes(); mLastPipelineLayout = pipeline != nullptr ? pipeline->GetLayout() : nullptr;
mMinBufferSizes = pipeline != nullptr ? &pipeline->GetMinBufferSizes() : nullptr;
mAspects.set(VALIDATION_ASPECT_PIPELINE); mAspects.set(VALIDATION_ASPECT_PIPELINE);
@ -324,6 +336,25 @@ namespace dawn_native {
return mBindgroups[index]; return mBindgroups[index];
} }
const std::vector<uint32_t>& CommandBufferStateTracker::GetDynamicOffsets(
BindGroupIndex index) const {
return mDynamicOffsets[index];
}
bool CommandBufferStateTracker::HasPipeline() const {
return mLastPipeline != nullptr;
}
RenderPipelineBase* CommandBufferStateTracker::GetRenderPipeline() const {
ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::RenderPipeline);
return static_cast<RenderPipelineBase*>(mLastPipeline);
}
ComputePipelineBase* CommandBufferStateTracker::GetComputePipeline() const {
ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::ComputePipeline);
return static_cast<ComputePipelineBase*>(mLastPipeline);
}
PipelineLayoutBase* CommandBufferStateTracker::GetPipelineLayout() const { PipelineLayoutBase* CommandBufferStateTracker::GetPipelineLayout() const {
return mLastPipelineLayout; return mLastPipelineLayout;
} }

View File

@ -38,7 +38,10 @@ namespace dawn_native {
// State-modifying methods // State-modifying methods
void SetComputePipeline(ComputePipelineBase* pipeline); void SetComputePipeline(ComputePipelineBase* pipeline);
void SetRenderPipeline(RenderPipelineBase* pipeline); void SetRenderPipeline(RenderPipelineBase* pipeline);
void SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup); void SetBindGroup(BindGroupIndex index,
BindGroupBase* bindgroup,
uint32_t dynamicOffsetCount,
const uint32_t* dynamicOffsets);
void SetIndexBuffer(wgpu::IndexFormat format, uint64_t size); void SetIndexBuffer(wgpu::IndexFormat format, uint64_t size);
void SetVertexBuffer(VertexBufferSlot slot, uint64_t size); void SetVertexBuffer(VertexBufferSlot slot, uint64_t size);
@ -46,6 +49,10 @@ namespace dawn_native {
using ValidationAspects = std::bitset<kNumAspects>; using ValidationAspects = std::bitset<kNumAspects>;
BindGroupBase* GetBindGroup(BindGroupIndex index) const; BindGroupBase* GetBindGroup(BindGroupIndex index) const;
const std::vector<uint32_t>& GetDynamicOffsets(BindGroupIndex index) const;
bool HasPipeline() const;
RenderPipelineBase* GetRenderPipeline() const;
ComputePipelineBase* GetComputePipeline() const;
PipelineLayoutBase* GetPipelineLayout() const; PipelineLayoutBase* GetPipelineLayout() const;
wgpu::IndexFormat GetIndexFormat() const; wgpu::IndexFormat GetIndexFormat() const;
uint64_t GetIndexBufferSize() const; uint64_t GetIndexBufferSize() const;
@ -60,6 +67,7 @@ namespace dawn_native {
ValidationAspects mAspects; ValidationAspects mAspects;
ityp::array<BindGroupIndex, BindGroupBase*, kMaxBindGroups> mBindgroups = {}; ityp::array<BindGroupIndex, BindGroupBase*, kMaxBindGroups> mBindgroups = {};
ityp::array<BindGroupIndex, std::vector<uint32_t>, kMaxBindGroups> mDynamicOffsets = {};
ityp::bitset<VertexBufferSlot, kMaxVertexBuffers> mVertexBufferSlotsUsed; ityp::bitset<VertexBufferSlot, kMaxVertexBuffers> mVertexBufferSlotsUsed;
bool mIndexBufferSet = false; bool mIndexBufferSet = false;
wgpu::IndexFormat mIndexFormat; wgpu::IndexFormat mIndexFormat;
@ -68,7 +76,7 @@ namespace dawn_native {
ityp::array<VertexBufferSlot, uint64_t, kMaxVertexBuffers> mVertexBufferSizes = {}; ityp::array<VertexBufferSlot, uint64_t, kMaxVertexBuffers> mVertexBufferSizes = {};
PipelineLayoutBase* mLastPipelineLayout = nullptr; PipelineLayoutBase* mLastPipelineLayout = nullptr;
RenderPipelineBase* mLastRenderPipeline = nullptr; PipelineBase* mLastPipeline = nullptr;
const RequiredBufferSizes* mMinBufferSizes = nullptr; const RequiredBufferSizes* mMinBufferSizes = nullptr;
}; };

View File

@ -14,18 +14,107 @@
#include "dawn_native/ComputePassEncoder.h" #include "dawn_native/ComputePassEncoder.h"
#include "dawn_native/BindGroup.h"
#include "dawn_native/BindGroupLayout.h"
#include "dawn_native/Buffer.h" #include "dawn_native/Buffer.h"
#include "dawn_native/CommandEncoder.h" #include "dawn_native/CommandEncoder.h"
#include "dawn_native/CommandValidation.h" #include "dawn_native/CommandValidation.h"
#include "dawn_native/Commands.h" #include "dawn_native/Commands.h"
#include "dawn_native/ComputePipeline.h" #include "dawn_native/ComputePipeline.h"
#include "dawn_native/Device.h" #include "dawn_native/Device.h"
#include "dawn_native/InternalPipelineStore.h"
#include "dawn_native/ObjectType_autogen.h" #include "dawn_native/ObjectType_autogen.h"
#include "dawn_native/PassResourceUsageTracker.h" #include "dawn_native/PassResourceUsageTracker.h"
#include "dawn_native/QuerySet.h" #include "dawn_native/QuerySet.h"
namespace dawn_native { namespace dawn_native {
namespace {
ResultOrError<ComputePipelineBase*> GetOrCreateIndirectDispatchValidationPipeline(
DeviceBase* device) {
InternalPipelineStore* store = device->GetInternalPipelineStore();
if (store->dispatchIndirectValidationPipeline != nullptr) {
return store->dispatchIndirectValidationPipeline.Get();
}
ShaderModuleDescriptor descriptor;
ShaderModuleWGSLDescriptor wgslDesc;
descriptor.nextInChain = reinterpret_cast<ChainedStruct*>(&wgslDesc);
// TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this
// shader in various failure modes.
wgslDesc.source = R"(
[[block]] struct UniformParams {
maxComputeWorkgroupsPerDimension: u32;
clientOffsetInU32: u32;
};
[[block]] struct IndirectParams {
data: array<u32>;
};
[[block]] struct ValidatedParams {
data: array<u32, 3>;
};
[[group(0), binding(0)]] var<uniform> uniformParams: UniformParams;
[[group(0), binding(1)]] var<storage, read_write> clientParams: IndirectParams;
[[group(0), binding(2)]] var<storage, write> validatedParams: ValidatedParams;
[[stage(compute), workgroup_size(1, 1, 1)]]
fn main() {
for (var i = 0u; i < 3u; i = i + 1u) {
var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i];
if (numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) {
numWorkgroups = 0u;
}
validatedParams.data[i] = numWorkgroups;
}
}
)";
Ref<ShaderModuleBase> shaderModule;
DAWN_TRY_ASSIGN(shaderModule, device->CreateShaderModule(&descriptor));
std::array<BindGroupLayoutEntry, 3> entries;
entries[0].binding = 0;
entries[0].visibility = wgpu::ShaderStage::Compute;
entries[0].buffer.type = wgpu::BufferBindingType::Uniform;
entries[1].binding = 1;
entries[1].visibility = wgpu::ShaderStage::Compute;
entries[1].buffer.type = kInternalStorageBufferBinding;
entries[2].binding = 2;
entries[2].visibility = wgpu::ShaderStage::Compute;
entries[2].buffer.type = wgpu::BufferBindingType::Storage;
BindGroupLayoutDescriptor bindGroupLayoutDescriptor;
bindGroupLayoutDescriptor.entryCount = entries.size();
bindGroupLayoutDescriptor.entries = entries.data();
Ref<BindGroupLayoutBase> bindGroupLayout;
DAWN_TRY_ASSIGN(bindGroupLayout,
device->CreateBindGroupLayout(&bindGroupLayoutDescriptor, true));
PipelineLayoutDescriptor pipelineDescriptor;
pipelineDescriptor.bindGroupLayoutCount = 1;
pipelineDescriptor.bindGroupLayouts = &bindGroupLayout.Get();
Ref<PipelineLayoutBase> pipelineLayout;
DAWN_TRY_ASSIGN(pipelineLayout, device->CreatePipelineLayout(&pipelineDescriptor));
ComputePipelineDescriptor computePipelineDescriptor = {};
computePipelineDescriptor.layout = pipelineLayout.Get();
computePipelineDescriptor.compute.module = shaderModule.Get();
computePipelineDescriptor.compute.entryPoint = "main";
DAWN_TRY_ASSIGN(store->dispatchIndirectValidationPipeline,
device->CreateComputePipeline(&computePipelineDescriptor));
return store->dispatchIndirectValidationPipeline.Get();
}
} // namespace
ComputePassEncoder::ComputePassEncoder(DeviceBase* device, ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder, CommandEncoder* commandEncoder,
EncodingContext* encodingContext) EncodingContext* encodingContext)
@ -107,6 +196,95 @@ namespace dawn_native {
"encoding Dispatch (x: %u, y: %u, z: %u)", x, y, z); "encoding Dispatch (x: %u, y: %u, z: %u)", x, y, z);
} }
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>>
ComputePassEncoder::ValidateIndirectDispatch(BufferBase* indirectBuffer,
uint64_t indirectOffset) {
DeviceBase* device = GetDevice();
auto* const store = device->GetInternalPipelineStore();
Ref<ComputePipelineBase> validationPipeline;
DAWN_TRY_ASSIGN(validationPipeline, GetOrCreateIndirectDispatchValidationPipeline(device));
Ref<BindGroupLayoutBase> layout;
DAWN_TRY_ASSIGN(layout, validationPipeline->GetBindGroupLayout(0));
uint32_t storageBufferOffsetAlignment =
device->GetLimits().v1.minStorageBufferOffsetAlignment;
std::array<BindGroupEntry, 3> bindings;
// Storage binding holding the client's indirect buffer.
BindGroupEntry& clientIndirectBinding = bindings[0];
clientIndirectBinding.binding = 1;
clientIndirectBinding.buffer = indirectBuffer;
// Let the offset be the indirectOffset, aligned down to |storageBufferOffsetAlignment|.
const uint32_t clientOffsetFromAlignedBoundary =
indirectOffset % storageBufferOffsetAlignment;
const uint64_t clientOffsetAlignedDown = indirectOffset - clientOffsetFromAlignedBoundary;
clientIndirectBinding.offset = clientOffsetAlignedDown;
// Let the size of the binding be the additional offset, plus the size.
clientIndirectBinding.size = kDispatchIndirectSize + clientOffsetFromAlignedBoundary;
struct UniformParams {
uint32_t maxComputeWorkgroupsPerDimension;
uint32_t clientOffsetInU32;
};
// Create a uniform buffer to hold parameters for the shader.
Ref<BufferBase> uniformBuffer;
{
BufferDescriptor uniformDesc = {};
uniformDesc.size = sizeof(UniformParams);
uniformDesc.usage = wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst;
uniformDesc.mappedAtCreation = true;
DAWN_TRY_ASSIGN(uniformBuffer, device->CreateBuffer(&uniformDesc));
UniformParams* params = static_cast<UniformParams*>(
uniformBuffer->GetMappedRange(0, sizeof(UniformParams)));
params->maxComputeWorkgroupsPerDimension =
device->GetLimits().v1.maxComputeWorkgroupsPerDimension;
params->clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t);
uniformBuffer->Unmap();
}
// Uniform buffer binding pointing to the uniform parameters.
BindGroupEntry& uniformBinding = bindings[1];
uniformBinding.binding = 0;
uniformBinding.buffer = uniformBuffer.Get();
uniformBinding.offset = 0;
uniformBinding.size = sizeof(UniformParams);
// Reserve space in the scratch buffer to hold the validated indirect params.
ScratchBuffer& scratchBuffer = store->scratchIndirectStorage;
DAWN_TRY(scratchBuffer.EnsureCapacity(kDispatchIndirectSize));
Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer();
// Binding for the validated indirect params.
BindGroupEntry& validatedParamsBinding = bindings[2];
validatedParamsBinding.binding = 2;
validatedParamsBinding.buffer = validatedIndirectBuffer.Get();
validatedParamsBinding.offset = 0;
validatedParamsBinding.size = kDispatchIndirectSize;
BindGroupDescriptor bindGroupDescriptor = {};
bindGroupDescriptor.layout = layout.Get();
bindGroupDescriptor.entryCount = bindings.size();
bindGroupDescriptor.entries = bindings.data();
Ref<BindGroupBase> validationBindGroup;
DAWN_TRY_ASSIGN(validationBindGroup, device->CreateBindGroup(&bindGroupDescriptor));
// Issue commands to validate the indirect buffer.
APISetPipeline(validationPipeline.Get());
APISetBindGroup(0, validationBindGroup.Get());
APIDispatch(1);
// Return the new indirect buffer and indirect buffer offset.
return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0));
}
void ComputePassEncoder::APIDispatchIndirect(BufferBase* indirectBuffer, void ComputePassEncoder::APIDispatchIndirect(BufferBase* indirectBuffer,
uint64_t indirectOffset) { uint64_t indirectOffset) {
mEncodingContext->TryEncode( mEncodingContext->TryEncode(
@ -136,18 +314,46 @@ namespace dawn_native {
indirectOffset, kDispatchIndirectSize, indirectBuffer->GetSize()); indirectOffset, kDispatchIndirectSize, indirectBuffer->GetSize());
} }
// Record the synchronization scope for Dispatch, both the bindgroups and the
// indirect buffer.
SyncScopeUsageTracker scope; SyncScopeUsageTracker scope;
scope.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect); scope.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
mUsageTracker.AddReferencedBuffer(indirectBuffer); mUsageTracker.AddReferencedBuffer(indirectBuffer);
// TODO(crbug.com/dawn/1166): If validation is enabled, adding |indirectBuffer|
// is needed for correct usage validation even though it will only be bound for
// storage. This will unecessarily transition the |indirectBuffer| in
// the backend.
Ref<BufferBase> indirectBufferRef = indirectBuffer;
if (IsValidationEnabled()) {
// Save the previous command buffer state so it can be restored after the
// validation inserts additional commands.
CommandBufferStateTracker previousState = mCommandBufferState;
// Validate each indirect dispatch with a single dispatch to copy the indirect
// buffer params into a scratch buffer if they're valid, and otherwise zero them
// out. We could consider moving the validation earlier in the pass after the
// last point the indirect buffer was used with writable usage, as well as batch
// validation for multiple dispatches into one, but inserting commands at
// arbitrary points in the past is not possible right now.
DAWN_TRY_ASSIGN(
std::tie(indirectBufferRef, indirectOffset),
ValidateIndirectDispatch(indirectBufferRef.Get(), indirectOffset));
// Restore the state.
RestoreCommandBufferState(std::move(previousState));
// |indirectBufferRef| was replaced with a scratch buffer. Add it to the
// synchronization scope.
ASSERT(indirectBufferRef.Get() != indirectBuffer);
scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect);
mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get());
}
AddDispatchSyncScope(std::move(scope)); AddDispatchSyncScope(std::move(scope));
DispatchIndirectCmd* dispatch = DispatchIndirectCmd* dispatch =
allocator->Allocate<DispatchIndirectCmd>(Command::DispatchIndirect); allocator->Allocate<DispatchIndirectCmd>(Command::DispatchIndirect);
dispatch->indirectBuffer = indirectBuffer; dispatch->indirectBuffer = std::move(indirectBufferRef);
dispatch->indirectOffset = indirectOffset; dispatch->indirectOffset = indirectOffset;
return {}; return {};
}, },
"encoding DispatchIndirect with %s", indirectBuffer); "encoding DispatchIndirect with %s", indirectBuffer);
@ -187,10 +393,10 @@ namespace dawn_native {
} }
mUsageTracker.AddResourcesReferencedByBindGroup(group); mUsageTracker.AddResourcesReferencedByBindGroup(group);
RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount, RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount,
dynamicOffsets); dynamicOffsets);
mCommandBufferState.SetBindGroup(groupIndex, group); mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount,
dynamicOffsets);
return {}; return {};
}, },
@ -226,4 +432,29 @@ namespace dawn_native {
mUsageTracker.AddDispatch(scope.AcquireSyncScopeUsage()); mUsageTracker.AddDispatch(scope.AcquireSyncScopeUsage());
} }
void ComputePassEncoder::RestoreCommandBufferState(CommandBufferStateTracker state) {
// Encode commands for the backend to restore the pipeline and bind groups.
if (state.HasPipeline()) {
APISetPipeline(state.GetComputePipeline());
}
for (BindGroupIndex i(0); i < kMaxBindGroupsTyped; ++i) {
BindGroupBase* bg = state.GetBindGroup(i);
if (bg != nullptr) {
const std::vector<uint32_t>& offsets = state.GetDynamicOffsets(i);
if (offsets.empty()) {
APISetBindGroup(static_cast<uint32_t>(i), bg);
} else {
APISetBindGroup(static_cast<uint32_t>(i), bg, offsets.size(), offsets.data());
}
}
}
// Restore the frontend state tracking information.
mCommandBufferState = std::move(state);
}
CommandBufferStateTracker* ComputePassEncoder::GetCommandBufferStateTrackerForTesting() {
return &mCommandBufferState;
}
} // namespace dawn_native } // namespace dawn_native

View File

@ -50,6 +50,11 @@ namespace dawn_native {
void APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex); void APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex);
CommandBufferStateTracker* GetCommandBufferStateTrackerForTesting();
void RestoreCommandBufferStateForTesting(CommandBufferStateTracker state) {
RestoreCommandBufferState(std::move(state));
}
protected: protected:
ComputePassEncoder(DeviceBase* device, ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder, CommandEncoder* commandEncoder,
@ -57,6 +62,12 @@ namespace dawn_native {
ErrorTag errorTag); ErrorTag errorTag);
private: private:
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ValidateIndirectDispatch(
BufferBase* indirectBuffer,
uint64_t indirectOffset);
void RestoreCommandBufferState(CommandBufferStateTracker state);
CommandBufferStateTracker mCommandBufferState; CommandBufferStateTracker mCommandBufferState;
// Adds the bindgroups used for the current dispatch to the SyncScopeResourceUsage and // Adds the bindgroups used for the current dispatch to the SyncScopeResourceUsage and

View File

@ -52,6 +52,7 @@ namespace dawn_native {
Ref<ComputePipelineBase> renderValidationPipeline; Ref<ComputePipelineBase> renderValidationPipeline;
Ref<ShaderModuleBase> renderValidationShader; Ref<ShaderModuleBase> renderValidationShader;
Ref<ComputePipelineBase> dispatchIndirectValidationPipeline;
}; };
} // namespace dawn_native } // namespace dawn_native

View File

@ -208,6 +208,9 @@ namespace dawn_native {
BufferLocation::New(indirectBuffer, indirectOffset); BufferLocation::New(indirectBuffer, indirectOffset);
} }
// TODO(crbug.com/dawn/1166): Adding the indirectBuffer is needed for correct usage
// validation, but it will unecessarily transition to indirectBuffer usage in the
// backend.
mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect); mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
return {}; return {};
@ -404,7 +407,8 @@ namespace dawn_native {
RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount, RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount,
dynamicOffsets); dynamicOffsets);
mCommandBufferState.SetBindGroup(groupIndex, group); mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount,
dynamicOffsets);
mUsageTracker.AddBindGroup(group); mUsageTracker.AddBindGroup(group);
return {}; return {};

View File

@ -221,6 +221,7 @@ test("dawn_unittests") {
"unittests/SystemUtilsTests.cpp", "unittests/SystemUtilsTests.cpp",
"unittests/ToBackendTests.cpp", "unittests/ToBackendTests.cpp",
"unittests/TypedIntegerTests.cpp", "unittests/TypedIntegerTests.cpp",
"unittests/native/CommandBufferEncodingTests.cpp",
"unittests/native/DestroyObjectTests.cpp", "unittests/native/DestroyObjectTests.cpp",
"unittests/validation/BindGroupValidationTests.cpp", "unittests/validation/BindGroupValidationTests.cpp",
"unittests/validation/BufferValidationTests.cpp", "unittests/validation/BufferValidationTests.cpp",

View File

@ -12,9 +12,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include "tests/DawnNativeTest.h"
#include "absl/strings/str_cat.h" #include "absl/strings/str_cat.h"
#include "common/Assert.h"
#include "dawn/dawn_proc.h"
#include "dawn_native/ErrorData.h" #include "dawn_native/ErrorData.h"
namespace dawn_native { namespace dawn_native {
@ -28,3 +30,54 @@ namespace dawn_native {
} }
} // namespace dawn_native } // namespace dawn_native
DawnNativeTest::DawnNativeTest() {
dawnProcSetProcs(&dawn_native::GetProcs());
}
DawnNativeTest::~DawnNativeTest() {
device = wgpu::Device();
dawnProcSetProcs(nullptr);
}
void DawnNativeTest::SetUp() {
instance = std::make_unique<dawn_native::Instance>();
instance->DiscoverDefaultAdapters();
std::vector<dawn_native::Adapter> adapters = instance->GetAdapters();
// DawnNative unittests run against the null backend, find the corresponding adapter
bool foundNullAdapter = false;
for (auto& currentAdapter : adapters) {
wgpu::AdapterProperties adapterProperties;
currentAdapter.GetProperties(&adapterProperties);
if (adapterProperties.backendType == wgpu::BackendType::Null) {
adapter = currentAdapter;
foundNullAdapter = true;
break;
}
}
ASSERT(foundNullAdapter);
device = wgpu::Device(CreateTestDevice());
device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr);
}
void DawnNativeTest::TearDown() {
}
WGPUDevice DawnNativeTest::CreateTestDevice() {
// Disabled disallowing unsafe APIs so we can test them.
dawn_native::DeviceDescriptor deviceDescriptor;
deviceDescriptor.forceDisabledToggles.push_back("disallow_unsafe_apis");
return adapter.CreateDevice(&deviceDescriptor);
}
// static
void DawnNativeTest::OnDeviceError(WGPUErrorType type, const char* message, void* userdata) {
ASSERT(type != WGPUErrorType_NoError);
FAIL() << "Unexpected error: " << message;
}

View File

@ -17,6 +17,8 @@
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "dawn/webgpu_cpp.h"
#include "dawn_native/DawnNative.h"
#include "dawn_native/ErrorData.h" #include "dawn_native/ErrorData.h"
namespace dawn_native { namespace dawn_native {
@ -29,4 +31,23 @@ namespace dawn_native {
} // namespace dawn_native } // namespace dawn_native
class DawnNativeTest : public ::testing::Test {
public:
DawnNativeTest();
~DawnNativeTest() override;
void SetUp() override;
void TearDown() override;
virtual WGPUDevice CreateTestDevice();
protected:
std::unique_ptr<dawn_native::Instance> instance;
dawn_native::Adapter adapter;
wgpu::Device device;
private:
static void OnDeviceError(WGPUErrorType type, const char* message, void* userdata);
};
#endif // TESTS_DAWNNATIVETEST_H_ #endif // TESTS_DAWNNATIVETEST_H_

View File

@ -158,8 +158,14 @@ class ComputeDispatchTests : public DawnTest {
queue.Submit(1, &commands); queue.Submit(1, &commands);
std::vector<uint32_t> expected; std::vector<uint32_t> expected;
uint32_t maxComputeWorkgroupsPerDimension =
GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
if (indirectBufferData[indirectStart] == 0 || indirectBufferData[indirectStart + 1] == 0 || if (indirectBufferData[indirectStart] == 0 || indirectBufferData[indirectStart + 1] == 0 ||
indirectBufferData[indirectStart + 2] == 0) { indirectBufferData[indirectStart + 2] == 0 ||
indirectBufferData[indirectStart] > maxComputeWorkgroupsPerDimension ||
indirectBufferData[indirectStart + 1] > maxComputeWorkgroupsPerDimension ||
indirectBufferData[indirectStart + 2] > maxComputeWorkgroupsPerDimension) {
expected = kSentinelData; expected = kSentinelData;
} else { } else {
expected.assign(indirectBufferData.begin() + indirectStart, expected.assign(indirectBufferData.begin() + indirectStart,
@ -221,6 +227,52 @@ TEST_P(ComputeDispatchTests, IndirectOffset) {
IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t)); IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t));
} }
// Test indirect dispatches at max limit.
TEST_P(ComputeDispatchTests, MaxWorkgroups) {
// TODO(crbug.com/dawn/1165): Fails with WARP
DAWN_SUPPRESS_TEST_IF(IsWARP());
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
// Test that the maximum works in each dimension.
// Note: Testing (max, max, max) is very slow.
IndirectTest({max, 3, 4}, 0);
IndirectTest({2, max, 4}, 0);
IndirectTest({2, 3, max}, 0);
}
// Test indirect dispatches exceeding the max limit are noop-ed.
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
// All dimensions are above the max
IndirectTest({max + 1, max + 1, max + 1}, 0);
// Only x dimension is above the max
IndirectTest({max + 1, 3, 4}, 0);
IndirectTest({2 * max, 3, 4}, 0);
// Only y dimension is above the max
IndirectTest({2, max + 1, 4}, 0);
IndirectTest({2, 2 * max, 4}, 0);
// Only z dimension is above the max
IndirectTest({2, 3, max + 1}, 0);
IndirectTest({2, 3, 2 * max}, 0);
}
// Test indirect dispatches exceeding the max limit with an offset are noop-ed.
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) {
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t));
IndirectTest({1, 2, 3, max + 1, 4, 5}, 2 * sizeof(uint32_t));
IndirectTest({1, 2, 3, max + 1, 4, 5}, 3 * sizeof(uint32_t));
}
DAWN_INSTANTIATE_TEST(ComputeDispatchTests, DAWN_INSTANTIATE_TEST(ComputeDispatchTests,
D3D12Backend(), D3D12Backend(),
MetalBackend(), MetalBackend(),

View File

@ -0,0 +1,310 @@
// Copyright 2021 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tests/DawnNativeTest.h"
#include "dawn_native/CommandBuffer.h"
#include "dawn_native/Commands.h"
#include "dawn_native/ComputePassEncoder.h"
#include "utils/WGPUHelpers.h"
class CommandBufferEncodingTests : public DawnNativeTest {
protected:
void ExpectCommands(dawn_native::CommandIterator* commands,
std::vector<std::pair<dawn_native::Command,
std::function<void(dawn_native::CommandIterator*)>>>
expectedCommands) {
dawn_native::Command commandId;
for (uint32_t commandIndex = 0; commands->NextCommandId(&commandId); ++commandIndex) {
ASSERT_LT(commandIndex, expectedCommands.size()) << "Unexpected command";
ASSERT_EQ(commandId, expectedCommands[commandIndex].first)
<< "at command " << commandIndex;
expectedCommands[commandIndex].second(commands);
}
}
};
// Indirect dispatch validation changes the bind groups in the middle
// of a pass. Test that bindings are restored after the validation runs.
TEST_F(CommandBufferEncodingTests, ComputePassEncoderIndirectDispatchStateRestoration) {
using namespace dawn_native;
wgpu::BindGroupLayout staticLayout =
utils::MakeBindGroupLayout(device, {{
0,
wgpu::ShaderStage::Compute,
wgpu::BufferBindingType::Uniform,
}});
wgpu::BindGroupLayout dynamicLayout =
utils::MakeBindGroupLayout(device, {{
0,
wgpu::ShaderStage::Compute,
wgpu::BufferBindingType::Uniform,
true,
}});
// Create a simple pipeline
wgpu::ComputePipelineDescriptor csDesc;
csDesc.compute.module = utils::CreateShaderModule(device, R"(
[[stage(compute), workgroup_size(1, 1, 1)]]
fn main() {
})");
csDesc.compute.entryPoint = "main";
wgpu::PipelineLayout pl0 = utils::MakePipelineLayout(device, {staticLayout, dynamicLayout});
csDesc.layout = pl0;
wgpu::ComputePipeline pipeline0 = device.CreateComputePipeline(&csDesc);
wgpu::PipelineLayout pl1 = utils::MakePipelineLayout(device, {dynamicLayout, staticLayout});
csDesc.layout = pl1;
wgpu::ComputePipeline pipeline1 = device.CreateComputePipeline(&csDesc);
// Create buffers to use for both the indirect buffer and the bind groups.
wgpu::Buffer indirectBuffer =
utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, {1, 2, 3, 4});
wgpu::BufferDescriptor uniformBufferDesc = {};
uniformBufferDesc.size = 512;
uniformBufferDesc.usage = wgpu::BufferUsage::Uniform;
wgpu::Buffer uniformBuffer = device.CreateBuffer(&uniformBufferDesc);
wgpu::BindGroup staticBG = utils::MakeBindGroup(device, staticLayout, {{0, uniformBuffer}});
wgpu::BindGroup dynamicBG =
utils::MakeBindGroup(device, dynamicLayout, {{0, uniformBuffer, 0, 256}});
uint32_t dynamicOffset = 256;
std::vector<uint32_t> emptyDynamicOffsets = {};
std::vector<uint32_t> singleDynamicOffset = {dynamicOffset};
// Begin encoding commands.
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
CommandBufferStateTracker* stateTracker =
FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
// Perform a dispatch indirect which will be preceded by a validation dispatch.
pass.SetPipeline(pipeline0);
pass.SetBindGroup(0, staticBG);
pass.SetBindGroup(1, dynamicBG, 1, &dynamicOffset);
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
pass.DispatchIndirect(indirectBuffer, 0);
// Expect restored state.
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
// Dispatch again to check that the restored state can be used.
// Also pass an indirect offset which should get replaced with the offset
// into the scratch indirect buffer (0).
pass.DispatchIndirect(indirectBuffer, 4);
// Expect restored state.
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
// Change the pipeline
pass.SetPipeline(pipeline1);
pass.SetBindGroup(0, dynamicBG, 1, &dynamicOffset);
pass.SetBindGroup(1, staticBG);
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
pass.DispatchIndirect(indirectBuffer, 0);
// Expect restored state.
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), dynamicBG.Get());
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), singleDynamicOffset);
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), staticBG.Get());
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), emptyDynamicOffsets);
pass.EndPass();
wgpu::CommandBuffer commandBuffer = encoder.Finish();
auto ExpectSetPipeline = [](wgpu::ComputePipeline pipeline) {
return [pipeline](CommandIterator* commands) {
auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
EXPECT_EQ(ToAPI(cmd->pipeline.Get()), pipeline.Get());
};
};
auto ExpectSetBindGroup = [](uint32_t index, wgpu::BindGroup bg,
std::vector<uint32_t> offsets = {}) {
return [index, bg, offsets](CommandIterator* commands) {
auto* cmd = commands->NextCommand<SetBindGroupCmd>();
uint32_t* dynamicOffsets = nullptr;
if (cmd->dynamicOffsetCount > 0) {
dynamicOffsets = commands->NextData<uint32_t>(cmd->dynamicOffsetCount);
}
ASSERT_EQ(cmd->index, BindGroupIndex(index));
ASSERT_EQ(ToAPI(cmd->group.Get()), bg.Get());
ASSERT_EQ(cmd->dynamicOffsetCount, offsets.size());
for (uint32_t i = 0; i < cmd->dynamicOffsetCount; ++i) {
ASSERT_EQ(dynamicOffsets[i], offsets[i]);
}
};
};
// Initialize as null. Once we know the pointer, we'll check
// that it's the same buffer every time.
WGPUBuffer indirectScratchBuffer = nullptr;
auto ExpectDispatchIndirect = [&](CommandIterator* commands) {
auto* cmd = commands->NextCommand<DispatchIndirectCmd>();
if (indirectScratchBuffer == nullptr) {
indirectScratchBuffer = ToAPI(cmd->indirectBuffer.Get());
}
ASSERT_EQ(ToAPI(cmd->indirectBuffer.Get()), indirectScratchBuffer);
ASSERT_EQ(cmd->indirectOffset, uint64_t(0));
};
// Initialize as null. Once we know the pointer, we'll check
// that it's the same pipeline every time.
WGPUComputePipeline validationPipeline = nullptr;
auto ExpectSetValidationPipeline = [&](CommandIterator* commands) {
auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
WGPUComputePipeline pipeline = ToAPI(cmd->pipeline.Get());
if (validationPipeline != nullptr) {
EXPECT_EQ(pipeline, validationPipeline);
} else {
EXPECT_NE(pipeline, nullptr);
validationPipeline = pipeline;
}
};
auto ExpectSetValidationBindGroup = [&](CommandIterator* commands) {
auto* cmd = commands->NextCommand<SetBindGroupCmd>();
ASSERT_EQ(cmd->index, BindGroupIndex(0));
ASSERT_NE(cmd->group.Get(), nullptr);
ASSERT_EQ(cmd->dynamicOffsetCount, 0u);
};
auto ExpectSetValidationDispatch = [&](CommandIterator* commands) {
auto* cmd = commands->NextCommand<DispatchCmd>();
ASSERT_EQ(cmd->x, 1u);
ASSERT_EQ(cmd->y, 1u);
ASSERT_EQ(cmd->z, 1u);
};
ExpectCommands(
FromAPI(commandBuffer.Get())->GetCommandIteratorForTesting(),
{
{Command::BeginComputePass,
[&](CommandIterator* commands) { SkipCommand(commands, Command::BeginComputePass); }},
// Expect the state to be set.
{Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
{Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
{Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
// Expect the validation.
{Command::SetComputePipeline, ExpectSetValidationPipeline},
{Command::SetBindGroup, ExpectSetValidationBindGroup},
{Command::Dispatch, ExpectSetValidationDispatch},
// Expect the state to be restored.
{Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
{Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
{Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
// Expect the dispatchIndirect.
{Command::DispatchIndirect, ExpectDispatchIndirect},
// Expect the validation.
{Command::SetComputePipeline, ExpectSetValidationPipeline},
{Command::SetBindGroup, ExpectSetValidationBindGroup},
{Command::Dispatch, ExpectSetValidationDispatch},
// Expect the state to be restored.
{Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
{Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
{Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
// Expect the dispatchIndirect.
{Command::DispatchIndirect, ExpectDispatchIndirect},
// Expect the state to be set (new pipeline).
{Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
{Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
{Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
// Expect the validation.
{Command::SetComputePipeline, ExpectSetValidationPipeline},
{Command::SetBindGroup, ExpectSetValidationBindGroup},
{Command::Dispatch, ExpectSetValidationDispatch},
// Expect the state to be restored.
{Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
{Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
{Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
// Expect the dispatchIndirect.
{Command::DispatchIndirect, ExpectDispatchIndirect},
{Command::EndComputePass,
[&](CommandIterator* commands) { commands->NextCommand<EndComputePassCmd>(); }},
});
}
// Test that after restoring state, it is fully applied to the state tracker
// and does not leak state changes that occured between a snapshot and the
// state restoration.
TEST_F(CommandBufferEncodingTests, StateNotLeakedAfterRestore) {
using namespace dawn_native;
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
CommandBufferStateTracker* stateTracker =
FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
// Snapshot the state.
CommandBufferStateTracker snapshot = *stateTracker;
// Expect no pipeline in the snapshot
EXPECT_FALSE(snapshot.HasPipeline());
// Create a simple pipeline
wgpu::ComputePipelineDescriptor csDesc;
csDesc.compute.module = utils::CreateShaderModule(device, R"(
[[stage(compute), workgroup_size(1, 1, 1)]]
fn main() {
})");
csDesc.compute.entryPoint = "main";
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
// Set the pipeline.
pass.SetPipeline(pipeline);
// Expect the pipeline to be set.
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline.Get());
// Restore the state.
FromAPI(pass.Get())->RestoreCommandBufferStateForTesting(std::move(snapshot));
// Expect no pipeline
EXPECT_FALSE(stateTracker->HasPipeline());
}