Add WireClient::Disconnect to stop serializing commands

This prevents the client from continuing to send commands
when the wire connection has dropped. In Chromium this may
be because the connection to the GPU process is lost and the
transfer buffer may be destroyed.

This CL also adds a new helper to make testing callbacks
with mocks easier.

Bug: chromium:1070392
Change-Id: I6a69c32cc506069554ead18ee83a156ca70e2ce2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/19160
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
Austin Eng 2020-04-15 02:00:14 +00:00 committed by Commit Bot service account
parent beaaa5a601
commit 93bea5cb50
9 changed files with 271 additions and 8 deletions

View File

@ -42,6 +42,10 @@ namespace dawn_wire {
return mImpl->ReserveTexture(device);
}
void WireClient::Disconnect() {
mImpl->Disconnect();
}
namespace client {
MemoryTransferService::~MemoryTransferService() = default;

View File

@ -13,6 +13,8 @@
// limitations under the License.
#include "dawn_wire/client/Client.h"
#include "common/Compiler.h"
#include "dawn_wire/client/Device.h"
namespace dawn_wire { namespace client {
@ -44,4 +46,21 @@ namespace dawn_wire { namespace client {
return result;
}
void* Client::GetCmdSpace(size_t size) {
if (DAWN_UNLIKELY(mIsDisconnected)) {
if (size > mDummyCmdSpace.size()) {
mDummyCmdSpace.resize(size);
}
return mDummyCmdSpace.data();
}
return mSerializer->GetCmdSpace(size);
}
void Client::Disconnect() {
if (!mIsDisconnected) {
mIsDisconnected = true;
mDevice->HandleDeviceLost("GPU connection lost");
}
}
}} // namespace dawn_wire::client

View File

@ -33,13 +33,6 @@ namespace dawn_wire { namespace client {
Client(CommandSerializer* serializer, MemoryTransferService* memoryTransferService);
~Client();
const volatile char* HandleCommands(const volatile char* commands, size_t size);
ReservedTexture ReserveTexture(WGPUDevice device);
void* GetCmdSpace(size_t size) {
return mSerializer->GetCmdSpace(size);
}
WGPUDevice GetDevice() const {
return reinterpret_cast<WGPUDeviceImpl*>(mDevice);
}
@ -48,6 +41,13 @@ namespace dawn_wire { namespace client {
return mMemoryTransferService;
}
const volatile char* HandleCommands(const volatile char* commands, size_t size);
ReservedTexture ReserveTexture(WGPUDevice device);
void* GetCmdSpace(size_t size);
void Disconnect();
private:
#include "dawn_wire/client/ClientPrototypes_autogen.inc"
@ -56,6 +56,9 @@ namespace dawn_wire { namespace client {
WireDeserializeAllocator mAllocator;
MemoryTransferService* mMemoryTransferService = nullptr;
std::unique_ptr<MemoryTransferService> mOwnedMemoryTransferService = nullptr;
std::vector<char> mDummyCmdSpace;
bool mIsDisconnected = false;
};
DawnProcTable GetProcs();

View File

@ -43,7 +43,8 @@ namespace dawn_wire { namespace client {
}
void Device::HandleDeviceLost(const char* message) {
if (mDeviceLostCallback) {
if (mDeviceLostCallback && !mDidRunLostCallback) {
mDidRunLostCallback = true;
mDeviceLostCallback(message, mDeviceLostUserdata);
}
}

View File

@ -52,6 +52,7 @@ namespace dawn_wire { namespace client {
Client* mClient = nullptr;
WGPUErrorCallback mErrorCallback = nullptr;
WGPUDeviceLostCallback mDeviceLostCallback = nullptr;
bool mDidRunLostCallback = false;
void* mErrorUserdata;
void* mDeviceLostUserdata;
};

View File

@ -52,6 +52,10 @@ namespace dawn_wire {
ReservedTexture ReserveTexture(WGPUDevice device);
// Disconnects the client.
// Commands allocated after this point will not be sent.
void Disconnect();
private:
std::unique_ptr<client::Client> mImpl;
};

View File

@ -129,6 +129,7 @@ test("dawn_unittests") {
"${dawn_root}/src/dawn_wire/client/ClientMemoryTransferService_mock.h",
"${dawn_root}/src/dawn_wire/server/ServerMemoryTransferService_mock.cpp",
"${dawn_root}/src/dawn_wire/server/ServerMemoryTransferService_mock.h",
"MockCallback.h",
"unittests/BitSetIteratorTests.cpp",
"unittests/BuddyAllocatorTests.cpp",
"unittests/BuddyMemoryAllocatorTests.cpp",
@ -181,6 +182,7 @@ test("dawn_unittests") {
"unittests/wire/WireArgumentTests.cpp",
"unittests/wire/WireBasicTests.cpp",
"unittests/wire/WireBufferMappingTests.cpp",
"unittests/wire/WireDisconnectTests.cpp",
"unittests/wire/WireErrorCallbackTests.cpp",
"unittests/wire/WireExtensionTests.cpp",
"unittests/wire/WireFenceTests.cpp",

101
src/tests/MockCallback.h Normal file
View File

@ -0,0 +1,101 @@
// Copyright 2020 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gmock/gmock.h>
#include "common/Assert.h"
#include <memory>
#include <set>
namespace testing {
template <typename F>
class MockCallback;
// Helper class for mocking callbacks used for Dawn callbacks with |void* userdata|
// as the last callback argument.
//
// Example Usage:
// MockCallback<WGPUDeviceLostCallback> mock;
//
// void* foo = XYZ; // this is the callback userdata
//
// wgpuDeviceSetDeviceLostCallback(device, mock.Callback(), mock.MakeUserdata(foo));
// EXPECT_CALL(mock, Call(_, foo));
template <typename R, typename... Args>
class MockCallback<R (*)(Args...)> : public ::testing::MockFunction<R(Args...)> {
using CallbackType = R (*)(Args...);
public:
// Helper function makes it easier to get the callback using |foo.Callback()|
// unstead of MockCallback<CallbackType>::Callback.
static CallbackType Callback() {
return CallUnboundCallback;
}
void* MakeUserdata(void* userdata) {
auto mockAndUserdata =
std::unique_ptr<MockAndUserdata>(new MockAndUserdata{this, userdata});
// Add the userdata to a set of userdata for this mock. We never
// remove from this set even if a callback should only be called once so that
// repeated calls to the callback still forward the userdata correctly.
// Userdata will be destroyed when the mock is destroyed.
auto it = mUserdatas.insert(std::move(mockAndUserdata));
ASSERT(it.second);
return it.first->get();
}
private:
struct MockAndUserdata {
MockCallback* mock;
void* userdata;
};
static R CallUnboundCallback(Args... args) {
std::tuple<Args...> tuple = std::make_tuple(args...);
constexpr size_t ArgC = sizeof...(Args);
static_assert(ArgC >= 1, "Mock callback requires at least one argument (the userdata)");
// Get the userdata. It should be the last argument.
auto userdata = std::get<ArgC - 1>(tuple);
static_assert(std::is_same<decltype(userdata), void*>::value,
"Last callback argument must be void* userdata");
// Extract the mock.
ASSERT(userdata != nullptr);
auto* mockAndUserdata = reinterpret_cast<MockAndUserdata*>(userdata);
MockCallback* mock = mockAndUserdata->mock;
ASSERT(mock != nullptr);
// Replace the userdata
std::get<ArgC - 1>(tuple) = mockAndUserdata->userdata;
// Forward the callback to the mock.
return mock->CallImpl(std::make_index_sequence<ArgC>{}, std::move(tuple));
}
// This helper cannot be inlined because we dependent on the templated index sequence
// to unpack the tuple arguments.
template <size_t... Is>
R CallImpl(const std::index_sequence<Is...>&, std::tuple<Args...> args) {
return this->Call(std::get<Is>(args)...);
}
std::set<std::unique_ptr<MockAndUserdata>> mUserdatas;
};
} // namespace testing

View File

@ -0,0 +1,128 @@
// Copyright 2020 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tests/unittests/wire/WireTest.h"
#include "common/Assert.h"
#include "dawn_wire/WireClient.h"
#include "tests/MockCallback.h"
using namespace testing;
using namespace dawn_wire;
namespace {
class WireDisconnectTests : public WireTest {};
} // anonymous namespace
// Test that commands are not received if the client disconnects.
TEST_F(WireDisconnectTests, CommandsAfterDisconnect) {
// Sanity check that commands work at all.
wgpuDeviceCreateCommandEncoder(device, nullptr);
WGPUCommandEncoder apiCmdBufEncoder = api.GetNewCommandEncoder();
EXPECT_CALL(api, DeviceCreateCommandEncoder(apiDevice, nullptr))
.WillOnce(Return(apiCmdBufEncoder));
FlushClient();
// Disconnect.
GetWireClient()->Disconnect();
// Command is not received because client disconnected.
wgpuDeviceCreateCommandEncoder(device, nullptr);
EXPECT_CALL(api, DeviceCreateCommandEncoder(_, _)).Times(Exactly(0));
FlushClient();
}
// Test that commands that are serialized before a disconnect but flushed
// after are received.
TEST_F(WireDisconnectTests, FlushAfterDisconnect) {
// Sanity check that commands work at all.
wgpuDeviceCreateCommandEncoder(device, nullptr);
// Disconnect.
GetWireClient()->Disconnect();
// Already-serialized commmands are still received.
WGPUCommandEncoder apiCmdBufEncoder = api.GetNewCommandEncoder();
EXPECT_CALL(api, DeviceCreateCommandEncoder(apiDevice, nullptr))
.WillOnce(Return(apiCmdBufEncoder));
FlushClient();
}
// Check that disconnecting the wire client calls the device lost callback exacty once.
TEST_F(WireDisconnectTests, CallsDeviceLostCallback) {
MockCallback<WGPUDeviceLostCallback> mockDeviceLostCallback;
wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(),
mockDeviceLostCallback.MakeUserdata(this));
// Disconnect the wire client. We should receive device lost only once.
EXPECT_CALL(mockDeviceLostCallback, Call(_, this)).Times(Exactly(1));
GetWireClient()->Disconnect();
GetWireClient()->Disconnect();
}
// Check that disconnecting the wire client after a device loss does not trigger the callback again.
TEST_F(WireDisconnectTests, ServerLostThenDisconnect) {
MockCallback<WGPUDeviceLostCallback> mockDeviceLostCallback;
wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(),
mockDeviceLostCallback.MakeUserdata(this));
api.CallDeviceLostCallback(apiDevice, "some reason");
// Flush the device lost return command.
EXPECT_CALL(mockDeviceLostCallback, Call(StrEq("some reason"), this)).Times(Exactly(1));
FlushServer();
// Disconnect the client. We shouldn't see the lost callback again.
EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0));
GetWireClient()->Disconnect();
}
// Check that disconnecting the wire client inside the device loss callback does not trigger the
// callback again.
TEST_F(WireDisconnectTests, ServerLostThenDisconnectInCallback) {
MockCallback<WGPUDeviceLostCallback> mockDeviceLostCallback;
wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(),
mockDeviceLostCallback.MakeUserdata(this));
api.CallDeviceLostCallback(apiDevice, "lost reason");
// Disconnect the client inside the lost callback. We should see the callback
// only once.
EXPECT_CALL(mockDeviceLostCallback, Call(StrEq("lost reason"), this))
.WillOnce(InvokeWithoutArgs([&]() {
EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0));
GetWireClient()->Disconnect();
}));
FlushServer();
}
// Check that a device loss after a disconnect does not trigger the callback again.
TEST_F(WireDisconnectTests, DisconnectThenServerLost) {
MockCallback<WGPUDeviceLostCallback> mockDeviceLostCallback;
wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(),
mockDeviceLostCallback.MakeUserdata(this));
// Disconnect the client. We should see the callback once.
EXPECT_CALL(mockDeviceLostCallback, Call(_, this)).Times(Exactly(1));
GetWireClient()->Disconnect();
// Lose the device on the server. The client callback shouldn't be
// called again.
api.CallDeviceLostCallback(apiDevice, "lost reason");
EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0));
FlushServer();
}