diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn index 711fa8c80a..8e4f0c4ddf 100644 --- a/src/dawn_native/BUILD.gn +++ b/src/dawn_native/BUILD.gn @@ -184,6 +184,8 @@ source_set("dawn_native_sources") { "Buffer.h", "CachedObject.cpp", "CachedObject.h", + "CallbackTaskManager.cpp", + "CallbackTaskManager.h", "CommandAllocator.cpp", "CommandAllocator.h", "CommandBuffer.cpp", @@ -204,8 +206,8 @@ source_set("dawn_native_sources") { "ComputePipeline.h", "CopyTextureForBrowserHelper.cpp", "CopyTextureForBrowserHelper.h", - "CreatePipelineAsyncTracker.cpp", - "CreatePipelineAsyncTracker.h", + "CreatePipelineAsyncTask.cpp", + "CreatePipelineAsyncTask.h", "Device.cpp", "Device.h", "DynamicUploader.cpp", diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt index ba6cec9d32..b0d470b6ed 100644 --- a/src/dawn_native/CMakeLists.txt +++ b/src/dawn_native/CMakeLists.txt @@ -50,6 +50,8 @@ target_sources(dawn_native PRIVATE "Buffer.h" "CachedObject.cpp" "CachedObject.h" + "CallbackTaskManager.cpp" + "CallbackTaskManager.h" "CommandAllocator.cpp" "CommandAllocator.h" "CommandBuffer.cpp" @@ -70,8 +72,8 @@ target_sources(dawn_native PRIVATE "ComputePipeline.h" "CopyTextureForBrowserHelper.cpp" "CopyTextureForBrowserHelper.h" - "CreatePipelineAsyncTracker.cpp" - "CreatePipelineAsyncTracker.h" + "CreatePipelineAsyncTask.cpp" + "CreatePipelineAsyncTask.h" "Device.cpp" "Device.h" "DynamicUploader.cpp" diff --git a/src/dawn_native/CallbackTaskManager.cpp b/src/dawn_native/CallbackTaskManager.cpp new file mode 100644 index 0000000000..1c9106c261 --- /dev/null +++ b/src/dawn_native/CallbackTaskManager.cpp @@ -0,0 +1,37 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn_native/CallbackTaskManager.h" + +namespace dawn_native { + + bool CallbackTaskManager::IsEmpty() { + std::lock_guard lock(mCallbackTaskQueueMutex); + return mCallbackTaskQueue.empty(); + } + + std::vector> CallbackTaskManager::AcquireCallbackTasks() { + std::lock_guard lock(mCallbackTaskQueueMutex); + + std::vector> allTasks; + allTasks.swap(mCallbackTaskQueue); + return allTasks; + } + + void CallbackTaskManager::AddCallbackTask(std::unique_ptr callbackTask) { + std::lock_guard lock(mCallbackTaskQueueMutex); + mCallbackTaskQueue.push_back(std::move(callbackTask)); + } + +} // namespace dawn_native diff --git a/src/dawn_native/CallbackTaskManager.h b/src/dawn_native/CallbackTaskManager.h new file mode 100644 index 0000000000..1be0eb22b0 --- /dev/null +++ b/src/dawn_native/CallbackTaskManager.h @@ -0,0 +1,47 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWNNATIVE_CALLBACK_TASK_MANAGER_H_ +#define DAWNNATIVE_CALLBACK_TASK_MANAGER_H_ + +#include +#include +#include + +namespace dawn_native { + + class CallbackTaskManager; + + struct CallbackTask { + public: + virtual ~CallbackTask() = default; + virtual void Finish() = 0; + virtual void HandleShutDown() = 0; + virtual void HandleDeviceLoss() = 0; + }; + + class CallbackTaskManager { + public: + void AddCallbackTask(std::unique_ptr callbackTask); + bool IsEmpty(); + std::vector> AcquireCallbackTasks(); + + private: + std::mutex mCallbackTaskQueueMutex; + std::vector> mCallbackTaskQueue; + }; + +} // namespace dawn_native + +#endif diff --git a/src/dawn_native/CreatePipelineAsyncTracker.cpp b/src/dawn_native/CreatePipelineAsyncTask.cpp similarity index 55% rename from src/dawn_native/CreatePipelineAsyncTracker.cpp rename to src/dawn_native/CreatePipelineAsyncTask.cpp index 23b8310c43..b6a32b12e7 100644 --- a/src/dawn_native/CreatePipelineAsyncTracker.cpp +++ b/src/dawn_native/CreatePipelineAsyncTask.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "dawn_native/CreatePipelineAsyncTracker.h" +#include "dawn_native/CreatePipelineAsyncTask.h" #include "dawn_native/ComputePipeline.h" #include "dawn_native/Device.h" @@ -20,25 +20,23 @@ namespace dawn_native { - CreatePipelineAsyncTaskBase::CreatePipelineAsyncTaskBase(std::string errorMessage, - void* userdata) + CreatePipelineAsyncCallbackTaskBase::CreatePipelineAsyncCallbackTaskBase( + std::string errorMessage, + void* userdata) : mErrorMessage(errorMessage), mUserData(userdata) { } - CreatePipelineAsyncTaskBase::~CreatePipelineAsyncTaskBase() { - } - - CreateComputePipelineAsyncTask::CreateComputePipelineAsyncTask( + CreateComputePipelineAsyncCallbackTask::CreateComputePipelineAsyncCallbackTask( Ref pipeline, std::string errorMessage, WGPUCreateComputePipelineAsyncCallback callback, void* userdata) - : CreatePipelineAsyncTaskBase(errorMessage, userdata), + : CreatePipelineAsyncCallbackTaskBase(errorMessage, userdata), mPipeline(std::move(pipeline)), mCreateComputePipelineAsyncCallback(callback) { } - void CreateComputePipelineAsyncTask::Finish() { + void CreateComputePipelineAsyncCallbackTask::Finish() { ASSERT(mCreateComputePipelineAsyncCallback != nullptr); if (mPipeline.Get() != nullptr) { @@ -51,31 +49,31 @@ namespace dawn_native { } } - void CreateComputePipelineAsyncTask::HandleShutDown() { + void CreateComputePipelineAsyncCallbackTask::HandleShutDown() { ASSERT(mCreateComputePipelineAsyncCallback != nullptr); mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, "Device destroyed before callback", mUserData); } - void CreateComputePipelineAsyncTask::HandleDeviceLoss() { + void CreateComputePipelineAsyncCallbackTask::HandleDeviceLoss() { ASSERT(mCreateComputePipelineAsyncCallback != nullptr); mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost before callback", mUserData); } - CreateRenderPipelineAsyncTask::CreateRenderPipelineAsyncTask( + CreateRenderPipelineAsyncCallbackTask::CreateRenderPipelineAsyncCallbackTask( Ref pipeline, std::string errorMessage, WGPUCreateRenderPipelineAsyncCallback callback, void* userdata) - : CreatePipelineAsyncTaskBase(errorMessage, userdata), + : CreatePipelineAsyncCallbackTaskBase(errorMessage, userdata), mPipeline(std::move(pipeline)), mCreateRenderPipelineAsyncCallback(callback) { } - void CreateRenderPipelineAsyncTask::Finish() { + void CreateRenderPipelineAsyncCallbackTask::Finish() { ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); if (mPipeline.Get() != nullptr) { @@ -88,62 +86,18 @@ namespace dawn_native { } } - void CreateRenderPipelineAsyncTask::HandleShutDown() { + void CreateRenderPipelineAsyncCallbackTask::HandleShutDown() { ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, "Device destroyed before callback", mUserData); } - void CreateRenderPipelineAsyncTask::HandleDeviceLoss() { + void CreateRenderPipelineAsyncCallbackTask::HandleDeviceLoss() { ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, "Device lost before callback", mUserData); } - CreatePipelineAsyncTracker::CreatePipelineAsyncTracker(DeviceBase* device) : mDevice(device) { - } - - CreatePipelineAsyncTracker::~CreatePipelineAsyncTracker() { - ASSERT(mCreatePipelineAsyncTasksInFlight.Empty()); - } - - void CreatePipelineAsyncTracker::TrackTask(std::unique_ptr task, - ExecutionSerial serial) { - mCreatePipelineAsyncTasksInFlight.Enqueue(std::move(task), serial); - mDevice->AddFutureSerial(serial); - } - - void CreatePipelineAsyncTracker::Tick(ExecutionSerial finishedSerial) { - // If a user calls Queue::Submit inside Create*PipelineAsync, then the device will be - // ticked, which in turns ticks the tracker, causing reentrance here. To prevent the - // reentrant call from invalidating mCreatePipelineAsyncTasksInFlight while in use by the - // first call, we remove the tasks to finish from the queue, update - // mCreatePipelineAsyncTasksInFlight, then run the callbacks. - std::vector> tasks; - for (auto& task : mCreatePipelineAsyncTasksInFlight.IterateUpTo(finishedSerial)) { - tasks.push_back(std::move(task)); - } - mCreatePipelineAsyncTasksInFlight.ClearUpTo(finishedSerial); - - for (auto& task : tasks) { - task->Finish(); - } - } - - void CreatePipelineAsyncTracker::ClearForShutDown() { - for (auto& task : mCreatePipelineAsyncTasksInFlight.IterateAll()) { - task->HandleShutDown(); - } - mCreatePipelineAsyncTasksInFlight.Clear(); - } - - void CreatePipelineAsyncTracker::ClearForDeviceLoss() { - for (auto& task : mCreatePipelineAsyncTasksInFlight.IterateAll()) { - task->HandleDeviceLoss(); - } - mCreatePipelineAsyncTasksInFlight.Clear(); - } - } // namespace dawn_native diff --git a/src/dawn_native/CreatePipelineAsyncTask.h b/src/dawn_native/CreatePipelineAsyncTask.h new file mode 100644 index 0000000000..9cddfa2e34 --- /dev/null +++ b/src/dawn_native/CreatePipelineAsyncTask.h @@ -0,0 +1,68 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWNNATIVE_CREATEPIPELINEASYNCTASK_H_ +#define DAWNNATIVE_CREATEPIPELINEASYNCTASK_H_ + +#include "common/RefCounted.h" +#include "dawn/webgpu.h" +#include "dawn_native/CallbackTaskManager.h" + +namespace dawn_native { + + class ComputePipelineBase; + class DeviceBase; + class RenderPipelineBase; + + struct CreatePipelineAsyncCallbackTaskBase : CallbackTask { + CreatePipelineAsyncCallbackTaskBase(std::string errorMessage, void* userData); + + protected: + std::string mErrorMessage; + void* mUserData; + }; + + struct CreateComputePipelineAsyncCallbackTask final : CreatePipelineAsyncCallbackTaskBase { + CreateComputePipelineAsyncCallbackTask(Ref pipeline, + std::string errorMessage, + WGPUCreateComputePipelineAsyncCallback callback, + void* userdata); + + void Finish() final; + void HandleShutDown() final; + void HandleDeviceLoss() final; + + private: + Ref mPipeline; + WGPUCreateComputePipelineAsyncCallback mCreateComputePipelineAsyncCallback; + }; + + struct CreateRenderPipelineAsyncCallbackTask final : CreatePipelineAsyncCallbackTaskBase { + CreateRenderPipelineAsyncCallbackTask(Ref pipeline, + std::string errorMessage, + WGPUCreateRenderPipelineAsyncCallback callback, + void* userdata); + + void Finish() final; + void HandleShutDown() final; + void HandleDeviceLoss() final; + + private: + Ref mPipeline; + WGPUCreateRenderPipelineAsyncCallback mCreateRenderPipelineAsyncCallback; + }; + +} // namespace dawn_native + +#endif // DAWNNATIVE_CREATEPIPELINEASYNCTASK_H_ diff --git a/src/dawn_native/CreatePipelineAsyncTracker.h b/src/dawn_native/CreatePipelineAsyncTracker.h deleted file mode 100644 index 738d71930f..0000000000 --- a/src/dawn_native/CreatePipelineAsyncTracker.h +++ /dev/null @@ -1,93 +0,0 @@ -// Copyright 2020 The Dawn Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef DAWNNATIVE_CREATEPIPELINEASYNCTRACKER_H_ -#define DAWNNATIVE_CREATEPIPELINEASYNCTRACKER_H_ - -#include "common/RefCounted.h" -#include "common/SerialQueue.h" -#include "dawn/webgpu.h" -#include "dawn_native/IntegerTypes.h" - -#include -#include - -namespace dawn_native { - - class ComputePipelineBase; - class DeviceBase; - class RenderPipelineBase; - - struct CreatePipelineAsyncTaskBase { - CreatePipelineAsyncTaskBase(std::string errorMessage, void* userData); - virtual ~CreatePipelineAsyncTaskBase(); - - virtual void Finish() = 0; - virtual void HandleShutDown() = 0; - virtual void HandleDeviceLoss() = 0; - - protected: - std::string mErrorMessage; - void* mUserData; - }; - - struct CreateComputePipelineAsyncTask final : public CreatePipelineAsyncTaskBase { - CreateComputePipelineAsyncTask(Ref pipeline, - std::string errorMessage, - WGPUCreateComputePipelineAsyncCallback callback, - void* userdata); - - void Finish() final; - void HandleShutDown() final; - void HandleDeviceLoss() final; - - private: - Ref mPipeline; - WGPUCreateComputePipelineAsyncCallback mCreateComputePipelineAsyncCallback; - }; - - struct CreateRenderPipelineAsyncTask final : public CreatePipelineAsyncTaskBase { - CreateRenderPipelineAsyncTask(Ref pipeline, - std::string errorMessage, - WGPUCreateRenderPipelineAsyncCallback callback, - void* userdata); - - void Finish() final; - void HandleShutDown() final; - void HandleDeviceLoss() final; - - private: - Ref mPipeline; - WGPUCreateRenderPipelineAsyncCallback mCreateRenderPipelineAsyncCallback; - }; - - class CreatePipelineAsyncTracker { - public: - explicit CreatePipelineAsyncTracker(DeviceBase* device); - ~CreatePipelineAsyncTracker(); - - void TrackTask(std::unique_ptr task, ExecutionSerial serial); - void Tick(ExecutionSerial finishedSerial); - void ClearForShutDown(); - void ClearForDeviceLoss(); - - private: - DeviceBase* mDevice; - SerialQueue> - mCreatePipelineAsyncTasksInFlight; - }; - -} // namespace dawn_native - -#endif // DAWNNATIVE_CREATEPIPELINEASYNCTRACKER_H_ diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp index 713b88eb2f..a2443b2d0b 100644 --- a/src/dawn_native/Device.cpp +++ b/src/dawn_native/Device.cpp @@ -20,11 +20,12 @@ #include "dawn_native/BindGroup.h" #include "dawn_native/BindGroupLayout.h" #include "dawn_native/Buffer.h" +#include "dawn_native/CallbackTaskManager.h" #include "dawn_native/CommandBuffer.h" #include "dawn_native/CommandEncoder.h" #include "dawn_native/CompilationMessages.h" #include "dawn_native/ComputePipeline.h" -#include "dawn_native/CreatePipelineAsyncTracker.h" +#include "dawn_native/CreatePipelineAsyncTask.h" #include "dawn_native/DynamicUploader.h" #include "dawn_native/ErrorData.h" #include "dawn_native/ErrorScope.h" @@ -125,7 +126,7 @@ namespace dawn_native { mCaches = std::make_unique(); mErrorScopeStack = std::make_unique(); mDynamicUploader = std::make_unique(this); - mCreatePipelineAsyncTracker = std::make_unique(this); + mCallbackTaskManager = std::make_unique(); mDeprecationWarnings = std::make_unique(); mInternalPipelineStore = std::make_unique(); mPersistentCache = std::make_unique(this); @@ -142,8 +143,11 @@ namespace dawn_native { void DeviceBase::ShutDownBase() { // Skip handling device facilities if they haven't even been created (or failed doing so) if (mState != State::BeingCreated) { - // Reject all async pipeline creations. - mCreatePipelineAsyncTracker->ClearForShutDown(); + // Call all the callbacks immediately as the device is about to shut down. + auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks(); + for (std::unique_ptr& callbackTask : callbackTasks) { + callbackTask->HandleShutDown(); + } } // Disconnect the device, depending on which state we are currently in. @@ -188,7 +192,7 @@ namespace dawn_native { mState = State::Disconnected; mDynamicUploader = nullptr; - mCreatePipelineAsyncTracker = nullptr; + mCallbackTaskManager = nullptr; mPersistentCache = nullptr; mEmptyBindGroupLayout = nullptr; @@ -238,7 +242,10 @@ namespace dawn_native { } mQueue->HandleDeviceLoss(); - mCreatePipelineAsyncTracker->ClearForDeviceLoss(); + auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks(); + for (std::unique_ptr& callbackTask : callbackTasks) { + callbackTask->HandleDeviceLoss(); + } // Still forward device loss errors to the error scopes so they all reject. mErrorScopeStack->HandleError(ToWGPUErrorType(type), message); @@ -766,10 +773,10 @@ namespace dawn_native { } Ref result = maybeResult.AcquireSuccess(); - std::unique_ptr request = - std::make_unique(std::move(result), "", callback, - userdata); - mCreatePipelineAsyncTracker->TrackTask(std::move(request), GetPendingCommandSerial()); + std::unique_ptr callbackTask = + std::make_unique(std::move(result), "", callback, + userdata); + mCallbackTaskManager->AddCallbackTask(std::move(callbackTask)); } RenderBundleEncoder* DeviceBase::APICreateRenderBundleEncoder( const RenderBundleEncoderDescriptor* descriptor) { @@ -951,8 +958,19 @@ namespace dawn_native { // reclaiming resources one tick earlier. mDynamicUploader->Deallocate(mCompletedSerial); mQueue->Tick(mCompletedSerial); + } - mCreatePipelineAsyncTracker->Tick(mCompletedSerial); + // We have to check mCallbackTaskManager in every Tick because it is not related to any + // global serials. + if (!mCallbackTaskManager->IsEmpty()) { + // If a user calls Queue::Submit inside the callback, then the device will be ticked, + // which in turns ticks the tracker, causing reentrance and dead lock here. To prevent + // such reentrant call, we remove all the callback tasks from mCallbackTaskManager, + // update mCallbackTaskManager, then call all the callbacks. + auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks(); + for (std::unique_ptr& callbackTask : callbackTasks) { + callbackTask->Finish(); + } } return {}; @@ -1158,10 +1176,10 @@ namespace dawn_native { result = AddOrGetCachedPipeline(resultOrError.AcquireSuccess(), blueprintHash); } - std::unique_ptr request = - std::make_unique(result, errorMessage, callback, - userdata); - mCreatePipelineAsyncTracker->TrackTask(std::move(request), GetPendingCommandSerial()); + std::unique_ptr callbackTask = + std::make_unique( + std::move(result), errorMessage, callback, userdata); + mCallbackTaskManager->AddCallbackTask(std::move(callbackTask)); } ResultOrError> DeviceBase::CreatePipelineLayout( diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h index 87d567f75f..134bff89bc 100644 --- a/src/dawn_native/Device.h +++ b/src/dawn_native/Device.h @@ -34,7 +34,7 @@ namespace dawn_native { class AttachmentState; class AttachmentStateBlueprint; class BindGroupLayoutBase; - class CreatePipelineAsyncTracker; + class CallbackTaskManager; class DynamicUploader; class ErrorScopeStack; class ExternalTextureBase; @@ -402,7 +402,7 @@ namespace dawn_native { Ref mEmptyBindGroupLayout; std::unique_ptr mDynamicUploader; - std::unique_ptr mCreatePipelineAsyncTracker; + std::unique_ptr mCallbackTaskManager; Ref mQueue; struct DeprecationWarnings;