diff --git a/src/dawn_native/AsyncTask.cpp b/src/dawn_native/AsyncTask.cpp new file mode 100644 index 0000000000..8cdf55c376 --- /dev/null +++ b/src/dawn_native/AsyncTask.cpp @@ -0,0 +1,60 @@ +#include "dawn_native/AsyncTask.h" + +#include "dawn_platform/DawnPlatform.h" + +namespace dawn_native { + + AsnycTaskManager::AsnycTaskManager(dawn_platform::WorkerTaskPool* workerTaskPool) + : mWorkerTaskPool(workerTaskPool) { + } + + void AsnycTaskManager::PostTask(AsyncTask asyncTask) { + // If these allocations becomes expensive, we can slab-allocate tasks. + Ref waitableTask = AcquireRef(new WaitableTask()); + waitableTask->taskManager = this; + waitableTask->asyncTask = std::move(asyncTask); + + { + // We insert new waitableTask objects into mPendingTasks in main thread (PostTask()), + // and we may remove waitableTask objects from mPendingTasks in either main thread + // (WaitAllPendingTasks()) or sub-thread (TaskCompleted), so mPendingTasks should be + // protected by a mutex. + std::lock_guard lock(mPendingTasksMutex); + mPendingTasks.emplace(waitableTask.Get(), waitableTask); + } + + // Ref the task since it is accessed inside the worker function. + // The worker function will acquire and release the task upon completion. + waitableTask->Reference(); + waitableTask->waitableEvent = + mWorkerTaskPool->PostWorkerTask(DoWaitableTask, waitableTask.Get()); + } + + void AsnycTaskManager::HandleTaskCompletion(WaitableTask* task) { + std::lock_guard lock(mPendingTasksMutex); + auto iter = mPendingTasks.find(task); + if (iter != mPendingTasks.end()) { + mPendingTasks.erase(iter); + } + } + + void AsnycTaskManager::WaitAllPendingTasks() { + std::unordered_map> allPendingTasks; + + { + std::lock_guard lock(mPendingTasksMutex); + allPendingTasks.swap(mPendingTasks); + } + + for (auto& keyValue : allPendingTasks) { + keyValue.second->waitableEvent->Wait(); + } + } + + void AsnycTaskManager::DoWaitableTask(void* task) { + Ref waitableTask = AcquireRef(static_cast(task)); + waitableTask->asyncTask(); + waitableTask->taskManager->HandleTaskCompletion(waitableTask.Get()); + } + +} // namespace dawn_native diff --git a/src/dawn_native/AsyncTask.h b/src/dawn_native/AsyncTask.h new file mode 100644 index 0000000000..8fb8dc2098 --- /dev/null +++ b/src/dawn_native/AsyncTask.h @@ -0,0 +1,64 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWNNATIVE_ASYC_TASK_H_ +#define DAWNNATIVE_ASYC_TASK_H_ + +#include +#include +#include +#include + +#include "common/RefCounted.h" + +namespace dawn_platform { + class WaitableEvent; + class WorkerTaskPool; +} // namespace dawn_platform + +namespace dawn_native { + + // TODO(jiawei.shao@intel.com): we'll add additional things to AsyncTask in the future, like + // Cancel() and RunNow(). Cancelling helps avoid running the task's body when we are just + // shutting down the device. RunNow() could be used for more advanced scenarios, for example + // always doing ShaderModule initial compilation asynchronously, but being able to steal the + // task if we need it for synchronous pipeline compilation. + using AsyncTask = std::function; + + class AsnycTaskManager { + public: + explicit AsnycTaskManager(dawn_platform::WorkerTaskPool* workerTaskPool); + + void PostTask(AsyncTask asyncTask); + void WaitAllPendingTasks(); + + private: + class WaitableTask : public RefCounted { + public: + AsyncTask asyncTask; + AsnycTaskManager* taskManager; + std::unique_ptr waitableEvent; + }; + + static void DoWaitableTask(void* task); + void HandleTaskCompletion(WaitableTask* task); + + std::mutex mPendingTasksMutex; + std::unordered_map> mPendingTasks; + dawn_platform::WorkerTaskPool* mWorkerTaskPool; + }; + +} // namespace dawn_native + +#endif diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn index 7ea9b167c7..982e80daec 100644 --- a/src/dawn_native/BUILD.gn +++ b/src/dawn_native/BUILD.gn @@ -165,6 +165,8 @@ source_set("dawn_native_sources") { sources += [ "Adapter.cpp", "Adapter.h", + "AsyncTask.cpp", + "AsyncTask.h", "AttachmentState.cpp", "AttachmentState.h", "BackendConnection.cpp", diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt index f03d6fd678..d97bea3106 100644 --- a/src/dawn_native/CMakeLists.txt +++ b/src/dawn_native/CMakeLists.txt @@ -31,6 +31,8 @@ target_sources(dawn_native PRIVATE ${DAWN_NATIVE_UTILS_GEN_SOURCES} "Adapter.cpp" "Adapter.h" + "AsyncTask.cpp" + "AsyncTask.h" "AttachmentState.cpp" "AttachmentState.h" "BackendConnection.cpp" diff --git a/src/dawn_native/CallbackTaskManager.h b/src/dawn_native/CallbackTaskManager.h index 1be0eb22b0..49108ec0d0 100644 --- a/src/dawn_native/CallbackTaskManager.h +++ b/src/dawn_native/CallbackTaskManager.h @@ -21,8 +21,6 @@ namespace dawn_native { - class CallbackTaskManager; - struct CallbackTask { public: virtual ~CallbackTask() = default; diff --git a/src/dawn_platform/WorkerThread.cpp b/src/dawn_platform/WorkerThread.cpp index 7be4c3a63e..025ed1f62c 100644 --- a/src/dawn_platform/WorkerThread.cpp +++ b/src/dawn_platform/WorkerThread.cpp @@ -14,40 +14,84 @@ #include "dawn_platform/WorkerThread.h" -#include +#include +#include +#include #include "common/Assert.h" namespace { - class AsyncWaitableEvent final : public dawn_platform::WaitableEvent { + class AsyncWaitableEventImpl { public: - explicit AsyncWaitableEvent(std::function func) { - mFuture = std::async(std::launch::async, func); + AsyncWaitableEventImpl() : mIsComplete(false) { } - void Wait() override { - ASSERT(mFuture.valid()); - mFuture.wait(); + + void Wait() { + std::unique_lock lock(mMutex); + mCondition.wait(lock, [this] { return mIsComplete; }); } - bool IsComplete() override { - ASSERT(mFuture.valid()); - return mFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready; + + bool IsComplete() { + std::lock_guard lock(mMutex); + return mIsComplete; + } + + void MarkAsComplete() { + { + std::lock_guard lock(mMutex); + mIsComplete = true; + } + mCondition.notify_all(); } private: - // It is safe not to call Wait() in the destructor of AsyncWaitableEvent because since - // C++14 the destructor of std::future will always be blocked until its state becomes - // std::future_status::ready when it was created by a call of std::async and it is the - // last reference to the shared state. - // See https://en.cppreference.com/w/cpp/thread/future/~future for more details. - std::future mFuture; + std::mutex mMutex; + std::condition_variable mCondition; + bool mIsComplete; + }; + + class AsyncWaitableEvent final : public dawn_platform::WaitableEvent { + public: + explicit AsyncWaitableEvent() + : mWaitableEventImpl(std::make_shared()) { + } + + void Wait() override { + mWaitableEventImpl->Wait(); + } + + bool IsComplete() override { + return mWaitableEventImpl->IsComplete(); + } + + std::shared_ptr GetWaitableEventImpl() const { + return mWaitableEventImpl; + } + + private: + std::shared_ptr mWaitableEventImpl; }; } // anonymous namespace -std::unique_ptr AsyncWorkerThreadPool::PostWorkerTask( - dawn_platform::PostWorkerTaskCallback callback, - void* userdata) { - std::function doTask = [callback, userdata]() { callback(userdata); }; - return std::make_unique(doTask); -} \ No newline at end of file +namespace dawn_platform { + + std::unique_ptr AsyncWorkerThreadPool::PostWorkerTask( + dawn_platform::PostWorkerTaskCallback callback, + void* userdata) { + std::unique_ptr waitableEvent = std::make_unique(); + + std::function doTask = + [callback, userdata, waitableEventImpl = waitableEvent->GetWaitableEventImpl()]() { + callback(userdata); + waitableEventImpl->MarkAsComplete(); + }; + + std::thread thread(doTask); + thread.detach(); + + return waitableEvent; + } + +} // namespace dawn_platform diff --git a/src/dawn_platform/WorkerThread.h b/src/dawn_platform/WorkerThread.h index 56a5d1005d..49f81ad4fe 100644 --- a/src/dawn_platform/WorkerThread.h +++ b/src/dawn_platform/WorkerThread.h @@ -18,11 +18,15 @@ #include "common/NonCopyable.h" #include "dawn_platform/DawnPlatform.h" -class AsyncWorkerThreadPool : public dawn_platform::WorkerTaskPool, public NonCopyable { - public: - std::unique_ptr PostWorkerTask( - dawn_platform::PostWorkerTaskCallback callback, - void* userdata) override; -}; +namespace dawn_platform { + + class AsyncWorkerThreadPool : public dawn_platform::WorkerTaskPool, public NonCopyable { + public: + std::unique_ptr PostWorkerTask( + dawn_platform::PostWorkerTaskCallback callback, + void* userdata) override; + }; + +} // namespace dawn_platform #endif diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn index d4ba8069ca..a0ba923548 100644 --- a/src/tests/BUILD.gn +++ b/src/tests/BUILD.gn @@ -153,6 +153,7 @@ test("dawn_unittests") { "MockCallback.h", "ToggleParser.cpp", "ToggleParser.h", + "unittests/AsyncTaskTests.cpp", "unittests/BitSetIteratorTests.cpp", "unittests/BuddyAllocatorTests.cpp", "unittests/BuddyMemoryAllocatorTests.cpp", @@ -184,7 +185,6 @@ test("dawn_unittests") { "unittests/SystemUtilsTests.cpp", "unittests/ToBackendTests.cpp", "unittests/TypedIntegerTests.cpp", - "unittests/WorkerThreadTests.cpp", "unittests/validation/BindGroupValidationTests.cpp", "unittests/validation/BufferValidationTests.cpp", "unittests/validation/CommandBufferValidationTests.cpp", diff --git a/src/tests/unittests/AsyncTaskTests.cpp b/src/tests/unittests/AsyncTaskTests.cpp new file mode 100644 index 0000000000..5a5bcb1d02 --- /dev/null +++ b/src/tests/unittests/AsyncTaskTests.cpp @@ -0,0 +1,89 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// AsyncTaskTests: +// Simple tests for dawn_native::AsyncTask and dawn_native::AsnycTaskManager. + +#include + +#include +#include + +#include "common/NonCopyable.h" +#include "dawn_native/AsyncTask.h" +#include "dawn_platform/DawnPlatform.h" + +namespace { + + struct SimpleTaskResult { + uint32_t id; + }; + + // A thread-safe queue that stores the task results. + class ConcurrentTaskResultQueue : public NonCopyable { + public: + void AddResult(std::unique_ptr result) { + std::lock_guard lock(mMutex); + mTaskResults.push_back(std::move(result)); + } + + std::vector> GetAllResults() { + std::vector> outputResults; + { + std::lock_guard lock(mMutex); + outputResults.swap(mTaskResults); + } + return outputResults; + } + + private: + std::mutex mMutex; + std::vector> mTaskResults; + }; + + void DoTask(ConcurrentTaskResultQueue* resultQueue, uint32_t id) { + std::unique_ptr result = std::make_unique(); + result->id = id; + resultQueue->AddResult(std::move(result)); + } + +} // anonymous namespace + +class AsyncTaskTest : public testing::Test {}; + +// Emulate the basic usage of worker thread pool in Create*PipelineAsync(). +TEST_F(AsyncTaskTest, Basic) { + dawn_platform::Platform platform; + std::unique_ptr pool = platform.CreateWorkerTaskPool(); + + dawn_native::AsnycTaskManager taskManager(pool.get()); + ConcurrentTaskResultQueue taskResultQueue; + + constexpr size_t kTaskCount = 4u; + std::set idset; + for (uint32_t i = 0; i < kTaskCount; ++i) { + dawn_native::AsyncTask asyncTask([&taskResultQueue, i] { DoTask(&taskResultQueue, i); }); + taskManager.PostTask(std::move(asyncTask)); + idset.insert(i); + } + + taskManager.WaitAllPendingTasks(); + + std::vector> results = taskResultQueue.GetAllResults(); + ASSERT_EQ(kTaskCount, results.size()); + for (std::unique_ptr& result : results) { + idset.erase(result->id); + } + ASSERT_TRUE(idset.empty()); +} diff --git a/src/tests/unittests/WorkerThreadTests.cpp b/src/tests/unittests/WorkerThreadTests.cpp deleted file mode 100644 index 996d7e1eec..0000000000 --- a/src/tests/unittests/WorkerThreadTests.cpp +++ /dev/null @@ -1,169 +0,0 @@ -// Copyright 2021 The Dawn Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// -// WorkerThreadTests: -// Simple tests for the worker thread class. - -#include - -#include -#include -#include -#include - -#include "common/NonCopyable.h" -#include "dawn_platform/DawnPlatform.h" - -namespace { - - struct SimpleTaskResult { - uint32_t id; - bool isDone = false; - }; - - // A thread-safe queue that stores the task results. - class ConcurrentTaskResultQueue : public NonCopyable { - public: - void TaskCompleted(const SimpleTaskResult& result) { - ASSERT_TRUE(result.isDone); - - std::lock_guard lock(mMutex); - mTaskResultQueue.push(result); - } - - std::vector GetAndPopCompletedTasks() { - std::lock_guard lock(mMutex); - - std::vector results; - while (!mTaskResultQueue.empty()) { - results.push_back(mTaskResultQueue.front()); - mTaskResultQueue.pop(); - } - return results; - } - - private: - std::mutex mMutex; - std::queue mTaskResultQueue; - }; - - // A simple task that will be executed asynchronously with pool->PostWorkerTask(). - class SimpleTask : public NonCopyable { - public: - SimpleTask(uint32_t id, ConcurrentTaskResultQueue* resultQueue) - : mId(id), mResultQueue(resultQueue) { - } - - private: - friend class Tracker; - - static void DoTaskOnWorkerTaskPool(void* task) { - SimpleTask* simpleTaskPtr = static_cast(task); - simpleTaskPtr->doTask(); - } - - void doTask() { - SimpleTaskResult result; - result.id = mId; - result.isDone = true; - mResultQueue->TaskCompleted(result); - } - - uint32_t mId; - ConcurrentTaskResultQueue* mResultQueue; - }; - - // A simple implementation of task tracker which is only called in main thread and not - // thread-safe. - class Tracker : public NonCopyable { - public: - explicit Tracker(dawn_platform::WorkerTaskPool* pool) : mPool(pool) { - } - - void StartNewTask(uint32_t taskId) { - mTasksInFlight.emplace_back(this, mPool, taskId); - } - - uint64_t GetTasksInFlightCount() { - return mTasksInFlight.size(); - } - - void WaitAll() { - for (auto iter = mTasksInFlight.begin(); iter != mTasksInFlight.end(); ++iter) { - iter->waitableEvent->Wait(); - } - } - - // In Tick() we clean up all the completed tasks and consume all the available results. - void Tick() { - auto iter = mTasksInFlight.begin(); - while (iter != mTasksInFlight.end()) { - if (iter->waitableEvent->IsComplete()) { - iter = mTasksInFlight.erase(iter); - } else { - ++iter; - } - } - - const std::vector& results = - mCompletedTaskResultQueue.GetAndPopCompletedTasks(); - for (const SimpleTaskResult& result : results) { - EXPECT_TRUE(result.isDone); - } - } - - private: - SimpleTask* CreateSimpleTask(uint32_t taskId) { - return new SimpleTask(taskId, &mCompletedTaskResultQueue); - } - - struct WaitableTask { - WaitableTask(Tracker* tracker, dawn_platform::WorkerTaskPool* pool, uint32_t taskId) { - task.reset(tracker->CreateSimpleTask(taskId)); - waitableEvent = - pool->PostWorkerTask(SimpleTask::DoTaskOnWorkerTaskPool, task.get()); - } - - std::unique_ptr task; - std::unique_ptr waitableEvent; - }; - - dawn_platform::WorkerTaskPool* mPool; - - std::list mTasksInFlight; - ConcurrentTaskResultQueue mCompletedTaskResultQueue; - }; - -} // anonymous namespace - -class WorkerThreadTest : public testing::Test {}; - -// Emulate the basic usage of worker thread pool in Create*PipelineAsync(). -TEST_F(WorkerThreadTest, Basic) { - dawn_platform::Platform platform; - std::unique_ptr pool = platform.CreateWorkerTaskPool(); - Tracker tracker(pool.get()); - - constexpr uint32_t kTaskCount = 4; - for (uint32_t i = 0; i < kTaskCount; ++i) { - tracker.StartNewTask(i); - } - EXPECT_EQ(kTaskCount, tracker.GetTasksInFlightCount()); - - // Wait for the completion of all the tasks. - tracker.WaitAll(); - - tracker.Tick(); - EXPECT_EQ(0u, tracker.GetTasksInFlightCount()); -}