dawn_wire/client: Add RequestTracker helper
This helper helps ensure correct handling of request maps by: - Forcing erasing to happen immediately when acquiring a request. This prevents some cases of iterator invalidation if we later change the container type. - Implements correct closure of all callbacks, including if the callbacks themselves add more callbacks. Bug: dawn:1092 Change-Id: Ia0ba9f050bbf3f0dee846f537910523bebb3bf1b Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/63003 Commit-Queue: Corentin Wallez <cwallez@chromium.org> Auto-Submit: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Austin Eng <enga@chromium.org>
This commit is contained in:
parent
4a4a804476
commit
baf8df396c
|
@ -140,25 +140,20 @@ namespace dawn_wire { namespace client {
|
||||||
}
|
}
|
||||||
|
|
||||||
Buffer::~Buffer() {
|
Buffer::~Buffer() {
|
||||||
// Callbacks need to be fired in all cases, as they can handle freeing resources
|
ClearAllCallbacks(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
|
||||||
// so we call them with "DestroyedBeforeCallback" status.
|
|
||||||
for (auto& it : mRequests) {
|
|
||||||
if (it.second.callback) {
|
|
||||||
it.second.callback(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback, it.second.userdata);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mRequests.clear();
|
|
||||||
|
|
||||||
FreeMappedData();
|
FreeMappedData();
|
||||||
}
|
}
|
||||||
|
|
||||||
void Buffer::CancelCallbacksForDisconnect() {
|
void Buffer::CancelCallbacksForDisconnect() {
|
||||||
for (auto& it : mRequests) {
|
ClearAllCallbacks(WGPUBufferMapAsyncStatus_DeviceLost);
|
||||||
if (it.second.callback) {
|
}
|
||||||
it.second.callback(WGPUBufferMapAsyncStatus_DeviceLost, it.second.userdata);
|
|
||||||
|
void Buffer::ClearAllCallbacks(WGPUBufferMapAsyncStatus status) {
|
||||||
|
mRequests.CloseAll([status](MapRequestData* request) {
|
||||||
|
if (request->callback != nullptr) {
|
||||||
|
request->callback(status, request->userdata);
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
mRequests.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Buffer::MapAsync(WGPUMapModeFlags mode,
|
void Buffer::MapAsync(WGPUMapModeFlags mode,
|
||||||
|
@ -177,10 +172,7 @@ namespace dawn_wire { namespace client {
|
||||||
|
|
||||||
// Create the request structure that will hold information while this mapping is
|
// Create the request structure that will hold information while this mapping is
|
||||||
// in flight.
|
// in flight.
|
||||||
uint64_t serial = mRequestSerial++;
|
MapRequestData request = {};
|
||||||
ASSERT(mRequests.find(serial) == mRequests.end());
|
|
||||||
|
|
||||||
Buffer::MapRequestData request = {};
|
|
||||||
request.callback = callback;
|
request.callback = callback;
|
||||||
request.userdata = userdata;
|
request.userdata = userdata;
|
||||||
request.offset = offset;
|
request.offset = offset;
|
||||||
|
@ -191,6 +183,8 @@ namespace dawn_wire { namespace client {
|
||||||
request.type = MapRequestType::Write;
|
request.type = MapRequestType::Write;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint64_t serial = mRequests.Add(std::move(request));
|
||||||
|
|
||||||
// Serialize the command to send to the server.
|
// Serialize the command to send to the server.
|
||||||
BufferMapAsyncCmd cmd;
|
BufferMapAsyncCmd cmd;
|
||||||
cmd.bufferId = this->id;
|
cmd.bufferId = this->id;
|
||||||
|
@ -200,26 +194,17 @@ namespace dawn_wire { namespace client {
|
||||||
cmd.size = size;
|
cmd.size = size;
|
||||||
|
|
||||||
client->SerializeCommand(cmd);
|
client->SerializeCommand(cmd);
|
||||||
|
|
||||||
// Register this request so that we can retrieve it from its serial when the server
|
|
||||||
// sends the callback.
|
|
||||||
mRequests[serial] = std::move(request);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Buffer::OnMapAsyncCallback(uint64_t requestSerial,
|
bool Buffer::OnMapAsyncCallback(uint64_t requestSerial,
|
||||||
uint32_t status,
|
uint32_t status,
|
||||||
uint64_t readDataUpdateInfoLength,
|
uint64_t readDataUpdateInfoLength,
|
||||||
const uint8_t* readDataUpdateInfo) {
|
const uint8_t* readDataUpdateInfo) {
|
||||||
auto requestIt = mRequests.find(requestSerial);
|
MapRequestData request;
|
||||||
if (requestIt == mRequests.end()) {
|
if (!mRequests.Acquire(requestSerial, &request)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto request = std::move(requestIt->second);
|
|
||||||
// Delete the request before calling the callback otherwise the callback could be fired a
|
|
||||||
// second time. If, for example, buffer.Unmap() is called inside the callback.
|
|
||||||
mRequests.erase(requestIt);
|
|
||||||
|
|
||||||
auto FailRequest = [&request]() -> bool {
|
auto FailRequest = [&request]() -> bool {
|
||||||
if (request.callback != nullptr) {
|
if (request.callback != nullptr) {
|
||||||
request.callback(WGPUBufferMapAsyncStatus_DeviceLost, request.userdata);
|
request.callback(WGPUBufferMapAsyncStatus_DeviceLost, request.userdata);
|
||||||
|
@ -352,11 +337,11 @@ namespace dawn_wire { namespace client {
|
||||||
mMapSize = 0;
|
mMapSize = 0;
|
||||||
|
|
||||||
// Tag all mapping requests still in flight as unmapped before callback.
|
// Tag all mapping requests still in flight as unmapped before callback.
|
||||||
for (auto& it : mRequests) {
|
mRequests.ForAll([](MapRequestData* request) {
|
||||||
if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) {
|
if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) {
|
||||||
it.second.clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback;
|
request->clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback;
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
|
|
||||||
BufferUnmapCmd cmd;
|
BufferUnmapCmd cmd;
|
||||||
cmd.self = ToAPI(this);
|
cmd.self = ToAPI(this);
|
||||||
|
@ -368,11 +353,11 @@ namespace dawn_wire { namespace client {
|
||||||
FreeMappedData();
|
FreeMappedData();
|
||||||
|
|
||||||
// Tag all mapping requests still in flight as destroyed before callback.
|
// Tag all mapping requests still in flight as destroyed before callback.
|
||||||
for (auto& it : mRequests) {
|
mRequests.ForAll([](MapRequestData* request) {
|
||||||
if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) {
|
if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) {
|
||||||
it.second.clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
|
request->clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
|
|
||||||
BufferDestroyCmd cmd;
|
BufferDestroyCmd cmd;
|
||||||
cmd.self = ToAPI(this);
|
cmd.self = ToAPI(this);
|
||||||
|
|
|
@ -19,8 +19,7 @@
|
||||||
|
|
||||||
#include "dawn_wire/WireClient.h"
|
#include "dawn_wire/WireClient.h"
|
||||||
#include "dawn_wire/client/ObjectBase.h"
|
#include "dawn_wire/client/ObjectBase.h"
|
||||||
|
#include "dawn_wire/client/RequestTracker.h"
|
||||||
#include <map>
|
|
||||||
|
|
||||||
namespace dawn_wire { namespace client {
|
namespace dawn_wire { namespace client {
|
||||||
|
|
||||||
|
@ -52,6 +51,7 @@ namespace dawn_wire { namespace client {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void CancelCallbacksForDisconnect() override;
|
void CancelCallbacksForDisconnect() override;
|
||||||
|
void ClearAllCallbacks(WGPUBufferMapAsyncStatus status);
|
||||||
|
|
||||||
bool IsMappedForReading() const;
|
bool IsMappedForReading() const;
|
||||||
bool IsMappedForWriting() const;
|
bool IsMappedForWriting() const;
|
||||||
|
@ -86,8 +86,7 @@ namespace dawn_wire { namespace client {
|
||||||
|
|
||||||
MapRequestType type = MapRequestType::None;
|
MapRequestType type = MapRequestType::None;
|
||||||
};
|
};
|
||||||
std::map<uint64_t, MapRequestData> mRequests;
|
RequestTracker<MapRequestData> mRequests;
|
||||||
uint64_t mRequestSerial = 0;
|
|
||||||
uint64_t mSize = 0;
|
uint64_t mSize = 0;
|
||||||
|
|
||||||
// Only one mapped pointer can be active at a time because Unmap clears all the in-flight
|
// Only one mapped pointer can be active at a time because Unmap clears all the in-flight
|
||||||
|
|
|
@ -19,6 +19,7 @@
|
||||||
#include <dawn_wire/Wire.h>
|
#include <dawn_wire/Wire.h>
|
||||||
|
|
||||||
#include "common/LinkedList.h"
|
#include "common/LinkedList.h"
|
||||||
|
#include "common/NonCopyable.h"
|
||||||
#include "dawn_wire/ChunkedCommandSerializer.h"
|
#include "dawn_wire/ChunkedCommandSerializer.h"
|
||||||
#include "dawn_wire/WireClient.h"
|
#include "dawn_wire/WireClient.h"
|
||||||
#include "dawn_wire/WireCmd_autogen.h"
|
#include "dawn_wire/WireCmd_autogen.h"
|
||||||
|
|
|
@ -48,26 +48,23 @@ namespace dawn_wire { namespace client {
|
||||||
}
|
}
|
||||||
|
|
||||||
Device::~Device() {
|
Device::~Device() {
|
||||||
// Fire pending error scopes
|
mErrorScopes.CloseAll([](ErrorScopeData* request) {
|
||||||
auto errorScopes = std::move(mErrorScopes);
|
request->callback(WGPUErrorType_Unknown, "Device destroyed before callback",
|
||||||
for (const auto& it : errorScopes) {
|
request->userdata);
|
||||||
it.second.callback(WGPUErrorType_Unknown, "Device destroyed before callback",
|
});
|
||||||
it.second.userdata);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto createPipelineAsyncRequests = std::move(mCreatePipelineAsyncRequests);
|
mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) {
|
||||||
for (const auto& it : createPipelineAsyncRequests) {
|
if (request->createComputePipelineAsyncCallback != nullptr) {
|
||||||
if (it.second.createComputePipelineAsyncCallback != nullptr) {
|
request->createComputePipelineAsyncCallback(
|
||||||
it.second.createComputePipelineAsyncCallback(
|
|
||||||
WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
|
WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
|
||||||
"Device destroyed before callback", it.second.userdata);
|
"Device destroyed before callback", request->userdata);
|
||||||
} else {
|
} else {
|
||||||
ASSERT(it.second.createRenderPipelineAsyncCallback != nullptr);
|
ASSERT(request->createRenderPipelineAsyncCallback != nullptr);
|
||||||
it.second.createRenderPipelineAsyncCallback(
|
request->createRenderPipelineAsyncCallback(
|
||||||
WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
|
WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
|
||||||
"Device destroyed before callback", it.second.userdata);
|
"Device destroyed before callback", request->userdata);
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::HandleError(WGPUErrorType errorType, const char* message) {
|
void Device::HandleError(WGPUErrorType errorType, const char* message) {
|
||||||
|
@ -91,25 +88,22 @@ namespace dawn_wire { namespace client {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Device::CancelCallbacksForDisconnect() {
|
void Device::CancelCallbacksForDisconnect() {
|
||||||
for (auto& it : mCreatePipelineAsyncRequests) {
|
mErrorScopes.CloseAll([](ErrorScopeData* request) {
|
||||||
ASSERT((it.second.createComputePipelineAsyncCallback != nullptr) ^
|
request->callback(WGPUErrorType_DeviceLost, "Device lost", request->userdata);
|
||||||
(it.second.createRenderPipelineAsyncCallback != nullptr));
|
});
|
||||||
if (it.second.createRenderPipelineAsyncCallback) {
|
|
||||||
it.second.createRenderPipelineAsyncCallback(
|
|
||||||
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
|
|
||||||
it.second.userdata);
|
|
||||||
} else {
|
|
||||||
it.second.createComputePipelineAsyncCallback(
|
|
||||||
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
|
|
||||||
it.second.userdata);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mCreatePipelineAsyncRequests.clear();
|
|
||||||
|
|
||||||
for (auto& it : mErrorScopes) {
|
mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) {
|
||||||
it.second.callback(WGPUErrorType_DeviceLost, "Device lost", it.second.userdata);
|
if (request->createComputePipelineAsyncCallback != nullptr) {
|
||||||
}
|
request->createComputePipelineAsyncCallback(
|
||||||
mErrorScopes.clear();
|
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
|
||||||
|
request->userdata);
|
||||||
|
} else {
|
||||||
|
ASSERT(request->createRenderPipelineAsyncCallback != nullptr);
|
||||||
|
request->createRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost,
|
||||||
|
nullptr, "Device lost",
|
||||||
|
request->userdata);
|
||||||
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
std::weak_ptr<bool> Device::GetAliveWeakPtr() {
|
std::weak_ptr<bool> Device::GetAliveWeakPtr() {
|
||||||
|
@ -152,10 +146,7 @@ namespace dawn_wire { namespace client {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t serial = mErrorScopeRequestSerial++;
|
uint64_t serial = mErrorScopes.Add({callback, userdata});
|
||||||
ASSERT(mErrorScopes.find(serial) == mErrorScopes.end());
|
|
||||||
|
|
||||||
mErrorScopes[serial] = {callback, userdata};
|
|
||||||
|
|
||||||
DevicePopErrorScopeCmd cmd;
|
DevicePopErrorScopeCmd cmd;
|
||||||
cmd.deviceId = this->id;
|
cmd.deviceId = this->id;
|
||||||
|
@ -180,14 +171,11 @@ namespace dawn_wire { namespace client {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto requestIt = mErrorScopes.find(requestSerial);
|
ErrorScopeData request;
|
||||||
if (requestIt == mErrorScopes.end()) {
|
if (!mErrorScopes.Acquire(requestSerial, &request)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
ErrorScopeData request = std::move(requestIt->second);
|
|
||||||
|
|
||||||
mErrorScopes.erase(requestIt);
|
|
||||||
request.callback(type, message, request.userdata);
|
request.callback(type, message, request.userdata);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -265,9 +253,6 @@ namespace dawn_wire { namespace client {
|
||||||
"GPU device disconnected", userdata);
|
"GPU device disconnected", userdata);
|
||||||
}
|
}
|
||||||
|
|
||||||
DeviceCreateComputePipelineAsyncCmd cmd;
|
|
||||||
cmd.deviceId = this->id;
|
|
||||||
|
|
||||||
// Copy compute to the deprecated computeStage or visa-versa, depending on which one is
|
// Copy compute to the deprecated computeStage or visa-versa, depending on which one is
|
||||||
// populated, so that serialization doesn't fail.
|
// populated, so that serialization doesn't fail.
|
||||||
// TODO(dawn:800): Remove once computeStage is removed.
|
// TODO(dawn:800): Remove once computeStage is removed.
|
||||||
|
@ -280,35 +265,32 @@ namespace dawn_wire { namespace client {
|
||||||
localDescriptor.compute.entryPoint = localDescriptor.computeStage.entryPoint;
|
localDescriptor.compute.entryPoint = localDescriptor.computeStage.entryPoint;
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.descriptor = &localDescriptor;
|
|
||||||
|
|
||||||
uint64_t serial = mCreatePipelineAsyncRequestSerial++;
|
|
||||||
ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end());
|
|
||||||
cmd.requestSerial = serial;
|
|
||||||
|
|
||||||
auto* allocation = client->ComputePipelineAllocator().New(client);
|
auto* allocation = client->ComputePipelineAllocator().New(client);
|
||||||
|
|
||||||
CreatePipelineAsyncRequest request = {};
|
CreatePipelineAsyncRequest request = {};
|
||||||
request.createComputePipelineAsyncCallback = callback;
|
request.createComputePipelineAsyncCallback = callback;
|
||||||
request.userdata = userdata;
|
request.userdata = userdata;
|
||||||
request.pipelineObjectID = allocation->object->id;
|
request.pipelineObjectID = allocation->object->id;
|
||||||
|
|
||||||
cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation};
|
uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
|
||||||
client->SerializeCommand(cmd);
|
|
||||||
|
|
||||||
mCreatePipelineAsyncRequests[serial] = std::move(request);
|
DeviceCreateComputePipelineAsyncCmd cmd;
|
||||||
|
cmd.deviceId = this->id;
|
||||||
|
cmd.descriptor = &localDescriptor;
|
||||||
|
cmd.requestSerial = serial;
|
||||||
|
cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation};
|
||||||
|
|
||||||
|
client->SerializeCommand(cmd);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Device::OnCreateComputePipelineAsyncCallback(uint64_t requestSerial,
|
bool Device::OnCreateComputePipelineAsyncCallback(uint64_t requestSerial,
|
||||||
WGPUCreatePipelineAsyncStatus status,
|
WGPUCreatePipelineAsyncStatus status,
|
||||||
const char* message) {
|
const char* message) {
|
||||||
const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial);
|
CreatePipelineAsyncRequest request;
|
||||||
if (requestIt == mCreatePipelineAsyncRequests.end()) {
|
if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
CreatePipelineAsyncRequest request = std::move(requestIt->second);
|
|
||||||
mCreatePipelineAsyncRequests.erase(requestIt);
|
|
||||||
|
|
||||||
auto pipelineAllocation =
|
auto pipelineAllocation =
|
||||||
client->ComputePipelineAllocator().GetObject(request.pipelineObjectID);
|
client->ComputePipelineAllocator().GetObject(request.pipelineObjectID);
|
||||||
|
|
||||||
|
@ -333,37 +315,33 @@ namespace dawn_wire { namespace client {
|
||||||
return callback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr,
|
return callback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr,
|
||||||
"GPU device disconnected", userdata);
|
"GPU device disconnected", userdata);
|
||||||
}
|
}
|
||||||
DeviceCreateRenderPipelineAsyncCmd cmd;
|
|
||||||
cmd.deviceId = this->id;
|
|
||||||
cmd.descriptor = descriptor;
|
|
||||||
|
|
||||||
uint64_t serial = mCreatePipelineAsyncRequestSerial++;
|
|
||||||
ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end());
|
|
||||||
cmd.requestSerial = serial;
|
|
||||||
|
|
||||||
auto* allocation = client->RenderPipelineAllocator().New(client);
|
auto* allocation = client->RenderPipelineAllocator().New(client);
|
||||||
|
|
||||||
CreatePipelineAsyncRequest request = {};
|
CreatePipelineAsyncRequest request = {};
|
||||||
request.createRenderPipelineAsyncCallback = callback;
|
request.createRenderPipelineAsyncCallback = callback;
|
||||||
request.userdata = userdata;
|
request.userdata = userdata;
|
||||||
request.pipelineObjectID = allocation->object->id;
|
request.pipelineObjectID = allocation->object->id;
|
||||||
|
|
||||||
cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
|
uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
|
||||||
client->SerializeCommand(cmd);
|
|
||||||
|
|
||||||
mCreatePipelineAsyncRequests[serial] = std::move(request);
|
DeviceCreateRenderPipelineAsyncCmd cmd;
|
||||||
|
cmd.deviceId = this->id;
|
||||||
|
cmd.descriptor = descriptor;
|
||||||
|
cmd.requestSerial = serial;
|
||||||
|
cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
|
||||||
|
|
||||||
|
client->SerializeCommand(cmd);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Device::OnCreateRenderPipelineAsyncCallback(uint64_t requestSerial,
|
bool Device::OnCreateRenderPipelineAsyncCallback(uint64_t requestSerial,
|
||||||
WGPUCreatePipelineAsyncStatus status,
|
WGPUCreatePipelineAsyncStatus status,
|
||||||
const char* message) {
|
const char* message) {
|
||||||
const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial);
|
CreatePipelineAsyncRequest request;
|
||||||
if (requestIt == mCreatePipelineAsyncRequests.end()) {
|
if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
CreatePipelineAsyncRequest request = std::move(requestIt->second);
|
|
||||||
mCreatePipelineAsyncRequests.erase(requestIt);
|
|
||||||
|
|
||||||
auto pipelineAllocation =
|
auto pipelineAllocation =
|
||||||
client->RenderPipelineAllocator().GetObject(request.pipelineObjectID);
|
client->RenderPipelineAllocator().GetObject(request.pipelineObjectID);
|
||||||
|
|
||||||
|
|
|
@ -21,8 +21,8 @@
|
||||||
#include "dawn_wire/WireCmd_autogen.h"
|
#include "dawn_wire/WireCmd_autogen.h"
|
||||||
#include "dawn_wire/client/ApiObjects_autogen.h"
|
#include "dawn_wire/client/ApiObjects_autogen.h"
|
||||||
#include "dawn_wire/client/ObjectBase.h"
|
#include "dawn_wire/client/ObjectBase.h"
|
||||||
|
#include "dawn_wire/client/RequestTracker.h"
|
||||||
|
|
||||||
#include <map>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
namespace dawn_wire { namespace client {
|
namespace dawn_wire { namespace client {
|
||||||
|
@ -75,8 +75,7 @@ namespace dawn_wire { namespace client {
|
||||||
WGPUErrorCallback callback = nullptr;
|
WGPUErrorCallback callback = nullptr;
|
||||||
void* userdata = nullptr;
|
void* userdata = nullptr;
|
||||||
};
|
};
|
||||||
std::map<uint64_t, ErrorScopeData> mErrorScopes;
|
RequestTracker<ErrorScopeData> mErrorScopes;
|
||||||
uint64_t mErrorScopeRequestSerial = 0;
|
|
||||||
uint64_t mErrorScopeStackSize = 0;
|
uint64_t mErrorScopeStackSize = 0;
|
||||||
|
|
||||||
struct CreatePipelineAsyncRequest {
|
struct CreatePipelineAsyncRequest {
|
||||||
|
@ -85,8 +84,7 @@ namespace dawn_wire { namespace client {
|
||||||
void* userdata = nullptr;
|
void* userdata = nullptr;
|
||||||
ObjectId pipelineObjectID;
|
ObjectId pipelineObjectID;
|
||||||
};
|
};
|
||||||
std::map<uint64_t, CreatePipelineAsyncRequest> mCreatePipelineAsyncRequests;
|
RequestTracker<CreatePipelineAsyncRequest> mCreatePipelineAsyncRequests;
|
||||||
uint64_t mCreatePipelineAsyncRequestSerial = 0;
|
|
||||||
|
|
||||||
WGPUErrorCallback mErrorCallback = nullptr;
|
WGPUErrorCallback mErrorCallback = nullptr;
|
||||||
WGPUDeviceLostCallback mDeviceLostCallback = nullptr;
|
WGPUDeviceLostCallback mDeviceLostCallback = nullptr;
|
||||||
|
|
|
@ -24,17 +24,11 @@ namespace dawn_wire { namespace client {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Queue::OnWorkDoneCallback(uint64_t requestSerial, WGPUQueueWorkDoneStatus status) {
|
bool Queue::OnWorkDoneCallback(uint64_t requestSerial, WGPUQueueWorkDoneStatus status) {
|
||||||
auto requestIt = mOnWorkDoneRequests.find(requestSerial);
|
OnWorkDoneData request;
|
||||||
if (requestIt == mOnWorkDoneRequests.end()) {
|
if (!mOnWorkDoneRequests.Acquire(requestSerial, &request)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the request data so that the callback cannot be called again.
|
|
||||||
// ex.) inside the callback: if the queue is deleted (when there are multiple queues),
|
|
||||||
// all callbacks reject.
|
|
||||||
OnWorkDoneData request = std::move(requestIt->second);
|
|
||||||
mOnWorkDoneRequests.erase(requestIt);
|
|
||||||
|
|
||||||
request.callback(status, request.userdata);
|
request.callback(status, request.userdata);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -47,16 +41,13 @@ namespace dawn_wire { namespace client {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t serial = mOnWorkDoneSerial++;
|
uint64_t serial = mOnWorkDoneRequests.Add({callback, userdata});
|
||||||
ASSERT(mOnWorkDoneRequests.find(serial) == mOnWorkDoneRequests.end());
|
|
||||||
|
|
||||||
QueueOnSubmittedWorkDoneCmd cmd;
|
QueueOnSubmittedWorkDoneCmd cmd;
|
||||||
cmd.queueId = this->id;
|
cmd.queueId = this->id;
|
||||||
cmd.signalValue = signalValue;
|
cmd.signalValue = signalValue;
|
||||||
cmd.requestSerial = serial;
|
cmd.requestSerial = serial;
|
||||||
|
|
||||||
mOnWorkDoneRequests[serial] = {callback, userdata};
|
|
||||||
|
|
||||||
client->SerializeCommand(cmd);
|
client->SerializeCommand(cmd);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -97,12 +88,11 @@ namespace dawn_wire { namespace client {
|
||||||
}
|
}
|
||||||
|
|
||||||
void Queue::ClearAllCallbacks(WGPUQueueWorkDoneStatus status) {
|
void Queue::ClearAllCallbacks(WGPUQueueWorkDoneStatus status) {
|
||||||
for (auto& it : mOnWorkDoneRequests) {
|
mOnWorkDoneRequests.CloseAll([status](OnWorkDoneData* request) {
|
||||||
if (it.second.callback) {
|
if (request->callback != nullptr) {
|
||||||
it.second.callback(status, it.second.userdata);
|
request->callback(status, request->userdata);
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
mOnWorkDoneRequests.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}} // namespace dawn_wire::client
|
}} // namespace dawn_wire::client
|
||||||
|
|
|
@ -19,8 +19,7 @@
|
||||||
|
|
||||||
#include "dawn_wire/WireClient.h"
|
#include "dawn_wire/WireClient.h"
|
||||||
#include "dawn_wire/client/ObjectBase.h"
|
#include "dawn_wire/client/ObjectBase.h"
|
||||||
|
#include "dawn_wire/client/RequestTracker.h"
|
||||||
#include <map>
|
|
||||||
|
|
||||||
namespace dawn_wire { namespace client {
|
namespace dawn_wire { namespace client {
|
||||||
|
|
||||||
|
@ -44,15 +43,13 @@ namespace dawn_wire { namespace client {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void CancelCallbacksForDisconnect() override;
|
void CancelCallbacksForDisconnect() override;
|
||||||
|
|
||||||
void ClearAllCallbacks(WGPUQueueWorkDoneStatus status);
|
void ClearAllCallbacks(WGPUQueueWorkDoneStatus status);
|
||||||
|
|
||||||
struct OnWorkDoneData {
|
struct OnWorkDoneData {
|
||||||
WGPUQueueWorkDoneCallback callback = nullptr;
|
WGPUQueueWorkDoneCallback callback = nullptr;
|
||||||
void* userdata = nullptr;
|
void* userdata = nullptr;
|
||||||
};
|
};
|
||||||
uint64_t mOnWorkDoneSerial = 0;
|
RequestTracker<OnWorkDoneData> mOnWorkDoneRequests;
|
||||||
std::map<uint64_t, OnWorkDoneData> mOnWorkDoneRequests;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}} // namespace dawn_wire::client
|
}} // namespace dawn_wire::client
|
||||||
|
|
|
@ -0,0 +1,82 @@
|
||||||
|
// Copyright 2021 The Dawn Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef DAWNWIRE_CLIENT_REQUESTTRACKER_H_
|
||||||
|
#define DAWNWIRE_CLIENT_REQUESTTRACKER_H_
|
||||||
|
|
||||||
|
#include "common/Assert.h"
|
||||||
|
#include "common/NonCopyable.h"
|
||||||
|
|
||||||
|
#include <cstdint>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
|
namespace dawn_wire { namespace client {
|
||||||
|
|
||||||
|
class Device;
|
||||||
|
class MemoryTransferService;
|
||||||
|
|
||||||
|
template <typename Request>
|
||||||
|
class RequestTracker : NonCopyable {
|
||||||
|
public:
|
||||||
|
~RequestTracker() {
|
||||||
|
ASSERT(mRequests.empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t Add(Request&& request) {
|
||||||
|
mSerial++;
|
||||||
|
mRequests.emplace(mSerial, request);
|
||||||
|
return mSerial;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Acquire(uint64_t serial, Request* request) {
|
||||||
|
auto it = mRequests.find(serial);
|
||||||
|
if (it == mRequests.end()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
*request = std::move(it->second);
|
||||||
|
mRequests.erase(it);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename CloseFunc>
|
||||||
|
void CloseAll(CloseFunc&& closeFunc) {
|
||||||
|
// Call closeFunc on all requests while handling reentrancy where the callback of some
|
||||||
|
// requests may add some additional requests. We guarantee all callbacks for requests
|
||||||
|
// are called exactly onces, so keep closing new requests if the first batch added more.
|
||||||
|
// It is fine to loop infinitely here if that's what the application makes use do.
|
||||||
|
while (!mRequests.empty()) {
|
||||||
|
// Move mRequests to a local variable so that further reentrant modifications of
|
||||||
|
// mRequests don't invalidate the iterators.
|
||||||
|
auto allRequests = std::move(mRequests);
|
||||||
|
for (auto& it : allRequests) {
|
||||||
|
closeFunc(&it.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename F>
|
||||||
|
void ForAll(F&& f) {
|
||||||
|
for (auto& it : mRequests) {
|
||||||
|
f(&it.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint64_t mSerial;
|
||||||
|
std::map<uint64_t, Request> mRequests;
|
||||||
|
};
|
||||||
|
|
||||||
|
}} // namespace dawn_wire::client
|
||||||
|
|
||||||
|
#endif // DAWNWIRE_CLIENT_REQUESTTRACKER_H_
|
|
@ -19,15 +19,7 @@
|
||||||
namespace dawn_wire { namespace client {
|
namespace dawn_wire { namespace client {
|
||||||
|
|
||||||
ShaderModule::~ShaderModule() {
|
ShaderModule::~ShaderModule() {
|
||||||
// Callbacks need to be fired in all cases, as they can handle freeing resources. So we call
|
ClearAllCallbacks(WGPUCompilationInfoRequestStatus_Unknown);
|
||||||
// them with "Unknown" status.
|
|
||||||
for (auto& it : mCompilationInfoRequests) {
|
|
||||||
if (it.second.callback) {
|
|
||||||
it.second.callback(WGPUCompilationInfoRequestStatus_Unknown, nullptr,
|
|
||||||
it.second.userdata);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
mCompilationInfoRequests.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShaderModule::GetCompilationInfo(WGPUCompilationInfoCallback callback, void* userdata) {
|
void ShaderModule::GetCompilationInfo(WGPUCompilationInfoCallback callback, void* userdata) {
|
||||||
|
@ -36,41 +28,37 @@ namespace dawn_wire { namespace client {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t serial = mCompilationInfoRequestSerial++;
|
uint64_t serial = mCompilationInfoRequests.Add({callback, userdata});
|
||||||
|
|
||||||
ShaderModuleGetCompilationInfoCmd cmd;
|
ShaderModuleGetCompilationInfoCmd cmd;
|
||||||
cmd.shaderModuleId = this->id;
|
cmd.shaderModuleId = this->id;
|
||||||
cmd.requestSerial = serial;
|
cmd.requestSerial = serial;
|
||||||
|
|
||||||
mCompilationInfoRequests[serial] = {callback, userdata};
|
|
||||||
|
|
||||||
client->SerializeCommand(cmd);
|
client->SerializeCommand(cmd);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ShaderModule::GetCompilationInfoCallback(uint64_t requestSerial,
|
bool ShaderModule::GetCompilationInfoCallback(uint64_t requestSerial,
|
||||||
WGPUCompilationInfoRequestStatus status,
|
WGPUCompilationInfoRequestStatus status,
|
||||||
const WGPUCompilationInfo* info) {
|
const WGPUCompilationInfo* info) {
|
||||||
auto requestIt = mCompilationInfoRequests.find(requestSerial);
|
CompilationInfoRequest request;
|
||||||
if (requestIt == mCompilationInfoRequests.end()) {
|
if (!mCompilationInfoRequests.Acquire(requestSerial, &request)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Remove the request data so that the callback cannot be called again.
|
|
||||||
// ex.) inside the callback: if the shader module is deleted, all callbacks reject.
|
|
||||||
CompilationInfoRequest request = std::move(requestIt->second);
|
|
||||||
mCompilationInfoRequests.erase(requestIt);
|
|
||||||
|
|
||||||
request.callback(status, info, request.userdata);
|
request.callback(status, info, request.userdata);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShaderModule::CancelCallbacksForDisconnect() {
|
void ShaderModule::CancelCallbacksForDisconnect() {
|
||||||
for (auto& it : mCompilationInfoRequests) {
|
ClearAllCallbacks(WGPUCompilationInfoRequestStatus_DeviceLost);
|
||||||
if (it.second.callback) {
|
}
|
||||||
it.second.callback(WGPUCompilationInfoRequestStatus_DeviceLost, nullptr,
|
|
||||||
it.second.userdata);
|
void ShaderModule::ClearAllCallbacks(WGPUCompilationInfoRequestStatus status) {
|
||||||
|
mCompilationInfoRequests.CloseAll([status](CompilationInfoRequest* request) {
|
||||||
|
if (request->callback != nullptr) {
|
||||||
|
request->callback(status, nullptr, request->userdata);
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
mCompilationInfoRequests.clear();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}} // namespace dawn_wire::client
|
}} // namespace dawn_wire::client
|
||||||
|
|
|
@ -17,8 +17,8 @@
|
||||||
|
|
||||||
#include <dawn/webgpu.h>
|
#include <dawn/webgpu.h>
|
||||||
|
|
||||||
#include "common/SerialMap.h"
|
|
||||||
#include "dawn_wire/client/ObjectBase.h"
|
#include "dawn_wire/client/ObjectBase.h"
|
||||||
|
#include "dawn_wire/client/RequestTracker.h"
|
||||||
|
|
||||||
namespace dawn_wire { namespace client {
|
namespace dawn_wire { namespace client {
|
||||||
|
|
||||||
|
@ -32,15 +32,15 @@ namespace dawn_wire { namespace client {
|
||||||
WGPUCompilationInfoRequestStatus status,
|
WGPUCompilationInfoRequestStatus status,
|
||||||
const WGPUCompilationInfo* info);
|
const WGPUCompilationInfo* info);
|
||||||
|
|
||||||
void CancelCallbacksForDisconnect() override;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
void CancelCallbacksForDisconnect() override;
|
||||||
|
void ClearAllCallbacks(WGPUCompilationInfoRequestStatus status);
|
||||||
|
|
||||||
struct CompilationInfoRequest {
|
struct CompilationInfoRequest {
|
||||||
WGPUCompilationInfoCallback callback = nullptr;
|
WGPUCompilationInfoCallback callback = nullptr;
|
||||||
void* userdata = nullptr;
|
void* userdata = nullptr;
|
||||||
};
|
};
|
||||||
uint64_t mCompilationInfoRequestSerial = 0;
|
RequestTracker<CompilationInfoRequest> mCompilationInfoRequests;
|
||||||
std::map<uint64_t, CompilationInfoRequest> mCompilationInfoRequests;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
}} // namespace dawn_wire::client
|
}} // namespace dawn_wire::client
|
||||||
|
|
Loading…
Reference in New Issue