diff --git a/src/dawn_wire/WireClient.cpp b/src/dawn_wire/WireClient.cpp index 972d6decd8..e6fe263767 100644 --- a/src/dawn_wire/WireClient.cpp +++ b/src/dawn_wire/WireClient.cpp @@ -42,6 +42,10 @@ namespace dawn_wire { return mImpl->ReserveTexture(device); } + void WireClient::Disconnect() { + mImpl->Disconnect(); + } + namespace client { MemoryTransferService::~MemoryTransferService() = default; diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp index 71c1a5f63a..1214e3824f 100644 --- a/src/dawn_wire/client/Client.cpp +++ b/src/dawn_wire/client/Client.cpp @@ -13,6 +13,8 @@ // limitations under the License. #include "dawn_wire/client/Client.h" + +#include "common/Compiler.h" #include "dawn_wire/client/Device.h" namespace dawn_wire { namespace client { @@ -44,4 +46,21 @@ namespace dawn_wire { namespace client { return result; } + void* Client::GetCmdSpace(size_t size) { + if (DAWN_UNLIKELY(mIsDisconnected)) { + if (size > mDummyCmdSpace.size()) { + mDummyCmdSpace.resize(size); + } + return mDummyCmdSpace.data(); + } + return mSerializer->GetCmdSpace(size); + } + + void Client::Disconnect() { + if (!mIsDisconnected) { + mIsDisconnected = true; + mDevice->HandleDeviceLost("GPU connection lost"); + } + } + }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h index f7d311f3c9..6a8e4b844c 100644 --- a/src/dawn_wire/client/Client.h +++ b/src/dawn_wire/client/Client.h @@ -33,13 +33,6 @@ namespace dawn_wire { namespace client { Client(CommandSerializer* serializer, MemoryTransferService* memoryTransferService); ~Client(); - const volatile char* HandleCommands(const volatile char* commands, size_t size); - ReservedTexture ReserveTexture(WGPUDevice device); - - void* GetCmdSpace(size_t size) { - return mSerializer->GetCmdSpace(size); - } - WGPUDevice GetDevice() const { return reinterpret_cast(mDevice); } @@ -48,6 +41,13 @@ namespace dawn_wire { namespace client { return mMemoryTransferService; } + const volatile char* HandleCommands(const volatile char* commands, size_t size); + ReservedTexture ReserveTexture(WGPUDevice device); + + void* GetCmdSpace(size_t size); + + void Disconnect(); + private: #include "dawn_wire/client/ClientPrototypes_autogen.inc" @@ -56,6 +56,9 @@ namespace dawn_wire { namespace client { WireDeserializeAllocator mAllocator; MemoryTransferService* mMemoryTransferService = nullptr; std::unique_ptr mOwnedMemoryTransferService = nullptr; + + std::vector mDummyCmdSpace; + bool mIsDisconnected = false; }; DawnProcTable GetProcs(); diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp index 7d62b06e98..67021ef046 100644 --- a/src/dawn_wire/client/Device.cpp +++ b/src/dawn_wire/client/Device.cpp @@ -43,7 +43,8 @@ namespace dawn_wire { namespace client { } void Device::HandleDeviceLost(const char* message) { - if (mDeviceLostCallback) { + if (mDeviceLostCallback && !mDidRunLostCallback) { + mDidRunLostCallback = true; mDeviceLostCallback(message, mDeviceLostUserdata); } } diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h index af5934e825..e70cae2b11 100644 --- a/src/dawn_wire/client/Device.h +++ b/src/dawn_wire/client/Device.h @@ -52,6 +52,7 @@ namespace dawn_wire { namespace client { Client* mClient = nullptr; WGPUErrorCallback mErrorCallback = nullptr; WGPUDeviceLostCallback mDeviceLostCallback = nullptr; + bool mDidRunLostCallback = false; void* mErrorUserdata; void* mDeviceLostUserdata; }; diff --git a/src/include/dawn_wire/WireClient.h b/src/include/dawn_wire/WireClient.h index 7c8ee408c6..5b5f33c2ac 100644 --- a/src/include/dawn_wire/WireClient.h +++ b/src/include/dawn_wire/WireClient.h @@ -52,6 +52,10 @@ namespace dawn_wire { ReservedTexture ReserveTexture(WGPUDevice device); + // Disconnects the client. + // Commands allocated after this point will not be sent. + void Disconnect(); + private: std::unique_ptr mImpl; }; diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn index 577db233be..cdc2578806 100644 --- a/src/tests/BUILD.gn +++ b/src/tests/BUILD.gn @@ -129,6 +129,7 @@ test("dawn_unittests") { "${dawn_root}/src/dawn_wire/client/ClientMemoryTransferService_mock.h", "${dawn_root}/src/dawn_wire/server/ServerMemoryTransferService_mock.cpp", "${dawn_root}/src/dawn_wire/server/ServerMemoryTransferService_mock.h", + "MockCallback.h", "unittests/BitSetIteratorTests.cpp", "unittests/BuddyAllocatorTests.cpp", "unittests/BuddyMemoryAllocatorTests.cpp", @@ -181,6 +182,7 @@ test("dawn_unittests") { "unittests/wire/WireArgumentTests.cpp", "unittests/wire/WireBasicTests.cpp", "unittests/wire/WireBufferMappingTests.cpp", + "unittests/wire/WireDisconnectTests.cpp", "unittests/wire/WireErrorCallbackTests.cpp", "unittests/wire/WireExtensionTests.cpp", "unittests/wire/WireFenceTests.cpp", diff --git a/src/tests/MockCallback.h b/src/tests/MockCallback.h new file mode 100644 index 0000000000..c3dfb4ab90 --- /dev/null +++ b/src/tests/MockCallback.h @@ -0,0 +1,101 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "common/Assert.h" + +#include +#include + +namespace testing { + + template + class MockCallback; + + // Helper class for mocking callbacks used for Dawn callbacks with |void* userdata| + // as the last callback argument. + // + // Example Usage: + // MockCallback mock; + // + // void* foo = XYZ; // this is the callback userdata + // + // wgpuDeviceSetDeviceLostCallback(device, mock.Callback(), mock.MakeUserdata(foo)); + // EXPECT_CALL(mock, Call(_, foo)); + template + class MockCallback : public ::testing::MockFunction { + using CallbackType = R (*)(Args...); + + public: + // Helper function makes it easier to get the callback using |foo.Callback()| + // unstead of MockCallback::Callback. + static CallbackType Callback() { + return CallUnboundCallback; + } + + void* MakeUserdata(void* userdata) { + auto mockAndUserdata = + std::unique_ptr(new MockAndUserdata{this, userdata}); + + // Add the userdata to a set of userdata for this mock. We never + // remove from this set even if a callback should only be called once so that + // repeated calls to the callback still forward the userdata correctly. + // Userdata will be destroyed when the mock is destroyed. + auto it = mUserdatas.insert(std::move(mockAndUserdata)); + ASSERT(it.second); + return it.first->get(); + } + + private: + struct MockAndUserdata { + MockCallback* mock; + void* userdata; + }; + + static R CallUnboundCallback(Args... args) { + std::tuple tuple = std::make_tuple(args...); + + constexpr size_t ArgC = sizeof...(Args); + static_assert(ArgC >= 1, "Mock callback requires at least one argument (the userdata)"); + + // Get the userdata. It should be the last argument. + auto userdata = std::get(tuple); + static_assert(std::is_same::value, + "Last callback argument must be void* userdata"); + + // Extract the mock. + ASSERT(userdata != nullptr); + auto* mockAndUserdata = reinterpret_cast(userdata); + MockCallback* mock = mockAndUserdata->mock; + ASSERT(mock != nullptr); + + // Replace the userdata + std::get(tuple) = mockAndUserdata->userdata; + + // Forward the callback to the mock. + return mock->CallImpl(std::make_index_sequence{}, std::move(tuple)); + } + + // This helper cannot be inlined because we dependent on the templated index sequence + // to unpack the tuple arguments. + template + R CallImpl(const std::index_sequence&, std::tuple args) { + return this->Call(std::get(args)...); + } + + std::set> mUserdatas; + }; + +} // namespace testing diff --git a/src/tests/unittests/wire/WireDisconnectTests.cpp b/src/tests/unittests/wire/WireDisconnectTests.cpp new file mode 100644 index 0000000000..4e9b355f8c --- /dev/null +++ b/src/tests/unittests/wire/WireDisconnectTests.cpp @@ -0,0 +1,128 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tests/unittests/wire/WireTest.h" + +#include "common/Assert.h" +#include "dawn_wire/WireClient.h" +#include "tests/MockCallback.h" + +using namespace testing; +using namespace dawn_wire; + +namespace { + + class WireDisconnectTests : public WireTest {}; + +} // anonymous namespace + +// Test that commands are not received if the client disconnects. +TEST_F(WireDisconnectTests, CommandsAfterDisconnect) { + // Sanity check that commands work at all. + wgpuDeviceCreateCommandEncoder(device, nullptr); + + WGPUCommandEncoder apiCmdBufEncoder = api.GetNewCommandEncoder(); + EXPECT_CALL(api, DeviceCreateCommandEncoder(apiDevice, nullptr)) + .WillOnce(Return(apiCmdBufEncoder)); + FlushClient(); + + // Disconnect. + GetWireClient()->Disconnect(); + + // Command is not received because client disconnected. + wgpuDeviceCreateCommandEncoder(device, nullptr); + EXPECT_CALL(api, DeviceCreateCommandEncoder(_, _)).Times(Exactly(0)); + FlushClient(); +} + +// Test that commands that are serialized before a disconnect but flushed +// after are received. +TEST_F(WireDisconnectTests, FlushAfterDisconnect) { + // Sanity check that commands work at all. + wgpuDeviceCreateCommandEncoder(device, nullptr); + + // Disconnect. + GetWireClient()->Disconnect(); + + // Already-serialized commmands are still received. + WGPUCommandEncoder apiCmdBufEncoder = api.GetNewCommandEncoder(); + EXPECT_CALL(api, DeviceCreateCommandEncoder(apiDevice, nullptr)) + .WillOnce(Return(apiCmdBufEncoder)); + FlushClient(); +} + +// Check that disconnecting the wire client calls the device lost callback exacty once. +TEST_F(WireDisconnectTests, CallsDeviceLostCallback) { + MockCallback mockDeviceLostCallback; + wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(), + mockDeviceLostCallback.MakeUserdata(this)); + + // Disconnect the wire client. We should receive device lost only once. + EXPECT_CALL(mockDeviceLostCallback, Call(_, this)).Times(Exactly(1)); + GetWireClient()->Disconnect(); + GetWireClient()->Disconnect(); +} + +// Check that disconnecting the wire client after a device loss does not trigger the callback again. +TEST_F(WireDisconnectTests, ServerLostThenDisconnect) { + MockCallback mockDeviceLostCallback; + wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(), + mockDeviceLostCallback.MakeUserdata(this)); + + api.CallDeviceLostCallback(apiDevice, "some reason"); + + // Flush the device lost return command. + EXPECT_CALL(mockDeviceLostCallback, Call(StrEq("some reason"), this)).Times(Exactly(1)); + FlushServer(); + + // Disconnect the client. We shouldn't see the lost callback again. + EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0)); + GetWireClient()->Disconnect(); +} + +// Check that disconnecting the wire client inside the device loss callback does not trigger the +// callback again. +TEST_F(WireDisconnectTests, ServerLostThenDisconnectInCallback) { + MockCallback mockDeviceLostCallback; + wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(), + mockDeviceLostCallback.MakeUserdata(this)); + + api.CallDeviceLostCallback(apiDevice, "lost reason"); + + // Disconnect the client inside the lost callback. We should see the callback + // only once. + EXPECT_CALL(mockDeviceLostCallback, Call(StrEq("lost reason"), this)) + .WillOnce(InvokeWithoutArgs([&]() { + EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0)); + GetWireClient()->Disconnect(); + })); + FlushServer(); +} + +// Check that a device loss after a disconnect does not trigger the callback again. +TEST_F(WireDisconnectTests, DisconnectThenServerLost) { + MockCallback mockDeviceLostCallback; + wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(), + mockDeviceLostCallback.MakeUserdata(this)); + + // Disconnect the client. We should see the callback once. + EXPECT_CALL(mockDeviceLostCallback, Call(_, this)).Times(Exactly(1)); + GetWireClient()->Disconnect(); + + // Lose the device on the server. The client callback shouldn't be + // called again. + api.CallDeviceLostCallback(apiDevice, "lost reason"); + EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0)); + FlushServer(); +}