diff --git a/generator/templates/dawn_wire/server/ServerDoers.cpp b/generator/templates/dawn_wire/server/ServerDoers.cpp index 0c6ce4298b..025233655f 100644 --- a/generator/templates/dawn_wire/server/ServerDoers.cpp +++ b/generator/templates/dawn_wire/server/ServerDoers.cpp @@ -98,6 +98,11 @@ namespace dawn_wire { namespace server { *data->childObjectTypesAndIds.begin()); DoDestroyObject(childObjectType, childObjectId); } + if (data->handle != nullptr) { + //* Deregisters uncaptured error and device lost callbacks since + //* they should not be forwarded if the device no longer exists on the wire. + ClearDeviceCallbacks(data->handle); + } {% endif %} if (data->handle != nullptr) { mProcs.{{as_varName(type.name, Name("release"))}}(data->handle); diff --git a/src/dawn_wire/WireClient.cpp b/src/dawn_wire/WireClient.cpp index de73a84462..0dcea37306 100644 --- a/src/dawn_wire/WireClient.cpp +++ b/src/dawn_wire/WireClient.cpp @@ -37,6 +37,10 @@ namespace dawn_wire { return mImpl->ReserveTexture(device); } + ReservedDevice WireClient::ReserveDevice() { + return mImpl->ReserveDevice(); + } + void WireClient::Disconnect() { mImpl->Disconnect(); } diff --git a/src/dawn_wire/WireServer.cpp b/src/dawn_wire/WireServer.cpp index 723f691324..763b5fc65f 100644 --- a/src/dawn_wire/WireServer.cpp +++ b/src/dawn_wire/WireServer.cpp @@ -40,6 +40,14 @@ namespace dawn_wire { return mImpl->InjectTexture(texture, id, generation, deviceId, deviceGeneration); } + bool WireServer::InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation) { + return mImpl->InjectDevice(device, id, generation); + } + + WGPUDevice WireServer::GetDevice(uint32_t id, uint32_t generation) { + return mImpl->GetDevice(id, generation); + } + namespace server { MemoryTransferService::MemoryTransferService() = default; diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp index 0ca5f61c8c..b5665293f9 100644 --- a/src/dawn_wire/client/Client.cpp +++ b/src/dawn_wire/client/Client.cpp @@ -85,8 +85,13 @@ namespace dawn_wire { namespace client { } WGPUDevice Client::GetDevice() { + // This function is deprecated. The concept of a "default" device on the wire + // will be removed in favor of ReserveDevice/InjectDevice. if (mDevice == nullptr) { - mDevice = DeviceAllocator().New(this)->object.get(); + ReservedDevice reservation = ReserveDevice(); + mDevice = FromAPI(reservation.device); + ASSERT(reservation.id == 1); + ASSERT(reservation.generation == 0); } return reinterpret_cast(mDevice); } @@ -103,6 +108,16 @@ namespace dawn_wire { namespace client { return result; } + ReservedDevice Client::ReserveDevice() { + auto* allocation = DeviceAllocator().New(this); + + ReservedDevice result; + result.device = ToAPI(allocation->object.get()); + result.id = allocation->object->id; + result.generation = allocation->generation; + return result; + } + void Client::Disconnect() { mDisconnected = true; mSerializer = ChunkedCommandSerializer(NoopCommandSerializer::GetInstance()); diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h index 4902df890e..dd7ac76260 100644 --- a/src/dawn_wire/client/Client.h +++ b/src/dawn_wire/client/Client.h @@ -46,6 +46,7 @@ namespace dawn_wire { namespace client { } ReservedTexture ReserveTexture(WGPUDevice device); + ReservedDevice ReserveDevice(); template void SerializeCommand(const Cmd& cmd) { diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp index b6ee4fceea..2d643cbfcb 100644 --- a/src/dawn_wire/client/Device.cpp +++ b/src/dawn_wire/client/Device.cpp @@ -45,15 +45,6 @@ namespace dawn_wire { namespace client { } }; #endif // DAWN_ENABLE_ASSERTS - // Get the default queue for this device. - auto* allocation = client->QueueAllocator().New(client); - mDefaultQueue = allocation->object.get(); - - DeviceGetDefaultQueueCmd cmd; - cmd.self = ToAPI(this); - cmd.result = ObjectHandle{allocation->object->id, allocation->generation}; - - client->SerializeCommand(cmd); } Device::~Device() { @@ -206,6 +197,22 @@ namespace dawn_wire { namespace client { } WGPUQueue Device::GetDefaultQueue() { + // The queue is lazily created because if a Device is created by + // Reserve/Inject, we cannot send the getDefaultQueue message until + // it has been injected on the Server. It cannot happen immediately + // on construction. + if (mDefaultQueue == nullptr) { + // Get the default queue for this device. + auto* allocation = client->QueueAllocator().New(client); + mDefaultQueue = allocation->object.get(); + + DeviceGetDefaultQueueCmd cmd; + cmd.self = ToAPI(this); + cmd.result = ObjectHandle{allocation->object->id, allocation->generation}; + + client->SerializeCommand(cmd); + } + mDefaultQueue->refcount++; return ToAPI(mDefaultQueue); } diff --git a/src/dawn_wire/server/ObjectStorage.h b/src/dawn_wire/server/ObjectStorage.h index 74cc5a77e9..c803f53337 100644 --- a/src/dawn_wire/server/ObjectStorage.h +++ b/src/dawn_wire/server/ObjectStorage.h @@ -160,6 +160,17 @@ namespace dawn_wire { namespace server { return objects; } + std::vector GetAllHandles() { + std::vector objects; + for (Data& data : mKnown) { + if (data.allocated && data.handle != nullptr) { + objects.push_back(data.handle); + } + } + + return objects; + } + private: std::vector mKnown; }; diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp index 67a9dd5618..39e50ea118 100644 --- a/src/dawn_wire/server/Server.cpp +++ b/src/dawn_wire/server/Server.cpp @@ -23,7 +23,6 @@ namespace dawn_wire { namespace server { MemoryTransferService* memoryTransferService) : mSerializer(serializer), mProcs(procs), - mDeviceOnCreation(device), mMemoryTransferService(memoryTransferService), mIsAlive(std::make_shared(true)) { if (mMemoryTransferService == nullptr) { @@ -31,38 +30,21 @@ namespace dawn_wire { namespace server { mOwnedMemoryTransferService = CreateInlineMemoryTransferService(); mMemoryTransferService = mOwnedMemoryTransferService.get(); } - // The client-server knowledge is bootstrapped with device 1. - auto* deviceData = DeviceObjects().Allocate(1); - deviceData->handle = device; - // Take an extra ref. All objects may be freed by the client, but this - // one is externally owned. - mProcs.deviceReference(device); - - // Note: these callbacks are manually inlined here since they do not acquire and - // free their userdata. - mProcs.deviceSetUncapturedErrorCallback( - device, - [](WGPUErrorType type, const char* message, void* userdata) { - Server* server = static_cast(userdata); - server->OnUncapturedError(type, message); - }, - this); - mProcs.deviceSetDeviceLostCallback( - device, - [](const char* message, void* userdata) { - Server* server = static_cast(userdata); - server->OnDeviceLost(message); - }, - this); + // For the deprecated initialization path: + // The client-server knowledge is bootstrapped with device 1, generation 0. + if (device != nullptr) { + bool success = InjectDevice(device, 1, 0); + ASSERT(success); + } } Server::~Server() { // Un-set the error and lost callbacks since we cannot forward them // after the server has been destroyed. - mProcs.deviceSetUncapturedErrorCallback(mDeviceOnCreation, nullptr, nullptr); - mProcs.deviceSetDeviceLostCallback(mDeviceOnCreation, nullptr, nullptr); - + for (WGPUDevice device : DeviceObjects().GetAllHandles()) { + ClearDeviceCallbacks(device); + } DestroyAllObjects(mProcs); } @@ -71,6 +53,7 @@ namespace dawn_wire { namespace server { uint32_t generation, uint32_t deviceId, uint32_t deviceGeneration) { + ASSERT(texture != nullptr); ObjectData* device = DeviceObjects().Get(deviceId); if (device == nullptr || device->generation != deviceGeneration) { return false; @@ -97,6 +80,57 @@ namespace dawn_wire { namespace server { return true; } + bool Server::InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation) { + ASSERT(device != nullptr); + ObjectData* data = DeviceObjects().Allocate(id); + if (data == nullptr) { + return false; + } + + data->handle = device; + data->generation = generation; + data->allocated = true; + + // The device is externally owned so it shouldn't be destroyed when we receive a destroy + // message from the client. Add a reference to counterbalance the eventual release. + mProcs.deviceReference(device); + + // Set callbacks to forward errors to the client. + // Note: these callbacks are manually inlined here since they do not acquire and + // free their userdata. + mProcs.deviceSetUncapturedErrorCallback( + device, + [](WGPUErrorType type, const char* message, void* userdata) { + Server* server = static_cast(userdata); + server->OnUncapturedError(type, message); + }, + this); + mProcs.deviceSetDeviceLostCallback( + device, + [](const char* message, void* userdata) { + Server* server = static_cast(userdata); + server->OnDeviceLost(message); + }, + this); + + return true; + } + + WGPUDevice Server::GetDevice(uint32_t id, uint32_t generation) { + ObjectData* data = DeviceObjects().Get(id); + if (data == nullptr || data->generation != generation) { + return nullptr; + } + return data->handle; + } + + void Server::ClearDeviceCallbacks(WGPUDevice device) { + // Un-set the error and lost callbacks since we cannot forward them + // after the server has been destroyed. + mProcs.deviceSetUncapturedErrorCallback(device, nullptr, nullptr); + mProcs.deviceSetDeviceLostCallback(device, nullptr, nullptr); + } + bool TrackDeviceChild(ObjectDataBase* device, ObjectType type, ObjectId id) { auto it = static_cast*>(device)->childObjectTypesAndIds.insert( PackObjectTypeAndId(type, id)); diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h index f45ed0d6a4..4056896aaa 100644 --- a/src/dawn_wire/server/Server.h +++ b/src/dawn_wire/server/Server.h @@ -167,6 +167,10 @@ namespace dawn_wire { namespace server { uint32_t deviceId, uint32_t deviceGeneration); + bool InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation); + + WGPUDevice GetDevice(uint32_t id, uint32_t generation); + template ::value>> std::unique_ptr MakeUserdata() { @@ -186,6 +190,7 @@ namespace dawn_wire { namespace server { mSerializer.SerializeCommand(cmd, extraSize, SerializeExtraSize); } + void ClearDeviceCallbacks(WGPUDevice device); // Error callbacks void OnUncapturedError(WGPUErrorType type, const char* message); @@ -212,7 +217,6 @@ namespace dawn_wire { namespace server { WireDeserializeAllocator mAllocator; ChunkedCommandSerializer mSerializer; DawnProcTable mProcs; - WGPUDevice mDeviceOnCreation; std::unique_ptr mOwnedMemoryTransferService = nullptr; MemoryTransferService* mMemoryTransferService = nullptr; diff --git a/src/include/dawn_wire/WireClient.h b/src/include/dawn_wire/WireClient.h index 8af02a9719..b8f12472a3 100644 --- a/src/include/dawn_wire/WireClient.h +++ b/src/include/dawn_wire/WireClient.h @@ -38,6 +38,12 @@ namespace dawn_wire { uint32_t deviceGeneration; }; + struct ReservedDevice { + WGPUDevice device; + uint32_t id; + uint32_t generation; + }; + struct DAWN_WIRE_EXPORT WireClientDescriptor { CommandSerializer* serializer; client::MemoryTransferService* memoryTransferService = nullptr; @@ -53,6 +59,7 @@ namespace dawn_wire { size_t size) override final; ReservedTexture ReserveTexture(WGPUDevice device); + ReservedDevice ReserveDevice(); // Disconnects the client. // Commands allocated after this point will not be sent. diff --git a/src/include/dawn_wire/WireServer.h b/src/include/dawn_wire/WireServer.h index ad36f4402c..9ff6fed145 100644 --- a/src/include/dawn_wire/WireServer.h +++ b/src/include/dawn_wire/WireServer.h @@ -50,6 +50,17 @@ namespace dawn_wire { uint32_t deviceId = 1, uint32_t deviceGeneration = 0); + bool InjectDevice(WGPUDevice device, uint32_t id, uint32_t generation); + + // Look up a device by (id, generation) pair. Returns nullptr if the generation + // has expired or the id is not found. + // The Wire does not have destroy hooks to allow an embedder to observe when an object + // has been destroyed, but in Chrome, we need to know the list of live devices so we + // can call device.Tick() on all of them periodically to ensure progress on asynchronous + // work is made. Getting this list can be done by tracking the (id, generation) of + // previously injected devices, and observing if GetDevice(id, generation) returns non-null. + WGPUDevice GetDevice(uint32_t id, uint32_t generation); + private: std::unique_ptr mImpl; }; diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn index 5a3e310e42..e614db2b3c 100644 --- a/src/tests/BUILD.gn +++ b/src/tests/BUILD.gn @@ -223,6 +223,7 @@ test("dawn_unittests") { "unittests/wire/WireErrorCallbackTests.cpp", "unittests/wire/WireExtensionTests.cpp", "unittests/wire/WireFenceTests.cpp", + "unittests/wire/WireInjectDeviceTests.cpp", "unittests/wire/WireInjectTextureTests.cpp", "unittests/wire/WireMemoryTransferServiceTests.cpp", "unittests/wire/WireMultipleDeviceTests.cpp", diff --git a/src/tests/unittests/wire/WireDestroyObjectTests.cpp b/src/tests/unittests/wire/WireDestroyObjectTests.cpp index 34b976dcc9..2c7ddc29e5 100644 --- a/src/tests/unittests/wire/WireDestroyObjectTests.cpp +++ b/src/tests/unittests/wire/WireDestroyObjectTests.cpp @@ -36,10 +36,19 @@ TEST_F(WireDestroyObjectTests, DestroyDeviceDestroysChildren) { // The device and child objects should be released. EXPECT_CALL(api, CommandEncoderRelease(apiEncoder)).InSequence(s1); EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr)) + .Times(1) + .InSequence(s1, s2); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr)) + .Times(1) + .InSequence(s1, s2); EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2); FlushClient(); + // Signal that we already released and cleared callbacks for |apiDevice| + DefaultApiDeviceWasReleased(); + // Using the command encoder should be an error. wgpuCommandEncoderFinish(encoder, nullptr); FlushClient(false); @@ -82,8 +91,17 @@ TEST_F(WireDestroyObjectTests, ImplicitInjectErrorAfterDestroyDevice) { // The device and child objects alre also released. EXPECT_CALL(api, BufferRelease(apiBuffer)).InSequence(s1); EXPECT_CALL(api, QueueRelease(apiQueue)).InSequence(s2); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr)) + .Times(1) + .InSequence(s1, s2); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr)) + .Times(1) + .InSequence(s1, s2); EXPECT_CALL(api, DeviceRelease(apiDevice)).InSequence(s1, s2); FlushClient(); + + // Signal that we already released and cleared callbacks for |apiDevice| + DefaultApiDeviceWasReleased(); } } diff --git a/src/tests/unittests/wire/WireDisconnectTests.cpp b/src/tests/unittests/wire/WireDisconnectTests.cpp index f44df13c91..d3f65a91c3 100644 --- a/src/tests/unittests/wire/WireDisconnectTests.cpp +++ b/src/tests/unittests/wire/WireDisconnectTests.cpp @@ -149,6 +149,15 @@ TEST_F(WireDisconnectTests, DeleteClientDestroysObjects) { EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1).InSequence(s1); EXPECT_CALL(api, CommandEncoderRelease(apiCommandEncoder)).Times(1).InSequence(s2); EXPECT_CALL(api, SamplerRelease(apiSampler)).Times(1).InSequence(s3); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr)) + .Times(1) + .InSequence(s1, s2); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr)) + .Times(1) + .InSequence(s1, s2); EXPECT_CALL(api, DeviceRelease(apiDevice)).Times(1).InSequence(s1, s2, s3); FlushClient(); + + // Signal that we already released and cleared callbacks for |apiDevice| + DefaultApiDeviceWasReleased(); } diff --git a/src/tests/unittests/wire/WireInjectDeviceTests.cpp b/src/tests/unittests/wire/WireInjectDeviceTests.cpp new file mode 100644 index 0000000000..8f1dda3b6a --- /dev/null +++ b/src/tests/unittests/wire/WireInjectDeviceTests.cpp @@ -0,0 +1,184 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tests/unittests/wire/WireTest.h" + +#include "dawn_wire/WireClient.h" +#include "dawn_wire/WireServer.h" + +using namespace testing; +using namespace dawn_wire; + +class WireInjectDeviceTests : public WireTest { + public: + WireInjectDeviceTests() { + } + ~WireInjectDeviceTests() override = default; +}; + +// Test that reserving and injecting a device makes calls on the client object forward to the +// server object correctly. +TEST_F(WireInjectDeviceTests, CallAfterReserveInject) { + ReservedDevice reservation = GetWireClient()->ReserveDevice(); + + WGPUDevice serverDevice = api.GetNewDevice(); + EXPECT_CALL(api, DeviceReference(serverDevice)); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _)); + ASSERT_TRUE( + GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation)); + + WGPUBufferDescriptor bufferDesc = {}; + wgpuDeviceCreateBuffer(reservation.device, &bufferDesc); + WGPUBuffer serverBuffer = api.GetNewBuffer(); + EXPECT_CALL(api, DeviceCreateBuffer(serverDevice, _)).WillOnce(Return(serverBuffer)); + FlushClient(); + + // Called on shutdown. + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr)) + .Times(Exactly(1)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr)) + .Times(Exactly(1)); +} + +// Test that reserve correctly returns different IDs each time. +TEST_F(WireInjectDeviceTests, ReserveDifferentIDs) { + ReservedDevice reservation1 = GetWireClient()->ReserveDevice(); + ReservedDevice reservation2 = GetWireClient()->ReserveDevice(); + + ASSERT_NE(reservation1.id, reservation2.id); + ASSERT_NE(reservation1.device, reservation2.device); +} + +// Test that injecting the same id without a destroy first fails. +TEST_F(WireInjectDeviceTests, InjectExistingID) { + ReservedDevice reservation = GetWireClient()->ReserveDevice(); + + WGPUDevice serverDevice = api.GetNewDevice(); + EXPECT_CALL(api, DeviceReference(serverDevice)); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _)); + ASSERT_TRUE( + GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation)); + + // ID already in use, call fails. + ASSERT_FALSE( + GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation)); + + // Called on shutdown. + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr)) + .Times(Exactly(1)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr)) + .Times(Exactly(1)); +} + +// Test that the server only borrows the device and does a single reference-release +TEST_F(WireInjectDeviceTests, InjectedDeviceLifetime) { + ReservedDevice reservation = GetWireClient()->ReserveDevice(); + + // Injecting the device adds a reference + WGPUDevice serverDevice = api.GetNewDevice(); + EXPECT_CALL(api, DeviceReference(serverDevice)); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _)); + ASSERT_TRUE( + GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation)); + + // Releasing the device removes a single reference and clears its error callbacks. + wgpuDeviceRelease(reservation.device); + EXPECT_CALL(api, DeviceRelease(serverDevice)); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr)).Times(1); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr)).Times(1); + FlushClient(); + + // Deleting the server doesn't release a second reference. + DeleteServer(); + Mock::VerifyAndClearExpectations(&api); +} + +// Test that it is an error to get the default queue of a device before it has been +// injected on the server. +TEST_F(WireInjectDeviceTests, GetQueueBeforeInject) { + ReservedDevice reservation = GetWireClient()->ReserveDevice(); + + wgpuDeviceGetDefaultQueue(reservation.device); + FlushClient(false); +} + +// Test that it is valid to get the default queue of a device after it has been +// injected on the server. +TEST_F(WireInjectDeviceTests, GetQueueAfterInject) { + ReservedDevice reservation = GetWireClient()->ReserveDevice(); + + WGPUDevice serverDevice = api.GetNewDevice(); + EXPECT_CALL(api, DeviceReference(serverDevice)); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, _, _)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, _, _)); + ASSERT_TRUE( + GetWireServer()->InjectDevice(serverDevice, reservation.id, reservation.generation)); + + wgpuDeviceGetDefaultQueue(reservation.device); + + WGPUQueue apiQueue = api.GetNewQueue(); + EXPECT_CALL(api, DeviceGetDefaultQueue(serverDevice)).WillOnce(Return(apiQueue)); + FlushClient(); + + // Called on shutdown. + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice, nullptr, nullptr)) + .Times(Exactly(1)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice, nullptr, nullptr)) + .Times(Exactly(1)); +} + +// Test that the list of live devices can be reflected using GetDevice. +TEST_F(WireInjectDeviceTests, ReflectLiveDevices) { + // Reserve two devices. + ReservedDevice reservation1 = GetWireClient()->ReserveDevice(); + ReservedDevice reservation2 = GetWireClient()->ReserveDevice(); + + // Inject both devices. + + WGPUDevice serverDevice1 = api.GetNewDevice(); + EXPECT_CALL(api, DeviceReference(serverDevice1)); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice1, _, _)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice1, _, _)); + ASSERT_TRUE( + GetWireServer()->InjectDevice(serverDevice1, reservation1.id, reservation1.generation)); + + WGPUDevice serverDevice2 = api.GetNewDevice(); + EXPECT_CALL(api, DeviceReference(serverDevice2)); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, _, _)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, _, _)); + ASSERT_TRUE( + GetWireServer()->InjectDevice(serverDevice2, reservation2.id, reservation2.generation)); + + // Test that both devices can be reflected. + ASSERT_EQ(serverDevice1, GetWireServer()->GetDevice(reservation1.id, reservation1.generation)); + ASSERT_EQ(serverDevice2, GetWireServer()->GetDevice(reservation2.id, reservation2.generation)); + + // Release the first device + wgpuDeviceRelease(reservation1.device); + EXPECT_CALL(api, DeviceRelease(serverDevice1)); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice1, nullptr, nullptr)).Times(1); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice1, nullptr, nullptr)).Times(1); + FlushClient(); + + // The first device should no longer reflect, but the second should + ASSERT_EQ(nullptr, GetWireServer()->GetDevice(reservation1.id, reservation1.generation)); + ASSERT_EQ(serverDevice2, GetWireServer()->GetDevice(reservation2.id, reservation2.generation)); + + // Called on shutdown. + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(serverDevice2, nullptr, nullptr)).Times(1); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(serverDevice2, nullptr, nullptr)).Times(1); +} diff --git a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp index 75cedcd7f5..216c122848 100644 --- a/src/tests/unittests/wire/WireMultipleDeviceTests.cpp +++ b/src/tests/unittests/wire/WireMultipleDeviceTests.cpp @@ -83,9 +83,10 @@ class WireMultipleDeviceTests : public testing::Test { // These are called on server destruction to clear the callbacks. They must not be // called after the server is destroyed. - EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)) + EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(mServerDevice, nullptr, nullptr)) + .Times(Exactly(1)); + EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(mServerDevice, nullptr, nullptr)) .Times(Exactly(1)); - EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1)); mWireServer = nullptr; } diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp index 260951194b..d3f17c3ad6 100644 --- a/src/tests/unittests/wire/WireTest.cpp +++ b/src/tests/unittests/wire/WireTest.cpp @@ -88,15 +88,23 @@ void WireTest::TearDown() { api.IgnoreAllReleaseCalls(); mWireClient = nullptr; - if (mWireServer) { + if (mWireServer && apiDevice) { // These are called on server destruction to clear the callbacks. They must not be // called after the server is destroyed. - EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1)); - EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1)); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr)) + .Times(Exactly(1)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr)) + .Times(Exactly(1)); } mWireServer = nullptr; } +// This should be called if |apiDevice| is no longer exists on the wire. +// This signals that expectations in |TearDowb| shouldn't be added. +void WireTest::DefaultApiDeviceWasReleased() { + apiDevice = nullptr; +} + void WireTest::FlushClient(bool success) { ASSERT_EQ(mC2sBuf->Flush(), success); @@ -123,8 +131,10 @@ void WireTest::DeleteServer() { if (mWireServer) { // These are called on server destruction to clear the callbacks. They must not be // called after the server is destroyed. - EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1)); - EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1)); + EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(apiDevice, nullptr, nullptr)) + .Times(Exactly(1)); + EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(apiDevice, nullptr, nullptr)) + .Times(Exactly(1)); } mWireServer = nullptr; } diff --git a/src/tests/unittests/wire/WireTest.h b/src/tests/unittests/wire/WireTest.h index 95fd3077ec..03ac641168 100644 --- a/src/tests/unittests/wire/WireTest.h +++ b/src/tests/unittests/wire/WireTest.h @@ -123,6 +123,8 @@ class WireTest : public testing::Test { void FlushClient(bool success = true); void FlushServer(bool success = true); + void DefaultApiDeviceWasReleased(); + testing::StrictMock api; WGPUDevice apiDevice; WGPUQueue apiQueue;