mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-07-12 08:05:53 +00:00
Implement AsyncWaitableEvent with std::condition_variable
This patch implements the default implementation of WaitableEvent (AsyncWaitableEvent) with std::condition_variable instead of std::future as std::future will always block its destructor until the async function returns, which makes us unable to clean up all the execution environment of the async task inside the async function. This patch also implements WorkerThreadTaskManager to manage all the async tasks (inherited from WorkerThreadTask) in the future, for example all the Create*PipelineAsync() tasks. This patch also updates the related dawn_unittest WorkerThreadTest. Basic. BUG=dawn:529 TEST=dawn_unittests Change-Id: Ie789ba788789e91128ffc416e7e768923828a367 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/51740 Reviewed-by: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
parent
462a3ba917
commit
5d39860fef
60
src/dawn_native/AsyncTask.cpp
Normal file
60
src/dawn_native/AsyncTask.cpp
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
#include "dawn_native/AsyncTask.h"
|
||||||
|
|
||||||
|
#include "dawn_platform/DawnPlatform.h"
|
||||||
|
|
||||||
|
namespace dawn_native {
|
||||||
|
|
||||||
|
AsnycTaskManager::AsnycTaskManager(dawn_platform::WorkerTaskPool* workerTaskPool)
|
||||||
|
: mWorkerTaskPool(workerTaskPool) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsnycTaskManager::PostTask(AsyncTask asyncTask) {
|
||||||
|
// If these allocations becomes expensive, we can slab-allocate tasks.
|
||||||
|
Ref<WaitableTask> waitableTask = AcquireRef(new WaitableTask());
|
||||||
|
waitableTask->taskManager = this;
|
||||||
|
waitableTask->asyncTask = std::move(asyncTask);
|
||||||
|
|
||||||
|
{
|
||||||
|
// We insert new waitableTask objects into mPendingTasks in main thread (PostTask()),
|
||||||
|
// and we may remove waitableTask objects from mPendingTasks in either main thread
|
||||||
|
// (WaitAllPendingTasks()) or sub-thread (TaskCompleted), so mPendingTasks should be
|
||||||
|
// protected by a mutex.
|
||||||
|
std::lock_guard<std::mutex> lock(mPendingTasksMutex);
|
||||||
|
mPendingTasks.emplace(waitableTask.Get(), waitableTask);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ref the task since it is accessed inside the worker function.
|
||||||
|
// The worker function will acquire and release the task upon completion.
|
||||||
|
waitableTask->Reference();
|
||||||
|
waitableTask->waitableEvent =
|
||||||
|
mWorkerTaskPool->PostWorkerTask(DoWaitableTask, waitableTask.Get());
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsnycTaskManager::HandleTaskCompletion(WaitableTask* task) {
|
||||||
|
std::lock_guard<std::mutex> lock(mPendingTasksMutex);
|
||||||
|
auto iter = mPendingTasks.find(task);
|
||||||
|
if (iter != mPendingTasks.end()) {
|
||||||
|
mPendingTasks.erase(iter);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsnycTaskManager::WaitAllPendingTasks() {
|
||||||
|
std::unordered_map<WaitableTask*, Ref<WaitableTask>> allPendingTasks;
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mPendingTasksMutex);
|
||||||
|
allPendingTasks.swap(mPendingTasks);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (auto& keyValue : allPendingTasks) {
|
||||||
|
keyValue.second->waitableEvent->Wait();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AsnycTaskManager::DoWaitableTask(void* task) {
|
||||||
|
Ref<WaitableTask> waitableTask = AcquireRef(static_cast<WaitableTask*>(task));
|
||||||
|
waitableTask->asyncTask();
|
||||||
|
waitableTask->taskManager->HandleTaskCompletion(waitableTask.Get());
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dawn_native
|
64
src/dawn_native/AsyncTask.h
Normal file
64
src/dawn_native/AsyncTask.h
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
// Copyright 2021 The Dawn Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef DAWNNATIVE_ASYC_TASK_H_
|
||||||
|
#define DAWNNATIVE_ASYC_TASK_H_
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "common/RefCounted.h"
|
||||||
|
|
||||||
|
namespace dawn_platform {
|
||||||
|
class WaitableEvent;
|
||||||
|
class WorkerTaskPool;
|
||||||
|
} // namespace dawn_platform
|
||||||
|
|
||||||
|
namespace dawn_native {
|
||||||
|
|
||||||
|
// TODO(jiawei.shao@intel.com): we'll add additional things to AsyncTask in the future, like
|
||||||
|
// Cancel() and RunNow(). Cancelling helps avoid running the task's body when we are just
|
||||||
|
// shutting down the device. RunNow() could be used for more advanced scenarios, for example
|
||||||
|
// always doing ShaderModule initial compilation asynchronously, but being able to steal the
|
||||||
|
// task if we need it for synchronous pipeline compilation.
|
||||||
|
using AsyncTask = std::function<void()>;
|
||||||
|
|
||||||
|
class AsnycTaskManager {
|
||||||
|
public:
|
||||||
|
explicit AsnycTaskManager(dawn_platform::WorkerTaskPool* workerTaskPool);
|
||||||
|
|
||||||
|
void PostTask(AsyncTask asyncTask);
|
||||||
|
void WaitAllPendingTasks();
|
||||||
|
|
||||||
|
private:
|
||||||
|
class WaitableTask : public RefCounted {
|
||||||
|
public:
|
||||||
|
AsyncTask asyncTask;
|
||||||
|
AsnycTaskManager* taskManager;
|
||||||
|
std::unique_ptr<dawn_platform::WaitableEvent> waitableEvent;
|
||||||
|
};
|
||||||
|
|
||||||
|
static void DoWaitableTask(void* task);
|
||||||
|
void HandleTaskCompletion(WaitableTask* task);
|
||||||
|
|
||||||
|
std::mutex mPendingTasksMutex;
|
||||||
|
std::unordered_map<WaitableTask*, Ref<WaitableTask>> mPendingTasks;
|
||||||
|
dawn_platform::WorkerTaskPool* mWorkerTaskPool;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dawn_native
|
||||||
|
|
||||||
|
#endif
|
@ -165,6 +165,8 @@ source_set("dawn_native_sources") {
|
|||||||
sources += [
|
sources += [
|
||||||
"Adapter.cpp",
|
"Adapter.cpp",
|
||||||
"Adapter.h",
|
"Adapter.h",
|
||||||
|
"AsyncTask.cpp",
|
||||||
|
"AsyncTask.h",
|
||||||
"AttachmentState.cpp",
|
"AttachmentState.cpp",
|
||||||
"AttachmentState.h",
|
"AttachmentState.h",
|
||||||
"BackendConnection.cpp",
|
"BackendConnection.cpp",
|
||||||
|
@ -31,6 +31,8 @@ target_sources(dawn_native PRIVATE
|
|||||||
${DAWN_NATIVE_UTILS_GEN_SOURCES}
|
${DAWN_NATIVE_UTILS_GEN_SOURCES}
|
||||||
"Adapter.cpp"
|
"Adapter.cpp"
|
||||||
"Adapter.h"
|
"Adapter.h"
|
||||||
|
"AsyncTask.cpp"
|
||||||
|
"AsyncTask.h"
|
||||||
"AttachmentState.cpp"
|
"AttachmentState.cpp"
|
||||||
"AttachmentState.h"
|
"AttachmentState.h"
|
||||||
"BackendConnection.cpp"
|
"BackendConnection.cpp"
|
||||||
|
@ -21,8 +21,6 @@
|
|||||||
|
|
||||||
namespace dawn_native {
|
namespace dawn_native {
|
||||||
|
|
||||||
class CallbackTaskManager;
|
|
||||||
|
|
||||||
struct CallbackTask {
|
struct CallbackTask {
|
||||||
public:
|
public:
|
||||||
virtual ~CallbackTask() = default;
|
virtual ~CallbackTask() = default;
|
||||||
|
@ -14,40 +14,84 @@
|
|||||||
|
|
||||||
#include "dawn_platform/WorkerThread.h"
|
#include "dawn_platform/WorkerThread.h"
|
||||||
|
|
||||||
#include <future>
|
#include <condition_variable>
|
||||||
|
#include <functional>
|
||||||
|
#include <thread>
|
||||||
|
|
||||||
#include "common/Assert.h"
|
#include "common/Assert.h"
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
class AsyncWaitableEvent final : public dawn_platform::WaitableEvent {
|
class AsyncWaitableEventImpl {
|
||||||
public:
|
public:
|
||||||
explicit AsyncWaitableEvent(std::function<void()> func) {
|
AsyncWaitableEventImpl() : mIsComplete(false) {
|
||||||
mFuture = std::async(std::launch::async, func);
|
|
||||||
}
|
}
|
||||||
void Wait() override {
|
|
||||||
ASSERT(mFuture.valid());
|
void Wait() {
|
||||||
mFuture.wait();
|
std::unique_lock<std::mutex> lock(mMutex);
|
||||||
|
mCondition.wait(lock, [this] { return mIsComplete; });
|
||||||
}
|
}
|
||||||
bool IsComplete() override {
|
|
||||||
ASSERT(mFuture.valid());
|
bool IsComplete() {
|
||||||
return mFuture.wait_for(std::chrono::seconds(0)) == std::future_status::ready;
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
|
return mIsComplete;
|
||||||
|
}
|
||||||
|
|
||||||
|
void MarkAsComplete() {
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
|
mIsComplete = true;
|
||||||
|
}
|
||||||
|
mCondition.notify_all();
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// It is safe not to call Wait() in the destructor of AsyncWaitableEvent because since
|
std::mutex mMutex;
|
||||||
// C++14 the destructor of std::future will always be blocked until its state becomes
|
std::condition_variable mCondition;
|
||||||
// std::future_status::ready when it was created by a call of std::async and it is the
|
bool mIsComplete;
|
||||||
// last reference to the shared state.
|
};
|
||||||
// See https://en.cppreference.com/w/cpp/thread/future/~future for more details.
|
|
||||||
std::future<void> mFuture;
|
class AsyncWaitableEvent final : public dawn_platform::WaitableEvent {
|
||||||
|
public:
|
||||||
|
explicit AsyncWaitableEvent()
|
||||||
|
: mWaitableEventImpl(std::make_shared<AsyncWaitableEventImpl>()) {
|
||||||
|
}
|
||||||
|
|
||||||
|
void Wait() override {
|
||||||
|
mWaitableEventImpl->Wait();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsComplete() override {
|
||||||
|
return mWaitableEventImpl->IsComplete();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<AsyncWaitableEventImpl> GetWaitableEventImpl() const {
|
||||||
|
return mWaitableEventImpl;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::shared_ptr<AsyncWaitableEventImpl> mWaitableEventImpl;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
std::unique_ptr<dawn_platform::WaitableEvent> AsyncWorkerThreadPool::PostWorkerTask(
|
namespace dawn_platform {
|
||||||
dawn_platform::PostWorkerTaskCallback callback,
|
|
||||||
void* userdata) {
|
std::unique_ptr<dawn_platform::WaitableEvent> AsyncWorkerThreadPool::PostWorkerTask(
|
||||||
std::function<void()> doTask = [callback, userdata]() { callback(userdata); };
|
dawn_platform::PostWorkerTaskCallback callback,
|
||||||
return std::make_unique<AsyncWaitableEvent>(doTask);
|
void* userdata) {
|
||||||
}
|
std::unique_ptr<AsyncWaitableEvent> waitableEvent = std::make_unique<AsyncWaitableEvent>();
|
||||||
|
|
||||||
|
std::function<void()> doTask =
|
||||||
|
[callback, userdata, waitableEventImpl = waitableEvent->GetWaitableEventImpl()]() {
|
||||||
|
callback(userdata);
|
||||||
|
waitableEventImpl->MarkAsComplete();
|
||||||
|
};
|
||||||
|
|
||||||
|
std::thread thread(doTask);
|
||||||
|
thread.detach();
|
||||||
|
|
||||||
|
return waitableEvent;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace dawn_platform
|
||||||
|
@ -18,11 +18,15 @@
|
|||||||
#include "common/NonCopyable.h"
|
#include "common/NonCopyable.h"
|
||||||
#include "dawn_platform/DawnPlatform.h"
|
#include "dawn_platform/DawnPlatform.h"
|
||||||
|
|
||||||
class AsyncWorkerThreadPool : public dawn_platform::WorkerTaskPool, public NonCopyable {
|
namespace dawn_platform {
|
||||||
public:
|
|
||||||
std::unique_ptr<dawn_platform::WaitableEvent> PostWorkerTask(
|
class AsyncWorkerThreadPool : public dawn_platform::WorkerTaskPool, public NonCopyable {
|
||||||
dawn_platform::PostWorkerTaskCallback callback,
|
public:
|
||||||
void* userdata) override;
|
std::unique_ptr<dawn_platform::WaitableEvent> PostWorkerTask(
|
||||||
};
|
dawn_platform::PostWorkerTaskCallback callback,
|
||||||
|
void* userdata) override;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace dawn_platform
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
@ -153,6 +153,7 @@ test("dawn_unittests") {
|
|||||||
"MockCallback.h",
|
"MockCallback.h",
|
||||||
"ToggleParser.cpp",
|
"ToggleParser.cpp",
|
||||||
"ToggleParser.h",
|
"ToggleParser.h",
|
||||||
|
"unittests/AsyncTaskTests.cpp",
|
||||||
"unittests/BitSetIteratorTests.cpp",
|
"unittests/BitSetIteratorTests.cpp",
|
||||||
"unittests/BuddyAllocatorTests.cpp",
|
"unittests/BuddyAllocatorTests.cpp",
|
||||||
"unittests/BuddyMemoryAllocatorTests.cpp",
|
"unittests/BuddyMemoryAllocatorTests.cpp",
|
||||||
@ -184,7 +185,6 @@ test("dawn_unittests") {
|
|||||||
"unittests/SystemUtilsTests.cpp",
|
"unittests/SystemUtilsTests.cpp",
|
||||||
"unittests/ToBackendTests.cpp",
|
"unittests/ToBackendTests.cpp",
|
||||||
"unittests/TypedIntegerTests.cpp",
|
"unittests/TypedIntegerTests.cpp",
|
||||||
"unittests/WorkerThreadTests.cpp",
|
|
||||||
"unittests/validation/BindGroupValidationTests.cpp",
|
"unittests/validation/BindGroupValidationTests.cpp",
|
||||||
"unittests/validation/BufferValidationTests.cpp",
|
"unittests/validation/BufferValidationTests.cpp",
|
||||||
"unittests/validation/CommandBufferValidationTests.cpp",
|
"unittests/validation/CommandBufferValidationTests.cpp",
|
||||||
|
89
src/tests/unittests/AsyncTaskTests.cpp
Normal file
89
src/tests/unittests/AsyncTaskTests.cpp
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
// Copyright 2021 The Dawn Authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
//
|
||||||
|
// AsyncTaskTests:
|
||||||
|
// Simple tests for dawn_native::AsyncTask and dawn_native::AsnycTaskManager.
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <mutex>
|
||||||
|
|
||||||
|
#include "common/NonCopyable.h"
|
||||||
|
#include "dawn_native/AsyncTask.h"
|
||||||
|
#include "dawn_platform/DawnPlatform.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
struct SimpleTaskResult {
|
||||||
|
uint32_t id;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A thread-safe queue that stores the task results.
|
||||||
|
class ConcurrentTaskResultQueue : public NonCopyable {
|
||||||
|
public:
|
||||||
|
void AddResult(std::unique_ptr<SimpleTaskResult> result) {
|
||||||
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
|
mTaskResults.push_back(std::move(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<SimpleTaskResult>> GetAllResults() {
|
||||||
|
std::vector<std::unique_ptr<SimpleTaskResult>> outputResults;
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
|
outputResults.swap(mTaskResults);
|
||||||
|
}
|
||||||
|
return outputResults;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::mutex mMutex;
|
||||||
|
std::vector<std::unique_ptr<SimpleTaskResult>> mTaskResults;
|
||||||
|
};
|
||||||
|
|
||||||
|
void DoTask(ConcurrentTaskResultQueue* resultQueue, uint32_t id) {
|
||||||
|
std::unique_ptr<SimpleTaskResult> result = std::make_unique<SimpleTaskResult>();
|
||||||
|
result->id = id;
|
||||||
|
resultQueue->AddResult(std::move(result));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // anonymous namespace
|
||||||
|
|
||||||
|
class AsyncTaskTest : public testing::Test {};
|
||||||
|
|
||||||
|
// Emulate the basic usage of worker thread pool in Create*PipelineAsync().
|
||||||
|
TEST_F(AsyncTaskTest, Basic) {
|
||||||
|
dawn_platform::Platform platform;
|
||||||
|
std::unique_ptr<dawn_platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
|
||||||
|
|
||||||
|
dawn_native::AsnycTaskManager taskManager(pool.get());
|
||||||
|
ConcurrentTaskResultQueue taskResultQueue;
|
||||||
|
|
||||||
|
constexpr size_t kTaskCount = 4u;
|
||||||
|
std::set<uint32_t> idset;
|
||||||
|
for (uint32_t i = 0; i < kTaskCount; ++i) {
|
||||||
|
dawn_native::AsyncTask asyncTask([&taskResultQueue, i] { DoTask(&taskResultQueue, i); });
|
||||||
|
taskManager.PostTask(std::move(asyncTask));
|
||||||
|
idset.insert(i);
|
||||||
|
}
|
||||||
|
|
||||||
|
taskManager.WaitAllPendingTasks();
|
||||||
|
|
||||||
|
std::vector<std::unique_ptr<SimpleTaskResult>> results = taskResultQueue.GetAllResults();
|
||||||
|
ASSERT_EQ(kTaskCount, results.size());
|
||||||
|
for (std::unique_ptr<SimpleTaskResult>& result : results) {
|
||||||
|
idset.erase(result->id);
|
||||||
|
}
|
||||||
|
ASSERT_TRUE(idset.empty());
|
||||||
|
}
|
@ -1,169 +0,0 @@
|
|||||||
// Copyright 2021 The Dawn Authors
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
//
|
|
||||||
// WorkerThreadTests:
|
|
||||||
// Simple tests for the worker thread class.
|
|
||||||
|
|
||||||
#include <gtest/gtest.h>
|
|
||||||
|
|
||||||
#include <list>
|
|
||||||
#include <memory>
|
|
||||||
#include <mutex>
|
|
||||||
#include <queue>
|
|
||||||
|
|
||||||
#include "common/NonCopyable.h"
|
|
||||||
#include "dawn_platform/DawnPlatform.h"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
struct SimpleTaskResult {
|
|
||||||
uint32_t id;
|
|
||||||
bool isDone = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
// A thread-safe queue that stores the task results.
|
|
||||||
class ConcurrentTaskResultQueue : public NonCopyable {
|
|
||||||
public:
|
|
||||||
void TaskCompleted(const SimpleTaskResult& result) {
|
|
||||||
ASSERT_TRUE(result.isDone);
|
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(mMutex);
|
|
||||||
mTaskResultQueue.push(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<SimpleTaskResult> GetAndPopCompletedTasks() {
|
|
||||||
std::lock_guard<std::mutex> lock(mMutex);
|
|
||||||
|
|
||||||
std::vector<SimpleTaskResult> results;
|
|
||||||
while (!mTaskResultQueue.empty()) {
|
|
||||||
results.push_back(mTaskResultQueue.front());
|
|
||||||
mTaskResultQueue.pop();
|
|
||||||
}
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::mutex mMutex;
|
|
||||||
std::queue<SimpleTaskResult> mTaskResultQueue;
|
|
||||||
};
|
|
||||||
|
|
||||||
// A simple task that will be executed asynchronously with pool->PostWorkerTask().
|
|
||||||
class SimpleTask : public NonCopyable {
|
|
||||||
public:
|
|
||||||
SimpleTask(uint32_t id, ConcurrentTaskResultQueue* resultQueue)
|
|
||||||
: mId(id), mResultQueue(resultQueue) {
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
friend class Tracker;
|
|
||||||
|
|
||||||
static void DoTaskOnWorkerTaskPool(void* task) {
|
|
||||||
SimpleTask* simpleTaskPtr = static_cast<SimpleTask*>(task);
|
|
||||||
simpleTaskPtr->doTask();
|
|
||||||
}
|
|
||||||
|
|
||||||
void doTask() {
|
|
||||||
SimpleTaskResult result;
|
|
||||||
result.id = mId;
|
|
||||||
result.isDone = true;
|
|
||||||
mResultQueue->TaskCompleted(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint32_t mId;
|
|
||||||
ConcurrentTaskResultQueue* mResultQueue;
|
|
||||||
};
|
|
||||||
|
|
||||||
// A simple implementation of task tracker which is only called in main thread and not
|
|
||||||
// thread-safe.
|
|
||||||
class Tracker : public NonCopyable {
|
|
||||||
public:
|
|
||||||
explicit Tracker(dawn_platform::WorkerTaskPool* pool) : mPool(pool) {
|
|
||||||
}
|
|
||||||
|
|
||||||
void StartNewTask(uint32_t taskId) {
|
|
||||||
mTasksInFlight.emplace_back(this, mPool, taskId);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t GetTasksInFlightCount() {
|
|
||||||
return mTasksInFlight.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
void WaitAll() {
|
|
||||||
for (auto iter = mTasksInFlight.begin(); iter != mTasksInFlight.end(); ++iter) {
|
|
||||||
iter->waitableEvent->Wait();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// In Tick() we clean up all the completed tasks and consume all the available results.
|
|
||||||
void Tick() {
|
|
||||||
auto iter = mTasksInFlight.begin();
|
|
||||||
while (iter != mTasksInFlight.end()) {
|
|
||||||
if (iter->waitableEvent->IsComplete()) {
|
|
||||||
iter = mTasksInFlight.erase(iter);
|
|
||||||
} else {
|
|
||||||
++iter;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const std::vector<SimpleTaskResult>& results =
|
|
||||||
mCompletedTaskResultQueue.GetAndPopCompletedTasks();
|
|
||||||
for (const SimpleTaskResult& result : results) {
|
|
||||||
EXPECT_TRUE(result.isDone);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
SimpleTask* CreateSimpleTask(uint32_t taskId) {
|
|
||||||
return new SimpleTask(taskId, &mCompletedTaskResultQueue);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct WaitableTask {
|
|
||||||
WaitableTask(Tracker* tracker, dawn_platform::WorkerTaskPool* pool, uint32_t taskId) {
|
|
||||||
task.reset(tracker->CreateSimpleTask(taskId));
|
|
||||||
waitableEvent =
|
|
||||||
pool->PostWorkerTask(SimpleTask::DoTaskOnWorkerTaskPool, task.get());
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unique_ptr<SimpleTask> task;
|
|
||||||
std::unique_ptr<dawn_platform::WaitableEvent> waitableEvent;
|
|
||||||
};
|
|
||||||
|
|
||||||
dawn_platform::WorkerTaskPool* mPool;
|
|
||||||
|
|
||||||
std::list<WaitableTask> mTasksInFlight;
|
|
||||||
ConcurrentTaskResultQueue mCompletedTaskResultQueue;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // anonymous namespace
|
|
||||||
|
|
||||||
class WorkerThreadTest : public testing::Test {};
|
|
||||||
|
|
||||||
// Emulate the basic usage of worker thread pool in Create*PipelineAsync().
|
|
||||||
TEST_F(WorkerThreadTest, Basic) {
|
|
||||||
dawn_platform::Platform platform;
|
|
||||||
std::unique_ptr<dawn_platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
|
|
||||||
Tracker tracker(pool.get());
|
|
||||||
|
|
||||||
constexpr uint32_t kTaskCount = 4;
|
|
||||||
for (uint32_t i = 0; i < kTaskCount; ++i) {
|
|
||||||
tracker.StartNewTask(i);
|
|
||||||
}
|
|
||||||
EXPECT_EQ(kTaskCount, tracker.GetTasksInFlightCount());
|
|
||||||
|
|
||||||
// Wait for the completion of all the tasks.
|
|
||||||
tracker.WaitAll();
|
|
||||||
|
|
||||||
tracker.Tick();
|
|
||||||
EXPECT_EQ(0u, tracker.GetTasksInFlightCount());
|
|
||||||
}
|
|
Loading…
x
Reference in New Issue
Block a user