dawn_wire/client: Add RequestTracker helper
This helper helps ensure correct handling of request maps by: - Forcing erasing to happen immediately when acquiring a request. This prevents some cases of iterator invalidation if we later change the container type. - Implements correct closure of all callbacks, including if the callbacks themselves add more callbacks. Bug: dawn:1092 Change-Id: Ia0ba9f050bbf3f0dee846f537910523bebb3bf1b Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/63003 Commit-Queue: Corentin Wallez <cwallez@chromium.org> Auto-Submit: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Austin Eng <enga@chromium.org>
This commit is contained in:
parent
4a4a804476
commit
baf8df396c
|
@ -140,25 +140,20 @@ namespace dawn_wire { namespace client {
|
|||
}
|
||||
|
||||
Buffer::~Buffer() {
|
||||
// Callbacks need to be fired in all cases, as they can handle freeing resources
|
||||
// so we call them with "DestroyedBeforeCallback" status.
|
||||
for (auto& it : mRequests) {
|
||||
if (it.second.callback) {
|
||||
it.second.callback(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback, it.second.userdata);
|
||||
}
|
||||
}
|
||||
mRequests.clear();
|
||||
|
||||
ClearAllCallbacks(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
|
||||
FreeMappedData();
|
||||
}
|
||||
|
||||
void Buffer::CancelCallbacksForDisconnect() {
|
||||
for (auto& it : mRequests) {
|
||||
if (it.second.callback) {
|
||||
it.second.callback(WGPUBufferMapAsyncStatus_DeviceLost, it.second.userdata);
|
||||
ClearAllCallbacks(WGPUBufferMapAsyncStatus_DeviceLost);
|
||||
}
|
||||
|
||||
void Buffer::ClearAllCallbacks(WGPUBufferMapAsyncStatus status) {
|
||||
mRequests.CloseAll([status](MapRequestData* request) {
|
||||
if (request->callback != nullptr) {
|
||||
request->callback(status, request->userdata);
|
||||
}
|
||||
mRequests.clear();
|
||||
});
|
||||
}
|
||||
|
||||
void Buffer::MapAsync(WGPUMapModeFlags mode,
|
||||
|
@ -177,10 +172,7 @@ namespace dawn_wire { namespace client {
|
|||
|
||||
// Create the request structure that will hold information while this mapping is
|
||||
// in flight.
|
||||
uint64_t serial = mRequestSerial++;
|
||||
ASSERT(mRequests.find(serial) == mRequests.end());
|
||||
|
||||
Buffer::MapRequestData request = {};
|
||||
MapRequestData request = {};
|
||||
request.callback = callback;
|
||||
request.userdata = userdata;
|
||||
request.offset = offset;
|
||||
|
@ -191,6 +183,8 @@ namespace dawn_wire { namespace client {
|
|||
request.type = MapRequestType::Write;
|
||||
}
|
||||
|
||||
uint64_t serial = mRequests.Add(std::move(request));
|
||||
|
||||
// Serialize the command to send to the server.
|
||||
BufferMapAsyncCmd cmd;
|
||||
cmd.bufferId = this->id;
|
||||
|
@ -200,26 +194,17 @@ namespace dawn_wire { namespace client {
|
|||
cmd.size = size;
|
||||
|
||||
client->SerializeCommand(cmd);
|
||||
|
||||
// Register this request so that we can retrieve it from its serial when the server
|
||||
// sends the callback.
|
||||
mRequests[serial] = std::move(request);
|
||||
}
|
||||
|
||||
bool Buffer::OnMapAsyncCallback(uint64_t requestSerial,
|
||||
uint32_t status,
|
||||
uint64_t readDataUpdateInfoLength,
|
||||
const uint8_t* readDataUpdateInfo) {
|
||||
auto requestIt = mRequests.find(requestSerial);
|
||||
if (requestIt == mRequests.end()) {
|
||||
MapRequestData request;
|
||||
if (!mRequests.Acquire(requestSerial, &request)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto request = std::move(requestIt->second);
|
||||
// Delete the request before calling the callback otherwise the callback could be fired a
|
||||
// second time. If, for example, buffer.Unmap() is called inside the callback.
|
||||
mRequests.erase(requestIt);
|
||||
|
||||
auto FailRequest = [&request]() -> bool {
|
||||
if (request.callback != nullptr) {
|
||||
request.callback(WGPUBufferMapAsyncStatus_DeviceLost, request.userdata);
|
||||
|
@ -352,11 +337,11 @@ namespace dawn_wire { namespace client {
|
|||
mMapSize = 0;
|
||||
|
||||
// Tag all mapping requests still in flight as unmapped before callback.
|
||||
for (auto& it : mRequests) {
|
||||
if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) {
|
||||
it.second.clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback;
|
||||
}
|
||||
mRequests.ForAll([](MapRequestData* request) {
|
||||
if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) {
|
||||
request->clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback;
|
||||
}
|
||||
});
|
||||
|
||||
BufferUnmapCmd cmd;
|
||||
cmd.self = ToAPI(this);
|
||||
|
@ -368,11 +353,11 @@ namespace dawn_wire { namespace client {
|
|||
FreeMappedData();
|
||||
|
||||
// Tag all mapping requests still in flight as destroyed before callback.
|
||||
for (auto& it : mRequests) {
|
||||
if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) {
|
||||
it.second.clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
|
||||
}
|
||||
mRequests.ForAll([](MapRequestData* request) {
|
||||
if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) {
|
||||
request->clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
|
||||
}
|
||||
});
|
||||
|
||||
BufferDestroyCmd cmd;
|
||||
cmd.self = ToAPI(this);
|
||||
|
|
|
@ -19,8 +19,7 @@
|
|||
|
||||
#include "dawn_wire/WireClient.h"
|
||||
#include "dawn_wire/client/ObjectBase.h"
|
||||
|
||||
#include <map>
|
||||
#include "dawn_wire/client/RequestTracker.h"
|
||||
|
||||
namespace dawn_wire { namespace client {
|
||||
|
||||
|
@ -52,6 +51,7 @@ namespace dawn_wire { namespace client {
|
|||
|
||||
private:
|
||||
void CancelCallbacksForDisconnect() override;
|
||||
void ClearAllCallbacks(WGPUBufferMapAsyncStatus status);
|
||||
|
||||
bool IsMappedForReading() const;
|
||||
bool IsMappedForWriting() const;
|
||||
|
@ -86,8 +86,7 @@ namespace dawn_wire { namespace client {
|
|||
|
||||
MapRequestType type = MapRequestType::None;
|
||||
};
|
||||
std::map<uint64_t, MapRequestData> mRequests;
|
||||
uint64_t mRequestSerial = 0;
|
||||
RequestTracker<MapRequestData> mRequests;
|
||||
uint64_t mSize = 0;
|
||||
|
||||
// Only one mapped pointer can be active at a time because Unmap clears all the in-flight
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <dawn_wire/Wire.h>
|
||||
|
||||
#include "common/LinkedList.h"
|
||||
#include "common/NonCopyable.h"
|
||||
#include "dawn_wire/ChunkedCommandSerializer.h"
|
||||
#include "dawn_wire/WireClient.h"
|
||||
#include "dawn_wire/WireCmd_autogen.h"
|
||||
|
|
|
@ -48,26 +48,23 @@ namespace dawn_wire { namespace client {
|
|||
}
|
||||
|
||||
Device::~Device() {
|
||||
// Fire pending error scopes
|
||||
auto errorScopes = std::move(mErrorScopes);
|
||||
for (const auto& it : errorScopes) {
|
||||
it.second.callback(WGPUErrorType_Unknown, "Device destroyed before callback",
|
||||
it.second.userdata);
|
||||
}
|
||||
mErrorScopes.CloseAll([](ErrorScopeData* request) {
|
||||
request->callback(WGPUErrorType_Unknown, "Device destroyed before callback",
|
||||
request->userdata);
|
||||
});
|
||||
|
||||
auto createPipelineAsyncRequests = std::move(mCreatePipelineAsyncRequests);
|
||||
for (const auto& it : createPipelineAsyncRequests) {
|
||||
if (it.second.createComputePipelineAsyncCallback != nullptr) {
|
||||
it.second.createComputePipelineAsyncCallback(
|
||||
mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) {
|
||||
if (request->createComputePipelineAsyncCallback != nullptr) {
|
||||
request->createComputePipelineAsyncCallback(
|
||||
WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
|
||||
"Device destroyed before callback", it.second.userdata);
|
||||
"Device destroyed before callback", request->userdata);
|
||||
} else {
|
||||
ASSERT(it.second.createRenderPipelineAsyncCallback != nullptr);
|
||||
it.second.createRenderPipelineAsyncCallback(
|
||||
ASSERT(request->createRenderPipelineAsyncCallback != nullptr);
|
||||
request->createRenderPipelineAsyncCallback(
|
||||
WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
|
||||
"Device destroyed before callback", it.second.userdata);
|
||||
}
|
||||
"Device destroyed before callback", request->userdata);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void Device::HandleError(WGPUErrorType errorType, const char* message) {
|
||||
|
@ -91,25 +88,22 @@ namespace dawn_wire { namespace client {
|
|||
}
|
||||
|
||||
void Device::CancelCallbacksForDisconnect() {
|
||||
for (auto& it : mCreatePipelineAsyncRequests) {
|
||||
ASSERT((it.second.createComputePipelineAsyncCallback != nullptr) ^
|
||||
(it.second.createRenderPipelineAsyncCallback != nullptr));
|
||||
if (it.second.createRenderPipelineAsyncCallback) {
|
||||
it.second.createRenderPipelineAsyncCallback(
|
||||
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
|
||||
it.second.userdata);
|
||||
} else {
|
||||
it.second.createComputePipelineAsyncCallback(
|
||||
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
|
||||
it.second.userdata);
|
||||
}
|
||||
}
|
||||
mCreatePipelineAsyncRequests.clear();
|
||||
mErrorScopes.CloseAll([](ErrorScopeData* request) {
|
||||
request->callback(WGPUErrorType_DeviceLost, "Device lost", request->userdata);
|
||||
});
|
||||
|
||||
for (auto& it : mErrorScopes) {
|
||||
it.second.callback(WGPUErrorType_DeviceLost, "Device lost", it.second.userdata);
|
||||
mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) {
|
||||
if (request->createComputePipelineAsyncCallback != nullptr) {
|
||||
request->createComputePipelineAsyncCallback(
|
||||
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
|
||||
request->userdata);
|
||||
} else {
|
||||
ASSERT(request->createRenderPipelineAsyncCallback != nullptr);
|
||||
request->createRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost,
|
||||
nullptr, "Device lost",
|
||||
request->userdata);
|
||||
}
|
||||
mErrorScopes.clear();
|
||||
});
|
||||
}
|
||||
|
||||
std::weak_ptr<bool> Device::GetAliveWeakPtr() {
|
||||
|
@ -152,10 +146,7 @@ namespace dawn_wire { namespace client {
|
|||
return true;
|
||||
}
|
||||
|
||||
uint64_t serial = mErrorScopeRequestSerial++;
|
||||
ASSERT(mErrorScopes.find(serial) == mErrorScopes.end());
|
||||
|
||||
mErrorScopes[serial] = {callback, userdata};
|
||||
uint64_t serial = mErrorScopes.Add({callback, userdata});
|
||||
|
||||
DevicePopErrorScopeCmd cmd;
|
||||
cmd.deviceId = this->id;
|
||||
|
@ -180,14 +171,11 @@ namespace dawn_wire { namespace client {
|
|||
return false;
|
||||
}
|
||||
|
||||
auto requestIt = mErrorScopes.find(requestSerial);
|
||||
if (requestIt == mErrorScopes.end()) {
|
||||
ErrorScopeData request;
|
||||
if (!mErrorScopes.Acquire(requestSerial, &request)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
ErrorScopeData request = std::move(requestIt->second);
|
||||
|
||||
mErrorScopes.erase(requestIt);
|
||||
request.callback(type, message, request.userdata);
|
||||
return true;
|
||||
}
|
||||
|
@ -265,9 +253,6 @@ namespace dawn_wire { namespace client {
|
|||
"GPU device disconnected", userdata);
|
||||
}
|
||||
|
||||
DeviceCreateComputePipelineAsyncCmd cmd;
|
||||
cmd.deviceId = this->id;
|
||||
|
||||
// Copy compute to the deprecated computeStage or visa-versa, depending on which one is
|
||||
// populated, so that serialization doesn't fail.
|
||||
// TODO(dawn:800): Remove once computeStage is removed.
|
||||
|
@ -280,35 +265,32 @@ namespace dawn_wire { namespace client {
|
|||
localDescriptor.compute.entryPoint = localDescriptor.computeStage.entryPoint;
|
||||
}
|
||||
|
||||
cmd.descriptor = &localDescriptor;
|
||||
|
||||
uint64_t serial = mCreatePipelineAsyncRequestSerial++;
|
||||
ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end());
|
||||
cmd.requestSerial = serial;
|
||||
|
||||
auto* allocation = client->ComputePipelineAllocator().New(client);
|
||||
|
||||
CreatePipelineAsyncRequest request = {};
|
||||
request.createComputePipelineAsyncCallback = callback;
|
||||
request.userdata = userdata;
|
||||
request.pipelineObjectID = allocation->object->id;
|
||||
|
||||
cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation};
|
||||
client->SerializeCommand(cmd);
|
||||
uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
|
||||
|
||||
mCreatePipelineAsyncRequests[serial] = std::move(request);
|
||||
DeviceCreateComputePipelineAsyncCmd cmd;
|
||||
cmd.deviceId = this->id;
|
||||
cmd.descriptor = &localDescriptor;
|
||||
cmd.requestSerial = serial;
|
||||
cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation};
|
||||
|
||||
client->SerializeCommand(cmd);
|
||||
}
|
||||
|
||||
bool Device::OnCreateComputePipelineAsyncCallback(uint64_t requestSerial,
|
||||
WGPUCreatePipelineAsyncStatus status,
|
||||
const char* message) {
|
||||
const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial);
|
||||
if (requestIt == mCreatePipelineAsyncRequests.end()) {
|
||||
CreatePipelineAsyncRequest request;
|
||||
if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CreatePipelineAsyncRequest request = std::move(requestIt->second);
|
||||
mCreatePipelineAsyncRequests.erase(requestIt);
|
||||
|
||||
auto pipelineAllocation =
|
||||
client->ComputePipelineAllocator().GetObject(request.pipelineObjectID);
|
||||
|
||||
|
@ -333,37 +315,33 @@ namespace dawn_wire { namespace client {
|
|||
return callback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr,
|
||||
"GPU device disconnected", userdata);
|
||||
}
|
||||
DeviceCreateRenderPipelineAsyncCmd cmd;
|
||||
cmd.deviceId = this->id;
|
||||
cmd.descriptor = descriptor;
|
||||
|
||||
uint64_t serial = mCreatePipelineAsyncRequestSerial++;
|
||||
ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end());
|
||||
cmd.requestSerial = serial;
|
||||
|
||||
auto* allocation = client->RenderPipelineAllocator().New(client);
|
||||
|
||||
CreatePipelineAsyncRequest request = {};
|
||||
request.createRenderPipelineAsyncCallback = callback;
|
||||
request.userdata = userdata;
|
||||
request.pipelineObjectID = allocation->object->id;
|
||||
|
||||
cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
|
||||
client->SerializeCommand(cmd);
|
||||
uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
|
||||
|
||||
mCreatePipelineAsyncRequests[serial] = std::move(request);
|
||||
DeviceCreateRenderPipelineAsyncCmd cmd;
|
||||
cmd.deviceId = this->id;
|
||||
cmd.descriptor = descriptor;
|
||||
cmd.requestSerial = serial;
|
||||
cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
|
||||
|
||||
client->SerializeCommand(cmd);
|
||||
}
|
||||
|
||||
bool Device::OnCreateRenderPipelineAsyncCallback(uint64_t requestSerial,
|
||||
WGPUCreatePipelineAsyncStatus status,
|
||||
const char* message) {
|
||||
const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial);
|
||||
if (requestIt == mCreatePipelineAsyncRequests.end()) {
|
||||
CreatePipelineAsyncRequest request;
|
||||
if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CreatePipelineAsyncRequest request = std::move(requestIt->second);
|
||||
mCreatePipelineAsyncRequests.erase(requestIt);
|
||||
|
||||
auto pipelineAllocation =
|
||||
client->RenderPipelineAllocator().GetObject(request.pipelineObjectID);
|
||||
|
||||
|
|
|
@ -21,8 +21,8 @@
|
|||
#include "dawn_wire/WireCmd_autogen.h"
|
||||
#include "dawn_wire/client/ApiObjects_autogen.h"
|
||||
#include "dawn_wire/client/ObjectBase.h"
|
||||
#include "dawn_wire/client/RequestTracker.h"
|
||||
|
||||
#include <map>
|
||||
#include <memory>
|
||||
|
||||
namespace dawn_wire { namespace client {
|
||||
|
@ -75,8 +75,7 @@ namespace dawn_wire { namespace client {
|
|||
WGPUErrorCallback callback = nullptr;
|
||||
void* userdata = nullptr;
|
||||
};
|
||||
std::map<uint64_t, ErrorScopeData> mErrorScopes;
|
||||
uint64_t mErrorScopeRequestSerial = 0;
|
||||
RequestTracker<ErrorScopeData> mErrorScopes;
|
||||
uint64_t mErrorScopeStackSize = 0;
|
||||
|
||||
struct CreatePipelineAsyncRequest {
|
||||
|
@ -85,8 +84,7 @@ namespace dawn_wire { namespace client {
|
|||
void* userdata = nullptr;
|
||||
ObjectId pipelineObjectID;
|
||||
};
|
||||
std::map<uint64_t, CreatePipelineAsyncRequest> mCreatePipelineAsyncRequests;
|
||||
uint64_t mCreatePipelineAsyncRequestSerial = 0;
|
||||
RequestTracker<CreatePipelineAsyncRequest> mCreatePipelineAsyncRequests;
|
||||
|
||||
WGPUErrorCallback mErrorCallback = nullptr;
|
||||
WGPUDeviceLostCallback mDeviceLostCallback = nullptr;
|
||||
|
|
|
@ -24,17 +24,11 @@ namespace dawn_wire { namespace client {
|
|||
}
|
||||
|
||||
bool Queue::OnWorkDoneCallback(uint64_t requestSerial, WGPUQueueWorkDoneStatus status) {
|
||||
auto requestIt = mOnWorkDoneRequests.find(requestSerial);
|
||||
if (requestIt == mOnWorkDoneRequests.end()) {
|
||||
OnWorkDoneData request;
|
||||
if (!mOnWorkDoneRequests.Acquire(requestSerial, &request)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Remove the request data so that the callback cannot be called again.
|
||||
// ex.) inside the callback: if the queue is deleted (when there are multiple queues),
|
||||
// all callbacks reject.
|
||||
OnWorkDoneData request = std::move(requestIt->second);
|
||||
mOnWorkDoneRequests.erase(requestIt);
|
||||
|
||||
request.callback(status, request.userdata);
|
||||
return true;
|
||||
}
|
||||
|
@ -47,16 +41,13 @@ namespace dawn_wire { namespace client {
|
|||
return;
|
||||
}
|
||||
|
||||
uint32_t serial = mOnWorkDoneSerial++;
|
||||
ASSERT(mOnWorkDoneRequests.find(serial) == mOnWorkDoneRequests.end());
|
||||
uint64_t serial = mOnWorkDoneRequests.Add({callback, userdata});
|
||||
|
||||
QueueOnSubmittedWorkDoneCmd cmd;
|
||||
cmd.queueId = this->id;
|
||||
cmd.signalValue = signalValue;
|
||||
cmd.requestSerial = serial;
|
||||
|
||||
mOnWorkDoneRequests[serial] = {callback, userdata};
|
||||
|
||||
client->SerializeCommand(cmd);
|
||||
}
|
||||
|
||||
|
@ -97,12 +88,11 @@ namespace dawn_wire { namespace client {
|
|||
}
|
||||
|
||||
void Queue::ClearAllCallbacks(WGPUQueueWorkDoneStatus status) {
|
||||
for (auto& it : mOnWorkDoneRequests) {
|
||||
if (it.second.callback) {
|
||||
it.second.callback(status, it.second.userdata);
|
||||
mOnWorkDoneRequests.CloseAll([status](OnWorkDoneData* request) {
|
||||
if (request->callback != nullptr) {
|
||||
request->callback(status, request->userdata);
|
||||
}
|
||||
}
|
||||
mOnWorkDoneRequests.clear();
|
||||
});
|
||||
}
|
||||
|
||||
}} // namespace dawn_wire::client
|
||||
|
|
|
@ -19,8 +19,7 @@
|
|||
|
||||
#include "dawn_wire/WireClient.h"
|
||||
#include "dawn_wire/client/ObjectBase.h"
|
||||
|
||||
#include <map>
|
||||
#include "dawn_wire/client/RequestTracker.h"
|
||||
|
||||
namespace dawn_wire { namespace client {
|
||||
|
||||
|
@ -44,15 +43,13 @@ namespace dawn_wire { namespace client {
|
|||
|
||||
private:
|
||||
void CancelCallbacksForDisconnect() override;
|
||||
|
||||
void ClearAllCallbacks(WGPUQueueWorkDoneStatus status);
|
||||
|
||||
struct OnWorkDoneData {
|
||||
WGPUQueueWorkDoneCallback callback = nullptr;
|
||||
void* userdata = nullptr;
|
||||
};
|
||||
uint64_t mOnWorkDoneSerial = 0;
|
||||
std::map<uint64_t, OnWorkDoneData> mOnWorkDoneRequests;
|
||||
RequestTracker<OnWorkDoneData> mOnWorkDoneRequests;
|
||||
};
|
||||
|
||||
}} // namespace dawn_wire::client
|
||||
|
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright 2021 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef DAWNWIRE_CLIENT_REQUESTTRACKER_H_
|
||||
#define DAWNWIRE_CLIENT_REQUESTTRACKER_H_
|
||||
|
||||
#include "common/Assert.h"
|
||||
#include "common/NonCopyable.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <map>
|
||||
|
||||
namespace dawn_wire { namespace client {
|
||||
|
||||
class Device;
|
||||
class MemoryTransferService;
|
||||
|
||||
template <typename Request>
|
||||
class RequestTracker : NonCopyable {
|
||||
public:
|
||||
~RequestTracker() {
|
||||
ASSERT(mRequests.empty());
|
||||
}
|
||||
|
||||
uint64_t Add(Request&& request) {
|
||||
mSerial++;
|
||||
mRequests.emplace(mSerial, request);
|
||||
return mSerial;
|
||||
}
|
||||
|
||||
bool Acquire(uint64_t serial, Request* request) {
|
||||
auto it = mRequests.find(serial);
|
||||
if (it == mRequests.end()) {
|
||||
return false;
|
||||
}
|
||||
*request = std::move(it->second);
|
||||
mRequests.erase(it);
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename CloseFunc>
|
||||
void CloseAll(CloseFunc&& closeFunc) {
|
||||
// Call closeFunc on all requests while handling reentrancy where the callback of some
|
||||
// requests may add some additional requests. We guarantee all callbacks for requests
|
||||
// are called exactly onces, so keep closing new requests if the first batch added more.
|
||||
// It is fine to loop infinitely here if that's what the application makes use do.
|
||||
while (!mRequests.empty()) {
|
||||
// Move mRequests to a local variable so that further reentrant modifications of
|
||||
// mRequests don't invalidate the iterators.
|
||||
auto allRequests = std::move(mRequests);
|
||||
for (auto& it : allRequests) {
|
||||
closeFunc(&it.second);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F>
|
||||
void ForAll(F&& f) {
|
||||
for (auto& it : mRequests) {
|
||||
f(&it.second);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
uint64_t mSerial;
|
||||
std::map<uint64_t, Request> mRequests;
|
||||
};
|
||||
|
||||
}} // namespace dawn_wire::client
|
||||
|
||||
#endif // DAWNWIRE_CLIENT_REQUESTTRACKER_H_
|
|
@ -19,15 +19,7 @@
|
|||
namespace dawn_wire { namespace client {
|
||||
|
||||
ShaderModule::~ShaderModule() {
|
||||
// Callbacks need to be fired in all cases, as they can handle freeing resources. So we call
|
||||
// them with "Unknown" status.
|
||||
for (auto& it : mCompilationInfoRequests) {
|
||||
if (it.second.callback) {
|
||||
it.second.callback(WGPUCompilationInfoRequestStatus_Unknown, nullptr,
|
||||
it.second.userdata);
|
||||
}
|
||||
}
|
||||
mCompilationInfoRequests.clear();
|
||||
ClearAllCallbacks(WGPUCompilationInfoRequestStatus_Unknown);
|
||||
}
|
||||
|
||||
void ShaderModule::GetCompilationInfo(WGPUCompilationInfoCallback callback, void* userdata) {
|
||||
|
@ -36,41 +28,37 @@ namespace dawn_wire { namespace client {
|
|||
return;
|
||||
}
|
||||
|
||||
uint64_t serial = mCompilationInfoRequestSerial++;
|
||||
uint64_t serial = mCompilationInfoRequests.Add({callback, userdata});
|
||||
|
||||
ShaderModuleGetCompilationInfoCmd cmd;
|
||||
cmd.shaderModuleId = this->id;
|
||||
cmd.requestSerial = serial;
|
||||
|
||||
mCompilationInfoRequests[serial] = {callback, userdata};
|
||||
|
||||
client->SerializeCommand(cmd);
|
||||
}
|
||||
|
||||
bool ShaderModule::GetCompilationInfoCallback(uint64_t requestSerial,
|
||||
WGPUCompilationInfoRequestStatus status,
|
||||
const WGPUCompilationInfo* info) {
|
||||
auto requestIt = mCompilationInfoRequests.find(requestSerial);
|
||||
if (requestIt == mCompilationInfoRequests.end()) {
|
||||
CompilationInfoRequest request;
|
||||
if (!mCompilationInfoRequests.Acquire(requestSerial, &request)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Remove the request data so that the callback cannot be called again.
|
||||
// ex.) inside the callback: if the shader module is deleted, all callbacks reject.
|
||||
CompilationInfoRequest request = std::move(requestIt->second);
|
||||
mCompilationInfoRequests.erase(requestIt);
|
||||
|
||||
request.callback(status, info, request.userdata);
|
||||
return true;
|
||||
}
|
||||
|
||||
void ShaderModule::CancelCallbacksForDisconnect() {
|
||||
for (auto& it : mCompilationInfoRequests) {
|
||||
if (it.second.callback) {
|
||||
it.second.callback(WGPUCompilationInfoRequestStatus_DeviceLost, nullptr,
|
||||
it.second.userdata);
|
||||
ClearAllCallbacks(WGPUCompilationInfoRequestStatus_DeviceLost);
|
||||
}
|
||||
|
||||
void ShaderModule::ClearAllCallbacks(WGPUCompilationInfoRequestStatus status) {
|
||||
mCompilationInfoRequests.CloseAll([status](CompilationInfoRequest* request) {
|
||||
if (request->callback != nullptr) {
|
||||
request->callback(status, nullptr, request->userdata);
|
||||
}
|
||||
mCompilationInfoRequests.clear();
|
||||
});
|
||||
}
|
||||
|
||||
}} // namespace dawn_wire::client
|
||||
|
|
|
@ -17,8 +17,8 @@
|
|||
|
||||
#include <dawn/webgpu.h>
|
||||
|
||||
#include "common/SerialMap.h"
|
||||
#include "dawn_wire/client/ObjectBase.h"
|
||||
#include "dawn_wire/client/RequestTracker.h"
|
||||
|
||||
namespace dawn_wire { namespace client {
|
||||
|
||||
|
@ -32,15 +32,15 @@ namespace dawn_wire { namespace client {
|
|||
WGPUCompilationInfoRequestStatus status,
|
||||
const WGPUCompilationInfo* info);
|
||||
|
||||
void CancelCallbacksForDisconnect() override;
|
||||
|
||||
private:
|
||||
void CancelCallbacksForDisconnect() override;
|
||||
void ClearAllCallbacks(WGPUCompilationInfoRequestStatus status);
|
||||
|
||||
struct CompilationInfoRequest {
|
||||
WGPUCompilationInfoCallback callback = nullptr;
|
||||
void* userdata = nullptr;
|
||||
};
|
||||
uint64_t mCompilationInfoRequestSerial = 0;
|
||||
std::map<uint64_t, CompilationInfoRequest> mCompilationInfoRequests;
|
||||
RequestTracker<CompilationInfoRequest> mCompilationInfoRequests;
|
||||
};
|
||||
|
||||
}} // namespace dawn_wire::client
|
||||
|
|
Loading…
Reference in New Issue