diff --git a/src/dawn_wire/client/Buffer.cpp b/src/dawn_wire/client/Buffer.cpp index 3276325860..23ca0abe35 100644 --- a/src/dawn_wire/client/Buffer.cpp +++ b/src/dawn_wire/client/Buffer.cpp @@ -124,6 +124,10 @@ namespace dawn_wire { namespace client { size_t size, WGPUBufferMapCallback callback, void* userdata) { + if (device->GetClient()->IsDisconnected()) { + return callback(WGPUBufferMapAsyncStatus_DeviceLost, userdata); + } + // Handle the defaulting of size required by WebGPU. if (size == 0 && offset < mSize) { size = mSize - offset; diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp index e8a754b5ff..0b53bc4fbf 100644 --- a/src/dawn_wire/client/Client.cpp +++ b/src/dawn_wire/client/Client.cpp @@ -83,6 +83,7 @@ namespace dawn_wire { namespace client { } void Client::Disconnect() { + mDisconnected = true; mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance()); if (mDevice != nullptr) { mDevice->HandleDeviceLost("GPU connection lost"); @@ -94,4 +95,8 @@ namespace dawn_wire { namespace client { mDevices.Append(device); } + bool Client::IsDisconnected() const { + return mDisconnected; + } + }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h index ecdae95c31..9a44cc394f 100644 --- a/src/dawn_wire/client/Client.h +++ b/src/dawn_wire/client/Client.h @@ -60,6 +60,7 @@ namespace dawn_wire { namespace client { } void Disconnect(); + bool IsDisconnected() const; void TrackObject(Device* device); @@ -75,6 +76,7 @@ namespace dawn_wire { namespace client { std::unique_ptr mOwnedMemoryTransferService = nullptr; LinkedList mDevices; + bool mDisconnected = false; }; std::unique_ptr CreateInlineMemoryTransferService(); diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp index 178857e014..fbaa05ae79 100644 --- a/src/dawn_wire/client/Device.cpp +++ b/src/dawn_wire/client/Device.cpp @@ -147,6 +147,11 @@ namespace dawn_wire { namespace client { } mErrorScopeStackSize--; + if (GetClient()->IsDisconnected()) { + callback(WGPUErrorType_DeviceLost, "GPU device disconnected", userdata); + return true; + } + uint64_t serial = mErrorScopeRequestSerial++; ASSERT(mErrorScopes.find(serial) == mErrorScopes.end()); @@ -211,6 +216,11 @@ namespace dawn_wire { namespace client { void Device::CreateReadyComputePipeline(WGPUComputePipelineDescriptor const* descriptor, WGPUCreateReadyComputePipelineCallback callback, void* userdata) { + if (device->GetClient()->IsDisconnected()) { + return callback(WGPUCreateReadyPipelineStatus_DeviceLost, nullptr, + "GPU device disconnected", userdata); + } + DeviceCreateReadyComputePipelineCmd cmd; cmd.device = ToAPI(this); cmd.descriptor = descriptor; @@ -262,6 +272,10 @@ namespace dawn_wire { namespace client { void Device::CreateReadyRenderPipeline(WGPURenderPipelineDescriptor const* descriptor, WGPUCreateReadyRenderPipelineCallback callback, void* userdata) { + if (GetClient()->IsDisconnected()) { + return callback(WGPUCreateReadyPipelineStatus_DeviceLost, nullptr, + "GPU device disconnected", userdata); + } DeviceCreateReadyRenderPipelineCmd cmd; cmd.device = ToAPI(this); cmd.descriptor = descriptor; diff --git a/src/dawn_wire/client/Fence.cpp b/src/dawn_wire/client/Fence.cpp index e2a8e6de60..a11d4e9e4a 100644 --- a/src/dawn_wire/client/Fence.cpp +++ b/src/dawn_wire/client/Fence.cpp @@ -48,6 +48,10 @@ namespace dawn_wire { namespace client { void Fence::OnCompletion(uint64_t value, WGPUFenceOnCompletionCallback callback, void* userdata) { + if (device->GetClient()->IsDisconnected()) { + return callback(WGPUFenceCompletionStatus_DeviceLost, userdata); + } + uint32_t serial = mOnCompletionRequestSerial++; ASSERT(mOnCompletionRequests.find(serial) == mOnCompletionRequests.end()); diff --git a/src/tests/unittests/wire/WireBufferMappingTests.cpp b/src/tests/unittests/wire/WireBufferMappingTests.cpp index 6e3e06bf83..f8e135aeee 100644 --- a/src/tests/unittests/wire/WireBufferMappingTests.cpp +++ b/src/tests/unittests/wire/WireBufferMappingTests.cpp @@ -692,3 +692,12 @@ TEST_F(WireBufferMappingTests, MapThenDisconnect) { EXPECT_CALL(*mockBufferMapCallback, Call(WGPUBufferMapAsyncStatus_DeviceLost, this)).Times(1); GetWireClient()->Disconnect(); } + +// Test that registering a callback after wire disconnect calls the callback with +// DeviceLost. +TEST_F(WireBufferMappingTests, MapAfterDisconnect) { + GetWireClient()->Disconnect(); + + EXPECT_CALL(*mockBufferMapCallback, Call(WGPUBufferMapAsyncStatus_DeviceLost, this)).Times(1); + wgpuBufferMapAsync(buffer, WGPUMapMode_Read, 0, kBufferSize, ToMockBufferMapCallback, this); +} diff --git a/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp b/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp index fbe84d9271..654996f8a2 100644 --- a/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp +++ b/src/tests/unittests/wire/WireCreateReadyPipelineTests.cpp @@ -276,3 +276,55 @@ TEST_F(WireCreateReadyPipelineTest, CreateReadyComputePipelineThenDisconnect) { .Times(1); GetWireClient()->Disconnect(); } + +// Test that registering a callback after wire disconnect calls the callback with +// DeviceLost. +TEST_F(WireCreateReadyPipelineTest, CreateReadyRenderPipelineAfterDisconnect) { + WGPUShaderModuleDescriptor vertexDescriptor = {}; + WGPUShaderModule vsModule = wgpuDeviceCreateShaderModule(device, &vertexDescriptor); + WGPUShaderModule apiVsModule = api.GetNewShaderModule(); + EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _)).WillOnce(Return(apiVsModule)); + + WGPUProgrammableStageDescriptor fragmentStage = {}; + fragmentStage.module = vsModule; + fragmentStage.entryPoint = "main"; + + WGPURenderPipelineDescriptor pipelineDescriptor{}; + pipelineDescriptor.vertexStage.module = vsModule; + pipelineDescriptor.vertexStage.entryPoint = "main"; + pipelineDescriptor.fragmentStage = &fragmentStage; + + FlushClient(); + + GetWireClient()->Disconnect(); + + EXPECT_CALL(*mockCreateReadyRenderPipelineCallback, + Call(WGPUCreateReadyPipelineStatus_DeviceLost, nullptr, _, this)) + .Times(1); + wgpuDeviceCreateReadyRenderPipeline(device, &pipelineDescriptor, + ToMockCreateReadyRenderPipelineCallback, this); +} + +// Test that registering a callback after wire disconnect calls the callback with +// DeviceLost. +TEST_F(WireCreateReadyPipelineTest, CreateReadyComputePipelineAfterDisconnect) { + WGPUShaderModuleDescriptor csDescriptor{}; + WGPUShaderModule csModule = wgpuDeviceCreateShaderModule(device, &csDescriptor); + WGPUShaderModule apiCsModule = api.GetNewShaderModule(); + EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _)).WillOnce(Return(apiCsModule)); + + WGPUComputePipelineDescriptor descriptor{}; + descriptor.computeStage.module = csModule; + descriptor.computeStage.entryPoint = "main"; + + FlushClient(); + + GetWireClient()->Disconnect(); + + EXPECT_CALL(*mockCreateReadyComputePipelineCallback, + Call(WGPUCreateReadyPipelineStatus_DeviceLost, nullptr, _, this)) + .Times(1); + + wgpuDeviceCreateReadyComputePipeline(device, &descriptor, + ToMockCreateReadyComputePipelineCallback, this); +} diff --git a/src/tests/unittests/wire/WireErrorCallbackTests.cpp b/src/tests/unittests/wire/WireErrorCallbackTests.cpp index 4d5d1dee63..be4782d734 100644 --- a/src/tests/unittests/wire/WireErrorCallbackTests.cpp +++ b/src/tests/unittests/wire/WireErrorCallbackTests.cpp @@ -235,6 +235,22 @@ TEST_F(WireErrorCallbackTests, PopErrorScopeThenDisconnect) { GetWireClient()->Disconnect(); } +// Test that registering a callback after wire disconnect calls the callback with +// DeviceLost. +TEST_F(WireErrorCallbackTests, PopErrorScopeAfterDisconnect) { + wgpuDevicePushErrorScope(device, WGPUErrorFilter_Validation); + EXPECT_CALL(api, DevicePushErrorScope(apiDevice, WGPUErrorFilter_Validation)).Times(1); + + FlushClient(); + + GetWireClient()->Disconnect(); + + EXPECT_CALL(*mockDevicePopErrorScopeCallback, + Call(WGPUErrorType_DeviceLost, ValidStringMessage(), this)) + .Times(1); + EXPECT_TRUE(wgpuDevicePopErrorScope(device, ToMockDevicePopErrorScopeCallback, this)); +} + // Test that PopErrorScope returns false if there are no error scopes. TEST_F(WireErrorCallbackTests, PopErrorScopeEmptyStack) { // Empty stack diff --git a/src/tests/unittests/wire/WireFenceTests.cpp b/src/tests/unittests/wire/WireFenceTests.cpp index 22837d0ae4..63ace923f3 100644 --- a/src/tests/unittests/wire/WireFenceTests.cpp +++ b/src/tests/unittests/wire/WireFenceTests.cpp @@ -160,6 +160,16 @@ TEST_F(WireFenceTests, OnCompletionThenDisconnect) { GetWireClient()->Disconnect(); } +// Test that registering a callback after wire disconnect calls the callback with +// DeviceLost. +TEST_F(WireFenceTests, OnCompletionAfterDisconnect) { + GetWireClient()->Disconnect(); + + EXPECT_CALL(*mockFenceOnCompletionCallback, Call(WGPUFenceCompletionStatus_DeviceLost, this)) + .Times(1); + wgpuFenceOnCompletion(fence, 0, ToMockFenceOnCompletionCallback, this); +} + // Without any flushes, it is valid to wait on a value less than or equal to // the last signaled value TEST_F(WireFenceTests, OnCompletionSynchronousValidationSuccess) {