dawn::wire::client: Track the object generation on the objects

Previously the ObjectAllocator was tracking the generation on the side
of the object. This was done to avoid the need to check that the objects
aren't null before accessing the generation in ClientHandlers. This is
only a very minor optimization for return commands so it is removed in
favor of simplifying the code.

The code is simplified in a bunch of place by getting the ObjectHandle
for an object directly (since it knows the generation now) instead of
walking the object graph returned by the allocator.

The ObjectBase class is also changed to store an ObjectHandle
interrnally that's only accessible via getters. Encapsulating the other
memebers will be done in follow-up CLs.

Also adds the generation to the ObjectBaseParams since all ObjectBases
now require it.

Bug: dawn:1451
Change-Id: Ic6c850fc989f715f7c80952ff447b7c29378cd27
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/93146
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Corentin Wallez 2022-06-14 14:55:46 +00:00 committed by Dawn LUCI CQ
parent 87af04b769
commit 0f97df8c53
16 changed files with 124 additions and 125 deletions

View File

@ -58,8 +58,8 @@ namespace dawn::wire::client {
//* For object creation, store the object ID the client will use for the result.
{% if method.return_type.category == "object" %}
auto* allocation = self->client->{{method.return_type.name.CamelCase()}}Allocator().New(self->client);
cmd.result = ObjectHandle{allocation->object->id, allocation->generation};
auto* returnObject = self->client->{{method.return_type.name.CamelCase()}}Allocator().New(self->client);
cmd.result = returnObject->GetWireHandle();
{% endif %}
{% for arg in method.arguments %}
@ -72,7 +72,7 @@ namespace dawn::wire::client {
self->client->SerializeCommand(cmd);
{% if method.return_type.category == "object" %}
return reinterpret_cast<{{as_cType(method.return_type.name)}}>(allocation->object.get());
return ToAPI(returnObject);
{% endif %}
{% else %}
return self->{{method.name.CamelCase()}}(
@ -94,7 +94,7 @@ namespace dawn::wire::client {
DestroyObjectCmd cmd;
cmd.objectType = ObjectType::{{type.name.CamelCase()}};
cmd.objectId = obj->id;
cmd.objectId = obj->GetWireId();
obj->client->SerializeCommand(cmd);
obj->client->{{type.name.CamelCase()}}Allocator().Free(obj);

View File

@ -54,12 +54,12 @@ namespace dawn::wire::client {
if (object == nullptr) {
return WireResult::FatalError;
}
*out = reinterpret_cast<{{as_wireType(type)}}>(object)->id;
*out = reinterpret_cast<{{as_wireType(type)}}>(object)->GetWireId();
return WireResult::Success;
}
WireResult GetOptionalId({{as_cType(type.name)}} object, ObjectId* out) const final {
ASSERT(out != nullptr);
*out = (object == nullptr ? 0 : reinterpret_cast<{{as_wireType(type)}}>(object)->id);
*out = (object == nullptr ? 0 : reinterpret_cast<{{as_wireType(type)}}>(object)->GetWireId());
return WireResult::Success;
}
{% endfor %}

View File

@ -33,8 +33,7 @@ namespace dawn::wire::client {
{% if member.type.dict_name == "ObjectHandle" %}
{{Type}}* {{name}} = {{Type}}Allocator().GetObject(cmd.{{name}}.id);
uint32_t {{name}}Generation = {{Type}}Allocator().GetGeneration(cmd.{{name}}.id);
if ({{name}}Generation != cmd.{{name}}.generation) {
if ({{name}} != nullptr && {{name}}->GetWireGeneration() != cmd.{{name}}.generation) {
{{name}} = nullptr;
}
{% endif %}

View File

@ -28,7 +28,6 @@ ObjectHandle& ObjectHandle::operator=(const volatile ObjectHandle& rhs) {
return *this;
}
ObjectHandle::ObjectHandle(const ObjectHandle& rhs) = default;
ObjectHandle& ObjectHandle::operator=(const ObjectHandle& rhs) = default;

View File

@ -70,13 +70,13 @@ void Adapter::RequestDevice(const WGPUDeviceDescriptor* descriptor,
return;
}
auto* allocation = client->DeviceAllocator().New(client);
uint64_t serial = mRequestDeviceRequests.Add({callback, allocation->object->id, userdata});
Device* device = client->DeviceAllocator().New(client);
uint64_t serial = mRequestDeviceRequests.Add({callback, device->GetWireId(), userdata});
AdapterRequestDeviceCmd cmd;
cmd.adapterId = this->id;
cmd.adapterId = GetWireId();
cmd.requestSerial = serial;
cmd.deviceObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
cmd.deviceObjectHandle = device->GetWireHandle();
cmd.descriptor = descriptor;
client->SerializeCommand(cmd);

View File

@ -40,7 +40,7 @@ WGPUBuffer Buffer::Create(Device* device, const WGPUBufferDescriptor* descriptor
std::unique_ptr<MemoryTransferService::WriteHandle> writeHandle = nullptr;
DeviceCreateBufferCmd cmd;
cmd.deviceId = device->id;
cmd.deviceId = device->GetWireId();
cmd.descriptor = descriptor;
cmd.readHandleCreateInfoLength = 0;
cmd.readHandleCreateInfo = nullptr;
@ -74,8 +74,7 @@ WGPUBuffer Buffer::Create(Device* device, const WGPUBufferDescriptor* descriptor
// Create the buffer and send the creation command.
// This must happen after any potential device->CreateErrorBuffer()
// as server expects allocating ids to be monotonically increasing
auto* bufferObjectAndSerial = wireClient->BufferAllocator().New(wireClient);
Buffer* buffer = bufferObjectAndSerial->object.get();
Buffer* buffer = wireClient->BufferAllocator().New(wireClient);
buffer->mDevice = device;
buffer->mDeviceIsAlive = device->GetAliveWeakPtr();
buffer->mSize = descriptor->size;
@ -98,7 +97,7 @@ WGPUBuffer Buffer::Create(Device* device, const WGPUBufferDescriptor* descriptor
buffer->mMappedData = writeHandle->GetData();
}
cmd.result = ObjectHandle{buffer->id, bufferObjectAndSerial->generation};
cmd.result = buffer->GetWireHandle();
wireClient->SerializeCommand(
cmd, cmd.readHandleCreateInfoLength + cmd.writeHandleCreateInfoLength,
@ -126,18 +125,18 @@ WGPUBuffer Buffer::Create(Device* device, const WGPUBufferDescriptor* descriptor
// static
WGPUBuffer Buffer::CreateError(Device* device, const WGPUBufferDescriptor* descriptor) {
auto* allocation = device->client->BufferAllocator().New(device->client);
allocation->object->mDevice = device;
allocation->object->mDeviceIsAlive = device->GetAliveWeakPtr();
allocation->object->mSize = descriptor->size;
allocation->object->mUsage = static_cast<WGPUBufferUsage>(descriptor->usage);
Buffer* buffer = device->client->BufferAllocator().New(device->client);
buffer->mDevice = device;
buffer->mDeviceIsAlive = device->GetAliveWeakPtr();
buffer->mSize = descriptor->size;
buffer->mUsage = static_cast<WGPUBufferUsage>(descriptor->usage);
DeviceCreateErrorBufferCmd cmd;
cmd.self = ToAPI(device);
cmd.result = ObjectHandle{allocation->object->id, allocation->generation};
cmd.result = buffer->GetWireHandle();
device->client->SerializeCommand(cmd);
return ToAPI(allocation->object.get());
return ToAPI(buffer);
}
Buffer::~Buffer() {
@ -188,7 +187,7 @@ void Buffer::MapAsync(WGPUMapModeFlags mode,
// Serialize the command to send to the server.
BufferMapAsyncCmd cmd;
cmd.bufferId = this->id;
cmd.bufferId = GetWireId();
cmd.requestSerial = serial;
cmd.mode = mode;
cmd.offset = offset;
@ -301,7 +300,7 @@ void Buffer::Unmap() {
mWriteHandle->SizeOfSerializeDataUpdate(mMapOffset, mMapSize);
BufferUpdateMappedDataCmd cmd;
cmd.bufferId = id;
cmd.bufferId = GetWireId();
cmd.writeDataUpdateInfoLength = writeDataUpdateInfoLength;
cmd.writeDataUpdateInfo = nullptr;
cmd.offset = mMapOffset;

View File

@ -59,7 +59,7 @@ void Client::DestroyAllObjects() {
DestroyObjectCmd cmd;
cmd.objectType = ObjectType::Device;
cmd.objectId = object->id;
cmd.objectId = object->GetWireId();
SerializeCommand(cmd);
FreeObject(ObjectType::Device, object);
}
@ -74,7 +74,7 @@ void Client::DestroyAllObjects() {
DestroyObjectCmd cmd;
cmd.objectType = objectType;
cmd.objectId = object->id;
cmd.objectId = object->GetWireId();
SerializeCommand(cmd);
FreeObject(objectType, object);
}
@ -82,46 +82,46 @@ void Client::DestroyAllObjects() {
}
ReservedTexture Client::ReserveTexture(WGPUDevice device) {
auto* allocation = TextureAllocator().New(this);
Texture* texture = TextureAllocator().New(this);
ReservedTexture result;
result.texture = ToAPI(allocation->object.get());
result.id = allocation->object->id;
result.generation = allocation->generation;
result.deviceId = FromAPI(device)->id;
result.deviceGeneration = DeviceAllocator().GetGeneration(FromAPI(device)->id);
result.texture = ToAPI(texture);
result.id = texture->GetWireId();
result.generation = texture->GetWireGeneration();
result.deviceId = FromAPI(device)->GetWireId();
result.deviceGeneration = FromAPI(device)->GetWireGeneration();
return result;
}
ReservedSwapChain Client::ReserveSwapChain(WGPUDevice device) {
auto* allocation = SwapChainAllocator().New(this);
SwapChain* swapChain = SwapChainAllocator().New(this);
ReservedSwapChain result;
result.swapchain = ToAPI(allocation->object.get());
result.id = allocation->object->id;
result.generation = allocation->generation;
result.deviceId = FromAPI(device)->id;
result.deviceGeneration = DeviceAllocator().GetGeneration(FromAPI(device)->id);
result.swapchain = ToAPI(swapChain);
result.id = swapChain->GetWireId();
result.generation = swapChain->GetWireGeneration();
result.deviceId = FromAPI(device)->GetWireId();
result.deviceGeneration = FromAPI(device)->GetWireGeneration();
return result;
}
ReservedDevice Client::ReserveDevice() {
auto* allocation = DeviceAllocator().New(this);
Device* device = DeviceAllocator().New(this);
ReservedDevice result;
result.device = ToAPI(allocation->object.get());
result.id = allocation->object->id;
result.generation = allocation->generation;
result.device = ToAPI(device);
result.id = device->GetWireId();
result.generation = device->GetWireGeneration();
return result;
}
ReservedInstance Client::ReserveInstance() {
auto* allocation = InstanceAllocator().New(this);
Instance* instance = InstanceAllocator().New(this);
ReservedInstance result;
result.instance = ToAPI(allocation->object.get());
result.id = allocation->object->id;
result.generation = allocation->generation;
result.instance = ToAPI(instance);
result.id = instance->GetWireId();
result.generation = instance->GetWireGeneration();
return result;
}

View File

@ -158,7 +158,7 @@ bool Device::PopErrorScope(WGPUErrorCallback callback, void* userdata) {
uint64_t serial = mErrorScopes.Add({callback, userdata});
DevicePopErrorScopeCmd cmd;
cmd.deviceId = this->id;
cmd.deviceId = GetWireId();
cmd.requestSerial = serial;
client->SerializeCommand(cmd);
return true;
@ -219,12 +219,11 @@ WGPUQueue Device::GetQueue() {
// on construction.
if (mQueue == nullptr) {
// Get the primary queue for this device.
auto* allocation = client->QueueAllocator().New(client);
mQueue = allocation->object.get();
mQueue = client->QueueAllocator().New(client);
DeviceGetQueueCmd cmd;
cmd.self = ToAPI(this);
cmd.result = ObjectHandle{allocation->object->id, allocation->generation};
cmd.result = mQueue->GetWireHandle();
client->SerializeCommand(cmd);
}
@ -241,20 +240,20 @@ void Device::CreateComputePipelineAsync(WGPUComputePipelineDescriptor const* des
"GPU device disconnected", userdata);
}
auto* allocation = client->ComputePipelineAllocator().New(client);
ComputePipeline* pipeline = client->ComputePipelineAllocator().New(client);
CreatePipelineAsyncRequest request = {};
request.createComputePipelineAsyncCallback = callback;
request.userdata = userdata;
request.pipelineObjectID = allocation->object->id;
request.pipelineObjectID = pipeline->GetWireId();
uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
DeviceCreateComputePipelineAsyncCmd cmd;
cmd.deviceId = this->id;
cmd.deviceId = GetWireId();
cmd.descriptor = descriptor;
cmd.requestSerial = serial;
cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation};
cmd.pipelineObjectHandle = pipeline->GetWireHandle();
client->SerializeCommand(cmd);
}
@ -292,20 +291,20 @@ void Device::CreateRenderPipelineAsync(WGPURenderPipelineDescriptor const* descr
"GPU device disconnected", userdata);
}
auto* allocation = client->RenderPipelineAllocator().New(client);
RenderPipeline* pipeline = client->RenderPipelineAllocator().New(client);
CreatePipelineAsyncRequest request = {};
request.createRenderPipelineAsyncCallback = callback;
request.userdata = userdata;
request.pipelineObjectID = allocation->object->id;
request.pipelineObjectID = pipeline->GetWireId();
uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
DeviceCreateRenderPipelineAsyncCmd cmd;
cmd.deviceId = this->id;
cmd.deviceId = GetWireId();
cmd.descriptor = descriptor;
cmd.requestSerial = serial;
cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
cmd.pipelineObjectHandle = pipeline->GetWireHandle();
client->SerializeCommand(cmd);
}

View File

@ -40,13 +40,13 @@ void Instance::RequestAdapter(const WGPURequestAdapterOptions* options,
return;
}
auto* allocation = client->AdapterAllocator().New(client);
uint64_t serial = mRequestAdapterRequests.Add({callback, allocation->object->id, userdata});
Adapter* adapter = client->AdapterAllocator().New(client);
uint64_t serial = mRequestAdapterRequests.Add({callback, adapter->GetWireId(), userdata});
InstanceRequestAdapterCmd cmd;
cmd.instanceId = this->id;
cmd.instanceId = GetWireId();
cmd.requestSerial = serial;
cmd.adapterObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
cmd.adapterObjectHandle = adapter->GetWireHandle();
cmd.options = options;
client->SerializeCommand(cmd);

View File

@ -23,86 +23,72 @@
#include "dawn/common/Assert.h"
#include "dawn/common/Compiler.h"
#include "dawn/wire/WireCmd_autogen.h"
#include "dawn/wire/client/ObjectBase.h"
namespace dawn::wire::client {
template <typename T>
class ObjectAllocator {
public:
struct ObjectAndSerial {
ObjectAndSerial(std::unique_ptr<T> object, uint32_t generation)
: object(std::move(object)), generation(generation) {}
std::unique_ptr<T> object;
uint32_t generation;
};
ObjectAllocator() {
// ID 0 is nullptr
mObjects.emplace_back(nullptr, 0);
mObjects.emplace_back(nullptr);
}
template <typename Client>
ObjectAndSerial* New(Client* client) {
uint32_t id = GetNewId();
ObjectBaseParams params = {client, id};
T* New(Client* client) {
ObjectHandle handle = GetFreeHandle();
ObjectBaseParams params = {client, handle};
auto object = std::make_unique<T>(params);
client->TrackObject(object.get());
if (id >= mObjects.size()) {
ASSERT(id == mObjects.size());
mObjects.emplace_back(std::move(object), 0);
if (handle.id >= mObjects.size()) {
ASSERT(handle.id == mObjects.size());
mObjects.emplace_back(std::move(object));
} else {
ASSERT(mObjects[id].object == nullptr);
mObjects[id].generation++;
// The generation should never overflow. We don't recycle ObjectIds that would
// overflow their next generation.
ASSERT(mObjects[id].generation != 0);
mObjects[id].object = std::move(object);
ASSERT(handle.generation != 0);
ASSERT(mObjects[handle.id] == nullptr);
mObjects[handle.id] = std::move(object);
}
return &mObjects[id];
return mObjects[handle.id].get();
}
void Free(T* obj) {
ASSERT(obj->IsInList());
if (DAWN_LIKELY(mObjects[obj->id].generation != std::numeric_limits<uint32_t>::max())) {
// Only recycle this ObjectId if the generation won't overflow on the next
// allocation.
FreeId(obj->id);
// The wire reuses ID for objects to keep them in a packed array starting from 0.
// To avoid issues with asynchronous server->client communication referring to an ID that's
// already reused, each handle also has a generation that's increment by one on each reuse.
// Avoid overflows by only reusing the ID if the increment of the generation won't overflow.
ObjectHandle currentHandle = obj->GetWireHandle();
if (DAWN_LIKELY(currentHandle.generation != std::numeric_limits<ObjectGeneration>::max())) {
mFreeHandles.push_back({currentHandle.id, currentHandle.generation + 1});
}
mObjects[obj->id].object = nullptr;
mObjects[currentHandle.id] = nullptr;
}
T* GetObject(uint32_t id) {
if (id >= mObjects.size()) {
return nullptr;
}
return mObjects[id].object.get();
}
uint32_t GetGeneration(uint32_t id) {
if (id >= mObjects.size()) {
return 0;
}
return mObjects[id].generation;
return mObjects[id].get();
}
private:
uint32_t GetNewId() {
if (mFreeIds.empty()) {
return mCurrentId++;
ObjectHandle GetFreeHandle() {
if (mFreeHandles.empty()) {
return {mCurrentId++, 0};
}
uint32_t id = mFreeIds.back();
mFreeIds.pop_back();
return id;
ObjectHandle handle = mFreeHandles.back();
mFreeHandles.pop_back();
return handle;
}
void FreeId(uint32_t id) { mFreeIds.push_back(id); }
// 0 is an ID reserved to represent nullptr
uint32_t mCurrentId = 1;
std::vector<uint32_t> mFreeIds;
std::vector<ObjectAndSerial> mObjects;
std::vector<ObjectHandle> mFreeHandles;
std::vector<std::unique_ptr<T>> mObjects;
};
} // namespace dawn::wire::client

View File

@ -17,10 +17,22 @@
namespace dawn::wire::client {
ObjectBase::ObjectBase(const ObjectBaseParams& params)
: client(params.client), refcount(1), id(params.id) {}
: client(params.client), refcount(1), mHandle(params.handle) {}
ObjectBase::~ObjectBase() {
RemoveFromList();
}
const ObjectHandle& ObjectBase::GetWireHandle() const {
return mHandle;
}
ObjectId ObjectBase::GetWireId() const {
return mHandle.id;
}
ObjectGeneration ObjectBase::GetWireGeneration() const {
return mHandle.generation;
}
} // namespace dawn::wire::client

View File

@ -18,7 +18,6 @@
#include "dawn/webgpu.h"
#include "dawn/common/LinkedList.h"
#include "dawn/wire/ObjectType_autogen.h"
#include "dawn/wire/ObjectHandle.h"
namespace dawn::wire::client {
@ -27,7 +26,7 @@ class Client;
struct ObjectBaseParams {
Client* client;
ObjectId id;
ObjectHandle handle;
};
// All objects on the client side have:
@ -35,15 +34,23 @@ struct ObjectBaseParams {
// - The external reference count, starting at 1.
// - An ID that is used to refer to this object when talking with the server side
// - A next/prev pointer. They are part of a linked list of objects of the same type.
struct ObjectBase : public LinkNode<ObjectBase> {
class ObjectBase : public LinkNode<ObjectBase> {
public:
explicit ObjectBase(const ObjectBaseParams& params);
~ObjectBase();
virtual void CancelCallbacksForDisconnect() {}
const ObjectHandle& GetWireHandle() const;
ObjectId GetWireId() const;
ObjectGeneration GetWireGeneration() const;
// TODO(dawn:1451): Make these members private.
Client* const client;
uint32_t refcount;
const ObjectId id;
private:
const ObjectHandle mHandle;
};
} // namespace dawn::wire::client

View File

@ -22,19 +22,18 @@ namespace dawn::wire::client {
// static
WGPUQuerySet QuerySet::Create(Device* device, const WGPUQuerySetDescriptor* descriptor) {
Client* wireClient = device->client;
auto* objectAndSerial = wireClient->QuerySetAllocator().New(wireClient);
QuerySet* querySet = wireClient->QuerySetAllocator().New(wireClient);
// Copy over descriptor data for reflection.
QuerySet* querySet = objectAndSerial->object.get();
querySet->mType = descriptor->type;
querySet->mCount = descriptor->count;
// Send the Device::CreateQuerySet command without modifications.
DeviceCreateQuerySetCmd cmd;
cmd.self = ToAPI(device);
cmd.selfId = device->id;
cmd.selfId = device->GetWireId();
cmd.descriptor = descriptor;
cmd.result = ObjectHandle{querySet->id, objectAndSerial->generation};
cmd.result = querySet->GetWireHandle();
wireClient->SerializeCommand(cmd);
return ToAPI(querySet);

View File

@ -44,7 +44,7 @@ void Queue::OnSubmittedWorkDone(uint64_t signalValue,
uint64_t serial = mOnWorkDoneRequests.Add({callback, userdata});
QueueOnSubmittedWorkDoneCmd cmd;
cmd.queueId = this->id;
cmd.queueId = GetWireId();
cmd.signalValue = signalValue;
cmd.requestSerial = serial;
@ -55,8 +55,8 @@ void Queue::WriteBuffer(WGPUBuffer cBuffer, uint64_t bufferOffset, const void* d
Buffer* buffer = FromAPI(cBuffer);
QueueWriteBufferCmd cmd;
cmd.queueId = id;
cmd.bufferId = buffer->id;
cmd.queueId = GetWireId();
cmd.bufferId = buffer->GetWireId();
cmd.bufferOffset = bufferOffset;
cmd.data = static_cast<const uint8_t*>(data);
cmd.size = size;
@ -70,7 +70,7 @@ void Queue::WriteTexture(const WGPUImageCopyTexture* destination,
const WGPUTextureDataLayout* dataLayout,
const WGPUExtent3D* writeSize) {
QueueWriteTextureCmd cmd;
cmd.queueId = id;
cmd.queueId = GetWireId();
cmd.destination = destination;
cmd.data = static_cast<const uint8_t*>(data);
cmd.dataSize = dataSize;

View File

@ -31,7 +31,7 @@ void ShaderModule::GetCompilationInfo(WGPUCompilationInfoCallback callback, void
uint64_t serial = mCompilationInfoRequests.Add({callback, userdata});
ShaderModuleGetCompilationInfoCmd cmd;
cmd.shaderModuleId = this->id;
cmd.shaderModuleId = GetWireId();
cmd.requestSerial = serial;
client->SerializeCommand(cmd);

View File

@ -22,10 +22,9 @@ namespace dawn::wire::client {
// static
WGPUTexture Texture::Create(Device* device, const WGPUTextureDescriptor* descriptor) {
Client* wireClient = device->client;
auto* textureObjectAndSerial = wireClient->TextureAllocator().New(wireClient);
Texture* texture = wireClient->TextureAllocator().New(wireClient);
// Copy over descriptor data for reflection.
Texture* texture = textureObjectAndSerial->object.get();
texture->mSize = descriptor->size;
texture->mMipLevelCount = descriptor->mipLevelCount;
texture->mSampleCount = descriptor->sampleCount;
@ -36,9 +35,9 @@ WGPUTexture Texture::Create(Device* device, const WGPUTextureDescriptor* descrip
// Send the Device::CreateTexture command without modifications.
DeviceCreateTextureCmd cmd;
cmd.self = ToAPI(device);
cmd.selfId = device->id;
cmd.selfId = device->GetWireId();
cmd.descriptor = descriptor;
cmd.result = ObjectHandle{texture->id, textureObjectAndSerial->generation};
cmd.result = texture->GetWireHandle();
wireClient->SerializeCommand(cmd);
return ToAPI(texture);