// Copyright 2018 The Dawn Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "dawn_native/ComputePassEncoder.h" #include "dawn_native/BindGroup.h" #include "dawn_native/BindGroupLayout.h" #include "dawn_native/Buffer.h" #include "dawn_native/CommandEncoder.h" #include "dawn_native/CommandValidation.h" #include "dawn_native/Commands.h" #include "dawn_native/ComputePipeline.h" #include "dawn_native/Device.h" #include "dawn_native/InternalPipelineStore.h" #include "dawn_native/ObjectType_autogen.h" #include "dawn_native/PassResourceUsageTracker.h" #include "dawn_native/QuerySet.h" #include "dawn_native/utils/WGPUHelpers.h" namespace dawn::native { namespace { ResultOrError GetOrCreateIndirectDispatchValidationPipeline( DeviceBase* device) { InternalPipelineStore* store = device->GetInternalPipelineStore(); if (store->dispatchIndirectValidationPipeline != nullptr) { return store->dispatchIndirectValidationPipeline.Get(); } // TODO(https://crbug.com/dawn/1108): Propagate validation feedback from this // shader in various failure modes. // Type 'bool' cannot be used in storage class 'uniform' as it is non-host-shareable. Ref shaderModule; DAWN_TRY_ASSIGN(shaderModule, utils::CreateShaderModule(device, R"( struct UniformParams { maxComputeWorkgroupsPerDimension: u32; clientOffsetInU32: u32; enableValidation: u32; duplicateNumWorkgroups: u32; }; struct IndirectParams { data: array; }; struct ValidatedParams { data: array; }; [[group(0), binding(0)]] var uniformParams: UniformParams; [[group(0), binding(1)]] var clientParams: IndirectParams; [[group(0), binding(2)]] var validatedParams: ValidatedParams; [[stage(compute), workgroup_size(1, 1, 1)]] fn main() { for (var i = 0u; i < 3u; i = i + 1u) { var numWorkgroups = clientParams.data[uniformParams.clientOffsetInU32 + i]; if (uniformParams.enableValidation > 0u && numWorkgroups > uniformParams.maxComputeWorkgroupsPerDimension) { numWorkgroups = 0u; } validatedParams.data[i] = numWorkgroups; if (uniformParams.duplicateNumWorkgroups > 0u) { validatedParams.data[i + 3u] = numWorkgroups; } } } )")); Ref bindGroupLayout; DAWN_TRY_ASSIGN( bindGroupLayout, utils::MakeBindGroupLayout( device, { {0, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Uniform}, {1, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding}, {2, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Storage}, }, /* allowInternalBinding */ true)); Ref pipelineLayout; DAWN_TRY_ASSIGN(pipelineLayout, utils::MakeBasicPipelineLayout(device, bindGroupLayout)); ComputePipelineDescriptor computePipelineDescriptor = {}; computePipelineDescriptor.layout = pipelineLayout.Get(); computePipelineDescriptor.compute.module = shaderModule.Get(); computePipelineDescriptor.compute.entryPoint = "main"; DAWN_TRY_ASSIGN(store->dispatchIndirectValidationPipeline, device->CreateComputePipeline(&computePipelineDescriptor)); return store->dispatchIndirectValidationPipeline.Get(); } } // namespace ComputePassEncoder::ComputePassEncoder(DeviceBase* device, const ComputePassDescriptor* descriptor, CommandEncoder* commandEncoder, EncodingContext* encodingContext) : ProgrammableEncoder(device, descriptor->label, encodingContext), mCommandEncoder(commandEncoder) { TrackInDevice(); } ComputePassEncoder::ComputePassEncoder(DeviceBase* device, CommandEncoder* commandEncoder, EncodingContext* encodingContext, ErrorTag errorTag) : ProgrammableEncoder(device, encodingContext, errorTag), mCommandEncoder(commandEncoder) { } ComputePassEncoder* ComputePassEncoder::MakeError(DeviceBase* device, CommandEncoder* commandEncoder, EncodingContext* encodingContext) { return new ComputePassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError); } void ComputePassEncoder::DestroyImpl() { // Ensure that the pass has exited. This is done for passes only since validation requires // they exit before destruction while bundles do not. mEncodingContext->EnsurePassExited(this); } ObjectType ComputePassEncoder::GetType() const { return ObjectType::ComputePassEncoder; } void ComputePassEncoder::APIEndPass() { if (mEncodingContext->TryEncode( this, [&](CommandAllocator* allocator) -> MaybeError { if (IsValidationEnabled()) { DAWN_TRY(ValidateProgrammableEncoderEnd()); } allocator->Allocate(Command::EndComputePass); return {}; }, "encoding %s.EndPass().", this)) { mEncodingContext->ExitComputePass(this, mUsageTracker.AcquireResourceUsage()); } } void ComputePassEncoder::APIDispatch(uint32_t x, uint32_t y, uint32_t z) { mEncodingContext->TryEncode( this, [&](CommandAllocator* allocator) -> MaybeError { if (IsValidationEnabled()) { DAWN_TRY(mCommandBufferState.ValidateCanDispatch()); uint32_t workgroupsPerDimension = GetDevice()->GetLimits().v1.maxComputeWorkgroupsPerDimension; DAWN_INVALID_IF( x > workgroupsPerDimension, "Dispatch size X (%u) exceeds max compute workgroups per dimension (%u).", x, workgroupsPerDimension); DAWN_INVALID_IF( y > workgroupsPerDimension, "Dispatch size Y (%u) exceeds max compute workgroups per dimension (%u).", y, workgroupsPerDimension); DAWN_INVALID_IF( z > workgroupsPerDimension, "Dispatch size Z (%u) exceeds max compute workgroups per dimension (%u).", z, workgroupsPerDimension); } // Record the synchronization scope for Dispatch, which is just the current // bindgroups. AddDispatchSyncScope(); DispatchCmd* dispatch = allocator->Allocate(Command::Dispatch); dispatch->x = x; dispatch->y = y; dispatch->z = z; return {}; }, "encoding %s.Dispatch(%u, %u, %u).", this, x, y, z); } ResultOrError, uint64_t>> ComputePassEncoder::TransformIndirectDispatchBuffer(Ref indirectBuffer, uint64_t indirectOffset) { DeviceBase* device = GetDevice(); const bool shouldDuplicateNumWorkgroups = device->ShouldDuplicateNumWorkgroupsForDispatchIndirect( mCommandBufferState.GetComputePipeline()); if (!IsValidationEnabled() && !shouldDuplicateNumWorkgroups) { return std::make_pair(indirectBuffer, indirectOffset); } // Save the previous command buffer state so it can be restored after the // validation inserts additional commands. CommandBufferStateTracker previousState = mCommandBufferState; auto* const store = device->GetInternalPipelineStore(); Ref validationPipeline; DAWN_TRY_ASSIGN(validationPipeline, GetOrCreateIndirectDispatchValidationPipeline(device)); Ref layout; DAWN_TRY_ASSIGN(layout, validationPipeline->GetBindGroupLayout(0)); uint32_t storageBufferOffsetAlignment = device->GetLimits().v1.minStorageBufferOffsetAlignment; // Let the offset be the indirectOffset, aligned down to |storageBufferOffsetAlignment|. const uint32_t clientOffsetFromAlignedBoundary = indirectOffset % storageBufferOffsetAlignment; const uint64_t clientOffsetAlignedDown = indirectOffset - clientOffsetFromAlignedBoundary; const uint64_t clientIndirectBindingOffset = clientOffsetAlignedDown; // Let the size of the binding be the additional offset, plus the size. const uint64_t clientIndirectBindingSize = kDispatchIndirectSize + clientOffsetFromAlignedBoundary; // Neither 'enableValidation' nor 'duplicateNumWorkgroups' can be declared as 'bool' as // currently in WGSL type 'bool' cannot be used in storage class 'uniform' as 'it is // non-host-shareable'. struct UniformParams { uint32_t maxComputeWorkgroupsPerDimension; uint32_t clientOffsetInU32; uint32_t enableValidation; uint32_t duplicateNumWorkgroups; }; // Create a uniform buffer to hold parameters for the shader. Ref uniformBuffer; { UniformParams params; params.maxComputeWorkgroupsPerDimension = device->GetLimits().v1.maxComputeWorkgroupsPerDimension; params.clientOffsetInU32 = clientOffsetFromAlignedBoundary / sizeof(uint32_t); params.enableValidation = static_cast(IsValidationEnabled()); params.duplicateNumWorkgroups = static_cast(shouldDuplicateNumWorkgroups); DAWN_TRY_ASSIGN(uniformBuffer, utils::CreateBufferFromData( device, wgpu::BufferUsage::Uniform, {params})); } // Reserve space in the scratch buffer to hold the validated indirect params. ScratchBuffer& scratchBuffer = store->scratchIndirectStorage; const uint64_t scratchBufferSize = shouldDuplicateNumWorkgroups ? 2 * kDispatchIndirectSize : kDispatchIndirectSize; DAWN_TRY(scratchBuffer.EnsureCapacity(scratchBufferSize)); Ref validatedIndirectBuffer = scratchBuffer.GetBuffer(); Ref validationBindGroup; ASSERT(indirectBuffer->GetUsage() & kInternalStorageBuffer); DAWN_TRY_ASSIGN(validationBindGroup, utils::MakeBindGroup(device, layout, { {0, uniformBuffer}, {1, indirectBuffer, clientIndirectBindingOffset, clientIndirectBindingSize}, {2, validatedIndirectBuffer, 0, scratchBufferSize}, })); // Issue commands to validate the indirect buffer. APISetPipeline(validationPipeline.Get()); APISetBindGroup(0, validationBindGroup.Get()); APIDispatch(1); // Restore the state. RestoreCommandBufferState(std::move(previousState)); // Return the new indirect buffer and indirect buffer offset. return std::make_pair(std::move(validatedIndirectBuffer), uint64_t(0)); } void ComputePassEncoder::APIDispatchIndirect(BufferBase* indirectBuffer, uint64_t indirectOffset) { mEncodingContext->TryEncode( this, [&](CommandAllocator* allocator) -> MaybeError { if (IsValidationEnabled()) { DAWN_TRY(GetDevice()->ValidateObject(indirectBuffer)); DAWN_TRY(ValidateCanUseAs(indirectBuffer, wgpu::BufferUsage::Indirect)); DAWN_TRY(mCommandBufferState.ValidateCanDispatch()); DAWN_INVALID_IF(indirectOffset % 4 != 0, "Indirect offset (%u) is not a multiple of 4.", indirectOffset); DAWN_INVALID_IF( indirectOffset >= indirectBuffer->GetSize() || indirectOffset + kDispatchIndirectSize > indirectBuffer->GetSize(), "Indirect offset (%u) and dispatch size (%u) exceeds the indirect buffer " "size (%u).", indirectOffset, kDispatchIndirectSize, indirectBuffer->GetSize()); } SyncScopeUsageTracker scope; scope.BufferUsedAs(indirectBuffer, wgpu::BufferUsage::Indirect); mUsageTracker.AddReferencedBuffer(indirectBuffer); // TODO(crbug.com/dawn/1166): If validation is enabled, adding |indirectBuffer| // is needed for correct usage validation even though it will only be bound for // storage. This will unecessarily transition the |indirectBuffer| in // the backend. Ref indirectBufferRef = indirectBuffer; // Get applied indirect buffer with necessary changes on the original indirect // buffer. For example, // - Validate each indirect dispatch with a single dispatch to copy the indirect // buffer params into a scratch buffer if they're valid, and otherwise zero them // out. // - Duplicate all the indirect dispatch parameters to support [[num_workgroups]] on // D3D12. // - Directly return the original indirect dispatch buffer if we don't need any // transformations on it. // We could consider moving the validation earlier in the pass after the last // last point the indirect buffer was used with writable usage, as well as batch // validation for multiple dispatches into one, but inserting commands at // arbitrary points in the past is not possible right now. DAWN_TRY_ASSIGN(std::tie(indirectBufferRef, indirectOffset), TransformIndirectDispatchBuffer(indirectBufferRef, indirectOffset)); // If we have created a new scratch dispatch indirect buffer in // TransformIndirectDispatchBuffer(), we need to track it in mUsageTracker. if (indirectBufferRef.Get() != indirectBuffer) { // |indirectBufferRef| was replaced with a scratch buffer. Add it to the // synchronization scope. scope.BufferUsedAs(indirectBufferRef.Get(), wgpu::BufferUsage::Indirect); mUsageTracker.AddReferencedBuffer(indirectBufferRef.Get()); } AddDispatchSyncScope(std::move(scope)); DispatchIndirectCmd* dispatch = allocator->Allocate(Command::DispatchIndirect); dispatch->indirectBuffer = std::move(indirectBufferRef); dispatch->indirectOffset = indirectOffset; return {}; }, "encoding %s.DispatchIndirect(%s, %u).", this, indirectBuffer, indirectOffset); } void ComputePassEncoder::APISetPipeline(ComputePipelineBase* pipeline) { mEncodingContext->TryEncode( this, [&](CommandAllocator* allocator) -> MaybeError { if (IsValidationEnabled()) { DAWN_TRY(GetDevice()->ValidateObject(pipeline)); } mCommandBufferState.SetComputePipeline(pipeline); SetComputePipelineCmd* cmd = allocator->Allocate(Command::SetComputePipeline); cmd->pipeline = pipeline; return {}; }, "encoding %s.SetPipeline(%s).", this, pipeline); } void ComputePassEncoder::APISetBindGroup(uint32_t groupIndexIn, BindGroupBase* group, uint32_t dynamicOffsetCount, const uint32_t* dynamicOffsets) { mEncodingContext->TryEncode( this, [&](CommandAllocator* allocator) -> MaybeError { BindGroupIndex groupIndex(groupIndexIn); if (IsValidationEnabled()) { DAWN_TRY(ValidateSetBindGroup(groupIndex, group, dynamicOffsetCount, dynamicOffsets)); } mUsageTracker.AddResourcesReferencedByBindGroup(group); RecordSetBindGroup(allocator, groupIndex, group, dynamicOffsetCount, dynamicOffsets); mCommandBufferState.SetBindGroup(groupIndex, group, dynamicOffsetCount, dynamicOffsets); return {}; }, "encoding %s.SetBindGroup(%u, %s, %u, ...).", this, groupIndexIn, group, dynamicOffsetCount); } void ComputePassEncoder::APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex) { mEncodingContext->TryEncode( this, [&](CommandAllocator* allocator) -> MaybeError { if (IsValidationEnabled()) { DAWN_TRY(GetDevice()->ValidateObject(querySet)); DAWN_TRY(ValidateTimestampQuery(querySet, queryIndex)); } mCommandEncoder->TrackQueryAvailability(querySet, queryIndex); WriteTimestampCmd* cmd = allocator->Allocate(Command::WriteTimestamp); cmd->querySet = querySet; cmd->queryIndex = queryIndex; return {}; }, "encoding %s.WriteTimestamp(%s, %u).", this, querySet, queryIndex); } void ComputePassEncoder::AddDispatchSyncScope(SyncScopeUsageTracker scope) { PipelineLayoutBase* layout = mCommandBufferState.GetPipelineLayout(); for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) { scope.AddBindGroup(mCommandBufferState.GetBindGroup(i)); } mUsageTracker.AddDispatch(scope.AcquireSyncScopeUsage()); } void ComputePassEncoder::RestoreCommandBufferState(CommandBufferStateTracker state) { // Encode commands for the backend to restore the pipeline and bind groups. if (state.HasPipeline()) { APISetPipeline(state.GetComputePipeline()); } for (BindGroupIndex i(0); i < kMaxBindGroupsTyped; ++i) { BindGroupBase* bg = state.GetBindGroup(i); if (bg != nullptr) { const std::vector& offsets = state.GetDynamicOffsets(i); if (offsets.empty()) { APISetBindGroup(static_cast(i), bg); } else { APISetBindGroup(static_cast(i), bg, offsets.size(), offsets.data()); } } } // Restore the frontend state tracking information. mCommandBufferState = std::move(state); } CommandBufferStateTracker* ComputePassEncoder::GetCommandBufferStateTrackerForTesting() { return &mCommandBufferState; } } // namespace dawn::native