dawn_wire: Add Reserve/InjectDevice

Now that the wire does enough tracking to prevent a malicious client
from freeing a device before its child objects, and the device is no
longer a "special" object with regard to reference/release, it is
safe to support multiple devices on the wire. The simplest way to
use this in WebGPU (to fix createReadyRenderPipeline validation)
is to add a reserve/inject device API similar to the one we use for
swapchain textures.

Bug: dawn:565
Change-Id: Ie956aff528c5610c9ecc5c189dab2d22185cb572
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/37800
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Austin Eng 2021-01-19 19:27:52 +00:00 committed by Commit Bot service account
parent b830da7d6e
commit 8bcde8e394
18 changed files with 377 additions and 45 deletions

View File

@ -98,6 +98,11 @@ namespace dawn_wire { namespace server {
*data->childObjectTypesAndIds.begin()); *data->childObjectTypesAndIds.begin());
DoDestroyObject(childObjectType, childObjectId); DoDestroyObject(childObjectType, childObjectId);
} }
if (data->handle != nullptr) {
//* Deregisters uncaptured error and device lost callbacks since
//* they should not be forwarded if the device no longer exists on the wire.
ClearDeviceCallbacks(data->handle);
}
{% endif %} {% endif %}
if (data->handle != nullptr) { if (data->handle != nullptr) {
mProcs.{{as_varName(type.name, Name("release"))}}(data->handle); mProcs.{{as_varName(type.name, Name("release"))}}(data->handle);

View File

@ -37,6 +37,10 @@ namespace dawn_wire {
return mImpl->ReserveTexture(device); return mImpl->ReserveTexture(device);
} }
ReservedDevice WireClient::ReserveDevice() {
return mImpl->ReserveDevice();
}
void WireClient::Disconnect() { void WireClient::Disconnect() {
mImpl->Disconnect(); mImpl->Disconnect();
} }

View File

@ -40,6 +40,14 @@ namespace dawn_wire {
return mImpl->InjectTexture(texture, id, generation, deviceId, deviceGeneration); return mImpl->InjectTexture(texture, id, generation, deviceId, deviceGeneration);
} }
bool WireServer::InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation) {
return mImpl->InjectDevice(device, id, generation);
}
WGPUDevice WireServer::GetDevice(uint32_t id, uint32_t generation) {
return mImpl->GetDevice(id, generation);
}
namespace server { namespace server {
MemoryTransferService::MemoryTransferService() = default; MemoryTransferService::MemoryTransferService() = default;

View File

@ -85,8 +85,13 @@ namespace dawn_wire { namespace client {
} }
WGPUDevice Client::GetDevice() { WGPUDevice Client::GetDevice() {
// This function is deprecated. The concept of a "default" device on the wire
// will be removed in favor of ReserveDevice/InjectDevice.
if (mDevice == nullptr) { if (mDevice == nullptr) {
mDevice = DeviceAllocator().New(this)->object.get(); ReservedDevice reservation = ReserveDevice();
mDevice = FromAPI(reservation.device);
ASSERT(reservation.id == 1);
ASSERT(reservation.generation == 0);
} }
return reinterpret_cast<WGPUDeviceImpl*>(mDevice); return reinterpret_cast<WGPUDeviceImpl*>(mDevice);
} }
@ -103,6 +108,16 @@ namespace dawn_wire { namespace client {
return result; return result;
} }
ReservedDevice Client::ReserveDevice() {
auto* allocation = DeviceAllocator().New(this);
ReservedDevice result;
result.device = ToAPI(allocation->object.get());
result.id = allocation->object->id;
result.generation = allocation->generation;
return result;
}
void Client::Disconnect() { void Client::Disconnect() {
mDisconnected = true; mDisconnected = true;
mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance()); mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance());

View File

@ -46,6 +46,7 @@ namespace dawn_wire { namespace client {
} }
ReservedTexture ReserveTexture(WGPUDevice device); ReservedTexture ReserveTexture(WGPUDevice device);
ReservedDevice ReserveDevice();
template <typename Cmd> template <typename Cmd>
void SerializeCommand(const Cmd& cmd) { void SerializeCommand(const Cmd& cmd) {

View File

@ -45,15 +45,6 @@ namespace dawn_wire { namespace client {
} }
}; };
#endif // DAWN_ENABLE_ASSERTS #endif // DAWN_ENABLE_ASSERTS
// Get the default queue for this device.
auto* allocation = client->QueueAllocator().New(client);
mDefaultQueue = allocation->object.get();
DeviceGetDefaultQueueCmd cmd;
cmd.self = ToAPI(this);
cmd.result = ObjectHandle{allocation->object->id, allocation->generation};
client->SerializeCommand(cmd);
} }
Device::~Device() { Device::~Device() {
@ -206,6 +197,22 @@ namespace dawn_wire { namespace client {
} }
WGPUQueue Device::GetDefaultQueue() { WGPUQueue Device::GetDefaultQueue() {
// The queue is lazily created because if a Device is created by
// Reserve/Inject, we cannot send the getDefaultQueue message until
// it has been injected on the Server. It cannot happen immediately
// on construction.
if (mDefaultQueue == nullptr) {
// Get the default queue for this device.
auto* allocation = client->QueueAllocator().New(client);
mDefaultQueue = allocation->object.get();
DeviceGetDefaultQueueCmd cmd;
cmd.self = ToAPI(this);
cmd.result = ObjectHandle{allocation->object->id, allocation->generation};
client->SerializeCommand(cmd);
}
mDefaultQueue->refcount++; mDefaultQueue->refcount++;
return ToAPI(mDefaultQueue); return ToAPI(mDefaultQueue);
} }

View File

@ -160,6 +160,17 @@ namespace dawn_wire { namespace server {
return objects; return objects;
} }
std::vector<T> GetAllHandles() {
std::vector<T> objects;
for (Data& data : mKnown) {
if (data.allocated && data.handle != nullptr) {
objects.push_back(data.handle);
}
}
return objects;
}
private: private:
std::vector<Data> mKnown; std::vector<Data> mKnown;
}; };

View File

@ -23,7 +23,6 @@ namespace dawn_wire { namespace server {
MemoryTransferService* memoryTransferService) MemoryTransferService* memoryTransferService)
: mSerializer(serializer), : mSerializer(serializer),
mProcs(procs), mProcs(procs),
mDeviceOnCreation(device),
mMemoryTransferService(memoryTransferService), mMemoryTransferService(memoryTransferService),
mIsAlive(std::make_shared<bool>(true)) { mIsAlive(std::make_shared<bool>(true)) {
if (mMemoryTransferService == nullptr) { if (mMemoryTransferService == nullptr) {
@ -31,38 +30,21 @@ namespace dawn_wire { namespace server {
mOwnedMemoryTransferService = CreateInlineMemoryTransferService(); mOwnedMemoryTransferService = CreateInlineMemoryTransferService();
mMemoryTransferService = mOwnedMemoryTransferService.get(); mMemoryTransferService = mOwnedMemoryTransferService.get();
} }
// The client-server knowledge is bootstrapped with device 1.
auto* deviceData = DeviceObjects().Allocate(1);
deviceData->handle = device;
// Take an extra ref. All objects may be freed by the client, but this // For the deprecated initialization path:
// one is externally owned. // The client-server knowledge is bootstrapped with device 1, generation 0.
mProcs.deviceReference(device); if (device != nullptr) {
bool success = InjectDevice(device, 1, 0);
// Note: these callbacks are manually inlined here since they do not acquire and ASSERT(success);
// free their userdata. }
mProcs.deviceSetUncapturedErrorCallback(
device,
[](WGPUErrorType type, const char* message, void* userdata) {
Server* server = static_cast<Server*>(userdata);
server->OnUncapturedError(type, message);
},
this);
mProcs.deviceSetDeviceLostCallback(
device,
[](const char* message, void* userdata) {
Server* server = static_cast<Server*>(userdata);
server->OnDeviceLost(message);
},
this);
} }
Server::~Server() { Server::~Server() {
// Un-set the error and lost callbacks since we cannot forward them // Un-set the error and lost callbacks since we cannot forward them
// after the server has been destroyed. // after the server has been destroyed.
mProcs.deviceSetUncapturedErrorCallback(mDeviceOnCreation, nullptr, nullptr); for (WGPUDevice device : DeviceObjects().GetAllHandles()) {
mProcs.deviceSetDeviceLostCallback(mDeviceOnCreation, nullptr, nullptr); ClearDeviceCallbacks(device);
}
DestroyAllObjects(mProcs); DestroyAllObjects(mProcs);
} }
@ -71,6 +53,7 @@ namespace dawn_wire { namespace server {
uint32_t generation, uint32_t generation,
uint32_t deviceId, uint32_t deviceId,
uint32_t deviceGeneration) { uint32_t deviceGeneration) {
ASSERT(texture != nullptr);
ObjectData<WGPUDevice>* device = DeviceObjects().Get(deviceId); ObjectData<WGPUDevice>* device = DeviceObjects().Get(deviceId);
if (device == nullptr || device->generation != deviceGeneration) { if (device == nullptr || device->generation != deviceGeneration) {
return false; return false;
@ -97,6 +80,57 @@ namespace dawn_wire { namespace server {
return true; return true;
} }
bool Server::InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation) {
ASSERT(device != nullptr);
ObjectData<WGPUDevice>* data = DeviceObjects().Allocate(id);
if (data == nullptr) {
return false;
}
data->handle = device;
data->generation = generation;
data->allocated = true;
// The device is externally owned so it shouldn't be destroyed when we receive a destroy
// message from the client. Add a reference to counterbalance the eventual release.
mProcs.deviceReference(device);
// Set callbacks to forward errors to the client.
// Note: these callbacks are manually inlined here since they do not acquire and
// free their userdata.
mProcs.deviceSetUncapturedErrorCallback(
device,
[](WGPUErrorType type, const char* message, void* userdata) {
Server* server = static_cast<Server*>(userdata);
server->OnUncapturedError(type, message);
},
this);
mProcs.deviceSetDeviceLostCallback(
device,
[](const char* message, void* userdata) {
Server* server = static_cast<Server*>(userdata);
server->OnDeviceLost(message);
},
this);
return true;
}
WGPUDevice Server::GetDevice(uint32_t id, uint32_t generation) {
ObjectData<WGPUDevice>* data = DeviceObjects().Get(id);
if (data == nullptr || data->generation != generation) {
return nullptr;
}
return data->handle;
}
void Server::ClearDeviceCallbacks(WGPUDevice device) {
// Un-set the error and lost callbacks since we cannot forward them
// after the server has been destroyed.
mProcs.deviceSetUncapturedErrorCallback(device, nullptr, nullptr);
mProcs.deviceSetDeviceLostCallback(device, nullptr, nullptr);
}
bool TrackDeviceChild(ObjectDataBase<WGPUDevice>* device, ObjectType type, ObjectId id) { bool TrackDeviceChild(ObjectDataBase<WGPUDevice>* device, ObjectType type, ObjectId id) {
auto it = static_cast<ObjectData<WGPUDevice>*>(device)->childObjectTypesAndIds.insert( auto it = static_cast<ObjectData<WGPUDevice>*>(device)->childObjectTypesAndIds.insert(
PackObjectTypeAndId(type, id)); PackObjectTypeAndId(type, id));

View File

@ -167,6 +167,10 @@ namespace dawn_wire { namespace server {
uint32_t deviceId, uint32_t deviceId,
uint32_t deviceGeneration); uint32_t deviceGeneration);
bool InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation);
WGPUDevice GetDevice(uint32_t id, uint32_t generation);
template <typename T, template <typename T,
typename Enable = std::enable_if<std::is_base_of<CallbackUserdata, T>::value>> typename Enable = std::enable_if<std::is_base_of<CallbackUserdata, T>::value>>
std::unique_ptr<T> MakeUserdata() { std::unique_ptr<T> MakeUserdata() {
@ -186,6 +190,7 @@ namespace dawn_wire { namespace server {
mSerializer.SerializeCommand(cmd, extraSize, SerializeExtraSize); mSerializer.SerializeCommand(cmd, extraSize, SerializeExtraSize);
} }
void ClearDeviceCallbacks(WGPUDevice device);
// Error callbacks // Error callbacks
void OnUncapturedError(WGPUErrorType type, const char* message); void OnUncapturedError(WGPUErrorType type, const char* message);
@ -212,7 +217,6 @@ namespace dawn_wire { namespace server {
WireDeserializeAllocator mAllocator; WireDeserializeAllocator mAllocator;
ChunkedCommandSerializer mSerializer; ChunkedCommandSerializer mSerializer;
DawnProcTable mProcs; DawnProcTable mProcs;
WGPUDevice mDeviceOnCreation;
std::unique_ptr<MemoryTransferService> mOwnedMemoryTransferService = nullptr; std::unique_ptr<MemoryTransferService> mOwnedMemoryTransferService = nullptr;
MemoryTransferService* mMemoryTransferService = nullptr; MemoryTransferService* mMemoryTransferService = nullptr;

View File

@ -38,6 +38,12 @@ namespace dawn_wire {
uint32_t deviceGeneration; uint32_t deviceGeneration;
}; };
struct ReservedDevice {
WGPUDevice device;
uint32_t id;
uint32_t generation;
};
struct DAWN_WIRE_EXPORT WireClientDescriptor { struct DAWN_WIRE_EXPORT WireClientDescriptor {
CommandSerializer* serializer; CommandSerializer* serializer;
client::MemoryTransferService* memoryTransferService = nullptr; client::MemoryTransferService* memoryTransferService = nullptr;
@ -53,6 +59,7 @@ namespace dawn_wire {
size_t size) override final; size_t size) override final;
ReservedTexture ReserveTexture(WGPUDevice device); ReservedTexture ReserveTexture(WGPUDevice device);
ReservedDevice ReserveDevice();
// Disconnects the client. // Disconnects the client.
// Commands allocated after this point will not be sent. // Commands allocated after this point will not be sent.

View File

@ -50,6 +50,17 @@ namespace dawn_wire {
uint32_t deviceId = 1, uint32_t deviceId = 1,
uint32_t deviceGeneration = 0); uint32_t deviceGeneration = 0);
bool InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation);
// Look up a device by (id, generation) pair. Returns nullptr if the generation
// has expired or the id is not found.
// The Wire does not have destroy hooks to allow an embedder to observe when an object
// has been destroyed, but in Chrome, we need to know the list of live devices so we
// can call device.Tick() on all of them periodically to ensure progress on asynchronous
// work is made. Getting this list can be done by tracking the (id, generation) of
// previously injected devices, and observing if GetDevice(id, generation) returns non-null.
WGPUDevice GetDevice(uint32_t id, uint32_t generation);
private: private:
std::unique_ptr<server::Server> mImpl; std::unique_ptr<server::Server> mImpl;
}; };

View File

@ -223,6 +223,7 @@ test("dawn_unittests") {
"unittests/wire/WireErrorCallbackTests.cpp", "unittests/wire/WireErrorCallbackTests.cpp",
"unittests/wire/WireExtensionTests.cpp", "unittests/wire/WireExtensionTests.cpp",
"unittests/wire/WireFenceTests.cpp", "unittests/wire/WireFenceTests.cpp",
"unittests/wire/WireInjectDeviceTests.cpp",
"unittests/wire/WireInjectTextureTests.cpp", "unittests/wire/WireInjectTextureTests.cpp",
"unittests/wire/WireMemoryTransferServiceTests.cpp", "unittests/wire/WireMemoryTransferServiceTests.cpp",
"unittests/wire/WireMultipleDeviceTests.cpp", "unittests/wire/WireMultipleDeviceTests.cpp",

View File

@ -36,10 +36,19 @@ TEST_F(WireDestroyObjectTests, DestroyDeviceDestroysChildren) {
// The device and child objects should be released. // The device and child objects should be released.
EXPECT_CALL(api, CommandEncoderRelease(apiEncoder)).InSequence(s1); EXPECT_CALL(api, CommandEncoderRelease(apiEncoder)).InSequence(s1);
EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2); EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2);
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
.Times(1)
.InSequence(s1, s2);
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
.Times(1)
.InSequence(s1, s2);
EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2); EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2);
FlushClient(); FlushClient();
// Signal that we already released and cleared callbacks for |apiDevice|
DefaultApiDeviceWasReleased();
// Using the command encoder should be an error. // Using the command encoder should be an error.
wgpuCommandEncoderFinish(encoder, nullptr); wgpuCommandEncoderFinish(encoder, nullptr);
FlushClient(false); FlushClient(false);
@ -82,8 +91,17 @@ TEST_F(WireDestroyObjectTests, ImplicitInjectErrorAfterDestroyDevice) {
// The device and child objects alre also released. // The device and child objects alre also released.
EXPECT_CALL(api, BufferRelease(apiBuffer)).InSequence(s1); EXPECT_CALL(api, BufferRelease(apiBuffer)).InSequence(s1);
EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2); EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2);
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
.Times(1)
.InSequence(s1, s2);
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
.Times(1)
.InSequence(s1, s2);
EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2); EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2);
FlushClient(); FlushClient();
// Signal that we already released and cleared callbacks for |apiDevice|
DefaultApiDeviceWasReleased();
} }
} }

View File

@ -149,6 +149,15 @@ TEST_F(WireDisconnectTests, DeleteClientDestroysObjects) {
EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1).InSequence(s1); EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1).InSequence(s1);
EXPECT_CALL(api, CommandEncoderRelease(apiCommandEncoder)).Times(1).InSequence(s2); EXPECT_CALL(api, CommandEncoderRelease(apiCommandEncoder)).Times(1).InSequence(s2);
EXPECT_CALL(api, SamplerRelease(apiSampler)).Times(1).InSequence(s3); EXPECT_CALL(api, SamplerRelease(apiSampler)).Times(1).InSequence(s3);
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
.Times(1)
.InSequence(s1, s2);
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
.Times(1)
.InSequence(s1, s2);
EXPECT_CALL(api, DeviceRelease(apiDevice)).Times(1).InSequence(s1, s2, s3); EXPECT_CALL(api, DeviceRelease(apiDevice)).Times(1).InSequence(s1, s2, s3);
FlushClient(); FlushClient();
// Signal that we already released and cleared callbacks for |apiDevice|
DefaultApiDeviceWasReleased();
} }

View File

@ -0,0 +1,184 @@
// Copyright 2021 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tests/unittests/wire/WireTest.h"
#include "dawn_wire/WireClient.h"
#include "dawn_wire/WireServer.h"
using namespace testing;
using namespace dawn_wire;
class WireInjectDeviceTests : public WireTest {
public:
WireInjectDeviceTests() {
}
~WireInjectDeviceTests() override = default;
};
// Test that reserving and injecting a device makes calls on the client object forward to the
// server object correctly.
TEST_F(WireInjectDeviceTests, CallAfterReserveInject) {
ReservedDevice reservation = GetWireClient()->ReserveDevice();
WGPUDevice serverDevice = api.GetNewDevice();
EXPECT_CALL(api, DeviceReference(serverDevice));
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
ASSERT_TRUE(
GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
WGPUBufferDescriptor bufferDesc = {};
wgpuDeviceCreateBuffer(reservation.device, &bufferDesc);
WGPUBuffer serverBuffer = api.GetNewBuffer();
EXPECT_CALL(api, DeviceCreateBuffer(serverDevice, _)).WillOnce(Return(serverBuffer));
FlushClient();
// Called on shutdown.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr))
.Times(Exactly(1));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr))
.Times(Exactly(1));
}
// Test that reserve correctly returns different IDs each time.
TEST_F(WireInjectDeviceTests, ReserveDifferentIDs) {
ReservedDevice reservation1 = GetWireClient()->ReserveDevice();
ReservedDevice reservation2 = GetWireClient()->ReserveDevice();
ASSERT_NE(reservation1.id, reservation2.id);
ASSERT_NE(reservation1.device, reservation2.device);
}
// Test that injecting the same id without a destroy first fails.
TEST_F(WireInjectDeviceTests, InjectExistingID) {
ReservedDevice reservation = GetWireClient()->ReserveDevice();
WGPUDevice serverDevice = api.GetNewDevice();
EXPECT_CALL(api, DeviceReference(serverDevice));
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
ASSERT_TRUE(
GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
// ID already in use, call fails.
ASSERT_FALSE(
GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
// Called on shutdown.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr))
.Times(Exactly(1));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr))
.Times(Exactly(1));
}
// Test that the server only borrows the device and does a single reference-release
TEST_F(WireInjectDeviceTests, InjectedDeviceLifetime) {
ReservedDevice reservation = GetWireClient()->ReserveDevice();
// Injecting the device adds a reference
WGPUDevice serverDevice = api.GetNewDevice();
EXPECT_CALL(api, DeviceReference(serverDevice));
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
ASSERT_TRUE(
GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
// Releasing the device removes a single reference and clears its error callbacks.
wgpuDeviceRelease(reservation.device);
EXPECT_CALL(api, DeviceRelease(serverDevice));
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr)).Times(1);
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr)).Times(1);
FlushClient();
// Deleting the server doesn't release a second reference.
DeleteServer();
Mock::VerifyAndClearExpectations(&api);
}
// Test that it is an error to get the default queue of a device before it has been
// injected on the server.
TEST_F(WireInjectDeviceTests, GetQueueBeforeInject) {
ReservedDevice reservation = GetWireClient()->ReserveDevice();
wgpuDeviceGetDefaultQueue(reservation.device);
FlushClient(false);
}
// Test that it is valid to get the default queue of a device after it has been
// injected on the server.
TEST_F(WireInjectDeviceTests, GetQueueAfterInject) {
ReservedDevice reservation = GetWireClient()->ReserveDevice();
WGPUDevice serverDevice = api.GetNewDevice();
EXPECT_CALL(api, DeviceReference(serverDevice));
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _));
ASSERT_TRUE(
GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation));
wgpuDeviceGetDefaultQueue(reservation.device);
WGPUQueue apiQueue = api.GetNewQueue();
EXPECT_CALL(api, DeviceGetDefaultQueue(serverDevice)).WillOnce(Return(apiQueue));
FlushClient();
// Called on shutdown.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr))
.Times(Exactly(1));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr))
.Times(Exactly(1));
}
// Test that the list of live devices can be reflected using GetDevice.
TEST_F(WireInjectDeviceTests, ReflectLiveDevices) {
// Reserve two devices.
ReservedDevice reservation1 = GetWireClient()->ReserveDevice();
ReservedDevice reservation2 = GetWireClient()->ReserveDevice();
// Inject both devices.
WGPUDevice serverDevice1 = api.GetNewDevice();
EXPECT_CALL(api, DeviceReference(serverDevice1));
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice1, _, _));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice1, _, _));
ASSERT_TRUE(
GetWireServer()->InjectDevice(serverDevice1, reservation1.id, reservation1.generation));
WGPUDevice serverDevice2 = api.GetNewDevice();
EXPECT_CALL(api, DeviceReference(serverDevice2));
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, _, _));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, _, _));
ASSERT_TRUE(
GetWireServer()->InjectDevice(serverDevice2, reservation2.id, reservation2.generation));
// Test that both devices can be reflected.
ASSERT_EQ(serverDevice1, GetWireServer()->GetDevice(reservation1.id, reservation1.generation));
ASSERT_EQ(serverDevice2, GetWireServer()->GetDevice(reservation2.id, reservation2.generation));
// Release the first device
wgpuDeviceRelease(reservation1.device);
EXPECT_CALL(api, DeviceRelease(serverDevice1));
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice1, nullptr, nullptr)).Times(1);
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice1, nullptr, nullptr)).Times(1);
FlushClient();
// The first device should no longer reflect, but the second should
ASSERT_EQ(nullptr, GetWireServer()->GetDevice(reservation1.id, reservation1.generation));
ASSERT_EQ(serverDevice2, GetWireServer()->GetDevice(reservation2.id, reservation2.generation));
// Called on shutdown.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, nullptr, nullptr)).Times(1);
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, nullptr, nullptr)).Times(1);
}

View File

@ -83,9 +83,10 @@ class WireMultipleDeviceTests : public testing::Test {
// These are called on server destruction to clear the callbacks. They must not be // These are called on server destruction to clear the callbacks. They must not be
// called after the server is destroyed. // called after the server is destroyed.
EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)) EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(mServerDevice, nullptr, nullptr))
.Times(Exactly(1));
EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(mServerDevice, nullptr, nullptr))
.Times(Exactly(1)); .Times(Exactly(1));
EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
mWireServer = nullptr; mWireServer = nullptr;
} }

View File

@ -88,15 +88,23 @@ void WireTest::TearDown() {
api.IgnoreAllReleaseCalls(); api.IgnoreAllReleaseCalls();
mWireClient = nullptr; mWireClient = nullptr;
if (mWireServer) { if (mWireServer && apiDevice) {
// These are called on server destruction to clear the callbacks. They must not be // These are called on server destruction to clear the callbacks. They must not be
// called after the server is destroyed. // called after the server is destroyed.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1)); EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1)); .Times(Exactly(1));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
.Times(Exactly(1));
} }
mWireServer = nullptr; mWireServer = nullptr;
} }
// This should be called if |apiDevice| is no longer exists on the wire.
// This signals that expectations in |TearDowb| shouldn't be added.
void WireTest::DefaultApiDeviceWasReleased() {
apiDevice = nullptr;
}
void WireTest::FlushClient(bool success) { void WireTest::FlushClient(bool success) {
ASSERT_EQ(mC2sBuf->Flush(), success); ASSERT_EQ(mC2sBuf->Flush(), success);
@ -123,8 +131,10 @@ void WireTest::DeleteServer() {
if (mWireServer) { if (mWireServer) {
// These are called on server destruction to clear the callbacks. They must not be // These are called on server destruction to clear the callbacks. They must not be
// called after the server is destroyed. // called after the server is destroyed.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1)); EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr))
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1)); .Times(Exactly(1));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr))
.Times(Exactly(1));
} }
mWireServer = nullptr; mWireServer = nullptr;
} }

View File

@ -123,6 +123,8 @@ class WireTest : public testing::Test {
void FlushClient(bool success = true); void FlushClient(bool success = true);
void FlushServer(bool success = true); void FlushServer(bool success = true);
void DefaultApiDeviceWasReleased();
testing::StrictMock<MockProcTable> api; testing::StrictMock<MockProcTable> api;
WGPUDevice apiDevice; WGPUDevice apiDevice;
WGPUQueue apiQueue; WGPUQueue apiQueue;