From bcfa7b12533097ae4fed5f3f8f675c0b4ba44f98 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Fri, 29 Oct 2021 18:52:33 +0000 Subject: [PATCH] Implement GPU-based validation for dispatchIndirect Bug: dawn:1039 Change-Id: I1b77244d33b178c8e4d4b7d72dc038ccb9d65c48 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/67142 Reviewed-by: Corentin Wallez Commit-Queue: Austin Eng --- src/dawn_native/CommandBuffer.cpp | 4 + src/dawn_native/CommandBuffer.h | 2 + src/dawn_native/CommandBufferStateTracker.cpp | 63 +++- src/dawn_native/CommandBufferStateTracker.h | 12 +- src/dawn_native/ComputePassEncoder.cpp | 243 +++++++++++++- src/dawn_native/ComputePassEncoder.h | 11 + src/dawn_native/InternalPipelineStore.h | 1 + src/dawn_native/RenderEncoderBase.cpp | 6 +- src/tests/BUILD.gn | 1 + src/tests/DawnNativeTest.cpp | 55 +++- src/tests/DawnNativeTest.h | 21 ++ src/tests/end2end/ComputeDispatchTests.cpp | 54 ++- .../native/CommandBufferEncodingTests.cpp | 310 ++++++++++++++++++ 13 files changed, 756 insertions(+), 27 deletions(-) create mode 100644 src/tests/unittests/native/CommandBufferEncodingTests.cpp diff --git a/src/dawn_native/CommandBuffer.cpp b/src/dawn_native/CommandBuffer.cpp index 18fef0d952..a964f40bec 100644 --- a/src/dawn_native/CommandBuffer.cpp +++ b/src/dawn_native/CommandBuffer.cpp @@ -73,6 +73,10 @@ namespace dawn_native { return mResourceUsages; } + CommandIterator* CommandBufferBase::GetCommandIteratorForTesting() { + return &mCommands; + } + bool IsCompleteSubresourceCopiedTo(const TextureBase* texture, const Extent3D copySize, const uint32_t mipLevel) { diff --git a/src/dawn_native/CommandBuffer.h b/src/dawn_native/CommandBuffer.h index 2800929d36..c6d47ae24c 100644 --- a/src/dawn_native/CommandBuffer.h +++ b/src/dawn_native/CommandBuffer.h @@ -43,6 +43,8 @@ namespace dawn_native { const CommandBufferResourceUsage& GetResourceUsages() const; + CommandIterator* GetCommandIteratorForTesting(); + protected: ~CommandBufferBase() override; diff --git a/src/dawn_native/CommandBufferStateTracker.cpp b/src/dawn_native/CommandBufferStateTracker.cpp index 5210936c59..45892a6260 100644 --- a/src/dawn_native/CommandBufferStateTracker.cpp +++ b/src/dawn_native/CommandBufferStateTracker.cpp @@ -17,8 +17,10 @@ #include "common/Assert.h" #include "common/BitSetIterator.h" #include "dawn_native/BindGroup.h" +#include "dawn_native/ComputePassEncoder.h" #include "dawn_native/ComputePipeline.h" #include "dawn_native/Forward.h" +#include "dawn_native/ObjectType_autogen.h" #include "dawn_native/PipelineLayout.h" #include "dawn_native/RenderPipeline.h" @@ -83,13 +85,15 @@ namespace dawn_native { MaybeError CommandBufferStateTracker::ValidateBufferInRangeForVertexBuffer( uint32_t vertexCount, uint32_t firstVertex) { + RenderPipelineBase* lastRenderPipeline = GetRenderPipeline(); + const ityp::bitset& vertexBufferSlotsUsedAsVertexBuffer = - mLastRenderPipeline->GetVertexBufferSlotsUsedAsVertexBuffer(); + lastRenderPipeline->GetVertexBufferSlotsUsedAsVertexBuffer(); for (auto usedSlotVertex : IterateBitSet(vertexBufferSlotsUsedAsVertexBuffer)) { const VertexBufferInfo& vertexBuffer = - mLastRenderPipeline->GetVertexBuffer(usedSlotVertex); + lastRenderPipeline->GetVertexBuffer(usedSlotVertex); uint64_t arrayStride = vertexBuffer.arrayStride; uint64_t bufferSize = mVertexBufferSizes[usedSlotVertex]; @@ -120,13 +124,15 @@ namespace dawn_native { MaybeError CommandBufferStateTracker::ValidateBufferInRangeForInstanceBuffer( uint32_t instanceCount, uint32_t firstInstance) { + RenderPipelineBase* lastRenderPipeline = GetRenderPipeline(); + const ityp::bitset& vertexBufferSlotsUsedAsInstanceBuffer = - mLastRenderPipeline->GetVertexBufferSlotsUsedAsInstanceBuffer(); + lastRenderPipeline->GetVertexBufferSlotsUsedAsInstanceBuffer(); for (auto usedSlotInstance : IterateBitSet(vertexBufferSlotsUsedAsInstanceBuffer)) { const VertexBufferInfo& vertexBuffer = - mLastRenderPipeline->GetVertexBuffer(usedSlotInstance); + lastRenderPipeline->GetVertexBuffer(usedSlotInstance); uint64_t arrayStride = vertexBuffer.arrayStride; uint64_t bufferSize = mVertexBufferSizes[usedSlotInstance]; if (arrayStride == 0) { @@ -209,18 +215,19 @@ namespace dawn_native { } if (aspects[VALIDATION_ASPECT_VERTEX_BUFFERS]) { - ASSERT(mLastRenderPipeline != nullptr); + RenderPipelineBase* lastRenderPipeline = GetRenderPipeline(); const ityp::bitset& requiredVertexBuffers = - mLastRenderPipeline->GetVertexBufferSlotsUsed(); + lastRenderPipeline->GetVertexBufferSlotsUsed(); if (IsSubset(requiredVertexBuffers, mVertexBufferSlotsUsed)) { mAspects.set(VALIDATION_ASPECT_VERTEX_BUFFERS); } } if (aspects[VALIDATION_ASPECT_INDEX_BUFFER] && mIndexBufferSet) { - if (!IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) || - mIndexFormat == mLastRenderPipeline->GetStripIndexFormat()) { + RenderPipelineBase* lastRenderPipeline = GetRenderPipeline(); + if (!IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) || + mIndexFormat == lastRenderPipeline->GetStripIndexFormat()) { mAspects.set(VALIDATION_ASPECT_INDEX_BUFFER); } } @@ -234,12 +241,13 @@ namespace dawn_native { if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_INDEX_BUFFER])) { DAWN_INVALID_IF(!mIndexBufferSet, "Index buffer was not set."); - wgpu::IndexFormat pipelineIndexFormat = mLastRenderPipeline->GetStripIndexFormat(); + RenderPipelineBase* lastRenderPipeline = GetRenderPipeline(); + wgpu::IndexFormat pipelineIndexFormat = lastRenderPipeline->GetStripIndexFormat(); DAWN_INVALID_IF( - IsStripPrimitiveTopology(mLastRenderPipeline->GetPrimitiveTopology()) && + IsStripPrimitiveTopology(lastRenderPipeline->GetPrimitiveTopology()) && mIndexFormat != pipelineIndexFormat, "Strip index format (%s) of %s does not match index buffer format (%s).", - pipelineIndexFormat, mLastRenderPipeline, mIndexFormat); + pipelineIndexFormat, lastRenderPipeline, mIndexFormat); // The chunk of code above should be similar to the one in |RecomputeLazyAspects|. // It returns the first invalid state found. We shouldn't be able to reach this line @@ -251,7 +259,7 @@ namespace dawn_native { // TODO(dawn:563): Indicate which slots were not set. DAWN_INVALID_IF(aspects[VALIDATION_ASPECT_VERTEX_BUFFERS], - "Vertex buffer slots required by %s were not set.", mLastRenderPipeline); + "Vertex buffer slots required by %s were not set.", GetRenderPipeline()); if (DAWN_UNLIKELY(aspects[VALIDATION_ASPECT_BIND_GROUPS])) { for (BindGroupIndex i : IterateBitSet(mLastPipelineLayout->GetBindGroupLayoutsMask())) { @@ -290,12 +298,15 @@ namespace dawn_native { } void CommandBufferStateTracker::SetRenderPipeline(RenderPipelineBase* pipeline) { - mLastRenderPipeline = pipeline; SetPipelineCommon(pipeline); } - void CommandBufferStateTracker::SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup) { + void CommandBufferStateTracker::SetBindGroup(BindGroupIndex index, + BindGroupBase* bindgroup, + uint32_t dynamicOffsetCount, + const uint32_t* dynamicOffsets) { mBindgroups[index] = bindgroup; + mDynamicOffsets[index].assign(dynamicOffsets, dynamicOffsets + dynamicOffsetCount); mAspects.reset(VALIDATION_ASPECT_BIND_GROUPS); } @@ -311,8 +322,9 @@ namespace dawn_native { } void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) { - mLastPipelineLayout = pipeline->GetLayout(); - mMinBufferSizes = &pipeline->GetMinBufferSizes(); + mLastPipeline = pipeline; + mLastPipelineLayout = pipeline != nullptr ? pipeline->GetLayout() : nullptr; + mMinBufferSizes = pipeline != nullptr ? &pipeline->GetMinBufferSizes() : nullptr; mAspects.set(VALIDATION_ASPECT_PIPELINE); @@ -324,6 +336,25 @@ namespace dawn_native { return mBindgroups[index]; } + const std::vector& CommandBufferStateTracker::GetDynamicOffsets( + BindGroupIndex index) const { + return mDynamicOffsets[index]; + } + + bool CommandBufferStateTracker::HasPipeline() const { + return mLastPipeline != nullptr; + } + + RenderPipelineBase* CommandBufferStateTracker::GetRenderPipeline() const { + ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::RenderPipeline); + return static_cast(mLastPipeline); + } + + ComputePipelineBase* CommandBufferStateTracker::GetComputePipeline() const { + ASSERT(HasPipeline() && mLastPipeline->GetType() == ObjectType::ComputePipeline); + return static_cast(mLastPipeline); + } + PipelineLayoutBase* CommandBufferStateTracker::GetPipelineLayout() const { return mLastPipelineLayout; } diff --git a/src/dawn_native/CommandBufferStateTracker.h b/src/dawn_native/CommandBufferStateTracker.h index 0a6c587a98..5686956faf 100644 --- a/src/dawn_native/CommandBufferStateTracker.h +++ b/src/dawn_native/CommandBufferStateTracker.h @@ -38,7 +38,10 @@ namespace dawn_native { // State-modifying methods void SetComputePipeline(ComputePipelineBase* pipeline); void SetRenderPipeline(RenderPipelineBase* pipeline); - void SetBindGroup(BindGroupIndex index, BindGroupBase* bindgroup); + void SetBindGroup(BindGroupIndex index, + BindGroupBase* bindgroup, + uint32_t dynamicOffsetCount, + const uint32_t* dynamicOffsets); void SetIndexBuffer(wgpu::IndexFormat format, uint64_t size); void SetVertexBuffer(VertexBufferSlot slot, uint64_t size); @@ -46,6 +49,10 @@ namespace dawn_native { using ValidationAspects = std::bitset; BindGroupBase* GetBindGroup(BindGroupIndex index) const; + const std::vector& GetDynamicOffsets(BindGroupIndex index) const; + bool HasPipeline() const; + RenderPipelineBase* GetRenderPipeline() const; + ComputePipelineBase* GetComputePipeline() const; PipelineLayoutBase* GetPipelineLayout() const; wgpu::IndexFormat GetIndexFormat() const; uint64_t GetIndexBufferSize() const; @@ -60,6 +67,7 @@ namespace dawn_native { ValidationAspects mAspects; ityp::array mBindgroups = {}; + ityp::array, kMaxBindGroups> mDynamicOffsets = {}; ityp::bitset mVertexBufferSlotsUsed; bool mIndexBufferSet = false; wgpu::IndexFormat mIndexFormat; @@ -68,7 +76,7 @@ namespace dawn_native { ityp::array mVertexBufferSizes = {}; PipelineLayoutBase* mLastPipelineLayout = nullptr; - RenderPipelineBase* mLastRenderPipeline = nullptr; + PipelineBase* mLastPipeline = nullptr; const RequiredBufferSizes* mMinBufferSizes = nullptr; }; diff --git a/src/dawn_native/ComputePassEncoder.cpp b/src/dawn_native/ComputePassEncoder.cpp index 1aa4845886..05c68fb11c 100644 --- a/src/dawn_native/ComputePassEncoder.cpp +++ b/src/dawn_native/ComputePassEncoder.cpp @@ -14,18 +14,107 @@ #include "dawn_native/ComputePassEncoder.h" +#include "dawn_native/BindGroup.h" +#include "dawn_native/BindGroupLayout.h" #include "dawn_native/Buffer.h" #include "dawn_native/CommandEncoder.h" #include "dawn_native/CommandValidation.h" #include "dawn_native/Commands.h" #include "dawn_native/ComputePipeline.h" #include "dawn_native/Device.h" +#include "dawn_native/InternalPipelineStore.h" #include "dawn_native/ObjectType_autogen.h" #include "dawn_native/PassResourceUsageTracker.h" #include "dawn_native/QuerySet.h" namespace dawn_native { + namespace { + + ResultOrError GetOrCreateIndirectDispatchValidationPipeline( + DeviceBase* device) { + InternalPipelineStore* store = device->GetInternalPipelineStore(); + + if (store->dispatchIndirectValidationPipeline != nullptr) { + return store->dispatchIndirectValidationPipeline.Get(); + } + + ShaderModuleDescriptor descriptor; + ShaderModuleWGSLDescriptor wgslDesc; + descriptor.nextInChain = reinterpret_cast(&wgslDesc); + + // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this + // shader in various failure modes. + wgslDesc.source = R"( + [[block]] struct UniformParams { + maxComputeWorkgroupsPerDimension: u32; + clientOffsetInU32: u32; + }; + + [[block]] struct IndirectParams { + data: array; + }; + + [[block]] struct ValidatedParams { + data: array; + }; + + [[group(0), binding(0)]] var uniformParams: UniformParams; + [[group(0), binding(1)]] var clientParams: IndirectParams; + [[group(0), binding(2)]] var validatedParams: ValidatedParams; + + [[stage(compute), workgroup_size(1, 1, 1)]] + fn main() { + for (var i = 0u; i < 3u; i = i + 1u) { + var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i]; + if (numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) { + numWorkgroups = 0u; + } + validatedParams.data[i] = numWorkgroups; + } + } + )"; + + Ref shaderModule; + DAWN_TRY_ASSIGN(shaderModule, device->CreateShaderModule(&descriptor)); + + std::array entries; + entries[0].binding = 0; + entries[0].visibility = wgpu::ShaderStage::Compute; + entries[0].buffer.type = wgpu::BufferBindingType::Uniform; + entries[1].binding = 1; + entries[1].visibility = wgpu::ShaderStage::Compute; + entries[1].buffer.type = kInternalStorageBufferBinding; + entries[2].binding = 2; + entries[2].visibility = wgpu::ShaderStage::Compute; + entries[2].buffer.type = wgpu::BufferBindingType::Storage; + + BindGroupLayoutDescriptor bindGroupLayoutDescriptor; + bindGroupLayoutDescriptor.entryCount = entries.size(); + bindGroupLayoutDescriptor.entries = entries.data(); + Ref bindGroupLayout; + DAWN_TRY_ASSIGN(bindGroupLayout, + device->CreateBindGroupLayout(&bindGroupLayoutDescriptor, true)); + + PipelineLayoutDescriptor pipelineDescriptor; + pipelineDescriptor.bindGroupLayoutCount = 1; + pipelineDescriptor.bindGroupLayouts = &bindGroupLayout.Get(); + Ref pipelineLayout; + DAWN_TRY_ASSIGN(pipelineLayout, device->CreatePipelineLayout(&pipelineDescriptor)); + + ComputePipelineDescriptor computePipelineDescriptor = {}; + computePipelineDescriptor.layout = pipelineLayout.Get(); + computePipelineDescriptor.compute.module = shaderModule.Get(); + computePipelineDescriptor.compute.entryPoint = "main"; + + DAWN_TRY_ASSIGN(store->dispatchIndirectValidationPipeline, + device->CreateComputePipeline(&computePipelineDescriptor)); + + return store->dispatchIndirectValidationPipeline.Get(); + } + + } // namespace + ComputePassEncoder::ComputePassEncoder(DeviceBase* device, CommandEncoder* commandEncoder, EncodingContext* encodingContext) @@ -107,6 +196,95 @@ namespace dawn_native { "encoding Dispatch (x: %u, y: %u, z: %u)", x, y, z); } + ResultOrError, uint64_t>> + ComputePassEncoder::ValidateIndirectDispatch(BufferBase* indirectBuffer, + uint64_t indirectOffset) { + DeviceBase* device = GetDevice(); + auto* const store = device->GetInternalPipelineStore(); + + Ref validationPipeline; + DAWN_TRY_ASSIGN(validationPipeline, GetOrCreateIndirectDispatchValidationPipeline(device)); + + Ref layout; + DAWN_TRY_ASSIGN(layout, validationPipeline->GetBindGroupLayout(0)); + + uint32_t storageBufferOffsetAlignment = + device->GetLimits().v1.minStorageBufferOffsetAlignment; + + std::array bindings; + + // Storage binding holding the client's indirect buffer. + BindGroupEntry& clientIndirectBinding = bindings[0]; + clientIndirectBinding.binding = 1; + clientIndirectBinding.buffer = indirectBuffer; + + // Let the offset be the indirectOffset, aligned down to |storageBufferOffsetAlignment|. + const uint32_t clientOffsetFromAlignedBoundary = + indirectOffset % storageBufferOffsetAlignment; + const uint64_t clientOffsetAlignedDown = indirectOffset - clientOffsetFromAlignedBoundary; + clientIndirectBinding.offset = clientOffsetAlignedDown; + + // Let the size of the binding be the additional offset, plus the size. + clientIndirectBinding.size = kDispatchIndirectSize + clientOffsetFromAlignedBoundary; + + struct UniformParams { + uint32_t maxComputeWorkgroupsPerDimension; + uint32_t clientOffsetInU32; + }; + + // Create a uniform buffer to hold parameters for the shader. + Ref uniformBuffer; + { + BufferDescriptor uniformDesc = {}; + uniformDesc.size = sizeof(UniformParams); + uniformDesc.usage = wgpu::BufferUsage::Uniform | wgpu::BufferUsage::CopyDst; + uniformDesc.mappedAtCreation = true; + DAWN_TRY_ASSIGN(uniformBuffer, device->CreateBuffer(&uniformDesc)); + + UniformParams* params = static_cast( + uniformBuffer->GetMappedRange(0, sizeof(UniformParams))); + params->maxComputeWorkgroupsPerDimension = + device->GetLimits().v1.maxComputeWorkgroupsPerDimension; + params->clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t); + uniformBuffer->Unmap(); + } + + // Uniform buffer binding pointing to the uniform parameters. + BindGroupEntry& uniformBinding = bindings[1]; + uniformBinding.binding = 0; + uniformBinding.buffer = uniformBuffer.Get(); + uniformBinding.offset = 0; + uniformBinding.size = sizeof(UniformParams); + + // Reserve space in the scratch buffer to hold the validated indirect params. + ScratchBuffer& scratchBuffer = store->scratchIndirectStorage; + DAWN_TRY(scratchBuffer.EnsureCapacity(kDispatchIndirectSize)); + Ref validatedIndirectBuffer = scratchBuffer.GetBuffer(); + + // Binding for the validated indirect params. + BindGroupEntry& validatedParamsBinding = bindings[2]; + validatedParamsBinding.binding = 2; + validatedParamsBinding.buffer = validatedIndirectBuffer.Get(); + validatedParamsBinding.offset = 0; + validatedParamsBinding.size = kDispatchIndirectSize; + + BindGroupDescriptor bindGroupDescriptor = {}; + bindGroupDescriptor.layout = layout.Get(); + bindGroupDescriptor.entryCount = bindings.size(); + bindGroupDescriptor.entries = bindings.data(); + + Ref validationBindGroup; + DAWN_TRY_ASSIGN(validationBindGroup, device->CreateBindGroup(&bindGroupDescriptor)); + + // Issue commands to validate the indirect buffer. + APISetPipeline(validationPipeline.Get()); + APISetBindGroup(0, validationBindGroup.Get()); + APIDispatch(1); + + // Return the new indirect buffer and indirect buffer offset. + return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0)); + } + void ComputePassEncoder::APIDispatchIndirect(BufferBase* indirectBuffer, uint64_t indirectOffset) { mEncodingContext->TryEncode( @@ -136,18 +314,46 @@ namespace dawn_native { indirectOffset, kDispatchIndirectSize, indirectBuffer->GetSize()); } - // Record the synchronization scope for Dispatch, both the bindgroups and the - // indirect buffer. SyncScopeUsageTracker scope; scope.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect); mUsageTracker.AddReferencedBuffer(indirectBuffer); + // TODO(crbug.com/dawn/1166): If validation is enabled, adding |indirectBuffer| + // is needed for correct usage validation even though it will only be bound for + // storage. This will unecessarily transition the |indirectBuffer| in + // the backend. + + Ref indirectBufferRef = indirectBuffer; + if (IsValidationEnabled()) { + // Save the previous command buffer state so it can be restored after the + // validation inserts additional commands. + CommandBufferStateTracker previousState = mCommandBufferState; + + // Validate each indirect dispatch with a single dispatch to copy the indirect + // buffer params into a scratch buffer if they're valid, and otherwise zero them + // out. We could consider moving the validation earlier in the pass after the + // last point the indirect buffer was used with writable usage, as well as batch + // validation for multiple dispatches into one, but inserting commands at + // arbitrary points in the past is not possible right now. + DAWN_TRY_ASSIGN( + std::tie(indirectBufferRef, indirectOffset), + ValidateIndirectDispatch(indirectBufferRef.Get(), indirectOffset)); + + // Restore the state. + RestoreCommandBufferState(std::move(previousState)); + + // |indirectBufferRef| was replaced with a scratch buffer. Add it to the + // synchronization scope. + ASSERT(indirectBufferRef.Get() != indirectBuffer); + scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect); + mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get()); + } + AddDispatchSyncScope(std::move(scope)); DispatchIndirectCmd* dispatch = allocator->Allocate(Command::DispatchIndirect); - dispatch->indirectBuffer = indirectBuffer; + dispatch->indirectBuffer = std::move(indirectBufferRef); dispatch->indirectOffset = indirectOffset; - return {}; }, "encoding DispatchIndirect with %s", indirectBuffer); @@ -187,10 +393,10 @@ namespace dawn_native { } mUsageTracker.AddResourcesReferencedByBindGroup(group); - RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount, dynamicOffsets); - mCommandBufferState.SetBindGroup(groupIndex, group); + mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount, + dynamicOffsets); return {}; }, @@ -226,4 +432,29 @@ namespace dawn_native { mUsageTracker.AddDispatch(scope.AcquireSyncScopeUsage()); } + void ComputePassEncoder::RestoreCommandBufferState(CommandBufferStateTracker state) { + // Encode commands for the backend to restore the pipeline and bind groups. + if (state.HasPipeline()) { + APISetPipeline(state.GetComputePipeline()); + } + for (BindGroupIndex i(0); i < kMaxBindGroupsTyped; ++i) { + BindGroupBase* bg = state.GetBindGroup(i); + if (bg != nullptr) { + const std::vector& offsets = state.GetDynamicOffsets(i); + if (offsets.empty()) { + APISetBindGroup(static_cast(i), bg); + } else { + APISetBindGroup(static_cast(i), bg, offsets.size(), offsets.data()); + } + } + } + + // Restore the frontend state tracking information. + mCommandBufferState = std::move(state); + } + + CommandBufferStateTracker* ComputePassEncoder::GetCommandBufferStateTrackerForTesting() { + return &mCommandBufferState; + } + } // namespace dawn_native diff --git a/src/dawn_native/ComputePassEncoder.h b/src/dawn_native/ComputePassEncoder.h index b0962f4b91..03997cebad 100644 --- a/src/dawn_native/ComputePassEncoder.h +++ b/src/dawn_native/ComputePassEncoder.h @@ -50,6 +50,11 @@ namespace dawn_native { void APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex); + CommandBufferStateTracker* GetCommandBufferStateTrackerForTesting(); + void RestoreCommandBufferStateForTesting(CommandBufferStateTracker state) { + RestoreCommandBufferState(std::move(state)); + } + protected: ComputePassEncoder(DeviceBase* device, CommandEncoder* commandEncoder, @@ -57,6 +62,12 @@ namespace dawn_native { ErrorTag errorTag); private: + ResultOrError, uint64_t>> ValidateIndirectDispatch( + BufferBase* indirectBuffer, + uint64_t indirectOffset); + + void RestoreCommandBufferState(CommandBufferStateTracker state); + CommandBufferStateTracker mCommandBufferState; // Adds the bindgroups used for the current dispatch to the SyncScopeResourceUsage and diff --git a/src/dawn_native/InternalPipelineStore.h b/src/dawn_native/InternalPipelineStore.h index acf3b13dce..803e0dfd38 100644 --- a/src/dawn_native/InternalPipelineStore.h +++ b/src/dawn_native/InternalPipelineStore.h @@ -52,6 +52,7 @@ namespace dawn_native { Ref renderValidationPipeline; Ref renderValidationShader; + Ref dispatchIndirectValidationPipeline; }; } // namespace dawn_native diff --git a/src/dawn_native/RenderEncoderBase.cpp b/src/dawn_native/RenderEncoderBase.cpp index 0445a972be..a8ef2ffbf1 100644 --- a/src/dawn_native/RenderEncoderBase.cpp +++ b/src/dawn_native/RenderEncoderBase.cpp @@ -208,6 +208,9 @@ namespace dawn_native { BufferLocation::New(indirectBuffer, indirectOffset); } + // TODO(crbug.com/dawn/1166): Adding the indirectBuffer is needed for correct usage + // validation, but it will unecessarily transition to indirectBuffer usage in the + // backend. mUsageTracker.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect); return {}; @@ -404,7 +407,8 @@ namespace dawn_native { RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount, dynamicOffsets); - mCommandBufferState.SetBindGroup(groupIndex, group); + mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount, + dynamicOffsets); mUsageTracker.AddBindGroup(group); return {}; diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn index c7f4b7a96e..8ffd921524 100644 --- a/src/tests/BUILD.gn +++ b/src/tests/BUILD.gn @@ -221,6 +221,7 @@ test("dawn_unittests") { "unittests/SystemUtilsTests.cpp", "unittests/ToBackendTests.cpp", "unittests/TypedIntegerTests.cpp", + "unittests/native/CommandBufferEncodingTests.cpp", "unittests/native/DestroyObjectTests.cpp", "unittests/validation/BindGroupValidationTests.cpp", "unittests/validation/BufferValidationTests.cpp", diff --git a/src/tests/DawnNativeTest.cpp b/src/tests/DawnNativeTest.cpp index d39c8e0d8e..28d69bfa31 100644 --- a/src/tests/DawnNativeTest.cpp +++ b/src/tests/DawnNativeTest.cpp @@ -12,9 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include "tests/DawnNativeTest.h" #include "absl/strings/str_cat.h" +#include "common/Assert.h" +#include "dawn/dawn_proc.h" #include "dawn_native/ErrorData.h" namespace dawn_native { @@ -28,3 +30,54 @@ namespace dawn_native { } } // namespace dawn_native + +DawnNativeTest::DawnNativeTest() { + dawnProcSetProcs(&dawn_native::GetProcs()); +} + +DawnNativeTest::~DawnNativeTest() { + device = wgpu::Device(); + dawnProcSetProcs(nullptr); +} + +void DawnNativeTest::SetUp() { + instance = std::make_unique(); + instance->DiscoverDefaultAdapters(); + + std::vector adapters = instance->GetAdapters(); + + // DawnNative unittests run against the null backend, find the corresponding adapter + bool foundNullAdapter = false; + for (auto& currentAdapter : adapters) { + wgpu::AdapterProperties adapterProperties; + currentAdapter.GetProperties(&adapterProperties); + + if (adapterProperties.backendType == wgpu::BackendType::Null) { + adapter = currentAdapter; + foundNullAdapter = true; + break; + } + } + + ASSERT(foundNullAdapter); + + device = wgpu::Device(CreateTestDevice()); + device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr); +} + +void DawnNativeTest::TearDown() { +} + +WGPUDevice DawnNativeTest::CreateTestDevice() { + // Disabled disallowing unsafe APIs so we can test them. + dawn_native::DeviceDescriptor deviceDescriptor; + deviceDescriptor.forceDisabledToggles.push_back("disallow_unsafe_apis"); + + return adapter.CreateDevice(&deviceDescriptor); +} + +// static +void DawnNativeTest::OnDeviceError(WGPUErrorType type, const char* message, void* userdata) { + ASSERT(type != WGPUErrorType_NoError); + FAIL() << "Unexpected error: " << message; +} diff --git a/src/tests/DawnNativeTest.h b/src/tests/DawnNativeTest.h index 94fdafbbec..91904a39e5 100644 --- a/src/tests/DawnNativeTest.h +++ b/src/tests/DawnNativeTest.h @@ -17,6 +17,8 @@ #include +#include "dawn/webgpu_cpp.h" +#include "dawn_native/DawnNative.h" #include "dawn_native/ErrorData.h" namespace dawn_native { @@ -29,4 +31,23 @@ namespace dawn_native { } // namespace dawn_native +class DawnNativeTest : public ::testing::Test { + public: + DawnNativeTest(); + ~DawnNativeTest() override; + + void SetUp() override; + void TearDown() override; + + virtual WGPUDevice CreateTestDevice(); + + protected: + std::unique_ptr instance; + dawn_native::Adapter adapter; + wgpu::Device device; + + private: + static void OnDeviceError(WGPUErrorType type, const char* message, void* userdata); +}; + #endif // TESTS_DAWNNATIVETEST_H_ diff --git a/src/tests/end2end/ComputeDispatchTests.cpp b/src/tests/end2end/ComputeDispatchTests.cpp index 1a8b163f5d..cbc2d8642c 100644 --- a/src/tests/end2end/ComputeDispatchTests.cpp +++ b/src/tests/end2end/ComputeDispatchTests.cpp @@ -158,8 +158,14 @@ class ComputeDispatchTests : public DawnTest { queue.Submit(1, &commands); std::vector expected; + + uint32_t maxComputeWorkgroupsPerDimension = + GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension; if (indirectBufferData[indirectStart] == 0 || indirectBufferData[indirectStart + 1] == 0 || - indirectBufferData[indirectStart + 2] == 0) { + indirectBufferData[indirectStart + 2] == 0 || + indirectBufferData[indirectStart] > maxComputeWorkgroupsPerDimension || + indirectBufferData[indirectStart + 1] > maxComputeWorkgroupsPerDimension || + indirectBufferData[indirectStart + 2] > maxComputeWorkgroupsPerDimension) { expected = kSentinelData; } else { expected.assign(indirectBufferData.begin() + indirectStart, @@ -221,6 +227,52 @@ TEST_P(ComputeDispatchTests, IndirectOffset) { IndirectTest({0, 0, 0, 2, 3, 4}, 3 * sizeof(uint32_t)); } +// Test indirect dispatches at max limit. +TEST_P(ComputeDispatchTests, MaxWorkgroups) { + // TODO(crbug.com/dawn/1165): Fails with WARP + DAWN_SUPPRESS_TEST_IF(IsWARP()); + uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension; + + // Test that the maximum works in each dimension. + // Note: Testing (max, max, max) is very slow. + IndirectTest({max, 3, 4}, 0); + IndirectTest({2, max, 4}, 0); + IndirectTest({2, 3, max}, 0); +} + +// Test indirect dispatches exceeding the max limit are noop-ed. +TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsNoop) { + DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation")); + + uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension; + + // All dimensions are above the max + IndirectTest({max + 1, max + 1, max + 1}, 0); + + // Only x dimension is above the max + IndirectTest({max + 1, 3, 4}, 0); + IndirectTest({2 * max, 3, 4}, 0); + + // Only y dimension is above the max + IndirectTest({2, max + 1, 4}, 0); + IndirectTest({2, 2 * max, 4}, 0); + + // Only z dimension is above the max + IndirectTest({2, 3, max + 1}, 0); + IndirectTest({2, 3, 2 * max}, 0); +} + +// Test indirect dispatches exceeding the max limit with an offset are noop-ed. +TEST_P(ComputeDispatchTests, ExceedsMaxWorkgroupsWithOffsetNoop) { + DAWN_TEST_UNSUPPORTED_IF(HasToggleEnabled("skip_validation")); + + uint32_t max = GetSupportedLimits().limits.maxComputeWorkgroupsPerDimension; + + IndirectTest({1, 2, 3, max + 1, 4, 5}, 1 * sizeof(uint32_t)); + IndirectTest({1, 2, 3, max + 1, 4, 5}, 2 * sizeof(uint32_t)); + IndirectTest({1, 2, 3, max + 1, 4, 5}, 3 * sizeof(uint32_t)); +} + DAWN_INSTANTIATE_TEST(ComputeDispatchTests, D3D12Backend(), MetalBackend(), diff --git a/src/tests/unittests/native/CommandBufferEncodingTests.cpp b/src/tests/unittests/native/CommandBufferEncodingTests.cpp new file mode 100644 index 0000000000..c1ca2d993b --- /dev/null +++ b/src/tests/unittests/native/CommandBufferEncodingTests.cpp @@ -0,0 +1,310 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tests/DawnNativeTest.h" + +#include "dawn_native/CommandBuffer.h" +#include "dawn_native/Commands.h" +#include "dawn_native/ComputePassEncoder.h" +#include "utils/WGPUHelpers.h" + +class CommandBufferEncodingTests : public DawnNativeTest { + protected: + void ExpectCommands(dawn_native::CommandIterator* commands, + std::vector>> + expectedCommands) { + dawn_native::Command commandId; + for (uint32_t commandIndex = 0; commands->NextCommandId(&commandId); ++commandIndex) { + ASSERT_LT(commandIndex, expectedCommands.size()) << "Unexpected command"; + ASSERT_EQ(commandId, expectedCommands[commandIndex].first) + << "at command " << commandIndex; + expectedCommands[commandIndex].second(commands); + } + } +}; + +// Indirect dispatch validation changes the bind groups in the middle +// of a pass. Test that bindings are restored after the validation runs. +TEST_F(CommandBufferEncodingTests, ComputePassEncoderIndirectDispatchStateRestoration) { + using namespace dawn_native; + + wgpu::BindGroupLayout staticLayout = + utils::MakeBindGroupLayout(device, {{ + 0, + wgpu::ShaderStage::Compute, + wgpu::BufferBindingType::Uniform, + }}); + + wgpu::BindGroupLayout dynamicLayout = + utils::MakeBindGroupLayout(device, {{ + 0, + wgpu::ShaderStage::Compute, + wgpu::BufferBindingType::Uniform, + true, + }}); + + // Create a simple pipeline + wgpu::ComputePipelineDescriptor csDesc; + csDesc.compute.module = utils::CreateShaderModule(device, R"( + [[stage(compute), workgroup_size(1, 1, 1)]] + fn main() { + })"); + csDesc.compute.entryPoint = "main"; + + wgpu::PipelineLayout pl0 = utils::MakePipelineLayout(device, {staticLayout, dynamicLayout}); + csDesc.layout = pl0; + wgpu::ComputePipeline pipeline0 = device.CreateComputePipeline(&csDesc); + + wgpu::PipelineLayout pl1 = utils::MakePipelineLayout(device, {dynamicLayout, staticLayout}); + csDesc.layout = pl1; + wgpu::ComputePipeline pipeline1 = device.CreateComputePipeline(&csDesc); + + // Create buffers to use for both the indirect buffer and the bind groups. + wgpu::Buffer indirectBuffer = + utils::CreateBufferFromData(device, wgpu::BufferUsage::Indirect, {1, 2, 3, 4}); + + wgpu::BufferDescriptor uniformBufferDesc = {}; + uniformBufferDesc.size = 512; + uniformBufferDesc.usage = wgpu::BufferUsage::Uniform; + wgpu::Buffer uniformBuffer = device.CreateBuffer(&uniformBufferDesc); + + wgpu::BindGroup staticBG = utils::MakeBindGroup(device, staticLayout, {{0, uniformBuffer}}); + + wgpu::BindGroup dynamicBG = + utils::MakeBindGroup(device, dynamicLayout, {{0, uniformBuffer, 0, 256}}); + + uint32_t dynamicOffset = 256; + std::vector emptyDynamicOffsets = {}; + std::vector singleDynamicOffset = {dynamicOffset}; + + // Begin encoding commands. + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + + CommandBufferStateTracker* stateTracker = + FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting(); + + // Perform a dispatch indirect which will be preceded by a validation dispatch. + pass.SetPipeline(pipeline0); + pass.SetBindGroup(0, staticBG); + pass.SetBindGroup(1, dynamicBG, 1, &dynamicOffset); + EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get()); + + pass.DispatchIndirect(indirectBuffer, 0); + + // Expect restored state. + EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get()); + EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get()); + EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get()); + EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets); + EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get()); + EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset); + + // Dispatch again to check that the restored state can be used. + // Also pass an indirect offset which should get replaced with the offset + // into the scratch indirect buffer (0). + pass.DispatchIndirect(indirectBuffer, 4); + + // Expect restored state. + EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline0.Get()); + EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl0.Get()); + EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), staticBG.Get()); + EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), emptyDynamicOffsets); + EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), dynamicBG.Get()); + EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), singleDynamicOffset); + + // Change the pipeline + pass.SetPipeline(pipeline1); + pass.SetBindGroup(0, dynamicBG, 1, &dynamicOffset); + pass.SetBindGroup(1, staticBG); + EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get()); + EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get()); + + pass.DispatchIndirect(indirectBuffer, 0); + + // Expect restored state. + EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline1.Get()); + EXPECT_EQ(ToAPI(stateTracker->GetPipelineLayout()), pl1.Get()); + EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(0))), dynamicBG.Get()); + EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(0)), singleDynamicOffset); + EXPECT_EQ(ToAPI(stateTracker->GetBindGroup(BindGroupIndex(1))), staticBG.Get()); + EXPECT_EQ(stateTracker->GetDynamicOffsets(BindGroupIndex(1)), emptyDynamicOffsets); + + pass.EndPass(); + + wgpu::CommandBuffer commandBuffer = encoder.Finish(); + + auto ExpectSetPipeline = [](wgpu::ComputePipeline pipeline) { + return [pipeline](CommandIterator* commands) { + auto* cmd = commands->NextCommand(); + EXPECT_EQ(ToAPI(cmd->pipeline.Get()), pipeline.Get()); + }; + }; + + auto ExpectSetBindGroup = [](uint32_t index, wgpu::BindGroup bg, + std::vector offsets = {}) { + return [index, bg, offsets](CommandIterator* commands) { + auto* cmd = commands->NextCommand(); + uint32_t* dynamicOffsets = nullptr; + if (cmd->dynamicOffsetCount > 0) { + dynamicOffsets = commands->NextData(cmd->dynamicOffsetCount); + } + + ASSERT_EQ(cmd->index, BindGroupIndex(index)); + ASSERT_EQ(ToAPI(cmd->group.Get()), bg.Get()); + ASSERT_EQ(cmd->dynamicOffsetCount, offsets.size()); + for (uint32_t i = 0; i < cmd->dynamicOffsetCount; ++i) { + ASSERT_EQ(dynamicOffsets[i], offsets[i]); + } + }; + }; + + // Initialize as null. Once we know the pointer, we'll check + // that it's the same buffer every time. + WGPUBuffer indirectScratchBuffer = nullptr; + auto ExpectDispatchIndirect = [&](CommandIterator* commands) { + auto* cmd = commands->NextCommand(); + if (indirectScratchBuffer == nullptr) { + indirectScratchBuffer = ToAPI(cmd->indirectBuffer.Get()); + } + ASSERT_EQ(ToAPI(cmd->indirectBuffer.Get()), indirectScratchBuffer); + ASSERT_EQ(cmd->indirectOffset, uint64_t(0)); + }; + + // Initialize as null. Once we know the pointer, we'll check + // that it's the same pipeline every time. + WGPUComputePipeline validationPipeline = nullptr; + auto ExpectSetValidationPipeline = [&](CommandIterator* commands) { + auto* cmd = commands->NextCommand(); + WGPUComputePipeline pipeline = ToAPI(cmd->pipeline.Get()); + if (validationPipeline != nullptr) { + EXPECT_EQ(pipeline, validationPipeline); + } else { + EXPECT_NE(pipeline, nullptr); + validationPipeline = pipeline; + } + }; + + auto ExpectSetValidationBindGroup = [&](CommandIterator* commands) { + auto* cmd = commands->NextCommand(); + ASSERT_EQ(cmd->index, BindGroupIndex(0)); + ASSERT_NE(cmd->group.Get(), nullptr); + ASSERT_EQ(cmd->dynamicOffsetCount, 0u); + }; + + auto ExpectSetValidationDispatch = [&](CommandIterator* commands) { + auto* cmd = commands->NextCommand(); + ASSERT_EQ(cmd->x, 1u); + ASSERT_EQ(cmd->y, 1u); + ASSERT_EQ(cmd->z, 1u); + }; + + ExpectCommands( + FromAPI(commandBuffer.Get())->GetCommandIteratorForTesting(), + { + {Command::BeginComputePass, + [&](CommandIterator* commands) { SkipCommand(commands, Command::BeginComputePass); }}, + // Expect the state to be set. + {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)}, + {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)}, + {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})}, + + // Expect the validation. + {Command::SetComputePipeline, ExpectSetValidationPipeline}, + {Command::SetBindGroup, ExpectSetValidationBindGroup}, + {Command::Dispatch, ExpectSetValidationDispatch}, + + // Expect the state to be restored. + {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)}, + {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)}, + {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})}, + + // Expect the dispatchIndirect. + {Command::DispatchIndirect, ExpectDispatchIndirect}, + + // Expect the validation. + {Command::SetComputePipeline, ExpectSetValidationPipeline}, + {Command::SetBindGroup, ExpectSetValidationBindGroup}, + {Command::Dispatch, ExpectSetValidationDispatch}, + + // Expect the state to be restored. + {Command::SetComputePipeline, ExpectSetPipeline(pipeline0)}, + {Command::SetBindGroup, ExpectSetBindGroup(0, staticBG)}, + {Command::SetBindGroup, ExpectSetBindGroup(1, dynamicBG, {dynamicOffset})}, + + // Expect the dispatchIndirect. + {Command::DispatchIndirect, ExpectDispatchIndirect}, + + // Expect the state to be set (new pipeline). + {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)}, + {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})}, + {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)}, + + // Expect the validation. + {Command::SetComputePipeline, ExpectSetValidationPipeline}, + {Command::SetBindGroup, ExpectSetValidationBindGroup}, + {Command::Dispatch, ExpectSetValidationDispatch}, + + // Expect the state to be restored. + {Command::SetComputePipeline, ExpectSetPipeline(pipeline1)}, + {Command::SetBindGroup, ExpectSetBindGroup(0, dynamicBG, {dynamicOffset})}, + {Command::SetBindGroup, ExpectSetBindGroup(1, staticBG)}, + + // Expect the dispatchIndirect. + {Command::DispatchIndirect, ExpectDispatchIndirect}, + + {Command::EndComputePass, + [&](CommandIterator* commands) { commands->NextCommand(); }}, + }); +} + +// Test that after restoring state, it is fully applied to the state tracker +// and does not leak state changes that occured between a snapshot and the +// state restoration. +TEST_F(CommandBufferEncodingTests, StateNotLeakedAfterRestore) { + using namespace dawn_native; + + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + + CommandBufferStateTracker* stateTracker = + FromAPI(pass.Get())->GetCommandBufferStateTrackerForTesting(); + + // Snapshot the state. + CommandBufferStateTracker snapshot = *stateTracker; + // Expect no pipeline in the snapshot + EXPECT_FALSE(snapshot.HasPipeline()); + + // Create a simple pipeline + wgpu::ComputePipelineDescriptor csDesc; + csDesc.compute.module = utils::CreateShaderModule(device, R"( + [[stage(compute), workgroup_size(1, 1, 1)]] + fn main() { + })"); + csDesc.compute.entryPoint = "main"; + wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc); + + // Set the pipeline. + pass.SetPipeline(pipeline); + + // Expect the pipeline to be set. + EXPECT_EQ(ToAPI(stateTracker->GetComputePipeline()), pipeline.Get()); + + // Restore the state. + FromAPI(pass.Get())->RestoreCommandBufferStateForTesting(std::move(snapshot)); + + // Expect no pipeline + EXPECT_FALSE(stateTracker->HasPipeline()); +}