From b70a5b02e9712693ffb796c0b0d8cfa24ca5eba5 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Wed, 11 Nov 2020 21:01:18 +0000 Subject: [PATCH] Reject all callbacks with DeviceLost on wire client disconnect When the wire is disconnected, the client will not receive any messages from the server. We need to manually reject all callbacks. Bug: dawn:556 Change-Id: Ia03456b3209dbe0e1e54543d344180d11d4c6f1e Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/31162 Commit-Queue: Austin Eng Reviewed-by: Stephen White --- .../templates/dawn_wire/client/ApiObjects.h | 2 +- src/dawn_wire/client/Buffer.cpp | 9 +++ src/dawn_wire/client/Buffer.h | 4 +- src/dawn_wire/client/Client.cpp | 1 + src/dawn_wire/client/Device.cpp | 30 +++++++++ src/dawn_wire/client/Device.h | 4 +- src/dawn_wire/client/Fence.cpp | 9 +++ src/dawn_wire/client/Fence.h | 4 +- src/dawn_wire/client/ObjectBase.h | 3 + src/dawn_wire/client/Queue.h | 2 +- .../unittests/wire/WireBufferMappingTests.cpp | 18 ++++++ .../wire/WireCreateReadyPipelineTests.cpp | 63 +++++++++++++++++++ .../unittests/wire/WireErrorCallbackTests.cpp | 19 ++++++ src/tests/unittests/wire/WireFenceTests.cpp | 17 +++++ 14 files changed, 180 insertions(+), 5 deletions(-) diff --git a/generator/templates/dawn_wire/client/ApiObjects.h b/generator/templates/dawn_wire/client/ApiObjects.h index 46930e0420..0d8421b976 100644 --- a/generator/templates/dawn_wire/client/ApiObjects.h +++ b/generator/templates/dawn_wire/client/ApiObjects.h @@ -28,7 +28,7 @@ namespace dawn_wire { namespace client { {% if type.name.CamelCase() in client_special_objects %} class {{Type}}; {% else %} - struct {{type.name.CamelCase()}} : ObjectBase { + struct {{type.name.CamelCase()}} final : ObjectBase { using ObjectBase::ObjectBase; }; {% endif %} diff --git a/src/dawn_wire/client/Buffer.cpp b/src/dawn_wire/client/Buffer.cpp index 9cbec3d163..3276325860 100644 --- a/src/dawn_wire/client/Buffer.cpp +++ b/src/dawn_wire/client/Buffer.cpp @@ -110,6 +110,15 @@ namespace dawn_wire { namespace client { mRequests.clear(); } + void Buffer::CancelCallbacksForDisconnect() { + for (auto& it : mRequests) { + if (it.second.callback) { + it.second.callback(WGPUBufferMapAsyncStatus_DeviceLost, it.second.userdata); + } + } + mRequests.clear(); + } + void Buffer::MapAsync(WGPUMapModeFlags mode, size_t offset, size_t size, diff --git a/src/dawn_wire/client/Buffer.h b/src/dawn_wire/client/Buffer.h index 9268a6495f..cacfd48b2e 100644 --- a/src/dawn_wire/client/Buffer.h +++ b/src/dawn_wire/client/Buffer.h @@ -24,7 +24,7 @@ namespace dawn_wire { namespace client { - class Buffer : public ObjectBase { + class Buffer final : public ObjectBase { public: using ObjectBase::ObjectBase; @@ -49,6 +49,8 @@ namespace dawn_wire { namespace client { void Destroy(); private: + void CancelCallbacksForDisconnect() override; + bool IsMappedForReading() const; bool IsMappedForWriting() const; bool CheckGetMappedRangeOffsetSize(size_t offset, size_t size) const; diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp index a1a6f63b9e..e8a754b5ff 100644 --- a/src/dawn_wire/client/Client.cpp +++ b/src/dawn_wire/client/Client.cpp @@ -86,6 +86,7 @@ namespace dawn_wire { namespace client { mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance()); if (mDevice != nullptr) { mDevice->HandleDeviceLost("GPU connection lost"); + mDevice->CancelCallbacksForDisconnect(); } } diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp index f139230c96..178857e014 100644 --- a/src/dawn_wire/client/Device.cpp +++ b/src/dawn_wire/client/Device.cpp @@ -91,6 +91,36 @@ namespace dawn_wire { namespace client { } } + void Device::CancelCallbacksForDisconnect() { + for (auto& it : mCreateReadyPipelineRequests) { + ASSERT((it.second.createReadyComputePipelineCallback != nullptr) ^ + (it.second.createReadyRenderPipelineCallback != nullptr)); + if (it.second.createReadyRenderPipelineCallback) { + it.second.createReadyRenderPipelineCallback( + WGPUCreateReadyPipelineStatus_DeviceLost, nullptr, "Device lost", + it.second.userdata); + } else { + it.second.createReadyComputePipelineCallback( + WGPUCreateReadyPipelineStatus_DeviceLost, nullptr, "Device lost", + it.second.userdata); + } + } + mCreateReadyPipelineRequests.clear(); + + for (auto& it : mErrorScopes) { + it.second.callback(WGPUErrorType_DeviceLost, "Device lost", it.second.userdata); + } + mErrorScopes.clear(); + + for (auto& objectList : mObjects) { + LinkNode* object = objectList.head(); + while (object != objectList.end()) { + object->value()->CancelCallbacksForDisconnect(); + object = object->next(); + } + } + } + void Device::SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata) { mErrorCallback = errorCallback; mErrorUserdata = errorUserdata; diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h index 82c68990c7..eef03a573f 100644 --- a/src/dawn_wire/client/Device.h +++ b/src/dawn_wire/client/Device.h @@ -29,7 +29,7 @@ namespace dawn_wire { namespace client { class Client; class Queue; - class Device : public ObjectBase { + class Device final : public ObjectBase { public: Device(Client* client, uint32_t refcount, uint32_t id); ~Device(); @@ -68,6 +68,8 @@ namespace dawn_wire { namespace client { mObjects[ObjectTypeToTypeEnum].Append(object); } + void CancelCallbacksForDisconnect() override; + private: void DestroyAllObjects(); diff --git a/src/dawn_wire/client/Fence.cpp b/src/dawn_wire/client/Fence.cpp index 329998f1ac..e2a8e6de60 100644 --- a/src/dawn_wire/client/Fence.cpp +++ b/src/dawn_wire/client/Fence.cpp @@ -30,6 +30,15 @@ namespace dawn_wire { namespace client { mOnCompletionRequests.clear(); } + void Fence::CancelCallbacksForDisconnect() { + for (auto& it : mOnCompletionRequests) { + if (it.second.callback) { + it.second.callback(WGPUFenceCompletionStatus_DeviceLost, it.second.userdata); + } + } + mOnCompletionRequests.clear(); + } + void Fence::Initialize(Queue* queue, const WGPUFenceDescriptor* descriptor) { mQueue = queue; diff --git a/src/dawn_wire/client/Fence.h b/src/dawn_wire/client/Fence.h index 00791944bd..a8fb5f4722 100644 --- a/src/dawn_wire/client/Fence.h +++ b/src/dawn_wire/client/Fence.h @@ -23,7 +23,7 @@ namespace dawn_wire { namespace client { class Queue; - class Fence : public ObjectBase { + class Fence final : public ObjectBase { public: using ObjectBase::ObjectBase; ~Fence(); @@ -38,6 +38,8 @@ namespace dawn_wire { namespace client { Queue* GetQueue() const; private: + void CancelCallbacksForDisconnect() override; + struct OnCompletionData { WGPUFenceOnCompletionCallback callback = nullptr; void* userdata = nullptr; diff --git a/src/dawn_wire/client/ObjectBase.h b/src/dawn_wire/client/ObjectBase.h index 18778d6a64..e317611f90 100644 --- a/src/dawn_wire/client/ObjectBase.h +++ b/src/dawn_wire/client/ObjectBase.h @@ -38,6 +38,9 @@ namespace dawn_wire { namespace client { RemoveFromList(); } + virtual void CancelCallbacksForDisconnect() { + } + Device* const device; uint32_t refcount; const uint32_t id; diff --git a/src/dawn_wire/client/Queue.h b/src/dawn_wire/client/Queue.h index 9e50348368..91c93935a7 100644 --- a/src/dawn_wire/client/Queue.h +++ b/src/dawn_wire/client/Queue.h @@ -24,7 +24,7 @@ namespace dawn_wire { namespace client { - class Queue : public ObjectBase { + class Queue final : public ObjectBase { public: using ObjectBase::ObjectBase; diff --git a/src/tests/unittests/wire/WireBufferMappingTests.cpp b/src/tests/unittests/wire/WireBufferMappingTests.cpp index 1f0f31ba4b..6e3e06bf83 100644 --- a/src/tests/unittests/wire/WireBufferMappingTests.cpp +++ b/src/tests/unittests/wire/WireBufferMappingTests.cpp @@ -14,6 +14,8 @@ #include "tests/unittests/wire/WireTest.h" +#include "dawn_wire/WireClient.h" + using namespace testing; using namespace dawn_wire; @@ -674,3 +676,19 @@ TEST_F(WireBufferMappingTests, MaxSizeMappableBufferOOMDirectly) { FlushClient(); } } + +// Test that registering a callback then wire disconnect calls the callback with +// DeviceLost. +TEST_F(WireBufferMappingTests, MapThenDisconnect) { + wgpuBufferMapAsync(buffer, WGPUMapMode_Write, 0, kBufferSize, ToMockBufferMapCallback, this); + + EXPECT_CALL(api, OnBufferMapAsyncCallback(apiBuffer, _, _)).WillOnce(InvokeWithoutArgs([&]() { + api.CallMapAsyncCallback(apiBuffer, WGPUBufferMapAsyncStatus_Success); + })); + EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)).Times(1); + + FlushClient(); + + EXPECT_CALL(*mockBufferMapCallback, Call(WGPUBufferMapAsyncStatus_DeviceLost, this)).Times(1); + GetWireClient()->Disconnect(); +} diff --git a/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp b/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp index 98fb1f8d05..fbe84d9271 100644 --- a/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp +++ b/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp @@ -14,6 +14,8 @@ #include "tests/unittests/wire/WireTest.h" +#include "dawn_wire/WireClient.h" + using namespace testing; using namespace dawn_wire; @@ -213,3 +215,64 @@ TEST_F(WireCreateReadyPipelineTest, CreateReadyRenderPipelineError) { FlushServer(); } + +// Test that registering a callback then wire disconnect calls the callback with +// DeviceLost. +TEST_F(WireCreateReadyPipelineTest, CreateReadyRenderPipelineThenDisconnect) { + WGPUShaderModuleDescriptor vertexDescriptor = {}; + WGPUShaderModule vsModule = wgpuDeviceCreateShaderModule(device, &vertexDescriptor); + WGPUShaderModule apiVsModule = api.GetNewShaderModule(); + EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _)).WillOnce(Return(apiVsModule)); + + WGPUProgrammableStageDescriptor fragmentStage = {}; + fragmentStage.module = vsModule; + fragmentStage.entryPoint = "main"; + + WGPURenderPipelineDescriptor pipelineDescriptor{}; + pipelineDescriptor.vertexStage.module = vsModule; + pipelineDescriptor.vertexStage.entryPoint = "main"; + pipelineDescriptor.fragmentStage = &fragmentStage; + + wgpuDeviceCreateReadyRenderPipeline(device, &pipelineDescriptor, + ToMockCreateReadyRenderPipelineCallback, this); + EXPECT_CALL(api, OnDeviceCreateReadyRenderPipelineCallback(apiDevice, _, _, _)) + .WillOnce(InvokeWithoutArgs([&]() { + api.CallDeviceCreateReadyRenderPipelineCallback( + apiDevice, WGPUCreateReadyPipelineStatus_Success, nullptr, ""); + })); + + FlushClient(); + + EXPECT_CALL(*mockCreateReadyRenderPipelineCallback, + Call(WGPUCreateReadyPipelineStatus_DeviceLost, _, _, this)) + .Times(1); + GetWireClient()->Disconnect(); +} + +// Test that registering a callback then wire disconnect calls the callback with +// DeviceLost. +TEST_F(WireCreateReadyPipelineTest, CreateReadyComputePipelineThenDisconnect) { + WGPUShaderModuleDescriptor csDescriptor{}; + WGPUShaderModule csModule = wgpuDeviceCreateShaderModule(device, &csDescriptor); + WGPUShaderModule apiCsModule = api.GetNewShaderModule(); + EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _)).WillOnce(Return(apiCsModule)); + + WGPUComputePipelineDescriptor descriptor{}; + descriptor.computeStage.module = csModule; + descriptor.computeStage.entryPoint = "main"; + + wgpuDeviceCreateReadyComputePipeline(device, &descriptor, + ToMockCreateReadyComputePipelineCallback, this); + EXPECT_CALL(api, OnDeviceCreateReadyComputePipelineCallback(apiDevice, _, _, _)) + .WillOnce(InvokeWithoutArgs([&]() { + api.CallDeviceCreateReadyComputePipelineCallback( + apiDevice, WGPUCreateReadyPipelineStatus_Success, nullptr, ""); + })); + + FlushClient(); + + EXPECT_CALL(*mockCreateReadyComputePipelineCallback, + Call(WGPUCreateReadyPipelineStatus_DeviceLost, _, _, this)) + .Times(1); + GetWireClient()->Disconnect(); +} diff --git a/src/tests/unittests/wire/WireErrorCallbackTests.cpp b/src/tests/unittests/wire/WireErrorCallbackTests.cpp index 3fe25f0171..4d5d1dee63 100644 --- a/src/tests/unittests/wire/WireErrorCallbackTests.cpp +++ b/src/tests/unittests/wire/WireErrorCallbackTests.cpp @@ -14,6 +14,8 @@ #include "tests/unittests/wire/WireTest.h" +#include "dawn_wire/WireClient.h" + using namespace testing; using namespace dawn_wire; @@ -216,6 +218,23 @@ TEST_F(WireErrorCallbackTests, PopErrorScopeDeviceDestroyed) { .Times(1); } +// Test that registering a callback then wire disconnect calls the callback with +// DeviceLost. +TEST_F(WireErrorCallbackTests, PopErrorScopeThenDisconnect) { + wgpuDevicePushErrorScope(device, WGPUErrorFilter_Validation); + EXPECT_CALL(api, DevicePushErrorScope(apiDevice, WGPUErrorFilter_Validation)).Times(1); + + EXPECT_TRUE(wgpuDevicePopErrorScope(device, ToMockDevicePopErrorScopeCallback, this)); + EXPECT_CALL(api, OnDevicePopErrorScopeCallback(apiDevice, _, _)).WillOnce(Return(true)); + + FlushClient(); + + EXPECT_CALL(*mockDevicePopErrorScopeCallback, + Call(WGPUErrorType_DeviceLost, ValidStringMessage(), this)) + .Times(1); + GetWireClient()->Disconnect(); +} + // Test that PopErrorScope returns false if there are no error scopes. TEST_F(WireErrorCallbackTests, PopErrorScopeEmptyStack) { // Empty stack diff --git a/src/tests/unittests/wire/WireFenceTests.cpp b/src/tests/unittests/wire/WireFenceTests.cpp index 29ae5ee5fa..22837d0ae4 100644 --- a/src/tests/unittests/wire/WireFenceTests.cpp +++ b/src/tests/unittests/wire/WireFenceTests.cpp @@ -14,6 +14,8 @@ #include "tests/unittests/wire/WireTest.h" +#include "dawn_wire/WireClient.h" + using namespace testing; using namespace dawn_wire; @@ -143,6 +145,21 @@ TEST_F(WireFenceTests, OnCompletionError) { FlushServer(); } +// Test that registering a callback then wire disconnect calls the callback with +// DeviceLost. +TEST_F(WireFenceTests, OnCompletionThenDisconnect) { + wgpuFenceOnCompletion(fence, 0, ToMockFenceOnCompletionCallback, this); + EXPECT_CALL(api, OnFenceOnCompletionCallback(apiFence, 0u, _, _)) + .WillOnce(InvokeWithoutArgs([&]() { + api.CallFenceOnCompletionCallback(apiFence, WGPUFenceCompletionStatus_Success); + })); + FlushClient(); + + EXPECT_CALL(*mockFenceOnCompletionCallback, Call(WGPUFenceCompletionStatus_DeviceLost, this)) + .Times(1); + GetWireClient()->Disconnect(); +} + // Without any flushes, it is valid to wait on a value less than or equal to // the last signaled value TEST_F(WireFenceTests, OnCompletionSynchronousValidationSuccess) {