dawn_wire/client: Add RequestTracker helper

This helper helps ensure correct handling of request maps by:

 - Forcing erasing to happen immediately when acquiring a request. This
   prevents some cases of iterator invalidation if we later change the
   container type.
 - Implements correct closure of all callbacks, including if the
   callbacks themselves add more callbacks.

Bug: dawn:1092

Change-Id: Ia0ba9f050bbf3f0dee846f537910523bebb3bf1b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/63003
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
This commit is contained in:
Corentin Wallez 2021-09-01 16:40:22 +00:00 committed by Dawn LUCI CQ
parent 4a4a804476
commit baf8df396c
10 changed files with 190 additions and 172 deletions

View File

@ -140,25 +140,20 @@ namespace dawn_wire { namespace client {
} }
Buffer::~Buffer() { Buffer::~Buffer() {
// Callbacks need to be fired in all cases, as they can handle freeing resources ClearAllCallbacks(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
// so we call them with "DestroyedBeforeCallback" status.
for (auto& it : mRequests) {
if (it.second.callback) {
it.second.callback(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback, it.second.userdata);
}
}
mRequests.clear();
FreeMappedData(); FreeMappedData();
} }
void Buffer::CancelCallbacksForDisconnect() { void Buffer::CancelCallbacksForDisconnect() {
for (auto& it : mRequests) { ClearAllCallbacks(WGPUBufferMapAsyncStatus_DeviceLost);
if (it.second.callback) {
it.second.callback(WGPUBufferMapAsyncStatus_DeviceLost, it.second.userdata);
} }
void Buffer::ClearAllCallbacks(WGPUBufferMapAsyncStatus status) {
mRequests.CloseAll([status](MapRequestData* request) {
if (request->callback != nullptr) {
request->callback(status, request->userdata);
} }
mRequests.clear(); });
} }
void Buffer::MapAsync(WGPUMapModeFlags mode, void Buffer::MapAsync(WGPUMapModeFlags mode,
@ -177,10 +172,7 @@ namespace dawn_wire { namespace client {
// Create the request structure that will hold information while this mapping is // Create the request structure that will hold information while this mapping is
// in flight. // in flight.
uint64_t serial = mRequestSerial++; MapRequestData request = {};
ASSERT(mRequests.find(serial) == mRequests.end());
Buffer::MapRequestData request = {};
request.callback = callback; request.callback = callback;
request.userdata = userdata; request.userdata = userdata;
request.offset = offset; request.offset = offset;
@ -191,6 +183,8 @@ namespace dawn_wire { namespace client {
request.type = MapRequestType::Write; request.type = MapRequestType::Write;
} }
uint64_t serial = mRequests.Add(std::move(request));
// Serialize the command to send to the server. // Serialize the command to send to the server.
BufferMapAsyncCmd cmd; BufferMapAsyncCmd cmd;
cmd.bufferId = this->id; cmd.bufferId = this->id;
@ -200,26 +194,17 @@ namespace dawn_wire { namespace client {
cmd.size = size; cmd.size = size;
client->SerializeCommand(cmd); client->SerializeCommand(cmd);
// Register this request so that we can retrieve it from its serial when the server
// sends the callback.
mRequests[serial] = std::move(request);
} }
bool Buffer::OnMapAsyncCallback(uint64_t requestSerial, bool Buffer::OnMapAsyncCallback(uint64_t requestSerial,
uint32_t status, uint32_t status,
uint64_t readDataUpdateInfoLength, uint64_t readDataUpdateInfoLength,
const uint8_t* readDataUpdateInfo) { const uint8_t* readDataUpdateInfo) {
auto requestIt = mRequests.find(requestSerial); MapRequestData request;
if (requestIt == mRequests.end()) { if (!mRequests.Acquire(requestSerial, &request)) {
return false; return false;
} }
auto request = std::move(requestIt->second);
// Delete the request before calling the callback otherwise the callback could be fired a
// second time. If, for example, buffer.Unmap() is called inside the callback.
mRequests.erase(requestIt);
auto FailRequest = [&request]() -> bool { auto FailRequest = [&request]() -> bool {
if (request.callback != nullptr) { if (request.callback != nullptr) {
request.callback(WGPUBufferMapAsyncStatus_DeviceLost, request.userdata); request.callback(WGPUBufferMapAsyncStatus_DeviceLost, request.userdata);
@ -352,11 +337,11 @@ namespace dawn_wire { namespace client {
mMapSize = 0; mMapSize = 0;
// Tag all mapping requests still in flight as unmapped before callback. // Tag all mapping requests still in flight as unmapped before callback.
for (auto& it : mRequests) { mRequests.ForAll([](MapRequestData* request) {
if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) { if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) {
it.second.clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback; request->clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback;
}
} }
});
BufferUnmapCmd cmd; BufferUnmapCmd cmd;
cmd.self = ToAPI(this); cmd.self = ToAPI(this);
@ -368,11 +353,11 @@ namespace dawn_wire { namespace client {
FreeMappedData(); FreeMappedData();
// Tag all mapping requests still in flight as destroyed before callback. // Tag all mapping requests still in flight as destroyed before callback.
for (auto& it : mRequests) { mRequests.ForAll([](MapRequestData* request) {
if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) { if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) {
it.second.clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback; request->clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
}
} }
});
BufferDestroyCmd cmd; BufferDestroyCmd cmd;
cmd.self = ToAPI(this); cmd.self = ToAPI(this);

View File

@ -19,8 +19,7 @@
#include "dawn_wire/WireClient.h" #include "dawn_wire/WireClient.h"
#include "dawn_wire/client/ObjectBase.h" #include "dawn_wire/client/ObjectBase.h"
#include "dawn_wire/client/RequestTracker.h"
#include <map>
namespace dawn_wire { namespace client { namespace dawn_wire { namespace client {
@ -52,6 +51,7 @@ namespace dawn_wire { namespace client {
private: private:
void CancelCallbacksForDisconnect() override; void CancelCallbacksForDisconnect() override;
void ClearAllCallbacks(WGPUBufferMapAsyncStatus status);
bool IsMappedForReading() const; bool IsMappedForReading() const;
bool IsMappedForWriting() const; bool IsMappedForWriting() const;
@ -86,8 +86,7 @@ namespace dawn_wire { namespace client {
MapRequestType type = MapRequestType::None; MapRequestType type = MapRequestType::None;
}; };
std::map<uint64_t, MapRequestData> mRequests; RequestTracker<MapRequestData> mRequests;
uint64_t mRequestSerial = 0;
uint64_t mSize = 0; uint64_t mSize = 0;
// Only one mapped pointer can be active at a time because Unmap clears all the in-flight // Only one mapped pointer can be active at a time because Unmap clears all the in-flight

View File

@ -19,6 +19,7 @@
#include <dawn_wire/Wire.h> #include <dawn_wire/Wire.h>
#include "common/LinkedList.h" #include "common/LinkedList.h"
#include "common/NonCopyable.h"
#include "dawn_wire/ChunkedCommandSerializer.h" #include "dawn_wire/ChunkedCommandSerializer.h"
#include "dawn_wire/WireClient.h" #include "dawn_wire/WireClient.h"
#include "dawn_wire/WireCmd_autogen.h" #include "dawn_wire/WireCmd_autogen.h"

View File

@ -48,26 +48,23 @@ namespace dawn_wire { namespace client {
} }
Device::~Device() { Device::~Device() {
// Fire pending error scopes mErrorScopes.CloseAll([](ErrorScopeData* request) {
auto errorScopes = std::move(mErrorScopes); request->callback(WGPUErrorType_Unknown, "Device destroyed before callback",
for (const auto& it : errorScopes) { request->userdata);
it.second.callback(WGPUErrorType_Unknown, "Device destroyed before callback", });
it.second.userdata);
}
auto createPipelineAsyncRequests = std::move(mCreatePipelineAsyncRequests); mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) {
for (const auto& it : createPipelineAsyncRequests) { if (request->createComputePipelineAsyncCallback != nullptr) {
if (it.second.createComputePipelineAsyncCallback != nullptr) { request->createComputePipelineAsyncCallback(
it.second.createComputePipelineAsyncCallback(
WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
"Device destroyed before callback", it.second.userdata); "Device destroyed before callback", request->userdata);
} else { } else {
ASSERT(it.second.createRenderPipelineAsyncCallback != nullptr); ASSERT(request->createRenderPipelineAsyncCallback != nullptr);
it.second.createRenderPipelineAsyncCallback( request->createRenderPipelineAsyncCallback(
WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
"Device destroyed before callback", it.second.userdata); "Device destroyed before callback", request->userdata);
}
} }
});
} }
void Device::HandleError(WGPUErrorType errorType, const char* message) { void Device::HandleError(WGPUErrorType errorType, const char* message) {
@ -91,25 +88,22 @@ namespace dawn_wire { namespace client {
} }
void Device::CancelCallbacksForDisconnect() { void Device::CancelCallbacksForDisconnect() {
for (auto& it : mCreatePipelineAsyncRequests) { mErrorScopes.CloseAll([](ErrorScopeData* request) {
ASSERT((it.second.createComputePipelineAsyncCallback != nullptr) ^ request->callback(WGPUErrorType_DeviceLost, "Device lost", request->userdata);
(it.second.createRenderPipelineAsyncCallback != nullptr)); });
if (it.second.createRenderPipelineAsyncCallback) {
it.second.createRenderPipelineAsyncCallback(
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
it.second.userdata);
} else {
it.second.createComputePipelineAsyncCallback(
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
it.second.userdata);
}
}
mCreatePipelineAsyncRequests.clear();
for (auto& it : mErrorScopes) { mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) {
it.second.callback(WGPUErrorType_DeviceLost, "Device lost", it.second.userdata); if (request->createComputePipelineAsyncCallback != nullptr) {
request->createComputePipelineAsyncCallback(
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
request->userdata);
} else {
ASSERT(request->createRenderPipelineAsyncCallback != nullptr);
request->createRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost,
nullptr, "Device lost",
request->userdata);
} }
mErrorScopes.clear(); });
} }
std::weak_ptr<bool> Device::GetAliveWeakPtr() { std::weak_ptr<bool> Device::GetAliveWeakPtr() {
@ -152,10 +146,7 @@ namespace dawn_wire { namespace client {
return true; return true;
} }
uint64_t serial = mErrorScopeRequestSerial++; uint64_t serial = mErrorScopes.Add({callback, userdata});
ASSERT(mErrorScopes.find(serial) == mErrorScopes.end());
mErrorScopes[serial] = {callback, userdata};
DevicePopErrorScopeCmd cmd; DevicePopErrorScopeCmd cmd;
cmd.deviceId = this->id; cmd.deviceId = this->id;
@ -180,14 +171,11 @@ namespace dawn_wire { namespace client {
return false; return false;
} }
auto requestIt = mErrorScopes.find(requestSerial); ErrorScopeData request;
if (requestIt == mErrorScopes.end()) { if (!mErrorScopes.Acquire(requestSerial, &request)) {
return false; return false;
} }
ErrorScopeData request = std::move(requestIt->second);
mErrorScopes.erase(requestIt);
request.callback(type, message, request.userdata); request.callback(type, message, request.userdata);
return true; return true;
} }
@ -265,9 +253,6 @@ namespace dawn_wire { namespace client {
"GPU device disconnected", userdata); "GPU device disconnected", userdata);
} }
DeviceCreateComputePipelineAsyncCmd cmd;
cmd.deviceId = this->id;
// Copy compute to the deprecated computeStage or visa-versa, depending on which one is // Copy compute to the deprecated computeStage or visa-versa, depending on which one is
// populated, so that serialization doesn't fail. // populated, so that serialization doesn't fail.
// TODO(dawn:800): Remove once computeStage is removed. // TODO(dawn:800): Remove once computeStage is removed.
@ -280,35 +265,32 @@ namespace dawn_wire { namespace client {
localDescriptor.compute.entryPoint = localDescriptor.computeStage.entryPoint; localDescriptor.compute.entryPoint = localDescriptor.computeStage.entryPoint;
} }
cmd.descriptor = &localDescriptor;
uint64_t serial = mCreatePipelineAsyncRequestSerial++;
ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end());
cmd.requestSerial = serial;
auto* allocation = client->ComputePipelineAllocator().New(client); auto* allocation = client->ComputePipelineAllocator().New(client);
CreatePipelineAsyncRequest request = {}; CreatePipelineAsyncRequest request = {};
request.createComputePipelineAsyncCallback = callback; request.createComputePipelineAsyncCallback = callback;
request.userdata = userdata; request.userdata = userdata;
request.pipelineObjectID = allocation->object->id; request.pipelineObjectID = allocation->object->id;
cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation}; uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
client->SerializeCommand(cmd);
mCreatePipelineAsyncRequests[serial] = std::move(request); DeviceCreateComputePipelineAsyncCmd cmd;
cmd.deviceId = this->id;
cmd.descriptor = &localDescriptor;
cmd.requestSerial = serial;
cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation};
client->SerializeCommand(cmd);
} }
bool Device::OnCreateComputePipelineAsyncCallback(uint64_t requestSerial, bool Device::OnCreateComputePipelineAsyncCallback(uint64_t requestSerial,
WGPUCreatePipelineAsyncStatus status, WGPUCreatePipelineAsyncStatus status,
const char* message) { const char* message) {
const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial); CreatePipelineAsyncRequest request;
if (requestIt == mCreatePipelineAsyncRequests.end()) { if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) {
return false; return false;
} }
CreatePipelineAsyncRequest request = std::move(requestIt->second);
mCreatePipelineAsyncRequests.erase(requestIt);
auto pipelineAllocation = auto pipelineAllocation =
client->ComputePipelineAllocator().GetObject(request.pipelineObjectID); client->ComputePipelineAllocator().GetObject(request.pipelineObjectID);
@ -333,37 +315,33 @@ namespace dawn_wire { namespace client {
return callback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, return callback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr,
"GPU device disconnected", userdata); "GPU device disconnected", userdata);
} }
DeviceCreateRenderPipelineAsyncCmd cmd;
cmd.deviceId = this->id;
cmd.descriptor = descriptor;
uint64_t serial = mCreatePipelineAsyncRequestSerial++;
ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end());
cmd.requestSerial = serial;
auto* allocation = client->RenderPipelineAllocator().New(client); auto* allocation = client->RenderPipelineAllocator().New(client);
CreatePipelineAsyncRequest request = {}; CreatePipelineAsyncRequest request = {};
request.createRenderPipelineAsyncCallback = callback; request.createRenderPipelineAsyncCallback = callback;
request.userdata = userdata; request.userdata = userdata;
request.pipelineObjectID = allocation->object->id; request.pipelineObjectID = allocation->object->id;
cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation); uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
client->SerializeCommand(cmd);
mCreatePipelineAsyncRequests[serial] = std::move(request); DeviceCreateRenderPipelineAsyncCmd cmd;
cmd.deviceId = this->id;
cmd.descriptor = descriptor;
cmd.requestSerial = serial;
cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
client->SerializeCommand(cmd);
} }
bool Device::OnCreateRenderPipelineAsyncCallback(uint64_t requestSerial, bool Device::OnCreateRenderPipelineAsyncCallback(uint64_t requestSerial,
WGPUCreatePipelineAsyncStatus status, WGPUCreatePipelineAsyncStatus status,
const char* message) { const char* message) {
const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial); CreatePipelineAsyncRequest request;
if (requestIt == mCreatePipelineAsyncRequests.end()) { if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) {
return false; return false;
} }
CreatePipelineAsyncRequest request = std::move(requestIt->second);
mCreatePipelineAsyncRequests.erase(requestIt);
auto pipelineAllocation = auto pipelineAllocation =
client->RenderPipelineAllocator().GetObject(request.pipelineObjectID); client->RenderPipelineAllocator().GetObject(request.pipelineObjectID);

View File

@ -21,8 +21,8 @@
#include "dawn_wire/WireCmd_autogen.h" #include "dawn_wire/WireCmd_autogen.h"
#include "dawn_wire/client/ApiObjects_autogen.h" #include "dawn_wire/client/ApiObjects_autogen.h"
#include "dawn_wire/client/ObjectBase.h" #include "dawn_wire/client/ObjectBase.h"
#include "dawn_wire/client/RequestTracker.h"
#include <map>
#include <memory> #include <memory>
namespace dawn_wire { namespace client { namespace dawn_wire { namespace client {
@ -75,8 +75,7 @@ namespace dawn_wire { namespace client {
WGPUErrorCallback callback = nullptr; WGPUErrorCallback callback = nullptr;
void* userdata = nullptr; void* userdata = nullptr;
}; };
std::map<uint64_t, ErrorScopeData> mErrorScopes; RequestTracker<ErrorScopeData> mErrorScopes;
uint64_t mErrorScopeRequestSerial = 0;
uint64_t mErrorScopeStackSize = 0; uint64_t mErrorScopeStackSize = 0;
struct CreatePipelineAsyncRequest { struct CreatePipelineAsyncRequest {
@ -85,8 +84,7 @@ namespace dawn_wire { namespace client {
void* userdata = nullptr; void* userdata = nullptr;
ObjectId pipelineObjectID; ObjectId pipelineObjectID;
}; };
std::map<uint64_t, CreatePipelineAsyncRequest> mCreatePipelineAsyncRequests; RequestTracker<CreatePipelineAsyncRequest> mCreatePipelineAsyncRequests;
uint64_t mCreatePipelineAsyncRequestSerial = 0;
WGPUErrorCallback mErrorCallback = nullptr; WGPUErrorCallback mErrorCallback = nullptr;
WGPUDeviceLostCallback mDeviceLostCallback = nullptr; WGPUDeviceLostCallback mDeviceLostCallback = nullptr;

View File

@ -24,17 +24,11 @@ namespace dawn_wire { namespace client {
} }
bool Queue::OnWorkDoneCallback(uint64_t requestSerial, WGPUQueueWorkDoneStatus status) { bool Queue::OnWorkDoneCallback(uint64_t requestSerial, WGPUQueueWorkDoneStatus status) {
auto requestIt = mOnWorkDoneRequests.find(requestSerial); OnWorkDoneData request;
if (requestIt == mOnWorkDoneRequests.end()) { if (!mOnWorkDoneRequests.Acquire(requestSerial, &request)) {
return false; return false;
} }
// Remove the request data so that the callback cannot be called again.
// ex.) inside the callback: if the queue is deleted (when there are multiple queues),
// all callbacks reject.
OnWorkDoneData request = std::move(requestIt->second);
mOnWorkDoneRequests.erase(requestIt);
request.callback(status, request.userdata); request.callback(status, request.userdata);
return true; return true;
} }
@ -47,16 +41,13 @@ namespace dawn_wire { namespace client {
return; return;
} }
uint32_t serial = mOnWorkDoneSerial++; uint64_t serial = mOnWorkDoneRequests.Add({callback, userdata});
ASSERT(mOnWorkDoneRequests.find(serial) == mOnWorkDoneRequests.end());
QueueOnSubmittedWorkDoneCmd cmd; QueueOnSubmittedWorkDoneCmd cmd;
cmd.queueId = this->id; cmd.queueId = this->id;
cmd.signalValue = signalValue; cmd.signalValue = signalValue;
cmd.requestSerial = serial; cmd.requestSerial = serial;
mOnWorkDoneRequests[serial] = {callback, userdata};
client->SerializeCommand(cmd); client->SerializeCommand(cmd);
} }
@ -97,12 +88,11 @@ namespace dawn_wire { namespace client {
} }
void Queue::ClearAllCallbacks(WGPUQueueWorkDoneStatus status) { void Queue::ClearAllCallbacks(WGPUQueueWorkDoneStatus status) {
for (auto& it : mOnWorkDoneRequests) { mOnWorkDoneRequests.CloseAll([status](OnWorkDoneData* request) {
if (it.second.callback) { if (request->callback != nullptr) {
it.second.callback(status, it.second.userdata); request->callback(status, request->userdata);
} }
} });
mOnWorkDoneRequests.clear();
} }
}} // namespace dawn_wire::client }} // namespace dawn_wire::client

View File

@ -19,8 +19,7 @@
#include "dawn_wire/WireClient.h" #include "dawn_wire/WireClient.h"
#include "dawn_wire/client/ObjectBase.h" #include "dawn_wire/client/ObjectBase.h"
#include "dawn_wire/client/RequestTracker.h"
#include <map>
namespace dawn_wire { namespace client { namespace dawn_wire { namespace client {
@ -44,15 +43,13 @@ namespace dawn_wire { namespace client {
private: private:
void CancelCallbacksForDisconnect() override; void CancelCallbacksForDisconnect() override;
void ClearAllCallbacks(WGPUQueueWorkDoneStatus status); void ClearAllCallbacks(WGPUQueueWorkDoneStatus status);
struct OnWorkDoneData { struct OnWorkDoneData {
WGPUQueueWorkDoneCallback callback = nullptr; WGPUQueueWorkDoneCallback callback = nullptr;
void* userdata = nullptr; void* userdata = nullptr;
}; };
uint64_t mOnWorkDoneSerial = 0; RequestTracker<OnWorkDoneData> mOnWorkDoneRequests;
std::map<uint64_t, OnWorkDoneData> mOnWorkDoneRequests;
}; };
}} // namespace dawn_wire::client }} // namespace dawn_wire::client

View File

@ -0,0 +1,82 @@
// Copyright 2021 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DAWNWIRE_CLIENT_REQUESTTRACKER_H_
#define DAWNWIRE_CLIENT_REQUESTTRACKER_H_
#include "common/Assert.h"
#include "common/NonCopyable.h"
#include <cstdint>
#include <map>
namespace dawn_wire { namespace client {
class Device;
class MemoryTransferService;
template <typename Request>
class RequestTracker : NonCopyable {
public:
~RequestTracker() {
ASSERT(mRequests.empty());
}
uint64_t Add(Request&& request) {
mSerial++;
mRequests.emplace(mSerial, request);
return mSerial;
}
bool Acquire(uint64_t serial, Request* request) {
auto it = mRequests.find(serial);
if (it == mRequests.end()) {
return false;
}
*request = std::move(it->second);
mRequests.erase(it);
return true;
}
template <typename CloseFunc>
void CloseAll(CloseFunc&& closeFunc) {
// Call closeFunc on all requests while handling reentrancy where the callback of some
// requests may add some additional requests. We guarantee all callbacks for requests
// are called exactly onces, so keep closing new requests if the first batch added more.
// It is fine to loop infinitely here if that's what the application makes use do.
while (!mRequests.empty()) {
// Move mRequests to a local variable so that further reentrant modifications of
// mRequests don't invalidate the iterators.
auto allRequests = std::move(mRequests);
for (auto& it : allRequests) {
closeFunc(&it.second);
}
}
}
template <typename F>
void ForAll(F&& f) {
for (auto& it : mRequests) {
f(&it.second);
}
}
private:
uint64_t mSerial;
std::map<uint64_t, Request> mRequests;
};
}} // namespace dawn_wire::client
#endif // DAWNWIRE_CLIENT_REQUESTTRACKER_H_

View File

@ -19,15 +19,7 @@
namespace dawn_wire { namespace client { namespace dawn_wire { namespace client {
ShaderModule::~ShaderModule() { ShaderModule::~ShaderModule() {
// Callbacks need to be fired in all cases, as they can handle freeing resources. So we call ClearAllCallbacks(WGPUCompilationInfoRequestStatus_Unknown);
// them with "Unknown" status.
for (auto& it : mCompilationInfoRequests) {
if (it.second.callback) {
it.second.callback(WGPUCompilationInfoRequestStatus_Unknown, nullptr,
it.second.userdata);
}
}
mCompilationInfoRequests.clear();
} }
void ShaderModule::GetCompilationInfo(WGPUCompilationInfoCallback callback, void* userdata) { void ShaderModule::GetCompilationInfo(WGPUCompilationInfoCallback callback, void* userdata) {
@ -36,41 +28,37 @@ namespace dawn_wire { namespace client {
return; return;
} }
uint64_t serial = mCompilationInfoRequestSerial++; uint64_t serial = mCompilationInfoRequests.Add({callback, userdata});
ShaderModuleGetCompilationInfoCmd cmd; ShaderModuleGetCompilationInfoCmd cmd;
cmd.shaderModuleId = this->id; cmd.shaderModuleId = this->id;
cmd.requestSerial = serial; cmd.requestSerial = serial;
mCompilationInfoRequests[serial] = {callback, userdata};
client->SerializeCommand(cmd); client->SerializeCommand(cmd);
} }
bool ShaderModule::GetCompilationInfoCallback(uint64_t requestSerial, bool ShaderModule::GetCompilationInfoCallback(uint64_t requestSerial,
WGPUCompilationInfoRequestStatus status, WGPUCompilationInfoRequestStatus status,
const WGPUCompilationInfo* info) { const WGPUCompilationInfo* info) {
auto requestIt = mCompilationInfoRequests.find(requestSerial); CompilationInfoRequest request;
if (requestIt == mCompilationInfoRequests.end()) { if (!mCompilationInfoRequests.Acquire(requestSerial, &request)) {
return false; return false;
} }
// Remove the request data so that the callback cannot be called again.
// ex.) inside the callback: if the shader module is deleted, all callbacks reject.
CompilationInfoRequest request = std::move(requestIt->second);
mCompilationInfoRequests.erase(requestIt);
request.callback(status, info, request.userdata); request.callback(status, info, request.userdata);
return true; return true;
} }
void ShaderModule::CancelCallbacksForDisconnect() { void ShaderModule::CancelCallbacksForDisconnect() {
for (auto& it : mCompilationInfoRequests) { ClearAllCallbacks(WGPUCompilationInfoRequestStatus_DeviceLost);
if (it.second.callback) {
it.second.callback(WGPUCompilationInfoRequestStatus_DeviceLost, nullptr,
it.second.userdata);
} }
void ShaderModule::ClearAllCallbacks(WGPUCompilationInfoRequestStatus status) {
mCompilationInfoRequests.CloseAll([status](CompilationInfoRequest* request) {
if (request->callback != nullptr) {
request->callback(status, nullptr, request->userdata);
} }
mCompilationInfoRequests.clear(); });
} }
}} // namespace dawn_wire::client }} // namespace dawn_wire::client

View File

@ -17,8 +17,8 @@
#include <dawn/webgpu.h> #include <dawn/webgpu.h>
#include "common/SerialMap.h"
#include "dawn_wire/client/ObjectBase.h" #include "dawn_wire/client/ObjectBase.h"
#include "dawn_wire/client/RequestTracker.h"
namespace dawn_wire { namespace client { namespace dawn_wire { namespace client {
@ -32,15 +32,15 @@ namespace dawn_wire { namespace client {
WGPUCompilationInfoRequestStatus status, WGPUCompilationInfoRequestStatus status,
const WGPUCompilationInfo* info); const WGPUCompilationInfo* info);
void CancelCallbacksForDisconnect() override;
private: private:
void CancelCallbacksForDisconnect() override;
void ClearAllCallbacks(WGPUCompilationInfoRequestStatus status);
struct CompilationInfoRequest { struct CompilationInfoRequest {
WGPUCompilationInfoCallback callback = nullptr; WGPUCompilationInfoCallback callback = nullptr;
void* userdata = nullptr; void* userdata = nullptr;
}; };
uint64_t mCompilationInfoRequestSerial = 0; RequestTracker<CompilationInfoRequest> mCompilationInfoRequests;
std::map<uint64_t, CompilationInfoRequest> mCompilationInfoRequests;
}; };
}} // namespace dawn_wire::client }} // namespace dawn_wire::client