Introduce Begin/EndComputePass (#70)

This commit is contained in:
Kai Ninomiya 2017-07-10 14:07:24 -07:00 committed by GitHub
parent afdcf7d828
commit 296951df60
11 changed files with 177 additions and 33 deletions

View File

@ -264,11 +264,13 @@ void initCommandBuffers() {
auto& bufferSrc = particleBuffers[i]; auto& bufferSrc = particleBuffers[i];
auto& bufferDst = particleBuffers[(i + 1) % 2]; auto& bufferDst = particleBuffers[(i + 1) % 2];
commandBuffers[i] = device.CreateCommandBufferBuilder() commandBuffers[i] = device.CreateCommandBufferBuilder()
.BeginComputePass()
.SetPipeline(updatePipeline) .SetPipeline(updatePipeline)
.TransitionBufferUsage(bufferSrc, nxt::BufferUsageBit::Storage) .TransitionBufferUsage(bufferSrc, nxt::BufferUsageBit::Storage)
.TransitionBufferUsage(bufferDst, nxt::BufferUsageBit::Storage) .TransitionBufferUsage(bufferDst, nxt::BufferUsageBit::Storage)
.SetBindGroup(0, updateBGs[i]) .SetBindGroup(0, updateBGs[i])
.Dispatch(kNumParticles, 1, 1) .Dispatch(kNumParticles, 1, 1)
.EndComputePass()
.BeginRenderPass(renderpass, framebuffer) .BeginRenderPass(renderpass, framebuffer)
.BeginRenderSubpass() .BeginRenderSubpass()

View File

@ -126,10 +126,12 @@ void init() {
void frame() { void frame() {
nxt::CommandBuffer commands = device.CreateCommandBufferBuilder() nxt::CommandBuffer commands = device.CreateCommandBufferBuilder()
.BeginComputePass()
.SetPipeline(computePipeline) .SetPipeline(computePipeline)
.TransitionBufferUsage(buffer, nxt::BufferUsageBit::Storage) .TransitionBufferUsage(buffer, nxt::BufferUsageBit::Storage)
.SetBindGroup(0, computeBindGroup) .SetBindGroup(0, computeBindGroup)
.Dispatch(1, 1, 1) .Dispatch(1, 1, 1)
.EndComputePass()
.BeginRenderPass(renderpass, framebuffer) .BeginRenderPass(renderpass, framebuffer)
.BeginRenderSubpass() .BeginRenderSubpass()

View File

@ -247,6 +247,9 @@
"name": "get result", "name": "get result",
"returns": "command buffer" "returns": "command buffer"
}, },
{
"name": "begin compute pass"
},
{ {
"name": "begin render pass", "name": "begin render pass",
"args": [ "args": [
@ -340,6 +343,9 @@
{"name": "first instance", "type": "uint32_t"} {"name": "first instance", "type": "uint32_t"}
] ]
}, },
{
"name": "end compute pass"
},
{ {
"name": "end render pass" "name": "end render pass"
}, },

View File

@ -105,6 +105,12 @@ namespace backend {
Command type; Command type;
while(commands->NextCommandId(&type)) { while(commands->NextCommandId(&type)) {
switch (type) { switch (type) {
case Command::BeginComputePass:
{
BeginComputePassCmd* begin = commands->NextCommand<BeginComputePassCmd>();
begin->~BeginComputePassCmd();
}
break;
case Command::BeginRenderPass: case Command::BeginRenderPass:
{ {
BeginRenderPassCmd* begin = commands->NextCommand<BeginRenderPassCmd>(); BeginRenderPassCmd* begin = commands->NextCommand<BeginRenderPassCmd>();
@ -153,6 +159,12 @@ namespace backend {
draw->~DrawElementsCmd(); draw->~DrawElementsCmd();
} }
break; break;
case Command::EndComputePass:
{
EndComputePassCmd* cmd = commands->NextCommand<EndComputePassCmd>();
cmd->~EndComputePassCmd();
}
break;
case Command::EndRenderPass: case Command::EndRenderPass:
{ {
EndRenderPassCmd* cmd = commands->NextCommand<EndRenderPassCmd>(); EndRenderPassCmd* cmd = commands->NextCommand<EndRenderPassCmd>();
@ -226,6 +238,10 @@ namespace backend {
void SkipCommand(CommandIterator* commands, Command type) { void SkipCommand(CommandIterator* commands, Command type) {
switch (type) { switch (type) {
case Command::BeginComputePass:
commands->NextCommand<BeginComputePassCmd>();
break;
case Command::BeginRenderPass: case Command::BeginRenderPass:
commands->NextCommand<BeginRenderPassCmd>(); commands->NextCommand<BeginRenderPassCmd>();
break; break;
@ -258,6 +274,10 @@ namespace backend {
commands->NextCommand<DrawElementsCmd>(); commands->NextCommand<DrawElementsCmd>();
break; break;
case Command::EndComputePass:
commands->NextCommand<EndComputePassCmd>();
break;
case Command::EndRenderPass: case Command::EndRenderPass:
commands->NextCommand<EndRenderPassCmd>(); commands->NextCommand<EndRenderPassCmd>();
break; break;
@ -323,6 +343,15 @@ namespace backend {
Command type; Command type;
while (iterator.NextCommandId(&type)) { while (iterator.NextCommandId(&type)) {
switch (type) { switch (type) {
case Command::BeginComputePass:
{
iterator.NextCommand<BeginComputePassCmd>();
if (!state->BeginComputePass()) {
return false;
}
}
break;
case Command::BeginRenderPass: case Command::BeginRenderPass:
{ {
BeginRenderPassCmd* cmd = iterator.NextCommand<BeginRenderPassCmd>(); BeginRenderPassCmd* cmd = iterator.NextCommand<BeginRenderPassCmd>();
@ -424,6 +453,15 @@ namespace backend {
} }
break; break;
case Command::EndComputePass:
{
iterator.NextCommand<EndComputePassCmd>();
if (!state->EndComputePass()) {
return false;
}
}
break;
case Command::EndRenderPass: case Command::EndRenderPass:
{ {
iterator.NextCommand<EndRenderPassCmd>(); iterator.NextCommand<EndRenderPassCmd>();
@ -542,8 +580,8 @@ namespace backend {
return device->CreateCommandBuffer(this); return device->CreateCommandBuffer(this);
} }
void CommandBufferBuilder::BeginRenderSubpass() { void CommandBufferBuilder::BeginComputePass() {
allocator.Allocate<BeginRenderSubpassCmd>(Command::BeginRenderSubpass); allocator.Allocate<BeginComputePassCmd>(Command::BeginComputePass);
} }
void CommandBufferBuilder::BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer) { void CommandBufferBuilder::BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer) {
@ -553,6 +591,10 @@ namespace backend {
cmd->framebuffer = framebuffer; cmd->framebuffer = framebuffer;
} }
void CommandBufferBuilder::BeginRenderSubpass() {
allocator.Allocate<BeginRenderSubpassCmd>(Command::BeginRenderSubpass);
}
void CommandBufferBuilder::CopyBufferToBuffer(BufferBase* source, uint32_t sourceOffset, BufferBase* destination, uint32_t destinationOffset, uint32_t size) { void CommandBufferBuilder::CopyBufferToBuffer(BufferBase* source, uint32_t sourceOffset, BufferBase* destination, uint32_t destinationOffset, uint32_t size) {
CopyBufferToBufferCmd* copy = allocator.Allocate<CopyBufferToBufferCmd>(Command::CopyBufferToBuffer); CopyBufferToBufferCmd* copy = allocator.Allocate<CopyBufferToBufferCmd>(Command::CopyBufferToBuffer);
new(copy) CopyBufferToBufferCmd; new(copy) CopyBufferToBufferCmd;
@ -623,6 +665,10 @@ namespace backend {
draw->firstInstance = firstInstance; draw->firstInstance = firstInstance;
} }
void CommandBufferBuilder::EndComputePass() {
allocator.Allocate<EndComputePassCmd>(Command::EndComputePass);
}
void CommandBufferBuilder::EndRenderPass() { void CommandBufferBuilder::EndRenderPass() {
allocator.Allocate<EndRenderPassCmd>(Command::EndRenderPass); allocator.Allocate<EndRenderPassCmd>(Command::EndRenderPass);
} }

View File

@ -59,6 +59,7 @@ namespace backend {
CommandIterator AcquireCommands(); CommandIterator AcquireCommands();
// NXT API // NXT API
void BeginComputePass();
void BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer); void BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer);
void BeginRenderSubpass(); void BeginRenderSubpass();
void CopyBufferToBuffer(BufferBase* source, uint32_t sourceOffset, BufferBase* destination, uint32_t destinationOffset, uint32_t size); void CopyBufferToBuffer(BufferBase* source, uint32_t sourceOffset, BufferBase* destination, uint32_t destinationOffset, uint32_t size);
@ -71,6 +72,7 @@ namespace backend {
void Dispatch(uint32_t x, uint32_t y, uint32_t z); void Dispatch(uint32_t x, uint32_t y, uint32_t z);
void DrawArrays(uint32_t vertexCount, uint32_t instanceCount, uint32_t firstVertex, uint32_t firstInstance); void DrawArrays(uint32_t vertexCount, uint32_t instanceCount, uint32_t firstVertex, uint32_t firstInstance);
void DrawElements(uint32_t vertexCount, uint32_t instanceCount, uint32_t firstIndex, uint32_t firstInstance); void DrawElements(uint32_t vertexCount, uint32_t instanceCount, uint32_t firstIndex, uint32_t firstInstance);
void EndComputePass();
void EndRenderPass(); void EndRenderPass();
void EndRenderSubpass(); void EndRenderSubpass();
void SetPushConstants(nxt::ShaderStageBit stage, uint32_t offset, uint32_t count, const void* data); void SetPushConstants(nxt::ShaderStageBit stage, uint32_t offset, uint32_t count, const void* data);

View File

@ -61,7 +61,7 @@ namespace backend {
bool CommandBufferStateTracker::ValidateCanDispatch() { bool CommandBufferStateTracker::ValidateCanDispatch() {
constexpr ValidationAspects requiredAspects = constexpr ValidationAspects requiredAspects =
1 << VALIDATION_ASPECT_COMPUTE_PIPELINE | 1 << VALIDATION_ASPECT_COMPUTE_PIPELINE | // implicitly requires COMPUTE_PASS
1 << VALIDATION_ASPECT_BIND_GROUPS; 1 << VALIDATION_ASPECT_BIND_GROUPS;
if ((requiredAspects & ~aspects).none()) { if ((requiredAspects & ~aspects).none()) {
// Fast return-true path if everything is good // Fast return-true path if everything is good
@ -83,7 +83,7 @@ namespace backend {
bool CommandBufferStateTracker::ValidateCanDrawArrays() { bool CommandBufferStateTracker::ValidateCanDrawArrays() {
// TODO(kainino@chromium.org): Check for a current render pass // TODO(kainino@chromium.org): Check for a current render pass
constexpr ValidationAspects requiredAspects = constexpr ValidationAspects requiredAspects =
1 << VALIDATION_ASPECT_RENDER_PIPELINE | 1 << VALIDATION_ASPECT_RENDER_PIPELINE | // implicitly requires RENDER_SUBPASS
1 << VALIDATION_ASPECT_BIND_GROUPS | 1 << VALIDATION_ASPECT_BIND_GROUPS |
1 << VALIDATION_ASPECT_VERTEX_BUFFERS; 1 << VALIDATION_ASPECT_VERTEX_BUFFERS;
if ((requiredAspects & ~aspects).none()) { if ((requiredAspects & ~aspects).none()) {
@ -118,6 +118,29 @@ namespace backend {
builder->HandleError("Can't end command buffer with an active render pass"); builder->HandleError("Can't end command buffer with an active render pass");
return false; return false;
} }
if (aspects[VALIDATION_ASPECT_COMPUTE_PASS]) {
builder->HandleError("Can't end command buffer with an active compute pass");
return false;
}
return true;
}
bool CommandBufferStateTracker::BeginComputePass() {
if (currentRenderPass != nullptr) {
builder->HandleError("Cannot begin a compute pass while a render pass is active");
return false;
}
aspects.set(VALIDATION_ASPECT_COMPUTE_PASS);
return true;
}
bool CommandBufferStateTracker::EndComputePass() {
if (!aspects[VALIDATION_ASPECT_COMPUTE_PASS]) {
builder->HandleError("Can't end a compute pass without beginning one");
return false;
}
aspects.reset(VALIDATION_ASPECT_COMPUTE_PASS);
UnsetPipeline();
return true; return true;
} }
@ -193,6 +216,10 @@ namespace backend {
}; };
bool CommandBufferStateTracker::BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer) { bool CommandBufferStateTracker::BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer) {
if (aspects[VALIDATION_ASPECT_COMPUTE_PASS]) {
builder->HandleError("Cannot begin a render pass while a compute pass is active");
return false;
}
if (currentRenderPass != nullptr) { if (currentRenderPass != nullptr) {
builder->HandleError("A render pass is already active"); builder->HandleError("A render pass is already active");
return false; return false;
@ -207,7 +234,6 @@ namespace backend {
currentFramebuffer = framebuffer; currentFramebuffer = framebuffer;
currentSubpass = 0; currentSubpass = 0;
UnsetPipeline();
return true; return true;
} }
@ -234,6 +260,10 @@ namespace backend {
PipelineLayoutBase* layout = pipeline->GetLayout(); PipelineLayoutBase* layout = pipeline->GetLayout();
if (pipeline->IsCompute()) { if (pipeline->IsCompute()) {
if (!aspects[VALIDATION_ASPECT_COMPUTE_PASS]) {
builder->HandleError("A compute pass must be active when a compute pipeline is set");
return false;
}
if (currentRenderPass) { if (currentRenderPass) {
builder->HandleError("Can't use a compute pipeline while a render pass is active"); builder->HandleError("Can't use a compute pipeline while a render pass is active");
return false; return false;

View File

@ -39,6 +39,8 @@ namespace backend {
bool ValidateEndCommandBuffer() const; bool ValidateEndCommandBuffer() const;
// State-modifying methods // State-modifying methods
bool BeginComputePass();
bool EndComputePass();
bool BeginSubpass(); bool BeginSubpass();
bool EndSubpass(); bool EndSubpass();
bool BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer); bool BeginRenderPass(RenderPassBase* renderPass, FramebufferBase* framebuffer);
@ -66,6 +68,7 @@ namespace backend {
VALIDATION_ASPECT_VERTEX_BUFFERS, VALIDATION_ASPECT_VERTEX_BUFFERS,
VALIDATION_ASPECT_INDEX_BUFFER, VALIDATION_ASPECT_INDEX_BUFFER,
VALIDATION_ASPECT_RENDER_SUBPASS, VALIDATION_ASPECT_RENDER_SUBPASS,
VALIDATION_ASPECT_COMPUTE_PASS,
VALIDATION_ASPECT_COUNT VALIDATION_ASPECT_COUNT
}; };

View File

@ -28,6 +28,7 @@ namespace backend {
// dependencies: Ref<Object> needs Object to be defined. // dependencies: Ref<Object> needs Object to be defined.
enum class Command { enum class Command {
BeginComputePass,
BeginRenderPass, BeginRenderPass,
BeginRenderSubpass, BeginRenderSubpass,
CopyBufferToBuffer, CopyBufferToBuffer,
@ -36,6 +37,7 @@ namespace backend {
Dispatch, Dispatch,
DrawArrays, DrawArrays,
DrawElements, DrawElements,
EndComputePass,
EndRenderPass, EndRenderPass,
EndRenderSubpass, EndRenderSubpass,
SetPipeline, SetPipeline,
@ -48,6 +50,9 @@ namespace backend {
TransitionTextureUsage, TransitionTextureUsage,
}; };
struct BeginComputePassCmd {
};
struct BeginRenderPassCmd { struct BeginRenderPassCmd {
Ref<RenderPassBase> renderPass; Ref<RenderPassBase> renderPass;
Ref<FramebufferBase> framebuffer; Ref<FramebufferBase> framebuffer;
@ -104,6 +109,9 @@ namespace backend {
uint32_t firstInstance; uint32_t firstInstance;
}; };
struct EndComputePassCmd {
};
struct EndRenderPassCmd { struct EndRenderPassCmd {
}; };

View File

@ -253,6 +253,12 @@ namespace d3d12 {
while(commands.NextCommandId(&type)) { while(commands.NextCommandId(&type)) {
switch (type) { switch (type) {
case Command::BeginComputePass:
{
commands.NextCommand<BeginComputePassCmd>();
}
break;
case Command::BeginRenderPass: case Command::BeginRenderPass:
{ {
BeginRenderPassCmd* beginRenderPassCmd = commands.NextCommand<BeginRenderPassCmd>(); BeginRenderPassCmd* beginRenderPassCmd = commands.NextCommand<BeginRenderPassCmd>();
@ -365,6 +371,12 @@ namespace d3d12 {
} }
break; break;
case Command::EndComputePass:
{
commands.NextCommand<EndComputePassCmd>();
}
break;
case Command::EndRenderPass: case Command::EndRenderPass:
{ {
EndRenderPassCmd* cmd = commands.NextCommand<EndRenderPassCmd>(); EndRenderPassCmd* cmd = commands.NextCommand<EndRenderPassCmd>();

View File

@ -47,31 +47,35 @@ namespace metal {
RenderPass* currentRenderPass = nullptr; RenderPass* currentRenderPass = nullptr;
Framebuffer* currentFramebuffer = nullptr; Framebuffer* currentFramebuffer = nullptr;
void FinishEncoders() { void EnsureNoBlitEncoder() {
ASSERT(render == nil); ASSERT(render == nil);
ASSERT(compute == nil);
if (blit != nil) { if (blit != nil) {
[blit endEncoding]; [blit endEncoding];
blit = nil; blit = nil;
} }
if (compute != nil) {
[compute endEncoding];
compute = nil;
}
} }
void EnsureBlit(id<MTLCommandBuffer> commandBuffer) { void EnsureBlit(id<MTLCommandBuffer> commandBuffer) {
ASSERT(render == nil);
ASSERT(compute == nil);
if (blit == nil) { if (blit == nil) {
FinishEncoders();
blit = [commandBuffer blitCommandEncoder]; blit = [commandBuffer blitCommandEncoder];
} }
} }
void EnsureCompute(id<MTLCommandBuffer> commandBuffer) {
if (compute == nil) { void BeginCompute(id<MTLCommandBuffer> commandBuffer) {
FinishEncoders(); EnsureNoBlitEncoder();
compute = [commandBuffer computeCommandEncoder]; compute = [commandBuffer computeCommandEncoder];
// TODO(cwallez@chromium.org): does any state need to be reset? // TODO(cwallez@chromium.org): does any state need to be reset?
} }
void EndCompute() {
ASSERT(compute != nil);
[compute endEncoding];
compute = nil;
} }
void BeginSubpass(id<MTLCommandBuffer> commandBuffer, uint32_t subpass) { void BeginSubpass(id<MTLCommandBuffer> commandBuffer, uint32_t subpass) {
ASSERT(currentRenderPass); ASSERT(currentRenderPass);
if (render != nil) { if (render != nil) {
@ -111,7 +115,8 @@ namespace metal {
render = [commandBuffer renderCommandEncoderWithDescriptor:descriptor]; render = [commandBuffer renderCommandEncoderWithDescriptor:descriptor];
// TODO(cwallez@chromium.org): does any state need to be reset? // TODO(cwallez@chromium.org): does any state need to be reset?
} }
void EndRenderPass() {
void EndSubpass() {
ASSERT(render != nil); ASSERT(render != nil);
[render endEncoding]; [render endEncoding];
render = nil; render = nil;
@ -141,12 +146,19 @@ namespace metal {
uint32_t currentSubpass = 0; uint32_t currentSubpass = 0;
while (commands.NextCommandId(&type)) { while (commands.NextCommandId(&type)) {
switch (type) { switch (type) {
case Command::BeginComputePass:
{
commands.NextCommand<BeginComputePassCmd>();
encoders.BeginCompute(commandBuffer);
}
break;
case Command::BeginRenderPass: case Command::BeginRenderPass:
{ {
BeginRenderPassCmd* beginRenderPassCmd = commands.NextCommand<BeginRenderPassCmd>(); BeginRenderPassCmd* beginRenderPassCmd = commands.NextCommand<BeginRenderPassCmd>();
encoders.currentRenderPass = ToBackend(beginRenderPassCmd->renderPass.Get()); encoders.currentRenderPass = ToBackend(beginRenderPassCmd->renderPass.Get());
encoders.currentFramebuffer = ToBackend(beginRenderPassCmd->framebuffer.Get()); encoders.currentFramebuffer = ToBackend(beginRenderPassCmd->framebuffer.Get());
encoders.FinishEncoders(); encoders.EnsureNoBlitEncoder();
currentSubpass = 0; currentSubpass = 0;
} }
break; break;
@ -243,7 +255,7 @@ namespace metal {
case Command::Dispatch: case Command::Dispatch:
{ {
DispatchCmd* dispatch = commands.NextCommand<DispatchCmd>(); DispatchCmd* dispatch = commands.NextCommand<DispatchCmd>();
encoders.EnsureCompute(commandBuffer); ASSERT(encoders.compute);
ASSERT(lastPipeline->IsCompute()); ASSERT(lastPipeline->IsCompute());
[encoders.compute dispatchThreadgroups:MTLSizeMake(dispatch->x, dispatch->y, dispatch->z) [encoders.compute dispatchThreadgroups:MTLSizeMake(dispatch->x, dispatch->y, dispatch->z)
@ -282,16 +294,23 @@ namespace metal {
} }
break; break;
case Command::EndComputePass:
{
commands.NextCommand<EndComputePassCmd>();
encoders.EndCompute();
}
break;
case Command::EndRenderPass: case Command::EndRenderPass:
{ {
commands.NextCommand<EndRenderPassCmd>(); commands.NextCommand<EndRenderPassCmd>();
encoders.EndRenderPass();
} }
break; break;
case Command::EndRenderSubpass: case Command::EndRenderSubpass:
{ {
commands.NextCommand<EndRenderSubpassCmd>(); commands.NextCommand<EndRenderSubpassCmd>();
encoders.EndSubpass();
currentSubpass += 1; currentSubpass += 1;
} }
break; break;
@ -302,7 +321,7 @@ namespace metal {
lastPipeline = ToBackend(cmd->pipeline).Get(); lastPipeline = ToBackend(cmd->pipeline).Get();
if (lastPipeline->IsCompute()) { if (lastPipeline->IsCompute()) {
encoders.EnsureCompute(commandBuffer); ASSERT(encoders.compute);
lastPipeline->Encode(encoders.compute); lastPipeline->Encode(encoders.compute);
} else { } else {
ASSERT(encoders.render); ASSERT(encoders.render);
@ -340,7 +359,7 @@ namespace metal {
const auto& layout = group->GetLayout()->GetBindingInfo(); const auto& layout = group->GetLayout()->GetBindingInfo();
if (lastPipeline->IsCompute()) { if (lastPipeline->IsCompute()) {
encoders.EnsureCompute(commandBuffer); ASSERT(encoders.compute);
} else { } else {
ASSERT(encoders.render); ASSERT(encoders.render);
} }
@ -501,7 +520,9 @@ namespace metal {
} }
} }
encoders.FinishEncoders(); encoders.EnsureNoBlitEncoder();
ASSERT(encoders.render == nil);
ASSERT(encoders.compute == nil);
} }
} }

View File

@ -65,6 +65,12 @@ namespace opengl {
while(commands.NextCommandId(&type)) { while(commands.NextCommandId(&type)) {
switch (type) { switch (type) {
case Command::BeginComputePass:
{
commands.NextCommand<BeginComputePassCmd>();
}
break;
case Command::BeginRenderPass: case Command::BeginRenderPass:
{ {
commands.NextCommand<BeginRenderPassCmd>(); commands.NextCommand<BeginRenderPassCmd>();
@ -166,6 +172,12 @@ namespace opengl {
} }
break; break;
case Command::EndComputePass:
{
commands.NextCommand<EndComputePassCmd>();
}
break;
case Command::EndRenderPass: case Command::EndRenderPass:
{ {
commands.NextCommand<EndRenderPassCmd>(); commands.NextCommand<EndRenderPassCmd>();