Implement AsyncWaitableEvent with std::condition_variable

This patch implements the default implementation of WaitableEvent
(AsyncWaitableEvent) with std::condition_variable instead of
std::future as std::future will always block its destructor until
the async function returns, which makes us unable to clean up all
the execution environment of the async task inside the async
function.

This patch also implements WorkerThreadTaskManager to manage all
the async tasks (inherited from WorkerThreadTask) in the future,
for example all the Create*PipelineAsync() tasks.

This patch also updates the related dawn_unittest WorkerThreadTest.
Basic.

BUG=dawn:529
TEST=dawn_unittests

Change-Id: Ie789ba788789e91128ffc416e7e768923828a367
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/51740
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
Jiawei Shao
2021-05-27 00:49:03 +00:00
committed by Dawn LUCI CQ
parent 462a3ba917
commit 5d39860fef
10 changed files with 294 additions and 200 deletions

View File

@@ -153,6 +153,7 @@ test("dawn_unittests") {
"MockCallback.h",
"ToggleParser.cpp",
"ToggleParser.h",
"unittests/AsyncTaskTests.cpp",
"unittests/BitSetIteratorTests.cpp",
"unittests/BuddyAllocatorTests.cpp",
"unittests/BuddyMemoryAllocatorTests.cpp",
@@ -184,7 +185,6 @@ test("dawn_unittests") {
"unittests/SystemUtilsTests.cpp",
"unittests/ToBackendTests.cpp",
"unittests/TypedIntegerTests.cpp",
"unittests/WorkerThreadTests.cpp",
"unittests/validation/BindGroupValidationTests.cpp",
"unittests/validation/BufferValidationTests.cpp",
"unittests/validation/CommandBufferValidationTests.cpp",

View File

@@ -0,0 +1,89 @@
// Copyright 2021 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// AsyncTaskTests:
// Simple tests for dawn_native::AsyncTask and dawn_native::AsnycTaskManager.
#include <gtest/gtest.h>
#include <memory>
#include <mutex>
#include "common/NonCopyable.h"
#include "dawn_native/AsyncTask.h"
#include "dawn_platform/DawnPlatform.h"
namespace {
struct SimpleTaskResult {
uint32_t id;
};
// A thread-safe queue that stores the task results.
class ConcurrentTaskResultQueue : public NonCopyable {
public:
void AddResult(std::unique_ptr<SimpleTaskResult> result) {
std::lock_guard<std::mutex> lock(mMutex);
mTaskResults.push_back(std::move(result));
}
std::vector<std::unique_ptr<SimpleTaskResult>> GetAllResults() {
std::vector<std::unique_ptr<SimpleTaskResult>> outputResults;
{
std::lock_guard<std::mutex> lock(mMutex);
outputResults.swap(mTaskResults);
}
return outputResults;
}
private:
std::mutex mMutex;
std::vector<std::unique_ptr<SimpleTaskResult>> mTaskResults;
};
void DoTask(ConcurrentTaskResultQueue* resultQueue, uint32_t id) {
std::unique_ptr<SimpleTaskResult> result = std::make_unique<SimpleTaskResult>();
result->id = id;
resultQueue->AddResult(std::move(result));
}
} // anonymous namespace
class AsyncTaskTest : public testing::Test {};
// Emulate the basic usage of worker thread pool in Create*PipelineAsync().
TEST_F(AsyncTaskTest, Basic) {
dawn_platform::Platform platform;
std::unique_ptr<dawn_platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
dawn_native::AsnycTaskManager taskManager(pool.get());
ConcurrentTaskResultQueue taskResultQueue;
constexpr size_t kTaskCount = 4u;
std::set<uint32_t> idset;
for (uint32_t i = 0; i < kTaskCount; ++i) {
dawn_native::AsyncTask asyncTask([&taskResultQueue, i] { DoTask(&taskResultQueue, i); });
taskManager.PostTask(std::move(asyncTask));
idset.insert(i);
}
taskManager.WaitAllPendingTasks();
std::vector<std::unique_ptr<SimpleTaskResult>> results = taskResultQueue.GetAllResults();
ASSERT_EQ(kTaskCount, results.size());
for (std::unique_ptr<SimpleTaskResult>& result : results) {
idset.erase(result->id);
}
ASSERT_TRUE(idset.empty());
}

View File

@@ -1,169 +0,0 @@
// Copyright 2021 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// WorkerThreadTests:
// Simple tests for the worker thread class.
#include <gtest/gtest.h>
#include <list>
#include <memory>
#include <mutex>
#include <queue>
#include "common/NonCopyable.h"
#include "dawn_platform/DawnPlatform.h"
namespace {
struct SimpleTaskResult {
uint32_t id;
bool isDone = false;
};
// A thread-safe queue that stores the task results.
class ConcurrentTaskResultQueue : public NonCopyable {
public:
void TaskCompleted(const SimpleTaskResult& result) {
ASSERT_TRUE(result.isDone);
std::lock_guard<std::mutex> lock(mMutex);
mTaskResultQueue.push(result);
}
std::vector<SimpleTaskResult> GetAndPopCompletedTasks() {
std::lock_guard<std::mutex> lock(mMutex);
std::vector<SimpleTaskResult> results;
while (!mTaskResultQueue.empty()) {
results.push_back(mTaskResultQueue.front());
mTaskResultQueue.pop();
}
return results;
}
private:
std::mutex mMutex;
std::queue<SimpleTaskResult> mTaskResultQueue;
};
// A simple task that will be executed asynchronously with pool->PostWorkerTask().
class SimpleTask : public NonCopyable {
public:
SimpleTask(uint32_t id, ConcurrentTaskResultQueue* resultQueue)
: mId(id), mResultQueue(resultQueue) {
}
private:
friend class Tracker;
static void DoTaskOnWorkerTaskPool(void* task) {
SimpleTask* simpleTaskPtr = static_cast<SimpleTask*>(task);
simpleTaskPtr->doTask();
}
void doTask() {
SimpleTaskResult result;
result.id = mId;
result.isDone = true;
mResultQueue->TaskCompleted(result);
}
uint32_t mId;
ConcurrentTaskResultQueue* mResultQueue;
};
// A simple implementation of task tracker which is only called in main thread and not
// thread-safe.
class Tracker : public NonCopyable {
public:
explicit Tracker(dawn_platform::WorkerTaskPool* pool) : mPool(pool) {
}
void StartNewTask(uint32_t taskId) {
mTasksInFlight.emplace_back(this, mPool, taskId);
}
uint64_t GetTasksInFlightCount() {
return mTasksInFlight.size();
}
void WaitAll() {
for (auto iter = mTasksInFlight.begin(); iter != mTasksInFlight.end(); ++iter) {
iter->waitableEvent->Wait();
}
}
// In Tick() we clean up all the completed tasks and consume all the available results.
void Tick() {
auto iter = mTasksInFlight.begin();
while (iter != mTasksInFlight.end()) {
if (iter->waitableEvent->IsComplete()) {
iter = mTasksInFlight.erase(iter);
} else {
++iter;
}
}
const std::vector<SimpleTaskResult>& results =
mCompletedTaskResultQueue.GetAndPopCompletedTasks();
for (const SimpleTaskResult& result : results) {
EXPECT_TRUE(result.isDone);
}
}
private:
SimpleTask* CreateSimpleTask(uint32_t taskId) {
return new SimpleTask(taskId, &mCompletedTaskResultQueue);
}
struct WaitableTask {
WaitableTask(Tracker* tracker, dawn_platform::WorkerTaskPool* pool, uint32_t taskId) {
task.reset(tracker->CreateSimpleTask(taskId));
waitableEvent =
pool->PostWorkerTask(SimpleTask::DoTaskOnWorkerTaskPool, task.get());
}
std::unique_ptr<SimpleTask> task;
std::unique_ptr<dawn_platform::WaitableEvent> waitableEvent;
};
dawn_platform::WorkerTaskPool* mPool;
std::list<WaitableTask> mTasksInFlight;
ConcurrentTaskResultQueue mCompletedTaskResultQueue;
};
} // anonymous namespace
class WorkerThreadTest : public testing::Test {};
// Emulate the basic usage of worker thread pool in Create*PipelineAsync().
TEST_F(WorkerThreadTest, Basic) {
dawn_platform::Platform platform;
std::unique_ptr<dawn_platform::WorkerTaskPool> pool = platform.CreateWorkerTaskPool();
Tracker tracker(pool.get());
constexpr uint32_t kTaskCount = 4;
for (uint32_t i = 0; i < kTaskCount; ++i) {
tracker.StartNewTask(i);
}
EXPECT_EQ(kTaskCount, tracker.GetTasksInFlightCount());
// Wait for the completion of all the tasks.
tracker.WaitAll();
tracker.Tick();
EXPECT_EQ(0u, tracker.GetTasksInFlightCount());
}