diff --git a/src/dawn_wire/client/Buffer.cpp b/src/dawn_wire/client/Buffer.cpp index 40b5759712..da3691c584 100644 --- a/src/dawn_wire/client/Buffer.cpp +++ b/src/dawn_wire/client/Buffer.cpp @@ -63,6 +63,7 @@ namespace dawn_wire { namespace client { auto* bufferObjectAndSerial = wireClient->BufferAllocator().New(wireClient); Buffer* buffer = bufferObjectAndSerial->object.get(); buffer->mDevice = device; + buffer->mDeviceIsAlive = device->GetAliveWeakPtr(); buffer->mSize = descriptor->size; DeviceCreateBufferCmd cmd; @@ -92,6 +93,7 @@ namespace dawn_wire { namespace client { WGPUBuffer Buffer::CreateError(Device* device) { auto* allocation = device->client->BufferAllocator().New(device->client); allocation->object->mDevice = device; + allocation->object->mDeviceIsAlive = device->GetAliveWeakPtr(); DeviceCreateErrorBufferCmd cmd; cmd.self = ToAPI(device); @@ -140,8 +142,10 @@ namespace dawn_wire { namespace client { // Step 1. Do early validation of READ ^ WRITE because the server rejects mode = 0. if (!(isReadMode ^ isWriteMode)) { - mDevice->InjectError(WGPUErrorType_Validation, - "MapAsync mode must be exactly one of Read or Write"); + if (!mDeviceIsAlive.expired()) { + mDevice->InjectError(WGPUErrorType_Validation, + "MapAsync mode must be exactly one of Read or Write"); + } if (callback != nullptr) { callback(WGPUBufferMapAsyncStatus_Error, userdata); } @@ -163,7 +167,10 @@ namespace dawn_wire { namespace client { if (isReadMode) { request.readHandle.reset(client->GetMemoryTransferService()->CreateReadHandle(size)); if (request.readHandle == nullptr) { - mDevice->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping"); + if (!mDeviceIsAlive.expired()) { + mDevice->InjectError(WGPUErrorType_OutOfMemory, + "Failed to create buffer mapping"); + } callback(WGPUBufferMapAsyncStatus_Error, userdata); return; } @@ -171,7 +178,10 @@ namespace dawn_wire { namespace client { ASSERT(isWriteMode); request.writeHandle.reset(client->GetMemoryTransferService()->CreateWriteHandle(size)); if (request.writeHandle == nullptr) { - mDevice->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping"); + if (!mDeviceIsAlive.expired()) { + mDevice->InjectError(WGPUErrorType_OutOfMemory, + "Failed to create buffer mapping"); + } callback(WGPUBufferMapAsyncStatus_Error, userdata); return; } diff --git a/src/dawn_wire/client/Buffer.h b/src/dawn_wire/client/Buffer.h index d6a275c15d..5e0d5ec2a3 100644 --- a/src/dawn_wire/client/Buffer.h +++ b/src/dawn_wire/client/Buffer.h @@ -89,6 +89,8 @@ namespace dawn_wire { namespace client { void* mMappedData = nullptr; size_t mMapOffset = 0; size_t mMapSize = 0; + + std::weak_ptr mDeviceIsAlive; }; }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp index 08a3e33831..b6ee4fceea 100644 --- a/src/dawn_wire/client/Device.cpp +++ b/src/dawn_wire/client/Device.cpp @@ -23,7 +23,7 @@ namespace dawn_wire { namespace client { Device::Device(Client* clientIn, uint32_t initialRefcount, uint32_t initialId) - : ObjectBase(clientIn, initialRefcount, initialId) { + : ObjectBase(clientIn, initialRefcount, initialId), mIsAlive(std::make_shared()) { #if defined(DAWN_ENABLE_ASSERTS) mErrorCallback = [](WGPUErrorType, char const*, void*) { static bool calledOnce = false; @@ -114,6 +114,10 @@ namespace dawn_wire { namespace client { mErrorScopes.clear(); } + std::weak_ptr Device::GetAliveWeakPtr() { + return mIsAlive; + } + void Device::SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata) { mErrorCallback = errorCallback; mErrorUserdata = errorUserdata; diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h index 6862bb3d3b..3f167004a8 100644 --- a/src/dawn_wire/client/Device.h +++ b/src/dawn_wire/client/Device.h @@ -64,6 +64,8 @@ namespace dawn_wire { namespace client { void CancelCallbacksForDisconnect() override; + std::weak_ptr GetAliveWeakPtr(); + private: struct ErrorScopeData { WGPUErrorCallback callback = nullptr; @@ -89,6 +91,8 @@ namespace dawn_wire { namespace client { void* mDeviceLostUserdata = nullptr; Queue* mDefaultQueue = nullptr; + + std::shared_ptr mIsAlive; }; }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Fence.cpp b/src/dawn_wire/client/Fence.cpp index 9bb40b3e4a..d9a858a245 100644 --- a/src/dawn_wire/client/Fence.cpp +++ b/src/dawn_wire/client/Fence.cpp @@ -38,9 +38,7 @@ namespace dawn_wire { namespace client { mOnCompletionRequests.clear(); } - void Fence::Initialize(Queue* queue, const WGPUFenceDescriptor* descriptor) { - mQueue = queue; - + void Fence::Initialize(const WGPUFenceDescriptor* descriptor) { mCompletedValue = descriptor != nullptr ? descriptor->initialValue : 0u; } @@ -87,8 +85,4 @@ namespace dawn_wire { namespace client { return mCompletedValue; } - Queue* Fence::GetQueue() const { - return mQueue; - } - }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Fence.h b/src/dawn_wire/client/Fence.h index a8fb5f4722..7c09948e56 100644 --- a/src/dawn_wire/client/Fence.h +++ b/src/dawn_wire/client/Fence.h @@ -27,7 +27,7 @@ namespace dawn_wire { namespace client { public: using ObjectBase::ObjectBase; ~Fence(); - void Initialize(Queue* queue, const WGPUFenceDescriptor* descriptor); + void Initialize(const WGPUFenceDescriptor* descriptor); void CheckPassedFences(); void OnCompletion(uint64_t value, WGPUFenceOnCompletionCallback callback, void* userdata); @@ -35,7 +35,6 @@ namespace dawn_wire { namespace client { bool OnCompletionCallback(uint64_t requestSerial, WGPUFenceCompletionStatus status); uint64_t GetCompletedValue() const; - Queue* GetQueue() const; private: void CancelCallbacksForDisconnect() override; @@ -44,7 +43,6 @@ namespace dawn_wire { namespace client { WGPUFenceOnCompletionCallback callback = nullptr; void* userdata = nullptr; }; - Queue* mQueue = nullptr; uint64_t mCompletedValue = 0; uint64_t mOnCompletionRequestSerial = 0; std::map mOnCompletionRequests; diff --git a/src/dawn_wire/client/Queue.cpp b/src/dawn_wire/client/Queue.cpp index 37ae3fa6f4..f5f68f782b 100644 --- a/src/dawn_wire/client/Queue.cpp +++ b/src/dawn_wire/client/Queue.cpp @@ -29,7 +29,7 @@ namespace dawn_wire { namespace client { client->SerializeCommand(cmd); Fence* fence = allocation->object.get(); - fence->Initialize(this, descriptor); + fence->Initialize(descriptor); return ToAPI(fence); } diff --git a/src/tests/unittests/wire/WireDestroyObjectTests.cpp b/src/tests/unittests/wire/WireDestroyObjectTests.cpp index f5e16b7b30..34b976dcc9 100644 --- a/src/tests/unittests/wire/WireDestroyObjectTests.cpp +++ b/src/tests/unittests/wire/WireDestroyObjectTests.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "tests/MockCallback.h" #include "tests/unittests/wire/WireTest.h" using namespace testing; @@ -43,3 +44,46 @@ TEST_F(WireDestroyObjectTests, DestroyDeviceDestroysChildren) { wgpuCommandEncoderFinish(encoder, nullptr); FlushClient(false); } + +// Test that calling a function that would generate an InjectError doesn't crash after +// the device is destroyed. +TEST_F(WireDestroyObjectTests, ImplicitInjectErrorAfterDestroyDevice) { + WGPUBufferDescriptor bufferDesc = {}; + bufferDesc.size = 4; + WGPUBuffer buffer = wgpuDeviceCreateBuffer(device, &bufferDesc); + + WGPUBuffer apiBuffer = api.GetNewBuffer(); + EXPECT_CALL(api, DeviceCreateBuffer(apiDevice, _)).WillOnce(Return(apiBuffer)); + + FlushClient(); + + { + // Control case: MapAsync errors on invalid WGPUMapMode. + MockCallback mockBufferMapCallback; + + EXPECT_CALL(api, DeviceInjectError(apiDevice, WGPUErrorType_Validation, _)); + EXPECT_CALL(mockBufferMapCallback, Call(WGPUBufferMapAsyncStatus_Error, this)); + wgpuBufferMapAsync(buffer, WGPUMapMode(0), 0, 4, mockBufferMapCallback.Callback(), + mockBufferMapCallback.MakeUserdata(this)); + + FlushClient(); + } + + { + // Now, release the device. InjectError shouldn't happen. + wgpuDeviceRelease(device); + MockCallback mockBufferMapCallback; + + EXPECT_CALL(mockBufferMapCallback, Call(WGPUBufferMapAsyncStatus_Error, this + 1)); + wgpuBufferMapAsync(buffer, WGPUMapMode(0), 0, 4, mockBufferMapCallback.Callback(), + mockBufferMapCallback.MakeUserdata(this + 1)); + + Sequence s1, s2; + // The device and child objects alre also released. + EXPECT_CALL(api, BufferRelease(apiBuffer)).InSequence(s1); + EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2); + EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2); + + FlushClient(); + } +}