Forward callbacks to Instance after Device is destroyed.
When InstanceBase::ProcessEvents() iterates through list of devices, one device might be being destructed on another thread. Even if we try to increase ref count of that device inside the ProcessEvents(), the device might be in the middle of destructor call on another thread, increasing the ref count is invalid in this case. This CL attempts to fix this issue by removing the device's pointer from InstanceBase earlier: when DeviceBase::WillDropLastExternalRef() is called. After this point, any callback registered to this device will be forwarded to InstanceBase's callback queue instead. Bug: dawn:752 Change-Id: I8ae86575e34f753e52a76f5fc774bbb5366a1b85 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/124281 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Quyen Le <lehoangquyen@chromium.org> Reviewed-by: Loko Kung <lokokung@google.com>
This commit is contained in:
parent
f359395288
commit
de078bad8d
|
@ -102,4 +102,17 @@ void CallbackTaskManager::HandleShutDown() {
|
|||
}
|
||||
}
|
||||
|
||||
void CallbackTaskManager::Flush() {
|
||||
if (!IsEmpty()) {
|
||||
// If a user calls Queue::Submit inside the callback, then the device will be ticked,
|
||||
// which in turns ticks the tracker, causing reentrance and dead lock here. To prevent
|
||||
// such reentrant call, we remove all the callback tasks from mCallbackTaskManager,
|
||||
// update mCallbackTaskManager, then call all the callbacks.
|
||||
auto callbackTasks = AcquireCallbackTasks();
|
||||
for (std::unique_ptr<CallbackTask>& callbackTask : callbackTasks) {
|
||||
callbackTask->Execute();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace dawn::native
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include <mutex>
|
||||
#include <vector>
|
||||
|
||||
#include "dawn/common/RefCounted.h"
|
||||
#include "dawn/common/TypeTraits.h"
|
||||
|
||||
namespace dawn::native {
|
||||
|
@ -47,10 +48,10 @@ struct CallbackTask {
|
|||
State mState = State::Normal;
|
||||
};
|
||||
|
||||
class CallbackTaskManager {
|
||||
class CallbackTaskManager : public RefCounted {
|
||||
public:
|
||||
CallbackTaskManager();
|
||||
~CallbackTaskManager();
|
||||
~CallbackTaskManager() override;
|
||||
|
||||
void AddCallbackTask(std::unique_ptr<CallbackTask> callbackTask);
|
||||
void AddCallbackTask(std::function<void()> callback);
|
||||
|
@ -63,9 +64,11 @@ class CallbackTaskManager {
|
|||
bool IsEmpty();
|
||||
void HandleDeviceLoss();
|
||||
void HandleShutDown();
|
||||
std::vector<std::unique_ptr<CallbackTask>> AcquireCallbackTasks();
|
||||
void Flush();
|
||||
|
||||
private:
|
||||
std::vector<std::unique_ptr<CallbackTask>> AcquireCallbackTasks();
|
||||
|
||||
std::mutex mCallbackTaskQueueMutex;
|
||||
std::vector<std::unique_ptr<CallbackTask>> mCallbackTaskQueue;
|
||||
};
|
||||
|
|
|
@ -216,11 +216,6 @@ DeviceBase::~DeviceBase() {
|
|||
// We need to explicitly release the Queue before we complete the destructor so that the
|
||||
// Queue does not get destroyed after the Device.
|
||||
mQueue = nullptr;
|
||||
// mAdapter is not set for mock test devices.
|
||||
// TODO(crbug.com/dawn/1702): using a mock adapter could avoid the null checking.
|
||||
if (mAdapter != nullptr) {
|
||||
mAdapter->GetInstance()->RemoveDevice(this);
|
||||
}
|
||||
}
|
||||
|
||||
MaybeError DeviceBase::Initialize(Ref<QueueBase> defaultQueue) {
|
||||
|
@ -253,7 +248,7 @@ MaybeError DeviceBase::Initialize(Ref<QueueBase> defaultQueue) {
|
|||
mCaches = std::make_unique<DeviceBase::Caches>();
|
||||
mErrorScopeStack = std::make_unique<ErrorScopeStack>();
|
||||
mDynamicUploader = std::make_unique<DynamicUploader>(this);
|
||||
mCallbackTaskManager = std::make_unique<CallbackTaskManager>();
|
||||
mCallbackTaskManager = AcquireRef(new CallbackTaskManager());
|
||||
mDeprecationWarnings = std::make_unique<DeprecationWarnings>();
|
||||
mInternalPipelineStore = std::make_unique<InternalPipelineStore>(this);
|
||||
|
||||
|
@ -311,7 +306,9 @@ void DeviceBase::WillDropLastExternalRef() {
|
|||
Destroy();
|
||||
|
||||
// Flush last remaining callback tasks.
|
||||
FlushCallbackTaskQueue();
|
||||
do {
|
||||
mCallbackTaskManager->Flush();
|
||||
} while (!mCallbackTaskManager->IsEmpty());
|
||||
|
||||
// Drop te device's reference to the queue. Because the application dropped the last external
|
||||
// references, they can no longer get the queue from APIGetQueue().
|
||||
|
@ -328,6 +325,16 @@ void DeviceBase::WillDropLastExternalRef() {
|
|||
dawn::WarningLog() << "Device lost after last external device reference dropped.\n"
|
||||
<< message;
|
||||
};
|
||||
|
||||
// mAdapter is not set for mock test devices.
|
||||
// TODO(crbug.com/dawn/1702): using a mock adapter could avoid the null checking.
|
||||
if (mAdapter != nullptr) {
|
||||
mAdapter->GetInstance()->RemoveDevice(this);
|
||||
|
||||
// Once last external ref dropped, all callbacks should be forwarded to Instance's callback
|
||||
// queue instead.
|
||||
mCallbackTaskManager = mAdapter->GetInstance()->GetCallbackTaskManager();
|
||||
}
|
||||
}
|
||||
|
||||
void DeviceBase::DestroyObjects() {
|
||||
|
@ -566,7 +573,7 @@ void DeviceBase::APISetLoggingCallback(wgpu::LoggingCallback callback, void* use
|
|||
if (IsLost()) {
|
||||
return;
|
||||
}
|
||||
FlushCallbackTaskQueue();
|
||||
mCallbackTaskManager->Flush();
|
||||
mLoggingCallback = callback;
|
||||
mLoggingUserdata = userdata;
|
||||
}
|
||||
|
@ -580,7 +587,7 @@ void DeviceBase::APISetUncapturedErrorCallback(wgpu::ErrorCallback callback, voi
|
|||
if (IsLost()) {
|
||||
return;
|
||||
}
|
||||
FlushCallbackTaskQueue();
|
||||
mCallbackTaskManager->Flush();
|
||||
mUncapturedErrorCallback = callback;
|
||||
mUncapturedErrorUserdata = userdata;
|
||||
}
|
||||
|
@ -594,7 +601,7 @@ void DeviceBase::APISetDeviceLostCallback(wgpu::DeviceLostCallback callback, voi
|
|||
if (IsLost()) {
|
||||
return;
|
||||
}
|
||||
FlushCallbackTaskQueue();
|
||||
mCallbackTaskManager->Flush();
|
||||
mDeviceLostCallback = callback;
|
||||
mDeviceLostUserdata = userdata;
|
||||
}
|
||||
|
@ -1261,13 +1268,13 @@ bool DeviceBase::APITick() {
|
|||
// to avoid deleting |this| in the middle of this function call.
|
||||
Ref<DeviceBase> self(this);
|
||||
if (ConsumedError(Tick())) {
|
||||
FlushCallbackTaskQueue();
|
||||
mCallbackTaskManager->Flush();
|
||||
return false;
|
||||
}
|
||||
|
||||
// We have to check callback tasks in every APITick because it is not related to any global
|
||||
// serials.
|
||||
FlushCallbackTaskQueue();
|
||||
mCallbackTaskManager->Flush();
|
||||
|
||||
// We don't throw an error when device is lost. This allows pending callbacks to be
|
||||
// executed even after the Device is lost/destroyed.
|
||||
|
@ -1806,19 +1813,6 @@ void DeviceBase::ForceSetToggleForTesting(Toggle toggle, bool isEnabled) {
|
|||
mToggles.ForceSet(toggle, isEnabled);
|
||||
}
|
||||
|
||||
void DeviceBase::FlushCallbackTaskQueue() {
|
||||
if (!mCallbackTaskManager->IsEmpty()) {
|
||||
// If a user calls Queue::Submit inside the callback, then the device will be ticked,
|
||||
// which in turns ticks the tracker, causing reentrance and dead lock here. To prevent
|
||||
// such reentrant call, we remove all the callback tasks from mCallbackTaskManager,
|
||||
// update mCallbackTaskManager, then call all the callbacks.
|
||||
auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks();
|
||||
for (std::unique_ptr<CallbackTask>& callbackTask : callbackTasks) {
|
||||
callbackTask->Execute();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const CombinedLimits& DeviceBase::GetLimits() const {
|
||||
return mLimits;
|
||||
}
|
||||
|
@ -1828,7 +1822,7 @@ AsyncTaskManager* DeviceBase::GetAsyncTaskManager() const {
|
|||
}
|
||||
|
||||
CallbackTaskManager* DeviceBase::GetCallbackTaskManager() const {
|
||||
return mCallbackTaskManager.get();
|
||||
return mCallbackTaskManager.Get();
|
||||
}
|
||||
|
||||
dawn::platform::WorkerTaskPool* DeviceBase::GetWorkerTaskPool() const {
|
||||
|
|
|
@ -482,7 +482,6 @@ class DeviceBase : public RefCountedWithExternalCount {
|
|||
virtual void SetLabelImpl();
|
||||
|
||||
virtual MaybeError TickImpl() = 0;
|
||||
void FlushCallbackTaskQueue();
|
||||
|
||||
ResultOrError<Ref<BindGroupLayoutBase>> CreateEmptyBindGroupLayout();
|
||||
|
||||
|
@ -594,7 +593,7 @@ class DeviceBase : public RefCountedWithExternalCount {
|
|||
|
||||
std::unique_ptr<InternalPipelineStore> mInternalPipelineStore;
|
||||
|
||||
std::unique_ptr<CallbackTaskManager> mCallbackTaskManager;
|
||||
Ref<CallbackTaskManager> mCallbackTaskManager;
|
||||
std::unique_ptr<dawn::platform::WorkerTaskPool> mWorkerTaskPool;
|
||||
std::string mLabel;
|
||||
CacheKey mDeviceCacheKey;
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "dawn/common/GPUInfo.h"
|
||||
#include "dawn/common/Log.h"
|
||||
#include "dawn/common/SystemUtils.h"
|
||||
#include "dawn/native/CallbackTaskManager.h"
|
||||
#include "dawn/native/ChainUtils_autogen.h"
|
||||
#include "dawn/native/Device.h"
|
||||
#include "dawn/native/ErrorData.h"
|
||||
|
@ -172,6 +173,8 @@ MaybeError InstanceBase::Initialize(const InstanceDescriptor* descriptor) {
|
|||
}
|
||||
mRuntimeSearchPaths.push_back("");
|
||||
|
||||
mCallbackTaskManager = AcquireRef(new CallbackTaskManager());
|
||||
|
||||
// Initialize the platform to the default for now.
|
||||
mDefaultPlatform = std::make_unique<dawn::platform::Platform>();
|
||||
SetPlatform(mDefaultPlatform.get());
|
||||
|
@ -533,13 +536,19 @@ bool InstanceBase::APIProcessEvents() {
|
|||
hasMoreEvents = device->APITick() || hasMoreEvents;
|
||||
}
|
||||
|
||||
return hasMoreEvents;
|
||||
mCallbackTaskManager->Flush();
|
||||
|
||||
return hasMoreEvents || !mCallbackTaskManager->IsEmpty();
|
||||
}
|
||||
|
||||
const std::vector<std::string>& InstanceBase::GetRuntimeSearchPaths() const {
|
||||
return mRuntimeSearchPaths;
|
||||
}
|
||||
|
||||
const Ref<CallbackTaskManager>& InstanceBase::GetCallbackTaskManager() const {
|
||||
return mCallbackTaskManager;
|
||||
}
|
||||
|
||||
void InstanceBase::ConsumeError(std::unique_ptr<ErrorData> error) {
|
||||
ASSERT(error != nullptr);
|
||||
dawn::ErrorLog() << error->GetFormattedMessage();
|
||||
|
|
|
@ -39,6 +39,7 @@ class Platform;
|
|||
|
||||
namespace dawn::native {
|
||||
|
||||
class CallbackTaskManager;
|
||||
class DeviceBase;
|
||||
class Surface;
|
||||
class XlibXcbFunctions;
|
||||
|
@ -111,6 +112,8 @@ class InstanceBase final : public RefCountedWithExternalCount {
|
|||
|
||||
const std::vector<std::string>& GetRuntimeSearchPaths() const;
|
||||
|
||||
const Ref<CallbackTaskManager>& GetCallbackTaskManager() const;
|
||||
|
||||
// Get backend-independent libraries that need to be loaded dynamically.
|
||||
const XlibXcbFunctions* GetOrCreateXlibXcbFunctions();
|
||||
|
||||
|
@ -165,6 +168,8 @@ class InstanceBase final : public RefCountedWithExternalCount {
|
|||
std::unique_ptr<XlibXcbFunctions> mXlibXcbFunctions;
|
||||
#endif // defined(DAWN_USE_X11)
|
||||
|
||||
Ref<CallbackTaskManager> mCallbackTaskManager;
|
||||
|
||||
std::set<DeviceBase*> mDevicesList;
|
||||
mutable std::mutex mDevicesListMutex;
|
||||
};
|
||||
|
|
|
@ -318,6 +318,7 @@ dawn_test("dawn_unittests") {
|
|||
"unittests/native/CommandBufferEncodingTests.cpp",
|
||||
"unittests/native/CreatePipelineAsyncTaskTests.cpp",
|
||||
"unittests/native/DestroyObjectTests.cpp",
|
||||
"unittests/native/DeviceAsyncTaskTests.cpp",
|
||||
"unittests/native/DeviceCreationTests.cpp",
|
||||
"unittests/native/ObjectContentHasherTests.cpp",
|
||||
"unittests/native/StreamTests.cpp",
|
||||
|
|
|
@ -14,6 +14,11 @@
|
|||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
#include "dawn/native/CreatePipelineAsyncTask.h"
|
||||
#include "dawn/utils/WGPUHelpers.h"
|
||||
#include "mocks/ComputePipelineMock.h"
|
||||
|
@ -149,5 +154,38 @@ TEST_F(CreatePipelineAsyncTaskTests, InitializationInternalErrorInCreateComputeP
|
|||
EXPECT_CALL(*computePipelineMock.Get(), DestroyImpl).Times(1);
|
||||
}
|
||||
|
||||
// Test that a long async task's execution won't extend to after the device is dropped.
|
||||
// Device dropping should wait for that task to finish.
|
||||
TEST_F(CreatePipelineAsyncTaskTests, LongAsyncTaskFinishesBeforeDeviceIsDropped) {
|
||||
wgpu::RenderPipelineDescriptor desc = {};
|
||||
desc.vertex.module = utils::CreateShaderModule(device, kVertexShader.data());
|
||||
desc.vertex.entryPoint = "main";
|
||||
Ref<RenderPipelineMock> renderPipelineMock =
|
||||
RenderPipelineMock::Create(mDeviceMock, FromCppAPI(&desc));
|
||||
|
||||
// Simulate that Initialize() would take a long time to finish.
|
||||
ON_CALL(*renderPipelineMock.Get(), Initialize).WillByDefault([]() -> MaybeError {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
return {};
|
||||
});
|
||||
|
||||
bool done = false;
|
||||
auto asyncTask = std::make_unique<CreateRenderPipelineAsyncTask>(
|
||||
renderPipelineMock,
|
||||
[](WGPUCreatePipelineAsyncStatus status, WGPURenderPipeline returnPipeline,
|
||||
const char* message, void* userdata) {
|
||||
wgpu::RenderPipeline::Acquire(returnPipeline);
|
||||
|
||||
*static_cast<bool*>(userdata) = true;
|
||||
},
|
||||
&done);
|
||||
|
||||
CreateRenderPipelineAsyncTask::RunAsync(std::move(asyncTask));
|
||||
|
||||
device = nullptr;
|
||||
// Dropping the device should force the async task to finish.
|
||||
EXPECT_TRUE(done);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace dawn::native
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
// Copyright 2023 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <chrono>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
#include "dawn/native/AsyncTask.h"
|
||||
#include "mocks/DawnMockTest.h"
|
||||
|
||||
namespace dawn::native {
|
||||
namespace {
|
||||
using ::testing::Test;
|
||||
|
||||
class DeviceAsyncTaskTests : public DawnMockTest {};
|
||||
|
||||
// Test that a long async task's execution won't extend to after the device is dropped.
|
||||
// Device dropping should wait for that task to finish.
|
||||
TEST_F(DeviceAsyncTaskTests, LongAsyncTaskFinishesBeforeDeviceIsDropped) {
|
||||
std::atomic_bool done(false);
|
||||
|
||||
// Simulate that an async task would take a long time to finish.
|
||||
dawn::native::AsyncTask asyncTask([&done] {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(1000));
|
||||
done = true;
|
||||
});
|
||||
|
||||
mDeviceMock->GetAsyncTaskManager()->PostTask(std::move(asyncTask));
|
||||
device = nullptr;
|
||||
// Dropping the device should force the async task to finish.
|
||||
EXPECT_TRUE(done.load());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace dawn::native
|
Loading…
Reference in New Issue