From 93bea5cb50e30f115088be752145df88562bfa25 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Wed, 15 Apr 2020 02:00:14 +0000 Subject: [PATCH] Add WireClient::Disconnect to stop serializing commands This prevents the client from continuing to send commands when the wire connection has dropped. In Chromium this may be because the connection to the GPU process is lost and the transfer buffer may be destroyed. This CL also adds a new helper to make testing callbacks with mocks easier. Bug: chromium:1070392 Change-Id: I6a69c32cc506069554ead18ee83a156ca70e2ce2 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/19160 Commit-Queue: Austin Eng Reviewed-by: Kai Ninomiya --- src/dawn_wire/WireClient.cpp | 4 + src/dawn_wire/client/Client.cpp | 19 +++ src/dawn_wire/client/Client.h | 17 ++- src/dawn_wire/client/Device.cpp | 3 +- src/dawn_wire/client/Device.h | 1 + src/include/dawn_wire/WireClient.h | 4 + src/tests/BUILD.gn | 2 + src/tests/MockCallback.h | 101 ++++++++++++++ .../unittests/wire/WireDisconnectTests.cpp | 128 ++++++++++++++++++ 9 files changed, 271 insertions(+), 8 deletions(-) create mode 100644 src/tests/MockCallback.h create mode 100644 src/tests/unittests/wire/WireDisconnectTests.cpp diff --git a/src/dawn_wire/WireClient.cpp b/src/dawn_wire/WireClient.cpp index 972d6decd8..e6fe263767 100644 --- a/src/dawn_wire/WireClient.cpp +++ b/src/dawn_wire/WireClient.cpp @@ -42,6 +42,10 @@ namespace dawn_wire { return mImpl->ReserveTexture(device); } + void WireClient::Disconnect() { + mImpl->Disconnect(); + } + namespace client { MemoryTransferService::~MemoryTransferService() = default; diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp index 71c1a5f63a..1214e3824f 100644 --- a/src/dawn_wire/client/Client.cpp +++ b/src/dawn_wire/client/Client.cpp @@ -13,6 +13,8 @@ // limitations under the License. #include "dawn_wire/client/Client.h" + +#include "common/Compiler.h" #include "dawn_wire/client/Device.h" namespace dawn_wire { namespace client { @@ -44,4 +46,21 @@ namespace dawn_wire { namespace client { return result; } + void* Client::GetCmdSpace(size_t size) { + if (DAWN_UNLIKELY(mIsDisconnected)) { + if (size > mDummyCmdSpace.size()) { + mDummyCmdSpace.resize(size); + } + return mDummyCmdSpace.data(); + } + return mSerializer->GetCmdSpace(size); + } + + void Client::Disconnect() { + if (!mIsDisconnected) { + mIsDisconnected = true; + mDevice->HandleDeviceLost("GPU connection lost"); + } + } + }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h index f7d311f3c9..6a8e4b844c 100644 --- a/src/dawn_wire/client/Client.h +++ b/src/dawn_wire/client/Client.h @@ -33,13 +33,6 @@ namespace dawn_wire { namespace client { Client(CommandSerializer* serializer, MemoryTransferService* memoryTransferService); ~Client(); - const volatile char* HandleCommands(const volatile char* commands, size_t size); - ReservedTexture ReserveTexture(WGPUDevice device); - - void* GetCmdSpace(size_t size) { - return mSerializer->GetCmdSpace(size); - } - WGPUDevice GetDevice() const { return reinterpret_cast(mDevice); } @@ -48,6 +41,13 @@ namespace dawn_wire { namespace client { return mMemoryTransferService; } + const volatile char* HandleCommands(const volatile char* commands, size_t size); + ReservedTexture ReserveTexture(WGPUDevice device); + + void* GetCmdSpace(size_t size); + + void Disconnect(); + private: #include "dawn_wire/client/ClientPrototypes_autogen.inc" @@ -56,6 +56,9 @@ namespace dawn_wire { namespace client { WireDeserializeAllocator mAllocator; MemoryTransferService* mMemoryTransferService = nullptr; std::unique_ptr mOwnedMemoryTransferService = nullptr; + + std::vector mDummyCmdSpace; + bool mIsDisconnected = false; }; DawnProcTable GetProcs(); diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp index 7d62b06e98..67021ef046 100644 --- a/src/dawn_wire/client/Device.cpp +++ b/src/dawn_wire/client/Device.cpp @@ -43,7 +43,8 @@ namespace dawn_wire { namespace client { } void Device::HandleDeviceLost(const char* message) { - if (mDeviceLostCallback) { + if (mDeviceLostCallback && !mDidRunLostCallback) { + mDidRunLostCallback = true; mDeviceLostCallback(message, mDeviceLostUserdata); } } diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h index af5934e825..e70cae2b11 100644 --- a/src/dawn_wire/client/Device.h +++ b/src/dawn_wire/client/Device.h @@ -52,6 +52,7 @@ namespace dawn_wire { namespace client { Client* mClient = nullptr; WGPUErrorCallback mErrorCallback = nullptr; WGPUDeviceLostCallback mDeviceLostCallback = nullptr; + bool mDidRunLostCallback = false; void* mErrorUserdata; void* mDeviceLostUserdata; }; diff --git a/src/include/dawn_wire/WireClient.h b/src/include/dawn_wire/WireClient.h index 7c8ee408c6..5b5f33c2ac 100644 --- a/src/include/dawn_wire/WireClient.h +++ b/src/include/dawn_wire/WireClient.h @@ -52,6 +52,10 @@ namespace dawn_wire { ReservedTexture ReserveTexture(WGPUDevice device); + // Disconnects the client. + // Commands allocated after this point will not be sent. + void Disconnect(); + private: std::unique_ptr mImpl; }; diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn index 577db233be..cdc2578806 100644 --- a/src/tests/BUILD.gn +++ b/src/tests/BUILD.gn @@ -129,6 +129,7 @@ test("dawn_unittests") { "${dawn_root}/src/dawn_wire/client/ClientMemoryTransferService_mock.h", "${dawn_root}/src/dawn_wire/server/ServerMemoryTransferService_mock.cpp", "${dawn_root}/src/dawn_wire/server/ServerMemoryTransferService_mock.h", + "MockCallback.h", "unittests/BitSetIteratorTests.cpp", "unittests/BuddyAllocatorTests.cpp", "unittests/BuddyMemoryAllocatorTests.cpp", @@ -181,6 +182,7 @@ test("dawn_unittests") { "unittests/wire/WireArgumentTests.cpp", "unittests/wire/WireBasicTests.cpp", "unittests/wire/WireBufferMappingTests.cpp", + "unittests/wire/WireDisconnectTests.cpp", "unittests/wire/WireErrorCallbackTests.cpp", "unittests/wire/WireExtensionTests.cpp", "unittests/wire/WireFenceTests.cpp", diff --git a/src/tests/MockCallback.h b/src/tests/MockCallback.h new file mode 100644 index 0000000000..c3dfb4ab90 --- /dev/null +++ b/src/tests/MockCallback.h @@ -0,0 +1,101 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "common/Assert.h" + +#include +#include + +namespace testing { + + template + class MockCallback; + + // Helper class for mocking callbacks used for Dawn callbacks with |void* userdata| + // as the last callback argument. + // + // Example Usage: + // MockCallback mock; + // + // void* foo = XYZ; // this is the callback userdata + // + // wgpuDeviceSetDeviceLostCallback(device, mock.Callback(), mock.MakeUserdata(foo)); + // EXPECT_CALL(mock, Call(_, foo)); + template + class MockCallback : public ::testing::MockFunction { + using CallbackType = R (*)(Args...); + + public: + // Helper function makes it easier to get the callback using |foo.Callback()| + // unstead of MockCallback::Callback. + static CallbackType Callback() { + return CallUnboundCallback; + } + + void* MakeUserdata(void* userdata) { + auto mockAndUserdata = + std::unique_ptr(new MockAndUserdata{this, userdata}); + + // Add the userdata to a set of userdata for this mock. We never + // remove from this set even if a callback should only be called once so that + // repeated calls to the callback still forward the userdata correctly. + // Userdata will be destroyed when the mock is destroyed. + auto it = mUserdatas.insert(std::move(mockAndUserdata)); + ASSERT(it.second); + return it.first->get(); + } + + private: + struct MockAndUserdata { + MockCallback* mock; + void* userdata; + }; + + static R CallUnboundCallback(Args... args) { + std::tuple tuple = std::make_tuple(args...); + + constexpr size_t ArgC = sizeof...(Args); + static_assert(ArgC >= 1, "Mock callback requires at least one argument (the userdata)"); + + // Get the userdata. It should be the last argument. + auto userdata = std::get(tuple); + static_assert(std::is_same::value, + "Last callback argument must be void* userdata"); + + // Extract the mock. + ASSERT(userdata != nullptr); + auto* mockAndUserdata = reinterpret_cast(userdata); + MockCallback* mock = mockAndUserdata->mock; + ASSERT(mock != nullptr); + + // Replace the userdata + std::get(tuple) = mockAndUserdata->userdata; + + // Forward the callback to the mock. + return mock->CallImpl(std::make_index_sequence{}, std::move(tuple)); + } + + // This helper cannot be inlined because we dependent on the templated index sequence + // to unpack the tuple arguments. + template + R CallImpl(const std::index_sequence&, std::tuple args) { + return this->Call(std::get(args)...); + } + + std::set> mUserdatas; + }; + +} // namespace testing diff --git a/src/tests/unittests/wire/WireDisconnectTests.cpp b/src/tests/unittests/wire/WireDisconnectTests.cpp new file mode 100644 index 0000000000..4e9b355f8c --- /dev/null +++ b/src/tests/unittests/wire/WireDisconnectTests.cpp @@ -0,0 +1,128 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tests/unittests/wire/WireTest.h" + +#include "common/Assert.h" +#include "dawn_wire/WireClient.h" +#include "tests/MockCallback.h" + +using namespace testing; +using namespace dawn_wire; + +namespace { + + class WireDisconnectTests : public WireTest {}; + +} // anonymous namespace + +// Test that commands are not received if the client disconnects. +TEST_F(WireDisconnectTests, CommandsAfterDisconnect) { + // Sanity check that commands work at all. + wgpuDeviceCreateCommandEncoder(device, nullptr); + + WGPUCommandEncoder apiCmdBufEncoder = api.GetNewCommandEncoder(); + EXPECT_CALL(api, DeviceCreateCommandEncoder(apiDevice, nullptr)) + .WillOnce(Return(apiCmdBufEncoder)); + FlushClient(); + + // Disconnect. + GetWireClient()->Disconnect(); + + // Command is not received because client disconnected. + wgpuDeviceCreateCommandEncoder(device, nullptr); + EXPECT_CALL(api, DeviceCreateCommandEncoder(_, _)).Times(Exactly(0)); + FlushClient(); +} + +// Test that commands that are serialized before a disconnect but flushed +// after are received. +TEST_F(WireDisconnectTests, FlushAfterDisconnect) { + // Sanity check that commands work at all. + wgpuDeviceCreateCommandEncoder(device, nullptr); + + // Disconnect. + GetWireClient()->Disconnect(); + + // Already-serialized commmands are still received. + WGPUCommandEncoder apiCmdBufEncoder = api.GetNewCommandEncoder(); + EXPECT_CALL(api, DeviceCreateCommandEncoder(apiDevice, nullptr)) + .WillOnce(Return(apiCmdBufEncoder)); + FlushClient(); +} + +// Check that disconnecting the wire client calls the device lost callback exacty once. +TEST_F(WireDisconnectTests, CallsDeviceLostCallback) { + MockCallback mockDeviceLostCallback; + wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(), + mockDeviceLostCallback.MakeUserdata(this)); + + // Disconnect the wire client. We should receive device lost only once. + EXPECT_CALL(mockDeviceLostCallback, Call(_, this)).Times(Exactly(1)); + GetWireClient()->Disconnect(); + GetWireClient()->Disconnect(); +} + +// Check that disconnecting the wire client after a device loss does not trigger the callback again. +TEST_F(WireDisconnectTests, ServerLostThenDisconnect) { + MockCallback mockDeviceLostCallback; + wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(), + mockDeviceLostCallback.MakeUserdata(this)); + + api.CallDeviceLostCallback(apiDevice, "some reason"); + + // Flush the device lost return command. + EXPECT_CALL(mockDeviceLostCallback, Call(StrEq("some reason"), this)).Times(Exactly(1)); + FlushServer(); + + // Disconnect the client. We shouldn't see the lost callback again. + EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0)); + GetWireClient()->Disconnect(); +} + +// Check that disconnecting the wire client inside the device loss callback does not trigger the +// callback again. +TEST_F(WireDisconnectTests, ServerLostThenDisconnectInCallback) { + MockCallback mockDeviceLostCallback; + wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(), + mockDeviceLostCallback.MakeUserdata(this)); + + api.CallDeviceLostCallback(apiDevice, "lost reason"); + + // Disconnect the client inside the lost callback. We should see the callback + // only once. + EXPECT_CALL(mockDeviceLostCallback, Call(StrEq("lost reason"), this)) + .WillOnce(InvokeWithoutArgs([&]() { + EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0)); + GetWireClient()->Disconnect(); + })); + FlushServer(); +} + +// Check that a device loss after a disconnect does not trigger the callback again. +TEST_F(WireDisconnectTests, DisconnectThenServerLost) { + MockCallback mockDeviceLostCallback; + wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(), + mockDeviceLostCallback.MakeUserdata(this)); + + // Disconnect the client. We should see the callback once. + EXPECT_CALL(mockDeviceLostCallback, Call(_, this)).Times(Exactly(1)); + GetWireClient()->Disconnect(); + + // Lose the device on the server. The client callback shouldn't be + // called again. + api.CallDeviceLostCallback(apiDevice, "lost reason"); + EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0)); + FlushServer(); +}