mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-08-03 02:35:55 +00:00
Implement WaitableEvent and WorkerTaskPool for multi-threaded tasks
This patch adds the basic implementation of WaitableEvent and WorkerTaskPool for multi-threaded tasks in Dawn (for example, the multi-threaded implementation of CreateReady*Pipeline()). BUG=dawn:529 TEST=dawn_unittests Change-Id: Ibf84348f4c0f0d26badc19ae94cd536cef89d084 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/36360 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
parent
311a17a8fe
commit
064f33e441
@ -171,6 +171,7 @@ if (is_win || is_linux || is_chromeos || is_mac || is_fuchsia || is_android) {
|
|||||||
"Math.cpp",
|
"Math.cpp",
|
||||||
"Math.h",
|
"Math.h",
|
||||||
"NSRef.h",
|
"NSRef.h",
|
||||||
|
"NonCopyable.h",
|
||||||
"PlacementAllocated.h",
|
"PlacementAllocated.h",
|
||||||
"Platform.h",
|
"Platform.h",
|
||||||
"RefBase.h",
|
"RefBase.h",
|
||||||
|
@ -33,6 +33,7 @@ target_sources(dawn_common PRIVATE
|
|||||||
"Math.cpp"
|
"Math.cpp"
|
||||||
"Math.h"
|
"Math.h"
|
||||||
"NSRef.h"
|
"NSRef.h"
|
||||||
|
"NonCopyable.h"
|
||||||
"PlacementAllocated.h"
|
"PlacementAllocated.h"
|
||||||
"Platform.h"
|
"Platform.h"
|
||||||
"RefBase.h"
|
"RefBase.h"
|
||||||
|
32
src/common/NonCopyable.h
Normal file
32
src/common/NonCopyable.h
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright 2021 The Dawn Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef COMMON_NONCOPYABLE_H_
|
||||||
|
#define COMMON_NONCOPYABLE_H_
|
||||||
|
|
||||||
|
// NonCopyable:
|
||||||
|
// the base class for the classes that are not copyable.
|
||||||
|
//
|
||||||
|
|
||||||
|
class NonCopyable {
|
||||||
|
protected:
|
||||||
|
constexpr NonCopyable() = default;
|
||||||
|
~NonCopyable() = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
NonCopyable(const NonCopyable&) = delete;
|
||||||
|
void operator=(const NonCopyable&) = delete;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
@ -25,6 +25,8 @@ dawn_component("dawn_platform") {
|
|||||||
"${dawn_root}/src/include/dawn_platform/DawnPlatform.h",
|
"${dawn_root}/src/include/dawn_platform/DawnPlatform.h",
|
||||||
"${dawn_root}/src/include/dawn_platform/dawn_platform_export.h",
|
"${dawn_root}/src/include/dawn_platform/dawn_platform_export.h",
|
||||||
"DawnPlatform.cpp",
|
"DawnPlatform.cpp",
|
||||||
|
"WorkerThread.cpp",
|
||||||
|
"WorkerThread.h",
|
||||||
"tracing/EventTracer.cpp",
|
"tracing/EventTracer.cpp",
|
||||||
"tracing/EventTracer.h",
|
"tracing/EventTracer.h",
|
||||||
"tracing/TraceEvent.h",
|
"tracing/TraceEvent.h",
|
||||||
|
@ -23,6 +23,8 @@ target_sources(dawn_platform PRIVATE
|
|||||||
"${DAWN_INCLUDE_DIR}/dawn_platform/DawnPlatform.h"
|
"${DAWN_INCLUDE_DIR}/dawn_platform/DawnPlatform.h"
|
||||||
"${DAWN_INCLUDE_DIR}/dawn_platform/dawn_platform_export.h"
|
"${DAWN_INCLUDE_DIR}/dawn_platform/dawn_platform_export.h"
|
||||||
"DawnPlatform.cpp"
|
"DawnPlatform.cpp"
|
||||||
|
"WorkerThread.cpp"
|
||||||
|
"WorkerThread.h"
|
||||||
"tracing/EventTracer.cpp"
|
"tracing/EventTracer.cpp"
|
||||||
"tracing/EventTracer.h"
|
"tracing/EventTracer.h"
|
||||||
"tracing/TraceEvent.h"
|
"tracing/TraceEvent.h"
|
||||||
|
@ -13,6 +13,7 @@
|
|||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "dawn_platform/DawnPlatform.h"
|
#include "dawn_platform/DawnPlatform.h"
|
||||||
|
#include "dawn_platform/WorkerThread.h"
|
||||||
|
|
||||||
#include "common/Assert.h"
|
#include "common/Assert.h"
|
||||||
|
|
||||||
@ -55,4 +56,8 @@ namespace dawn_platform {
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<dawn_platform::WorkerTaskPool> Platform::CreateWorkerTaskPool() {
|
||||||
|
return std::make_unique<AsyncWorkerThreadPool>();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dawn_platform
|
} // namespace dawn_platform
|
||||||
|
51
src/dawn_platform/WorkerThread.cpp
Normal file
51
src/dawn_platform/WorkerThread.cpp
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
// Copyright 2021 The Dawn Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "dawn_platform/WorkerThread.h"
|
||||||
|
|
||||||
|
#include <future>
|
||||||
|
|
||||||
|
#include "common/Assert.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
class AsyncWaitableEvent final : public dawn_platform::WaitableEvent {
|
||||||
|
public:
|
||||||
|
explicit AsyncWaitableEvent(std::function<void()> func) {
|
||||||
|
mFuture = std::async(std::launch::async, func);
|
||||||
|
}
|
||||||
|
virtual ~AsyncWaitableEvent() override {
|
||||||
|
ASSERT(IsComplete());
|
||||||
|
}
|
||||||
|
void Wait() override {
|
||||||
|
ASSERT(mFuture.valid());
|
||||||
|
mFuture.wait();
|
||||||
|
}
|
||||||
|
bool IsComplete() override {
|
||||||
|
ASSERT(mFuture.valid());
|
||||||
|
return mFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::future<void> mFuture;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
std::unique_ptr<dawn_platform::WaitableEvent> AsyncWorkerThreadPool::PostWorkerTask(
|
||||||
|
dawn_platform::PostWorkerTaskCallback callback,
|
||||||
|
void* userdata) {
|
||||||
|
std::function<void()> doTask = [callback, userdata]() { callback(userdata); };
|
||||||
|
return std::make_unique<AsyncWaitableEvent>(doTask);
|
||||||
|
}
|
28
src/dawn_platform/WorkerThread.h
Normal file
28
src/dawn_platform/WorkerThread.h
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
// Copyright 2021 The Dawn Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef COMMON_WORKERTHREAD_H_
|
||||||
|
#define COMMON_WORKERTHREAD_H_
|
||||||
|
|
||||||
|
#include "common/NonCopyable.h"
|
||||||
|
#include "dawn_platform/DawnPlatform.h"
|
||||||
|
|
||||||
|
class AsyncWorkerThreadPool : public dawn_platform::WorkerTaskPool, public NonCopyable {
|
||||||
|
public:
|
||||||
|
std::unique_ptr<dawn_platform::WaitableEvent> PostWorkerTask(
|
||||||
|
dawn_platform::PostWorkerTaskCallback callback,
|
||||||
|
void* userdata) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
#endif
|
@ -19,6 +19,7 @@
|
|||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <memory>
|
||||||
|
|
||||||
#include <dawn/webgpu.h>
|
#include <dawn/webgpu.h>
|
||||||
|
|
||||||
@ -60,6 +61,24 @@ namespace dawn_platform {
|
|||||||
CachingInterface& operator=(const CachingInterface&) = delete;
|
CachingInterface& operator=(const CachingInterface&) = delete;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
class DAWN_PLATFORM_EXPORT WaitableEvent {
|
||||||
|
public:
|
||||||
|
WaitableEvent() = default;
|
||||||
|
virtual ~WaitableEvent() = default;
|
||||||
|
virtual void Wait() = 0; // Wait for completion
|
||||||
|
virtual bool IsComplete() = 0; // Non-blocking check if the event is complete
|
||||||
|
};
|
||||||
|
|
||||||
|
using PostWorkerTaskCallback = void (*)(void* userdata);
|
||||||
|
|
||||||
|
class DAWN_PLATFORM_EXPORT WorkerTaskPool {
|
||||||
|
public:
|
||||||
|
WorkerTaskPool() = default;
|
||||||
|
virtual ~WorkerTaskPool() = default;
|
||||||
|
virtual std::unique_ptr<WaitableEvent> PostWorkerTask(PostWorkerTaskCallback,
|
||||||
|
void* userdata) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
class DAWN_PLATFORM_EXPORT Platform {
|
class DAWN_PLATFORM_EXPORT Platform {
|
||||||
public:
|
public:
|
||||||
Platform();
|
Platform();
|
||||||
@ -85,6 +104,7 @@ namespace dawn_platform {
|
|||||||
// device which uses it to persistently cache objects.
|
// device which uses it to persistently cache objects.
|
||||||
virtual CachingInterface* GetCachingInterface(const void* fingerprint,
|
virtual CachingInterface* GetCachingInterface(const void* fingerprint,
|
||||||
size_t fingerprintSize);
|
size_t fingerprintSize);
|
||||||
|
virtual std::unique_ptr<WorkerTaskPool> CreateWorkerTaskPool();
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Platform(const Platform&) = delete;
|
Platform(const Platform&) = delete;
|
||||||
|
@ -181,6 +181,7 @@ test("dawn_unittests") {
|
|||||||
"unittests/SystemUtilsTests.cpp",
|
"unittests/SystemUtilsTests.cpp",
|
||||||
"unittests/ToBackendTests.cpp",
|
"unittests/ToBackendTests.cpp",
|
||||||
"unittests/TypedIntegerTests.cpp",
|
"unittests/TypedIntegerTests.cpp",
|
||||||
|
"unittests/WorkerThreadTests.cpp",
|
||||||
"unittests/validation/BindGroupValidationTests.cpp",
|
"unittests/validation/BindGroupValidationTests.cpp",
|
||||||
"unittests/validation/BufferValidationTests.cpp",
|
"unittests/validation/BufferValidationTests.cpp",
|
||||||
"unittests/validation/CommandBufferValidationTests.cpp",
|
"unittests/validation/CommandBufferValidationTests.cpp",
|
||||||
|
169
src/tests/unittests/WorkerThreadTests.cpp
Normal file
169
src/tests/unittests/WorkerThreadTests.cpp
Normal file
@ -0,0 +1,169 @@
|
|||||||
|
// Copyright 2021 The Dawn Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// WorkerThreadTests:
|
||||||
|
// Simple tests for the worker thread class.
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <list>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <queue>
|
||||||
|
|
||||||
|
#include "common/NonCopyable.h"
|
||||||
|
#include "dawn_platform/DawnPlatform.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct SimpleTaskResult {
|
||||||
|
uint32_t id;
|
||||||
|
bool isDone = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A thread-safe queue that stores the task results.
|
||||||
|
class ConcurrentTaskResultQueue : public NonCopyable {
|
||||||
|
public:
|
||||||
|
void TaskCompleted(const SimpleTaskResult& result) {
|
||||||
|
ASSERT_TRUE(result.isDone);
|
||||||
|
|
||||||
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
|
mTaskResultQueue.push(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<SimpleTaskResult> GetAndPopCompletedTasks() {
|
||||||
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
|
|
||||||
|
std::vector<SimpleTaskResult> results;
|
||||||
|
while (!mTaskResultQueue.empty()) {
|
||||||
|
results.push_back(mTaskResultQueue.front());
|
||||||
|
mTaskResultQueue.pop();
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex mMutex;
|
||||||
|
std::queue<SimpleTaskResult> mTaskResultQueue;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A simple task that will be executed asynchronously with pool->PostWorkerTask().
|
||||||
|
class SimpleTask : public NonCopyable {
|
||||||
|
public:
|
||||||
|
SimpleTask(uint32_t id, ConcurrentTaskResultQueue* resultQueue)
|
||||||
|
: mId(id), mResultQueue(resultQueue) {
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
friend class Tracker;
|
||||||
|
|
||||||
|
static void DoTaskOnWorkerTaskPool(void* task) {
|
||||||
|
SimpleTask* simpleTaskPtr = static_cast<SimpleTask*>(task);
|
||||||
|
simpleTaskPtr->doTask();
|
||||||
|
}
|
||||||
|
|
||||||
|
void doTask() {
|
||||||
|
SimpleTaskResult result;
|
||||||
|
result.id = mId;
|
||||||
|
result.isDone = true;
|
||||||
|
mResultQueue->TaskCompleted(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint32_t mId;
|
||||||
|
ConcurrentTaskResultQueue* mResultQueue;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A simple implementation of task tracker which is only called in main thread and not
|
||||||
|
// thread-safe.
|
||||||
|
class Tracker : public NonCopyable {
|
||||||
|
public:
|
||||||
|
explicit Tracker(dawn_platform::WorkerTaskPool* pool) : mPool(pool) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void StartNewTask(uint32_t taskId) {
|
||||||
|
mTasksInFlight.emplace_back(this, mPool, taskId);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t GetTasksInFlightCount() {
|
||||||
|
return mTasksInFlight.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
void WaitAll() {
|
||||||
|
for (auto iter = mTasksInFlight.begin(); iter != mTasksInFlight.end(); ++iter) {
|
||||||
|
iter->waitableEvent->Wait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// In Tick() we clean up all the completed tasks and consume all the available results.
|
||||||
|
void Tick() {
|
||||||
|
auto iter = mTasksInFlight.begin();
|
||||||
|
while (iter != mTasksInFlight.end()) {
|
||||||
|
if (iter->waitableEvent->IsComplete()) {
|
||||||
|
iter = mTasksInFlight.erase(iter);
|
||||||
|
} else {
|
||||||
|
++iter;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<SimpleTaskResult>& results =
|
||||||
|
mCompletedTaskResultQueue.GetAndPopCompletedTasks();
|
||||||
|
for (const SimpleTaskResult& result : results) {
|
||||||
|
EXPECT_TRUE(result.isDone);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
SimpleTask* CreateSimpleTask(uint32_t taskId) {
|
||||||
|
return new SimpleTask(taskId, &mCompletedTaskResultQueue);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct WaitableTask {
|
||||||
|
WaitableTask(Tracker* tracker, dawn_platform::WorkerTaskPool* pool, uint32_t taskId) {
|
||||||
|
task.reset(tracker->CreateSimpleTask(taskId));
|
||||||
|
waitableEvent =
|
||||||
|
pool->PostWorkerTask(SimpleTask::DoTaskOnWorkerTaskPool, task.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_ptr<SimpleTask> task;
|
||||||
|
std::unique_ptr<dawn_platform::WaitableEvent> waitableEvent;
|
||||||
|
};
|
||||||
|
|
||||||
|
dawn_platform::WorkerTaskPool* mPool;
|
||||||
|
|
||||||
|
std::list<WaitableTask> mTasksInFlight;
|
||||||
|
ConcurrentTaskResultQueue mCompletedTaskResultQueue;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
class WorkerThreadTest : public testing::Test {};
|
||||||
|
|
||||||
|
// Emulate the basic usage of worker thread pool in CreateReady*Pipeline().
|
||||||
|
TEST_F(WorkerThreadTest, Basic) {
|
||||||
|
dawn_platform::Platform platform;
|
||||||
|
std::unique_ptr<dawn_platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
|
||||||
|
Tracker tracker(pool.get());
|
||||||
|
|
||||||
|
constexpr uint32_t kTaskCount = 4;
|
||||||
|
for (uint32_t i = 0; i < kTaskCount; ++i) {
|
||||||
|
tracker.StartNewTask(i);
|
||||||
|
}
|
||||||
|
EXPECT_EQ(kTaskCount, tracker.GetTasksInFlightCount());
|
||||||
|
|
||||||
|
// Wait for the completion of all the tasks.
|
||||||
|
tracker.WaitAll();
|
||||||
|
|
||||||
|
tracker.Tick();
|
||||||
|
EXPECT_EQ(0u, tracker.GetTasksInFlightCount());
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user