Make reeantrant creation calls returning Ref for command encoding.

Fixed: dawn:723
Change-Id: I953e0aa7d663f68e15c021448a90ecf799fef891
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/84766
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
Corentin Wallez 2022-03-29 08:10:03 +00:00 committed by Dawn LUCI CQ
parent 46d5480d20
commit e055ae5b52
10 changed files with 114 additions and 58 deletions

View File

@ -700,6 +700,11 @@ namespace dawn::native {
// Implementation of the API's command recording methods
ComputePassEncoder* CommandEncoder::APIBeginComputePass(
const ComputePassDescriptor* descriptor) {
return BeginComputePass(descriptor).Detach();
}
Ref<ComputePassEncoder> CommandEncoder::BeginComputePass(
const ComputePassDescriptor* descriptor) {
DeviceBase* device = GetDevice();
@ -748,9 +753,9 @@ namespace dawn::native {
descriptor = &defaultDescriptor;
}
ComputePassEncoder* passEncoder = new ComputePassEncoder(
Ref<ComputePassEncoder> passEncoder = ComputePassEncoder::Create(
device, descriptor, this, &mEncodingContext, std::move(timestampWritesAtEnd));
mEncodingContext.EnterPass(passEncoder);
mEncodingContext.EnterPass(passEncoder.Get());
return passEncoder;
}
@ -758,6 +763,10 @@ namespace dawn::native {
}
RenderPassEncoder* CommandEncoder::APIBeginRenderPass(const RenderPassDescriptor* descriptor) {
return BeginRenderPass(descriptor).Detach();
}
Ref<RenderPassEncoder> CommandEncoder::BeginRenderPass(const RenderPassDescriptor* descriptor) {
DeviceBase* device = GetDevice();
RenderPassResourceUsageTracker usageTracker;
@ -907,11 +916,11 @@ namespace dawn::native {
"encoding %s.BeginRenderPass(%s).", this, descriptor);
if (success) {
RenderPassEncoder* passEncoder = new RenderPassEncoder(
Ref<RenderPassEncoder> passEncoder = RenderPassEncoder::Create(
device, descriptor, this, &mEncodingContext, std::move(usageTracker),
std::move(attachmentState), std::move(timestampWritesAtEnd), width, height,
depthReadOnly, stencilReadOnly);
mEncodingContext.EnterPass(passEncoder);
mEncodingContext.EnterPass(passEncoder.Get());
return passEncoder;
}
@ -1355,14 +1364,14 @@ namespace dawn::native {
CommandBufferBase* CommandEncoder::APIFinish(const CommandBufferDescriptor* descriptor) {
Ref<CommandBufferBase> commandBuffer;
if (GetDevice()->ConsumedError(FinishInternal(descriptor), &commandBuffer)) {
if (GetDevice()->ConsumedError(Finish(descriptor), &commandBuffer)) {
return CommandBufferBase::MakeError(GetDevice());
}
ASSERT(!IsError());
return commandBuffer.Detach();
}
ResultOrError<Ref<CommandBufferBase>> CommandEncoder::FinishInternal(
ResultOrError<Ref<CommandBufferBase>> CommandEncoder::Finish(
const CommandBufferDescriptor* descriptor) {
DeviceBase* device = GetDevice();

View File

@ -86,13 +86,16 @@ namespace dawn::native {
CommandBufferBase* APIFinish(const CommandBufferDescriptor* descriptor = nullptr);
Ref<ComputePassEncoder> BeginComputePass(const ComputePassDescriptor* descriptor = nullptr);
Ref<RenderPassEncoder> BeginRenderPass(const RenderPassDescriptor* descriptor);
ResultOrError<Ref<CommandBufferBase>> Finish(
const CommandBufferDescriptor* descriptor = nullptr);
private:
CommandEncoder(DeviceBase* device, const CommandEncoderDescriptor* descriptor);
CommandEncoder(DeviceBase* device, ObjectBase::ErrorTag tag);
void DestroyImpl() override;
ResultOrError<Ref<CommandBufferBase>> FinishInternal(
const CommandBufferDescriptor* descriptor);
// Helper to be able to implement both APICopyTextureToTexture and
// APICopyTextureToTextureInternal. The only difference between both

View File

@ -121,6 +121,17 @@ namespace dawn::native {
TrackInDevice();
}
// static
Ref<ComputePassEncoder> ComputePassEncoder::Create(
DeviceBase* device,
const ComputePassDescriptor* descriptor,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
std::vector<TimestampWrite> timestampWritesAtEnd) {
return AcquireRef(new ComputePassEncoder(device, descriptor, commandEncoder,
encodingContext, std::move(timestampWritesAtEnd)));
}
ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
@ -128,10 +139,12 @@ namespace dawn::native {
: ProgrammableEncoder(device, encodingContext, errorTag), mCommandEncoder(commandEncoder) {
}
ComputePassEncoder* ComputePassEncoder::MakeError(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext) {
return new ComputePassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError);
// static
Ref<ComputePassEncoder> ComputePassEncoder::MakeError(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext) {
return AcquireRef(
new ComputePassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError));
}
void ComputePassEncoder::DestroyImpl() {

View File

@ -27,15 +27,14 @@ namespace dawn::native {
class ComputePassEncoder final : public ProgrammableEncoder {
public:
ComputePassEncoder(DeviceBase* device,
const ComputePassDescriptor* descriptor,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
std::vector<TimestampWrite> timestampWritesAtEnd);
static ComputePassEncoder* MakeError(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext);
static Ref<ComputePassEncoder> Create(DeviceBase* device,
const ComputePassDescriptor* descriptor,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
std::vector<TimestampWrite> timestampWritesAtEnd);
static Ref<ComputePassEncoder> MakeError(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext);
ObjectType GetType() const override;
@ -61,6 +60,11 @@ namespace dawn::native {
}
protected:
ComputePassEncoder(DeviceBase* device,
const ComputePassDescriptor* descriptor,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
std::vector<TimestampWrite> timestampWritesAtEnd);
ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,

View File

@ -555,9 +555,8 @@ namespace dawn::native {
{{0, uniformBuffer}, {1, sampler}, {2, srcTextureView}}));
// Create command encoder.
CommandEncoderDescriptor encoderDesc = {};
// TODO(dawn:723): change to not use AcquireRef for reentrant object creation.
Ref<CommandEncoder> encoder = AcquireRef(device->APICreateCommandEncoder(&encoderDesc));
Ref<CommandEncoder> encoder;
DAWN_TRY_ASSIGN(encoder, device->CreateCommandEncoder());
// Prepare dst texture view as color Attachment.
TextureViewDescriptor dstTextureViewDesc;
@ -581,9 +580,7 @@ namespace dawn::native {
RenderPassDescriptor renderPassDesc;
renderPassDesc.colorAttachmentCount = 1;
renderPassDesc.colorAttachments = &colorAttachmentDesc;
// TODO(dawn:723): change to not use AcquireRef for reentrant object creation.
Ref<RenderPassEncoder> passEncoder =
AcquireRef(encoder->APIBeginRenderPass(&renderPassDesc));
Ref<RenderPassEncoder> passEncoder = encoder->BeginRenderPass(&renderPassDesc);
// Start pipeline and encode commands to complete
// the copy from src texture to dst texture with transformation.
@ -595,8 +592,8 @@ namespace dawn::native {
passEncoder->APIEnd();
// Finsh encoding.
// TODO(dawn:723): change to not use AcquireRef for reentrant object creation.
Ref<CommandBufferBase> commandBuffer = AcquireRef(encoder->APIFinish());
Ref<CommandBufferBase> commandBuffer;
DAWN_TRY_ASSIGN(commandBuffer, encoder->Finish());
CommandBufferBase* submitCommandBuffer = commandBuffer.Get();
// Submit command buffer.

View File

@ -987,11 +987,6 @@ namespace dawn::native {
}
CommandEncoder* DeviceBase::APICreateCommandEncoder(
const CommandEncoderDescriptor* descriptor) {
const CommandEncoderDescriptor defaultDescriptor = {};
if (descriptor == nullptr) {
descriptor = &defaultDescriptor;
}
Ref<CommandEncoder> result;
if (ConsumedError(CreateCommandEncoder(descriptor), &result,
"calling %s.CreateCommandEncoder(%s).", this, descriptor)) {
@ -1359,6 +1354,11 @@ namespace dawn::native {
ResultOrError<Ref<CommandEncoder>> DeviceBase::CreateCommandEncoder(
const CommandEncoderDescriptor* descriptor) {
const CommandEncoderDescriptor defaultDescriptor = {};
if (descriptor == nullptr) {
descriptor = &defaultDescriptor;
}
DAWN_TRY(ValidateIsAlive());
if (IsValidationEnabled()) {
DAWN_TRY(ValidateCommandEncoderDescriptor(this, descriptor));

View File

@ -201,7 +201,7 @@ namespace dawn::native {
bool allowInternalBinding = false);
ResultOrError<Ref<BufferBase>> CreateBuffer(const BufferDescriptor* descriptor);
ResultOrError<Ref<CommandEncoder>> CreateCommandEncoder(
const CommandEncoderDescriptor* descriptor);
const CommandEncoderDescriptor* descriptor = nullptr);
ResultOrError<Ref<ComputePipelineBase>> CreateComputePipeline(
const ComputePipelineDescriptor* descriptor);
MaybeError CreateComputePipelineAsync(const ComputePipelineDescriptor* descriptor,
@ -218,7 +218,8 @@ namespace dawn::native {
MaybeError CreateRenderPipelineAsync(const RenderPipelineDescriptor* descriptor,
WGPUCreateRenderPipelineAsyncCallback callback,
void* userdata);
ResultOrError<Ref<SamplerBase>> CreateSampler(const SamplerDescriptor* descriptor);
ResultOrError<Ref<SamplerBase>> CreateSampler(
const SamplerDescriptor* descriptor = nullptr);
ResultOrError<Ref<ShaderModuleBase>> CreateShaderModule(
const ShaderModuleDescriptor* descriptor,
OwnedCompilationMessages* compilationMessages = nullptr);

View File

@ -204,9 +204,7 @@ namespace dawn::native {
{{0, timestamps}, {1, availability}, {2, params}}));
// Create compute encoder and issue dispatch.
ComputePassDescriptor passDesc = {};
// TODO(dawn:723): change to not use AcquireRef for reentrant object creation.
Ref<ComputePassEncoder> pass = AcquireRef(encoder->APIBeginComputePass(&passDesc));
Ref<ComputePassEncoder> pass = encoder->BeginComputePass();
pass->APISetPipeline(pipeline);
pass->APISetBindGroup(0, bindGroup.Get());
pass->APIDispatch(

View File

@ -74,6 +74,25 @@ namespace dawn::native {
TrackInDevice();
}
// static
Ref<RenderPassEncoder> RenderPassEncoder::Create(
DeviceBase* device,
const RenderPassDescriptor* descriptor,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
RenderPassResourceUsageTracker usageTracker,
Ref<AttachmentState> attachmentState,
std::vector<TimestampWrite> timestampWritesAtEnd,
uint32_t renderTargetWidth,
uint32_t renderTargetHeight,
bool depthReadOnly,
bool stencilReadOnly) {
return AcquireRef(new RenderPassEncoder(
device, descriptor, commandEncoder, encodingContext, std::move(usageTracker),
std::move(attachmentState), std::move(timestampWritesAtEnd), renderTargetWidth,
renderTargetHeight, depthReadOnly, stencilReadOnly));
}
RenderPassEncoder::RenderPassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
@ -81,10 +100,12 @@ namespace dawn::native {
: RenderEncoderBase(device, encodingContext, errorTag), mCommandEncoder(commandEncoder) {
}
RenderPassEncoder* RenderPassEncoder::MakeError(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext) {
return new RenderPassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError);
// static
Ref<RenderPassEncoder> RenderPassEncoder::MakeError(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext) {
return AcquireRef(
new RenderPassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError));
}
void RenderPassEncoder::DestroyImpl() {

View File

@ -25,21 +25,20 @@ namespace dawn::native {
class RenderPassEncoder final : public RenderEncoderBase {
public:
RenderPassEncoder(DeviceBase* device,
const RenderPassDescriptor* descriptor,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
RenderPassResourceUsageTracker usageTracker,
Ref<AttachmentState> attachmentState,
std::vector<TimestampWrite> timestampWritesAtEnd,
uint32_t renderTargetWidth,
uint32_t renderTargetHeight,
bool depthReadOnly,
bool stencilReadOnly);
static RenderPassEncoder* MakeError(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext);
static Ref<RenderPassEncoder> Create(DeviceBase* device,
const RenderPassDescriptor* descriptor,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
RenderPassResourceUsageTracker usageTracker,
Ref<AttachmentState> attachmentState,
std::vector<TimestampWrite> timestampWritesAtEnd,
uint32_t renderTargetWidth,
uint32_t renderTargetHeight,
bool depthReadOnly,
bool stencilReadOnly);
static Ref<RenderPassEncoder> MakeError(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext);
ObjectType GetType() const override;
@ -63,6 +62,17 @@ namespace dawn::native {
void APIWriteTimestamp(QuerySetBase* querySet, uint32_t queryIndex);
protected:
RenderPassEncoder(DeviceBase* device,
const RenderPassDescriptor* descriptor,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
RenderPassResourceUsageTracker usageTracker,
Ref<AttachmentState> attachmentState,
std::vector<TimestampWrite> timestampWritesAtEnd,
uint32_t renderTargetWidth,
uint32_t renderTargetHeight,
bool depthReadOnly,
bool stencilReadOnly);
RenderPassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,