From 0ecc48ecb7a1353b9e1a3cd9d3a125f6a770477c Mon Sep 17 00:00:00 2001 From: Natasha Lee Date: Wed, 15 Jan 2020 19:02:13 +0000 Subject: [PATCH] Handle DeviceLost error Handle DeviceLostCallback once DeviceLost error occurs. Disallow any other commands or actions on device to happen after device has been lost. Bug: dawn:68 Change-Id: Icbbbadf278cae5e6213050d00439118789c863dc Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/12801 Commit-Queue: Natasha Lee Reviewed-by: Austin Eng --- BUILD.gn | 2 + dawn.json | 3 + src/dawn_native/Buffer.cpp | 4 + src/dawn_native/Device.cpp | 59 +++++++- src/dawn_native/Device.h | 15 +- src/dawn_native/Error.cpp | 27 ++++ src/dawn_native/Error.h | 3 + src/dawn_native/ErrorData.h | 1 - src/dawn_native/Queue.cpp | 5 + src/dawn_native/Texture.cpp | 1 + src/dawn_native/d3d12/DeviceD3D12.cpp | 2 + src/dawn_native/metal/DeviceMTL.mm | 2 + src/dawn_native/null/DeviceNull.cpp | 2 + src/dawn_native/opengl/DeviceGL.cpp | 2 + src/dawn_native/vulkan/DeviceVk.cpp | 6 +- src/tests/DawnTest.cpp | 5 + src/tests/DawnTest.h | 1 + src/tests/end2end/DeviceLostTests.cpp | 192 ++++++++++++++++++++++++++ 18 files changed, 322 insertions(+), 10 deletions(-) create mode 100644 src/dawn_native/Error.cpp create mode 100644 src/tests/end2end/DeviceLostTests.cpp diff --git a/BUILD.gn b/BUILD.gn index b6201ce845..0bfede2457 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -204,6 +204,7 @@ source_set("libdawn_native_sources") { "src/dawn_native/DynamicUploader.h", "src/dawn_native/EncodingContext.cpp", "src/dawn_native/EncodingContext.h", + "src/dawn_native/Error.cpp", "src/dawn_native/Error.h", "src/dawn_native/ErrorData.cpp", "src/dawn_native/ErrorData.h", @@ -916,6 +917,7 @@ source_set("dawn_end2end_tests_sources") { "src/tests/end2end/DebugMarkerTests.cpp", "src/tests/end2end/DepthStencilStateTests.cpp", "src/tests/end2end/DestroyTests.cpp", + "src/tests/end2end/DeviceLostTests.cpp", "src/tests/end2end/DrawIndexedIndirectTests.cpp", "src/tests/end2end/DrawIndexedTests.cpp", "src/tests/end2end/DrawIndirectTests.cpp", diff --git a/dawn.json b/dawn.json index c8a318aea5..5a6fa82689 100644 --- a/dawn.json +++ b/dawn.json @@ -611,6 +611,9 @@ ], "TODO": "enga@: Make this a Dawn extension" }, + { + "name": "lose for testing" + }, { "name": "tick" }, diff --git a/src/dawn_native/Buffer.cpp b/src/dawn_native/Buffer.cpp index c57d23a129..8e72bd731b 100644 --- a/src/dawn_native/Buffer.cpp +++ b/src/dawn_native/Buffer.cpp @@ -17,6 +17,7 @@ #include "common/Assert.h" #include "dawn_native/Device.h" #include "dawn_native/DynamicUploader.h" +#include "dawn_native/ErrorData.h" #include "dawn_native/ValidationUtils_autogen.h" #include @@ -350,6 +351,7 @@ namespace dawn_native { } MaybeError BufferBase::ValidateSetSubData(uint32_t start, uint32_t count) const { + DAWN_TRY(GetDevice()->ValidateIsAlive()); DAWN_TRY(GetDevice()->ValidateObject(this)); switch (mState) { @@ -388,6 +390,7 @@ namespace dawn_native { } MaybeError BufferBase::ValidateMap(wgpu::BufferUsage requiredUsage) const { + DAWN_TRY(GetDevice()->ValidateIsAlive()); DAWN_TRY(GetDevice()->ValidateObject(this)); switch (mState) { @@ -407,6 +410,7 @@ namespace dawn_native { } MaybeError BufferBase::ValidateUnmap() const { + DAWN_TRY(GetDevice()->ValidateIsAlive()); DAWN_TRY(GetDevice()->ValidateObject(this)); switch (mState) { diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp index f5609dc9c7..d2aab654de 100644 --- a/src/dawn_native/Device.cpp +++ b/src/dawn_native/Device.cpp @@ -95,15 +95,20 @@ namespace dawn_native { } void DeviceBase::BaseDestructor() { - MaybeError err = WaitForIdleForDestruction(); - if (err.IsError()) { - // Assert that errors are device loss so that we can continue with destruction - ASSERT(err.AcquireError()->GetType() == wgpu::ErrorType::DeviceLost); + if (mLossStatus != LossStatus::Alive) { + return; } + // Assert that errors are device loss so that we can continue with destruction + AssertAndIgnoreDeviceLossError(WaitForIdleForDestruction()); Destroy(); + mLossStatus = LossStatus::AlreadyLost; } void DeviceBase::HandleError(wgpu::ErrorType type, const char* message) { + if (type == wgpu::ErrorType::DeviceLost) { + HandleLoss(message); + } + // Still forward device loss to error scope so it can reject them all mCurrentErrorScope->HandleError(type, message); } @@ -165,6 +170,33 @@ namespace dawn_native { return {}; } + MaybeError DeviceBase::ValidateIsAlive() const { + if (DAWN_LIKELY(mLossStatus == LossStatus::Alive)) { + return {}; + } + return DAWN_DEVICE_LOST_ERROR("Device is lost"); + } + + void DeviceBase::HandleLoss(const char* message) { + if (mLossStatus == LossStatus::AlreadyLost) { + return; + } + + Destroy(); + mLossStatus = LossStatus::AlreadyLost; + + if (mDeviceLostCallback) { + mDeviceLostCallback(message, mDeviceLostUserdata); + } + } + + void DeviceBase::LoseForTesting() { + mLossStatus = LossStatus::BeingLost; + // Assert that errors are device loss so that we can continue with destruction + AssertAndIgnoreDeviceLossError(WaitForIdleForDestruction()); + HandleError(wgpu::ErrorType::DeviceLost, "Device lost for testing"); + } + AdapterBase* DeviceBase::GetAdapter() const { return mAdapter; } @@ -563,8 +595,12 @@ namespace dawn_native { // Other Device API methods void DeviceBase::Tick() { - if (ConsumedError(TickImpl())) + if (ConsumedError(ValidateIsAlive())) { return; + } + if (ConsumedError(TickImpl())) { + return; + } { auto deferredResults = std::move(mDeferredCreateBufferMappedAsyncResults); @@ -651,6 +687,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateBindGroupInternal(BindGroupBase** result, const BindGroupDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidateBindGroupDescriptor(this, descriptor)); } @@ -661,6 +698,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateBindGroupLayoutInternal( BindGroupLayoutBase** result, const BindGroupLayoutDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidateBindGroupLayoutDescriptor(this, descriptor)); } @@ -670,6 +708,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateBufferInternal(BufferBase** result, const BufferDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidateBufferDescriptor(this, descriptor)); } @@ -680,6 +719,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateComputePipelineInternal( ComputePipelineBase** result, const ComputePipelineDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor)); } @@ -704,6 +744,7 @@ namespace dawn_native { MaybeError DeviceBase::CreatePipelineLayoutInternal( PipelineLayoutBase** result, const PipelineLayoutDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidatePipelineLayoutDescriptor(this, descriptor)); } @@ -712,6 +753,7 @@ namespace dawn_native { } MaybeError DeviceBase::CreateQueueInternal(QueueBase** result) { + DAWN_TRY(ValidateIsAlive()); DAWN_TRY_ASSIGN(*result, CreateQueueImpl()); return {}; } @@ -719,6 +761,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateRenderBundleEncoderInternal( RenderBundleEncoder** result, const RenderBundleEncoderDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidateRenderBundleEncoderDescriptor(this, descriptor)); } @@ -729,6 +772,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateRenderPipelineInternal( RenderPipelineBase** result, const RenderPipelineDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidateRenderPipelineDescriptor(this, descriptor)); } @@ -761,6 +805,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateSamplerInternal(SamplerBase** result, const SamplerDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidateSamplerDescriptor(this, descriptor)); } @@ -770,6 +815,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result, const ShaderModuleDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor)); } @@ -779,6 +825,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateSwapChainInternal(SwapChainBase** result, const SwapChainDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidateSwapChainDescriptor(this, descriptor)); } @@ -788,6 +835,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateTextureInternal(TextureBase** result, const TextureDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); if (IsValidationEnabled()) { DAWN_TRY(ValidateTextureDescriptor(this, descriptor)); } @@ -798,6 +846,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateTextureViewInternal(TextureViewBase** result, TextureBase* texture, const TextureViewDescriptor* descriptor) { + DAWN_TRY(ValidateIsAlive()); DAWN_TRY(ValidateObject(texture)); TextureViewDescriptor desc = GetTextureViewDescriptorWithDefaults(texture, descriptor); if (IsValidationEnabled()) { diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h index 9d0947d2b7..6e27c3b620 100644 --- a/src/dawn_native/Device.h +++ b/src/dawn_native/Device.h @@ -29,9 +29,6 @@ #include namespace dawn_native { - - using ErrorCallback = void (*)(const char* errorMessage, void* userData); - class AdapterBase; class AttachmentState; class AttachmentStateBlueprint; @@ -167,6 +164,9 @@ namespace dawn_native { void SetUncapturedErrorCallback(wgpu::ErrorCallback callback, void* userdata); void PushErrorScope(wgpu::ErrorFilter filter); bool PopErrorScope(wgpu::ErrorCallback callback, void* userdata); + + MaybeError ValidateIsAlive() const; + ErrorScope* GetCurrentErrorScope(); void Reference(); @@ -189,6 +189,7 @@ namespace dawn_native { bool IsValidationEnabled() const; size_t GetLazyClearCountForTesting(); void IncrementLazyClearCountForTesting(); + void LoseForTesting(); protected: void SetToggle(Toggle toggle, bool isEnabled); @@ -196,6 +197,13 @@ namespace dawn_native { void BaseDestructor(); std::unique_ptr mDynamicUploader; + // LossStatus::Alive means the device is alive and can be used normally. + // LossStatus::BeingLost means the device is in the process of being lost and should not + // accept any new commands. + // LossStatus::AlreadyLost means the device has been lost and can no longer be used, + // all resources have been freed. + enum class LossStatus { Alive, BeingLost, AlreadyLost }; + LossStatus mLossStatus = LossStatus::Alive; private: virtual ResultOrError CreateBindGroupImpl( @@ -263,6 +271,7 @@ namespace dawn_native { // resources. virtual MaybeError WaitForIdleForDestruction() = 0; + void HandleLoss(const char* message); wgpu::DeviceLostCallback mDeviceLostCallback = nullptr; void* mDeviceLostUserdata; diff --git a/src/dawn_native/Error.cpp b/src/dawn_native/Error.cpp new file mode 100644 index 0000000000..d1ca233ae7 --- /dev/null +++ b/src/dawn_native/Error.cpp @@ -0,0 +1,27 @@ +// Copyright 2018 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn_native/Error.h" + +#include "dawn_native/ErrorData.h" +#include "dawn_native/dawn_platform.h" + +namespace dawn_native { + void AssertAndIgnoreDeviceLossError(MaybeError maybeError) { + if (maybeError.IsError()) { + std::unique_ptr errorData = maybeError.AcquireError(); + ASSERT(errorData->GetType() == wgpu::ErrorType::DeviceLost); + } + } +} // namespace dawn_native \ No newline at end of file diff --git a/src/dawn_native/Error.h b/src/dawn_native/Error.h index 7f9dbcbc1e..87ac5405eb 100644 --- a/src/dawn_native/Error.h +++ b/src/dawn_native/Error.h @@ -81,6 +81,9 @@ namespace dawn_native { for (;;) \ break + // Assert that errors are device loss so that we can continue with destruction + void AssertAndIgnoreDeviceLossError(MaybeError maybeError); + } // namespace dawn_native #endif // DAWNNATIVE_ERROR_H_ diff --git a/src/dawn_native/ErrorData.h b/src/dawn_native/ErrorData.h index 27004de2f0..5d74d36327 100644 --- a/src/dawn_native/ErrorData.h +++ b/src/dawn_native/ErrorData.h @@ -28,7 +28,6 @@ namespace dawn { } namespace dawn_native { - enum class InternalErrorType : uint32_t; class ErrorData { diff --git a/src/dawn_native/Queue.cpp b/src/dawn_native/Queue.cpp index 0fbcdc7d05..803f299e39 100644 --- a/src/dawn_native/Queue.cpp +++ b/src/dawn_native/Queue.cpp @@ -34,6 +34,11 @@ namespace dawn_native { void QueueBase::Submit(uint32_t commandCount, CommandBufferBase* const* commands) { DeviceBase* device = GetDevice(); + if (device->ConsumedError(device->ValidateIsAlive())) { + // If device is lost, don't let any commands be submitted + return; + } + TRACE_EVENT0(device->GetPlatform(), General, "Queue::Submit"); if (device->IsValidationEnabled() && device->ConsumedError(ValidateSubmit(commandCount, commands))) { diff --git a/src/dawn_native/Texture.cpp b/src/dawn_native/Texture.cpp index 7b264c1ec1..5ec3b1f2c5 100644 --- a/src/dawn_native/Texture.cpp +++ b/src/dawn_native/Texture.cpp @@ -503,6 +503,7 @@ namespace dawn_native { } MaybeError TextureBase::ValidateDestroy() const { + DAWN_TRY(GetDevice()->ValidateIsAlive()); DAWN_TRY(GetDevice()->ValidateObject(this)); return {}; } diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp index 352fe00495..6b76544e0a 100644 --- a/src/dawn_native/d3d12/DeviceD3D12.cpp +++ b/src/dawn_native/d3d12/DeviceD3D12.cpp @@ -403,6 +403,8 @@ namespace dawn_native { namespace d3d12 { } void Device::Destroy() { + ASSERT(mLossStatus != LossStatus::AlreadyLost); + // Immediately forget about all pending commands mPendingCommands.Release(); diff --git a/src/dawn_native/metal/DeviceMTL.mm b/src/dawn_native/metal/DeviceMTL.mm index 77faa4e377..6fa63d32e3 100644 --- a/src/dawn_native/metal/DeviceMTL.mm +++ b/src/dawn_native/metal/DeviceMTL.mm @@ -277,6 +277,8 @@ namespace dawn_native { namespace metal { } void Device::Destroy() { + ASSERT(mLossStatus != LossStatus::AlreadyLost); + [mCommandContext.AcquireCommands() release]; mMapTracker = nullptr; diff --git a/src/dawn_native/null/DeviceNull.cpp b/src/dawn_native/null/DeviceNull.cpp index be074b6b77..c93de3f882 100644 --- a/src/dawn_native/null/DeviceNull.cpp +++ b/src/dawn_native/null/DeviceNull.cpp @@ -166,6 +166,8 @@ namespace dawn_native { namespace null { } void Device::Destroy() { + ASSERT(mLossStatus != LossStatus::AlreadyLost); + mDynamicUploader = nullptr; mPendingOperations.clear(); diff --git a/src/dawn_native/opengl/DeviceGL.cpp b/src/dawn_native/opengl/DeviceGL.cpp index d44bd525f3..b794ed77e9 100644 --- a/src/dawn_native/opengl/DeviceGL.cpp +++ b/src/dawn_native/opengl/DeviceGL.cpp @@ -162,6 +162,8 @@ namespace dawn_native { namespace opengl { } void Device::Destroy() { + ASSERT(mLossStatus != LossStatus::AlreadyLost); + // Some operations might have been started since the last submit and waiting // on a serial that doesn't have a corresponding fence enqueued. Force all // operations to look as if they were completed (because they were). diff --git a/src/dawn_native/vulkan/DeviceVk.cpp b/src/dawn_native/vulkan/DeviceVk.cpp index 487e7efac5..6b44d49974 100644 --- a/src/dawn_native/vulkan/DeviceVk.cpp +++ b/src/dawn_native/vulkan/DeviceVk.cpp @@ -733,6 +733,8 @@ namespace dawn_native { namespace vulkan { } void Device::Destroy() { + ASSERT(mLossStatus != LossStatus::AlreadyLost); + // Immediately tag the recording context as unused so we don't try to submit it in Tick. mRecordingContext.used = false; fn.DestroyCommandPool(mVkDevice, mRecordingContext.commandPool, nullptr); @@ -741,7 +743,9 @@ namespace dawn_native { namespace vulkan { // on a serial that doesn't have a corresponding fence enqueued. Force all // operations to look as if they were completed (because they were). mCompletedSerial = mLastSubmittedSerial + 1; - Tick(); + + // Assert that errors are device loss so that we can continue with destruction + AssertAndIgnoreDeviceLossError(TickImpl()); ASSERT(mCommandsInFlight.Empty()); for (const CommandPoolAndBuffer& commands : mUnusedCommands) { diff --git a/src/tests/DawnTest.cpp b/src/tests/DawnTest.cpp index da7e575427..5f5dba4af9 100644 --- a/src/tests/DawnTest.cpp +++ b/src/tests/DawnTest.cpp @@ -582,6 +582,7 @@ void DawnTestBase::SetUp() { queue = device.CreateQueue(); device.SetUncapturedErrorCallback(OnDeviceError, this); + device.SetDeviceLostCallback(OnDeviceLost, this); } void DawnTestBase::TearDown() { @@ -618,6 +619,10 @@ void DawnTestBase::OnDeviceError(WGPUErrorType type, const char* message, void* self->mError = true; } +void DawnTestBase::OnDeviceLost(const char* message, void* userdata) { + FAIL() << "Device Lost during test: " << message; +} + std::ostringstream& DawnTestBase::AddBufferExpectation(const char* file, int line, const wgpu::Buffer& buffer, diff --git a/src/tests/DawnTest.h b/src/tests/DawnTest.h index fb8825dc2e..4378ac233a 100644 --- a/src/tests/DawnTest.h +++ b/src/tests/DawnTest.h @@ -249,6 +249,7 @@ class DawnTestBase { // Tracking for validation errors static void OnDeviceError(WGPUErrorType type, const char* message, void* userdata); + static void OnDeviceLost(const char* message, void* userdata); bool mExpectError = false; bool mError = false; diff --git a/src/tests/end2end/DeviceLostTests.cpp b/src/tests/end2end/DeviceLostTests.cpp new file mode 100644 index 0000000000..ab8f260ed8 --- /dev/null +++ b/src/tests/end2end/DeviceLostTests.cpp @@ -0,0 +1,192 @@ +// Copyright 2019 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tests/DawnTest.h" + +#include +#include "utils/ComboRenderPipelineDescriptor.h" +#include "utils/WGPUHelpers.h" + +#include + +using namespace testing; + +class MockDeviceLostCallback { + public: + MOCK_METHOD2(Call, void(const char* message, void* userdata)); +}; + +static std::unique_ptr mockDeviceLostCallback; +static void ToMockDeviceLostCallback(const char* message, void* userdata) { + mockDeviceLostCallback->Call(message, userdata); + DawnTestBase* self = static_cast(userdata); + self->StartExpectDeviceError(); +} + +class DeviceLostTest : public DawnTest { + protected: + void TestSetUp() override { + DAWN_SKIP_TEST_IF(UsesWire()); + DawnTest::TestSetUp(); + mockDeviceLostCallback = std::make_unique(); + } + + void TearDown() override { + DawnTest::TearDown(); + mockDeviceLostCallback = nullptr; + } + + void SetCallbackAndLoseForTesting() { + device.SetDeviceLostCallback(ToMockDeviceLostCallback, this); + EXPECT_CALL(*mockDeviceLostCallback, Call(_, this)).Times(1); + device.LoseForTesting(); + } +}; + +// Test that DeviceLostCallback is invoked when LostForTestimg is called +TEST_P(DeviceLostTest, DeviceLostCallbackIsCalled) { + SetCallbackAndLoseForTesting(); +} + +// Test that submit fails when device is lost +TEST_P(DeviceLostTest, SubmitFails) { + wgpu::CommandBuffer commands; + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + commands = encoder.Finish(); + + SetCallbackAndLoseForTesting(); + ASSERT_DEVICE_ERROR(queue.Submit(0, &commands)); +} + +// Test that CreateBindGroupLayout fails when device is lost +TEST_P(DeviceLostTest, CreateBindGroupLayoutFails) { + SetCallbackAndLoseForTesting(); + + wgpu::BindGroupLayoutBinding binding = {0, wgpu::ShaderStage::None, + wgpu::BindingType::UniformBuffer}; + wgpu::BindGroupLayoutDescriptor descriptor; + descriptor.bindingCount = 1; + descriptor.bindings = &binding; + ASSERT_DEVICE_ERROR(device.CreateBindGroupLayout(&descriptor)); +} + +// Test that CreateBindGroup fails when device is lost +TEST_P(DeviceLostTest, CreateBindGroupFails) { + SetCallbackAndLoseForTesting(); + + wgpu::BindGroupBinding binding; + binding.binding = 0; + binding.sampler = nullptr; + binding.textureView = nullptr; + binding.buffer = nullptr; + binding.offset = 0; + binding.size = 0; + + wgpu::BindGroupDescriptor descriptor; + descriptor.layout = nullptr; + descriptor.bindingCount = 1; + descriptor.bindings = &binding; + ASSERT_DEVICE_ERROR(device.CreateBindGroup(&descriptor)); +} + +// Test that CreatePipelineLayout fails when device is lost +TEST_P(DeviceLostTest, CreatePipelineLayoutFails) { + SetCallbackAndLoseForTesting(); + + wgpu::PipelineLayoutDescriptor descriptor; + descriptor.bindGroupLayoutCount = 0; + descriptor.bindGroupLayouts = nullptr; + ASSERT_DEVICE_ERROR(device.CreatePipelineLayout(&descriptor)); +} + +// Tests that CreateRenderBundleEncoder fails when device is lost +TEST_P(DeviceLostTest, CreateRenderBundleEncoderFails) { + SetCallbackAndLoseForTesting(); + + wgpu::RenderBundleEncoderDescriptor descriptor; + descriptor.colorFormatsCount = 0; + descriptor.colorFormats = nullptr; + ASSERT_DEVICE_ERROR(device.CreateRenderBundleEncoder(&descriptor)); +} + +// Tests that CreateComputePipeline fails when device is lost +TEST_P(DeviceLostTest, CreateComputePipelineFails) { + SetCallbackAndLoseForTesting(); + + wgpu::ComputePipelineDescriptor descriptor; + descriptor.layout = nullptr; + descriptor.computeStage.module = nullptr; + descriptor.nextInChain = nullptr; + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&descriptor)); +} + +// Tests that CreateRenderPipeline fails when device is lost +TEST_P(DeviceLostTest, CreateRenderPipelineFails) { + SetCallbackAndLoseForTesting(); + + utils::ComboRenderPipelineDescriptor descriptor(device); + ASSERT_DEVICE_ERROR(device.CreateRenderPipeline(&descriptor)); +} + +// Tests that CreateSampler fails when device is lost +TEST_P(DeviceLostTest, CreateSamplerFails) { + SetCallbackAndLoseForTesting(); + + wgpu::SamplerDescriptor descriptor = utils::GetDefaultSamplerDescriptor(); + ASSERT_DEVICE_ERROR(device.CreateSampler(&descriptor)); +} + +// Tests that CreateShaderModule fails when device is lost +TEST_P(DeviceLostTest, CreateShaderModuleFails) { + SetCallbackAndLoseForTesting(); + + ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, utils::SingleShaderStage::Fragment, R"( + #version 450 + layout(location = 0) in vec4 color; + layout(location = 0) out vec4 fragColor; + void main() { + fragColor = color; + })")); +} + +// Tests that CreateSwapChain fails when device is lost +TEST_P(DeviceLostTest, CreateSwapChainFails) { + SetCallbackAndLoseForTesting(); + + wgpu::SwapChainDescriptor descriptor; + descriptor.nextInChain = nullptr; + ASSERT_DEVICE_ERROR(device.CreateSwapChain(&descriptor)); +} + +// Tests that CreateTexture fails when device is lost +TEST_P(DeviceLostTest, CreateTextureFails) { + SetCallbackAndLoseForTesting(); + + wgpu::TextureDescriptor descriptor; + descriptor.size.width = 4; + descriptor.size.height = 4; + descriptor.size.depth = 1; + descriptor.arrayLayerCount = 1; + descriptor.mipLevelCount = 1; + descriptor.dimension = wgpu::TextureDimension::e2D; + descriptor.usage = wgpu::TextureUsage::OutputAttachment; + + ASSERT_DEVICE_ERROR(device.CreateTexture(&descriptor)); +} + +TEST_P(DeviceLostTest, TickFails) { + SetCallbackAndLoseForTesting(); + ASSERT_DEVICE_ERROR(device.Tick()); +} +DAWN_INSTANTIATE_TEST(DeviceLostTest, D3D12Backend, VulkanBackend); \ No newline at end of file