From baf8df396c2c95bd44402dfdbea79eec20d4c6e0 Mon Sep 17 00:00:00 2001 From: Corentin Wallez Date: Wed, 1 Sep 2021 16:40:22 +0000 Subject: [PATCH] dawn_wire/client: Add RequestTracker helper This helper helps ensure correct handling of request maps by: - Forcing erasing to happen immediately when acquiring a request. This prevents some cases of iterator invalidation if we later change the container type. - Implements correct closure of all callbacks, including if the callbacks themselves add more callbacks. Bug: dawn:1092 Change-Id: Ia0ba9f050bbf3f0dee846f537910523bebb3bf1b Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/63003 Commit-Queue: Corentin Wallez Auto-Submit: Corentin Wallez Reviewed-by: Austin Eng --- src/dawn_wire/client/Buffer.cpp | 59 +++++------- src/dawn_wire/client/Buffer.h | 7 +- src/dawn_wire/client/Client.h | 1 + src/dawn_wire/client/Device.cpp | 126 +++++++++++--------------- src/dawn_wire/client/Device.h | 8 +- src/dawn_wire/client/Queue.cpp | 24 ++--- src/dawn_wire/client/Queue.h | 7 +- src/dawn_wire/client/RequestTracker.h | 82 +++++++++++++++++ src/dawn_wire/client/ShaderModule.cpp | 38 +++----- src/dawn_wire/client/ShaderModule.h | 10 +- 10 files changed, 190 insertions(+), 172 deletions(-) create mode 100644 src/dawn_wire/client/RequestTracker.h diff --git a/src/dawn_wire/client/Buffer.cpp b/src/dawn_wire/client/Buffer.cpp index 2233c8114e..f27b99ea4e 100644 --- a/src/dawn_wire/client/Buffer.cpp +++ b/src/dawn_wire/client/Buffer.cpp @@ -140,25 +140,20 @@ namespace dawn_wire { namespace client { } Buffer::~Buffer() { - // Callbacks need to be fired in all cases, as they can handle freeing resources - // so we call them with "DestroyedBeforeCallback" status. - for (auto& it : mRequests) { - if (it.second.callback) { - it.second.callback(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback, it.second.userdata); - } - } - mRequests.clear(); - + ClearAllCallbacks(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback); FreeMappedData(); } void Buffer::CancelCallbacksForDisconnect() { - for (auto& it : mRequests) { - if (it.second.callback) { - it.second.callback(WGPUBufferMapAsyncStatus_DeviceLost, it.second.userdata); + ClearAllCallbacks(WGPUBufferMapAsyncStatus_DeviceLost); + } + + void Buffer::ClearAllCallbacks(WGPUBufferMapAsyncStatus status) { + mRequests.CloseAll([status](MapRequestData* request) { + if (request->callback != nullptr) { + request->callback(status, request->userdata); } - } - mRequests.clear(); + }); } void Buffer::MapAsync(WGPUMapModeFlags mode, @@ -177,10 +172,7 @@ namespace dawn_wire { namespace client { // Create the request structure that will hold information while this mapping is // in flight. - uint64_t serial = mRequestSerial++; - ASSERT(mRequests.find(serial) == mRequests.end()); - - Buffer::MapRequestData request = {}; + MapRequestData request = {}; request.callback = callback; request.userdata = userdata; request.offset = offset; @@ -191,6 +183,8 @@ namespace dawn_wire { namespace client { request.type = MapRequestType::Write; } + uint64_t serial = mRequests.Add(std::move(request)); + // Serialize the command to send to the server. BufferMapAsyncCmd cmd; cmd.bufferId = this->id; @@ -200,26 +194,17 @@ namespace dawn_wire { namespace client { cmd.size = size; client->SerializeCommand(cmd); - - // Register this request so that we can retrieve it from its serial when the server - // sends the callback. - mRequests[serial] = std::move(request); } bool Buffer::OnMapAsyncCallback(uint64_t requestSerial, uint32_t status, uint64_t readDataUpdateInfoLength, const uint8_t* readDataUpdateInfo) { - auto requestIt = mRequests.find(requestSerial); - if (requestIt == mRequests.end()) { + MapRequestData request; + if (!mRequests.Acquire(requestSerial, &request)) { return false; } - auto request = std::move(requestIt->second); - // Delete the request before calling the callback otherwise the callback could be fired a - // second time. If, for example, buffer.Unmap() is called inside the callback. - mRequests.erase(requestIt); - auto FailRequest = [&request]() -> bool { if (request.callback != nullptr) { request.callback(WGPUBufferMapAsyncStatus_DeviceLost, request.userdata); @@ -352,11 +337,11 @@ namespace dawn_wire { namespace client { mMapSize = 0; // Tag all mapping requests still in flight as unmapped before callback. - for (auto& it : mRequests) { - if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) { - it.second.clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback; + mRequests.ForAll([](MapRequestData* request) { + if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) { + request->clientStatus = WGPUBufferMapAsyncStatus_UnmappedBeforeCallback; } - } + }); BufferUnmapCmd cmd; cmd.self = ToAPI(this); @@ -368,11 +353,11 @@ namespace dawn_wire { namespace client { FreeMappedData(); // Tag all mapping requests still in flight as destroyed before callback. - for (auto& it : mRequests) { - if (it.second.clientStatus == WGPUBufferMapAsyncStatus_Success) { - it.second.clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback; + mRequests.ForAll([](MapRequestData* request) { + if (request->clientStatus == WGPUBufferMapAsyncStatus_Success) { + request->clientStatus = WGPUBufferMapAsyncStatus_DestroyedBeforeCallback; } - } + }); BufferDestroyCmd cmd; cmd.self = ToAPI(this); diff --git a/src/dawn_wire/client/Buffer.h b/src/dawn_wire/client/Buffer.h index a7d3fabff7..0a24384389 100644 --- a/src/dawn_wire/client/Buffer.h +++ b/src/dawn_wire/client/Buffer.h @@ -19,8 +19,7 @@ #include "dawn_wire/WireClient.h" #include "dawn_wire/client/ObjectBase.h" - -#include +#include "dawn_wire/client/RequestTracker.h" namespace dawn_wire { namespace client { @@ -52,6 +51,7 @@ namespace dawn_wire { namespace client { private: void CancelCallbacksForDisconnect() override; + void ClearAllCallbacks(WGPUBufferMapAsyncStatus status); bool IsMappedForReading() const; bool IsMappedForWriting() const; @@ -86,8 +86,7 @@ namespace dawn_wire { namespace client { MapRequestType type = MapRequestType::None; }; - std::map mRequests; - uint64_t mRequestSerial = 0; + RequestTracker mRequests; uint64_t mSize = 0; // Only one mapped pointer can be active at a time because Unmap clears all the in-flight diff --git a/src/dawn_wire/client/Client.h b/src/dawn_wire/client/Client.h index 3616e37215..fc3758a0d8 100644 --- a/src/dawn_wire/client/Client.h +++ b/src/dawn_wire/client/Client.h @@ -19,6 +19,7 @@ #include #include "common/LinkedList.h" +#include "common/NonCopyable.h" #include "dawn_wire/ChunkedCommandSerializer.h" #include "dawn_wire/WireClient.h" #include "dawn_wire/WireCmd_autogen.h" diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp index 95be206c96..17f98a5fc6 100644 --- a/src/dawn_wire/client/Device.cpp +++ b/src/dawn_wire/client/Device.cpp @@ -48,26 +48,23 @@ namespace dawn_wire { namespace client { } Device::~Device() { - // Fire pending error scopes - auto errorScopes = std::move(mErrorScopes); - for (const auto& it : errorScopes) { - it.second.callback(WGPUErrorType_Unknown, "Device destroyed before callback", - it.second.userdata); - } + mErrorScopes.CloseAll([](ErrorScopeData* request) { + request->callback(WGPUErrorType_Unknown, "Device destroyed before callback", + request->userdata); + }); - auto createPipelineAsyncRequests = std::move(mCreatePipelineAsyncRequests); - for (const auto& it : createPipelineAsyncRequests) { - if (it.second.createComputePipelineAsyncCallback != nullptr) { - it.second.createComputePipelineAsyncCallback( + mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) { + if (request->createComputePipelineAsyncCallback != nullptr) { + request->createComputePipelineAsyncCallback( WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, - "Device destroyed before callback", it.second.userdata); + "Device destroyed before callback", request->userdata); } else { - ASSERT(it.second.createRenderPipelineAsyncCallback != nullptr); - it.second.createRenderPipelineAsyncCallback( + ASSERT(request->createRenderPipelineAsyncCallback != nullptr); + request->createRenderPipelineAsyncCallback( WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, - "Device destroyed before callback", it.second.userdata); + "Device destroyed before callback", request->userdata); } - } + }); } void Device::HandleError(WGPUErrorType errorType, const char* message) { @@ -91,25 +88,22 @@ namespace dawn_wire { namespace client { } void Device::CancelCallbacksForDisconnect() { - for (auto& it : mCreatePipelineAsyncRequests) { - ASSERT((it.second.createComputePipelineAsyncCallback != nullptr) ^ - (it.second.createRenderPipelineAsyncCallback != nullptr)); - if (it.second.createRenderPipelineAsyncCallback) { - it.second.createRenderPipelineAsyncCallback( - WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost", - it.second.userdata); - } else { - it.second.createComputePipelineAsyncCallback( - WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost", - it.second.userdata); - } - } - mCreatePipelineAsyncRequests.clear(); + mErrorScopes.CloseAll([](ErrorScopeData* request) { + request->callback(WGPUErrorType_DeviceLost, "Device lost", request->userdata); + }); - for (auto& it : mErrorScopes) { - it.second.callback(WGPUErrorType_DeviceLost, "Device lost", it.second.userdata); - } - mErrorScopes.clear(); + mCreatePipelineAsyncRequests.CloseAll([](CreatePipelineAsyncRequest* request) { + if (request->createComputePipelineAsyncCallback != nullptr) { + request->createComputePipelineAsyncCallback( + WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost", + request->userdata); + } else { + ASSERT(request->createRenderPipelineAsyncCallback != nullptr); + request->createRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, + nullptr, "Device lost", + request->userdata); + } + }); } std::weak_ptr Device::GetAliveWeakPtr() { @@ -152,10 +146,7 @@ namespace dawn_wire { namespace client { return true; } - uint64_t serial = mErrorScopeRequestSerial++; - ASSERT(mErrorScopes.find(serial) == mErrorScopes.end()); - - mErrorScopes[serial] = {callback, userdata}; + uint64_t serial = mErrorScopes.Add({callback, userdata}); DevicePopErrorScopeCmd cmd; cmd.deviceId = this->id; @@ -180,14 +171,11 @@ namespace dawn_wire { namespace client { return false; } - auto requestIt = mErrorScopes.find(requestSerial); - if (requestIt == mErrorScopes.end()) { + ErrorScopeData request; + if (!mErrorScopes.Acquire(requestSerial, &request)) { return false; } - ErrorScopeData request = std::move(requestIt->second); - - mErrorScopes.erase(requestIt); request.callback(type, message, request.userdata); return true; } @@ -265,9 +253,6 @@ namespace dawn_wire { namespace client { "GPU device disconnected", userdata); } - DeviceCreateComputePipelineAsyncCmd cmd; - cmd.deviceId = this->id; - // Copy compute to the deprecated computeStage or visa-versa, depending on which one is // populated, so that serialization doesn't fail. // TODO(dawn:800): Remove once computeStage is removed. @@ -280,35 +265,32 @@ namespace dawn_wire { namespace client { localDescriptor.compute.entryPoint = localDescriptor.computeStage.entryPoint; } - cmd.descriptor = &localDescriptor; - - uint64_t serial = mCreatePipelineAsyncRequestSerial++; - ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end()); - cmd.requestSerial = serial; - auto* allocation = client->ComputePipelineAllocator().New(client); + CreatePipelineAsyncRequest request = {}; request.createComputePipelineAsyncCallback = callback; request.userdata = userdata; request.pipelineObjectID = allocation->object->id; - cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation}; - client->SerializeCommand(cmd); + uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request)); - mCreatePipelineAsyncRequests[serial] = std::move(request); + DeviceCreateComputePipelineAsyncCmd cmd; + cmd.deviceId = this->id; + cmd.descriptor = &localDescriptor; + cmd.requestSerial = serial; + cmd.pipelineObjectHandle = ObjectHandle{allocation->object->id, allocation->generation}; + + client->SerializeCommand(cmd); } bool Device::OnCreateComputePipelineAsyncCallback(uint64_t requestSerial, WGPUCreatePipelineAsyncStatus status, const char* message) { - const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial); - if (requestIt == mCreatePipelineAsyncRequests.end()) { + CreatePipelineAsyncRequest request; + if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) { return false; } - CreatePipelineAsyncRequest request = std::move(requestIt->second); - mCreatePipelineAsyncRequests.erase(requestIt); - auto pipelineAllocation = client->ComputePipelineAllocator().GetObject(request.pipelineObjectID); @@ -333,37 +315,33 @@ namespace dawn_wire { namespace client { return callback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "GPU device disconnected", userdata); } - DeviceCreateRenderPipelineAsyncCmd cmd; - cmd.deviceId = this->id; - cmd.descriptor = descriptor; - - uint64_t serial = mCreatePipelineAsyncRequestSerial++; - ASSERT(mCreatePipelineAsyncRequests.find(serial) == mCreatePipelineAsyncRequests.end()); - cmd.requestSerial = serial; auto* allocation = client->RenderPipelineAllocator().New(client); + CreatePipelineAsyncRequest request = {}; request.createRenderPipelineAsyncCallback = callback; request.userdata = userdata; request.pipelineObjectID = allocation->object->id; - cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation); - client->SerializeCommand(cmd); + uint64_t serial = mCreatePipelineAsyncRequests.Add(std::move(request)); - mCreatePipelineAsyncRequests[serial] = std::move(request); + DeviceCreateRenderPipelineAsyncCmd cmd; + cmd.deviceId = this->id; + cmd.descriptor = descriptor; + cmd.requestSerial = serial; + cmd.pipelineObjectHandle = ObjectHandle(allocation->object->id, allocation->generation); + + client->SerializeCommand(cmd); } bool Device::OnCreateRenderPipelineAsyncCallback(uint64_t requestSerial, WGPUCreatePipelineAsyncStatus status, const char* message) { - const auto& requestIt = mCreatePipelineAsyncRequests.find(requestSerial); - if (requestIt == mCreatePipelineAsyncRequests.end()) { + CreatePipelineAsyncRequest request; + if (!mCreatePipelineAsyncRequests.Acquire(requestSerial, &request)) { return false; } - CreatePipelineAsyncRequest request = std::move(requestIt->second); - mCreatePipelineAsyncRequests.erase(requestIt); - auto pipelineAllocation = client->RenderPipelineAllocator().GetObject(request.pipelineObjectID); diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h index 0bc2ca30c3..849364fdc3 100644 --- a/src/dawn_wire/client/Device.h +++ b/src/dawn_wire/client/Device.h @@ -21,8 +21,8 @@ #include "dawn_wire/WireCmd_autogen.h" #include "dawn_wire/client/ApiObjects_autogen.h" #include "dawn_wire/client/ObjectBase.h" +#include "dawn_wire/client/RequestTracker.h" -#include #include namespace dawn_wire { namespace client { @@ -75,8 +75,7 @@ namespace dawn_wire { namespace client { WGPUErrorCallback callback = nullptr; void* userdata = nullptr; }; - std::map mErrorScopes; - uint64_t mErrorScopeRequestSerial = 0; + RequestTracker mErrorScopes; uint64_t mErrorScopeStackSize = 0; struct CreatePipelineAsyncRequest { @@ -85,8 +84,7 @@ namespace dawn_wire { namespace client { void* userdata = nullptr; ObjectId pipelineObjectID; }; - std::map mCreatePipelineAsyncRequests; - uint64_t mCreatePipelineAsyncRequestSerial = 0; + RequestTracker mCreatePipelineAsyncRequests; WGPUErrorCallback mErrorCallback = nullptr; WGPUDeviceLostCallback mDeviceLostCallback = nullptr; diff --git a/src/dawn_wire/client/Queue.cpp b/src/dawn_wire/client/Queue.cpp index 1ac8c77819..098ddc5afc 100644 --- a/src/dawn_wire/client/Queue.cpp +++ b/src/dawn_wire/client/Queue.cpp @@ -24,17 +24,11 @@ namespace dawn_wire { namespace client { } bool Queue::OnWorkDoneCallback(uint64_t requestSerial, WGPUQueueWorkDoneStatus status) { - auto requestIt = mOnWorkDoneRequests.find(requestSerial); - if (requestIt == mOnWorkDoneRequests.end()) { + OnWorkDoneData request; + if (!mOnWorkDoneRequests.Acquire(requestSerial, &request)) { return false; } - // Remove the request data so that the callback cannot be called again. - // ex.) inside the callback: if the queue is deleted (when there are multiple queues), - // all callbacks reject. - OnWorkDoneData request = std::move(requestIt->second); - mOnWorkDoneRequests.erase(requestIt); - request.callback(status, request.userdata); return true; } @@ -47,16 +41,13 @@ namespace dawn_wire { namespace client { return; } - uint32_t serial = mOnWorkDoneSerial++; - ASSERT(mOnWorkDoneRequests.find(serial) == mOnWorkDoneRequests.end()); + uint64_t serial = mOnWorkDoneRequests.Add({callback, userdata}); QueueOnSubmittedWorkDoneCmd cmd; cmd.queueId = this->id; cmd.signalValue = signalValue; cmd.requestSerial = serial; - mOnWorkDoneRequests[serial] = {callback, userdata}; - client->SerializeCommand(cmd); } @@ -97,12 +88,11 @@ namespace dawn_wire { namespace client { } void Queue::ClearAllCallbacks(WGPUQueueWorkDoneStatus status) { - for (auto& it : mOnWorkDoneRequests) { - if (it.second.callback) { - it.second.callback(status, it.second.userdata); + mOnWorkDoneRequests.CloseAll([status](OnWorkDoneData* request) { + if (request->callback != nullptr) { + request->callback(status, request->userdata); } - } - mOnWorkDoneRequests.clear(); + }); } }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/Queue.h b/src/dawn_wire/client/Queue.h index d8e93a3106..901acac2d4 100644 --- a/src/dawn_wire/client/Queue.h +++ b/src/dawn_wire/client/Queue.h @@ -19,8 +19,7 @@ #include "dawn_wire/WireClient.h" #include "dawn_wire/client/ObjectBase.h" - -#include +#include "dawn_wire/client/RequestTracker.h" namespace dawn_wire { namespace client { @@ -44,15 +43,13 @@ namespace dawn_wire { namespace client { private: void CancelCallbacksForDisconnect() override; - void ClearAllCallbacks(WGPUQueueWorkDoneStatus status); struct OnWorkDoneData { WGPUQueueWorkDoneCallback callback = nullptr; void* userdata = nullptr; }; - uint64_t mOnWorkDoneSerial = 0; - std::map mOnWorkDoneRequests; + RequestTracker mOnWorkDoneRequests; }; }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/RequestTracker.h b/src/dawn_wire/client/RequestTracker.h new file mode 100644 index 0000000000..7ce2d0004f --- /dev/null +++ b/src/dawn_wire/client/RequestTracker.h @@ -0,0 +1,82 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWNWIRE_CLIENT_REQUESTTRACKER_H_ +#define DAWNWIRE_CLIENT_REQUESTTRACKER_H_ + +#include "common/Assert.h" +#include "common/NonCopyable.h" + +#include +#include + +namespace dawn_wire { namespace client { + + class Device; + class MemoryTransferService; + + template + class RequestTracker : NonCopyable { + public: + ~RequestTracker() { + ASSERT(mRequests.empty()); + } + + uint64_t Add(Request&& request) { + mSerial++; + mRequests.emplace(mSerial, request); + return mSerial; + } + + bool Acquire(uint64_t serial, Request* request) { + auto it = mRequests.find(serial); + if (it == mRequests.end()) { + return false; + } + *request = std::move(it->second); + mRequests.erase(it); + return true; + } + + template + void CloseAll(CloseFunc&& closeFunc) { + // Call closeFunc on all requests while handling reentrancy where the callback of some + // requests may add some additional requests. We guarantee all callbacks for requests + // are called exactly onces, so keep closing new requests if the first batch added more. + // It is fine to loop infinitely here if that's what the application makes use do. + while (!mRequests.empty()) { + // Move mRequests to a local variable so that further reentrant modifications of + // mRequests don't invalidate the iterators. + auto allRequests = std::move(mRequests); + for (auto& it : allRequests) { + closeFunc(&it.second); + } + } + } + + template + void ForAll(F&& f) { + for (auto& it : mRequests) { + f(&it.second); + } + } + + private: + uint64_t mSerial; + std::map mRequests; + }; + +}} // namespace dawn_wire::client + +#endif // DAWNWIRE_CLIENT_REQUESTTRACKER_H_ diff --git a/src/dawn_wire/client/ShaderModule.cpp b/src/dawn_wire/client/ShaderModule.cpp index fa7945aed0..c28b978c3a 100644 --- a/src/dawn_wire/client/ShaderModule.cpp +++ b/src/dawn_wire/client/ShaderModule.cpp @@ -19,15 +19,7 @@ namespace dawn_wire { namespace client { ShaderModule::~ShaderModule() { - // Callbacks need to be fired in all cases, as they can handle freeing resources. So we call - // them with "Unknown" status. - for (auto& it : mCompilationInfoRequests) { - if (it.second.callback) { - it.second.callback(WGPUCompilationInfoRequestStatus_Unknown, nullptr, - it.second.userdata); - } - } - mCompilationInfoRequests.clear(); + ClearAllCallbacks(WGPUCompilationInfoRequestStatus_Unknown); } void ShaderModule::GetCompilationInfo(WGPUCompilationInfoCallback callback, void* userdata) { @@ -36,41 +28,37 @@ namespace dawn_wire { namespace client { return; } - uint64_t serial = mCompilationInfoRequestSerial++; + uint64_t serial = mCompilationInfoRequests.Add({callback, userdata}); + ShaderModuleGetCompilationInfoCmd cmd; cmd.shaderModuleId = this->id; cmd.requestSerial = serial; - mCompilationInfoRequests[serial] = {callback, userdata}; - client->SerializeCommand(cmd); } bool ShaderModule::GetCompilationInfoCallback(uint64_t requestSerial, WGPUCompilationInfoRequestStatus status, const WGPUCompilationInfo* info) { - auto requestIt = mCompilationInfoRequests.find(requestSerial); - if (requestIt == mCompilationInfoRequests.end()) { + CompilationInfoRequest request; + if (!mCompilationInfoRequests.Acquire(requestSerial, &request)) { return false; } - // Remove the request data so that the callback cannot be called again. - // ex.) inside the callback: if the shader module is deleted, all callbacks reject. - CompilationInfoRequest request = std::move(requestIt->second); - mCompilationInfoRequests.erase(requestIt); - request.callback(status, info, request.userdata); return true; } void ShaderModule::CancelCallbacksForDisconnect() { - for (auto& it : mCompilationInfoRequests) { - if (it.second.callback) { - it.second.callback(WGPUCompilationInfoRequestStatus_DeviceLost, nullptr, - it.second.userdata); + ClearAllCallbacks(WGPUCompilationInfoRequestStatus_DeviceLost); + } + + void ShaderModule::ClearAllCallbacks(WGPUCompilationInfoRequestStatus status) { + mCompilationInfoRequests.CloseAll([status](CompilationInfoRequest* request) { + if (request->callback != nullptr) { + request->callback(status, nullptr, request->userdata); } - } - mCompilationInfoRequests.clear(); + }); } }} // namespace dawn_wire::client diff --git a/src/dawn_wire/client/ShaderModule.h b/src/dawn_wire/client/ShaderModule.h index d7ac55d667..f12a4d0f1b 100644 --- a/src/dawn_wire/client/ShaderModule.h +++ b/src/dawn_wire/client/ShaderModule.h @@ -17,8 +17,8 @@ #include -#include "common/SerialMap.h" #include "dawn_wire/client/ObjectBase.h" +#include "dawn_wire/client/RequestTracker.h" namespace dawn_wire { namespace client { @@ -32,15 +32,15 @@ namespace dawn_wire { namespace client { WGPUCompilationInfoRequestStatus status, const WGPUCompilationInfo* info); - void CancelCallbacksForDisconnect() override; - private: + void CancelCallbacksForDisconnect() override; + void ClearAllCallbacks(WGPUCompilationInfoRequestStatus status); + struct CompilationInfoRequest { WGPUCompilationInfoCallback callback = nullptr; void* userdata = nullptr; }; - uint64_t mCompilationInfoRequestSerial = 0; - std::map mCompilationInfoRequests; + RequestTracker mCompilationInfoRequests; }; }} // namespace dawn_wire::client