D3D12: Factor SetVertexBuffer tracking to match other tracking classes

Bug: dawn:201
Change-Id: I711e93a706b5043318263b203d3f3dc7f1a675bb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/11000
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Austin Eng 2019-09-09 21:09:11 +00:00 committed by Commit Bot service account
parent 882ff72742
commit 8e37315012
2 changed files with 85 additions and 83 deletions

View File

@ -38,6 +38,7 @@
namespace dawn_native { namespace d3d12 {
namespace {
DXGI_FORMAT DXGIIndexFormat(dawn::IndexFormat format) {
switch (format) {
case dawn::IndexFormat::Uint16:
@ -63,6 +64,12 @@ namespace dawn_native { namespace d3d12 {
return false;
}
struct OMSetRenderTargetArgs {
unsigned int numRTVs = 0;
std::array<D3D12_CPU_DESCRIPTOR_HANDLE, kMaxColorAttachments> RTVs = {};
D3D12_CPU_DESCRIPTOR_HANDLE dsv = {};
};
} // anonymous namespace
class BindGroupStateTracker {
@ -291,12 +298,6 @@ namespace dawn_native { namespace d3d12 {
Device* mDevice;
};
struct OMSetRenderTargetArgs {
unsigned int numRTVs = 0;
std::array<D3D12_CPU_DESCRIPTOR_HANDLE, kMaxColorAttachments> RTVs = {};
D3D12_CPU_DESCRIPTOR_HANDLE dsv = {};
};
class RenderPassDescriptorHeapTracker {
public:
RenderPassDescriptorHeapTracker(Device* device) : mDevice(device) {
@ -325,8 +326,8 @@ namespace dawn_native { namespace d3d12 {
}
}
// TODO(jiawei.shao@intel.com): use hash map <RenderPass, OMSetRenderTargetArgs> as cache to
// avoid redundant RTV and DSV memory allocations.
// TODO(jiawei.shao@intel.com): use hash map <RenderPass, OMSetRenderTargetArgs> as
// cache to avoid redundant RTV and DSV memory allocations.
OMSetRenderTargetArgs GetSubpassOMSetRenderTargetArgs(BeginRenderPassCmd* renderPass) {
OMSetRenderTargetArgs args = {};
@ -380,6 +381,73 @@ namespace dawn_native { namespace d3d12 {
namespace {
class VertexBufferTracker {
public:
void OnSetVertexBuffers(uint32_t startSlot,
uint32_t count,
Ref<BufferBase>* buffers,
uint64_t* offsets) {
mStartSlot = std::min(mStartSlot, startSlot);
mEndSlot = std::max(mEndSlot, startSlot + count);
for (uint32_t i = 0; i < count; ++i) {
Buffer* buffer = ToBackend(buffers[i].Get());
auto* d3d12BufferView = &mD3D12BufferViews[startSlot + i];
d3d12BufferView->BufferLocation = buffer->GetVA() + offsets[i];
d3d12BufferView->SizeInBytes = buffer->GetSize() - offsets[i];
// The bufferView stride is set based on the input state before a draw.
}
}
void Apply(ID3D12GraphicsCommandList* commandList,
const RenderPipeline* renderPipeline) {
ASSERT(renderPipeline != nullptr);
std::bitset<kMaxVertexBuffers> inputsMask = renderPipeline->GetInputsSetMask();
uint32_t startSlot = mStartSlot;
uint32_t endSlot = mEndSlot;
// If the input state has changed, we need to update the StrideInBytes
// for the D3D12 buffer views. We also need to extend the dirty range to
// touch all these slots because the stride may have changed.
if (mLastAppliedRenderPipeline != renderPipeline) {
mLastAppliedRenderPipeline = renderPipeline;
for (uint32_t slot : IterateBitSet(inputsMask)) {
startSlot = std::min(startSlot, slot);
endSlot = std::max(endSlot, slot + 1);
mD3D12BufferViews[slot].StrideInBytes =
renderPipeline->GetInput(slot).stride;
}
}
if (endSlot <= startSlot) {
return;
}
// mD3D12BufferViews is kept up to date with the most recent data passed
// to SetVertexBuffers. This makes it correct to only track the start
// and end of the dirty range. When Apply is called,
// we will at worst set non-dirty vertex buffers in duplicate.
uint32_t count = endSlot - startSlot;
commandList->IASetVertexBuffers(startSlot, count, &mD3D12BufferViews[startSlot]);
mStartSlot = kMaxVertexBuffers;
mEndSlot = 0;
}
private:
// startSlot and endSlot indicate the range of dirty vertex buffers.
// If there are multiple calls to SetVertexBuffers, the start and end
// represent the union of the dirty ranges (the union may have non-dirty
// data in the middle of the range).
const RenderPipeline* mLastAppliedRenderPipeline = nullptr;
uint32_t mStartSlot = kMaxVertexBuffers;
uint32_t mEndSlot = 0;
std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexBuffers> mD3D12BufferViews = {};
};
void AllocateAndSetDescriptorHeaps(Device* device,
BindGroupStateTracker* bindingTracker,
RenderPassDescriptorHeapTracker* renderPassTracker,
@ -719,47 +787,6 @@ namespace dawn_native { namespace d3d12 {
DAWN_ASSERT(renderPassTracker.IsHeapAllocationCompleted());
}
void CommandBuffer::FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
VertexBuffersInfo* vertexBuffersInfo,
const RenderPipeline* renderPipeline) {
DAWN_ASSERT(vertexBuffersInfo != nullptr);
DAWN_ASSERT(renderPipeline != nullptr);
auto inputsMask = renderPipeline->GetInputsSetMask();
uint32_t startSlot = vertexBuffersInfo->startSlot;
uint32_t endSlot = vertexBuffersInfo->endSlot;
// If the input state has changed, we need to update the StrideInBytes
// for the D3D12 buffer views. We also need to extend the dirty range to
// touch all these slots because the stride may have changed.
if (vertexBuffersInfo->lastRenderPipeline != renderPipeline) {
vertexBuffersInfo->lastRenderPipeline = renderPipeline;
for (uint32_t slot : IterateBitSet(inputsMask)) {
startSlot = std::min(startSlot, slot);
endSlot = std::max(endSlot, slot + 1);
vertexBuffersInfo->d3d12BufferViews[slot].StrideInBytes =
renderPipeline->GetInput(slot).stride;
}
}
if (endSlot <= startSlot) {
return;
}
// d3d12BufferViews is kept up to date with the most recent data passed
// to SetVertexBuffers. This makes it correct to only track the start
// and end of the dirty range. When FlushSetVertexBuffers is called,
// we will at worst set non-dirty vertex buffers in duplicate.
uint32_t count = endSlot - startSlot;
commandList->IASetVertexBuffers(startSlot, count,
&vertexBuffersInfo->d3d12BufferViews[startSlot]);
vertexBuffersInfo->startSlot = kMaxVertexBuffers;
vertexBuffersInfo->endSlot = 0;
}
void CommandBuffer::RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
BindGroupStateTracker* bindingTracker) {
PipelineLayout* lastLayout = nullptr;
@ -969,14 +996,14 @@ namespace dawn_native { namespace d3d12 {
RenderPipeline* lastPipeline = nullptr;
PipelineLayout* lastLayout = nullptr;
VertexBuffersInfo vertexBuffersInfo = {};
VertexBufferTracker vertexBufferTracker = {};
auto EncodeRenderBundleCommand = [&](CommandIterator* iter, Command type) {
switch (type) {
case Command::Draw: {
DrawCmd* draw = iter->NextCommand<DrawCmd>();
FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
commandList->DrawInstanced(draw->vertexCount, draw->instanceCount,
draw->firstVertex, draw->firstInstance);
} break;
@ -984,7 +1011,7 @@ namespace dawn_native { namespace d3d12 {
case Command::DrawIndexed: {
DrawIndexedCmd* draw = iter->NextCommand<DrawIndexedCmd>();
FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
commandList->DrawIndexedInstanced(draw->indexCount, draw->instanceCount,
draw->firstIndex, draw->baseVertex,
draw->firstInstance);
@ -993,7 +1020,7 @@ namespace dawn_native { namespace d3d12 {
case Command::DrawIndirect: {
DrawIndirectCmd* draw = iter->NextCommand<DrawIndirectCmd>();
FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
ComPtr<ID3D12CommandSignature> signature =
ToBackend(GetDevice())->GetDrawIndirectSignature();
@ -1005,7 +1032,7 @@ namespace dawn_native { namespace d3d12 {
case Command::DrawIndexedIndirect: {
DrawIndexedIndirectCmd* draw = iter->NextCommand<DrawIndexedIndirectCmd>();
FlushSetVertexBuffers(commandList, &vertexBuffersInfo, lastPipeline);
vertexBufferTracker.Apply(commandList.Get(), lastPipeline);
Buffer* buffer = ToBackend(draw->indirectBuffer.Get());
ComPtr<ID3D12CommandSignature> signature =
ToBackend(GetDevice())->GetDrawIndexedIndirectSignature();
@ -1096,22 +1123,11 @@ namespace dawn_native { namespace d3d12 {
case Command::SetVertexBuffers: {
SetVertexBuffersCmd* cmd = iter->NextCommand<SetVertexBuffersCmd>();
auto buffers = iter->NextData<Ref<BufferBase>>(cmd->count);
auto offsets = iter->NextData<uint64_t>(cmd->count);
Ref<BufferBase>* buffers = iter->NextData<Ref<BufferBase>>(cmd->count);
uint64_t* offsets = iter->NextData<uint64_t>(cmd->count);
vertexBuffersInfo.startSlot =
std::min(vertexBuffersInfo.startSlot, cmd->startSlot);
vertexBuffersInfo.endSlot =
std::max(vertexBuffersInfo.endSlot, cmd->startSlot + cmd->count);
for (uint32_t i = 0; i < cmd->count; ++i) {
Buffer* buffer = ToBackend(buffers[i].Get());
auto* d3d12BufferView =
&vertexBuffersInfo.d3d12BufferViews[cmd->startSlot + i];
d3d12BufferView->BufferLocation = buffer->GetVA() + offsets[i];
d3d12BufferView->SizeInBytes = buffer->GetSize() - offsets[i];
// The bufferView stride is set based on the input state before a draw.
}
vertexBufferTracker.OnSetVertexBuffers(cmd->startSlot, cmd->count, buffers,
offsets);
} break;
default:

View File

@ -35,17 +35,6 @@ namespace dawn_native { namespace d3d12 {
class RenderPassDescriptorHeapTracker;
class RenderPipeline;
struct VertexBuffersInfo {
// startSlot and endSlot indicate the range of dirty vertex buffers.
// If there are multiple calls to SetVertexBuffers, the start and end
// represent the union of the dirty ranges (the union may have non-dirty
// data in the middle of the range).
const RenderPipeline* lastRenderPipeline = nullptr;
uint32_t startSlot = kMaxVertexBuffers;
uint32_t endSlot = 0;
std::array<D3D12_VERTEX_BUFFER_VIEW, kMaxVertexBuffers> d3d12BufferViews = {};
};
class CommandBuffer : public CommandBufferBase {
public:
CommandBuffer(CommandEncoderBase* encoder, const CommandBufferDescriptor* descriptor);
@ -54,9 +43,6 @@ namespace dawn_native { namespace d3d12 {
void RecordCommands(ComPtr<ID3D12GraphicsCommandList> commandList, uint32_t indexInSubmit);
private:
void FlushSetVertexBuffers(ComPtr<ID3D12GraphicsCommandList> commandList,
VertexBuffersInfo* vertexBuffersInfo,
const RenderPipeline* lastRenderPipeline);
void RecordComputePass(ComPtr<ID3D12GraphicsCommandList> commandList,
BindGroupStateTracker* bindingTracker);
void RecordRenderPass(ComPtr<ID3D12GraphicsCommandList> commandList,