#include "dawn_native/AsyncTask.h" #include "dawn_platform/DawnPlatform.h" namespace dawn::native { AsyncTaskManager::AsyncTaskManager(dawn::platform::WorkerTaskPool* workerTaskPool) : mWorkerTaskPool(workerTaskPool) { } void AsyncTaskManager::PostTask(AsyncTask asyncTask) { // If these allocations becomes expensive, we can slab-allocate tasks. Ref waitableTask = AcquireRef(new WaitableTask()); waitableTask->taskManager = this; waitableTask->asyncTask = std::move(asyncTask); { // We insert new waitableTask objects into mPendingTasks in main thread (PostTask()), // and we may remove waitableTask objects from mPendingTasks in either main thread // (WaitAllPendingTasks()) or sub-thread (TaskCompleted), so mPendingTasks should be // protected by a mutex. std::lock_guard lock(mPendingTasksMutex); mPendingTasks.emplace(waitableTask.Get(), waitableTask); } // Ref the task since it is accessed inside the worker function. // The worker function will acquire and release the task upon completion. waitableTask->Reference(); waitableTask->waitableEvent = mWorkerTaskPool->PostWorkerTask(DoWaitableTask, waitableTask.Get()); } void AsyncTaskManager::HandleTaskCompletion(WaitableTask* task) { std::lock_guard lock(mPendingTasksMutex); auto iter = mPendingTasks.find(task); if (iter != mPendingTasks.end()) { mPendingTasks.erase(iter); } } void AsyncTaskManager::WaitAllPendingTasks() { std::unordered_map> allPendingTasks; { std::lock_guard lock(mPendingTasksMutex); allPendingTasks.swap(mPendingTasks); } for (auto& [_, task] : allPendingTasks) { task->waitableEvent->Wait(); } } bool AsyncTaskManager::HasPendingTasks() { std::lock_guard lock(mPendingTasksMutex); return !mPendingTasks.empty(); } void AsyncTaskManager::DoWaitableTask(void* task) { Ref waitableTask = AcquireRef(static_cast(task)); waitableTask->asyncTask(); waitableTask->taskManager->HandleTaskCompletion(waitableTask.Get()); } } // namespace dawn::native