From 8c58491d25e5079fbda8ebc948bc39773e21542e Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Thu, 14 Jan 2021 00:51:58 +0000 Subject: [PATCH] dawn_wire: Skip device inject error if the client drops the device If the client drops the last reference to the device, it would dereference an invalid pointer upon calling InjectError. So, skip the call. We can't keep the device alive if the Buffer is still alive because we intend to make all objects internally null if you delete their device. It is ok to skip error injection because if the client deletes the device, it should not expect to receive any more error callbacks. Bug: dawn:384 Change-Id: I4c694310e4395b06cd49603fc5d4cd846799decb Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/37580 Reviewed-by: Jiawei Shao Reviewed-by: Corentin Wallez Commit-Queue: Austin Eng --- src/dawn_wire/client/Buffer.cpp | 18 ++++++-- src/dawn_wire/client/Buffer.h | 2 + src/dawn_wire/client/Device.cpp | 6 ++- src/dawn_wire/client/Device.h | 4 ++ src/dawn_wire/client/Fence.cpp | 8 +--- src/dawn_wire/client/Fence.h | 4 +- src/dawn_wire/client/Queue.cpp | 2 +- .../unittests/wire/WireDestroyObjectTests.cpp | 44 +++++++++++++++++++ 8 files changed, 72 insertions(+), 16 deletions(-) diff --git a/src/dawn_wire/client/Buffer.cpp b/src/dawn_wire/client/Buffer.cpp index 40b5759712..da3691c584 100644 --- a/src/dawn_wire/client/Buffer.cpp +++ b/src/dawn_wire/client/Buffer.cpp @@ -63,6 +63,7 @@ namespace dawn_wire { namespace client { auto* bufferObjectAndSerial = wireClient->BufferAllocator().New(wireClient); Buffer* buffer = bufferObjectAndSerial->object.get(); buffer->mDevice = device; + buffer->mDeviceIsAlive = device->GetAliveWeakPtr(); buffer->mSize = descriptor->size; DeviceCreateBufferCmd cmd; @@ -92,6 +93,7 @@ namespace dawn_wire { namespace client { WGPUBuffer Buffer::CreateError(Device* device) { auto* allocation = device->client->BufferAllocator().New(device->client); allocation->object->mDevice = device; + allocation->object->mDeviceIsAlive = device->GetAliveWeakPtr(); DeviceCreateErrorBufferCmd cmd; cmd.self = ToAPI(device); @@ -140,8 +142,10 @@ namespace dawn_wire { namespace client { // Step 1. Do early validation of READ ^ WRITE because the server rejects mode = 0. if (!(isReadMode ^ isWriteMode)) { - mDevice->InjectError(WGPUErrorType_Validation, - "MapAsync mode must be exactly one of Read or Write"); + if (!mDeviceIsAlive.expired()) { + mDevice->InjectError(WGPUErrorType_Validation, + "MapAsync mode must be exactly one of Read or Write"); + } if (callback != nullptr) { callback(WGPUBufferMapAsyncStatus_Error, userdata); } @@ -163,7 +167,10 @@ namespace dawn_wire { namespace client { if (isReadMode) { request.readHandle.reset(client->GetMemoryTransferService()->CreateReadHandle(size)); if (request.readHandle == nullptr) { - mDevice->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping"); + if (!mDeviceIsAlive.expired()) { + mDevice->InjectError(WGPUErrorType_OutOfMemory, + "Failed to create buffer mapping"); + } callback(WGPUBufferMapAsyncStatus_Error, userdata); return; } @@ -171,7 +178,10 @@ namespace dawn_wire { namespace client { ASSERT(isWriteMode); request.writeHandle.reset(client->GetMemoryTransferService()->CreateWriteHandle(size)); if (request.writeHandle == nullptr) { - mDevice->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping"); + if (!mDeviceIsAlive.expired()) { + mDevice->InjectError(WGPUErrorType_OutOfMemory, + "Failed to create buffer mapping"); + } callback(WGPUBufferMapAsyncStatus_Error, userdata); return; } diff --git a/src/dawn_wire/client/Buffer.h b/src/dawn_wire/client/Buffer.h index d6a275c15d..5e0d5ec2a3 100644 --- a/src/dawn_wire/client/Buffer.h +++ b/src/dawn_wire/client/Buffer.h @@ -89,6 +89,8 @@ namespace dawn_wire { namespace client { void* mMappedData = nullptr; size_t mMapOffset = 0; size_t mMapSize = 0; + + std::weak_ptr mDeviceIsAlive; }; }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp index 08a3e33831..b6ee4fceea 100644 --- a/src/dawn_wire/client/Device.cpp +++ b/src/dawn_wire/client/Device.cpp @@ -23,7 +23,7 @@ namespace dawn_wire { namespace client { Device::Device(Client* clientIn, uint32_t initialRefcount, uint32_t initialId) - : ObjectBase(clientIn, initialRefcount, initialId) { + : ObjectBase(clientIn, initialRefcount, initialId), mIsAlive(std::make_shared()) { #if defined(DAWN_ENABLE_ASSERTS) mErrorCallback = [](WGPUErrorType, char const*, void*) { static bool calledOnce = false; @@ -114,6 +114,10 @@ namespace dawn_wire { namespace client { mErrorScopes.clear(); } + std::weak_ptr Device::GetAliveWeakPtr() { + return mIsAlive; + } + void Device::SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata) { mErrorCallback = errorCallback; mErrorUserdata = errorUserdata; diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h index 6862bb3d3b..3f167004a8 100644 --- a/src/dawn_wire/client/Device.h +++ b/src/dawn_wire/client/Device.h @@ -64,6 +64,8 @@ namespace dawn_wire { namespace client { void CancelCallbacksForDisconnect() override; + std::weak_ptr GetAliveWeakPtr(); + private: struct ErrorScopeData { WGPUErrorCallback callback = nullptr; @@ -89,6 +91,8 @@ namespace dawn_wire { namespace client { void* mDeviceLostUserdata = nullptr; Queue* mDefaultQueue = nullptr; + + std::shared_ptr mIsAlive; }; }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Fence.cpp b/src/dawn_wire/client/Fence.cpp index 9bb40b3e4a..d9a858a245 100644 --- a/src/dawn_wire/client/Fence.cpp +++ b/src/dawn_wire/client/Fence.cpp @@ -38,9 +38,7 @@ namespace dawn_wire { namespace client { mOnCompletionRequests.clear(); } - void Fence::Initialize(Queue* queue, const WGPUFenceDescriptor* descriptor) { - mQueue = queue; - + void Fence::Initialize(const WGPUFenceDescriptor* descriptor) { mCompletedValue = descriptor != nullptr ? descriptor->initialValue : 0u; } @@ -87,8 +85,4 @@ namespace dawn_wire { namespace client { return mCompletedValue; } - Queue* Fence::GetQueue() const { - return mQueue; - } - }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Fence.h b/src/dawn_wire/client/Fence.h index a8fb5f4722..7c09948e56 100644 --- a/src/dawn_wire/client/Fence.h +++ b/src/dawn_wire/client/Fence.h @@ -27,7 +27,7 @@ namespace dawn_wire { namespace client { public: using ObjectBase::ObjectBase; ~Fence(); - void Initialize(Queue* queue, const WGPUFenceDescriptor* descriptor); + void Initialize(const WGPUFenceDescriptor* descriptor); void CheckPassedFences(); void OnCompletion(uint64_t value, WGPUFenceOnCompletionCallback callback, void* userdata); @@ -35,7 +35,6 @@ namespace dawn_wire { namespace client { bool OnCompletionCallback(uint64_t requestSerial, WGPUFenceCompletionStatus status); uint64_t GetCompletedValue() const; - Queue* GetQueue() const; private: void CancelCallbacksForDisconnect() override; @@ -44,7 +43,6 @@ namespace dawn_wire { namespace client { WGPUFenceOnCompletionCallback callback = nullptr; void* userdata = nullptr; }; - Queue* mQueue = nullptr; uint64_t mCompletedValue = 0; uint64_t mOnCompletionRequestSerial = 0; std::map mOnCompletionRequests; diff --git a/src/dawn_wire/client/Queue.cpp b/src/dawn_wire/client/Queue.cpp index 37ae3fa6f4..f5f68f782b 100644 --- a/src/dawn_wire/client/Queue.cpp +++ b/src/dawn_wire/client/Queue.cpp @@ -29,7 +29,7 @@ namespace dawn_wire { namespace client { client->SerializeCommand(cmd); Fence* fence = allocation->object.get(); - fence->Initialize(this, descriptor); + fence->Initialize(descriptor); return ToAPI(fence); } diff --git a/src/tests/unittests/wire/WireDestroyObjectTests.cpp b/src/tests/unittests/wire/WireDestroyObjectTests.cpp index f5e16b7b30..34b976dcc9 100644 --- a/src/tests/unittests/wire/WireDestroyObjectTests.cpp +++ b/src/tests/unittests/wire/WireDestroyObjectTests.cpp @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include "tests/MockCallback.h" #include "tests/unittests/wire/WireTest.h" using namespace testing; @@ -43,3 +44,46 @@ TEST_F(WireDestroyObjectTests, DestroyDeviceDestroysChildren) { wgpuCommandEncoderFinish(encoder, nullptr); FlushClient(false); } + +// Test that calling a function that would generate an InjectError doesn't crash after +// the device is destroyed. +TEST_F(WireDestroyObjectTests, ImplicitInjectErrorAfterDestroyDevice) { + WGPUBufferDescriptor bufferDesc = {}; + bufferDesc.size = 4; + WGPUBuffer buffer = wgpuDeviceCreateBuffer(device, &bufferDesc); + + WGPUBuffer apiBuffer = api.GetNewBuffer(); + EXPECT_CALL(api, DeviceCreateBuffer(apiDevice, _)).WillOnce(Return(apiBuffer)); + + FlushClient(); + + { + // Control case: MapAsync errors on invalid WGPUMapMode. + MockCallback mockBufferMapCallback; + + EXPECT_CALL(api, DeviceInjectError(apiDevice, WGPUErrorType_Validation, _)); + EXPECT_CALL(mockBufferMapCallback, Call(WGPUBufferMapAsyncStatus_Error, this)); + wgpuBufferMapAsync(buffer, WGPUMapMode(0), 0, 4, mockBufferMapCallback.Callback(), + mockBufferMapCallback.MakeUserdata(this)); + + FlushClient(); + } + + { + // Now, release the device. InjectError shouldn't happen. + wgpuDeviceRelease(device); + MockCallback mockBufferMapCallback; + + EXPECT_CALL(mockBufferMapCallback, Call(WGPUBufferMapAsyncStatus_Error, this + 1)); + wgpuBufferMapAsync(buffer, WGPUMapMode(0), 0, 4, mockBufferMapCallback.Callback(), + mockBufferMapCallback.MakeUserdata(this + 1)); + + Sequence s1, s2; + // The device and child objects alre also released. + EXPECT_CALL(api, BufferRelease(apiBuffer)).InSequence(s1); + EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2); + EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2); + + FlushClient(); + } +}