diff --git a/examples/ComputeBoids.cpp b/examples/ComputeBoids.cpp index 518bbbc7c2..1d6a2e449f 100644 --- a/examples/ComputeBoids.cpp +++ b/examples/ComputeBoids.cpp @@ -264,11 +264,13 @@ void initCommandBuffers() { auto& bufferSrc = particleBuffers[i]; auto& bufferDst = particleBuffers[(i + 1) % 2]; commandBuffers[i] = device.CreateCommandBufferBuilder() - .SetPipeline(updatePipeline) - .TransitionBufferUsage(bufferSrc, nxt::BufferUsageBit::Storage) - .TransitionBufferUsage(bufferDst, nxt::BufferUsageBit::Storage) - .SetBindGroup(0, updateBGs[i]) - .Dispatch(kNumParticles, 1, 1) + .BeginComputePass() + .SetPipeline(updatePipeline) + .TransitionBufferUsage(bufferSrc, nxt::BufferUsageBit::Storage) + .TransitionBufferUsage(bufferDst, nxt::BufferUsageBit::Storage) + .SetBindGroup(0, updateBGs[i]) + .Dispatch(kNumParticles, 1, 1) + .EndComputePass() .BeginRenderPass(renderpass, framebuffer) .BeginRenderSubpass() diff --git a/examples/HelloCompute.cpp b/examples/HelloCompute.cpp index 0d3d58c896..9e611ddcd4 100644 --- a/examples/HelloCompute.cpp +++ b/examples/HelloCompute.cpp @@ -126,10 +126,12 @@ void init() { void frame() { nxt::CommandBuffer commands = device.CreateCommandBufferBuilder() - .SetPipeline(computePipeline) - .TransitionBufferUsage(buffer, nxt::BufferUsageBit::Storage) - .SetBindGroup(0, computeBindGroup) - .Dispatch(1, 1, 1) + .BeginComputePass() + .SetPipeline(computePipeline) + .TransitionBufferUsage(buffer, nxt::BufferUsageBit::Storage) + .SetBindGroup(0, computeBindGroup) + .Dispatch(1, 1, 1) + .EndComputePass() .BeginRenderPass(renderpass, framebuffer) .BeginRenderSubpass() diff --git a/next.json b/next.json index b038e560c0..ecbe42fb1b 100644 --- a/next.json +++ b/next.json @@ -247,6 +247,9 @@ "name": "get result", "returns": "command buffer" }, + { + "name": "begin compute pass" + }, { "name": "begin render pass", "args": [ @@ -340,6 +343,9 @@ {"name": "first instance", "type": "uint32_t"} ] }, + { + "name": "end compute pass" + }, { "name": "end render pass" }, diff --git a/src/backend/CommandBuffer.cpp b/src/backend/CommandBuffer.cpp index b182662128..cf68b145dc 100644 --- a/src/backend/CommandBuffer.cpp +++ b/src/backend/CommandBuffer.cpp @@ -105,6 +105,12 @@ namespace backend { Command type; while(commands->NextCommandId(&type)) { switch (type) { + case Command::BeginComputePass: + { + BeginComputePassCmd* begin = commands->NextCommand(); + begin->~BeginComputePassCmd(); + } + break; case Command::BeginRenderPass: { BeginRenderPassCmd* begin = commands->NextCommand(); @@ -153,6 +159,12 @@ namespace backend { draw->~DrawElementsCmd(); } break; + case Command::EndComputePass: + { + EndComputePassCmd* cmd = commands->NextCommand(); + cmd->~EndComputePassCmd(); + } + break; case Command::EndRenderPass: { EndRenderPassCmd* cmd = commands->NextCommand(); @@ -226,6 +238,10 @@ namespace backend { void SkipCommand(CommandIterator* commands, Command type) { switch (type) { + case Command::BeginComputePass: + commands->NextCommand(); + break; + case Command::BeginRenderPass: commands->NextCommand(); break; @@ -258,6 +274,10 @@ namespace backend { commands->NextCommand(); break; + case Command::EndComputePass: + commands->NextCommand(); + break; + case Command::EndRenderPass: commands->NextCommand(); break; @@ -323,6 +343,15 @@ namespace backend { Command type; while (iterator.NextCommandId(&type)) { switch (type) { + case Command::BeginComputePass: + { + iterator.NextCommand(); + if (!state->BeginComputePass()) { + return false; + } + } + break; + case Command::BeginRenderPass: { BeginRenderPassCmd* cmd = iterator.NextCommand(); @@ -424,6 +453,15 @@ namespace backend { } break; + case Command::EndComputePass: + { + iterator.NextCommand(); + if (!state->EndComputePass()) { + return false; + } + } + break; + case Command::EndRenderPass: { iterator.NextCommand(); @@ -542,8 +580,8 @@ namespace backend { return device->CreateCommandBuffer(this); } - void CommandBufferBuilder::BeginRenderSubpass() { - allocator.Allocate(Command::BeginRenderSubpass); + void CommandBufferBuilder::BeginComputePass() { + allocator.Allocate(Command::BeginComputePass); } void CommandBufferBuilder::BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer) { @@ -553,6 +591,10 @@ namespace backend { cmd->framebuffer = framebuffer; } + void CommandBufferBuilder::BeginRenderSubpass() { + allocator.Allocate(Command::BeginRenderSubpass); + } + void CommandBufferBuilder::CopyBufferToBuffer(BufferBase* source, uint32_t sourceOffset, BufferBase* destination, uint32_t destinationOffset, uint32_t size) { CopyBufferToBufferCmd* copy = allocator.Allocate(Command::CopyBufferToBuffer); new(copy) CopyBufferToBufferCmd; @@ -623,6 +665,10 @@ namespace backend { draw->firstInstance = firstInstance; } + void CommandBufferBuilder::EndComputePass() { + allocator.Allocate(Command::EndComputePass); + } + void CommandBufferBuilder::EndRenderPass() { allocator.Allocate(Command::EndRenderPass); } diff --git a/src/backend/CommandBuffer.h b/src/backend/CommandBuffer.h index 4b0c90def2..3caeda721d 100644 --- a/src/backend/CommandBuffer.h +++ b/src/backend/CommandBuffer.h @@ -59,6 +59,7 @@ namespace backend { CommandIterator AcquireCommands(); // NXT API + void BeginComputePass(); void BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer); void BeginRenderSubpass(); void CopyBufferToBuffer(BufferBase* source, uint32_t sourceOffset, BufferBase* destination, uint32_t destinationOffset, uint32_t size); @@ -71,6 +72,7 @@ namespace backend { void Dispatch(uint32_t x, uint32_t y, uint32_t z); void DrawArrays(uint32_t vertexCount, uint32_t instanceCount, uint32_t firstVertex, uint32_t firstInstance); void DrawElements(uint32_t vertexCount, uint32_t instanceCount, uint32_t firstIndex, uint32_t firstInstance); + void EndComputePass(); void EndRenderPass(); void EndRenderSubpass(); void SetPushConstants(nxt::ShaderStageBit stage, uint32_t offset, uint32_t count, const void* data); diff --git a/src/backend/CommandBufferStateTracker.cpp b/src/backend/CommandBufferStateTracker.cpp index 122ddb1647..1b3a332598 100644 --- a/src/backend/CommandBufferStateTracker.cpp +++ b/src/backend/CommandBufferStateTracker.cpp @@ -61,7 +61,7 @@ namespace backend { bool CommandBufferStateTracker::ValidateCanDispatch() { constexpr ValidationAspects requiredAspects = - 1 << VALIDATION_ASPECT_COMPUTE_PIPELINE | + 1 << VALIDATION_ASPECT_COMPUTE_PIPELINE | // implicitly requires COMPUTE_PASS 1 << VALIDATION_ASPECT_BIND_GROUPS; if ((requiredAspects & ~aspects).none()) { // Fast return-true path if everything is good @@ -83,7 +83,7 @@ namespace backend { bool CommandBufferStateTracker::ValidateCanDrawArrays() { // TODO(kainino@chromium.org): Check for a current render pass constexpr ValidationAspects requiredAspects = - 1 << VALIDATION_ASPECT_RENDER_PIPELINE | + 1 << VALIDATION_ASPECT_RENDER_PIPELINE | // implicitly requires RENDER_SUBPASS 1 << VALIDATION_ASPECT_BIND_GROUPS | 1 << VALIDATION_ASPECT_VERTEX_BUFFERS; if ((requiredAspects & ~aspects).none()) { @@ -118,6 +118,29 @@ namespace backend { builder->HandleError("Can't end command buffer with an active render pass"); return false; } + if (aspects[VALIDATION_ASPECT_COMPUTE_PASS]) { + builder->HandleError("Can't end command buffer with an active compute pass"); + return false; + } + return true; + } + + bool CommandBufferStateTracker::BeginComputePass() { + if (currentRenderPass != nullptr) { + builder->HandleError("Cannot begin a compute pass while a render pass is active"); + return false; + } + aspects.set(VALIDATION_ASPECT_COMPUTE_PASS); + return true; + } + + bool CommandBufferStateTracker::EndComputePass() { + if (!aspects[VALIDATION_ASPECT_COMPUTE_PASS]) { + builder->HandleError("Can't end a compute pass without beginning one"); + return false; + } + aspects.reset(VALIDATION_ASPECT_COMPUTE_PASS); + UnsetPipeline(); return true; } @@ -193,6 +216,10 @@ namespace backend { }; bool CommandBufferStateTracker::BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer) { + if (aspects[VALIDATION_ASPECT_COMPUTE_PASS]) { + builder->HandleError("Cannot begin a render pass while a compute pass is active"); + return false; + } if (currentRenderPass != nullptr) { builder->HandleError("A render pass is already active"); return false; @@ -207,7 +234,6 @@ namespace backend { currentFramebuffer = framebuffer; currentSubpass = 0; - UnsetPipeline(); return true; } @@ -234,6 +260,10 @@ namespace backend { PipelineLayoutBase* layout = pipeline->GetLayout(); if (pipeline->IsCompute()) { + if (!aspects[VALIDATION_ASPECT_COMPUTE_PASS]) { + builder->HandleError("A compute pass must be active when a compute pipeline is set"); + return false; + } if (currentRenderPass) { builder->HandleError("Can't use a compute pipeline while a render pass is active"); return false; diff --git a/src/backend/CommandBufferStateTracker.h b/src/backend/CommandBufferStateTracker.h index 81d04c8c03..c490c600e0 100644 --- a/src/backend/CommandBufferStateTracker.h +++ b/src/backend/CommandBufferStateTracker.h @@ -39,6 +39,8 @@ namespace backend { bool ValidateEndCommandBuffer() const; // State-modifying methods + bool BeginComputePass(); + bool EndComputePass(); bool BeginSubpass(); bool EndSubpass(); bool BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer); @@ -66,6 +68,7 @@ namespace backend { VALIDATION_ASPECT_VERTEX_BUFFERS, VALIDATION_ASPECT_INDEX_BUFFER, VALIDATION_ASPECT_RENDER_SUBPASS, + VALIDATION_ASPECT_COMPUTE_PASS, VALIDATION_ASPECT_COUNT }; diff --git a/src/backend/Commands.h b/src/backend/Commands.h index 3e3839cbb3..4df6eacadb 100644 --- a/src/backend/Commands.h +++ b/src/backend/Commands.h @@ -28,6 +28,7 @@ namespace backend { // dependencies: Ref needs Object to be defined. enum class Command { + BeginComputePass, BeginRenderPass, BeginRenderSubpass, CopyBufferToBuffer, @@ -36,6 +37,7 @@ namespace backend { Dispatch, DrawArrays, DrawElements, + EndComputePass, EndRenderPass, EndRenderSubpass, SetPipeline, @@ -48,6 +50,9 @@ namespace backend { TransitionTextureUsage, }; + struct BeginComputePassCmd { + }; + struct BeginRenderPassCmd { Ref renderPass; Ref framebuffer; @@ -104,6 +109,9 @@ namespace backend { uint32_t firstInstance; }; + struct EndComputePassCmd { + }; + struct EndRenderPassCmd { }; diff --git a/src/backend/d3d12/CommandBufferD3D12.cpp b/src/backend/d3d12/CommandBufferD3D12.cpp index cce134e9f7..f9e9036a16 100644 --- a/src/backend/d3d12/CommandBufferD3D12.cpp +++ b/src/backend/d3d12/CommandBufferD3D12.cpp @@ -253,6 +253,12 @@ namespace d3d12 { while(commands.NextCommandId(&type)) { switch (type) { + case Command::BeginComputePass: + { + commands.NextCommand(); + } + break; + case Command::BeginRenderPass: { BeginRenderPassCmd* beginRenderPassCmd = commands.NextCommand(); @@ -365,6 +371,12 @@ namespace d3d12 { } break; + case Command::EndComputePass: + { + commands.NextCommand(); + } + break; + case Command::EndRenderPass: { EndRenderPassCmd* cmd = commands.NextCommand(); diff --git a/src/backend/metal/CommandBufferMTL.mm b/src/backend/metal/CommandBufferMTL.mm index dd4cc37964..b367d5b9f3 100644 --- a/src/backend/metal/CommandBufferMTL.mm +++ b/src/backend/metal/CommandBufferMTL.mm @@ -47,31 +47,35 @@ namespace metal { RenderPass* currentRenderPass = nullptr; Framebuffer* currentFramebuffer = nullptr; - void FinishEncoders() { + void EnsureNoBlitEncoder() { ASSERT(render == nil); + ASSERT(compute == nil); if (blit != nil) { [blit endEncoding]; blit = nil; } - if (compute != nil) { - [compute endEncoding]; - compute = nil; - } } void EnsureBlit(id commandBuffer) { + ASSERT(render == nil); + ASSERT(compute == nil); if (blit == nil) { - FinishEncoders(); blit = [commandBuffer blitCommandEncoder]; } } - void EnsureCompute(id commandBuffer) { - if (compute == nil) { - FinishEncoders(); - compute = [commandBuffer computeCommandEncoder]; - // TODO(cwallez@chromium.org): does any state need to be reset? - } + + void BeginCompute(id commandBuffer) { + EnsureNoBlitEncoder(); + compute = [commandBuffer computeCommandEncoder]; + // TODO(cwallez@chromium.org): does any state need to be reset? } + + void EndCompute() { + ASSERT(compute != nil); + [compute endEncoding]; + compute = nil; + } + void BeginSubpass(id commandBuffer, uint32_t subpass) { ASSERT(currentRenderPass); if (render != nil) { @@ -111,7 +115,8 @@ namespace metal { render = [commandBuffer renderCommandEncoderWithDescriptor:descriptor]; // TODO(cwallez@chromium.org): does any state need to be reset? } - void EndRenderPass() { + + void EndSubpass() { ASSERT(render != nil); [render endEncoding]; render = nil; @@ -141,12 +146,19 @@ namespace metal { uint32_t currentSubpass = 0; while (commands.NextCommandId(&type)) { switch (type) { + case Command::BeginComputePass: + { + commands.NextCommand(); + encoders.BeginCompute(commandBuffer); + } + break; + case Command::BeginRenderPass: { BeginRenderPassCmd* beginRenderPassCmd = commands.NextCommand(); encoders.currentRenderPass = ToBackend(beginRenderPassCmd->renderPass.Get()); encoders.currentFramebuffer = ToBackend(beginRenderPassCmd->framebuffer.Get()); - encoders.FinishEncoders(); + encoders.EnsureNoBlitEncoder(); currentSubpass = 0; } break; @@ -243,7 +255,7 @@ namespace metal { case Command::Dispatch: { DispatchCmd* dispatch = commands.NextCommand(); - encoders.EnsureCompute(commandBuffer); + ASSERT(encoders.compute); ASSERT(lastPipeline->IsCompute()); [encoders.compute dispatchThreadgroups:MTLSizeMake(dispatch->x, dispatch->y, dispatch->z) @@ -282,16 +294,23 @@ namespace metal { } break; + case Command::EndComputePass: + { + commands.NextCommand(); + encoders.EndCompute(); + } + break; + case Command::EndRenderPass: { commands.NextCommand(); - encoders.EndRenderPass(); } break; case Command::EndRenderSubpass: { commands.NextCommand(); + encoders.EndSubpass(); currentSubpass += 1; } break; @@ -302,7 +321,7 @@ namespace metal { lastPipeline = ToBackend(cmd->pipeline).Get(); if (lastPipeline->IsCompute()) { - encoders.EnsureCompute(commandBuffer); + ASSERT(encoders.compute); lastPipeline->Encode(encoders.compute); } else { ASSERT(encoders.render); @@ -340,7 +359,7 @@ namespace metal { const auto& layout = group->GetLayout()->GetBindingInfo(); if (lastPipeline->IsCompute()) { - encoders.EnsureCompute(commandBuffer); + ASSERT(encoders.compute); } else { ASSERT(encoders.render); } @@ -501,7 +520,9 @@ namespace metal { } } - encoders.FinishEncoders(); + encoders.EnsureNoBlitEncoder(); + ASSERT(encoders.render == nil); + ASSERT(encoders.compute == nil); } } diff --git a/src/backend/opengl/CommandBufferGL.cpp b/src/backend/opengl/CommandBufferGL.cpp index 0de3e55518..6a0d16e9d2 100644 --- a/src/backend/opengl/CommandBufferGL.cpp +++ b/src/backend/opengl/CommandBufferGL.cpp @@ -65,6 +65,12 @@ namespace opengl { while(commands.NextCommandId(&type)) { switch (type) { + case Command::BeginComputePass: + { + commands.NextCommand(); + } + break; + case Command::BeginRenderPass: { commands.NextCommand(); @@ -166,6 +172,12 @@ namespace opengl { } break; + case Command::EndComputePass: + { + commands.NextCommand(); + } + break; + case Command::EndRenderPass: { commands.NextCommand();