Add/update destruction handling for command encoding objects

- Renames ProgrammablePassEncoder to just ProgrammableEncoder since it is also used in RenderBundleEncoder which is not a "pass"
- Adds testing infrastructure to further test device errors
- Ensures AttachmentStates are de-reffed when encoder objects are destroyed for proper cleanup
- Makes sure that both encoded and partial encoded commands are freed at destruction

Bug: dawn:628
Change-Id: Id62ab02d54461c4da266963035e8666799f61e9a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/68461
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
This commit is contained in:
Loko Kung 2021-11-16 22:46:34 +00:00 committed by Dawn LUCI CQ
parent 8f4eacd082
commit 970739e4e3
24 changed files with 298 additions and 81 deletions

View File

@ -282,8 +282,8 @@ source_set("dawn_native_sources") {
"PipelineLayout.h",
"PooledResourceMemoryAllocator.cpp",
"PooledResourceMemoryAllocator.h",
"ProgrammablePassEncoder.cpp",
"ProgrammablePassEncoder.h",
"ProgrammableEncoder.cpp",
"ProgrammableEncoder.h",
"QueryHelper.cpp",
"QueryHelper.h",
"QuerySet.cpp",

View File

@ -127,8 +127,8 @@ target_sources(dawn_native PRIVATE
"PipelineLayout.h"
"PooledResourceMemoryAllocator.cpp"
"PooledResourceMemoryAllocator.h"
"ProgrammablePassEncoder.cpp"
"ProgrammablePassEncoder.h"
"ProgrammableEncoder.cpp"
"ProgrammableEncoder.h"
"QueryHelper.cpp"
"QueryHelper.h"
"QuerySet.cpp"

View File

@ -36,7 +36,6 @@ namespace dawn_native {
static CommandBufferBase* MakeError(DeviceBase* device);
void DestroyImpl() override;
ObjectType GetType() const override;
MaybeError ValidateCanUseInSubmitNow() const;
@ -54,6 +53,8 @@ namespace dawn_native {
private:
CommandBufferBase(DeviceBase* device, ObjectBase::ErrorTag tag);
void DestroyImpl() override;
CommandBufferResourceUsage mResourceUsages;
};

View File

@ -468,12 +468,17 @@ namespace dawn_native {
CommandEncoder::CommandEncoder(DeviceBase* device, const CommandEncoderDescriptor*)
: ApiObjectBase(device, kLabelNotImplemented), mEncodingContext(device, this) {
TrackInDevice();
}
ObjectType CommandEncoder::GetType() const {
return ObjectType::CommandEncoder;
}
void CommandEncoder::DestroyImpl() {
mEncodingContext.Destroy();
}
CommandBufferResourceUsage CommandEncoder::AcquireResourceUsages() {
return CommandBufferResourceUsage{
mEncodingContext.AcquireRenderPassUsages(), mEncodingContext.AcquireComputePassUsages(),

View File

@ -30,7 +30,7 @@ namespace dawn_native {
public:
CommandEncoder(DeviceBase* device, const CommandEncoderDescriptor* descriptor);
ObjectType GetType() const;
ObjectType GetType() const override;
CommandIterator AcquireCommands();
CommandBufferResourceUsage AcquireResourceUsages();
@ -79,6 +79,7 @@ namespace dawn_native {
CommandBufferBase* APIFinish(const CommandBufferDescriptor* descriptor = nullptr);
private:
void DestroyImpl() override;
ResultOrError<Ref<CommandBufferBase>> FinishInternal(
const CommandBufferDescriptor* descriptor);

View File

@ -105,15 +105,15 @@ namespace dawn_native {
ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext)
: ProgrammablePassEncoder(device, encodingContext), mCommandEncoder(commandEncoder) {
: ProgrammableEncoder(device, encodingContext), mCommandEncoder(commandEncoder) {
TrackInDevice();
}
ComputePassEncoder::ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder,
EncodingContext* encodingContext,
ErrorTag errorTag)
: ProgrammablePassEncoder(device, encodingContext, errorTag),
mCommandEncoder(commandEncoder) {
: ProgrammableEncoder(device, encodingContext, errorTag), mCommandEncoder(commandEncoder) {
}
ComputePassEncoder* ComputePassEncoder::MakeError(DeviceBase* device,
@ -122,6 +122,13 @@ namespace dawn_native {
return new ComputePassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError);
}
void ComputePassEncoder::DestroyImpl() {
ApiObjectBase::DestroyImpl();
// Ensure that the pass has exited. This is done for passes only since validation requires
// they exit before destruction while bundles do not.
mEncodingContext->EnsurePassExited(this);
}
ObjectType ComputePassEncoder::GetType() const {
return ObjectType::ComputePassEncoder;
}

View File

@ -19,13 +19,13 @@
#include "dawn_native/Error.h"
#include "dawn_native/Forward.h"
#include "dawn_native/PassResourceUsageTracker.h"
#include "dawn_native/ProgrammablePassEncoder.h"
#include "dawn_native/ProgrammableEncoder.h"
namespace dawn_native {
class SyncScopeUsageTracker;
class ComputePassEncoder final : public ProgrammablePassEncoder {
class ComputePassEncoder final : public ProgrammableEncoder {
public:
ComputePassEncoder(DeviceBase* device,
CommandEncoder* commandEncoder,
@ -62,6 +62,8 @@ namespace dawn_native {
ErrorTag errorTag);
private:
void DestroyImpl() override;
ResultOrError<std::pair<Ref<BufferBase>, uint64_t>> ValidateIndirectDispatch(
BufferBase* indirectBuffer,
uint64_t indirectOffset);

View File

@ -266,10 +266,19 @@ namespace dawn_native {
// a ref to A, then B depends on A. We therefore try to destroy B before destroying A. Note
// that this only considers the immediate frontend dependencies, while backend objects could
// add complications and extra dependencies.
// TODO(dawn/628) Add types into the array as they are implemented.
//
// Note that AttachmentState is not an ApiObject so it cannot be eagerly destroyed. However,
// since AttachmentStates are cached by the device, objects that hold references to
// AttachmentStates should make sure to un-ref them in their Destroy operation so that we
// can destroy the frontend cache.
// clang-format off
static constexpr std::array<ObjectType, 14> kObjectTypeDependencyOrder = {
static constexpr std::array<ObjectType, 19> kObjectTypeDependencyOrder = {
ObjectType::ComputePassEncoder,
ObjectType::RenderPassEncoder,
ObjectType::RenderBundleEncoder,
ObjectType::RenderBundle,
ObjectType::CommandEncoder,
ObjectType::CommandBuffer,
ObjectType::RenderPipeline,
ObjectType::ComputePipeline,

View File

@ -29,9 +29,23 @@ namespace dawn_native {
}
EncodingContext::~EncodingContext() {
Destroy();
}
void EncodingContext::Destroy() {
if (mDestroyed) {
return;
}
if (!mWereCommandsAcquired) {
FreeCommands(GetIterator());
}
// If we weren't already finished, then we want to handle an error here so that any calls
// to Finish after Destroy will return a meaningful error.
if (!IsFinished()) {
HandleError(DAWN_FORMAT_VALIDATION_ERROR("Destroyed encoder cannot be finished."));
}
mDestroyed = true;
mCurrentEncoder = nullptr;
}
CommandIterator EncodingContext::AcquireCommands() {

View File

@ -37,6 +37,10 @@ namespace dawn_native {
EncodingContext(DeviceBase* device, const ApiObjectBase* initialEncoder);
~EncodingContext();
// Marks the encoding context as destroyed so that any future encodes will fail, and all
// encoded commands are released.
void Destroy();
CommandIterator AcquireCommands();
CommandIterator* GetIterator();
@ -75,7 +79,10 @@ namespace dawn_native {
inline bool CheckCurrentEncoder(const ApiObjectBase* encoder) {
if (DAWN_UNLIKELY(encoder != mCurrentEncoder)) {
if (mCurrentEncoder != mTopLevelEncoder) {
if (mDestroyed) {
HandleError(
DAWN_FORMAT_VALIDATION_ERROR("Recording in a destroyed %s.", encoder));
} else if (mCurrentEncoder != mTopLevelEncoder) {
// The top level encoder was used when a pass encoder was current.
HandleError(DAWN_FORMAT_VALIDATION_ERROR(
"Command cannot be recorded while %s is active.", mCurrentEncoder));
@ -164,6 +171,7 @@ namespace dawn_native {
CommandIterator mIterator;
bool mWasMovedToIterator = false;
bool mWereCommandsAcquired = false;
bool mDestroyed = false;
std::unique_ptr<ErrorData> mError;
std::vector<std::string> mDebugGroupLabels;

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn_native/ProgrammablePassEncoder.h"
#include "dawn_native/ProgrammableEncoder.h"
#include "common/BitSetIterator.h"
#include "common/ityp_array.h"
@ -28,40 +28,32 @@
namespace dawn_native {
ProgrammablePassEncoder::ProgrammablePassEncoder(DeviceBase* device,
EncodingContext* encodingContext)
ProgrammableEncoder::ProgrammableEncoder(DeviceBase* device, EncodingContext* encodingContext)
: ApiObjectBase(device, kLabelNotImplemented),
mEncodingContext(encodingContext),
mValidationEnabled(device->IsValidationEnabled()) {
}
ProgrammablePassEncoder::ProgrammablePassEncoder(DeviceBase* device,
EncodingContext* encodingContext,
ErrorTag errorTag)
ProgrammableEncoder::ProgrammableEncoder(DeviceBase* device,
EncodingContext* encodingContext,
ErrorTag errorTag)
: ApiObjectBase(device, errorTag),
mEncodingContext(encodingContext),
mValidationEnabled(device->IsValidationEnabled()) {
}
void ProgrammablePassEncoder::DeleteThis() {
// This must be called prior to the destructor because it may generate an error message
// which calls the virtual RenderPassEncoder->GetType() as part of it's formatting.
mEncodingContext->EnsurePassExited(this);
ApiObjectBase::DeleteThis();
}
bool ProgrammablePassEncoder::IsValidationEnabled() const {
bool ProgrammableEncoder::IsValidationEnabled() const {
return mValidationEnabled;
}
MaybeError ProgrammablePassEncoder::ValidateProgrammableEncoderEnd() const {
MaybeError ProgrammableEncoder::ValidateProgrammableEncoderEnd() const {
DAWN_INVALID_IF(mDebugGroupStackSize != 0,
"PushDebugGroup called %u time(s) without a corresponding PopDebugGroup.",
mDebugGroupStackSize);
return {};
}
void ProgrammablePassEncoder::APIInsertDebugMarker(const char* groupLabel) {
void ProgrammableEncoder::APIInsertDebugMarker(const char* groupLabel) {
mEncodingContext->TryEncode(
this,
[&](CommandAllocator* allocator) -> MaybeError {
@ -77,7 +69,7 @@ namespace dawn_native {
"encoding %s.InsertDebugMarker(\"%s\").", this, groupLabel);
}
void ProgrammablePassEncoder::APIPopDebugGroup() {
void ProgrammableEncoder::APIPopDebugGroup() {
mEncodingContext->TryEncode(
this,
[&](CommandAllocator* allocator) -> MaybeError {
@ -95,7 +87,7 @@ namespace dawn_native {
"encoding %s.PopDebugGroup().", this);
}
void ProgrammablePassEncoder::APIPushDebugGroup(const char* groupLabel) {
void ProgrammableEncoder::APIPushDebugGroup(const char* groupLabel) {
mEncodingContext->TryEncode(
this,
[&](CommandAllocator* allocator) -> MaybeError {
@ -114,11 +106,10 @@ namespace dawn_native {
"encoding %s.PushDebugGroup(\"%s\").", this, groupLabel);
}
MaybeError ProgrammablePassEncoder::ValidateSetBindGroup(
BindGroupIndex index,
BindGroupBase* group,
uint32_t dynamicOffsetCountIn,
const uint32_t* dynamicOffsetsIn) const {
MaybeError ProgrammableEncoder::ValidateSetBindGroup(BindGroupIndex index,
BindGroupBase* group,
uint32_t dynamicOffsetCountIn,
const uint32_t* dynamicOffsetsIn) const {
DAWN_TRY(GetDevice()->ValidateObject(group));
DAWN_INVALID_IF(index >= kMaxBindGroupsTyped,
@ -192,11 +183,11 @@ namespace dawn_native {
return {};
}
void ProgrammablePassEncoder::RecordSetBindGroup(CommandAllocator* allocator,
BindGroupIndex index,
BindGroupBase* group,
uint32_t dynamicOffsetCount,
const uint32_t* dynamicOffsets) const {
void ProgrammableEncoder::RecordSetBindGroup(CommandAllocator* allocator,
BindGroupIndex index,
BindGroupBase* group,
uint32_t dynamicOffsetCount,
const uint32_t* dynamicOffsets) const {
SetBindGroupCmd* cmd = allocator->Allocate<SetBindGroupCmd>(Command::SetBindGroup);
cmd->index = index;
cmd->group = group;

View File

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DAWNNATIVE_PROGRAMMABLEPASSENCODER_H_
#define DAWNNATIVE_PROGRAMMABLEPASSENCODER_H_
#ifndef DAWNNATIVE_PROGRAMMABLEENCODER_H_
#define DAWNNATIVE_PROGRAMMABLEENCODER_H_
#include "dawn_native/CommandEncoder.h"
#include "dawn_native/Error.h"
@ -27,10 +27,10 @@ namespace dawn_native {
class DeviceBase;
// Base class for shared functionality between ComputePassEncoder and RenderPassEncoder.
class ProgrammablePassEncoder : public ApiObjectBase {
// Base class for shared functionality between programmable encoders.
class ProgrammableEncoder : public ApiObjectBase {
public:
ProgrammablePassEncoder(DeviceBase* device, EncodingContext* encodingContext);
ProgrammableEncoder(DeviceBase* device, EncodingContext* encodingContext);
void APIInsertDebugMarker(const char* groupLabel);
void APIPopDebugGroup();
@ -40,8 +40,6 @@ namespace dawn_native {
bool IsValidationEnabled() const;
MaybeError ValidateProgrammableEncoderEnd() const;
virtual void DeleteThis() override;
// Compute and render passes do different things on SetBindGroup. These are helper functions
// for the logic they have in common.
MaybeError ValidateSetBindGroup(BindGroupIndex index,
@ -55,9 +53,9 @@ namespace dawn_native {
const uint32_t* dynamicOffsets) const;
// Construct an "error" programmable pass encoder.
ProgrammablePassEncoder(DeviceBase* device,
EncodingContext* encodingContext,
ErrorTag errorTag);
ProgrammableEncoder(DeviceBase* device,
EncodingContext* encodingContext,
ErrorTag errorTag);
EncodingContext* mEncodingContext = nullptr;
@ -69,4 +67,4 @@ namespace dawn_native {
} // namespace dawn_native
#endif // DAWNNATIVE_PROGRAMMABLEPASSENCODER_H_
#endif // DAWNNATIVE_PROGRAMMABLEENCODER_H_

View File

@ -36,10 +36,15 @@ namespace dawn_native {
mDepthReadOnly(depthReadOnly),
mStencilReadOnly(stencilReadOnly),
mResourceUsage(std::move(resourceUsage)) {
TrackInDevice();
}
RenderBundleBase::~RenderBundleBase() {
void RenderBundleBase::DestroyImpl() {
FreeCommands(&mCommands);
// Remove reference to the attachment state so that we don't have lingering references to
// it preventing it from being uncached in the device.
mAttachmentState = nullptr;
}
// static

View File

@ -33,7 +33,7 @@ namespace dawn_native {
struct RenderBundleDescriptor;
class RenderBundleEncoder;
class RenderBundleBase : public ApiObjectBase {
class RenderBundleBase final : public ApiObjectBase {
public:
RenderBundleBase(RenderBundleEncoder* encoder,
const RenderBundleDescriptor* descriptor,
@ -55,12 +55,11 @@ namespace dawn_native {
const RenderPassResourceUsage& GetResourceUsage() const;
const IndirectDrawMetadata& GetIndirectDrawMetadata();
protected:
~RenderBundleBase() override;
private:
RenderBundleBase(DeviceBase* device, ErrorTag errorTag);
void DestroyImpl() override;
CommandIterator mCommands;
IndirectDrawMetadata mIndirectDrawMetadata;
Ref<AttachmentState> mAttachmentState;

View File

@ -93,6 +93,7 @@ namespace dawn_native {
descriptor->depthReadOnly,
descriptor->stencilReadOnly),
mBundleEncodingContext(device, this) {
TrackInDevice();
}
RenderBundleEncoder::RenderBundleEncoder(DeviceBase* device, ErrorTag errorTag)
@ -100,6 +101,11 @@ namespace dawn_native {
mBundleEncodingContext(device, this) {
}
void RenderBundleEncoder::DestroyImpl() {
RenderEncoderBase::DestroyImpl();
mBundleEncodingContext.Destroy();
}
// static
Ref<RenderBundleEncoder> RenderBundleEncoder::Create(
DeviceBase* device,

View File

@ -43,6 +43,8 @@ namespace dawn_native {
RenderBundleEncoder(DeviceBase* device, const RenderBundleEncoderDescriptor* descriptor);
RenderBundleEncoder(DeviceBase* device, ErrorTag errorTag);
void DestroyImpl() override;
ResultOrError<RenderBundleBase*> FinishImpl(const RenderBundleDescriptor* descriptor);
MaybeError ValidateFinish(const RenderPassResourceUsage& usages) const;

View File

@ -34,7 +34,7 @@ namespace dawn_native {
Ref<AttachmentState> attachmentState,
bool depthReadOnly,
bool stencilReadOnly)
: ProgrammablePassEncoder(device, encodingContext),
: ProgrammableEncoder(device, encodingContext),
mIndirectDrawMetadata(device->GetLimits()),
mAttachmentState(std::move(attachmentState)),
mDisableBaseVertex(device->IsToggleEnabled(Toggle::DisableBaseVertex)),
@ -46,12 +46,18 @@ namespace dawn_native {
RenderEncoderBase::RenderEncoderBase(DeviceBase* device,
EncodingContext* encodingContext,
ErrorTag errorTag)
: ProgrammablePassEncoder(device, encodingContext, errorTag),
: ProgrammableEncoder(device, encodingContext, errorTag),
mIndirectDrawMetadata(device->GetLimits()),
mDisableBaseVertex(device->IsToggleEnabled(Toggle::DisableBaseVertex)),
mDisableBaseInstance(device->IsToggleEnabled(Toggle::DisableBaseInstance)) {
}
void RenderEncoderBase::DestroyImpl() {
// Remove reference to the attachment state so that we don't have lingering references to
// it preventing it from being uncached in the device.
mAttachmentState = nullptr;
}
const AttachmentState* RenderEncoderBase::GetAttachmentState() const {
ASSERT(!IsError());
ASSERT(mAttachmentState != nullptr);

View File

@ -20,11 +20,11 @@
#include "dawn_native/Error.h"
#include "dawn_native/IndirectDrawMetadata.h"
#include "dawn_native/PassResourceUsageTracker.h"
#include "dawn_native/ProgrammablePassEncoder.h"
#include "dawn_native/ProgrammableEncoder.h"
namespace dawn_native {
class RenderEncoderBase : public ProgrammablePassEncoder {
class RenderEncoderBase : public ProgrammableEncoder {
public:
RenderEncoderBase(DeviceBase* device,
EncodingContext* encodingContext,
@ -67,6 +67,8 @@ namespace dawn_native {
// Construct an "error" render encoder base.
RenderEncoderBase(DeviceBase* device, EncodingContext* encodingContext, ErrorTag errorTag);
void DestroyImpl() override;
CommandBufferStateTracker mCommandBufferState;
RenderPassResourceUsageTracker mUsageTracker;
IndirectDrawMetadata mIndirectDrawMetadata;

View File

@ -68,6 +68,7 @@ namespace dawn_native {
mRenderTargetHeight(renderTargetHeight),
mOcclusionQuerySet(occlusionQuerySet) {
mUsageTracker = std::move(usageTracker);
TrackInDevice();
}
RenderPassEncoder::RenderPassEncoder(DeviceBase* device,
@ -83,6 +84,13 @@ namespace dawn_native {
return new RenderPassEncoder(device, commandEncoder, encodingContext, ObjectBase::kError);
}
void RenderPassEncoder::DestroyImpl() {
RenderEncoderBase::DestroyImpl();
// Ensure that the pass has exited. This is done for passes only since validation requires
// they exit before destruction while bundles do not.
mEncodingContext->EnsurePassExited(this);
}
ObjectType RenderPassEncoder::GetType() const {
return ObjectType::RenderPassEncoder;
}

View File

@ -67,6 +67,8 @@ namespace dawn_native {
ErrorTag errorTag);
private:
void DestroyImpl() override;
void TrackQueryAvailability(QuerySetBase* querySet, uint32_t queryIndex);
// For render and compute passes, the encoding context is borrowed from the command encoder.

View File

@ -689,6 +689,10 @@ namespace dawn_native {
// Do not uncache the actual cached object if we are a blueprint or already destroyed.
GetDevice()->UncacheRenderPipeline(this);
}
// Remove reference to the attachment state so that we don't have lingering references to
// it preventing it from being uncached in the device.
mAttachmentState = nullptr;
return wasDestroyed;
}

View File

@ -12,10 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gmock/gmock.h>
#include "dawn_native/CommandEncoder.h"
#include "tests/unittests/validation/ValidationTest.h"
#include "utils/WGPUHelpers.h"
using ::testing::HasSubstr;
class CommandBufferValidationTest : public ValidationTest {};
// Test for an empty command buffer
@ -39,7 +45,9 @@ TEST_F(CommandBufferValidationTest, EndedMidRenderPass) {
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [RenderPassEncoder] was ended."));
}
// Error case, command buffer ended mid-pass. Trying to use encoders after Finish
@ -47,8 +55,12 @@ TEST_F(CommandBufferValidationTest, EndedMidRenderPass) {
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(pass.EndPass());
ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [RenderPassEncoder] was ended."));
ASSERT_DEVICE_ERROR(
pass.EndPass(),
HasSubstr("Recording in an error or already ended [RenderPassEncoder]."));
}
}
@ -66,7 +78,9 @@ TEST_F(CommandBufferValidationTest, EndedMidComputePass) {
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [ComputePassEncoder] was ended."));
}
// Error case, command buffer ended mid-pass. Trying to use encoders after Finish
@ -74,8 +88,12 @@ TEST_F(CommandBufferValidationTest, EndedMidComputePass) {
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(pass.EndPass());
ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [ComputePassEncoder] was ended."));
ASSERT_DEVICE_ERROR(
pass.EndPass(),
HasSubstr("Recording in an error or already ended [ComputePassEncoder]."));
}
}
@ -97,7 +115,9 @@ TEST_F(CommandBufferValidationTest, RenderPassEndedTwice) {
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
pass.EndPass();
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Recording in an error or already ended [RenderPassEncoder]."));
}
}
@ -117,7 +137,9 @@ TEST_F(CommandBufferValidationTest, ComputePassEndedTwice) {
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.EndPass();
pass.EndPass();
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Recording in an error or already ended [ComputePassEncoder]."));
}
}
@ -223,14 +245,18 @@ TEST_F(CommandBufferValidationTest, PassDereferenced) {
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.BeginRenderPass(&dummyRenderPass);
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [RenderPassEncoder] was ended."));
}
// Error case, no reference is kept to a compute pass.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.BeginComputePass();
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [ComputePassEncoder] was ended."));
}
// Error case, beginning a new pass after failing to end a de-referenced pass.
@ -239,21 +265,25 @@ TEST_F(CommandBufferValidationTest, PassDereferenced) {
encoder.BeginRenderPass(&dummyRenderPass);
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.EndPass();
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [RenderPassEncoder] was ended."));
}
// Error case, deleting the pass after finishing the commend encoder shouldn't generate an
// Error case, deleting the pass after finishing the command encoder shouldn't generate an
// uncaptured error.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(
encoder.Finish(),
HasSubstr("Command buffer recording ended before [ComputePassEncoder] was ended."));
pass = nullptr;
}
// Valid case, command encoder is never finished so the de-referenced pass shouldn't generate an
// uncaptured error.
// Valid case, command encoder is never finished so the de-referenced pass shouldn't
// generate an uncaptured error.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.BeginComputePass();
@ -264,5 +294,80 @@ TEST_F(CommandBufferValidationTest, PassDereferenced) {
TEST_F(CommandBufferValidationTest, InjectValidationError) {
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
encoder.InjectValidationError("my error");
ASSERT_DEVICE_ERROR(encoder.Finish());
ASSERT_DEVICE_ERROR(encoder.Finish(), HasSubstr("my error"));
}
TEST_F(CommandBufferValidationTest, DestroyEncoder) {
// Skip these tests if we are using wire because the destroy functionality is not exposed
// and needs to use a cast to call manually. We cannot test this in the wire case since the
// only way to trigger the destroy call is by losing all references which means we cannot
// call finish.
DAWN_SKIP_TEST_IF(UsesWire());
DummyRenderPass dummyRenderPass(device);
// Control case, command buffer ended after the pass is ended.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
encoder.Finish();
}
// Destroyed encoder with encoded commands should emit error on finish.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
ASSERT_DEVICE_ERROR(encoder.Finish(), HasSubstr("Destroyed encoder cannot be finished."));
}
// Destroyed encoder with encoded commands shouldn't emit an error if never finished.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
}
// Destroyed encoder should allow encoding, and emit error on finish.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
ASSERT_DEVICE_ERROR(encoder.Finish(), HasSubstr("Destroyed encoder cannot be finished."));
}
// Destroyed encoder should allow encoding and shouldn't emit an error if never finished.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
}
// Destroying a finished encoder should not emit any errors.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&dummyRenderPass);
pass.EndPass();
encoder.Finish();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
}
// Destroying an encoder twice should not emit any errors.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
}
// Destroying an encoder twice and then calling finish should fail.
{
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
reinterpret_cast<dawn_native::CommandEncoder*>(encoder.Get())->Destroy();
ASSERT_DEVICE_ERROR(encoder.Finish(), HasSubstr("Destroyed encoder cannot be finished."));
}
}

View File

@ -129,12 +129,19 @@ void ValidationTest::TearDown() {
}
}
void ValidationTest::StartExpectDeviceError() {
void ValidationTest::StartExpectDeviceError(testing::Matcher<std::string> errorMatcher) {
mExpectError = true;
mError = false;
mErrorMatcher = errorMatcher;
}
void ValidationTest::StartExpectDeviceError() {
StartExpectDeviceError(testing::_);
}
bool ValidationTest::EndExpectDeviceError() {
mExpectError = false;
mErrorMatcher = testing::_;
return mError;
}
std::string ValidationTest::GetLastDeviceErrorMessage() const {
@ -210,6 +217,9 @@ void ValidationTest::OnDeviceError(WGPUErrorType type, const char* message, void
ASSERT_TRUE(self->mExpectError) << "Got unexpected device error: " << message;
ASSERT_FALSE(self->mError) << "Got two errors in expect block";
if (self->mExpectError) {
ASSERT_THAT(message, self->mErrorMatcher);
}
self->mError = true;
}

View File

@ -19,10 +19,30 @@
#include "dawn/webgpu_cpp.h"
#include "dawn_native/DawnNative.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#define ASSERT_DEVICE_ERROR(statement) \
FlushWire(); \
// Argument helpers to allow macro overriding.
#define UNIMPLEMENTED_MACRO(...) UNREACHABLE()
#define GET_3RD_ARG_HELPER_(_1, _2, NAME, ...) NAME
#define GET_3RD_ARG_(args) GET_3RD_ARG_HELPER_ args
// Overloaded to allow further validation of the error messages given an error is expected.
// Especially useful to verify that the expected errors are occuring, not just any error.
//
// Example usages:
// 1 Argument Case:
// ASSERT_DEVICE_ERROR(FunctionThatExpectsError());
//
// 2 Argument Case:
// ASSERT_DEVICE_ERROR(FunctionThatHasLongError(), HasSubstr("partial match"))
// ASSERT_DEVICE_ERROR(FunctionThatHasShortError(), Eq("exact match"));
#define ASSERT_DEVICE_ERROR(...) \
GET_3RD_ARG_((__VA_ARGS__, ASSERT_DEVICE_ERROR_IMPL_2_, ASSERT_DEVICE_ERROR_IMPL_1_, \
UNIMPLEMENTED_MACRO)) \
(__VA_ARGS__)
#define ASSERT_DEVICE_ERROR_IMPL_1_(statement) \
StartExpectDeviceError(); \
statement; \
FlushWire(); \
@ -32,6 +52,16 @@
do { \
} while (0)
#define ASSERT_DEVICE_ERROR_IMPL_2_(statement, matcher) \
StartExpectDeviceError(matcher); \
statement; \
FlushWire(); \
if (!EndExpectDeviceError()) { \
FAIL() << "Expected device error in:\n " << #statement; \
} \
do { \
} while (0)
// Skip a test when the given condition is satisfied.
#define DAWN_SKIP_TEST_IF(condition) \
do { \
@ -69,6 +99,7 @@ class ValidationTest : public testing::Test {
void SetUp() override;
void TearDown() override;
void StartExpectDeviceError(testing::Matcher<std::string> errorMatcher);
void StartExpectDeviceError();
bool EndExpectDeviceError();
std::string GetLastDeviceErrorMessage() const;
@ -118,6 +149,7 @@ class ValidationTest : public testing::Test {
std::string mDeviceErrorMessage;
bool mExpectError = false;
bool mError = false;
testing::Matcher<std::string> mErrorMatcher;
};
#endif // TESTS_UNITTESTS_VALIDATIONTEST_H_