From 3120d5ea0de8840298cc1dd3e702a1a3ebfe9523 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Wed, 11 Nov 2020 19:46:18 +0000 Subject: [PATCH] Track and destroy all child objects on wire client destruction This is needed so that: 1. We can support multiple devices in the wire. The device will need to know how to destroy its child objects. 2. The wire needs to be aware of all objects and their in-flight callbacks so that it can reject them if the wire is disconnnected. A future change will handle this. 3. Fix leaks of objects on page teardown. When the page is torn down, the wire client is destroyed, and we skip calling release() for all objects since the object holding the proc table was also destroyed. Bug: dawn:384, dawn:556 Change-Id: Ie23afe4e515b02e924fcfc2db92b749fd2257c9c Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/31160 Reviewed-by: Austin Eng Commit-Queue: Austin Eng --- generator/dawn_json_generator.py | 3 ++ generator/templates/dawn_wire/ObjectType.h | 34 +++++++++++++++++++ generator/templates/dawn_wire/WireCmd.h | 8 ++--- .../templates/dawn_wire/client/ApiObjects.h | 9 +++++ .../templates/dawn_wire/client/ClientBase.h | 17 +++++++--- src/dawn_wire/BUILD.gn | 1 + src/dawn_wire/client/Client.cpp | 16 +++++++-- src/dawn_wire/client/Client.h | 7 ++++ src/dawn_wire/client/Device.cpp | 25 +++++++++----- src/dawn_wire/client/Device.h | 11 ++++++ src/dawn_wire/client/Fence.h | 1 - src/dawn_wire/client/ObjectAllocator.h | 3 ++ src/dawn_wire/client/ObjectBase.h | 14 ++++++-- .../unittests/wire/WireDisconnectTests.cpp | 24 +++++++++++++ src/tests/unittests/wire/WireTest.cpp | 4 +++ src/tests/unittests/wire/WireTest.h | 1 + 16 files changed, 151 insertions(+), 27 deletions(-) create mode 100644 generator/templates/dawn_wire/ObjectType.h diff --git a/generator/dawn_json_generator.py b/generator/dawn_json_generator.py index 02c572337c..01d7040b30 100644 --- a/generator/dawn_json_generator.py +++ b/generator/dawn_json_generator.py @@ -762,6 +762,9 @@ class MultiGeneratorFromDawnJSON(Generator): lambda arg: annotated(as_wireType(arg.type), arg), }, additional_params ] + renders.append( + FileRender('dawn_wire/ObjectType.h', + 'src/dawn_wire/ObjectType_autogen.h', wire_params)) renders.append( FileRender('dawn_wire/WireCmd.h', 'src/dawn_wire/WireCmd_autogen.h', wire_params)) diff --git a/generator/templates/dawn_wire/ObjectType.h b/generator/templates/dawn_wire/ObjectType.h new file mode 100644 index 0000000000..e049d3782e --- /dev/null +++ b/generator/templates/dawn_wire/ObjectType.h @@ -0,0 +1,34 @@ +//* Copyright 2020 The Dawn Authors +//* +//* Licensed under the Apache License, Version 2.0 (the "License"); +//* you may not use this file except in compliance with the License. +//* You may obtain a copy of the License at +//* +//* http://www.apache.org/licenses/LICENSE-2.0 +//* +//* Unless required by applicable law or agreed to in writing, software +//* distributed under the License is distributed on an "AS IS" BASIS, +//* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +//* See the License for the specific language governing permissions and +//* limitations under the License. + +#ifndef DAWNWIRE_OBJECTTPYE_AUTOGEN_H_ +#define DAWNWIRE_OBJECTTPYE_AUTOGEN_H_ + +#include "common/ityp_array.h" + +namespace dawn_wire { + + enum class ObjectType : uint32_t { + {% for type in by_category["object"] %} + {{type.name.CamelCase()}}, + {% endfor %} + }; + + template + using PerObjectType = ityp::array; + +} // namespace dawn_wire + + +#endif // DAWNWIRE_OBJECTTPYE_AUTOGEN_H_ diff --git a/generator/templates/dawn_wire/WireCmd.h b/generator/templates/dawn_wire/WireCmd.h index a2216921c5..a592d2f8d2 100644 --- a/generator/templates/dawn_wire/WireCmd.h +++ b/generator/templates/dawn_wire/WireCmd.h @@ -17,6 +17,8 @@ #include +#include "dawn_wire/ObjectType_autogen.h" + namespace dawn_wire { using ObjectId = uint32_t; @@ -72,12 +74,6 @@ namespace dawn_wire { {% endfor %} }; - enum class ObjectType : uint32_t { - {% for type in by_category["object"] %} - {{type.name.CamelCase()}}, - {% endfor %} - }; - //* Enum used as a prefix to each command on the wire format. enum class WireCmd : uint32_t { {% for command in cmd_records["command"] %} diff --git a/generator/templates/dawn_wire/client/ApiObjects.h b/generator/templates/dawn_wire/client/ApiObjects.h index 288c7004de..46930e0420 100644 --- a/generator/templates/dawn_wire/client/ApiObjects.h +++ b/generator/templates/dawn_wire/client/ApiObjects.h @@ -15,8 +15,14 @@ #ifndef DAWNWIRE_CLIENT_APIOBJECTS_AUTOGEN_H_ #define DAWNWIRE_CLIENT_APIOBJECTS_AUTOGEN_H_ +#include "dawn_wire/ObjectType_autogen.h" +#include "dawn_wire/client/ObjectBase.h" + namespace dawn_wire { namespace client { + template + static constexpr ObjectType ObjectTypeToTypeEnum = static_cast(-1); + {% for type in by_category["object"] %} {% set Type = type.name.CamelCase() %} {% if type.name.CamelCase() in client_special_objects %} @@ -34,6 +40,9 @@ namespace dawn_wire { namespace client { return reinterpret_cast(obj); } + template <> + static constexpr ObjectType ObjectTypeToTypeEnum<{{type.name.CamelCase()}}> = ObjectType::{{type.name.CamelCase()}}; + {% endfor %} }} // namespace dawn_wire::client diff --git a/generator/templates/dawn_wire/client/ClientBase.h b/generator/templates/dawn_wire/client/ClientBase.h index 65ff08f24d..1f2c31e9db 100644 --- a/generator/templates/dawn_wire/client/ClientBase.h +++ b/generator/templates/dawn_wire/client/ClientBase.h @@ -24,11 +24,8 @@ namespace dawn_wire { namespace client { class ClientBase : public ChunkedCommandHandler, public ObjectIdProvider { public: - ClientBase() { - } - - virtual ~ClientBase() { - } + ClientBase() = default; + virtual ~ClientBase() = default; {% for type in by_category["object"] %} const ObjectAllocator<{{type.name.CamelCase()}}>& {{type.name.CamelCase()}}Allocator() const { @@ -39,6 +36,16 @@ namespace dawn_wire { namespace client { } {% endfor %} + void FreeObject(ObjectType objectType, ObjectBase* obj) { + switch (objectType) { + {% for type in by_category["object"] %} + case ObjectType::{{type.name.CamelCase()}}: + m{{type.name.CamelCase()}}Allocator.Free(static_cast<{{type.name.CamelCase()}}*>(obj)); + break; + {% endfor %} + } + } + private: // Implementation of the ObjectIdProvider interface {% for type in by_category["object"] %} diff --git a/src/dawn_wire/BUILD.gn b/src/dawn_wire/BUILD.gn index 0f5c70e05c..04340b006d 100644 --- a/src/dawn_wire/BUILD.gn +++ b/src/dawn_wire/BUILD.gn @@ -33,6 +33,7 @@ source_set("dawn_wire_headers") { dawn_json_generator("dawn_wire_gen") { target = "dawn_wire" outputs = [ + "src/dawn_wire/ObjectType_autogen.h", "src/dawn_wire/WireCmd_autogen.h", "src/dawn_wire/WireCmd_autogen.cpp", "src/dawn_wire/client/ApiObjects_autogen.h", diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp index af0e40b091..a1a6f63b9e 100644 --- a/src/dawn_wire/client/Client.cpp +++ b/src/dawn_wire/client/Client.cpp @@ -53,8 +53,14 @@ namespace dawn_wire { namespace client { } Client::~Client() { - if (mDevice != nullptr) { - DeviceAllocator().Free(mDevice); + DestroyAllObjects(); + } + + void Client::DestroyAllObjects() { + while (!mDevices.empty()) { + // Note: We don't send a DestroyObject command for the device + // since freeing a device object is done out of band. + DeviceAllocator().Free(static_cast(mDevices.head()->value())); } } @@ -67,7 +73,7 @@ namespace dawn_wire { namespace client { ReservedTexture Client::ReserveTexture(WGPUDevice cDevice) { Device* device = FromAPI(cDevice); - ObjectAllocator::ObjectAndSerial* allocation = TextureAllocator().New(device); + auto* allocation = TextureAllocator().New(device); ReservedTexture result; result.texture = ToAPI(allocation->object.get()); @@ -83,4 +89,8 @@ namespace dawn_wire { namespace client { } } + void Client::TrackObject(Device* device) { + mDevices.Append(device); + } + }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h index 47ec95b5db..ecdae95c31 100644 --- a/src/dawn_wire/client/Client.h +++ b/src/dawn_wire/client/Client.h @@ -18,6 +18,7 @@ #include #include +#include "common/LinkedList.h" #include "dawn_wire/ChunkedCommandSerializer.h" #include "dawn_wire/WireClient.h" #include "dawn_wire/WireCmd_autogen.h" @@ -60,7 +61,11 @@ namespace dawn_wire { namespace client { void Disconnect(); + void TrackObject(Device* device); + private: + void DestroyAllObjects(); + #include "dawn_wire/client/ClientPrototypes_autogen.inc" Device* mDevice = nullptr; @@ -68,6 +73,8 @@ namespace dawn_wire { namespace client { WireDeserializeAllocator mAllocator; MemoryTransferService* mMemoryTransferService = nullptr; std::unique_ptr mOwnedMemoryTransferService = nullptr; + + LinkedList mDevices; }; std::unique_ptr CreateInlineMemoryTransferService(); diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp index 91d12fc4a4..f139230c96 100644 --- a/src/dawn_wire/client/Device.cpp +++ b/src/dawn_wire/client/Device.cpp @@ -23,10 +23,8 @@ namespace dawn_wire { namespace client { Device::Device(Client* client, uint32_t initialRefcount, uint32_t initialId) : ObjectBase(this, initialRefcount, initialId), mClient(client) { - this->device = this; - // Get the default queue for this device. - ObjectAllocator::ObjectAndSerial* allocation = mClient->QueueAllocator().New(this); + auto* allocation = mClient->QueueAllocator().New(this); mDefaultQueue = allocation->object.get(); DeviceGetDefaultQueueCmd cmd; @@ -58,14 +56,22 @@ namespace dawn_wire { namespace client { } } - // Destroy the default queue - DestroyObjectCmd cmd; - cmd.objectType = ObjectType::Queue; - cmd.objectId = mDefaultQueue->id; + DestroyAllObjects(); + } - mClient->SerializeCommand(cmd); + void Device::DestroyAllObjects() { + for (auto& objectList : mObjects) { + ObjectType objectType = static_cast(&objectList - mObjects.begin()); + while (!objectList.empty()) { + ObjectBase* object = objectList.head()->value(); - mClient->QueueAllocator().Free(mDefaultQueue); + DestroyObjectCmd cmd; + cmd.objectType = objectType; + cmd.objectId = object->id; + mClient->SerializeCommand(cmd); + mClient->FreeObject(objectType, object); + } + } } Client* Device::GetClient() { @@ -273,4 +279,5 @@ namespace dawn_wire { namespace client { return true; } + }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h index 7d14c6fe7f..82c68990c7 100644 --- a/src/dawn_wire/client/Device.h +++ b/src/dawn_wire/client/Device.h @@ -17,7 +17,9 @@ #include +#include "common/LinkedList.h" #include "dawn_wire/WireCmd_autogen.h" +#include "dawn_wire/client/ApiObjects_autogen.h" #include "dawn_wire/client/ObjectBase.h" #include @@ -61,7 +63,14 @@ namespace dawn_wire { namespace client { WGPUQueue GetDefaultQueue(); + template + void TrackObject(T* object) { + mObjects[ObjectTypeToTypeEnum].Append(object); + } + private: + void DestroyAllObjects(); + struct ErrorScopeData { WGPUErrorCallback callback = nullptr; void* userdata = nullptr; @@ -87,6 +96,8 @@ namespace dawn_wire { namespace client { void* mDeviceLostUserdata = nullptr; Queue* mDefaultQueue = nullptr; + + PerObjectType> mObjects; }; }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Fence.h b/src/dawn_wire/client/Fence.h index 58d89c3371..00791944bd 100644 --- a/src/dawn_wire/client/Fence.h +++ b/src/dawn_wire/client/Fence.h @@ -26,7 +26,6 @@ namespace dawn_wire { namespace client { class Fence : public ObjectBase { public: using ObjectBase::ObjectBase; - ~Fence(); void Initialize(Queue* queue, const WGPUFenceDescriptor* descriptor); diff --git a/src/dawn_wire/client/ObjectAllocator.h b/src/dawn_wire/client/ObjectAllocator.h index 215b9f4a32..bb0c4e4ed5 100644 --- a/src/dawn_wire/client/ObjectAllocator.h +++ b/src/dawn_wire/client/ObjectAllocator.h @@ -17,6 +17,7 @@ #include "common/Assert.h" #include "common/Compiler.h" +#include "dawn_wire/WireCmd_autogen.h" #include #include @@ -49,6 +50,7 @@ namespace dawn_wire { namespace client { ObjectAndSerial* New(ObjectOwner* owner) { uint32_t id = GetNewId(); auto object = std::make_unique(owner, 1, id); + owner->TrackObject(object.get()); if (id >= mObjects.size()) { ASSERT(id == mObjects.size()); @@ -67,6 +69,7 @@ namespace dawn_wire { namespace client { return &mObjects[id]; } void Free(T* obj) { + ASSERT(obj->IsInList()); if (DAWN_LIKELY(mObjects[obj->id].generation != std::numeric_limits::max())) { // Only recycle this ObjectId if the generation won't overflow on the next // allocation. diff --git a/src/dawn_wire/client/ObjectBase.h b/src/dawn_wire/client/ObjectBase.h index edf18f6c87..18778d6a64 100644 --- a/src/dawn_wire/client/ObjectBase.h +++ b/src/dawn_wire/client/ObjectBase.h @@ -17,6 +17,9 @@ #include +#include "common/LinkedList.h" +#include "dawn_wire/ObjectType_autogen.h" + namespace dawn_wire { namespace client { class Device; @@ -25,14 +28,19 @@ namespace dawn_wire { namespace client { // - A pointer to the device to get where to serialize commands // - The external reference count // - An ID that is used to refer to this object when talking with the server side - struct ObjectBase { + // - A next/prev pointer. They are part of a linked list of objects of the same type. + struct ObjectBase : public LinkNode { ObjectBase(Device* device, uint32_t refcount, uint32_t id) : device(device), refcount(refcount), id(id) { } - Device* device; + ~ObjectBase() { + RemoveFromList(); + } + + Device* const device; uint32_t refcount; - uint32_t id; + const uint32_t id; }; }} // namespace dawn_wire::client diff --git a/src/tests/unittests/wire/WireDisconnectTests.cpp b/src/tests/unittests/wire/WireDisconnectTests.cpp index 4e9b355f8c..5536292158 100644 --- a/src/tests/unittests/wire/WireDisconnectTests.cpp +++ b/src/tests/unittests/wire/WireDisconnectTests.cpp @@ -126,3 +126,27 @@ TEST_F(WireDisconnectTests, DisconnectThenServerLost) { EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0)); FlushServer(); } + +// Test that client objects are all destroyed if the WireClient is destroyed. +TEST_F(WireDisconnectTests, DeleteClientDestroysObjects) { + WGPUSamplerDescriptor desc = {}; + wgpuDeviceCreateCommandEncoder(device, nullptr); + wgpuDeviceCreateSampler(device, &desc); + + WGPUCommandEncoder apiCommandEncoder = api.GetNewCommandEncoder(); + EXPECT_CALL(api, DeviceCreateCommandEncoder(apiDevice, nullptr)) + .WillOnce(Return(apiCommandEncoder)); + + WGPUSampler apiSampler = api.GetNewSampler(); + EXPECT_CALL(api, DeviceCreateSampler(apiDevice, _)).WillOnce(Return(apiSampler)); + + FlushClient(); + + DeleteClient(); + + // Expect release on all objects created by the client. + EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1); + EXPECT_CALL(api, CommandEncoderRelease(apiCommandEncoder)).Times(1); + EXPECT_CALL(api, SamplerRelease(apiSampler)).Times(1); + FlushClient(); +} diff --git a/src/tests/unittests/wire/WireTest.cpp b/src/tests/unittests/wire/WireTest.cpp index d23709c88c..7c8a4d53a0 100644 --- a/src/tests/unittests/wire/WireTest.cpp +++ b/src/tests/unittests/wire/WireTest.cpp @@ -113,6 +113,10 @@ void WireTest::DeleteServer() { mWireServer = nullptr; } +void WireTest::DeleteClient() { + mWireClient = nullptr; +} + void WireTest::SetupIgnoredCallExpectations() { EXPECT_CALL(api, DeviceTick(_)).Times(AnyNumber()); } diff --git a/src/tests/unittests/wire/WireTest.h b/src/tests/unittests/wire/WireTest.h index d4537d67e0..95fd3077ec 100644 --- a/src/tests/unittests/wire/WireTest.h +++ b/src/tests/unittests/wire/WireTest.h @@ -133,6 +133,7 @@ class WireTest : public testing::Test { dawn_wire::WireClient* GetWireClient(); void DeleteServer(); + void DeleteClient(); private: void SetupIgnoredCallExpectations();