Implement GPU-based validation for dispatchIndirect
Bug: dawn:1039 Change-Id: I1b77244d33b178c8e4d4b7d72dc038ccb9d65c48 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/67142 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
parent
deb4057d27
commit
bcfa7b1253
|
@ -73,6 +73,10 @@ namespace dawn_native {
|
|||
return mResourceUsages;
|
||||
}
|
||||
|
||||
CommandIterator* CommandBufferBase::GetCommandIteratorForTesting() {
|
||||
return &mCommands;
|
||||
}
|
||||
|
||||
bool IsCompleteSubresourceCopiedTo(const TextureBase* texture,
|
||||
const Extent3D copySize,
|
||||
const uint32_t mipLevel) {
|
||||
|
|
|
@ -43,6 +43,8 @@ namespace dawn_native {
|
|||
|
||||
const CommandBufferResourceUsage& GetResourceUsages() const;
|
||||
|
||||
CommandIterator* GetCommandIteratorForTesting();
|
||||
|
||||
protected:
|
||||
~CommandBufferBase() override;
|
||||
|
||||
|
|
|
@ -17,8 +17,10 @@
|
|||
#include "common/Assert.h"
|
||||
#include "common/BitSetIterator.h"
|
||||
#include "dawn_native/BindGroup.h"
|
||||
#include "dawn_native/ComputePassEncoder.h"
|
||||
#include "dawn_native/ComputePipeline.h"
|
||||
#include "dawn_native/Forward.h"
|
||||
#include "dawn_native/ObjectType_autogen.h"
|
||||
#include "dawn_native/PipelineLayout.h"
|
||||
#include "dawn_native/RenderPipeline.h"
|
||||
|
||||
|
@ -83,13 +85,15 @@ namespace dawn_native {
|
|||
MaybeError CommandBufferStateTracker::ValidateBufferInRangeForVertexBuffer(
|
||||
uint32_t vertexCount,
|
||||
uint32_t firstVertex) {
|
||||
RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
|
||||
|
||||
const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>&
|
||||
vertexBufferSlotsUsedAsVertexBuffer =
|
||||
mLastRenderPipeline->GetVertexBufferSlotsUsedAsVertexBuffer();
|
||||
lastRenderPipeline->GetVertexBufferSlotsUsedAsVertexBuffer();
|
||||
|
||||
for (auto usedSlotVertex : IterateBitSet(vertexBufferSlotsUsedAsVertexBuffer)) {
|
||||
const VertexBufferInfo& vertexBuffer =
|
||||
mLastRenderPipeline->GetVertexBuffer(usedSlotVertex);
|
||||
lastRenderPipeline->GetVertexBuffer(usedSlotVertex);
|
||||
uint64_t arrayStride = vertexBuffer.arrayStride;
|
||||
uint64_t bufferSize = mVertexBufferSizes[usedSlotVertex];
|
||||
|
||||
|
@ -120,13 +124,15 @@ namespace dawn_native {
|
|||
MaybeError CommandBufferStateTracker::ValidateBufferInRangeForInstanceBuffer(
|
||||
uint32_t instanceCount,
|
||||
uint32_t firstInstance) {
|
||||
RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
|
||||
|
||||
const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>&
|
||||
vertexBufferSlotsUsedAsInstanceBuffer =
|
||||
mLastRenderPipeline->GetVertexBufferSlotsUsedAsInstanceBuffer();
|
||||
lastRenderPipeline->GetVertexBufferSlotsUsedAsInstanceBuffer();
|
||||
|
||||
for (auto usedSlotInstance : IterateBitSet(vertexBufferSlotsUsedAsInstanceBuffer)) {
|
||||
const VertexBufferInfo& vertexBuffer =
|
||||
mLastRenderPipeline->GetVertexBuffer(usedSlotInstance);
|
||||
lastRenderPipeline->GetVertexBuffer(usedSlotInstance);
|
||||
uint64_t arrayStride = vertexBuffer.arrayStride;
|
||||
uint64_t bufferSize = mVertexBufferSizes[usedSlotInstance];
|
||||
if (arrayStride == 0) {
|
||||
|
@ -209,18 +215,19 @@ namespace dawn_native {
|
|||
}
|
||||
|
||||
if (aspects[VALIDATION_ASPECT_VERTEX_BUFFERS]) {
|
||||
ASSERT(mLastRenderPipeline != nullptr);
|
||||
RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
|
||||
|
||||
const ityp::bitset<VertexBufferSlot, kMaxVertexBuffers>& requiredVertexBuffers =
|
||||
mLastRenderPipeline->GetVertexBufferSlotsUsed();
|
||||
lastRenderPipeline->GetVertexBufferSlotsUsed();
|
||||
if (IsSubset(requiredVertexBuffers, mVertexBufferSlotsUsed)) {
|
||||
mAspects.set(VALIDATION_ASPECT_VERTEX_BUFFERS);
|
||||
}
|
||||
}
|
||||
|
||||
if (aspects[VALIDATION_ASPECT_INDEX_BUFFER] && mIndexBufferSet) {
|
||||
if (!IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) ||
|
||||
mIndexFormat == mLastRenderPipeline->GetStripIndexFormat()) {
|
||||
RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
|
||||
if (!IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) ||
|
||||
mIndexFormat == lastRenderPipeline->GetStripIndexFormat()) {
|
||||
mAspects.set(VALIDATION_ASPECT_INDEX_BUFFER);
|
||||
}
|
||||
}
|
||||
|
@ -234,12 +241,13 @@ namespace dawn_native {
|
|||
if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_INDEX_BUFFER])) {
|
||||
DAWN_INVALID_IF(!mIndexBufferSet, "Index buffer was not set.");
|
||||
|
||||
wgpu::IndexFormat pipelineIndexFormat = mLastRenderPipeline->GetStripIndexFormat();
|
||||
RenderPipelineBase* lastRenderPipeline = GetRenderPipeline();
|
||||
wgpu::IndexFormat pipelineIndexFormat = lastRenderPipeline->GetStripIndexFormat();
|
||||
DAWN_INVALID_IF(
|
||||
IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) &&
|
||||
IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) &&
|
||||
mIndexFormat != pipelineIndexFormat,
|
||||
"Strip index format (%s) of %s does not match index buffer format (%s).",
|
||||
pipelineIndexFormat, mLastRenderPipeline, mIndexFormat);
|
||||
pipelineIndexFormat, lastRenderPipeline, mIndexFormat);
|
||||
|
||||
// The chunk of code above should be similar to the one in |RecomputeLazyAspects|.
|
||||
// It returns the first invalid state found. We shouldn't be able to reach this line
|
||||
|
@ -251,7 +259,7 @@ namespace dawn_native {
|
|||
|
||||
// TODO(dawn:563): Indicate which slots were not set.
|
||||
DAWN_INVALID_IF(aspects[VALIDATION_ASPECT_VERTEX_BUFFERS],
|
||||
"Vertex buffer slots required by %s were not set.", mLastRenderPipeline);
|
||||
"Vertex buffer slots required by %s were not set.", GetRenderPipeline());
|
||||
|
||||
if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_BIND_GROUPS])) {
|
||||
for (BindGroupIndex i : IterateBitSet(mLastPipelineLayout->GetBindGroupLayoutsMask())) {
|
||||
|
@ -290,12 +298,15 @@ namespace dawn_native {
|
|||
}
|
||||
|
||||
void CommandBufferStateTracker::SetRenderPipeline(RenderPipelineBase* pipeline) {
|
||||
mLastRenderPipeline = pipeline;
|
||||
SetPipelineCommon(pipeline);
|
||||
}
|
||||
|
||||
void CommandBufferStateTracker::SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup) {
|
||||
void CommandBufferStateTracker::SetBindGroup(BindGroupIndex index,
|
||||
BindGroupBase* bindgroup,
|
||||
uint32_t dynamicOffsetCount,
|
||||
const uint32_t* dynamicOffsets) {
|
||||
mBindgroups[index] = bindgroup;
|
||||
mDynamicOffsets[index].assign(dynamicOffsets, dynamicOffsets + dynamicOffsetCount);
|
||||
mAspects.reset(VALIDATION_ASPECT_BIND_GROUPS);
|
||||
}
|
||||
|
||||
|
@ -311,8 +322,9 @@ namespace dawn_native {
|
|||
}
|
||||
|
||||
void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) {
|
||||
mLastPipelineLayout = pipeline->GetLayout();
|
||||
mMinBufferSizes = &pipeline->GetMinBufferSizes();
|
||||
mLastPipeline = pipeline;
|
||||
mLastPipelineLayout = pipeline != nullptr ? pipeline->GetLayout() : nullptr;
|
||||
mMinBufferSizes = pipeline != nullptr ? &pipeline->GetMinBufferSizes() : nullptr;
|
||||
|
||||
mAspects.set(VALIDATION_ASPECT_PIPELINE);
|
||||
|
||||
|
@ -324,6 +336,25 @@ namespace dawn_native {
|
|||
return mBindgroups[index];
|
||||
}
|
||||
|
||||
const std::vector<uint32_t>& CommandBufferStateTracker::GetDynamicOffsets(
|
||||
BindGroupIndex index) const {
|
||||
return mDynamicOffsets[index];
|
||||
}
|
||||
|
||||
bool CommandBufferStateTracker::HasPipeline() const {
|
||||
return mLastPipeline != nullptr;
|
||||
}
|
||||
|
||||
RenderPipelineBase* CommandBufferStateTracker::GetRenderPipeline() const {
|
||||
ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::RenderPipeline);
|
||||
return static_cast<RenderPipelineBase*>(mLastPipeline);
|
||||
}
|
||||
|
||||
ComputePipelineBase* CommandBufferStateTracker::GetComputePipeline() const {
|
||||
ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::ComputePipeline);
|
||||
return static_cast<ComputePipelineBase*>(mLastPipeline);
|
||||
}
|
||||
|
||||
PipelineLayoutBase* CommandBufferStateTracker::GetPipelineLayout() const {
|
||||
return mLastPipelineLayout;
|
||||
}
|
||||
|
|
|
@ -38,7 +38,10 @@ namespace dawn_native {
|
|||
// State-modifying methods
|
||||
void SetComputePipeline(ComputePipelineBase* pipeline);
|
||||
void SetRenderPipeline(RenderPipelineBase* pipeline);
|
||||
void SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup);
|
||||
void SetBindGroup(BindGroupIndex index,
|
||||
BindGroupBase* bindgroup,
|
||||
uint32_t dynamicOffsetCount,
|
||||
const uint32_t* dynamicOffsets);
|
||||
void SetIndexBuffer(wgpu::IndexFormat format, uint64_t size);
|
||||
void SetVertexBuffer(VertexBufferSlot slot, uint64_t size);
|
||||
|
||||
|
@ -46,6 +49,10 @@ namespace dawn_native {
|
|||
using ValidationAspects = std::bitset<kNumAspects>;
|
||||
|
||||
BindGroupBase* GetBindGroup(BindGroupIndex index) const;
|
||||
const std::vector<uint32_t>& GetDynamicOffsets(BindGroupIndex index) const;
|
||||
bool HasPipeline() const;
|
||||
RenderPipelineBase* GetRenderPipeline() const;
|
||||
ComputePipelineBase* GetComputePipeline() const;
|
||||
PipelineLayoutBase* GetPipelineLayout() const;
|
||||
wgpu::IndexFormat GetIndexFormat() const;
|
||||
uint64_t GetIndexBufferSize() const;
|
||||
|
@ -60,6 +67,7 @@ namespace dawn_native {
|
|||
ValidationAspects mAspects;
|
||||
|
||||
ityp::array<BindGroupIndex, BindGroupBase*, kMaxBindGroups> mBindgroups = {};
|
||||
ityp::array<BindGroupIndex, std::vector<uint32_t>, kMaxBindGroups> mDynamicOffsets = {};
|
||||
ityp::bitset<VertexBufferSlot, kMaxVertexBuffers> mVertexBufferSlotsUsed;
|
||||
bool mIndexBufferSet = false;
|
||||
wgpu::IndexFormat mIndexFormat;
|
||||
|
@ -68,7 +76,7 @@ namespace dawn_native {
|
|||
ityp::array<VertexBufferSlot, uint64_t, kMaxVertexBuffers> mVertexBufferSizes = {};
|
||||
|
||||
PipelineLayoutBase* mLastPipelineLayout = nullptr;
|
||||
RenderPipelineBase* mLastRenderPipeline = nullptr;
|
||||
PipelineBase* mLastPipeline = nullptr;
|
||||
|
||||
const RequiredBufferSizes* mMinBufferSizes = nullptr;
|
||||
};
|
||||
|
|
|
@ -14,18 +14,107 @@
|
|||
|
||||
#include "dawn_native/ComputePassEncoder.h"
|
||||
|
||||
#include "dawn_native/BindGroup.h"
|
||||
#include "dawn_native/BindGroupLayout.h"
|
||||
#include "dawn_native/Buffer.h"
|
||||
#include "dawn_native/CommandEncoder.h"
|
||||
#include "dawn_native/CommandValidation.h"
|
||||
#include "dawn_native/Commands.h"
|
||||
#include "dawn_native/ComputePipeline.h"
|
||||
#include "dawn_native/Device.h"
|
||||
#include "dawn_native/InternalPipelineStore.h"
|
||||
#include "dawn_native/ObjectType_autogen.h"
|
||||
#include "dawn_native/PassResourceUsageTracker.h"
|
||||
#include "dawn_native/QuerySet.h"
|
||||
|
||||
namespace dawn_native {
|
||||
|
||||
namespace {
|
||||
|
||||
ResultOrError<ComputePipelineBase*> GetOrCreateIndirectDispatchValidationPipeline(
|
||||
DeviceBase* device) {
|
||||
InternalPipelineStore* store = device->GetInternalPipelineStore();
|
||||
|
||||
if (store->dispatchIndirectValidationPipeline != nullptr) {
|
||||
return store->dispatchIndirectValidationPipeline.Get();
|
||||
}
|
||||
|
||||
ShaderModuleDescriptor descriptor;
|
||||
ShaderModuleWGSLDescriptor wgslDesc;
|
||||
descriptor.nextInChain = reinterpret_cast<ChainedStruct*>(&wgslDesc);
|
||||
|
||||
// TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this
|
||||
// shader in various failure modes.
|
||||
wgslDesc.source = R"(
|
||||
[[block]] struct UniformParams {
|
||||
maxComputeWorkgroupsPerDimension: u32;
|
||||
clientOffsetInU32: u32;
|
||||
};
|
||||
|
||||
[[block]] struct IndirectParams {
|
||||
data: array<u32>;
|
||||
};
|
||||
|
||||
[[block]] struct ValidatedParams {
|
||||
data: array<u32, 3>;
|
||||
};
|
||||
|
||||
[[group(0), binding(0)]] var<uniform> uniformParams: UniformParams;
|
||||
[[group(0), binding(1)]] var<storage, read_write> clientParams: IndirectParams;
|
||||
[[group(0), binding(2)]] var<storage, write> validatedParams: ValidatedParams;
|
||||
|
||||
[[stage(compute), workgroup_size(1, 1, 1)]]
|
||||
fn main() {
|
||||
for (var i = 0u; i < 3u; i = i + 1u) {
|
||||
var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i];
|
||||
if (numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) {
|
||||
numWorkgroups = 0u;
|
||||
}
|
||||
validatedParams.data[i] = numWorkgroups;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
Ref<ShaderModuleBase> shaderModule;
|
||||
DAWN_TRY_ASSIGN(shaderModule, device->CreateShaderModule(&descriptor));
|
||||
|
||||
std::array<BindGroupLayoutEntry, 3> entries;
|
||||
entries[0].binding = 0;
|
||||
entries[0].visibility = wgpu::ShaderStage::Compute;
|
||||
entries[0].buffer.type = wgpu::BufferBindingType::Uniform;
|
||||
entries[1].binding = 1;
|
||||
entries[1].visibility = wgpu::ShaderStage::Compute;
|
||||
entries[1].buffer.type = kInternalStorageBufferBinding;
|
||||
entries[2].binding = 2;
|
||||
entries[2].visibility = wgpu::ShaderStage::Compute;
|
||||
entries[2].buffer.type = wgpu::BufferBindingType::Storage;
|
||||
|
||||
BindGroupLayoutDescriptor bindGroupLayoutDescriptor;
|
||||
bindGroupLayoutDescriptor.entryCount = entries.size();
|
||||
bindGroupLayoutDescriptor.entries = entries.data();
|
||||
Ref<BindGroupLayoutBase> bindGroupLayout;
|
||||
DAWN_TRY_ASSIGN(bindGroupLayout,
|
||||
device->CreateBindGroupLayout(&bindGroupLayoutDescriptor, true));
|
||||
|
||||
PipelineLayoutDescriptor pipelineDescriptor;
|
||||
pipelineDescriptor.bindGroupLayoutCount = 1;
|
||||
pipelineDescriptor.bindGroupLayouts = &bindGroupLayout.Get();
|
||||
Ref<PipelineLayoutBase> pipelineLayout;
|
||||
DAWN_TRY_ASSIGN(pipelineLayout, device->CreatePipelineLayout(&pipelineDescriptor));
|
||||
|
||||
ComputePipelineDescriptor computePipelineDescriptor = {};
|
||||
computePipelineDescriptor.layout = pipelineLayout.Get();
|
||||
computePipelineDescriptor.compute.module = shaderModule.Get();
|
||||
computePipelineDescriptor.compute.entryPoint = "main";
|
||||
|
||||
DAWN_TRY_ASSIGN(store->dispatchIndirectValidationPipeline,
|
||||
device->CreateComputePipeline(&computePipelineDescriptor));
|
||||
|
||||
return store->dispatchIndirectValidationPipeline.Get();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
|
||||
CommandEncoder* commandEncoder,
|
||||
EncodingContext* encodingContext)
|
||||
|
@ -107,6 +196,95 @@ namespace dawn_native {
|
|||
"encoding Dispatch (x: %u, y: %u, z: %u)", x, y, z);
|
||||
}
|
||||
|
||||
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>>
|
||||
ComputePassEncoder::ValidateIndirectDispatch(BufferBase* indirectBuffer,
|
||||
uint64_t indirectOffset) {
|
||||
DeviceBase* device = GetDevice();
|
||||
auto* const store = device->GetInternalPipelineStore();
|
||||
|
||||
Ref<ComputePipelineBase> validationPipeline;
|
||||
DAWN_TRY_ASSIGN(validationPipeline, GetOrCreateIndirectDispatchValidationPipeline(device));
|
||||
|
||||
Ref<BindGroupLayoutBase> layout;
|
||||
DAWN_TRY_ASSIGN(layout, validationPipeline->GetBindGroupLayout(0));
|
||||
|
||||
uint32_t storageBufferOffsetAlignment =
|
||||
device->GetLimits().v1.minStorageBufferOffsetAlignment;
|
||||
|
||||
std::array<BindGroupEntry, 3> bindings;
|
||||
|
||||
// Storage binding holding the client's indirect buffer.
|
||||
BindGroupEntry& clientIndirectBinding = bindings[0];
|
||||
clientIndirectBinding.binding = 1;
|
||||
clientIndirectBinding.buffer = indirectBuffer;
|
||||
|
||||
// Let the offset be the indirectOffset, aligned down to |storageBufferOffsetAlignment|.
|
||||
const uint32_t clientOffsetFromAlignedBoundary =
|
||||
indirectOffset % storageBufferOffsetAlignment;
|
||||
const uint64_t clientOffsetAlignedDown = indirectOffset - clientOffsetFromAlignedBoundary;
|
||||
clientIndirectBinding.offset = clientOffsetAlignedDown;
|
||||
|
||||
// Let the size of the binding be the additional offset, plus the size.
|
||||
clientIndirectBinding.size = kDispatchIndirectSize + clientOffsetFromAlignedBoundary;
|
||||
|
||||
struct UniformParams {
|
||||
uint32_t maxComputeWorkgroupsPerDimension;
|
||||
uint32_t clientOffsetInU32;
|
||||
};
|
||||
|
||||
// Create a uniform buffer to hold parameters for the shader.
|
||||
Ref<BufferBase> uniformBuffer;
|
||||
{
|
||||
BufferDescriptor uniformDesc = {};
|
||||
uniformDesc.size = sizeof(UniformParams);
|
||||
uniformDesc.usage = wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst;
|
||||
uniformDesc.mappedAtCreation = true;
|
||||
DAWN_TRY_ASSIGN(uniformBuffer, device->CreateBuffer(&uniformDesc));
|
||||
|
||||
UniformParams* params = static_cast<UniformParams*>(
|
||||
uniformBuffer->GetMappedRange(0, sizeof(UniformParams)));
|
||||
params->maxComputeWorkgroupsPerDimension =
|
||||
device->GetLimits().v1.maxComputeWorkgroupsPerDimension;
|
||||
params->clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t);
|
||||
uniformBuffer->Unmap();
|
||||
}
|
||||
|
||||
// Uniform buffer binding pointing to the uniform parameters.
|
||||
BindGroupEntry& uniformBinding = bindings[1];
|
||||
uniformBinding.binding = 0;
|
||||
uniformBinding.buffer = uniformBuffer.Get();
|
||||
uniformBinding.offset = 0;
|
||||
uniformBinding.size = sizeof(UniformParams);
|
||||
|
||||
// Reserve space in the scratch buffer to hold the validated indirect params.
|
||||
ScratchBuffer& scratchBuffer = store->scratchIndirectStorage;
|
||||
DAWN_TRY(scratchBuffer.EnsureCapacity(kDispatchIndirectSize));
|
||||
Ref<BufferBase> validatedIndirectBuffer = scratchBuffer.GetBuffer();
|
||||
|
||||
// Binding for the validated indirect params.
|
||||
BindGroupEntry& validatedParamsBinding = bindings[2];
|
||||
validatedParamsBinding.binding = 2;
|
||||
validatedParamsBinding.buffer = validatedIndirectBuffer.Get();
|
||||
validatedParamsBinding.offset = 0;
|
||||
validatedParamsBinding.size = kDispatchIndirectSize;
|
||||
|
||||
BindGroupDescriptor bindGroupDescriptor = {};
|
||||
bindGroupDescriptor.layout = layout.Get();
|
||||
bindGroupDescriptor.entryCount = bindings.size();
|
||||
bindGroupDescriptor.entries = bindings.data();
|
||||
|
||||
Ref<BindGroupBase> validationBindGroup;
|
||||
DAWN_TRY_ASSIGN(validationBindGroup, device->CreateBindGroup(&bindGroupDescriptor));
|
||||
|
||||
// Issue commands to validate the indirect buffer.
|
||||
APISetPipeline(validationPipeline.Get());
|
||||
APISetBindGroup(0, validationBindGroup.Get());
|
||||
APIDispatch(1);
|
||||
|
||||
// Return the new indirect buffer and indirect buffer offset.
|
||||
return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0));
|
||||
}
|
||||
|
||||
void ComputePassEncoder::APIDispatchIndirect(BufferBase* indirectBuffer,
|
||||
uint64_t indirectOffset) {
|
||||
mEncodingContext->TryEncode(
|
||||
|
@ -136,18 +314,46 @@ namespace dawn_native {
|
|||
indirectOffset, kDispatchIndirectSize, indirectBuffer->GetSize());
|
||||
}
|
||||
|
||||
// Record the synchronization scope for Dispatch, both the bindgroups and the
|
||||
// indirect buffer.
|
||||
SyncScopeUsageTracker scope;
|
||||
scope.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
|
||||
mUsageTracker.AddReferencedBuffer(indirectBuffer);
|
||||
// TODO(crbug.com/dawn/1166): If validation is enabled, adding |indirectBuffer|
|
||||
// is needed for correct usage validation even though it will only be bound for
|
||||
// storage. This will unecessarily transition the |indirectBuffer| in
|
||||
// the backend.
|
||||
|
||||
Ref<BufferBase> indirectBufferRef = indirectBuffer;
|
||||
if (IsValidationEnabled()) {
|
||||
// Save the previous command buffer state so it can be restored after the
|
||||
// validation inserts additional commands.
|
||||
CommandBufferStateTracker previousState = mCommandBufferState;
|
||||
|
||||
// Validate each indirect dispatch with a single dispatch to copy the indirect
|
||||
// buffer params into a scratch buffer if they're valid, and otherwise zero them
|
||||
// out. We could consider moving the validation earlier in the pass after the
|
||||
// last point the indirect buffer was used with writable usage, as well as batch
|
||||
// validation for multiple dispatches into one, but inserting commands at
|
||||
// arbitrary points in the past is not possible right now.
|
||||
DAWN_TRY_ASSIGN(
|
||||
std::tie(indirectBufferRef, indirectOffset),
|
||||
ValidateIndirectDispatch(indirectBufferRef.Get(), indirectOffset));
|
||||
|
||||
// Restore the state.
|
||||
RestoreCommandBufferState(std::move(previousState));
|
||||
|
||||
// |indirectBufferRef| was replaced with a scratch buffer. Add it to the
|
||||
// synchronization scope.
|
||||
ASSERT(indirectBufferRef.Get() != indirectBuffer);
|
||||
scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect);
|
||||
mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get());
|
||||
}
|
||||
|
||||
AddDispatchSyncScope(std::move(scope));
|
||||
|
||||
DispatchIndirectCmd* dispatch =
|
||||
allocator->Allocate<DispatchIndirectCmd>(Command::DispatchIndirect);
|
||||
dispatch->indirectBuffer = indirectBuffer;
|
||||
dispatch->indirectBuffer = std::move(indirectBufferRef);
|
||||
dispatch->indirectOffset = indirectOffset;
|
||||
|
||||
return {};
|
||||
},
|
||||
"encoding DispatchIndirect with %s", indirectBuffer);
|
||||
|
@ -187,10 +393,10 @@ namespace dawn_native {
|
|||
}
|
||||
|
||||
mUsageTracker.AddResourcesReferencedByBindGroup(group);
|
||||
|
||||
RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount,
|
||||
dynamicOffsets);
|
||||
mCommandBufferState.SetBindGroup(groupIndex, group);
|
||||
mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount,
|
||||
dynamicOffsets);
|
||||
|
||||
return {};
|
||||
},
|
||||
|
@ -226,4 +432,29 @@ namespace dawn_native {
|
|||
mUsageTracker.AddDispatch(scope.AcquireSyncScopeUsage());
|
||||
}
|
||||
|
||||
void ComputePassEncoder::RestoreCommandBufferState(CommandBufferStateTracker state) {
|
||||
// Encode commands for the backend to restore the pipeline and bind groups.
|
||||
if (state.HasPipeline()) {
|
||||
APISetPipeline(state.GetComputePipeline());
|
||||
}
|
||||
for (BindGroupIndex i(0); i < kMaxBindGroupsTyped; ++i) {
|
||||
BindGroupBase* bg = state.GetBindGroup(i);
|
||||
if (bg != nullptr) {
|
||||
const std::vector<uint32_t>& offsets = state.GetDynamicOffsets(i);
|
||||
if (offsets.empty()) {
|
||||
APISetBindGroup(static_cast<uint32_t>(i), bg);
|
||||
} else {
|
||||
APISetBindGroup(static_cast<uint32_t>(i), bg, offsets.size(), offsets.data());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Restore the frontend state tracking information.
|
||||
mCommandBufferState = std::move(state);
|
||||
}
|
||||
|
||||
CommandBufferStateTracker* ComputePassEncoder::GetCommandBufferStateTrackerForTesting() {
|
||||
return &mCommandBufferState;
|
||||
}
|
||||
|
||||
} // namespace dawn_native
|
||||
|
|
|
@ -50,6 +50,11 @@ namespace dawn_native {
|
|||
|
||||
void APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex);
|
||||
|
||||
CommandBufferStateTracker* GetCommandBufferStateTrackerForTesting();
|
||||
void RestoreCommandBufferStateForTesting(CommandBufferStateTracker state) {
|
||||
RestoreCommandBufferState(std::move(state));
|
||||
}
|
||||
|
||||
protected:
|
||||
ComputePassEncoder(DeviceBase* device,
|
||||
CommandEncoder* commandEncoder,
|
||||
|
@ -57,6 +62,12 @@ namespace dawn_native {
|
|||
ErrorTag errorTag);
|
||||
|
||||
private:
|
||||
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ValidateIndirectDispatch(
|
||||
BufferBase* indirectBuffer,
|
||||
uint64_t indirectOffset);
|
||||
|
||||
void RestoreCommandBufferState(CommandBufferStateTracker state);
|
||||
|
||||
CommandBufferStateTracker mCommandBufferState;
|
||||
|
||||
// Adds the bindgroups used for the current dispatch to the SyncScopeResourceUsage and
|
||||
|
|
|
@ -52,6 +52,7 @@ namespace dawn_native {
|
|||
|
||||
Ref<ComputePipelineBase> renderValidationPipeline;
|
||||
Ref<ShaderModuleBase> renderValidationShader;
|
||||
Ref<ComputePipelineBase> dispatchIndirectValidationPipeline;
|
||||
};
|
||||
|
||||
} // namespace dawn_native
|
||||
|
|
|
@ -208,6 +208,9 @@ namespace dawn_native {
|
|||
BufferLocation::New(indirectBuffer, indirectOffset);
|
||||
}
|
||||
|
||||
// TODO(crbug.com/dawn/1166): Adding the indirectBuffer is needed for correct usage
|
||||
// validation, but it will unecessarily transition to indirectBuffer usage in the
|
||||
// backend.
|
||||
mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect);
|
||||
|
||||
return {};
|
||||
|
@ -404,7 +407,8 @@ namespace dawn_native {
|
|||
|
||||
RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount,
|
||||
dynamicOffsets);
|
||||
mCommandBufferState.SetBindGroup(groupIndex, group);
|
||||
mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount,
|
||||
dynamicOffsets);
|
||||
mUsageTracker.AddBindGroup(group);
|
||||
|
||||
return {};
|
||||
|
|
|
@ -221,6 +221,7 @@ test("dawn_unittests") {
|
|||
"unittests/SystemUtilsTests.cpp",
|
||||
"unittests/ToBackendTests.cpp",
|
||||
"unittests/TypedIntegerTests.cpp",
|
||||
"unittests/native/CommandBufferEncodingTests.cpp",
|
||||
"unittests/native/DestroyObjectTests.cpp",
|
||||
"unittests/validation/BindGroupValidationTests.cpp",
|
||||
"unittests/validation/BufferValidationTests.cpp",
|
||||
|
|
|
@ -12,9 +12,11 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tests/DawnNativeTest.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "common/Assert.h"
|
||||
#include "dawn/dawn_proc.h"
|
||||
#include "dawn_native/ErrorData.h"
|
||||
|
||||
namespace dawn_native {
|
||||
|
@ -28,3 +30,54 @@ namespace dawn_native {
|
|||
}
|
||||
|
||||
} // namespace dawn_native
|
||||
|
||||
DawnNativeTest::DawnNativeTest() {
|
||||
dawnProcSetProcs(&dawn_native::GetProcs());
|
||||
}
|
||||
|
||||
DawnNativeTest::~DawnNativeTest() {
|
||||
device = wgpu::Device();
|
||||
dawnProcSetProcs(nullptr);
|
||||
}
|
||||
|
||||
void DawnNativeTest::SetUp() {
|
||||
instance = std::make_unique<dawn_native::Instance>();
|
||||
instance->DiscoverDefaultAdapters();
|
||||
|
||||
std::vector<dawn_native::Adapter> adapters = instance->GetAdapters();
|
||||
|
||||
// DawnNative unittests run against the null backend, find the corresponding adapter
|
||||
bool foundNullAdapter = false;
|
||||
for (auto& currentAdapter : adapters) {
|
||||
wgpu::AdapterProperties adapterProperties;
|
||||
currentAdapter.GetProperties(&adapterProperties);
|
||||
|
||||
if (adapterProperties.backendType == wgpu::BackendType::Null) {
|
||||
adapter = currentAdapter;
|
||||
foundNullAdapter = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
ASSERT(foundNullAdapter);
|
||||
|
||||
device = wgpu::Device(CreateTestDevice());
|
||||
device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr);
|
||||
}
|
||||
|
||||
void DawnNativeTest::TearDown() {
|
||||
}
|
||||
|
||||
WGPUDevice DawnNativeTest::CreateTestDevice() {
|
||||
// Disabled disallowing unsafe APIs so we can test them.
|
||||
dawn_native::DeviceDescriptor deviceDescriptor;
|
||||
deviceDescriptor.forceDisabledToggles.push_back("disallow_unsafe_apis");
|
||||
|
||||
return adapter.CreateDevice(&deviceDescriptor);
|
||||
}
|
||||
|
||||
// static
|
||||
void DawnNativeTest::OnDeviceError(WGPUErrorType type, const char* message, void* userdata) {
|
||||
ASSERT(type != WGPUErrorType_NoError);
|
||||
FAIL() << "Unexpected error: " << message;
|
||||
}
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "dawn/webgpu_cpp.h"
|
||||
#include "dawn_native/DawnNative.h"
|
||||
#include "dawn_native/ErrorData.h"
|
||||
|
||||
namespace dawn_native {
|
||||
|
@ -29,4 +31,23 @@ namespace dawn_native {
|
|||
|
||||
} // namespace dawn_native
|
||||
|
||||
class DawnNativeTest : public ::testing::Test {
|
||||
public:
|
||||
DawnNativeTest();
|
||||
~DawnNativeTest() override;
|
||||
|
||||
void SetUp() override;
|
||||
void TearDown() override;
|
||||
|
||||
virtual WGPUDevice CreateTestDevice();
|
||||
|
||||
protected:
|
||||
std::unique_ptr<dawn_native::Instance> instance;
|
||||
dawn_native::Adapter adapter;
|
||||
wgpu::Device device;
|
||||
|
||||
private:
|
||||
static void OnDeviceError(WGPUErrorType type, const char* message, void* userdata);
|
||||
};
|
||||
|
||||
#endif // TESTS_DAWNNATIVETEST_H_
|
||||
|
|
|
@ -158,8 +158,14 @@ class ComputeDispatchTests : public DawnTest {
|
|||
queue.Submit(1, &commands);
|
||||
|
||||
std::vector<uint32_t> expected;
|
||||
|
||||
uint32_t maxComputeWorkgroupsPerDimension =
|
||||
GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
|
||||
if (indirectBufferData[indirectStart] == 0 || indirectBufferData[indirectStart + 1] == 0 ||
|
||||
indirectBufferData[indirectStart + 2] == 0) {
|
||||
indirectBufferData[indirectStart + 2] == 0 ||
|
||||
indirectBufferData[indirectStart] > maxComputeWorkgroupsPerDimension ||
|
||||
indirectBufferData[indirectStart + 1] > maxComputeWorkgroupsPerDimension ||
|
||||
indirectBufferData[indirectStart + 2] > maxComputeWorkgroupsPerDimension) {
|
||||
expected = kSentinelData;
|
||||
} else {
|
||||
expected.assign(indirectBufferData.begin() + indirectStart,
|
||||
|
@ -221,6 +227,52 @@ TEST_P(ComputeDispatchTests, IndirectOffset) {
|
|||
IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
// Test indirect dispatches at max limit.
|
||||
TEST_P(ComputeDispatchTests, MaxWorkgroups) {
|
||||
// TODO(crbug.com/dawn/1165): Fails with WARP
|
||||
DAWN_SUPPRESS_TEST_IF(IsWARP());
|
||||
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
// Test that the maximum works in each dimension.
|
||||
// Note: Testing (max, max, max) is very slow.
|
||||
IndirectTest({max, 3, 4}, 0);
|
||||
IndirectTest({2, max, 4}, 0);
|
||||
IndirectTest({2, 3, max}, 0);
|
||||
}
|
||||
|
||||
// Test indirect dispatches exceeding the max limit are noop-ed.
|
||||
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) {
|
||||
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
|
||||
|
||||
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
// All dimensions are above the max
|
||||
IndirectTest({max + 1, max + 1, max + 1}, 0);
|
||||
|
||||
// Only x dimension is above the max
|
||||
IndirectTest({max + 1, 3, 4}, 0);
|
||||
IndirectTest({2 * max, 3, 4}, 0);
|
||||
|
||||
// Only y dimension is above the max
|
||||
IndirectTest({2, max + 1, 4}, 0);
|
||||
IndirectTest({2, 2 * max, 4}, 0);
|
||||
|
||||
// Only z dimension is above the max
|
||||
IndirectTest({2, 3, max + 1}, 0);
|
||||
IndirectTest({2, 3, 2 * max}, 0);
|
||||
}
|
||||
|
||||
// Test indirect dispatches exceeding the max limit with an offset are noop-ed.
|
||||
TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) {
|
||||
DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation"));
|
||||
|
||||
uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension;
|
||||
|
||||
IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t));
|
||||
IndirectTest({1, 2, 3, max + 1, 4, 5}, 2 * sizeof(uint32_t));
|
||||
IndirectTest({1, 2, 3, max + 1, 4, 5}, 3 * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
DAWN_INSTANTIATE_TEST(ComputeDispatchTests,
|
||||
D3D12Backend(),
|
||||
MetalBackend(),
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
// Copyright 2021 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "tests/DawnNativeTest.h"
|
||||
|
||||
#include "dawn_native/CommandBuffer.h"
|
||||
#include "dawn_native/Commands.h"
|
||||
#include "dawn_native/ComputePassEncoder.h"
|
||||
#include "utils/WGPUHelpers.h"
|
||||
|
||||
class CommandBufferEncodingTests : public DawnNativeTest {
|
||||
protected:
|
||||
void ExpectCommands(dawn_native::CommandIterator* commands,
|
||||
std::vector<std::pair<dawn_native::Command,
|
||||
std::function<void(dawn_native::CommandIterator*)>>>
|
||||
expectedCommands) {
|
||||
dawn_native::Command commandId;
|
||||
for (uint32_t commandIndex = 0; commands->NextCommandId(&commandId); ++commandIndex) {
|
||||
ASSERT_LT(commandIndex, expectedCommands.size()) << "Unexpected command";
|
||||
ASSERT_EQ(commandId, expectedCommands[commandIndex].first)
|
||||
<< "at command " << commandIndex;
|
||||
expectedCommands[commandIndex].second(commands);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Indirect dispatch validation changes the bind groups in the middle
|
||||
// of a pass. Test that bindings are restored after the validation runs.
|
||||
TEST_F(CommandBufferEncodingTests, ComputePassEncoderIndirectDispatchStateRestoration) {
|
||||
using namespace dawn_native;
|
||||
|
||||
wgpu::BindGroupLayout staticLayout =
|
||||
utils::MakeBindGroupLayout(device, {{
|
||||
0,
|
||||
wgpu::ShaderStage::Compute,
|
||||
wgpu::BufferBindingType::Uniform,
|
||||
}});
|
||||
|
||||
wgpu::BindGroupLayout dynamicLayout =
|
||||
utils::MakeBindGroupLayout(device, {{
|
||||
0,
|
||||
wgpu::ShaderStage::Compute,
|
||||
wgpu::BufferBindingType::Uniform,
|
||||
true,
|
||||
}});
|
||||
|
||||
// Create a simple pipeline
|
||||
wgpu::ComputePipelineDescriptor csDesc;
|
||||
csDesc.compute.module = utils::CreateShaderModule(device, R"(
|
||||
[[stage(compute), workgroup_size(1, 1, 1)]]
|
||||
fn main() {
|
||||
})");
|
||||
csDesc.compute.entryPoint = "main";
|
||||
|
||||
wgpu::PipelineLayout pl0 = utils::MakePipelineLayout(device, {staticLayout, dynamicLayout});
|
||||
csDesc.layout = pl0;
|
||||
wgpu::ComputePipeline pipeline0 = device.CreateComputePipeline(&csDesc);
|
||||
|
||||
wgpu::PipelineLayout pl1 = utils::MakePipelineLayout(device, {dynamicLayout, staticLayout});
|
||||
csDesc.layout = pl1;
|
||||
wgpu::ComputePipeline pipeline1 = device.CreateComputePipeline(&csDesc);
|
||||
|
||||
// Create buffers to use for both the indirect buffer and the bind groups.
|
||||
wgpu::Buffer indirectBuffer =
|
||||
utils::CreateBufferFromData<uint32_t>(device, wgpu::BufferUsage::Indirect, {1, 2, 3, 4});
|
||||
|
||||
wgpu::BufferDescriptor uniformBufferDesc = {};
|
||||
uniformBufferDesc.size = 512;
|
||||
uniformBufferDesc.usage = wgpu::BufferUsage::Uniform;
|
||||
wgpu::Buffer uniformBuffer = device.CreateBuffer(&uniformBufferDesc);
|
||||
|
||||
wgpu::BindGroup staticBG = utils::MakeBindGroup(device, staticLayout, {{0, uniformBuffer}});
|
||||
|
||||
wgpu::BindGroup dynamicBG =
|
||||
utils::MakeBindGroup(device, dynamicLayout, {{0, uniformBuffer, 0, 256}});
|
||||
|
||||
uint32_t dynamicOffset = 256;
|
||||
std::vector<uint32_t> emptyDynamicOffsets = {};
|
||||
std::vector<uint32_t> singleDynamicOffset = {dynamicOffset};
|
||||
|
||||
// Begin encoding commands.
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
|
||||
CommandBufferStateTracker* stateTracker =
|
||||
FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
|
||||
|
||||
// Perform a dispatch indirect which will be preceded by a validation dispatch.
|
||||
pass.SetPipeline(pipeline0);
|
||||
pass.SetBindGroup(0, staticBG);
|
||||
pass.SetBindGroup(1, dynamicBG, 1, &dynamicOffset);
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
|
||||
|
||||
pass.DispatchIndirect(indirectBuffer, 0);
|
||||
|
||||
// Expect restored state.
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
|
||||
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
|
||||
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
|
||||
|
||||
// Dispatch again to check that the restored state can be used.
|
||||
// Also pass an indirect offset which should get replaced with the offset
|
||||
// into the scratch indirect buffer (0).
|
||||
pass.DispatchIndirect(indirectBuffer, 4);
|
||||
|
||||
// Expect restored state.
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get());
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get());
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get());
|
||||
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets);
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get());
|
||||
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset);
|
||||
|
||||
// Change the pipeline
|
||||
pass.SetPipeline(pipeline1);
|
||||
pass.SetBindGroup(0, dynamicBG, 1, &dynamicOffset);
|
||||
pass.SetBindGroup(1, staticBG);
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
|
||||
|
||||
pass.DispatchIndirect(indirectBuffer, 0);
|
||||
|
||||
// Expect restored state.
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get());
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get());
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), dynamicBG.Get());
|
||||
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), singleDynamicOffset);
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), staticBG.Get());
|
||||
EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), emptyDynamicOffsets);
|
||||
|
||||
pass.EndPass();
|
||||
|
||||
wgpu::CommandBuffer commandBuffer = encoder.Finish();
|
||||
|
||||
auto ExpectSetPipeline = [](wgpu::ComputePipeline pipeline) {
|
||||
return [pipeline](CommandIterator* commands) {
|
||||
auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
|
||||
EXPECT_EQ(ToAPI(cmd->pipeline.Get()), pipeline.Get());
|
||||
};
|
||||
};
|
||||
|
||||
auto ExpectSetBindGroup = [](uint32_t index, wgpu::BindGroup bg,
|
||||
std::vector<uint32_t> offsets = {}) {
|
||||
return [index, bg, offsets](CommandIterator* commands) {
|
||||
auto* cmd = commands->NextCommand<SetBindGroupCmd>();
|
||||
uint32_t* dynamicOffsets = nullptr;
|
||||
if (cmd->dynamicOffsetCount > 0) {
|
||||
dynamicOffsets = commands->NextData<uint32_t>(cmd->dynamicOffsetCount);
|
||||
}
|
||||
|
||||
ASSERT_EQ(cmd->index, BindGroupIndex(index));
|
||||
ASSERT_EQ(ToAPI(cmd->group.Get()), bg.Get());
|
||||
ASSERT_EQ(cmd->dynamicOffsetCount, offsets.size());
|
||||
for (uint32_t i = 0; i < cmd->dynamicOffsetCount; ++i) {
|
||||
ASSERT_EQ(dynamicOffsets[i], offsets[i]);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
// Initialize as null. Once we know the pointer, we'll check
|
||||
// that it's the same buffer every time.
|
||||
WGPUBuffer indirectScratchBuffer = nullptr;
|
||||
auto ExpectDispatchIndirect = [&](CommandIterator* commands) {
|
||||
auto* cmd = commands->NextCommand<DispatchIndirectCmd>();
|
||||
if (indirectScratchBuffer == nullptr) {
|
||||
indirectScratchBuffer = ToAPI(cmd->indirectBuffer.Get());
|
||||
}
|
||||
ASSERT_EQ(ToAPI(cmd->indirectBuffer.Get()), indirectScratchBuffer);
|
||||
ASSERT_EQ(cmd->indirectOffset, uint64_t(0));
|
||||
};
|
||||
|
||||
// Initialize as null. Once we know the pointer, we'll check
|
||||
// that it's the same pipeline every time.
|
||||
WGPUComputePipeline validationPipeline = nullptr;
|
||||
auto ExpectSetValidationPipeline = [&](CommandIterator* commands) {
|
||||
auto* cmd = commands->NextCommand<SetComputePipelineCmd>();
|
||||
WGPUComputePipeline pipeline = ToAPI(cmd->pipeline.Get());
|
||||
if (validationPipeline != nullptr) {
|
||||
EXPECT_EQ(pipeline, validationPipeline);
|
||||
} else {
|
||||
EXPECT_NE(pipeline, nullptr);
|
||||
validationPipeline = pipeline;
|
||||
}
|
||||
};
|
||||
|
||||
auto ExpectSetValidationBindGroup = [&](CommandIterator* commands) {
|
||||
auto* cmd = commands->NextCommand<SetBindGroupCmd>();
|
||||
ASSERT_EQ(cmd->index, BindGroupIndex(0));
|
||||
ASSERT_NE(cmd->group.Get(), nullptr);
|
||||
ASSERT_EQ(cmd->dynamicOffsetCount, 0u);
|
||||
};
|
||||
|
||||
auto ExpectSetValidationDispatch = [&](CommandIterator* commands) {
|
||||
auto* cmd = commands->NextCommand<DispatchCmd>();
|
||||
ASSERT_EQ(cmd->x, 1u);
|
||||
ASSERT_EQ(cmd->y, 1u);
|
||||
ASSERT_EQ(cmd->z, 1u);
|
||||
};
|
||||
|
||||
ExpectCommands(
|
||||
FromAPI(commandBuffer.Get())->GetCommandIteratorForTesting(),
|
||||
{
|
||||
{Command::BeginComputePass,
|
||||
[&](CommandIterator* commands) { SkipCommand(commands, Command::BeginComputePass); }},
|
||||
// Expect the state to be set.
|
||||
{Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
|
||||
{Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
|
||||
{Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
|
||||
|
||||
// Expect the validation.
|
||||
{Command::SetComputePipeline, ExpectSetValidationPipeline},
|
||||
{Command::SetBindGroup, ExpectSetValidationBindGroup},
|
||||
{Command::Dispatch, ExpectSetValidationDispatch},
|
||||
|
||||
// Expect the state to be restored.
|
||||
{Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
|
||||
{Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
|
||||
{Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
|
||||
|
||||
// Expect the dispatchIndirect.
|
||||
{Command::DispatchIndirect, ExpectDispatchIndirect},
|
||||
|
||||
// Expect the validation.
|
||||
{Command::SetComputePipeline, ExpectSetValidationPipeline},
|
||||
{Command::SetBindGroup, ExpectSetValidationBindGroup},
|
||||
{Command::Dispatch, ExpectSetValidationDispatch},
|
||||
|
||||
// Expect the state to be restored.
|
||||
{Command::SetComputePipeline, ExpectSetPipeline(pipeline0)},
|
||||
{Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)},
|
||||
{Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})},
|
||||
|
||||
// Expect the dispatchIndirect.
|
||||
{Command::DispatchIndirect, ExpectDispatchIndirect},
|
||||
|
||||
// Expect the state to be set (new pipeline).
|
||||
{Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
|
||||
{Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
|
||||
{Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
|
||||
|
||||
// Expect the validation.
|
||||
{Command::SetComputePipeline, ExpectSetValidationPipeline},
|
||||
{Command::SetBindGroup, ExpectSetValidationBindGroup},
|
||||
{Command::Dispatch, ExpectSetValidationDispatch},
|
||||
|
||||
// Expect the state to be restored.
|
||||
{Command::SetComputePipeline, ExpectSetPipeline(pipeline1)},
|
||||
{Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})},
|
||||
{Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)},
|
||||
|
||||
// Expect the dispatchIndirect.
|
||||
{Command::DispatchIndirect, ExpectDispatchIndirect},
|
||||
|
||||
{Command::EndComputePass,
|
||||
[&](CommandIterator* commands) { commands->NextCommand<EndComputePassCmd>(); }},
|
||||
});
|
||||
}
|
||||
|
||||
// Test that after restoring state, it is fully applied to the state tracker
|
||||
// and does not leak state changes that occured between a snapshot and the
|
||||
// state restoration.
|
||||
TEST_F(CommandBufferEncodingTests, StateNotLeakedAfterRestore) {
|
||||
using namespace dawn_native;
|
||||
|
||||
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
|
||||
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
|
||||
|
||||
CommandBufferStateTracker* stateTracker =
|
||||
FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting();
|
||||
|
||||
// Snapshot the state.
|
||||
CommandBufferStateTracker snapshot = *stateTracker;
|
||||
// Expect no pipeline in the snapshot
|
||||
EXPECT_FALSE(snapshot.HasPipeline());
|
||||
|
||||
// Create a simple pipeline
|
||||
wgpu::ComputePipelineDescriptor csDesc;
|
||||
csDesc.compute.module = utils::CreateShaderModule(device, R"(
|
||||
[[stage(compute), workgroup_size(1, 1, 1)]]
|
||||
fn main() {
|
||||
})");
|
||||
csDesc.compute.entryPoint = "main";
|
||||
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
|
||||
|
||||
// Set the pipeline.
|
||||
pass.SetPipeline(pipeline);
|
||||
|
||||
// Expect the pipeline to be set.
|
||||
EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline.Get());
|
||||
|
||||
// Restore the state.
|
||||
FromAPI(pass.Get())->RestoreCommandBufferStateForTesting(std::move(snapshot));
|
||||
|
||||
// Expect no pipeline
|
||||
EXPECT_FALSE(stateTracker->HasPipeline());
|
||||
}
|
Loading…
Reference in New Issue