diff --git a/src/common/BUILD.gn b/src/common/BUILD.gn index 8db531f6d0..761ecd0c6f 100644 --- a/src/common/BUILD.gn +++ b/src/common/BUILD.gn @@ -171,6 +171,7 @@ if (is_win || is_linux || is_chromeos || is_mac || is_fuchsia || is_android) { "Math.cpp", "Math.h", "NSRef.h", + "NonCopyable.h", "PlacementAllocated.h", "Platform.h", "RefBase.h", diff --git a/src/common/CMakeLists.txt b/src/common/CMakeLists.txt index 3b28bba492..46c1737585 100644 --- a/src/common/CMakeLists.txt +++ b/src/common/CMakeLists.txt @@ -33,6 +33,7 @@ target_sources(dawn_common PRIVATE "Math.cpp" "Math.h" "NSRef.h" + "NonCopyable.h" "PlacementAllocated.h" "Platform.h" "RefBase.h" diff --git a/src/common/NonCopyable.h b/src/common/NonCopyable.h new file mode 100644 index 0000000000..e711f7133a --- /dev/null +++ b/src/common/NonCopyable.h @@ -0,0 +1,32 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef COMMON_NONCOPYABLE_H_ +#define COMMON_NONCOPYABLE_H_ + +// NonCopyable: +// the base class for the classes that are not copyable. +// + +class NonCopyable { + protected: + constexpr NonCopyable() = default; + ~NonCopyable() = default; + + private: + NonCopyable(const NonCopyable&) = delete; + void operator=(const NonCopyable&) = delete; +}; + +#endif diff --git a/src/dawn_platform/BUILD.gn b/src/dawn_platform/BUILD.gn index 91c9e75a14..cbd323814e 100644 --- a/src/dawn_platform/BUILD.gn +++ b/src/dawn_platform/BUILD.gn @@ -25,6 +25,8 @@ dawn_component("dawn_platform") { "${dawn_root}/src/include/dawn_platform/DawnPlatform.h", "${dawn_root}/src/include/dawn_platform/dawn_platform_export.h", "DawnPlatform.cpp", + "WorkerThread.cpp", + "WorkerThread.h", "tracing/EventTracer.cpp", "tracing/EventTracer.h", "tracing/TraceEvent.h", diff --git a/src/dawn_platform/CMakeLists.txt b/src/dawn_platform/CMakeLists.txt index b8075e29e1..92372bbe71 100644 --- a/src/dawn_platform/CMakeLists.txt +++ b/src/dawn_platform/CMakeLists.txt @@ -23,6 +23,8 @@ target_sources(dawn_platform PRIVATE "${DAWN_INCLUDE_DIR}/dawn_platform/DawnPlatform.h" "${DAWN_INCLUDE_DIR}/dawn_platform/dawn_platform_export.h" "DawnPlatform.cpp" + "WorkerThread.cpp" + "WorkerThread.h" "tracing/EventTracer.cpp" "tracing/EventTracer.h" "tracing/TraceEvent.h" diff --git a/src/dawn_platform/DawnPlatform.cpp b/src/dawn_platform/DawnPlatform.cpp index b772bacc49..1bedbcb141 100644 --- a/src/dawn_platform/DawnPlatform.cpp +++ b/src/dawn_platform/DawnPlatform.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "dawn_platform/DawnPlatform.h" +#include "dawn_platform/WorkerThread.h" #include "common/Assert.h" @@ -55,4 +56,8 @@ namespace dawn_platform { return nullptr; } + std::unique_ptr Platform::CreateWorkerTaskPool() { + return std::make_unique(); + } + } // namespace dawn_platform diff --git a/src/dawn_platform/WorkerThread.cpp b/src/dawn_platform/WorkerThread.cpp new file mode 100644 index 0000000000..64d09f153a --- /dev/null +++ b/src/dawn_platform/WorkerThread.cpp @@ -0,0 +1,51 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn_platform/WorkerThread.h" + +#include + +#include "common/Assert.h" + +namespace { + + class AsyncWaitableEvent final : public dawn_platform::WaitableEvent { + public: + explicit AsyncWaitableEvent(std::function func) { + mFuture = std::async(std::launch::async, func); + } + virtual ~AsyncWaitableEvent() override { + ASSERT(IsComplete()); + } + void Wait() override { + ASSERT(mFuture.valid()); + mFuture.wait(); + } + bool IsComplete() override { + ASSERT(mFuture.valid()); + return mFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready; + } + + private: + std::future mFuture; + }; + +} // anonymous namespace + +std::unique_ptr AsyncWorkerThreadPool::PostWorkerTask( + dawn_platform::PostWorkerTaskCallback callback, + void* userdata) { + std::function doTask = [callback, userdata]() { callback(userdata); }; + return std::make_unique(doTask); +} \ No newline at end of file diff --git a/src/dawn_platform/WorkerThread.h b/src/dawn_platform/WorkerThread.h new file mode 100644 index 0000000000..56a5d1005d --- /dev/null +++ b/src/dawn_platform/WorkerThread.h @@ -0,0 +1,28 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef COMMON_WORKERTHREAD_H_ +#define COMMON_WORKERTHREAD_H_ + +#include "common/NonCopyable.h" +#include "dawn_platform/DawnPlatform.h" + +class AsyncWorkerThreadPool : public dawn_platform::WorkerTaskPool, public NonCopyable { + public: + std::unique_ptr PostWorkerTask( + dawn_platform::PostWorkerTaskCallback callback, + void* userdata) override; +}; + +#endif diff --git a/src/include/dawn_platform/DawnPlatform.h b/src/include/dawn_platform/DawnPlatform.h index 4a00f5319e..3a28419334 100644 --- a/src/include/dawn_platform/DawnPlatform.h +++ b/src/include/dawn_platform/DawnPlatform.h @@ -19,6 +19,7 @@ #include #include +#include #include @@ -60,6 +61,24 @@ namespace dawn_platform { CachingInterface& operator=(const CachingInterface&) = delete; }; + class DAWN_PLATFORM_EXPORT WaitableEvent { + public: + WaitableEvent() = default; + virtual ~WaitableEvent() = default; + virtual void Wait() = 0; // Wait for completion + virtual bool IsComplete() = 0; // Non-blocking check if the event is complete + }; + + using PostWorkerTaskCallback = void (*)(void* userdata); + + class DAWN_PLATFORM_EXPORT WorkerTaskPool { + public: + WorkerTaskPool() = default; + virtual ~WorkerTaskPool() = default; + virtual std::unique_ptr PostWorkerTask(PostWorkerTaskCallback, + void* userdata) = 0; + }; + class DAWN_PLATFORM_EXPORT Platform { public: Platform(); @@ -85,6 +104,7 @@ namespace dawn_platform { // device which uses it to persistently cache objects. virtual CachingInterface* GetCachingInterface(const void* fingerprint, size_t fingerprintSize); + virtual std::unique_ptr CreateWorkerTaskPool(); private: Platform(const Platform&) = delete; diff --git a/src/tests/BUILD.gn b/src/tests/BUILD.gn index 3971b09282..daabd3b2eb 100644 --- a/src/tests/BUILD.gn +++ b/src/tests/BUILD.gn @@ -181,6 +181,7 @@ test("dawn_unittests") { "unittests/SystemUtilsTests.cpp", "unittests/ToBackendTests.cpp", "unittests/TypedIntegerTests.cpp", + "unittests/WorkerThreadTests.cpp", "unittests/validation/BindGroupValidationTests.cpp", "unittests/validation/BufferValidationTests.cpp", "unittests/validation/CommandBufferValidationTests.cpp", diff --git a/src/tests/unittests/WorkerThreadTests.cpp b/src/tests/unittests/WorkerThreadTests.cpp new file mode 100644 index 0000000000..8faee5d064 --- /dev/null +++ b/src/tests/unittests/WorkerThreadTests.cpp @@ -0,0 +1,169 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// WorkerThreadTests: +// Simple tests for the worker thread class. + +#include + +#include +#include +#include +#include + +#include "common/NonCopyable.h" +#include "dawn_platform/DawnPlatform.h" + +namespace { + + struct SimpleTaskResult { + uint32_t id; + bool isDone = false; + }; + + // A thread-safe queue that stores the task results. + class ConcurrentTaskResultQueue : public NonCopyable { + public: + void TaskCompleted(const SimpleTaskResult& result) { + ASSERT_TRUE(result.isDone); + + std::lock_guard lock(mMutex); + mTaskResultQueue.push(result); + } + + std::vector GetAndPopCompletedTasks() { + std::lock_guard lock(mMutex); + + std::vector results; + while (!mTaskResultQueue.empty()) { + results.push_back(mTaskResultQueue.front()); + mTaskResultQueue.pop(); + } + return results; + } + + private: + std::mutex mMutex; + std::queue mTaskResultQueue; + }; + + // A simple task that will be executed asynchronously with pool->PostWorkerTask(). + class SimpleTask : public NonCopyable { + public: + SimpleTask(uint32_t id, ConcurrentTaskResultQueue* resultQueue) + : mId(id), mResultQueue(resultQueue) { + } + + private: + friend class Tracker; + + static void DoTaskOnWorkerTaskPool(void* task) { + SimpleTask* simpleTaskPtr = static_cast(task); + simpleTaskPtr->doTask(); + } + + void doTask() { + SimpleTaskResult result; + result.id = mId; + result.isDone = true; + mResultQueue->TaskCompleted(result); + } + + uint32_t mId; + ConcurrentTaskResultQueue* mResultQueue; + }; + + // A simple implementation of task tracker which is only called in main thread and not + // thread-safe. + class Tracker : public NonCopyable { + public: + explicit Tracker(dawn_platform::WorkerTaskPool* pool) : mPool(pool) { + } + + void StartNewTask(uint32_t taskId) { + mTasksInFlight.emplace_back(this, mPool, taskId); + } + + uint64_t GetTasksInFlightCount() { + return mTasksInFlight.size(); + } + + void WaitAll() { + for (auto iter = mTasksInFlight.begin(); iter != mTasksInFlight.end(); ++iter) { + iter->waitableEvent->Wait(); + } + } + + // In Tick() we clean up all the completed tasks and consume all the available results. + void Tick() { + auto iter = mTasksInFlight.begin(); + while (iter != mTasksInFlight.end()) { + if (iter->waitableEvent->IsComplete()) { + iter = mTasksInFlight.erase(iter); + } else { + ++iter; + } + } + + const std::vector& results = + mCompletedTaskResultQueue.GetAndPopCompletedTasks(); + for (const SimpleTaskResult& result : results) { + EXPECT_TRUE(result.isDone); + } + } + + private: + SimpleTask* CreateSimpleTask(uint32_t taskId) { + return new SimpleTask(taskId, &mCompletedTaskResultQueue); + } + + struct WaitableTask { + WaitableTask(Tracker* tracker, dawn_platform::WorkerTaskPool* pool, uint32_t taskId) { + task.reset(tracker->CreateSimpleTask(taskId)); + waitableEvent = + pool->PostWorkerTask(SimpleTask::DoTaskOnWorkerTaskPool, task.get()); + } + + std::unique_ptr task; + std::unique_ptr waitableEvent; + }; + + dawn_platform::WorkerTaskPool* mPool; + + std::list mTasksInFlight; + ConcurrentTaskResultQueue mCompletedTaskResultQueue; + }; + +} // anonymous namespace + +class WorkerThreadTest : public testing::Test {}; + +// Emulate the basic usage of worker thread pool in CreateReady*Pipeline(). +TEST_F(WorkerThreadTest, Basic) { + dawn_platform::Platform platform; + std::unique_ptr pool = platform.CreateWorkerTaskPool(); + Tracker tracker(pool.get()); + + constexpr uint32_t kTaskCount = 4; + for (uint32_t i = 0; i < kTaskCount; ++i) { + tracker.StartNewTask(i); + } + EXPECT_EQ(kTaskCount, tracker.GetTasksInFlightCount()); + + // Wait for the completion of all the tasks. + tracker.WaitAll(); + + tracker.Tick(); + EXPECT_EQ(0u, tracker.GetTasksInFlightCount()); +}