Add/update destruction handling for command encoding objects

- Renames ProgrammablePassEncoder to just ProgrammableEncoder since it is also used in RenderBundleEncoder which is not a "pass"
- Adds testing infrastructure to further test device errors
- Ensures AttachmentStates are de-reffed when encoder objects are destroyed for proper cleanup
- Makes sure that both encoded and partial encoded commands are freed at destruction

Bug: dawn:628
Change-Id: Id62ab02d54461c4da266963035e8666799f61e9a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/68461
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
This commit is contained in:
Loko Kung 2021-11-16 22:46:34 +00:00 committed by Dawn LUCI CQ
parent 8f4eacd082
commit 970739e4e3
24 changed files with 298 additions and 81 deletions

View File

@ -282,8 +282,8 @@ source_set("dawn_native_sources") {
"PipelineLayout.h", "PipelineLayout.h",
"PooledResourceMemoryAllocator.cpp", "PooledResourceMemoryAllocator.cpp",
"PooledResourceMemoryAllocator.h", "PooledResourceMemoryAllocator.h",
"ProgrammablePassEncoder.cpp", "ProgrammableEncoder.cpp",
"ProgrammablePassEncoder.h", "ProgrammableEncoder.h",
"QueryHelper.cpp", "QueryHelper.cpp",
"QueryHelper.h", "QueryHelper.h",
"QuerySet.cpp", "QuerySet.cpp",

View File

@ -127,8 +127,8 @@ target_sources(dawn_native PRIVATE
"PipelineLayout.h" "PipelineLayout.h"
"PooledResourceMemoryAllocator.cpp" "PooledResourceMemoryAllocator.cpp"
"PooledResourceMemoryAllocator.h" "PooledResourceMemoryAllocator.h"
"ProgrammablePassEncoder.cpp" "ProgrammableEncoder.cpp"
"ProgrammablePassEncoder.h" "ProgrammableEncoder.h"
"QueryHelper.cpp" "QueryHelper.cpp"
"QueryHelper.h" "QueryHelper.h"
"QuerySet.cpp" "QuerySet.cpp"

View File

@ -36,7 +36,6 @@ namespace dawn_native {
static CommandBufferBase* MakeError(DeviceBase* device); static CommandBufferBase* MakeError(DeviceBase* device);
void DestroyImpl() override;
ObjectType GetType() const override; ObjectType GetType() const override;
MaybeError ValidateCanUseInSubmitNow() const; MaybeError ValidateCanUseInSubmitNow() const;
@ -54,6 +53,8 @@ namespace dawn_native {
private: private:
CommandBufferBase(DeviceBase* device, ObjectBase::ErrorTag tag); CommandBufferBase(DeviceBase* device, ObjectBase::ErrorTag tag);
void DestroyImpl() override;
CommandBufferResourceUsage mResourceUsages; CommandBufferResourceUsage mResourceUsages;
}; };

View File

@ -468,12 +468,17 @@ namespace dawn_native {
CommandEncoder::CommandEncoder(DeviceBase* device, const CommandEncoderDescriptor*) CommandEncoder::CommandEncoder(DeviceBase* device, const CommandEncoderDescriptor*)
: ApiObjectBase(device, kLabelNotImplemented), mEncodingContext(device, this) { : ApiObjectBase(device, kLabelNotImplemented), mEncodingContext(device, this) {
TrackInDevice();
} }
ObjectType CommandEncoder::GetType() const { ObjectType CommandEncoder::GetType() const {
return ObjectType::CommandEncoder; return ObjectType::CommandEncoder;
} }
void CommandEncoder::DestroyImpl() {
mEncodingContext.Destroy();
}
CommandBufferResourceUsage CommandEncoder::AcquireResourceUsages() { CommandBufferResourceUsage CommandEncoder::AcquireResourceUsages() {
return CommandBufferResourceUsage{ return CommandBufferResourceUsage{
mEncodingContext.AcquireRenderPassUsages(), mEncodingContext.AcquireComputePassUsages(), mEncodingContext.AcquireRenderPassUsages(), mEncodingContext.AcquireComputePassUsages(),

View File

@ -30,7 +30,7 @@ namespace dawn_native {
public: public:
CommandEncoder(DeviceBase* device, const CommandEncoderDescriptor* descriptor); CommandEncoder(DeviceBase* device, const CommandEncoderDescriptor* descriptor);
ObjectType GetType() const; ObjectType GetType() const override;
CommandIterator AcquireCommands(); CommandIterator AcquireCommands();
CommandBufferResourceUsage AcquireResourceUsages(); CommandBufferResourceUsage AcquireResourceUsages();
@ -79,6 +79,7 @@ namespace dawn_native {
CommandBufferBase* APIFinish(const CommandBufferDescriptor* descriptor = nullptr); CommandBufferBase* APIFinish(const CommandBufferDescriptor* descriptor = nullptr);
private: private:
void DestroyImpl() override;
ResultOrError<Ref<CommandBufferBase>> FinishInternal( ResultOrError<Ref<CommandBufferBase>> FinishInternal(
const CommandBufferDescriptor* descriptor); const CommandBufferDescriptor* descriptor);

View File

@ -105,15 +105,15 @@ namespace dawn_native {
ComputePassEncoder::ComputePassEncoder(DeviceBase* device, ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder, CommandEncoder* commandEncoder,
EncodingContext* encodingContext) EncodingContext* encodingContext)
: ProgrammablePassEncoder(device, encodingContext), mCommandEncoder(commandEncoder) { : ProgrammableEncoder(device, encodingContext), mCommandEncoder(commandEncoder) {
TrackInDevice();
} }
ComputePassEncoder::ComputePassEncoder(DeviceBase* device, ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder, CommandEncoder* commandEncoder,
EncodingContext* encodingContext, EncodingContext* encodingContext,
ErrorTag errorTag) ErrorTag errorTag)
: ProgrammablePassEncoder(device, encodingContext, errorTag), : ProgrammableEncoder(device, encodingContext, errorTag), mCommandEncoder(commandEncoder) {
mCommandEncoder(commandEncoder) {
} }
ComputePassEncoder* ComputePassEncoder::MakeError(DeviceBase* device, ComputePassEncoder* ComputePassEncoder::MakeError(DeviceBase* device,
@ -122,6 +122,13 @@ namespace dawn_native {
return new ComputePassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError); return new ComputePassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError);
} }
void ComputePassEncoder::DestroyImpl() {
ApiObjectBase::DestroyImpl();
// Ensure that the pass has exited. This is done for passes only since validation requires
// they exit before destruction while bundles do not.
mEncodingContext->EnsurePassExited(this);
}
ObjectType ComputePassEncoder::GetType() const { ObjectType ComputePassEncoder::GetType() const {
return ObjectType::ComputePassEncoder; return ObjectType::ComputePassEncoder;
} }

View File

@ -19,13 +19,13 @@
#include "dawn_native/Error.h" #include "dawn_native/Error.h"
#include "dawn_native/Forward.h" #include "dawn_native/Forward.h"
#include "dawn_native/PassResourceUsageTracker.h" #include "dawn_native/PassResourceUsageTracker.h"
#include "dawn_native/ProgrammablePassEncoder.h" #include "dawn_native/ProgrammableEncoder.h"
namespace dawn_native { namespace dawn_native {
class SyncScopeUsageTracker; class SyncScopeUsageTracker;
class ComputePassEncoder final : public ProgrammablePassEncoder { class ComputePassEncoder final : public ProgrammableEncoder {
public: public:
ComputePassEncoder(DeviceBase* device, ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder, CommandEncoder* commandEncoder,
@ -62,6 +62,8 @@ namespace dawn_native {
ErrorTag errorTag); ErrorTag errorTag);
private: private:
void DestroyImpl() override;
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ValidateIndirectDispatch( ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ValidateIndirectDispatch(
BufferBase* indirectBuffer, BufferBase* indirectBuffer,
uint64_t indirectOffset); uint64_t indirectOffset);

View File

@ -266,10 +266,19 @@ namespace dawn_native {
// a ref to A, then B depends on A. We therefore try to destroy B before destroying A. Note // a ref to A, then B depends on A. We therefore try to destroy B before destroying A. Note
// that this only considers the immediate frontend dependencies, while backend objects could // that this only considers the immediate frontend dependencies, while backend objects could
// add complications and extra dependencies. // add complications and extra dependencies.
// TODO(dawn/628) Add types into the array as they are implemented. //
// Note that AttachmentState is not an ApiObject so it cannot be eagerly destroyed. However,
// since AttachmentStates are cached by the device, objects that hold references to
// AttachmentStates should make sure to un-ref them in their Destroy operation so that we
// can destroy the frontend cache.
// clang-format off // clang-format off
static constexpr std::array<ObjectType, 14> kObjectTypeDependencyOrder = { static constexpr std::array<ObjectType, 19> kObjectTypeDependencyOrder = {
ObjectType::ComputePassEncoder,
ObjectType::RenderPassEncoder,
ObjectType::RenderBundleEncoder,
ObjectType::RenderBundle,
ObjectType::CommandEncoder,
ObjectType::CommandBuffer, ObjectType::CommandBuffer,
ObjectType::RenderPipeline, ObjectType::RenderPipeline,
ObjectType::ComputePipeline, ObjectType::ComputePipeline,

View File

@ -29,9 +29,23 @@ namespace dawn_native {
} }
EncodingContext::~EncodingContext() { EncodingContext::~EncodingContext() {
Destroy();
}
void EncodingContext::Destroy() {
if (mDestroyed) {
return;
}
if (!mWereCommandsAcquired) { if (!mWereCommandsAcquired) {
FreeCommands(GetIterator()); FreeCommands(GetIterator());
} }
// If we weren't already finished, then we want to handle an error here so that any calls
// to Finish after Destroy will return a meaningful error.
if (!IsFinished()) {
HandleError(DAWN_FORMAT_VALIDATION_ERROR("Destroyed encoder cannot be finished."));
}
mDestroyed = true;
mCurrentEncoder = nullptr;
} }
CommandIterator EncodingContext::AcquireCommands() { CommandIterator EncodingContext::AcquireCommands() {

View File

@ -37,6 +37,10 @@ namespace dawn_native {
EncodingContext(DeviceBase* device, const ApiObjectBase* initialEncoder); EncodingContext(DeviceBase* device, const ApiObjectBase* initialEncoder);
~EncodingContext(); ~EncodingContext();
// Marks the encoding context as destroyed so that any future encodes will fail, and all
// encoded commands are released.
void Destroy();
CommandIterator AcquireCommands(); CommandIterator AcquireCommands();
CommandIterator* GetIterator(); CommandIterator* GetIterator();
@ -75,7 +79,10 @@ namespace dawn_native {
inline bool CheckCurrentEncoder(const ApiObjectBase* encoder) { inline bool CheckCurrentEncoder(const ApiObjectBase* encoder) {
if (DAWN_UNLIKELY(encoder != mCurrentEncoder)) { if (DAWN_UNLIKELY(encoder != mCurrentEncoder)) {
if (mCurrentEncoder != mTopLevelEncoder) { if (mDestroyed) {
HandleError(
DAWN_FORMAT_VALIDATION_ERROR("Recording in a destroyed %s.", encoder));
} else if (mCurrentEncoder != mTopLevelEncoder) {
// The top level encoder was used when a pass encoder was current. // The top level encoder was used when a pass encoder was current.
HandleError(DAWN_FORMAT_VALIDATION_ERROR( HandleError(DAWN_FORMAT_VALIDATION_ERROR(
"Command cannot be recorded while %s is active.", mCurrentEncoder)); "Command cannot be recorded while %s is active.", mCurrentEncoder));
@ -164,6 +171,7 @@ namespace dawn_native {
CommandIterator mIterator; CommandIterator mIterator;
bool mWasMovedToIterator = false; bool mWasMovedToIterator = false;
bool mWereCommandsAcquired = false; bool mWereCommandsAcquired = false;
bool mDestroyed = false;
std::unique_ptr<ErrorData> mError; std::unique_ptr<ErrorData> mError;
std::vector<std::string> mDebugGroupLabels; std::vector<std::string> mDebugGroupLabels;

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "dawn_native/ProgrammablePassEncoder.h" #include "dawn_native/ProgrammableEncoder.h"
#include "common/BitSetIterator.h" #include "common/BitSetIterator.h"
#include "common/ityp_array.h" #include "common/ityp_array.h"
@ -28,40 +28,32 @@
namespace dawn_native { namespace dawn_native {
ProgrammablePassEncoder::ProgrammablePassEncoder(DeviceBase* device, ProgrammableEncoder::ProgrammableEncoder(DeviceBase* device, EncodingContext* encodingContext)
EncodingContext* encodingContext)
: ApiObjectBase(device, kLabelNotImplemented), : ApiObjectBase(device, kLabelNotImplemented),
mEncodingContext(encodingContext), mEncodingContext(encodingContext),
mValidationEnabled(device->IsValidationEnabled()) { mValidationEnabled(device->IsValidationEnabled()) {
} }
ProgrammablePassEncoder::ProgrammablePassEncoder(DeviceBase* device, ProgrammableEncoder::ProgrammableEncoder(DeviceBase* device,
EncodingContext* encodingContext, EncodingContext* encodingContext,
ErrorTag errorTag) ErrorTag errorTag)
: ApiObjectBase(device, errorTag), : ApiObjectBase(device, errorTag),
mEncodingContext(encodingContext), mEncodingContext(encodingContext),
mValidationEnabled(device->IsValidationEnabled()) { mValidationEnabled(device->IsValidationEnabled()) {
} }
void ProgrammablePassEncoder::DeleteThis() { bool ProgrammableEncoder::IsValidationEnabled() const {
// This must be called prior to the destructor because it may generate an error message
// which calls the virtual RenderPassEncoder->GetType() as part of it's formatting.
mEncodingContext->EnsurePassExited(this);
ApiObjectBase::DeleteThis();
}
bool ProgrammablePassEncoder::IsValidationEnabled() const {
return mValidationEnabled; return mValidationEnabled;
} }
MaybeError ProgrammablePassEncoder::ValidateProgrammableEncoderEnd() const { MaybeError ProgrammableEncoder::ValidateProgrammableEncoderEnd() const {
DAWN_INVALID_IF(mDebugGroupStackSize != 0, DAWN_INVALID_IF(mDebugGroupStackSize != 0,
"PushDebugGroup called %u time(s) without a corresponding PopDebugGroup.", "PushDebugGroup called %u time(s) without a corresponding PopDebugGroup.",
mDebugGroupStackSize); mDebugGroupStackSize);
return {}; return {};
} }
void ProgrammablePassEncoder::APIInsertDebugMarker(const char* groupLabel) { void ProgrammableEncoder::APIInsertDebugMarker(const char* groupLabel) {
mEncodingContext->TryEncode( mEncodingContext->TryEncode(
this, this,
[&](CommandAllocator* allocator) -> MaybeError { [&](CommandAllocator* allocator) -> MaybeError {
@ -77,7 +69,7 @@ namespace dawn_native {
"encoding %s.InsertDebugMarker(\"%s\").", this, groupLabel); "encoding %s.InsertDebugMarker(\"%s\").", this, groupLabel);
} }
void ProgrammablePassEncoder::APIPopDebugGroup() { void ProgrammableEncoder::APIPopDebugGroup() {
mEncodingContext->TryEncode( mEncodingContext->TryEncode(
this, this,
[&](CommandAllocator* allocator) -> MaybeError { [&](CommandAllocator* allocator) -> MaybeError {
@ -95,7 +87,7 @@ namespace dawn_native {
"encoding %s.PopDebugGroup().", this); "encoding %s.PopDebugGroup().", this);
} }
void ProgrammablePassEncoder::APIPushDebugGroup(const char* groupLabel) { void ProgrammableEncoder::APIPushDebugGroup(const char* groupLabel) {
mEncodingContext->TryEncode( mEncodingContext->TryEncode(
this, this,
[&](CommandAllocator* allocator) -> MaybeError { [&](CommandAllocator* allocator) -> MaybeError {
@ -114,11 +106,10 @@ namespace dawn_native {
"encoding %s.PushDebugGroup(\"%s\").", this, groupLabel); "encoding %s.PushDebugGroup(\"%s\").", this, groupLabel);
} }
MaybeError ProgrammablePassEncoder::ValidateSetBindGroup( MaybeError ProgrammableEncoder::ValidateSetBindGroup(BindGroupIndex index,
BindGroupIndex index, BindGroupBase* group,
BindGroupBase* group, uint32_t dynamicOffsetCountIn,
uint32_t dynamicOffsetCountIn, const uint32_t* dynamicOffsetsIn) const {
const uint32_t* dynamicOffsetsIn) const {
DAWN_TRY(GetDevice()->ValidateObject(group)); DAWN_TRY(GetDevice()->ValidateObject(group));
DAWN_INVALID_IF(index >= kMaxBindGroupsTyped, DAWN_INVALID_IF(index >= kMaxBindGroupsTyped,
@ -192,11 +183,11 @@ namespace dawn_native {
return {}; return {};
} }
void ProgrammablePassEncoder::RecordSetBindGroup(CommandAllocator* allocator, void ProgrammableEncoder::RecordSetBindGroup(CommandAllocator* allocator,
BindGroupIndex index, BindGroupIndex index,
BindGroupBase* group, BindGroupBase* group,
uint32_t dynamicOffsetCount, uint32_t dynamicOffsetCount,
const uint32_t* dynamicOffsets) const { const uint32_t* dynamicOffsets) const {
SetBindGroupCmd* cmd = allocator->Allocate<SetBindGroupCmd>(Command::SetBindGroup); SetBindGroupCmd* cmd = allocator->Allocate<SetBindGroupCmd>(Command::SetBindGroup);
cmd->index = index; cmd->index = index;
cmd->group = group; cmd->group = group;

View File

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef DAWNNATIVE_PROGRAMMABLEPASSENCODER_H_ #ifndef DAWNNATIVE_PROGRAMMABLEENCODER_H_
#define DAWNNATIVE_PROGRAMMABLEPASSENCODER_H_ #define DAWNNATIVE_PROGRAMMABLEENCODER_H_
#include "dawn_native/CommandEncoder.h" #include "dawn_native/CommandEncoder.h"
#include "dawn_native/Error.h" #include "dawn_native/Error.h"
@ -27,10 +27,10 @@ namespace dawn_native {
class DeviceBase; class DeviceBase;
// Base class for shared functionality between ComputePassEncoder and RenderPassEncoder. // Base class for shared functionality between programmable encoders.
class ProgrammablePassEncoder : public ApiObjectBase { class ProgrammableEncoder : public ApiObjectBase {
public: public:
ProgrammablePassEncoder(DeviceBase* device, EncodingContext* encodingContext); ProgrammableEncoder(DeviceBase* device, EncodingContext* encodingContext);
void APIInsertDebugMarker(const char* groupLabel); void APIInsertDebugMarker(const char* groupLabel);
void APIPopDebugGroup(); void APIPopDebugGroup();
@ -40,8 +40,6 @@ namespace dawn_native {
bool IsValidationEnabled() const; bool IsValidationEnabled() const;
MaybeError ValidateProgrammableEncoderEnd() const; MaybeError ValidateProgrammableEncoderEnd() const;
virtual void DeleteThis() override;
// Compute and render passes do different things on SetBindGroup. These are helper functions // Compute and render passes do different things on SetBindGroup. These are helper functions
// for the logic they have in common. // for the logic they have in common.
MaybeError ValidateSetBindGroup(BindGroupIndex index, MaybeError ValidateSetBindGroup(BindGroupIndex index,
@ -55,9 +53,9 @@ namespace dawn_native {
const uint32_t* dynamicOffsets) const; const uint32_t* dynamicOffsets) const;
// Construct an "error" programmable pass encoder. // Construct an "error" programmable pass encoder.
ProgrammablePassEncoder(DeviceBase* device, ProgrammableEncoder(DeviceBase* device,
EncodingContext* encodingContext, EncodingContext* encodingContext,
ErrorTag errorTag); ErrorTag errorTag);
EncodingContext* mEncodingContext = nullptr; EncodingContext* mEncodingContext = nullptr;
@ -69,4 +67,4 @@ namespace dawn_native {
} // namespace dawn_native } // namespace dawn_native
#endif // DAWNNATIVE_PROGRAMMABLEPASSENCODER_H_ #endif // DAWNNATIVE_PROGRAMMABLEENCODER_H_

View File

@ -36,10 +36,15 @@ namespace dawn_native {
mDepthReadOnly(depthReadOnly), mDepthReadOnly(depthReadOnly),
mStencilReadOnly(stencilReadOnly), mStencilReadOnly(stencilReadOnly),
mResourceUsage(std::move(resourceUsage)) { mResourceUsage(std::move(resourceUsage)) {
TrackInDevice();
} }
RenderBundleBase::~RenderBundleBase() { void RenderBundleBase::DestroyImpl() {
FreeCommands(&mCommands); FreeCommands(&mCommands);
// Remove reference to the attachment state so that we don't have lingering references to
// it preventing it from being uncached in the device.
mAttachmentState = nullptr;
} }
// static // static

View File

@ -33,7 +33,7 @@ namespace dawn_native {
struct RenderBundleDescriptor; struct RenderBundleDescriptor;
class RenderBundleEncoder; class RenderBundleEncoder;
class RenderBundleBase : public ApiObjectBase { class RenderBundleBase final : public ApiObjectBase {
public: public:
RenderBundleBase(RenderBundleEncoder* encoder, RenderBundleBase(RenderBundleEncoder* encoder,
const RenderBundleDescriptor* descriptor, const RenderBundleDescriptor* descriptor,
@ -55,12 +55,11 @@ namespace dawn_native {
const RenderPassResourceUsage& GetResourceUsage() const; const RenderPassResourceUsage& GetResourceUsage() const;
const IndirectDrawMetadata& GetIndirectDrawMetadata(); const IndirectDrawMetadata& GetIndirectDrawMetadata();
protected:
~RenderBundleBase() override;
private: private:
RenderBundleBase(DeviceBase* device, ErrorTag errorTag); RenderBundleBase(DeviceBase* device, ErrorTag errorTag);
void DestroyImpl() override;
CommandIterator mCommands; CommandIterator mCommands;
IndirectDrawMetadata mIndirectDrawMetadata; IndirectDrawMetadata mIndirectDrawMetadata;
Ref<AttachmentState> mAttachmentState; Ref<AttachmentState> mAttachmentState;

View File

@ -93,6 +93,7 @@ namespace dawn_native {
descriptor->depthReadOnly, descriptor->depthReadOnly,
descriptor->stencilReadOnly), descriptor->stencilReadOnly),
mBundleEncodingContext(device, this) { mBundleEncodingContext(device, this) {
TrackInDevice();
} }
RenderBundleEncoder::RenderBundleEncoder(DeviceBase* device, ErrorTag errorTag) RenderBundleEncoder::RenderBundleEncoder(DeviceBase* device, ErrorTag errorTag)
@ -100,6 +101,11 @@ namespace dawn_native {
mBundleEncodingContext(device, this) { mBundleEncodingContext(device, this) {
} }
void RenderBundleEncoder::DestroyImpl() {
RenderEncoderBase::DestroyImpl();
mBundleEncodingContext.Destroy();
}
// static // static
Ref<RenderBundleEncoder> RenderBundleEncoder::Create( Ref<RenderBundleEncoder> RenderBundleEncoder::Create(
DeviceBase* device, DeviceBase* device,

View File

@ -43,6 +43,8 @@ namespace dawn_native {
RenderBundleEncoder(DeviceBase* device, const RenderBundleEncoderDescriptor* descriptor); RenderBundleEncoder(DeviceBase* device, const RenderBundleEncoderDescriptor* descriptor);
RenderBundleEncoder(DeviceBase* device, ErrorTag errorTag); RenderBundleEncoder(DeviceBase* device, ErrorTag errorTag);
void DestroyImpl() override;
ResultOrError<RenderBundleBase*> FinishImpl(const RenderBundleDescriptor* descriptor); ResultOrError<RenderBundleBase*> FinishImpl(const RenderBundleDescriptor* descriptor);
MaybeError ValidateFinish(const RenderPassResourceUsage& usages) const; MaybeError ValidateFinish(const RenderPassResourceUsage& usages) const;

View File

@ -34,7 +34,7 @@ namespace dawn_native {
Ref<AttachmentState> attachmentState, Ref<AttachmentState> attachmentState,
bool depthReadOnly, bool depthReadOnly,
bool stencilReadOnly) bool stencilReadOnly)
: ProgrammablePassEncoder(device, encodingContext), : ProgrammableEncoder(device, encodingContext),
mIndirectDrawMetadata(device->GetLimits()), mIndirectDrawMetadata(device->GetLimits()),
mAttachmentState(std::move(attachmentState)), mAttachmentState(std::move(attachmentState)),
mDisableBaseVertex(device->IsToggleEnabled(Toggle::DisableBaseVertex)), mDisableBaseVertex(device->IsToggleEnabled(Toggle::DisableBaseVertex)),
@ -46,12 +46,18 @@ namespace dawn_native {
RenderEncoderBase::RenderEncoderBase(DeviceBase* device, RenderEncoderBase::RenderEncoderBase(DeviceBase* device,
EncodingContext* encodingContext, EncodingContext* encodingContext,
ErrorTag errorTag) ErrorTag errorTag)
: ProgrammablePassEncoder(device, encodingContext, errorTag), : ProgrammableEncoder(device, encodingContext, errorTag),
mIndirectDrawMetadata(device->GetLimits()), mIndirectDrawMetadata(device->GetLimits()),
mDisableBaseVertex(device->IsToggleEnabled(Toggle::DisableBaseVertex)), mDisableBaseVertex(device->IsToggleEnabled(Toggle::DisableBaseVertex)),
mDisableBaseInstance(device->IsToggleEnabled(Toggle::DisableBaseInstance)) { mDisableBaseInstance(device->IsToggleEnabled(Toggle::DisableBaseInstance)) {
} }
void RenderEncoderBase::DestroyImpl() {
// Remove reference to the attachment state so that we don't have lingering references to
// it preventing it from being uncached in the device.
mAttachmentState = nullptr;
}
const AttachmentState* RenderEncoderBase::GetAttachmentState() const { const AttachmentState* RenderEncoderBase::GetAttachmentState() const {
ASSERT(!IsError()); ASSERT(!IsError());
ASSERT(mAttachmentState != nullptr); ASSERT(mAttachmentState != nullptr);

View File

@ -20,11 +20,11 @@
#include "dawn_native/Error.h" #include "dawn_native/Error.h"
#include "dawn_native/IndirectDrawMetadata.h" #include "dawn_native/IndirectDrawMetadata.h"
#include "dawn_native/PassResourceUsageTracker.h" #include "dawn_native/PassResourceUsageTracker.h"
#include "dawn_native/ProgrammablePassEncoder.h" #include "dawn_native/ProgrammableEncoder.h"
namespace dawn_native { namespace dawn_native {
class RenderEncoderBase : public ProgrammablePassEncoder { class RenderEncoderBase : public ProgrammableEncoder {
public: public:
RenderEncoderBase(DeviceBase* device, RenderEncoderBase(DeviceBase* device,
EncodingContext* encodingContext, EncodingContext* encodingContext,
@ -67,6 +67,8 @@ namespace dawn_native {
// Construct an "error" render encoder base. // Construct an "error" render encoder base.
RenderEncoderBase(DeviceBase* device, EncodingContext* encodingContext, ErrorTag errorTag); RenderEncoderBase(DeviceBase* device, EncodingContext* encodingContext, ErrorTag errorTag);
void DestroyImpl() override;
CommandBufferStateTracker mCommandBufferState; CommandBufferStateTracker mCommandBufferState;
RenderPassResourceUsageTracker mUsageTracker; RenderPassResourceUsageTracker mUsageTracker;
IndirectDrawMetadata mIndirectDrawMetadata; IndirectDrawMetadata mIndirectDrawMetadata;

View File

@ -68,6 +68,7 @@ namespace dawn_native {
mRenderTargetHeight(renderTargetHeight), mRenderTargetHeight(renderTargetHeight),
mOcclusionQuerySet(occlusionQuerySet) { mOcclusionQuerySet(occlusionQuerySet) {
mUsageTracker = std::move(usageTracker); mUsageTracker = std::move(usageTracker);
TrackInDevice();
} }
RenderPassEncoder::RenderPassEncoder(DeviceBase* device, RenderPassEncoder::RenderPassEncoder(DeviceBase* device,
@ -83,6 +84,13 @@ namespace dawn_native {
return new RenderPassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError); return new RenderPassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError);
} }
void RenderPassEncoder::DestroyImpl() {
RenderEncoderBase::DestroyImpl();
// Ensure that the pass has exited. This is done for passes only since validation requires
// they exit before destruction while bundles do not.
mEncodingContext->EnsurePassExited(this);
}
ObjectType RenderPassEncoder::GetType() const { ObjectType RenderPassEncoder::GetType() const {
return ObjectType::RenderPassEncoder; return ObjectType::RenderPassEncoder;
} }

View File

@ -67,6 +67,8 @@ namespace dawn_native {
ErrorTag errorTag); ErrorTag errorTag);
private: private:
void DestroyImpl() override;
void TrackQueryAvailability(QuerySetBase* querySet, uint32_t queryIndex); void TrackQueryAvailability(QuerySetBase* querySet, uint32_t queryIndex);
// For render and compute passes, the encoding context is borrowed from the command encoder. // For render and compute passes, the encoding context is borrowed from the command encoder.

View File

@ -689,6 +689,10 @@ namespace dawn_native {
// Do not uncache the actual cached object if we are a blueprint or already destroyed. // Do not uncache the actual cached object if we are a blueprint or already destroyed.
GetDevice()->UncacheRenderPipeline(this); GetDevice()->UncacheRenderPipeline(this);
} }
// Remove reference to the attachment state so that we don't have lingering references to
// it preventing it from being uncached in the device.
mAttachmentState = nullptr;
return wasDestroyed; return wasDestroyed;
} }

View File

@ -12,10 +12,16 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <gmock/gmock.h>
#include "dawn_native/CommandEncoder.h"
#include "tests/unittests/validation/ValidationTest.h" #include "tests/unittests/validation/ValidationTest.h"
#include "utils/WGPUHelpers.h" #include "utils/WGPUHelpers.h"
using ::testing::HasSubstr;
class CommandBufferValidationTest : public ValidationTest {}; class CommandBufferValidationTest : public ValidationTest {};
// Test for an empty command buffer // Test for an empty command buffer
@ -39,7 +45,9 @@ TEST_F(CommandBufferValidationTest, EndedMidRenderPass) {
{ {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass); wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [RenderPassEncoder] was ended."));
} }
// Error case, command buffer ended mid-pass. Trying to use encoders after Finish // Error case, command buffer ended mid-pass. Trying to use encoders after Finish
@ -47,8 +55,12 @@ TEST_F(CommandBufferValidationTest, EndedMidRenderPass) {
{ {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass); wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(
ASSERT_DEVICE_ERROR(pass.EndPass()); encoder.Finish(),
HasSubstr("Command buffer recording ended before [RenderPassEncoder] was ended."));
ASSERT_DEVICE_ERROR(
pass.EndPass(),
HasSubstr("Recording in an error or already ended [RenderPassEncoder]."));
} }
} }
@ -66,7 +78,9 @@ TEST_F(CommandBufferValidationTest, EndedMidComputePass) {
{ {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [ComputePassEncoder] was ended."));
} }
// Error case, command buffer ended mid-pass. Trying to use encoders after Finish // Error case, command buffer ended mid-pass. Trying to use encoders after Finish
@ -74,8 +88,12 @@ TEST_F(CommandBufferValidationTest, EndedMidComputePass) {
{ {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(
ASSERT_DEVICE_ERROR(pass.EndPass()); encoder.Finish(),
HasSubstr("Command buffer recording ended before [ComputePassEncoder] was ended."));
ASSERT_DEVICE_ERROR(
pass.EndPass(),
HasSubstr("Recording in an error or already ended [ComputePassEncoder]."));
} }
} }
@ -97,7 +115,9 @@ TEST_F(CommandBufferValidationTest, RenderPassEndedTwice) {
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass); wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass(); pass.EndPass();
pass.EndPass(); pass.EndPass();
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Recording in an error or already ended [RenderPassEncoder]."));
} }
} }
@ -117,7 +137,9 @@ TEST_F(CommandBufferValidationTest, ComputePassEndedTwice) {
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.EndPass(); pass.EndPass();
pass.EndPass(); pass.EndPass();
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Recording in an error or already ended [ComputePassEncoder]."));
} }
} }
@ -223,14 +245,18 @@ TEST_F(CommandBufferValidationTest, PassDereferenced) {
{ {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.BeginRenderPass(&dummyRenderPass); encoder.BeginRenderPass(&dummyRenderPass);
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [RenderPassEncoder] was ended."));
} }
// Error case, no reference is kept to a compute pass. // Error case, no reference is kept to a compute pass.
{ {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.BeginComputePass(); encoder.BeginComputePass();
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [ComputePassEncoder] was ended."));
} }
// Error case, beginning a new pass after failing to end a de-referenced pass. // Error case, beginning a new pass after failing to end a de-referenced pass.
@ -239,21 +265,25 @@ TEST_F(CommandBufferValidationTest, PassDereferenced) {
encoder.BeginRenderPass(&dummyRenderPass); encoder.BeginRenderPass(&dummyRenderPass);
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.EndPass(); pass.EndPass();
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [RenderPassEncoder] was ended."));
} }
// Error case, deleting the pass after finishing the commend encoder shouldn't generate an // Error case, deleting the pass after finishing the command encoder shouldn't generate an
// uncaptured error. // uncaptured error.
{ {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [ComputePassEncoder] was ended."));
pass = nullptr; pass = nullptr;
} }
// Valid case, command encoder is never finished so the de-referenced pass shouldn't generate an // Valid case, command encoder is never finished so the de-referenced pass shouldn't
// uncaptured error. // generate an uncaptured error.
{ {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.BeginComputePass(); encoder.BeginComputePass();
@ -264,5 +294,80 @@ TEST_F(CommandBufferValidationTest, PassDereferenced) {
TEST_F(CommandBufferValidationTest, InjectValidationError) { TEST_F(CommandBufferValidationTest, InjectValidationError) {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.InjectValidationError("my error"); encoder.InjectValidationError("my error");
ASSERT_DEVICE_ERROR(encoder.Finish()); ASSERT_DEVICE_ERROR(encoder.Finish(), HasSubstr("my error"));
}
TEST_F(CommandBufferValidationTest, DestroyEncoder) {
// Skip these tests if we are using wire because the destroy functionality is not exposed
// and needs to use a cast to call manually. We cannot test this in the wire case since the
// only way to trigger the destroy call is by losing all references which means we cannot
// call finish.
DAWN_SKIP_TEST_IF(UsesWire());
DummyRenderPass dummyRenderPass(device);
// Control case, command buffer ended after the pass is ended.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
encoder.Finish();
}
// Destroyed encoder with encoded commands should emit error on finish.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
ASSERT_DEVICE_ERROR(encoder.Finish(), HasSubstr("Destroyed encoder cannot be finished."));
}
// Destroyed encoder with encoded commands shouldn't emit an error if never finished.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
}
// Destroyed encoder should allow encoding, and emit error on finish.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
ASSERT_DEVICE_ERROR(encoder.Finish(), HasSubstr("Destroyed encoder cannot be finished."));
}
// Destroyed encoder should allow encoding and shouldn't emit an error if never finished.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
}
// Destroying a finished encoder should not emit any errors.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
encoder.Finish();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
}
// Destroying an encoder twice should not emit any errors.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
}
// Destroying an encoder twice and then calling finish should fail.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
ASSERT_DEVICE_ERROR(encoder.Finish(), HasSubstr("Destroyed encoder cannot be finished."));
}
} }

View File

@ -129,12 +129,19 @@ void ValidationTest::TearDown() {
} }
} }
void ValidationTest::StartExpectDeviceError() { void ValidationTest::StartExpectDeviceError(testing::Matcher<std::string> errorMatcher) {
mExpectError = true; mExpectError = true;
mError = false; mError = false;
mErrorMatcher = errorMatcher;
} }
void ValidationTest::StartExpectDeviceError() {
StartExpectDeviceError(testing::_);
}
bool ValidationTest::EndExpectDeviceError() { bool ValidationTest::EndExpectDeviceError() {
mExpectError = false; mExpectError = false;
mErrorMatcher = testing::_;
return mError; return mError;
} }
std::string ValidationTest::GetLastDeviceErrorMessage() const { std::string ValidationTest::GetLastDeviceErrorMessage() const {
@ -210,6 +217,9 @@ void ValidationTest::OnDeviceError(WGPUErrorType type, const char* message, void
ASSERT_TRUE(self->mExpectError) << "Got unexpected device error: " << message; ASSERT_TRUE(self->mExpectError) << "Got unexpected device error: " << message;
ASSERT_FALSE(self->mError) << "Got two errors in expect block"; ASSERT_FALSE(self->mError) << "Got two errors in expect block";
if (self->mExpectError) {
ASSERT_THAT(message, self->mErrorMatcher);
}
self->mError = true; self->mError = true;
} }

View File

@ -19,10 +19,30 @@
#include "dawn/webgpu_cpp.h" #include "dawn/webgpu_cpp.h"
#include "dawn_native/DawnNative.h" #include "dawn_native/DawnNative.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#define ASSERT_DEVICE_ERROR(statement) \ // Argument helpers to allow macro overriding.
FlushWire(); \ #define UNIMPLEMENTED_MACRO(...) UNREACHABLE()
#define GET_3RD_ARG_HELPER_(_1, _2, NAME, ...) NAME
#define GET_3RD_ARG_(args) GET_3RD_ARG_HELPER_ args
// Overloaded to allow further validation of the error messages given an error is expected.
// Especially useful to verify that the expected errors are occuring, not just any error.
//
// Example usages:
// 1 Argument Case:
// ASSERT_DEVICE_ERROR(FunctionThatExpectsError());
//
// 2 Argument Case:
// ASSERT_DEVICE_ERROR(FunctionThatHasLongError(), HasSubstr("partial match"))
// ASSERT_DEVICE_ERROR(FunctionThatHasShortError(), Eq("exact match"));
#define ASSERT_DEVICE_ERROR(...) \
GET_3RD_ARG_((__VA_ARGS__, ASSERT_DEVICE_ERROR_IMPL_2_, ASSERT_DEVICE_ERROR_IMPL_1_, \
UNIMPLEMENTED_MACRO)) \
(__VA_ARGS__)
#define ASSERT_DEVICE_ERROR_IMPL_1_(statement) \
StartExpectDeviceError(); \ StartExpectDeviceError(); \
statement; \ statement; \
FlushWire(); \ FlushWire(); \
@ -32,6 +52,16 @@
do { \ do { \
} while (0) } while (0)
#define ASSERT_DEVICE_ERROR_IMPL_2_(statement, matcher) \
StartExpectDeviceError(matcher); \
statement; \
FlushWire(); \
if (!EndExpectDeviceError()) { \
FAIL() << "Expected device error in:\n " << #statement; \
} \
do { \
} while (0)
// Skip a test when the given condition is satisfied. // Skip a test when the given condition is satisfied.
#define DAWN_SKIP_TEST_IF(condition) \ #define DAWN_SKIP_TEST_IF(condition) \
do { \ do { \
@ -69,6 +99,7 @@ class ValidationTest : public testing::Test {
void SetUp() override; void SetUp() override;
void TearDown() override; void TearDown() override;
void StartExpectDeviceError(testing::Matcher<std::string> errorMatcher);
void StartExpectDeviceError(); void StartExpectDeviceError();
bool EndExpectDeviceError(); bool EndExpectDeviceError();
std::string GetLastDeviceErrorMessage() const; std::string GetLastDeviceErrorMessage() const;
@ -118,6 +149,7 @@ class ValidationTest : public testing::Test {
std::string mDeviceErrorMessage; std::string mDeviceErrorMessage;
bool mExpectError = false; bool mExpectError = false;
bool mError = false; bool mError = false;
testing::Matcher<std::string> mErrorMatcher;
}; };
#endif // TESTS_UNITTESTS_VALIDATIONTEST_H_ #endif // TESTS_UNITTESTS_VALIDATIONTEST_H_