dawn_wire/client: Add RequestTracker helper

This helper helps ensure correct handling of request maps by:

 - Forcing erasing to happen immediately when acquiring a request. This
   prevents some cases of iterator invalidation if we later change the
   container type.
 - Implements correct closure of all callbacks, including if the
   callbacks themselves add more callbacks.

Bug: dawn:1092

Change-Id: Ia0ba9f050bbf3f0dee846f537910523bebb3bf1b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/63003
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Auto-Submit: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
This commit is contained in:
Corentin Wallez 2021-09-01 16:40:22 +00:00 committed by Dawn LUCI CQ
parent 4a4a804476
commit baf8df396c
10 changed files with 190 additions and 172 deletions

View File

@ -140,25 +140,20 @@ namespace dawn_wire { namespace client {
}
Buffer::~Buffer() {
// Callbacks need to be fired in all cases, as they can handle freeing resources
// so we call them with "DestroyedBeforeCallback" status.
for (auto& it : mRequests) {
if (it.second.callback) {
it.second.callback(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback, it.second.userdata);
}
}
mRequests.clear();
ClearAllCallbacks(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
FreeMappedData();
}
void Buffer::CancelCallbacksForDisconnect() {
for (auto& it : mRequests) {
if (it.second.callback) {
it.second.callback(WGPUBufferMapAsyncStatus_DeviceLost, it.second.userdata);
ClearAllCallbacks(WGPUBufferMapAsyncStatus_DeviceLost);
}
void Buffer::ClearAllCallbacks(WGPUBufferMapAsyncStatus status) {
mRequests.CloseAll([status](MapRequestData* request) {
if (request->callback != nullptr) {
request->callback(status, request->userdata);
}
}
mRequests.clear();
});
}
void Buffer::MapAsync(WGPUMapModeFlags mode,
@ -177,10 +172,7 @@ namespace dawn_wire { namespace client {
// Create the request structure that will hold information while this mapping is
// in flight.
uint64_t serial = mRequestSerial++;
ASSERT(mRequests.find(serial) == mRequests.end());
Buffer::MapRequestData request = {};
MapRequestData request = {};
request.callback = callback;
request.userdata = userdata;
request.offset = offset;
@ -191,6 +183,8 @@ namespace dawn_wire { namespace client {
request.type = MapRequestType::Write;
}
uint64_t serial = mRequests.Add(std::move(request));
// Serialize the command to send to the server.
BufferMapAsyncCmd cmd;
cmd.bufferId = this->id;
@ -200,26 +194,17 @@ namespace dawn_wire { namespace client {
cmd.size = size;
client->SerializeCommand(cmd);
// Register this request so that we can retrieve it from its serial when the server
// sends the callback.
mRequests[serial] = std::move(request);
}
bool Buffer::OnMapAsyncCallback(uint64_t requestSerial,
uint32_t status,
uint64_t readDataUpdateInfoLength,
const uint8_t* readDataUpdateInfo) {
auto requestIt = mRequests.find(requestSerial);
if (requestIt == mRequests.end()) {
MapRequestData request;
if (!mRequests.Acquire(requestSerial, &request)) {
return false;
}
auto request = std::move(requestIt->second);
// Delete the request before calling the callback otherwise the callback could be fired a
// second time. If, for example, buffer.Unmap() is called inside the callback.
mRequests.erase(requestIt);
auto FailRequest = [&request]() -> bool {
if (request.callback != nullptr) {
request.callback(WGPUBufferMapAsyncStatus_DeviceLost, request.userdata);
@ -352,11 +337,11 @@ namespace dawn_wire { namespace client {
mMapSize = 0;
// Tag all mapping requests still in flight as unmapped before callback.
for (auto& it : mRequests) {
if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) {
it.second.clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback;
mRequests.ForAll([](MapRequestData* request) {
if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) {
request->clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback;
}
}
});
BufferUnmapCmd cmd;
cmd.self = ToAPI(this);
@ -368,11 +353,11 @@ namespace dawn_wire { namespace client {
FreeMappedData();
// Tag all mapping requests still in flight as destroyed before callback.
for (auto& it : mRequests) {
if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) {
it.second.clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
mRequests.ForAll([](MapRequestData* request) {
if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) {
request->clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback;
}
}
});
BufferDestroyCmd cmd;
cmd.self = ToAPI(this);

View File

@ -19,8 +19,7 @@
#include "dawn_wire/WireClient.h"
#include "dawn_wire/client/ObjectBase.h"
#include <map>
#include "dawn_wire/client/RequestTracker.h"
namespace dawn_wire { namespace client {
@ -52,6 +51,7 @@ namespace dawn_wire { namespace client {
private:
void CancelCallbacksForDisconnect() override;
void ClearAllCallbacks(WGPUBufferMapAsyncStatus status);
bool IsMappedForReading() const;
bool IsMappedForWriting() const;
@ -86,8 +86,7 @@ namespace dawn_wire { namespace client {
MapRequestType type = MapRequestType::None;
};
std::map<uint64_t, MapRequestData> mRequests;
uint64_t mRequestSerial = 0;
RequestTracker<MapRequestData> mRequests;
uint64_t mSize = 0;
// Only one mapped pointer can be active at a time because Unmap clears all the in-flight

View File

@ -19,6 +19,7 @@
#include <dawn_wire/Wire.h>
#include "common/LinkedList.h"
#include "common/NonCopyable.h"
#include "dawn_wire/ChunkedCommandSerializer.h"
#include "dawn_wire/WireClient.h"
#include "dawn_wire/WireCmd_autogen.h"

View File

@ -48,26 +48,23 @@ namespace dawn_wire { namespace client {
}
Device::~Device() {
// Fire pending error scopes
auto errorScopes = std::move(mErrorScopes);
for (const auto& it : errorScopes) {
it.second.callback(WGPUErrorType_Unknown, "Device destroyed before callback",
it.second.userdata);
}
mErrorScopes.CloseAll([](ErrorScopeData* request) {
request->callback(WGPUErrorType_Unknown, "Device destroyed before callback",
request->userdata);
});
auto createPipelineAsyncRequests = std::move(mCreatePipelineAsyncRequests);
for (const auto& it : createPipelineAsyncRequests) {
if (it.second.createComputePipelineAsyncCallback != nullptr) {
it.second.createComputePipelineAsyncCallback(
mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) {
if (request->createComputePipelineAsyncCallback != nullptr) {
request->createComputePipelineAsyncCallback(
WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
"Device destroyed before callback", it.second.userdata);
"Device destroyed before callback", request->userdata);
} else {
ASSERT(it.second.createRenderPipelineAsyncCallback != nullptr);
it.second.createRenderPipelineAsyncCallback(
ASSERT(request->createRenderPipelineAsyncCallback != nullptr);
request->createRenderPipelineAsyncCallback(
WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
"Device destroyed before callback", it.second.userdata);
"Device destroyed before callback", request->userdata);
}
}
});
}
void Device::HandleError(WGPUErrorType errorType, const char* message) {
@ -91,25 +88,22 @@ namespace dawn_wire { namespace client {
}
void Device::CancelCallbacksForDisconnect() {
for (auto& it : mCreatePipelineAsyncRequests) {
ASSERT((it.second.createComputePipelineAsyncCallback != nullptr) ^
(it.second.createRenderPipelineAsyncCallback != nullptr));
if (it.second.createRenderPipelineAsyncCallback) {
it.second.createRenderPipelineAsyncCallback(
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
it.second.userdata);
} else {
it.second.createComputePipelineAsyncCallback(
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
it.second.userdata);
}
}
mCreatePipelineAsyncRequests.clear();
mErrorScopes.CloseAll([](ErrorScopeData* request) {
request->callback(WGPUErrorType_DeviceLost, "Device lost", request->userdata);
});
for (auto& it : mErrorScopes) {
it.second.callback(WGPUErrorType_DeviceLost, "Device lost", it.second.userdata);
}
mErrorScopes.clear();
mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) {
if (request->createComputePipelineAsyncCallback != nullptr) {
request->createComputePipelineAsyncCallback(
WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost",
request->userdata);
} else {
ASSERT(request->createRenderPipelineAsyncCallback != nullptr);
request->createRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost,
nullptr, "Device lost",
request->userdata);
}
});
}
std::weak_ptr<bool> Device::GetAliveWeakPtr() {
@ -152,10 +146,7 @@ namespace dawn_wire { namespace client {
return true;
}
uint64_t serial = mErrorScopeRequestSerial++;
ASSERT(mErrorScopes.find(serial) == mErrorScopes.end());
mErrorScopes[serial] = {callback, userdata};
uint64_t serial = mErrorScopes.Add({callback, userdata});
DevicePopErrorScopeCmd cmd;
cmd.deviceId = this->id;
@ -180,14 +171,11 @@ namespace dawn_wire { namespace client {
return false;
}
auto requestIt = mErrorScopes.find(requestSerial);
if (requestIt == mErrorScopes.end()) {
ErrorScopeData request;
if (!mErrorScopes.Acquire(requestSerial, &request)) {
return false;
}
ErrorScopeData request = std::move(requestIt->second);
mErrorScopes.erase(requestIt);
request.callback(type, message, request.userdata);
return true;
}
@ -265,9 +253,6 @@ namespace dawn_wire { namespace client {
"GPU device disconnected", userdata);
}
DeviceCreateComputePipelineAsyncCmd cmd;
cmd.deviceId = this->id;
// Copy compute to the deprecated computeStage or visa-versa, depending on which one is
// populated, so that serialization doesn't fail.
// TODO(dawn:800): Remove once computeStage is removed.
@ -280,35 +265,32 @@ namespace dawn_wire { namespace client {
localDescriptor.compute.entryPoint = localDescriptor.computeStage.entryPoint;
}
cmd.descriptor = &localDescriptor;
uint64_t serial = mCreatePipelineAsyncRequestSerial++;
ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end());
cmd.requestSerial = serial;
auto* allocation = client->ComputePipelineAllocator().New(client);
CreatePipelineAsyncRequest request = {};
request.createComputePipelineAsyncCallback = callback;
request.userdata = userdata;
request.pipelineObjectID = allocation->object->id;
cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation};
client->SerializeCommand(cmd);
uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
mCreatePipelineAsyncRequests[serial] = std::move(request);
DeviceCreateComputePipelineAsyncCmd cmd;
cmd.deviceId = this->id;
cmd.descriptor = &localDescriptor;
cmd.requestSerial = serial;
cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation};
client->SerializeCommand(cmd);
}
bool Device::OnCreateComputePipelineAsyncCallback(uint64_t requestSerial,
WGPUCreatePipelineAsyncStatus status,
const char* message) {
const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial);
if (requestIt == mCreatePipelineAsyncRequests.end()) {
CreatePipelineAsyncRequest request;
if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) {
return false;
}
CreatePipelineAsyncRequest request = std::move(requestIt->second);
mCreatePipelineAsyncRequests.erase(requestIt);
auto pipelineAllocation =
client->ComputePipelineAllocator().GetObject(request.pipelineObjectID);
@ -333,37 +315,33 @@ namespace dawn_wire { namespace client {
return callback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr,
"GPU device disconnected", userdata);
}
DeviceCreateRenderPipelineAsyncCmd cmd;
cmd.deviceId = this->id;
cmd.descriptor = descriptor;
uint64_t serial = mCreatePipelineAsyncRequestSerial++;
ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end());
cmd.requestSerial = serial;
auto* allocation = client->RenderPipelineAllocator().New(client);
CreatePipelineAsyncRequest request = {};
request.createRenderPipelineAsyncCallback = callback;
request.userdata = userdata;
request.pipelineObjectID = allocation->object->id;
cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
client->SerializeCommand(cmd);
uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request));
mCreatePipelineAsyncRequests[serial] = std::move(request);
DeviceCreateRenderPipelineAsyncCmd cmd;
cmd.deviceId = this->id;
cmd.descriptor = descriptor;
cmd.requestSerial = serial;
cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation);
client->SerializeCommand(cmd);
}
bool Device::OnCreateRenderPipelineAsyncCallback(uint64_t requestSerial,
WGPUCreatePipelineAsyncStatus status,
const char* message) {
const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial);
if (requestIt == mCreatePipelineAsyncRequests.end()) {
CreatePipelineAsyncRequest request;
if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) {
return false;
}
CreatePipelineAsyncRequest request = std::move(requestIt->second);
mCreatePipelineAsyncRequests.erase(requestIt);
auto pipelineAllocation =
client->RenderPipelineAllocator().GetObject(request.pipelineObjectID);

View File

@ -21,8 +21,8 @@
#include "dawn_wire/WireCmd_autogen.h"
#include "dawn_wire/client/ApiObjects_autogen.h"
#include "dawn_wire/client/ObjectBase.h"
#include "dawn_wire/client/RequestTracker.h"
#include <map>
#include <memory>
namespace dawn_wire { namespace client {
@ -75,8 +75,7 @@ namespace dawn_wire { namespace client {
WGPUErrorCallback callback = nullptr;
void* userdata = nullptr;
};
std::map<uint64_t, ErrorScopeData> mErrorScopes;
uint64_t mErrorScopeRequestSerial = 0;
RequestTracker<ErrorScopeData> mErrorScopes;
uint64_t mErrorScopeStackSize = 0;
struct CreatePipelineAsyncRequest {
@ -85,8 +84,7 @@ namespace dawn_wire { namespace client {
void* userdata = nullptr;
ObjectId pipelineObjectID;
};
std::map<uint64_t, CreatePipelineAsyncRequest> mCreatePipelineAsyncRequests;
uint64_t mCreatePipelineAsyncRequestSerial = 0;
RequestTracker<CreatePipelineAsyncRequest> mCreatePipelineAsyncRequests;
WGPUErrorCallback mErrorCallback = nullptr;
WGPUDeviceLostCallback mDeviceLostCallback = nullptr;

View File

@ -24,17 +24,11 @@ namespace dawn_wire { namespace client {
}
bool Queue::OnWorkDoneCallback(uint64_t requestSerial, WGPUQueueWorkDoneStatus status) {
auto requestIt = mOnWorkDoneRequests.find(requestSerial);
if (requestIt == mOnWorkDoneRequests.end()) {
OnWorkDoneData request;
if (!mOnWorkDoneRequests.Acquire(requestSerial, &request)) {
return false;
}
// Remove the request data so that the callback cannot be called again.
// ex.) inside the callback: if the queue is deleted (when there are multiple queues),
// all callbacks reject.
OnWorkDoneData request = std::move(requestIt->second);
mOnWorkDoneRequests.erase(requestIt);
request.callback(status, request.userdata);
return true;
}
@ -47,16 +41,13 @@ namespace dawn_wire { namespace client {
return;
}
uint32_t serial = mOnWorkDoneSerial++;
ASSERT(mOnWorkDoneRequests.find(serial) == mOnWorkDoneRequests.end());
uint64_t serial = mOnWorkDoneRequests.Add({callback, userdata});
QueueOnSubmittedWorkDoneCmd cmd;
cmd.queueId = this->id;
cmd.signalValue = signalValue;
cmd.requestSerial = serial;
mOnWorkDoneRequests[serial] = {callback, userdata};
client->SerializeCommand(cmd);
}
@ -97,12 +88,11 @@ namespace dawn_wire { namespace client {
}
void Queue::ClearAllCallbacks(WGPUQueueWorkDoneStatus status) {
for (auto& it : mOnWorkDoneRequests) {
if (it.second.callback) {
it.second.callback(status, it.second.userdata);
mOnWorkDoneRequests.CloseAll([status](OnWorkDoneData* request) {
if (request->callback != nullptr) {
request->callback(status, request->userdata);
}
}
mOnWorkDoneRequests.clear();
});
}
}} // namespace dawn_wire::client

View File

@ -19,8 +19,7 @@
#include "dawn_wire/WireClient.h"
#include "dawn_wire/client/ObjectBase.h"
#include <map>
#include "dawn_wire/client/RequestTracker.h"
namespace dawn_wire { namespace client {
@ -44,15 +43,13 @@ namespace dawn_wire { namespace client {
private:
void CancelCallbacksForDisconnect() override;
void ClearAllCallbacks(WGPUQueueWorkDoneStatus status);
struct OnWorkDoneData {
WGPUQueueWorkDoneCallback callback = nullptr;
void* userdata = nullptr;
};
uint64_t mOnWorkDoneSerial = 0;
std::map<uint64_t, OnWorkDoneData> mOnWorkDoneRequests;
RequestTracker<OnWorkDoneData> mOnWorkDoneRequests;
};
}} // namespace dawn_wire::client

View File

@ -0,0 +1,82 @@
// Copyright 2021 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DAWNWIRE_CLIENT_REQUESTTRACKER_H_
#define DAWNWIRE_CLIENT_REQUESTTRACKER_H_
#include "common/Assert.h"
#include "common/NonCopyable.h"
#include <cstdint>
#include <map>
namespace dawn_wire { namespace client {
class Device;
class MemoryTransferService;
template <typename Request>
class RequestTracker : NonCopyable {
public:
~RequestTracker() {
ASSERT(mRequests.empty());
}
uint64_t Add(Request&& request) {
mSerial++;
mRequests.emplace(mSerial, request);
return mSerial;
}
bool Acquire(uint64_t serial, Request* request) {
auto it = mRequests.find(serial);
if (it == mRequests.end()) {
return false;
}
*request = std::move(it->second);
mRequests.erase(it);
return true;
}
template <typename CloseFunc>
void CloseAll(CloseFunc&& closeFunc) {
// Call closeFunc on all requests while handling reentrancy where the callback of some
// requests may add some additional requests. We guarantee all callbacks for requests
// are called exactly onces, so keep closing new requests if the first batch added more.
// It is fine to loop infinitely here if that's what the application makes use do.
while (!mRequests.empty()) {
// Move mRequests to a local variable so that further reentrant modifications of
// mRequests don't invalidate the iterators.
auto allRequests = std::move(mRequests);
for (auto& it : allRequests) {
closeFunc(&it.second);
}
}
}
template <typename F>
void ForAll(F&& f) {
for (auto& it : mRequests) {
f(&it.second);
}
}
private:
uint64_t mSerial;
std::map<uint64_t, Request> mRequests;
};
}} // namespace dawn_wire::client
#endif // DAWNWIRE_CLIENT_REQUESTTRACKER_H_

View File

@ -19,15 +19,7 @@
namespace dawn_wire { namespace client {
ShaderModule::~ShaderModule() {
// Callbacks need to be fired in all cases, as they can handle freeing resources. So we call
// them with "Unknown" status.
for (auto& it : mCompilationInfoRequests) {
if (it.second.callback) {
it.second.callback(WGPUCompilationInfoRequestStatus_Unknown, nullptr,
it.second.userdata);
}
}
mCompilationInfoRequests.clear();
ClearAllCallbacks(WGPUCompilationInfoRequestStatus_Unknown);
}
void ShaderModule::GetCompilationInfo(WGPUCompilationInfoCallback callback, void* userdata) {
@ -36,41 +28,37 @@ namespace dawn_wire { namespace client {
return;
}
uint64_t serial = mCompilationInfoRequestSerial++;
uint64_t serial = mCompilationInfoRequests.Add({callback, userdata});
ShaderModuleGetCompilationInfoCmd cmd;
cmd.shaderModuleId = this->id;
cmd.requestSerial = serial;
mCompilationInfoRequests[serial] = {callback, userdata};
client->SerializeCommand(cmd);
}
bool ShaderModule::GetCompilationInfoCallback(uint64_t requestSerial,
WGPUCompilationInfoRequestStatus status,
const WGPUCompilationInfo* info) {
auto requestIt = mCompilationInfoRequests.find(requestSerial);
if (requestIt == mCompilationInfoRequests.end()) {
CompilationInfoRequest request;
if (!mCompilationInfoRequests.Acquire(requestSerial, &request)) {
return false;
}
// Remove the request data so that the callback cannot be called again.
// ex.) inside the callback: if the shader module is deleted, all callbacks reject.
CompilationInfoRequest request = std::move(requestIt->second);
mCompilationInfoRequests.erase(requestIt);
request.callback(status, info, request.userdata);
return true;
}
void ShaderModule::CancelCallbacksForDisconnect() {
for (auto& it : mCompilationInfoRequests) {
if (it.second.callback) {
it.second.callback(WGPUCompilationInfoRequestStatus_DeviceLost, nullptr,
it.second.userdata);
ClearAllCallbacks(WGPUCompilationInfoRequestStatus_DeviceLost);
}
void ShaderModule::ClearAllCallbacks(WGPUCompilationInfoRequestStatus status) {
mCompilationInfoRequests.CloseAll([status](CompilationInfoRequest* request) {
if (request->callback != nullptr) {
request->callback(status, nullptr, request->userdata);
}
}
mCompilationInfoRequests.clear();
});
}
}} // namespace dawn_wire::client

View File

@ -17,8 +17,8 @@
#include <dawn/webgpu.h>
#include "common/SerialMap.h"
#include "dawn_wire/client/ObjectBase.h"
#include "dawn_wire/client/RequestTracker.h"
namespace dawn_wire { namespace client {
@ -32,15 +32,15 @@ namespace dawn_wire { namespace client {
WGPUCompilationInfoRequestStatus status,
const WGPUCompilationInfo* info);
void CancelCallbacksForDisconnect() override;
private:
void CancelCallbacksForDisconnect() override;
void ClearAllCallbacks(WGPUCompilationInfoRequestStatus status);
struct CompilationInfoRequest {
WGPUCompilationInfoCallback callback = nullptr;
void* userdata = nullptr;
};
uint64_t mCompilationInfoRequestSerial = 0;
std::map<uint64_t, CompilationInfoRequest> mCompilationInfoRequests;
RequestTracker<CompilationInfoRequest> mCompilationInfoRequests;
};
}} // namespace dawn_wire::client