d3d12: track graphics/compute state independently

Fixes a bug where Dawn incorrectly did not re-apply state
when transitioning between compute and render passes. If
a compute and render pipeline share the same pipeline layout,
all of the resources for the graphics pipeline need to be rebound
since the graphics state in D3D12 is disjoint from the compute
state.

Fixed: dawn:1689
Change-Id: I7d25a1c7954039c4130e67b682ebc05324353e9a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/124540
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Loko Kung <lokokung@google.com>
This commit is contained in:
Austin Eng 2023-03-17 18:42:52 +00:00 committed by Dawn LUCI CQ
parent b4c5e0d32a
commit a66fa9b06f
2 changed files with 130 additions and 29 deletions

View File

@ -383,18 +383,20 @@ MaybeError TransitionAndClearForSyncScope(CommandRecordingContext* commandContex
} // anonymous namespace
class DescriptorHeapState;
class BindGroupStateTracker : public BindGroupTrackerBase<false, uint64_t> {
using Base = BindGroupTrackerBase;
public:
explicit BindGroupStateTracker(Device* device)
BindGroupStateTracker(Device* device, DescriptorHeapState* heapState, bool inCompute)
: BindGroupTrackerBase(),
mDevice(device),
mHeapState(heapState),
mInCompute(inCompute),
mViewAllocator(device->GetViewShaderVisibleDescriptorAllocator()),
mSamplerAllocator(device->GetSamplerShaderVisibleDescriptorAllocator()) {}
void SetInComputePass(bool inCompute_) { mInCompute = inCompute_; }
MaybeError Apply(CommandRecordingContext* commandContext) {
BeforeApply();
@ -454,20 +456,9 @@ class BindGroupStateTracker : public BindGroupTrackerBase<false, uint64_t> {
return {};
}
void SetID3D12DescriptorHeaps(ID3D12GraphicsCommandList* commandList) {
ASSERT(commandList != nullptr);
std::array<ID3D12DescriptorHeap*, 2> descriptorHeaps = {
mViewAllocator->GetShaderVisibleHeap(), mSamplerAllocator->GetShaderVisibleHeap()};
ASSERT(descriptorHeaps[0] != nullptr);
ASSERT(descriptorHeaps[1] != nullptr);
commandList->SetDescriptorHeaps(descriptorHeaps.size(), descriptorHeaps.data());
void ResetRootSamplerTables() { mBoundRootSamplerTables = {}; }
// Descriptor table state is undefined at the beginning of a command list and after
// descriptor heaps are changed on a command list. Invalidate the root sampler tables to
// reset the root descriptor table for samplers, otherwise the shader cannot access the
// descriptor heaps.
mBoundRootSamplerTables = {};
}
void SetID3D12DescriptorHeaps(ID3D12GraphicsCommandList* commandList);
private:
void UpdateRootSignatureIfNecessary(ID3D12GraphicsCommandList* commandList) {
@ -480,7 +471,7 @@ class BindGroupStateTracker : public BindGroupTrackerBase<false, uint64_t> {
ToBackend(mPipelineLayout)->GetRootSignature());
}
// Invalidate the root sampler tables previously set in the root signature.
mBoundRootSamplerTables = {};
ResetRootSamplerTables();
}
}
@ -607,6 +598,7 @@ class BindGroupStateTracker : public BindGroupTrackerBase<false, uint64_t> {
}
Device* mDevice;
DescriptorHeapState* mHeapState;
bool mInCompute = false;
@ -617,6 +609,43 @@ class BindGroupStateTracker : public BindGroupTrackerBase<false, uint64_t> {
ShaderVisibleDescriptorAllocator* mSamplerAllocator;
};
class DescriptorHeapState {
public:
explicit DescriptorHeapState(Device* device)
: mDevice(device),
mComputeBindingTracker(device, this, true),
mGraphicsBindingTracker(device, this, false) {}
void SetID3D12DescriptorHeaps(ID3D12GraphicsCommandList* commandList) {
ASSERT(commandList != nullptr);
std::array<ID3D12DescriptorHeap*, 2> descriptorHeaps = {
mDevice->GetViewShaderVisibleDescriptorAllocator()->GetShaderVisibleHeap(),
mDevice->GetSamplerShaderVisibleDescriptorAllocator()->GetShaderVisibleHeap()};
ASSERT(descriptorHeaps[0] != nullptr);
ASSERT(descriptorHeaps[1] != nullptr);
commandList->SetDescriptorHeaps(descriptorHeaps.size(), descriptorHeaps.data());
// Descriptor table state is undefined at the beginning of a command list and after
// descriptor heaps are changed on a command list. Invalidate the root sampler tables to
// reset the root descriptor table for samplers, otherwise the shader cannot access the
// descriptor heaps.
mComputeBindingTracker.ResetRootSamplerTables();
mGraphicsBindingTracker.ResetRootSamplerTables();
}
BindGroupStateTracker* GetComputeBindingTracker() { return &mComputeBindingTracker; }
BindGroupStateTracker* GetGraphicsBindingTracker() { return &mGraphicsBindingTracker; }
private:
Device* mDevice;
BindGroupStateTracker mComputeBindingTracker;
BindGroupStateTracker mGraphicsBindingTracker;
};
void BindGroupStateTracker::SetID3D12DescriptorHeaps(ID3D12GraphicsCommandList* commandList) {
mHeapState->SetID3D12DescriptorHeaps(commandList);
}
namespace {
class VertexBufferTracker {
public:
@ -726,13 +755,12 @@ CommandBuffer::CommandBuffer(CommandEncoder* encoder, const CommandBufferDescrip
MaybeError CommandBuffer::RecordCommands(CommandRecordingContext* commandContext) {
Device* device = ToBackend(GetDevice());
BindGroupStateTracker bindingTracker(device);
ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
DescriptorHeapState descriptorHeapState(device);
// Make sure we use the correct descriptors for this command list. Could be done once per
// actual command list but here is ok because there should be few command buffers.
bindingTracker.SetID3D12DescriptorHeaps(commandList);
ID3D12GraphicsCommandList* commandList = commandContext->GetCommandList();
descriptorHeapState.SetID3D12DescriptorHeaps(commandList);
size_t nextComputePassNumber = 0;
size_t nextRenderPassNumber = 0;
@ -743,11 +771,9 @@ MaybeError CommandBuffer::RecordCommands(CommandRecordingContext* commandContext
case Command::BeginComputePass: {
BeginComputePassCmd* cmd = mCommands.NextCommand<BeginComputePassCmd>();
bindingTracker.SetInComputePass(true);
DAWN_TRY(
RecordComputePass(commandContext, &bindingTracker, cmd,
GetResourceUsages().computePasses[nextComputePassNumber]));
DAWN_TRY(RecordComputePass(
commandContext, descriptorHeapState.GetComputeBindingTracker(), cmd,
GetResourceUsages().computePasses[nextComputePassNumber]));
nextComputePassNumber++;
break;
@ -761,11 +787,11 @@ MaybeError CommandBuffer::RecordCommands(CommandRecordingContext* commandContext
DAWN_TRY(TransitionAndClearForSyncScope(
commandContext, GetResourceUsages().renderPasses[nextRenderPassNumber],
&passHasUAV));
bindingTracker.SetInComputePass(false);
LazyClearRenderPassAttachments(beginRenderPassCmd);
DAWN_TRY(RecordRenderPass(commandContext, &bindingTracker, beginRenderPassCmd,
passHasUAV));
DAWN_TRY(RecordRenderPass(commandContext,
descriptorHeapState.GetGraphicsBindingTracker(),
beginRenderPassCmd, passHasUAV));
nextRenderPassNumber++;
break;

View File

@ -16,6 +16,7 @@
#include "dawn/common/Constants.h"
#include "dawn/tests/DawnTest.h"
#include "dawn/utils/WGPUHelpers.h"
class PipelineLayoutTests : public DawnTest {};
@ -68,6 +69,80 @@ TEST_P(PipelineLayoutTests, DynamicBuffersOverflow) {
device.CreatePipelineLayout(&descriptor);
}
// Regression test for crbug.com/dawn/1689. Test using a compute pass and a render pass,
// where the two pipelines have the same pipeline layout.
TEST_P(PipelineLayoutTests, ComputeAndRenderSamePipelineLayout) {
wgpu::TextureFormat format = wgpu::TextureFormat::RGBA8Unorm;
wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, R"(
@compute @workgroup_size(8, 8)
fn computeMain() {}
@vertex fn vertexMain() -> @builtin(position) vec4f {
return vec4f(0.0);
}
@fragment fn fragmentMain() -> @location(0) vec4f {
return vec4f(0.0, 0.0, 0.0, 1.0);
}
)");
wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout(
device, {{0, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Uniform}});
wgpu::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
wgpu::ComputePipeline computePipeline;
{
wgpu::ComputePipelineDescriptor desc = {};
desc.layout = pl;
desc.compute.module = shaderModule;
desc.compute.entryPoint = "computeMain";
computePipeline = device.CreateComputePipeline(&desc);
}
wgpu::RenderPipeline renderPipeline;
{
wgpu::RenderPipelineDescriptor desc = {};
desc.layout = pl;
desc.vertex.module = shaderModule;
desc.vertex.entryPoint = "vertexMain";
wgpu::FragmentState fragment = {};
desc.fragment = &fragment;
fragment.module = shaderModule;
fragment.entryPoint = "fragmentMain";
fragment.targetCount = 1;
wgpu::ColorTargetState colorTargetState = {};
colorTargetState.format = format;
fragment.targets = &colorTargetState;
renderPipeline = device.CreateRenderPipeline(&desc);
}
wgpu::Buffer buffer = utils::CreateBufferFromData(device, wgpu::BufferUsage::Uniform, {1});
wgpu::BindGroup bg0 = utils::MakeBindGroup(device, bgl, {{0, buffer}});
wgpu::BindGroup bg1 = utils::MakeBindGroup(device, bgl, {{0, buffer}});
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
{
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(computePipeline);
pass.SetBindGroup(0, bg0);
pass.DispatchWorkgroups(1);
pass.End();
}
{
utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, 4, 4, format);
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo);
pass.SetPipeline(renderPipeline);
pass.SetBindGroup(0, bg1);
pass.Draw(1);
pass.End();
}
wgpu::CommandBuffer commands = encoder.Finish();
queue.Submit(1, &commands);
}
DAWN_INSTANTIATE_TEST(PipelineLayoutTests,
D3D12Backend(),
MetalBackend(),