diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index bd3989b8c0..cfb45a1927 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -677,6 +677,8 @@ namespace dawn_native { metadata->localWorkgroupSize.x = entryPoint.workgroup_size_x; metadata->localWorkgroupSize.y = entryPoint.workgroup_size_y; metadata->localWorkgroupSize.z = entryPoint.workgroup_size_z; + + metadata->usesNumWorkgroups = entryPoint.num_workgroups_used; } if (metadata->stage == SingleShaderStage::Vertex) { diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index 8f8081cedb..070d001650 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -204,6 +204,8 @@ namespace dawn_native { // Store overridableConstants from tint program std::unordered_map overridableConstants; + + bool usesNumWorkgroups = false; }; class ShaderModuleBase : public ApiObjectBase, public CachedObject { diff --git a/src/dawn_native/d3d12/CommandBufferD3D12.cpp b/src/dawn_native/d3d12/CommandBufferD3D12.cpp index 4bdc3b0760..dc32d10c29 100644 --- a/src/dawn_native/d3d12/CommandBufferD3D12.cpp +++ b/src/dawn_native/d3d12/CommandBufferD3D12.cpp @@ -254,6 +254,18 @@ namespace dawn_native { namespace d3d12 { return {}; } + void RecordNumWorkgroupsForDispatch(ID3D12GraphicsCommandList* commandList, + ComputePipeline* pipeline, + DispatchCmd* dispatch) { + if (!pipeline->UsesNumWorkgroups()) { + return; + } + + PipelineLayout* layout = ToBackend(pipeline->GetLayout()); + commandList->SetComputeRoot32BitConstants(layout->GetNumWorkgroupsParameterIndex(), 3, + dispatch, 0); + } + // Records the necessary barriers for a synchronization scope using the resource usage // data pre-computed in the frontend. Also performs lazy initialization if required. // Returns whether any UAV are used in the synchronization scope. @@ -1030,6 +1042,7 @@ namespace dawn_native { namespace d3d12 { ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList(); Command type; + ComputePipeline* lastPipeline = nullptr; while (mCommands.NextCommandId(&type)) { switch (type) { case Command::Dispatch: { @@ -1045,6 +1058,7 @@ namespace dawn_native { namespace d3d12 { resourceUsages.dispatchUsages[currentDispatch]); DAWN_TRY(bindingTracker->Apply(commandContext)); + RecordNumWorkgroupsForDispatch(commandList, lastPipeline, dispatch); commandList->Dispatch(dispatch->x, dispatch->y, dispatch->z); currentDispatch++; break; @@ -1052,6 +1066,14 @@ namespace dawn_native { namespace d3d12 { case Command::DispatchIndirect: { DispatchIndirectCmd* dispatch = mCommands.NextCommand(); + + // TODO(dawn:839): support [[num_workgroups]] for DispatchIndirect calls + if (lastPipeline->UsesNumWorkgroups()) { + return DAWN_VALIDATION_ERROR( + "Using a compute pipeline with [[num_workgroups]] in a " + "DispatchIndirect call is not implemented"); + } + Buffer* buffer = ToBackend(dispatch->indirectBuffer.Get()); TransitionAndClearForSyncScope(commandContext, @@ -1078,6 +1100,7 @@ namespace dawn_native { namespace d3d12 { commandList->SetPipelineState(pipeline->GetPipelineState()); bindingTracker->OnSetPipeline(pipeline); + lastPipeline = pipeline; break; } diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp index 29ae08af9d..54ddc7e143 100644 --- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp +++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp @@ -84,4 +84,8 @@ namespace dawn_native { namespace d3d12 { CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask)); } + bool ComputePipeline::UsesNumWorkgroups() const { + return GetStage(SingleShaderStage::Compute).metadata->usesNumWorkgroups; + } + }} // namespace dawn_native::d3d12 diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.h b/src/dawn_native/d3d12/ComputePipelineD3D12.h index 7c7a02d2c6..b652026013 100644 --- a/src/dawn_native/d3d12/ComputePipelineD3D12.h +++ b/src/dawn_native/d3d12/ComputePipelineD3D12.h @@ -40,6 +40,8 @@ namespace dawn_native { namespace d3d12 { // Dawn API void SetLabelImpl() override; + bool UsesNumWorkgroups() const; + private: ~ComputePipeline() override; using ComputePipelineBase::ComputePipelineBase; diff --git a/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp b/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp index 372b61b6d4..1a512fa60e 100644 --- a/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp +++ b/src/dawn_native/d3d12/PipelineLayoutD3D12.cpp @@ -174,6 +174,21 @@ namespace dawn_native { namespace d3d12 { // would need to be updated often rootParameters.emplace_back(indexOffsetConstants); + // Always allocate 3 constants for num_workgroups_x, num_workgroups_y and num_workgroups_z + // for Dispatch calls + // NOTE: We should consider delaying root signature creation until we know how many values + // we need + D3D12_ROOT_PARAMETER numWorkgroupsConstants{}; + numWorkgroupsConstants.ShaderVisibility = D3D12_SHADER_VISIBILITY_ALL; + numWorkgroupsConstants.ParameterType = D3D12_ROOT_PARAMETER_TYPE_32BIT_CONSTANTS; + numWorkgroupsConstants.Constants.Num32BitValues = 3; + numWorkgroupsConstants.Constants.RegisterSpace = GetNumWorkgroupsRegisterSpace(); + numWorkgroupsConstants.Constants.ShaderRegister = GetNumWorkgroupsShaderRegister(); + mNumWorkgroupsParamterIndex = rootParameters.size(); + // NOTE: We should consider moving this entry to earlier in the root signature since + // dispatch sizes would need to be updated often + rootParameters.emplace_back(numWorkgroupsConstants); + D3D12_ROOT_SIGNATURE_DESC rootSignatureDescriptor; rootSignatureDescriptor.NumParameters = rootParameters.size(); rootSignatureDescriptor.pParameters = rootParameters.data(); @@ -230,7 +245,7 @@ namespace dawn_native { namespace d3d12 { } uint32_t PipelineLayout::GetFirstIndexOffsetRegisterSpace() const { - return kReservedRegisterSpace; + return kFirstIndexOffsetRegisterSpace; } uint32_t PipelineLayout::GetFirstIndexOffsetShaderRegister() const { @@ -240,4 +255,16 @@ namespace dawn_native { namespace d3d12 { uint32_t PipelineLayout::GetFirstIndexOffsetParameterIndex() const { return mFirstIndexOffsetParameterIndex; } + + uint32_t PipelineLayout::GetNumWorkgroupsRegisterSpace() const { + return kNumWorkgroupsRegisterSpace; + } + + uint32_t PipelineLayout::GetNumWorkgroupsShaderRegister() const { + return kNumWorkgroupsBaseRegister; + } + + uint32_t PipelineLayout::GetNumWorkgroupsParameterIndex() const { + return mNumWorkgroupsParamterIndex; + } }} // namespace dawn_native::d3d12 diff --git a/src/dawn_native/d3d12/PipelineLayoutD3D12.h b/src/dawn_native/d3d12/PipelineLayoutD3D12.h index b1efc0d00f..cf52f066e7 100644 --- a/src/dawn_native/d3d12/PipelineLayoutD3D12.h +++ b/src/dawn_native/d3d12/PipelineLayoutD3D12.h @@ -26,6 +26,9 @@ namespace dawn_native { namespace d3d12 { // We reserve a register space that a user cannot use. static constexpr uint32_t kReservedRegisterSpace = kMaxBindGroups + 1; static constexpr uint32_t kFirstOffsetInfoBaseRegister = 0; + static constexpr uint32_t kFirstIndexOffsetRegisterSpace = kReservedRegisterSpace; + static constexpr uint32_t kNumWorkgroupsRegisterSpace = kReservedRegisterSpace + 1; + static constexpr uint32_t kNumWorkgroupsBaseRegister = 0; class Device; @@ -46,6 +49,10 @@ namespace dawn_native { namespace d3d12 { uint32_t GetFirstIndexOffsetShaderRegister() const; uint32_t GetFirstIndexOffsetParameterIndex() const; + uint32_t GetNumWorkgroupsRegisterSpace() const; + uint32_t GetNumWorkgroupsShaderRegister() const; + uint32_t GetNumWorkgroupsParameterIndex() const; + ID3D12RootSignature* GetRootSignature() const; private: @@ -59,6 +66,7 @@ namespace dawn_native { namespace d3d12 { kMaxBindGroups> mDynamicRootParameterIndices; uint32_t mFirstIndexOffsetParameterIndex; + uint32_t mNumWorkgroupsParamterIndex; ComPtr mRootSignature; }; diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp index 2dafc9623f..89b5825977 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp @@ -106,6 +106,9 @@ namespace dawn_native { namespace d3d12 { tint::transform::BindingRemapper::BindingPoints bindingPoints; tint::transform::BindingRemapper::AccessControls accessControls; bool isRobustnessEnabled; + bool usesNumWorkgroups; + uint32_t numWorkgroupsRegisterSpace; + uint32_t numWorkgroupsShaderRegister; // FXC/DXC common inputs bool disableWorkgroupInit; @@ -125,7 +128,7 @@ namespace dawn_native { namespace d3d12 { uint32_t compileFlags, const Device* device, const tint::Program* program, - const BindingInfoArray& moduleBindingInfo) { + const EntryPointMetadata& entryPoint) { Compiler compiler; uint64_t dxcVersion = 0; if (device->IsToggleEnabled(Toggle::UseDXC)) { @@ -145,6 +148,7 @@ namespace dawn_native { namespace d3d12 { // Tint AST to make the "bindings" decoration match the offset chosen by // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers // assigned to each interface variable. + const BindingInfoArray& moduleBindingInfo = entryPoint.bindings; for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); const auto& groupBindingInfo = moduleBindingInfo[group]; @@ -189,6 +193,9 @@ namespace dawn_native { namespace d3d12 { request.isRobustnessEnabled = device->IsRobustnessEnabled(); request.disableWorkgroupInit = device->IsToggleEnabled(Toggle::DisableWorkgroupInit); + request.usesNumWorkgroups = entryPoint.usesNumWorkgroups; + request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister(); + request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace(); request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0; request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0; request.deviceInfo = &device->GetDeviceInfo(); @@ -234,6 +241,10 @@ namespace dawn_native { namespace d3d12 { stream << " accessControls="; Serialize(stream, accessControls); + stream << " useNumWorkgroups=" << usesNumWorkgroups; + stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace; + stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister; + stream << " shaderModel=" << deviceInfo->shaderModel; stream << " disableWorkgroupInit=" << disableWorkgroupInit; stream << " isRobustnessEnabled=" << isRobustnessEnabled; @@ -423,6 +434,10 @@ namespace dawn_native { namespace d3d12 { tint::writer::hlsl::Options options; options.disable_workgroup_init = request.disableWorkgroupInit; + if (request.usesNumWorkgroups) { + options.root_constant_binding_point.group = request.numWorkgroupsRegisterSpace; + options.root_constant_binding_point.binding = request.numWorkgroupsShaderRegister; + } auto result = tint::writer::hlsl::Generate(&transformedProgram, options); if (!result.success) { errorStream << "Generator: " << result.error << std::endl; @@ -547,9 +562,9 @@ namespace dawn_native { namespace d3d12 { } ShaderCompilationRequest request; - DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create( - entryPointName, stage, layout, compileFlags, device, program, - GetEntryPoint(entryPointName).bindings)); + DAWN_TRY_ASSIGN(request, ShaderCompilationRequest::Create(entryPointName, stage, layout, + compileFlags, device, program, + GetEntryPoint(entryPointName))); PersistentCacheKey shaderCacheKey; DAWN_TRY_ASSIGN(shaderCacheKey, request.CreateCacheKey()); diff --git a/src/tests/end2end/ComputeDispatchTests.cpp b/src/tests/end2end/ComputeDispatchTests.cpp index 7ec40762cb..1a8b163f5d 100644 --- a/src/tests/end2end/ComputeDispatchTests.cpp +++ b/src/tests/end2end/ComputeDispatchTests.cpp @@ -26,9 +26,30 @@ class ComputeDispatchTests : public DawnTest { DawnTest::SetUp(); // Write workgroup number into the output buffer if we saw the biggest dispatch - // This is a workaround since D3D12 doesn't have gl_NumWorkGroups // To make sure the dispatch was not called, write maximum u32 value for 0 dispatches - wgpu::ShaderModule module = utils::CreateShaderModule(device, R"( + wgpu::ShaderModule moduleForDispatch = utils::CreateShaderModule(device, R"( + [[block]] struct OutputBuf { + workGroups : vec3; + }; + + [[group(0), binding(0)]] var output : OutputBuf; + + [[stage(compute), workgroup_size(1, 1, 1)]] + fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3, + [[builtin(num_workgroups)]] dispatch : vec3) { + if (dispatch.x == 0u || dispatch.y == 0u || dispatch.z == 0u) { + output.workGroups = vec3(0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu); + return; + } + + if (all(GlobalInvocationID == dispatch - vec3(1u, 1u, 1u))) { + output.workGroups = dispatch; + } + })"); + + // TODO(dawn:839): use moduleForDispatch for indirect dispatch tests when D3D12 supports + // [[num_workgroups]] for indirect dispatch. + wgpu::ShaderModule moduleForDispatchIndirect = utils::CreateShaderModule(device, R"( [[block]] struct InputBuf { expectedDispatch : vec3; }; @@ -54,9 +75,12 @@ class ComputeDispatchTests : public DawnTest { })"); wgpu::ComputePipelineDescriptor csDesc; - csDesc.compute.module = module; + csDesc.compute.module = moduleForDispatch; csDesc.compute.entryPoint = "main"; - pipeline = device.CreateComputePipeline(&csDesc); + pipelineForDispatch = device.CreateComputePipeline(&csDesc); + + csDesc.compute.module = moduleForDispatchIndirect; + pipelineForDispatchIndirect = device.CreateComputePipeline(&csDesc); } void DirectTest(uint32_t x, uint32_t y, uint32_t z) { @@ -66,23 +90,18 @@ class ComputeDispatchTests : public DawnTest { wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst, kSentinelData); - std::initializer_list expectedBufferData{x, y, z}; - wgpu::Buffer expectedBuffer = utils::CreateBufferFromData( - device, wgpu::BufferUsage::Uniform, expectedBufferData); - // Set up bind group and issue dispatch wgpu::BindGroup bindGroup = - utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), + utils::MakeBindGroup(device, pipelineForDispatch.GetBindGroupLayout(0), { - {0, expectedBuffer, 0, 3 * sizeof(uint32_t)}, - {1, dst, 0, 3 * sizeof(uint32_t)}, + {0, dst, 0, 3 * sizeof(uint32_t)}, }); wgpu::CommandBuffer commands; { wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(pipeline); + pass.SetPipeline(pipelineForDispatch); pass.SetBindGroup(0, bindGroup); pass.Dispatch(x, y, z); pass.EndPass(); @@ -93,7 +112,7 @@ class ComputeDispatchTests : public DawnTest { queue.Submit(1, &commands); std::vector expected = - x == 0 || y == 0 || z == 0 ? kSentinelData : expectedBufferData; + x == 0 || y == 0 || z == 0 ? kSentinelData : std::initializer_list{x, y, z}; // Verify the dispatch got called if all group counts are not zero EXPECT_BUFFER_U32_RANGE_EQ(&expected[0], dst, 0, 3); @@ -118,7 +137,7 @@ class ComputeDispatchTests : public DawnTest { // Set up bind group and issue dispatch wgpu::BindGroup bindGroup = - utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), + utils::MakeBindGroup(device, pipelineForDispatchIndirect.GetBindGroupLayout(0), { {0, expectedBuffer, 0, 3 * sizeof(uint32_t)}, {1, dst, 0, 3 * sizeof(uint32_t)}, @@ -128,7 +147,7 @@ class ComputeDispatchTests : public DawnTest { { wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); - pass.SetPipeline(pipeline); + pass.SetPipeline(pipelineForDispatchIndirect); pass.SetBindGroup(0, bindGroup); pass.DispatchIndirect(indirectBuffer, indirectOffset); pass.EndPass(); @@ -153,7 +172,8 @@ class ComputeDispatchTests : public DawnTest { } private: - wgpu::ComputePipeline pipeline; + wgpu::ComputePipeline pipelineForDispatch; + wgpu::ComputePipeline pipelineForDispatchIndirect; }; // Test basic direct