diff --git a/src/dawn/native/CallbackTaskManager.cpp b/src/dawn/native/CallbackTaskManager.cpp index 8a19c1c463..ea72b969e0 100644 --- a/src/dawn/native/CallbackTaskManager.cpp +++ b/src/dawn/native/CallbackTaskManager.cpp @@ -102,4 +102,17 @@ void CallbackTaskManager::HandleShutDown() { } } +void CallbackTaskManager::Flush() { + if (!IsEmpty()) { + // If a user calls Queue::Submit inside the callback, then the device will be ticked, + // which in turns ticks the tracker, causing reentrance and dead lock here. To prevent + // such reentrant call, we remove all the callback tasks from mCallbackTaskManager, + // update mCallbackTaskManager, then call all the callbacks. + auto callbackTasks = AcquireCallbackTasks(); + for (std::unique_ptr& callbackTask : callbackTasks) { + callbackTask->Execute(); + } + } +} + } // namespace dawn::native diff --git a/src/dawn/native/CallbackTaskManager.h b/src/dawn/native/CallbackTaskManager.h index 266b80c66f..3cf00912c4 100644 --- a/src/dawn/native/CallbackTaskManager.h +++ b/src/dawn/native/CallbackTaskManager.h @@ -20,6 +20,7 @@ #include #include +#include "dawn/common/RefCounted.h" #include "dawn/common/TypeTraits.h" namespace dawn::native { @@ -47,10 +48,10 @@ struct CallbackTask { State mState = State::Normal; }; -class CallbackTaskManager { +class CallbackTaskManager : public RefCounted { public: CallbackTaskManager(); - ~CallbackTaskManager(); + ~CallbackTaskManager() override; void AddCallbackTask(std::unique_ptr callbackTask); void AddCallbackTask(std::function callback); @@ -63,9 +64,11 @@ class CallbackTaskManager { bool IsEmpty(); void HandleDeviceLoss(); void HandleShutDown(); - std::vector> AcquireCallbackTasks(); + void Flush(); private: + std::vector> AcquireCallbackTasks(); + std::mutex mCallbackTaskQueueMutex; std::vector> mCallbackTaskQueue; }; diff --git a/src/dawn/native/Device.cpp b/src/dawn/native/Device.cpp index d8b2effbaf..e795bc2f04 100644 --- a/src/dawn/native/Device.cpp +++ b/src/dawn/native/Device.cpp @@ -216,11 +216,6 @@ DeviceBase::~DeviceBase() { // We need to explicitly release the Queue before we complete the destructor so that the // Queue does not get destroyed after the Device. mQueue = nullptr; - // mAdapter is not set for mock test devices. - // TODO(crbug.com/dawn/1702): using a mock adapter could avoid the null checking. - if (mAdapter != nullptr) { - mAdapter->GetInstance()->RemoveDevice(this); - } } MaybeError DeviceBase::Initialize(Ref defaultQueue) { @@ -253,7 +248,7 @@ MaybeError DeviceBase::Initialize(Ref defaultQueue) { mCaches = std::make_unique(); mErrorScopeStack = std::make_unique(); mDynamicUploader = std::make_unique(this); - mCallbackTaskManager = std::make_unique(); + mCallbackTaskManager = AcquireRef(new CallbackTaskManager()); mDeprecationWarnings = std::make_unique(); mInternalPipelineStore = std::make_unique(this); @@ -311,7 +306,9 @@ void DeviceBase::WillDropLastExternalRef() { Destroy(); // Flush last remaining callback tasks. - FlushCallbackTaskQueue(); + do { + mCallbackTaskManager->Flush(); + } while (!mCallbackTaskManager->IsEmpty()); // Drop te device's reference to the queue. Because the application dropped the last external // references, they can no longer get the queue from APIGetQueue(). @@ -328,6 +325,16 @@ void DeviceBase::WillDropLastExternalRef() { dawn::WarningLog() << "Device lost after last external device reference dropped.\n" << message; }; + + // mAdapter is not set for mock test devices. + // TODO(crbug.com/dawn/1702): using a mock adapter could avoid the null checking. + if (mAdapter != nullptr) { + mAdapter->GetInstance()->RemoveDevice(this); + + // Once last external ref dropped, all callbacks should be forwarded to Instance's callback + // queue instead. + mCallbackTaskManager = mAdapter->GetInstance()->GetCallbackTaskManager(); + } } void DeviceBase::DestroyObjects() { @@ -566,7 +573,7 @@ void DeviceBase::APISetLoggingCallback(wgpu::LoggingCallback callback, void* use if (IsLost()) { return; } - FlushCallbackTaskQueue(); + mCallbackTaskManager->Flush(); mLoggingCallback = callback; mLoggingUserdata = userdata; } @@ -580,7 +587,7 @@ void DeviceBase::APISetUncapturedErrorCallback(wgpu::ErrorCallback callback, voi if (IsLost()) { return; } - FlushCallbackTaskQueue(); + mCallbackTaskManager->Flush(); mUncapturedErrorCallback = callback; mUncapturedErrorUserdata = userdata; } @@ -594,7 +601,7 @@ void DeviceBase::APISetDeviceLostCallback(wgpu::DeviceLostCallback callback, voi if (IsLost()) { return; } - FlushCallbackTaskQueue(); + mCallbackTaskManager->Flush(); mDeviceLostCallback = callback; mDeviceLostUserdata = userdata; } @@ -1261,13 +1268,13 @@ bool DeviceBase::APITick() { // to avoid deleting |this| in the middle of this function call. Ref self(this); if (ConsumedError(Tick())) { - FlushCallbackTaskQueue(); + mCallbackTaskManager->Flush(); return false; } // We have to check callback tasks in every APITick because it is not related to any global // serials. - FlushCallbackTaskQueue(); + mCallbackTaskManager->Flush(); // We don't throw an error when device is lost. This allows pending callbacks to be // executed even after the Device is lost/destroyed. @@ -1806,19 +1813,6 @@ void DeviceBase::ForceSetToggleForTesting(Toggle toggle, bool isEnabled) { mToggles.ForceSet(toggle, isEnabled); } -void DeviceBase::FlushCallbackTaskQueue() { - if (!mCallbackTaskManager->IsEmpty()) { - // If a user calls Queue::Submit inside the callback, then the device will be ticked, - // which in turns ticks the tracker, causing reentrance and dead lock here. To prevent - // such reentrant call, we remove all the callback tasks from mCallbackTaskManager, - // update mCallbackTaskManager, then call all the callbacks. - auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks(); - for (std::unique_ptr& callbackTask : callbackTasks) { - callbackTask->Execute(); - } - } -} - const CombinedLimits& DeviceBase::GetLimits() const { return mLimits; } @@ -1828,7 +1822,7 @@ AsyncTaskManager* DeviceBase::GetAsyncTaskManager() const { } CallbackTaskManager* DeviceBase::GetCallbackTaskManager() const { - return mCallbackTaskManager.get(); + return mCallbackTaskManager.Get(); } dawn::platform::WorkerTaskPool* DeviceBase::GetWorkerTaskPool() const { diff --git a/src/dawn/native/Device.h b/src/dawn/native/Device.h index 14a46213f2..f05bc667dd 100644 --- a/src/dawn/native/Device.h +++ b/src/dawn/native/Device.h @@ -482,7 +482,6 @@ class DeviceBase : public RefCountedWithExternalCount { virtual void SetLabelImpl(); virtual MaybeError TickImpl() = 0; - void FlushCallbackTaskQueue(); ResultOrError> CreateEmptyBindGroupLayout(); @@ -594,7 +593,7 @@ class DeviceBase : public RefCountedWithExternalCount { std::unique_ptr mInternalPipelineStore; - std::unique_ptr mCallbackTaskManager; + Ref mCallbackTaskManager; std::unique_ptr mWorkerTaskPool; std::string mLabel; CacheKey mDeviceCacheKey; diff --git a/src/dawn/native/Instance.cpp b/src/dawn/native/Instance.cpp index 63e3f1c4f7..6b8248624a 100644 --- a/src/dawn/native/Instance.cpp +++ b/src/dawn/native/Instance.cpp @@ -20,6 +20,7 @@ #include "dawn/common/GPUInfo.h" #include "dawn/common/Log.h" #include "dawn/common/SystemUtils.h" +#include "dawn/native/CallbackTaskManager.h" #include "dawn/native/ChainUtils_autogen.h" #include "dawn/native/Device.h" #include "dawn/native/ErrorData.h" @@ -172,6 +173,8 @@ MaybeError InstanceBase::Initialize(const InstanceDescriptor* descriptor) { } mRuntimeSearchPaths.push_back(""); + mCallbackTaskManager = AcquireRef(new CallbackTaskManager()); + // Initialize the platform to the default for now. mDefaultPlatform = std::make_unique(); SetPlatform(mDefaultPlatform.get()); @@ -533,13 +536,19 @@ bool InstanceBase::APIProcessEvents() { hasMoreEvents = device->APITick() || hasMoreEvents; } - return hasMoreEvents; + mCallbackTaskManager->Flush(); + + return hasMoreEvents || !mCallbackTaskManager->IsEmpty(); } const std::vector& InstanceBase::GetRuntimeSearchPaths() const { return mRuntimeSearchPaths; } +const Ref& InstanceBase::GetCallbackTaskManager() const { + return mCallbackTaskManager; +} + void InstanceBase::ConsumeError(std::unique_ptr error) { ASSERT(error != nullptr); dawn::ErrorLog() << error->GetFormattedMessage(); diff --git a/src/dawn/native/Instance.h b/src/dawn/native/Instance.h index f00a95ad0a..10c90dd746 100644 --- a/src/dawn/native/Instance.h +++ b/src/dawn/native/Instance.h @@ -39,6 +39,7 @@ class Platform; namespace dawn::native { +class CallbackTaskManager; class DeviceBase; class Surface; class XlibXcbFunctions; @@ -111,6 +112,8 @@ class InstanceBase final : public RefCountedWithExternalCount { const std::vector& GetRuntimeSearchPaths() const; + const Ref& GetCallbackTaskManager() const; + // Get backend-independent libraries that need to be loaded dynamically. const XlibXcbFunctions* GetOrCreateXlibXcbFunctions(); @@ -165,6 +168,8 @@ class InstanceBase final : public RefCountedWithExternalCount { std::unique_ptr mXlibXcbFunctions; #endif // defined(DAWN_USE_X11) + Ref mCallbackTaskManager; + std::set mDevicesList; mutable std::mutex mDevicesListMutex; }; diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn index a6db8484f1..066f0a928d 100644 --- a/src/dawn/tests/BUILD.gn +++ b/src/dawn/tests/BUILD.gn @@ -318,6 +318,7 @@ dawn_test("dawn_unittests") { "unittests/native/CommandBufferEncodingTests.cpp", "unittests/native/CreatePipelineAsyncTaskTests.cpp", "unittests/native/DestroyObjectTests.cpp", + "unittests/native/DeviceAsyncTaskTests.cpp", "unittests/native/DeviceCreationTests.cpp", "unittests/native/ObjectContentHasherTests.cpp", "unittests/native/StreamTests.cpp", diff --git a/src/dawn/tests/unittests/native/CreatePipelineAsyncTaskTests.cpp b/src/dawn/tests/unittests/native/CreatePipelineAsyncTaskTests.cpp index 613177700a..9509850fdb 100644 --- a/src/dawn/tests/unittests/native/CreatePipelineAsyncTaskTests.cpp +++ b/src/dawn/tests/unittests/native/CreatePipelineAsyncTaskTests.cpp @@ -14,6 +14,11 @@ #include +#include +#include +#include +#include + #include "dawn/native/CreatePipelineAsyncTask.h" #include "dawn/utils/WGPUHelpers.h" #include "mocks/ComputePipelineMock.h" @@ -149,5 +154,38 @@ TEST_F(CreatePipelineAsyncTaskTests, InitializationInternalErrorInCreateComputeP EXPECT_CALL(*computePipelineMock.Get(), DestroyImpl).Times(1); } +// Test that a long async task's execution won't extend to after the device is dropped. +// Device dropping should wait for that task to finish. +TEST_F(CreatePipelineAsyncTaskTests, LongAsyncTaskFinishesBeforeDeviceIsDropped) { + wgpu::RenderPipelineDescriptor desc = {}; + desc.vertex.module = utils::CreateShaderModule(device, kVertexShader.data()); + desc.vertex.entryPoint = "main"; + Ref renderPipelineMock = + RenderPipelineMock::Create(mDeviceMock, FromCppAPI(&desc)); + + // Simulate that Initialize() would take a long time to finish. + ON_CALL(*renderPipelineMock.Get(), Initialize).WillByDefault([]() -> MaybeError { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + return {}; + }); + + bool done = false; + auto asyncTask = std::make_unique( + renderPipelineMock, + [](WGPUCreatePipelineAsyncStatus status, WGPURenderPipeline returnPipeline, + const char* message, void* userdata) { + wgpu::RenderPipeline::Acquire(returnPipeline); + + *static_cast(userdata) = true; + }, + &done); + + CreateRenderPipelineAsyncTask::RunAsync(std::move(asyncTask)); + + device = nullptr; + // Dropping the device should force the async task to finish. + EXPECT_TRUE(done); +} + } // namespace } // namespace dawn::native diff --git a/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp b/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp new file mode 100644 index 0000000000..8309005a30 --- /dev/null +++ b/src/dawn/tests/unittests/native/DeviceAsyncTaskTests.cpp @@ -0,0 +1,50 @@ +// Copyright 2023 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include +#include +#include +#include +#include + +#include "dawn/native/AsyncTask.h" +#include "mocks/DawnMockTest.h" + +namespace dawn::native { +namespace { +using ::testing::Test; + +class DeviceAsyncTaskTests : public DawnMockTest {}; + +// Test that a long async task's execution won't extend to after the device is dropped. +// Device dropping should wait for that task to finish. +TEST_F(DeviceAsyncTaskTests, LongAsyncTaskFinishesBeforeDeviceIsDropped) { + std::atomic_bool done(false); + + // Simulate that an async task would take a long time to finish. + dawn::native::AsyncTask asyncTask([&done] { + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + done = true; + }); + + mDeviceMock->GetAsyncTaskManager()->PostTask(std::move(asyncTask)); + device = nullptr; + // Dropping the device should force the async task to finish. + EXPECT_TRUE(done.load()); +} + +} // namespace +} // namespace dawn::native