Reject all callbacks with DeviceLost on wire client disconnect

When the wire is disconnected, the client will not receive any
messages from the server. We need to manually reject all callbacks.

Bug: dawn:556
Change-Id: Ia03456b3209dbe0e1e54543d344180d11d4c6f1e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/31162
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Stephen White <senorblanco@chromium.org>
This commit is contained in:
Austin Eng 2020-11-11 21:01:18 +00:00 committed by Commit Bot service account
parent 3120d5ea0d
commit b70a5b02e9
14 changed files with 180 additions and 5 deletions

View File

@ -28,7 +28,7 @@ namespace dawn_wire { namespace client {
{% if type.name.CamelCase() in client_special_objects %} {% if type.name.CamelCase() in client_special_objects %}
class {{Type}}; class {{Type}};
{% else %} {% else %}
struct {{type.name.CamelCase()}} : ObjectBase { struct {{type.name.CamelCase()}} final : ObjectBase {
using ObjectBase::ObjectBase; using ObjectBase::ObjectBase;
}; };
{% endif %} {% endif %}

View File

@ -110,6 +110,15 @@ namespace dawn_wire { namespace client {
mRequests.clear(); mRequests.clear();
} }
void Buffer::CancelCallbacksForDisconnect() {
for (auto& it : mRequests) {
if (it.second.callback) {
it.second.callback(WGPUBufferMapAsyncStatus_DeviceLost, it.second.userdata);
}
}
mRequests.clear();
}
void Buffer::MapAsync(WGPUMapModeFlags mode, void Buffer::MapAsync(WGPUMapModeFlags mode,
size_t offset, size_t offset,
size_t size, size_t size,

View File

@ -24,7 +24,7 @@
namespace dawn_wire { namespace client { namespace dawn_wire { namespace client {
class Buffer : public ObjectBase { class Buffer final : public ObjectBase {
public: public:
using ObjectBase::ObjectBase; using ObjectBase::ObjectBase;
@ -49,6 +49,8 @@ namespace dawn_wire { namespace client {
void Destroy(); void Destroy();
private: private:
void CancelCallbacksForDisconnect() override;
bool IsMappedForReading() const; bool IsMappedForReading() const;
bool IsMappedForWriting() const; bool IsMappedForWriting() const;
bool CheckGetMappedRangeOffsetSize(size_t offset, size_t size) const; bool CheckGetMappedRangeOffsetSize(size_t offset, size_t size) const;

View File

@ -86,6 +86,7 @@ namespace dawn_wire { namespace client {
mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance()); mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance());
if (mDevice != nullptr) { if (mDevice != nullptr) {
mDevice->HandleDeviceLost("GPU connection lost"); mDevice->HandleDeviceLost("GPU connection lost");
mDevice->CancelCallbacksForDisconnect();
} }
} }

View File

@ -91,6 +91,36 @@ namespace dawn_wire { namespace client {
} }
} }
void Device::CancelCallbacksForDisconnect() {
for (auto& it : mCreateReadyPipelineRequests) {
ASSERT((it.second.createReadyComputePipelineCallback != nullptr) ^
(it.second.createReadyRenderPipelineCallback != nullptr));
if (it.second.createReadyRenderPipelineCallback) {
it.second.createReadyRenderPipelineCallback(
WGPUCreateReadyPipelineStatus_DeviceLost, nullptr, "Device lost",
it.second.userdata);
} else {
it.second.createReadyComputePipelineCallback(
WGPUCreateReadyPipelineStatus_DeviceLost, nullptr, "Device lost",
it.second.userdata);
}
}
mCreateReadyPipelineRequests.clear();
for (auto& it : mErrorScopes) {
it.second.callback(WGPUErrorType_DeviceLost, "Device lost", it.second.userdata);
}
mErrorScopes.clear();
for (auto& objectList : mObjects) {
LinkNode<ObjectBase>* object = objectList.head();
while (object != objectList.end()) {
object->value()->CancelCallbacksForDisconnect();
object = object->next();
}
}
}
void Device::SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata) { void Device::SetUncapturedErrorCallback(WGPUErrorCallback errorCallback, void* errorUserdata) {
mErrorCallback = errorCallback; mErrorCallback = errorCallback;
mErrorUserdata = errorUserdata; mErrorUserdata = errorUserdata;

View File

@ -29,7 +29,7 @@ namespace dawn_wire { namespace client {
class Client; class Client;
class Queue; class Queue;
class Device : public ObjectBase { class Device final : public ObjectBase {
public: public:
Device(Client* client, uint32_t refcount, uint32_t id); Device(Client* client, uint32_t refcount, uint32_t id);
~Device(); ~Device();
@ -68,6 +68,8 @@ namespace dawn_wire { namespace client {
mObjects[ObjectTypeToTypeEnum<T>].Append(object); mObjects[ObjectTypeToTypeEnum<T>].Append(object);
} }
void CancelCallbacksForDisconnect() override;
private: private:
void DestroyAllObjects(); void DestroyAllObjects();

View File

@ -30,6 +30,15 @@ namespace dawn_wire { namespace client {
mOnCompletionRequests.clear(); mOnCompletionRequests.clear();
} }
void Fence::CancelCallbacksForDisconnect() {
for (auto& it : mOnCompletionRequests) {
if (it.second.callback) {
it.second.callback(WGPUFenceCompletionStatus_DeviceLost, it.second.userdata);
}
}
mOnCompletionRequests.clear();
}
void Fence::Initialize(Queue* queue, const WGPUFenceDescriptor* descriptor) { void Fence::Initialize(Queue* queue, const WGPUFenceDescriptor* descriptor) {
mQueue = queue; mQueue = queue;

View File

@ -23,7 +23,7 @@
namespace dawn_wire { namespace client { namespace dawn_wire { namespace client {
class Queue; class Queue;
class Fence : public ObjectBase { class Fence final : public ObjectBase {
public: public:
using ObjectBase::ObjectBase; using ObjectBase::ObjectBase;
~Fence(); ~Fence();
@ -38,6 +38,8 @@ namespace dawn_wire { namespace client {
Queue* GetQueue() const; Queue* GetQueue() const;
private: private:
void CancelCallbacksForDisconnect() override;
struct OnCompletionData { struct OnCompletionData {
WGPUFenceOnCompletionCallback callback = nullptr; WGPUFenceOnCompletionCallback callback = nullptr;
void* userdata = nullptr; void* userdata = nullptr;

View File

@ -38,6 +38,9 @@ namespace dawn_wire { namespace client {
RemoveFromList(); RemoveFromList();
} }
virtual void CancelCallbacksForDisconnect() {
}
Device* const device; Device* const device;
uint32_t refcount; uint32_t refcount;
const uint32_t id; const uint32_t id;

View File

@ -24,7 +24,7 @@
namespace dawn_wire { namespace client { namespace dawn_wire { namespace client {
class Queue : public ObjectBase { class Queue final : public ObjectBase {
public: public:
using ObjectBase::ObjectBase; using ObjectBase::ObjectBase;

View File

@ -14,6 +14,8 @@
#include "tests/unittests/wire/WireTest.h" #include "tests/unittests/wire/WireTest.h"
#include "dawn_wire/WireClient.h"
using namespace testing; using namespace testing;
using namespace dawn_wire; using namespace dawn_wire;
@ -674,3 +676,19 @@ TEST_F(WireBufferMappingTests, MaxSizeMappableBufferOOMDirectly) {
FlushClient(); FlushClient();
} }
} }
// Test that registering a callback then wire disconnect calls the callback with
// DeviceLost.
TEST_F(WireBufferMappingTests, MapThenDisconnect) {
wgpuBufferMapAsync(buffer, WGPUMapMode_Write, 0, kBufferSize, ToMockBufferMapCallback, this);
EXPECT_CALL(api, OnBufferMapAsyncCallback(apiBuffer, _, _)).WillOnce(InvokeWithoutArgs([&]() {
api.CallMapAsyncCallback(apiBuffer, WGPUBufferMapAsyncStatus_Success);
}));
EXPECT_CALL(api, BufferGetMappedRange(apiBuffer, 0, kBufferSize)).Times(1);
FlushClient();
EXPECT_CALL(*mockBufferMapCallback, Call(WGPUBufferMapAsyncStatus_DeviceLost, this)).Times(1);
GetWireClient()->Disconnect();
}

View File

@ -14,6 +14,8 @@
#include "tests/unittests/wire/WireTest.h" #include "tests/unittests/wire/WireTest.h"
#include "dawn_wire/WireClient.h"
using namespace testing; using namespace testing;
using namespace dawn_wire; using namespace dawn_wire;
@ -213,3 +215,64 @@ TEST_F(WireCreateReadyPipelineTest, CreateReadyRenderPipelineError) {
FlushServer(); FlushServer();
} }
// Test that registering a callback then wire disconnect calls the callback with
// DeviceLost.
TEST_F(WireCreateReadyPipelineTest, CreateReadyRenderPipelineThenDisconnect) {
WGPUShaderModuleDescriptor vertexDescriptor = {};
WGPUShaderModule vsModule = wgpuDeviceCreateShaderModule(device, &vertexDescriptor);
WGPUShaderModule apiVsModule = api.GetNewShaderModule();
EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _)).WillOnce(Return(apiVsModule));
WGPUProgrammableStageDescriptor fragmentStage = {};
fragmentStage.module = vsModule;
fragmentStage.entryPoint = "main";
WGPURenderPipelineDescriptor pipelineDescriptor{};
pipelineDescriptor.vertexStage.module = vsModule;
pipelineDescriptor.vertexStage.entryPoint = "main";
pipelineDescriptor.fragmentStage = &fragmentStage;
wgpuDeviceCreateReadyRenderPipeline(device, &pipelineDescriptor,
ToMockCreateReadyRenderPipelineCallback, this);
EXPECT_CALL(api, OnDeviceCreateReadyRenderPipelineCallback(apiDevice, _, _, _))
.WillOnce(InvokeWithoutArgs([&]() {
api.CallDeviceCreateReadyRenderPipelineCallback(
apiDevice, WGPUCreateReadyPipelineStatus_Success, nullptr, "");
}));
FlushClient();
EXPECT_CALL(*mockCreateReadyRenderPipelineCallback,
Call(WGPUCreateReadyPipelineStatus_DeviceLost, _, _, this))
.Times(1);
GetWireClient()->Disconnect();
}
// Test that registering a callback then wire disconnect calls the callback with
// DeviceLost.
TEST_F(WireCreateReadyPipelineTest, CreateReadyComputePipelineThenDisconnect) {
WGPUShaderModuleDescriptor csDescriptor{};
WGPUShaderModule csModule = wgpuDeviceCreateShaderModule(device, &csDescriptor);
WGPUShaderModule apiCsModule = api.GetNewShaderModule();
EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _)).WillOnce(Return(apiCsModule));
WGPUComputePipelineDescriptor descriptor{};
descriptor.computeStage.module = csModule;
descriptor.computeStage.entryPoint = "main";
wgpuDeviceCreateReadyComputePipeline(device, &descriptor,
ToMockCreateReadyComputePipelineCallback, this);
EXPECT_CALL(api, OnDeviceCreateReadyComputePipelineCallback(apiDevice, _, _, _))
.WillOnce(InvokeWithoutArgs([&]() {
api.CallDeviceCreateReadyComputePipelineCallback(
apiDevice, WGPUCreateReadyPipelineStatus_Success, nullptr, "");
}));
FlushClient();
EXPECT_CALL(*mockCreateReadyComputePipelineCallback,
Call(WGPUCreateReadyPipelineStatus_DeviceLost, _, _, this))
.Times(1);
GetWireClient()->Disconnect();
}

View File

@ -14,6 +14,8 @@
#include "tests/unittests/wire/WireTest.h" #include "tests/unittests/wire/WireTest.h"
#include "dawn_wire/WireClient.h"
using namespace testing; using namespace testing;
using namespace dawn_wire; using namespace dawn_wire;
@ -216,6 +218,23 @@ TEST_F(WireErrorCallbackTests, PopErrorScopeDeviceDestroyed) {
.Times(1); .Times(1);
} }
// Test that registering a callback then wire disconnect calls the callback with
// DeviceLost.
TEST_F(WireErrorCallbackTests, PopErrorScopeThenDisconnect) {
wgpuDevicePushErrorScope(device, WGPUErrorFilter_Validation);
EXPECT_CALL(api, DevicePushErrorScope(apiDevice, WGPUErrorFilter_Validation)).Times(1);
EXPECT_TRUE(wgpuDevicePopErrorScope(device, ToMockDevicePopErrorScopeCallback, this));
EXPECT_CALL(api, OnDevicePopErrorScopeCallback(apiDevice, _, _)).WillOnce(Return(true));
FlushClient();
EXPECT_CALL(*mockDevicePopErrorScopeCallback,
Call(WGPUErrorType_DeviceLost, ValidStringMessage(), this))
.Times(1);
GetWireClient()->Disconnect();
}
// Test that PopErrorScope returns false if there are no error scopes. // Test that PopErrorScope returns false if there are no error scopes.
TEST_F(WireErrorCallbackTests, PopErrorScopeEmptyStack) { TEST_F(WireErrorCallbackTests, PopErrorScopeEmptyStack) {
// Empty stack // Empty stack

View File

@ -14,6 +14,8 @@
#include "tests/unittests/wire/WireTest.h" #include "tests/unittests/wire/WireTest.h"
#include "dawn_wire/WireClient.h"
using namespace testing; using namespace testing;
using namespace dawn_wire; using namespace dawn_wire;
@ -143,6 +145,21 @@ TEST_F(WireFenceTests, OnCompletionError) {
FlushServer(); FlushServer();
} }
// Test that registering a callback then wire disconnect calls the callback with
// DeviceLost.
TEST_F(WireFenceTests, OnCompletionThenDisconnect) {
wgpuFenceOnCompletion(fence, 0, ToMockFenceOnCompletionCallback, this);
EXPECT_CALL(api, OnFenceOnCompletionCallback(apiFence, 0u, _, _))
.WillOnce(InvokeWithoutArgs([&]() {
api.CallFenceOnCompletionCallback(apiFence, WGPUFenceCompletionStatus_Success);
}));
FlushClient();
EXPECT_CALL(*mockFenceOnCompletionCallback, Call(WGPUFenceCompletionStatus_DeviceLost, this))
.Times(1);
GetWireClient()->Disconnect();
}
// Without any flushes, it is valid to wait on a value less than or equal to // Without any flushes, it is valid to wait on a value less than or equal to
// the last signaled value // the last signaled value
TEST_F(WireFenceTests, OnCompletionSynchronousValidationSuccess) { TEST_F(WireFenceTests, OnCompletionSynchronousValidationSuccess) {