DawnNative: Defer callbacks' triggerings to APITick().

Currently in the middle of some functions, we execute callbacks
immediately such as inside Buffer::APIMapAsync(), Device::HandleError()
or Queue::Submit().

Firstly, this has risks. The functions might be in a middle of modifying
internal states. By triggering callbacks, users might call API
functions again which could further modify the internal states
unexpectedly or access the states in an inconsistent way.

Secondly, upcoming thread safe API which locks the public functions with
a mutex might encounter deadlock. Because callbacks might cause
re-entrances which would unexpectedly lock the public function again.

This CL attempts to limit number of functions that are allowed to
trigger callbacks. Other functions that want to trigger callbacks will
instead enqueue a request to execute callbacks in the next
Device::APITick() call.

Currently the functions that will be allowed to trigger callbacks are:
- Device::WillDropLastExternalRef()
- Device::APITick()
- Device::APISetLoggingCallback()
- Device::APISetUncapturedErrorCallback()
- Device::APISetDeviceLostCallback()

Bug: dawn:1672
Change-Id: Iabca00f1b6f8f69eb5e966ffaa43dda5ae20fa8b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/120940
Commit-Queue: Quyen Le <lehoangquyen@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Le Hoang Quyen 2023-03-06 19:03:26 +00:00 committed by Dawn LUCI CQ
parent 77a90cb796
commit 8ab7dbe424
19 changed files with 288 additions and 210 deletions

View File

@ -30,5 +30,17 @@ struct HasEqualityOperator<
std::is_same<decltype(std::declval<LHS>() == std::declval<RHS>()), bool>::value>> { std::is_same<decltype(std::declval<LHS>() == std::declval<RHS>()), bool>::value>> {
static constexpr const bool value = true; static constexpr const bool value = true;
}; };
template <typename T>
struct IsCString {
static constexpr bool Eval() {
using Tp = std::decay_t<T>;
if (!std::is_pointer_v<Tp>) {
return false;
}
return std::is_same_v<std::remove_cv_t<std::remove_pointer_t<Tp>>, char>;
}
static constexpr const bool value = Eval();
};
#endif // SRC_DAWN_COMMON_TYPETRAITS_H_ #endif // SRC_DAWN_COMMON_TYPETRAITS_H_

View File

@ -39,21 +39,22 @@ namespace {
struct MapRequestTask : TrackTaskCallback { struct MapRequestTask : TrackTaskCallback {
MapRequestTask(dawn::platform::Platform* platform, Ref<BufferBase> buffer, MapRequestID id) MapRequestTask(dawn::platform::Platform* platform, Ref<BufferBase> buffer, MapRequestID id)
: TrackTaskCallback(platform), buffer(std::move(buffer)), id(id) {} : TrackTaskCallback(platform), buffer(std::move(buffer)), id(id) {}
void Finish() override {
ASSERT(mSerial != kMaxExecutionSerial);
TRACE_EVENT1(mPlatform, General, "Buffer::TaskInFlight::Finished", "serial",
uint64_t(mSerial));
buffer->OnMapRequestCompleted(id, WGPUBufferMapAsyncStatus_Success);
}
void HandleDeviceLoss() override {
buffer->OnMapRequestCompleted(id, WGPUBufferMapAsyncStatus_DeviceLost);
}
void HandleShutDown() override {
buffer->OnMapRequestCompleted(id, WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
}
~MapRequestTask() override = default; ~MapRequestTask() override = default;
private: private:
void FinishImpl() override {
ASSERT(mSerial != kMaxExecutionSerial);
TRACE_EVENT1(mPlatform, General, "Buffer::TaskInFlight::Finished", "serial",
uint64_t(mSerial));
buffer->CallbackOnMapRequestCompleted(id, WGPUBufferMapAsyncStatus_Success);
}
void HandleDeviceLossImpl() override {
buffer->CallbackOnMapRequestCompleted(id, WGPUBufferMapAsyncStatus_DeviceLost);
}
void HandleShutDownImpl() override {
buffer->CallbackOnMapRequestCompleted(id, WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
}
Ref<BufferBase> buffer; Ref<BufferBase> buffer;
MapRequestID id; MapRequestID id;
}; };
@ -134,46 +135,6 @@ MaybeError ValidateBufferDescriptor(DeviceBase* device, const BufferDescriptor*
return {}; return {};
} }
// BufferBase::PendingMappingCallback
BufferBase::PendingMappingCallback::PendingMappingCallback()
: callback(nullptr), userdata(nullptr) {}
// Ensure to call the callback.
BufferBase::PendingMappingCallback::~PendingMappingCallback() {
ASSERT(callback == nullptr);
ASSERT(userdata == nullptr);
}
BufferBase::PendingMappingCallback::PendingMappingCallback(
BufferBase::PendingMappingCallback&& other) {
this->callback = std::move(other.callback);
this->userdata = std::move(other.userdata);
this->status = other.status;
other.callback = nullptr;
other.userdata = nullptr;
}
BufferBase::PendingMappingCallback& BufferBase::PendingMappingCallback::operator=(
PendingMappingCallback&& other) {
if (&other != this) {
this->callback = std::move(other.callback);
this->userdata = std::move(other.userdata);
this->status = other.status;
other.callback = nullptr;
other.userdata = nullptr;
}
return *this;
}
void BufferBase::PendingMappingCallback::Call() {
if (callback != nullptr) {
callback(status, userdata);
callback = nullptr;
userdata = nullptr;
}
}
// Buffer // Buffer
BufferBase::BufferBase(DeviceBase* device, const BufferDescriptor* descriptor) BufferBase::BufferBase(DeviceBase* device, const BufferDescriptor* descriptor)
@ -227,20 +188,17 @@ BufferBase::~BufferBase() {
} }
void BufferBase::DestroyImpl() { void BufferBase::DestroyImpl() {
PendingMappingCallback toCall;
if (mState == BufferState::Mapped || mState == BufferState::PendingMap) { if (mState == BufferState::Mapped || mState == BufferState::PendingMap) {
toCall = UnmapInternal(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback); UnmapInternal(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
} else if (mState == BufferState::MappedAtCreation) { } else if (mState == BufferState::MappedAtCreation) {
if (mStagingBuffer != nullptr) { if (mStagingBuffer != nullptr) {
mStagingBuffer = nullptr; mStagingBuffer = nullptr;
} else if (mSize != 0) { } else if (mSize != 0) {
toCall = UnmapInternal(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback); UnmapInternal(WGPUBufferMapAsyncStatus_DestroyedBeforeCallback);
} }
} }
mState = BufferState::Destroyed; mState = BufferState::Destroyed;
toCall.Call();
} }
// static // static
@ -376,31 +334,29 @@ MaybeError BufferBase::ValidateCanUseOnQueueNow() const {
UNREACHABLE(); UNREACHABLE();
} }
// Store the callback to be called in an intermediate struct that bubbles up the call stack std::function<void()> BufferBase::PrepareMappingCallback(MapRequestID mapID,
// and is called by the top most function at the very end. It helps to make sure that
// all code paths ensure that nothing happens after the callback.
BufferBase::PendingMappingCallback BufferBase::WillCallMappingCallback(
MapRequestID mapID,
WGPUBufferMapAsyncStatus status) { WGPUBufferMapAsyncStatus status) {
ASSERT(!IsError()); ASSERT(!IsError());
PendingMappingCallback toCall;
if (mMapCallback != nullptr && mapID == mLastMapID) { if (mMapCallback != nullptr && mapID == mLastMapID) {
toCall.callback = std::move(mMapCallback); auto callback = std::move(mMapCallback);
toCall.userdata = std::move(mMapUserdata); auto userdata = std::move(mMapUserdata);
WGPUBufferMapAsyncStatus actualStatus;
if (GetDevice()->IsLost()) { if (GetDevice()->IsLost()) {
toCall.status = WGPUBufferMapAsyncStatus_DeviceLost; actualStatus = WGPUBufferMapAsyncStatus_DeviceLost;
} else { } else {
toCall.status = status; actualStatus = status;
} }
// Tag the callback as fired before firing it, otherwise it could fire a second time if // Tag the callback as fired before firing it, otherwise it could fire a second time if
// for example buffer.Unmap() is called inside the application-provided callback. // for example buffer.Unmap() is called before the MapRequestTask completes.
mMapCallback = nullptr; mMapCallback = nullptr;
mMapUserdata = nullptr; mMapUserdata = nullptr;
return std::bind(callback, actualStatus, userdata);
} }
return toCall; return [] {};
} }
void BufferBase::APIMapAsync(wgpu::MapMode mode, void BufferBase::APIMapAsync(wgpu::MapMode mode,
@ -412,7 +368,8 @@ void BufferBase::APIMapAsync(wgpu::MapMode mode,
// rejects the callback and doesn't produce a validation error. // rejects the callback and doesn't produce a validation error.
if (mState == BufferState::PendingMap) { if (mState == BufferState::PendingMap) {
if (callback) { if (callback) {
callback(WGPUBufferMapAsyncStatus_Error, userdata); GetDevice()->GetCallbackTaskManager()->AddCallbackTask(
callback, WGPUBufferMapAsyncStatus_Error, userdata);
} }
return; return;
} }
@ -429,7 +386,7 @@ void BufferBase::APIMapAsync(wgpu::MapMode mode,
"calling %s.MapAsync(%s, %u, %u, ...).", this, mode, offset, "calling %s.MapAsync(%s, %u, %u, ...).", this, mode, offset,
size)) { size)) {
if (callback) { if (callback) {
callback(status, userdata); GetDevice()->GetCallbackTaskManager()->AddCallbackTask(callback, status, userdata);
} }
return; return;
} }
@ -444,7 +401,8 @@ void BufferBase::APIMapAsync(wgpu::MapMode mode,
mState = BufferState::PendingMap; mState = BufferState::PendingMap;
if (GetDevice()->ConsumedError(MapAsyncImpl(mode, offset, size))) { if (GetDevice()->ConsumedError(MapAsyncImpl(mode, offset, size))) {
WillCallMappingCallback(mLastMapID, WGPUBufferMapAsyncStatus_DeviceLost).Call(); GetDevice()->GetCallbackTaskManager()->AddCallbackTask(
PrepareMappingCallback(mLastMapID, WGPUBufferMapAsyncStatus_DeviceLost));
return; return;
} }
std::unique_ptr<MapRequestTask> request = std::unique_ptr<MapRequestTask> request =
@ -518,17 +476,15 @@ MaybeError BufferBase::Unmap() {
if (mState == BufferState::MappedAtCreation && mStagingBuffer != nullptr) { if (mState == BufferState::MappedAtCreation && mStagingBuffer != nullptr) {
DAWN_TRY(CopyFromStagingBuffer()); DAWN_TRY(CopyFromStagingBuffer());
} }
UnmapInternal(WGPUBufferMapAsyncStatus_UnmappedBeforeCallback).Call(); UnmapInternal(WGPUBufferMapAsyncStatus_UnmappedBeforeCallback);
return {}; return {};
} }
BufferBase::PendingMappingCallback BufferBase::UnmapInternal( void BufferBase::UnmapInternal(WGPUBufferMapAsyncStatus callbackStatus) {
WGPUBufferMapAsyncStatus callbackStatus) { // Unmaps resources on the backend.
PendingMappingCallback toCall;
// Unmaps resources on the backend and returns the callback.
if (mState == BufferState::PendingMap) { if (mState == BufferState::PendingMap) {
toCall = WillCallMappingCallback(mLastMapID, callbackStatus); GetDevice()->GetCallbackTaskManager()->AddCallbackTask(
PrepareMappingCallback(mLastMapID, callbackStatus));
UnmapImpl(); UnmapImpl();
} else if (mState == BufferState::Mapped) { } else if (mState == BufferState::Mapped) {
UnmapImpl(); UnmapImpl();
@ -539,7 +495,6 @@ BufferBase::PendingMappingCallback BufferBase::UnmapInternal(
} }
mState = BufferState::Unmapped; mState = BufferState::Unmapped;
return toCall;
} }
MaybeError BufferBase::ValidateMapAsync(wgpu::MapMode mode, MaybeError BufferBase::ValidateMapAsync(wgpu::MapMode mode,
@ -640,13 +595,15 @@ MaybeError BufferBase::ValidateUnmap() const {
return {}; return {};
} }
void BufferBase::OnMapRequestCompleted(MapRequestID mapID, WGPUBufferMapAsyncStatus status) { void BufferBase::CallbackOnMapRequestCompleted(MapRequestID mapID,
PendingMappingCallback toCall = WillCallMappingCallback(mapID, status); WGPUBufferMapAsyncStatus status) {
if (mapID == mLastMapID && status == WGPUBufferMapAsyncStatus_Success && if (mapID == mLastMapID && status == WGPUBufferMapAsyncStatus_Success &&
mState == BufferState::PendingMap) { mState == BufferState::PendingMap) {
mState = BufferState::Mapped; mState = BufferState::Mapped;
} }
toCall.Call();
auto cb = PrepareMappingCallback(mapID, status);
cb();
} }
bool BufferBase::NeedsInitialization() const { bool BufferBase::NeedsInitialization() const {

View File

@ -15,6 +15,7 @@
#ifndef SRC_DAWN_NATIVE_BUFFER_H_ #ifndef SRC_DAWN_NATIVE_BUFFER_H_
#define SRC_DAWN_NATIVE_BUFFER_H_ #define SRC_DAWN_NATIVE_BUFFER_H_
#include <functional>
#include <memory> #include <memory>
#include "dawn/common/NonCopyable.h" #include "dawn/common/NonCopyable.h"
@ -65,7 +66,7 @@ class BufferBase : public ApiObjectBase {
wgpu::BufferUsage GetUsageExternalOnly() const; wgpu::BufferUsage GetUsageExternalOnly() const;
MaybeError MapAtCreation(); MaybeError MapAtCreation();
void OnMapRequestCompleted(MapRequestID mapID, WGPUBufferMapAsyncStatus status); void CallbackOnMapRequestCompleted(MapRequestID mapID, WGPUBufferMapAsyncStatus status);
MaybeError ValidateCanUseOnQueueNow() const; MaybeError ValidateCanUseOnQueueNow() const;
@ -108,23 +109,7 @@ class BufferBase : public ApiObjectBase {
ExecutionSerial mLastUsageSerial = ExecutionSerial(0); ExecutionSerial mLastUsageSerial = ExecutionSerial(0);
private: private:
// A helper structure to enforce that the mapAsync callback is called only at the very end of std::function<void()> PrepareMappingCallback(MapRequestID mapID,
// methods that might trigger callbacks. Non-copyable but movable for the assertion in the
// destructor to ensure not to forget to call the callback
struct [[nodiscard]] PendingMappingCallback : public NonCopyable {
WGPUBufferMapCallback callback;
void* userdata;
WGPUBufferMapAsyncStatus status;
PendingMappingCallback();
~PendingMappingCallback();
PendingMappingCallback(PendingMappingCallback&& other);
PendingMappingCallback& operator=(PendingMappingCallback&& other);
void Call();
};
PendingMappingCallback WillCallMappingCallback(MapRequestID mapID,
WGPUBufferMapAsyncStatus status); WGPUBufferMapAsyncStatus status);
virtual MaybeError MapAtCreationImpl() = 0; virtual MaybeError MapAtCreationImpl() = 0;
@ -140,7 +125,7 @@ class BufferBase : public ApiObjectBase {
WGPUBufferMapAsyncStatus* status) const; WGPUBufferMapAsyncStatus* status) const;
MaybeError ValidateUnmap() const; MaybeError ValidateUnmap() const;
bool CanGetMappedRange(bool writable, size_t offset, size_t size) const; bool CanGetMappedRange(bool writable, size_t offset, size_t size) const;
PendingMappingCallback UnmapInternal(WGPUBufferMapAsyncStatus callbackStatus); void UnmapInternal(WGPUBufferMapAsyncStatus callbackStatus);
uint64_t mSize = 0; uint64_t mSize = 0;
wgpu::BufferUsage mUsage = wgpu::BufferUsage::None; wgpu::BufferUsage mUsage = wgpu::BufferUsage::None;

View File

@ -16,8 +16,52 @@
#include <utility> #include <utility>
#include "dawn/common/Assert.h"
namespace dawn::native { namespace dawn::native {
namespace {
struct GenericFunctionTask : CallbackTask {
public:
explicit GenericFunctionTask(std::function<void()> func) : mFunction(std::move(func)) {}
private:
void FinishImpl() override { mFunction(); }
void HandleShutDownImpl() override { mFunction(); }
void HandleDeviceLossImpl() override { mFunction(); }
std::function<void()> mFunction;
};
} // namespace
void CallbackTask::Execute() {
switch (mState) {
case State::HandleDeviceLoss:
HandleDeviceLossImpl();
break;
case State::HandleShutDown:
HandleShutDownImpl();
break;
default:
FinishImpl();
}
}
void CallbackTask::OnShutDown() {
// Only first state change will have effects in final Execute().
if (mState != State::Normal) {
return;
}
mState = State::HandleShutDown;
}
void CallbackTask::OnDeviceLoss() {
if (mState != State::Normal) {
return;
}
mState = State::HandleDeviceLoss;
}
CallbackTaskManager::CallbackTaskManager() = default; CallbackTaskManager::CallbackTaskManager() = default;
CallbackTaskManager::~CallbackTaskManager() = default; CallbackTaskManager::~CallbackTaskManager() = default;
@ -40,4 +84,22 @@ void CallbackTaskManager::AddCallbackTask(std::unique_ptr<CallbackTask> callback
mCallbackTaskQueue.push_back(std::move(callbackTask)); mCallbackTaskQueue.push_back(std::move(callbackTask));
} }
void CallbackTaskManager::AddCallbackTask(std::function<void()> callback) {
AddCallbackTask(std::make_unique<GenericFunctionTask>(std::move(callback)));
}
void CallbackTaskManager::HandleDeviceLoss() {
std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex);
for (auto& task : mCallbackTaskQueue) {
task->OnDeviceLoss();
}
}
void CallbackTaskManager::HandleShutDown() {
std::lock_guard<std::mutex> lock(mCallbackTaskQueueMutex);
for (auto& task : mCallbackTaskQueue) {
task->OnShutDown();
}
}
} // namespace dawn::native } // namespace dawn::native

View File

@ -15,18 +15,36 @@
#ifndef SRC_DAWN_NATIVE_CALLBACKTASKMANAGER_H_ #ifndef SRC_DAWN_NATIVE_CALLBACKTASKMANAGER_H_
#define SRC_DAWN_NATIVE_CALLBACKTASKMANAGER_H_ #define SRC_DAWN_NATIVE_CALLBACKTASKMANAGER_H_
#include <functional>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <vector> #include <vector>
#include "dawn/common/TypeTraits.h"
namespace dawn::native { namespace dawn::native {
struct CallbackTask { struct CallbackTask {
public: public:
virtual ~CallbackTask() = default; virtual ~CallbackTask() = default;
virtual void Finish() = 0;
virtual void HandleShutDown() = 0; void Execute();
virtual void HandleDeviceLoss() = 0; void OnShutDown();
void OnDeviceLoss();
protected:
virtual void FinishImpl() = 0;
virtual void HandleShutDownImpl() = 0;
virtual void HandleDeviceLossImpl() = 0;
private:
enum class State {
Normal,
HandleShutDown,
HandleDeviceLoss,
};
State mState = State::Normal;
}; };
class CallbackTaskManager { class CallbackTaskManager {
@ -35,7 +53,16 @@ class CallbackTaskManager {
~CallbackTaskManager(); ~CallbackTaskManager();
void AddCallbackTask(std::unique_ptr<CallbackTask> callbackTask); void AddCallbackTask(std::unique_ptr<CallbackTask> callbackTask);
void AddCallbackTask(std::function<void()> callback);
template <typename... Args>
void AddCallbackTask(void (*callback)(Args... args), Args... args) {
static_assert((!IsCString<Args>::value && ...), "passing C string argument is not allowed");
AddCallbackTask([=] { callback(args...); });
}
bool IsEmpty(); bool IsEmpty();
void HandleDeviceLoss();
void HandleShutDown();
std::vector<std::unique_ptr<CallbackTask>> AcquireCallbackTasks(); std::vector<std::unique_ptr<CallbackTask>> AcquireCallbackTasks();
private: private:

View File

@ -55,19 +55,19 @@ CreateComputePipelineAsyncCallbackTask::CreateComputePipelineAsyncCallbackTask(
CreateComputePipelineAsyncCallbackTask::~CreateComputePipelineAsyncCallbackTask() = default; CreateComputePipelineAsyncCallbackTask::~CreateComputePipelineAsyncCallbackTask() = default;
void CreateComputePipelineAsyncCallbackTask::Finish() { void CreateComputePipelineAsyncCallbackTask::FinishImpl() {
ASSERT(mCreateComputePipelineAsyncCallback != nullptr); ASSERT(mCreateComputePipelineAsyncCallback != nullptr);
mCreateComputePipelineAsyncCallback(mStatus, ToAPI(mPipeline.Detach()), mErrorMessage.c_str(), mCreateComputePipelineAsyncCallback(mStatus, ToAPI(mPipeline.Detach()), mErrorMessage.c_str(),
mUserData); mUserData);
} }
void CreateComputePipelineAsyncCallbackTask::HandleShutDown() { void CreateComputePipelineAsyncCallbackTask::HandleShutDownImpl() {
ASSERT(mCreateComputePipelineAsyncCallback != nullptr); ASSERT(mCreateComputePipelineAsyncCallback != nullptr);
mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
"Device destroyed before callback", mUserData); "Device destroyed before callback", mUserData);
} }
void CreateComputePipelineAsyncCallbackTask::HandleDeviceLoss() { void CreateComputePipelineAsyncCallbackTask::HandleDeviceLossImpl() {
ASSERT(mCreateComputePipelineAsyncCallback != nullptr); ASSERT(mCreateComputePipelineAsyncCallback != nullptr);
mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, mCreateComputePipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr,
"Device lost before callback", mUserData); "Device lost before callback", mUserData);
@ -91,19 +91,19 @@ CreateRenderPipelineAsyncCallbackTask::CreateRenderPipelineAsyncCallbackTask(
CreateRenderPipelineAsyncCallbackTask::~CreateRenderPipelineAsyncCallbackTask() = default; CreateRenderPipelineAsyncCallbackTask::~CreateRenderPipelineAsyncCallbackTask() = default;
void CreateRenderPipelineAsyncCallbackTask::Finish() { void CreateRenderPipelineAsyncCallbackTask::FinishImpl() {
ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); ASSERT(mCreateRenderPipelineAsyncCallback != nullptr);
mCreateRenderPipelineAsyncCallback(mStatus, ToAPI(mPipeline.Detach()), mErrorMessage.c_str(), mCreateRenderPipelineAsyncCallback(mStatus, ToAPI(mPipeline.Detach()), mErrorMessage.c_str(),
mUserData); mUserData);
} }
void CreateRenderPipelineAsyncCallbackTask::HandleShutDown() { void CreateRenderPipelineAsyncCallbackTask::HandleShutDownImpl() {
ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); ASSERT(mCreateRenderPipelineAsyncCallback != nullptr);
mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr, mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceDestroyed, nullptr,
"Device destroyed before callback", mUserData); "Device destroyed before callback", mUserData);
} }
void CreateRenderPipelineAsyncCallbackTask::HandleDeviceLoss() { void CreateRenderPipelineAsyncCallbackTask::HandleDeviceLossImpl() {
ASSERT(mCreateRenderPipelineAsyncCallback != nullptr); ASSERT(mCreateRenderPipelineAsyncCallback != nullptr);
mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr, mCreateRenderPipelineAsyncCallback(WGPUCreatePipelineAsyncStatus_DeviceLost, nullptr,
"Device lost before callback", mUserData); "Device lost before callback", mUserData);

View File

@ -55,11 +55,11 @@ struct CreateComputePipelineAsyncCallbackTask : CreatePipelineAsyncCallbackTaskB
void* userdata); void* userdata);
~CreateComputePipelineAsyncCallbackTask() override; ~CreateComputePipelineAsyncCallbackTask() override;
void Finish() override;
void HandleShutDown() final;
void HandleDeviceLoss() final;
protected: protected:
void FinishImpl() override;
void HandleShutDownImpl() final;
void HandleDeviceLossImpl() final;
Ref<ComputePipelineBase> mPipeline; Ref<ComputePipelineBase> mPipeline;
WGPUCreateComputePipelineAsyncCallback mCreateComputePipelineAsyncCallback; WGPUCreateComputePipelineAsyncCallback mCreateComputePipelineAsyncCallback;
}; };
@ -74,11 +74,11 @@ struct CreateRenderPipelineAsyncCallbackTask : CreatePipelineAsyncCallbackTaskBa
void* userdata); void* userdata);
~CreateRenderPipelineAsyncCallbackTask() override; ~CreateRenderPipelineAsyncCallbackTask() override;
void Finish() override;
void HandleShutDown() final;
void HandleDeviceLoss() final;
protected: protected:
void FinishImpl() override;
void HandleShutDownImpl() final;
void HandleDeviceLossImpl() final;
Ref<RenderPipelineBase> mPipeline; Ref<RenderPipelineBase> mPipeline;
WGPUCreateRenderPipelineAsyncCallback mCreateRenderPipelineAsyncCallback; WGPUCreateRenderPipelineAsyncCallback mCreateRenderPipelineAsyncCallback;
}; };

View File

@ -103,20 +103,20 @@ struct LoggingCallbackTask : CallbackTask {
mLoggingType(loggingType), mLoggingType(loggingType),
mMessage(message), mMessage(message),
mUserdata(userdata) { mUserdata(userdata) {
// Since the Finish() will be called in uncertain future in which time the message // Since the FinishImpl() will be called in uncertain future in which time the message
// may already disposed, we must keep a local copy in the CallbackTask. // may already disposed, we must keep a local copy in the CallbackTask.
} }
void Finish() override { mCallback(mLoggingType, mMessage.c_str(), mUserdata); } private:
void FinishImpl() override { mCallback(mLoggingType, mMessage.c_str(), mUserdata); }
void HandleShutDown() override { void HandleShutDownImpl() override {
// Do the logging anyway // Do the logging anyway
mCallback(mLoggingType, mMessage.c_str(), mUserdata); mCallback(mLoggingType, mMessage.c_str(), mUserdata);
} }
void HandleDeviceLoss() override { mCallback(mLoggingType, mMessage.c_str(), mUserdata); } void HandleDeviceLossImpl() override { mCallback(mLoggingType, mMessage.c_str(), mUserdata); }
private:
// As all deferred callback tasks will be triggered before modifying the registered // As all deferred callback tasks will be triggered before modifying the registered
// callback or shutting down, we are ensured that callback function and userdata pointer // callback or shutting down, we are ensured that callback function and userdata pointer
// stored in tasks is valid when triggered. // stored in tasks is valid when triggered.
@ -305,6 +305,9 @@ void DeviceBase::WillDropLastExternalRef() {
// out all remaining refs. // out all remaining refs.
Destroy(); Destroy();
// Flush last remaining callback tasks.
FlushCallbackTaskQueue();
// Drop te device's reference to the queue. Because the application dropped the last external // Drop te device's reference to the queue. Because the application dropped the last external
// references, they can no longer get the queue from APIGetQueue(). // references, they can no longer get the queue from APIGetQueue().
mQueue = nullptr; mQueue = nullptr;
@ -381,18 +384,16 @@ void DeviceBase::Destroy() {
if (mState != State::BeingCreated) { if (mState != State::BeingCreated) {
// The device is being destroyed so it will be lost, call the application callback. // The device is being destroyed so it will be lost, call the application callback.
if (mDeviceLostCallback != nullptr) { if (mDeviceLostCallback != nullptr) {
mDeviceLostCallback(WGPUDeviceLostReason_Destroyed, "Device was destroyed.", mCallbackTaskManager->AddCallbackTask(
mDeviceLostUserdata); std::bind(mDeviceLostCallback, WGPUDeviceLostReason_Destroyed,
"Device was destroyed.", mDeviceLostUserdata));
mDeviceLostCallback = nullptr; mDeviceLostCallback = nullptr;
} }
// Call all the callbacks immediately as the device is about to shut down. // Call all the callbacks immediately as the device is about to shut down.
// TODO(crbug.com/dawn/826): Cancel the tasks that are in flight if possible. // TODO(crbug.com/dawn/826): Cancel the tasks that are in flight if possible.
mAsyncTaskManager->WaitAllPendingTasks(); mAsyncTaskManager->WaitAllPendingTasks();
auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks(); mCallbackTaskManager->HandleShutDown();
for (std::unique_ptr<CallbackTask>& callbackTask : callbackTasks) {
callbackTask->HandleShutDown();
}
} }
// Disconnect the device, depending on which state we are currently in. // Disconnect the device, depending on which state we are currently in.
@ -436,9 +437,6 @@ void DeviceBase::Destroy() {
// Call TickImpl once last time to clean up resources // Call TickImpl once last time to clean up resources
// Ignore errors so that we can continue with destruction // Ignore errors so that we can continue with destruction
IgnoreErrors(TickImpl()); IgnoreErrors(TickImpl());
// Trigger all in-flight TrackTask callbacks from 'mQueue'.
FlushCallbackTaskQueue();
} }
// At this point GPU operations are always finished, so we are in the disconnected state. // At this point GPU operations are always finished, so we are in the disconnected state.
@ -512,11 +510,15 @@ void DeviceBase::HandleError(std::unique_ptr<ErrorData> error,
// TODO(lokokung) Update call sites that take the c-string to take string_view. // TODO(lokokung) Update call sites that take the c-string to take string_view.
const std::string messageStr = error->GetFormattedMessage(); const std::string messageStr = error->GetFormattedMessage();
const char* message = messageStr.c_str();
if (type == InternalErrorType::DeviceLost) { if (type == InternalErrorType::DeviceLost) {
// The device was lost, call the application callback. // The device was lost, schedule the application callback's executation.
// Note: we don't invoke the callbacks directly here because it could cause re-entrances ->
// possible deadlock.
if (mDeviceLostCallback != nullptr) { if (mDeviceLostCallback != nullptr) {
mDeviceLostCallback(lost_reason, message, mDeviceLostUserdata); mCallbackTaskManager->AddCallbackTask([callback = mDeviceLostCallback, lost_reason,
messageStr, userdata = mDeviceLostUserdata] {
callback(lost_reason, messageStr.c_str(), userdata);
});
mDeviceLostCallback = nullptr; mDeviceLostCallback = nullptr;
} }
@ -524,21 +526,22 @@ void DeviceBase::HandleError(std::unique_ptr<ErrorData> error,
// TODO(crbug.com/dawn/826): Cancel the tasks that are in flight if possible. // TODO(crbug.com/dawn/826): Cancel the tasks that are in flight if possible.
mAsyncTaskManager->WaitAllPendingTasks(); mAsyncTaskManager->WaitAllPendingTasks();
auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks(); mCallbackTaskManager->HandleDeviceLoss();
for (std::unique_ptr<CallbackTask>& callbackTask : callbackTasks) {
callbackTask->HandleDeviceLoss();
}
// Still forward device loss errors to the error scopes so they all reject. // Still forward device loss errors to the error scopes so they all reject.
mErrorScopeStack->HandleError(ToWGPUErrorType(type), message); mErrorScopeStack->HandleError(ToWGPUErrorType(type), messageStr.c_str());
} else { } else {
// Pass the error to the error scope stack and call the uncaptured error callback // Pass the error to the error scope stack and call the uncaptured error callback
// if it isn't handled. DeviceLost is not handled here because it should be // if it isn't handled. DeviceLost is not handled here because it should be
// handled by the lost callback. // handled by the lost callback.
bool captured = mErrorScopeStack->HandleError(ToWGPUErrorType(type), message); bool captured = mErrorScopeStack->HandleError(ToWGPUErrorType(type), messageStr.c_str());
if (!captured && mUncapturedErrorCallback != nullptr) { if (!captured && mUncapturedErrorCallback != nullptr) {
mUncapturedErrorCallback(static_cast<WGPUErrorType>(ToWGPUErrorType(type)), message, mCallbackTaskManager->AddCallbackTask([callback = mUncapturedErrorCallback, type,
mUncapturedErrorUserdata); messageStr,
userdata = mUncapturedErrorUserdata] {
callback(static_cast<WGPUErrorType>(ToWGPUErrorType(type)), messageStr.c_str(),
userdata);
});
} }
} }
} }
@ -607,15 +610,21 @@ bool DeviceBase::APIPopErrorScope(wgpu::ErrorCallback callback, void* userdata)
} }
// TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents // TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents
if (IsLost()) { if (IsLost()) {
callback(WGPUErrorType_DeviceLost, "GPU device disconnected", userdata); mCallbackTaskManager->AddCallbackTask(
std::bind(callback, WGPUErrorType_DeviceLost, "GPU device disconnected", userdata));
return returnValue; return returnValue;
} }
if (mErrorScopeStack->Empty()) { if (mErrorScopeStack->Empty()) {
callback(WGPUErrorType_Unknown, "No error scopes to pop", userdata); mCallbackTaskManager->AddCallbackTask(
std::bind(callback, WGPUErrorType_Unknown, "No error scopes to pop", userdata));
return returnValue; return returnValue;
} }
ErrorScope scope = mErrorScopeStack->Pop(); ErrorScope scope = mErrorScopeStack->Pop();
callback(static_cast<WGPUErrorType>(scope.GetErrorType()), scope.GetErrorMessage(), userdata); mCallbackTaskManager->AddCallbackTask(
[callback, errorType = static_cast<WGPUErrorType>(scope.GetErrorType()),
message = scope.GetErrorMessage(),
userdata] { callback(errorType, message.c_str(), userdata); });
return returnValue; return returnValue;
} }
@ -710,11 +719,12 @@ void DeviceBase::AssumeCommandsComplete() {
mCompletedSerial = mLastSubmittedSerial; mCompletedSerial = mLastSubmittedSerial;
} }
bool DeviceBase::IsDeviceIdle() { bool DeviceBase::HasPendingTasks() {
if (mAsyncTaskManager->HasPendingTasks()) { return mAsyncTaskManager->HasPendingTasks() || !mCallbackTaskManager->IsEmpty();
return false;
} }
if (!mCallbackTaskManager->IsEmpty()) {
bool DeviceBase::IsDeviceIdle() {
if (HasPendingTasks()) {
return false; return false;
} }
return !HasScheduledCommands(); return !HasScheduledCommands();
@ -1089,7 +1099,10 @@ void DeviceBase::APICreateComputePipelineAsync(const ComputePipelineDescriptor*
WGPUCreatePipelineAsyncStatus status = WGPUCreatePipelineAsyncStatus status =
CreatePipelineAsyncStatusFromErrorType(error->GetType()); CreatePipelineAsyncStatusFromErrorType(error->GetType());
// TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents // TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents
callback(status, nullptr, error->GetMessage().c_str(), userdata); mCallbackTaskManager->AddCallbackTask(
[callback, status, message = error->GetMessage(), userdata] {
callback(status, nullptr, message.c_str(), userdata);
});
} }
} }
PipelineLayoutBase* DeviceBase::APICreatePipelineLayout( PipelineLayoutBase* DeviceBase::APICreatePipelineLayout(
@ -1133,7 +1146,10 @@ void DeviceBase::APICreateRenderPipelineAsync(const RenderPipelineDescriptor* de
WGPUCreatePipelineAsyncStatus status = WGPUCreatePipelineAsyncStatus status =
CreatePipelineAsyncStatusFromErrorType(error->GetType()); CreatePipelineAsyncStatusFromErrorType(error->GetType());
// TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents // TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents
callback(status, nullptr, error->GetMessage().c_str(), userdata); mCallbackTaskManager->AddCallbackTask(
[callback, status, message = error->GetMessage(), userdata] {
callback(status, nullptr, message.c_str(), userdata);
});
} }
} }
RenderBundleEncoder* DeviceBase::APICreateRenderBundleEncoder( RenderBundleEncoder* DeviceBase::APICreateRenderBundleEncoder(
@ -1239,10 +1255,21 @@ bool DeviceBase::APITick() {
// Tick may trigger callbacks which drop a ref to the device itself. Hold a Ref to ourselves // Tick may trigger callbacks which drop a ref to the device itself. Hold a Ref to ourselves
// to avoid deleting |this| in the middle of this function call. // to avoid deleting |this| in the middle of this function call.
Ref<DeviceBase> self(this); Ref<DeviceBase> self(this);
if (IsLost() || ConsumedError(Tick())) { if (ConsumedError(Tick())) {
FlushCallbackTaskQueue();
return false; return false;
} }
// We have to check callback tasks in every APITick because it is not related to any global
// serials.
FlushCallbackTaskQueue();
// We don't throw an error when device is lost. This allows pending callbacks to be
// executed even after the Device is lost/destroyed.
if (IsLost()) {
return HasPendingTasks();
}
TRACE_EVENT1(GetPlatform(), General, "DeviceBase::APITick::IsDeviceIdle", "isDeviceIdle", TRACE_EVENT1(GetPlatform(), General, "DeviceBase::APITick::IsDeviceIdle", "isDeviceIdle",
IsDeviceIdle()); IsDeviceIdle());
@ -1250,12 +1277,13 @@ bool DeviceBase::APITick() {
} }
MaybeError DeviceBase::Tick() { MaybeError DeviceBase::Tick() {
DAWN_TRY(ValidateIsAlive()); if (IsLost() || !HasScheduledCommands()) {
return {};
}
// To avoid overly ticking, we only want to tick when: // To avoid overly ticking, we only want to tick when:
// 1. the last submitted serial has moved beyond the completed serial // 1. the last submitted serial has moved beyond the completed serial
// 2. or the backend still has pending commands to submit. // 2. or the backend still has pending commands to submit.
if (HasScheduledCommands()) {
DAWN_TRY(CheckPassedSerials()); DAWN_TRY(CheckPassedSerials());
DAWN_TRY(TickImpl()); DAWN_TRY(TickImpl());
@ -1264,11 +1292,6 @@ MaybeError DeviceBase::Tick() {
// reclaiming resources one tick earlier. // reclaiming resources one tick earlier.
mDynamicUploader->Deallocate(mCompletedSerial); mDynamicUploader->Deallocate(mCompletedSerial);
mQueue->Tick(mCompletedSerial); mQueue->Tick(mCompletedSerial);
}
// We have to check callback tasks in every Tick because it is not related to any global
// serials.
FlushCallbackTaskQueue();
return {}; return {};
} }
@ -1517,8 +1540,9 @@ MaybeError DeviceBase::CreateComputePipelineAsync(const ComputePipelineDescripto
GetCachedComputePipeline(uninitializedComputePipeline.Get()); GetCachedComputePipeline(uninitializedComputePipeline.Get());
if (cachedComputePipeline.Get() != nullptr) { if (cachedComputePipeline.Get() != nullptr) {
// TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents // TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents
callback(WGPUCreatePipelineAsyncStatus_Success, ToAPI(cachedComputePipeline.Detach()), "", mCallbackTaskManager->AddCallbackTask(
userdata); std::bind(callback, WGPUCreatePipelineAsyncStatus_Success,
ToAPI(cachedComputePipeline.Detach()), "", userdata));
} else { } else {
// Otherwise we will create the pipeline object in InitializeComputePipelineAsyncImpl(), // Otherwise we will create the pipeline object in InitializeComputePipelineAsyncImpl(),
// where the pipeline object may be initialized asynchronously and the result will be // where the pipeline object may be initialized asynchronously and the result will be
@ -1658,8 +1682,9 @@ MaybeError DeviceBase::CreateRenderPipelineAsync(const RenderPipelineDescriptor*
GetCachedRenderPipeline(uninitializedRenderPipeline.Get()); GetCachedRenderPipeline(uninitializedRenderPipeline.Get());
if (cachedRenderPipeline != nullptr) { if (cachedRenderPipeline != nullptr) {
// TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents // TODO(crbug.com/dawn/1122): Call callbacks only on wgpuInstanceProcessEvents
callback(WGPUCreatePipelineAsyncStatus_Success, ToAPI(cachedRenderPipeline.Detach()), "", mCallbackTaskManager->AddCallbackTask(
userdata); std::bind(callback, WGPUCreatePipelineAsyncStatus_Success,
ToAPI(cachedRenderPipeline.Detach()), "", userdata));
} else { } else {
// Otherwise we will create the pipeline object in InitializeRenderPipelineAsyncImpl(), // Otherwise we will create the pipeline object in InitializeRenderPipelineAsyncImpl(),
// where the pipeline object may be initialized asynchronously and the result will be // where the pipeline object may be initialized asynchronously and the result will be
@ -1784,7 +1809,7 @@ void DeviceBase::FlushCallbackTaskQueue() {
// update mCallbackTaskManager, then call all the callbacks. // update mCallbackTaskManager, then call all the callbacks.
auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks(); auto callbackTasks = mCallbackTaskManager->AcquireCallbackTasks();
for (std::unique_ptr<CallbackTask>& callbackTask : callbackTasks) { for (std::unique_ptr<CallbackTask>& callbackTask : callbackTasks) {
callbackTask->Finish(); callbackTask->Execute();
} }
} }
} }
@ -1814,14 +1839,14 @@ void DeviceBase::AddComputePipelineAsyncCallbackTask(
struct CreateComputePipelineAsyncWaitableCallbackTask final struct CreateComputePipelineAsyncWaitableCallbackTask final
: CreateComputePipelineAsyncCallbackTask { : CreateComputePipelineAsyncCallbackTask {
using CreateComputePipelineAsyncCallbackTask::CreateComputePipelineAsyncCallbackTask; using CreateComputePipelineAsyncCallbackTask::CreateComputePipelineAsyncCallbackTask;
void Finish() final { void FinishImpl() final {
// TODO(dawn:529): call AddOrGetCachedComputePipeline() asynchronously in // TODO(dawn:529): call AddOrGetCachedComputePipeline() asynchronously in
// CreateComputePipelineAsyncTaskImpl::Run() when the front-end pipeline cache is // CreateComputePipelineAsyncTaskImpl::Run() when the front-end pipeline cache is
// thread-safe. // thread-safe.
ASSERT(mPipeline != nullptr); ASSERT(mPipeline != nullptr);
mPipeline = mPipeline->GetDevice()->AddOrGetCachedComputePipeline(mPipeline); mPipeline = mPipeline->GetDevice()->AddOrGetCachedComputePipeline(mPipeline);
CreateComputePipelineAsyncCallbackTask::Finish(); CreateComputePipelineAsyncCallbackTask::FinishImpl();
} }
}; };
@ -1839,7 +1864,7 @@ void DeviceBase::AddRenderPipelineAsyncCallbackTask(Ref<RenderPipelineBase> pipe
: CreateRenderPipelineAsyncCallbackTask { : CreateRenderPipelineAsyncCallbackTask {
using CreateRenderPipelineAsyncCallbackTask::CreateRenderPipelineAsyncCallbackTask; using CreateRenderPipelineAsyncCallbackTask::CreateRenderPipelineAsyncCallbackTask;
void Finish() final { void FinishImpl() final {
// TODO(dawn:529): call AddOrGetCachedRenderPipeline() asynchronously in // TODO(dawn:529): call AddOrGetCachedRenderPipeline() asynchronously in
// CreateRenderPipelineAsyncTaskImpl::Run() when the front-end pipeline cache is // CreateRenderPipelineAsyncTaskImpl::Run() when the front-end pipeline cache is
// thread-safe. // thread-safe.
@ -1847,7 +1872,7 @@ void DeviceBase::AddRenderPipelineAsyncCallbackTask(Ref<RenderPipelineBase> pipe
mPipeline = mPipeline->GetDevice()->AddOrGetCachedRenderPipeline(mPipeline); mPipeline = mPipeline->GetDevice()->AddOrGetCachedRenderPipeline(mPipeline);
} }
CreateRenderPipelineAsyncCallbackTask::Finish(); CreateRenderPipelineAsyncCallbackTask::FinishImpl();
} }
}; };

View File

@ -515,6 +515,7 @@ class DeviceBase : public RefCountedWithExternalCount {
// and waiting on a serial that doesn't have a corresponding fence enqueued. Fake serials to // and waiting on a serial that doesn't have a corresponding fence enqueued. Fake serials to
// make all commands look completed. // make all commands look completed.
void AssumeCommandsComplete(); void AssumeCommandsComplete();
bool HasPendingTasks();
bool IsDeviceIdle(); bool IsDeviceIdle();
// mCompletedSerial tracks the last completed command serial that the fence has returned. // mCompletedSerial tracks the last completed command serial that the fence has returned.

View File

@ -43,8 +43,8 @@ wgpu::ErrorType ErrorScope::GetErrorType() const {
return mCapturedError; return mCapturedError;
} }
const char* ErrorScope::GetErrorMessage() const { const std::string& ErrorScope::GetErrorMessage() const {
return mErrorMessage.c_str(); return mErrorMessage;
} }
ErrorScopeStack::ErrorScopeStack() = default; ErrorScopeStack::ErrorScopeStack() = default;

View File

@ -25,7 +25,7 @@ namespace dawn::native {
class ErrorScope { class ErrorScope {
public: public:
wgpu::ErrorType GetErrorType() const; wgpu::ErrorType GetErrorType() const;
const char* GetErrorMessage() const; const std::string& GetErrorMessage() const;
private: private:
friend class ErrorScopeStack; friend class ErrorScopeStack;

View File

@ -135,7 +135,10 @@ struct SubmittedWorkDone : TrackTaskCallback {
WGPUQueueWorkDoneCallback callback, WGPUQueueWorkDoneCallback callback,
void* userdata) void* userdata)
: TrackTaskCallback(platform), mCallback(callback), mUserdata(userdata) {} : TrackTaskCallback(platform), mCallback(callback), mUserdata(userdata) {}
void Finish() override { ~SubmittedWorkDone() override = default;
private:
void FinishImpl() override {
ASSERT(mCallback != nullptr); ASSERT(mCallback != nullptr);
ASSERT(mSerial != kMaxExecutionSerial); ASSERT(mSerial != kMaxExecutionSerial);
TRACE_EVENT1(mPlatform, General, "Queue::SubmittedWorkDone::Finished", "serial", TRACE_EVENT1(mPlatform, General, "Queue::SubmittedWorkDone::Finished", "serial",
@ -143,15 +146,13 @@ struct SubmittedWorkDone : TrackTaskCallback {
mCallback(WGPUQueueWorkDoneStatus_Success, mUserdata); mCallback(WGPUQueueWorkDoneStatus_Success, mUserdata);
mCallback = nullptr; mCallback = nullptr;
} }
void HandleDeviceLoss() override { void HandleDeviceLossImpl() override {
ASSERT(mCallback != nullptr); ASSERT(mCallback != nullptr);
mCallback(WGPUQueueWorkDoneStatus_DeviceLost, mUserdata); mCallback(WGPUQueueWorkDoneStatus_DeviceLost, mUserdata);
mCallback = nullptr; mCallback = nullptr;
} }
void HandleShutDown() override { HandleDeviceLoss(); } void HandleShutDownImpl() override { HandleDeviceLossImpl(); }
~SubmittedWorkDone() override = default;
private:
WGPUQueueWorkDoneCallback mCallback = nullptr; WGPUQueueWorkDoneCallback mCallback = nullptr;
void* mUserdata; void* mUserdata;
}; };
@ -207,7 +208,8 @@ void QueueBase::APIOnSubmittedWorkDone(uint64_t signalValue,
// The error status depends on the type of error so we let the validation function choose it // The error status depends on the type of error so we let the validation function choose it
WGPUQueueWorkDoneStatus status; WGPUQueueWorkDoneStatus status;
if (GetDevice()->ConsumedError(ValidateOnSubmittedWorkDone(signalValue, &status))) { if (GetDevice()->ConsumedError(ValidateOnSubmittedWorkDone(signalValue, &status))) {
callback(status, userdata); GetDevice()->GetCallbackTaskManager()->AddCallbackTask(
[callback, status, userdata] { callback(status, userdata); });
return; return;
} }
@ -272,7 +274,8 @@ void QueueBase::Tick(ExecutionSerial finishedSerial) {
void QueueBase::HandleDeviceLoss() { void QueueBase::HandleDeviceLoss() {
for (auto& task : mTasksInFlight.IterateAll()) { for (auto& task : mTasksInFlight.IterateAll()) {
task->HandleDeviceLoss(); task->OnDeviceLoss();
GetDevice()->GetCallbackTaskManager()->AddCallbackTask(std::move(task));
} }
mTasksInFlight.Clear(); mTasksInFlight.Clear();
} }

View File

@ -1037,6 +1037,7 @@ void DawnTestBase::LoseDeviceForTesting(wgpu::Device device) {
Call(WGPUDeviceLostReason_Undefined, testing::_, resolvedDevice.Get())) Call(WGPUDeviceLostReason_Undefined, testing::_, resolvedDevice.Get()))
.Times(1); .Times(1);
resolvedDevice.ForceLoss(wgpu::DeviceLostReason::Undefined, "Device lost for testing"); resolvedDevice.ForceLoss(wgpu::DeviceLostReason::Undefined, "Device lost for testing");
resolvedDevice.Tick();
} }
std::ostringstream& DawnTestBase::AddBufferExpectation(const char* file, std::ostringstream& DawnTestBase::AddBufferExpectation(const char* file,

View File

@ -109,6 +109,7 @@
EXPECT_CALL(mDeviceErrorCallback, \ EXPECT_CALL(mDeviceErrorCallback, \
Call(testing::Ne(WGPUErrorType_NoError), matcher, device.Get())); \ Call(testing::Ne(WGPUErrorType_NoError), matcher, device.Get())); \
statement; \ statement; \
device.Tick(); \
FlushWire(); \ FlushWire(); \
testing::Mock::VerifyAndClearExpectations(&mDeviceErrorCallback); \ testing::Mock::VerifyAndClearExpectations(&mDeviceErrorCallback); \
do { \ do { \

View File

@ -189,7 +189,10 @@ TEST_P(DeviceLifetimeTests, DroppedBeforeMappedAtCreationBuffer) {
// Test that the device can be dropped before a buffer created from it, then mapping the buffer // Test that the device can be dropped before a buffer created from it, then mapping the buffer
// fails. // fails.
TEST_P(DeviceLifetimeTests, DroppedThenMapBuffer) { // TODO(crbug.com/dawn/752): Re-enable this test once we implement Instance.ProcessEvents().
// Currently the callbacks are called inside Device.Tick() only. However, since we drop the device,
// there is no way to call Device.Tick() anymore.
TEST_P(DeviceLifetimeTests, DISABLED_DroppedThenMapBuffer) {
wgpu::BufferDescriptor desc = {}; wgpu::BufferDescriptor desc = {};
desc.size = 4; desc.size = 4;
desc.usage = wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst; desc.usage = wgpu::BufferUsage::MapRead | wgpu::BufferUsage::CopyDst;

View File

@ -202,6 +202,7 @@ TEST_P(MaxLimitTests, MaxBufferBindingSize) {
device.PopErrorScope([](WGPUErrorType type, const char*, device.PopErrorScope([](WGPUErrorType type, const char*,
void* userdata) { *static_cast<WGPUErrorType*>(userdata) = type; }, void* userdata) { *static_cast<WGPUErrorType*>(userdata) = type; },
&oomResult); &oomResult);
device.Tick();
FlushWire(); FlushWire();
// Max buffer size is smaller than the max buffer binding size. // Max buffer size is smaller than the max buffer binding size.
DAWN_TEST_UNSUPPORTED_IF(oomResult == WGPUErrorType_OutOfMemory); DAWN_TEST_UNSUPPORTED_IF(oomResult == WGPUErrorType_OutOfMemory);

View File

@ -218,14 +218,3 @@ TEST_F(DeviceTickValidationTest, DestroyDeviceBeforeAPITick) {
device.Destroy(); device.Destroy();
device.Tick(); device.Tick();
} }
// Device destroy before an internal Tick should return an error.
TEST_F(DeviceTickValidationTest, DestroyDeviceBeforeInternalTick) {
DAWN_SKIP_TEST_IF(UsesWire());
ExpectDeviceDestruction();
device.Destroy();
dawn::native::DeviceBase* nativeDevice = dawn::native::FromAPI(device.Get());
ASSERT_DEVICE_ERROR(EXPECT_TRUE(nativeDevice->ConsumedError(nativeDevice->Tick())),
HasSubstr("[Device] is lost."));
}

View File

@ -45,6 +45,12 @@ static void ToMockQueueWorkDone(WGPUQueueWorkDoneStatus status, void* userdata)
} }
class ErrorScopeValidationTest : public ValidationTest { class ErrorScopeValidationTest : public ValidationTest {
protected:
void FlushWireAndTick() {
FlushWire();
device.Tick();
}
private: private:
void SetUp() override { void SetUp() override {
ValidationTest::SetUp(); ValidationTest::SetUp();
@ -67,7 +73,7 @@ TEST_F(ErrorScopeValidationTest, Success) {
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_NoError, _, this)).Times(1); EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_NoError, _, this)).Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this);
FlushWire(); FlushWireAndTick();
} }
// Test the simple case where the error scope catches an error. // Test the simple case where the error scope catches an error.
@ -80,7 +86,7 @@ TEST_F(ErrorScopeValidationTest, CatchesError) {
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_Validation, _, this)).Times(1); EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_Validation, _, this)).Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this);
FlushWire(); FlushWireAndTick();
} }
// Test that errors bubble to the parent scope if not handled by the current scope. // Test that errors bubble to the parent scope if not handled by the current scope.
@ -95,13 +101,13 @@ TEST_F(ErrorScopeValidationTest, ErrorBubbles) {
// OutOfMemory does not match Validation error. // OutOfMemory does not match Validation error.
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_NoError, _, this)).Times(1); EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_NoError, _, this)).Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this);
FlushWire(); FlushWireAndTick();
// Parent validation error scope captures the error. // Parent validation error scope captures the error.
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_Validation, _, this + 1)) EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_Validation, _, this + 1))
.Times(1); .Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this + 1); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this + 1);
FlushWire(); FlushWireAndTick();
} }
// Test that if an error scope matches an error, it does not bubble to the parent scope. // Test that if an error scope matches an error, it does not bubble to the parent scope.
@ -116,13 +122,13 @@ TEST_F(ErrorScopeValidationTest, HandledErrorsStopBubbling) {
// Inner scope catches the error. // Inner scope catches the error.
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_Validation, _, this)).Times(1); EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_Validation, _, this)).Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this);
FlushWire(); FlushWireAndTick();
// Parent scope does not see the error. // Parent scope does not see the error.
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_NoError, _, this + 1)) EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_NoError, _, this + 1))
.Times(1); .Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this + 1); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this + 1);
FlushWire(); FlushWireAndTick();
} }
// Test that if no error scope handles an error, it goes to the device UncapturedError callback // Test that if no error scope handles an error, it goes to the device UncapturedError callback
@ -135,7 +141,7 @@ TEST_F(ErrorScopeValidationTest, UnhandledErrorsMatchUncapturedErrorCallback) {
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_NoError, _, this)).Times(1); EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_NoError, _, this)).Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this);
FlushWire(); FlushWireAndTick();
} }
// Check that push/popping error scopes must be balanced. // Check that push/popping error scopes must be balanced.
@ -145,6 +151,7 @@ TEST_F(ErrorScopeValidationTest, PushPopBalanced) {
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_Unknown, _, this)) EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_Unknown, _, this))
.Times(1); .Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this);
FlushWireAndTick();
} }
// Too many pops // Too many pops
{ {
@ -153,11 +160,12 @@ TEST_F(ErrorScopeValidationTest, PushPopBalanced) {
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_NoError, _, this + 1)) EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_NoError, _, this + 1))
.Times(1); .Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this + 1); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this + 1);
FlushWire(); FlushWireAndTick();
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_Unknown, _, this + 2)) EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_Unknown, _, this + 2))
.Times(1); .Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this + 2); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this + 2);
FlushWireAndTick();
} }
} }
@ -221,10 +229,11 @@ TEST_F(ErrorScopeValidationTest, DeviceDestroyedBeforePop) {
device.PushErrorScope(wgpu::ErrorFilter::Validation); device.PushErrorScope(wgpu::ErrorFilter::Validation);
ExpectDeviceDestruction(); ExpectDeviceDestruction();
device.Destroy(); device.Destroy();
FlushWire(); FlushWireAndTick();
EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_DeviceLost, _, this)).Times(1); EXPECT_CALL(*mockDevicePopErrorScopeCallback, Call(WGPUErrorType_DeviceLost, _, this)).Times(1);
device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this); device.PopErrorScope(ToMockDevicePopErrorScopeCallback, this);
FlushWireAndTick();
} }
// Regression test that on device shutdown, we don't get a recursion in O(pushed error scope) that // Regression test that on device shutdown, we don't get a recursion in O(pushed error scope) that

View File

@ -47,6 +47,7 @@
#define ASSERT_DEVICE_ERROR_IMPL_1_(statement) \ #define ASSERT_DEVICE_ERROR_IMPL_1_(statement) \
StartExpectDeviceError(); \ StartExpectDeviceError(); \
statement; \ statement; \
device.Tick(); \
FlushWire(); \ FlushWire(); \
if (!EndExpectDeviceError()) { \ if (!EndExpectDeviceError()) { \
FAIL() << "Expected device error in:\n " << #statement; \ FAIL() << "Expected device error in:\n " << #statement; \
@ -57,6 +58,7 @@
#define ASSERT_DEVICE_ERROR_IMPL_2_(statement, matcher) \ #define ASSERT_DEVICE_ERROR_IMPL_2_(statement, matcher) \
StartExpectDeviceError(matcher); \ StartExpectDeviceError(matcher); \
statement; \ statement; \
device.Tick(); \
FlushWire(); \ FlushWire(); \
if (!EndExpectDeviceError()) { \ if (!EndExpectDeviceError()) { \
FAIL() << "Expected device error in:\n " << #statement; \ FAIL() << "Expected device error in:\n " << #statement; \