dawn_wire: Return early in callbacks after the server is destroyed

After the server is destroyed, the server's can't do anything like
forward callbacks to the client. Track this with a weak_ptr and
return early if it has expired.

It also updates device destruction in dawn_native so the lost
callback is always called, even on graceful destruction. This
is consistent with the rest of WebGPU where all callbacks are
guaranteed to be called in finite time.

Bug: chromium:1147416, chromium:1161943
Change-Id: Ib80dea36517401a2b8eafb01ded255ebbe757aef
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/35840
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
Austin Eng 2021-01-05 08:37:08 +00:00 committed by Commit Bot service account
parent e3fd026108
commit 200941c797
8 changed files with 200 additions and 146 deletions

View File

@ -34,11 +34,31 @@ namespace dawn_wire { namespace server {
auto* deviceData = DeviceObjects().Allocate(1);
deviceData->handle = device;
mProcs.deviceSetUncapturedErrorCallback(device, ForwardUncapturedError, this);
mProcs.deviceSetDeviceLostCallback(device, ForwardDeviceLost, this);
// Note: these callbacks are manually inlined here since they do not acquire and
// free their userdata.
mProcs.deviceSetUncapturedErrorCallback(
device,
[](WGPUErrorType type, const char* message, void* userdata) {
Server* server = static_cast<Server*>(userdata);
server->OnUncapturedError(type, message);
},
this);
mProcs.deviceSetDeviceLostCallback(
device,
[](const char* message, void* userdata) {
Server* server = static_cast<Server*>(userdata);
server->OnDeviceLost(message);
},
this);
}
Server::~Server() {
// Un-set the error and lost callbacks since we cannot forward them
// after the server has been destroyed.
WGPUDevice device = DeviceObjects().Get(1)->handle;
mProcs.deviceSetUncapturedErrorCallback(device, nullptr, nullptr);
mProcs.deviceSetDeviceLostCallback(device, nullptr, nullptr);
DestroyAllObjects(mProcs);
}

View File

@ -23,8 +23,94 @@ namespace dawn_wire { namespace server {
class Server;
class MemoryTransferService;
struct MapUserdata {
Server* server;
// CallbackUserdata and its derived classes are intended to be created by
// Server::MakeUserdata<T> and then passed as the userdata argument for Dawn
// callbacks.
// It contains a pointer back to the Server so that the callback can call the
// Server to perform operations like serialization, and it contains a weak pointer
// |serverIsAlive|. If the weak pointer has expired, it means the server has
// been destroyed and the callback must not use the Server pointer.
// To assist with checking |serverIsAlive| and lifetime management of the userdata,
// |ForwardToServer| (defined later in this file) can be used to acquire the userdata,
// return early if |serverIsAlive| has expired, and then forward the arguments
// to userdata->server->MyCallbackHandler.
//
// Example Usage:
//
// struct MyUserdata : CallbackUserdata { uint32_t foo; };
//
// auto userdata = MakeUserdata<MyUserdata>();
// userdata->foo = 2;
//
// // TODO(enga): Make the template inference for ForwardToServer cleaner with C++17
// callMyCallbackHandler(
// ForwardToServer<decltype(&Server::MyCallbackHandler)>::Func<
// &Server::MyCallbackHandler>(),
// userdata.release());
//
// void Server::MyCallbackHandler(MyUserdata* userdata) { }
struct CallbackUserdata {
Server* const server;
std::weak_ptr<bool> const serverIsAlive;
private:
friend class Server;
CallbackUserdata() = delete;
CallbackUserdata(Server* server, const std::shared_ptr<bool>& serverIsAlive)
: server(server), serverIsAlive(serverIsAlive) {
}
};
template <typename F>
class ForwardToServer;
template <typename R, typename... Args>
class ForwardToServer<R (Server::*)(Args...)> {
private:
// Get the type T of the last argument. It has CallbackUserdata as its base.
using UserdataT = typename std::remove_pointer<typename std::decay<decltype(
std::get<sizeof...(Args) - 1>(std::declval<std::tuple<Args...>>()))>::type>::type;
static_assert(std::is_base_of<CallbackUserdata, UserdataT>::value,
"Last argument of callback handler should derive from CallbackUserdata.");
template <class T, class... Ts>
struct UntypedCallbackImpl;
template <std::size_t... I, class... Ts>
struct UntypedCallbackImpl<std::index_sequence<I...>, Ts...> {
template <R (Server::*Func)(Args...)>
static auto ForwardToServer(
// Unpack and forward the types of the parameter pack.
// Append void* as the last argument.
typename std::tuple_element<I, std::tuple<Ts...>>::type... args,
void* userdata) {
// Acquire the userdata, and cast it to UserdataT.
std::unique_ptr<UserdataT> data(static_cast<UserdataT*>(userdata));
if (data->serverIsAlive.expired()) {
// Do nothing if the server has already been destroyed.
return;
}
// Forward the arguments and the typed userdata to the Server:: member function.
(data->server->*Func)(std::forward<decltype(args)>(args)..., data.get());
}
};
// Generate a free function which has all of the same arguments, except the last
// userdata argument is void* instead of UserdataT*. Dawn's userdata args are void*.
using UntypedCallback =
UntypedCallbackImpl<std::make_index_sequence<sizeof...(Args) - 1>, Args...>;
public:
template <R (Server::*F)(Args...)>
static auto Func() {
return UntypedCallback::template ForwardToServer<F>;
}
};
struct MapUserdata : CallbackUserdata {
using CallbackUserdata::CallbackUserdata;
ObjectHandle buffer;
WGPUBuffer bufferObj;
uint32_t requestSerial;
@ -36,28 +122,31 @@ namespace dawn_wire { namespace server {
std::unique_ptr<MemoryTransferService::WriteHandle> writeHandle = nullptr;
};
struct ErrorScopeUserdata {
Server* server;
struct ErrorScopeUserdata : CallbackUserdata {
using CallbackUserdata::CallbackUserdata;
// TODO(enga): ObjectHandle device;
// when the wire supports multiple devices.
uint64_t requestSerial;
};
struct FenceCompletionUserdata {
Server* server;
struct FenceCompletionUserdata : CallbackUserdata {
using CallbackUserdata::CallbackUserdata;
ObjectHandle fence;
uint64_t value;
};
struct FenceOnCompletionUserdata {
Server* server;
struct FenceOnCompletionUserdata : CallbackUserdata {
using CallbackUserdata::CallbackUserdata;
ObjectHandle fence;
uint64_t requestSerial;
};
struct CreateReadyPipelineUserData {
std::weak_ptr<bool> isServerAlive;
Server* server;
struct CreateReadyPipelineUserData : CallbackUserdata {
using CallbackUserdata::CallbackUserdata;
uint64_t requestSerial;
ObjectId pipelineObjectID;
};
@ -76,6 +165,12 @@ namespace dawn_wire { namespace server {
bool InjectTexture(WGPUTexture texture, uint32_t id, uint32_t generation);
template <typename T,
typename Enable = std::enable_if<std::is_base_of<CallbackUserdata, T>::value>>
std::unique_ptr<T> MakeUserdata() {
return std::unique_ptr<T>(new T(this, mIsAlive));
}
private:
template <typename Cmd>
void SerializeCommand(const Cmd& cmd) {
@ -89,21 +184,6 @@ namespace dawn_wire { namespace server {
mSerializer.SerializeCommand(cmd, extraSize, SerializeExtraSize);
}
// Forwarding callbacks
static void ForwardUncapturedError(WGPUErrorType type, const char* message, void* userdata);
static void ForwardDeviceLost(const char* message, void* userdata);
static void ForwardPopErrorScope(WGPUErrorType type, const char* message, void* userdata);
static void ForwardBufferMapAsync(WGPUBufferMapAsyncStatus status, void* userdata);
static void ForwardFenceCompletedValue(WGPUFenceCompletionStatus status, void* userdata);
static void ForwardFenceOnCompletion(WGPUFenceCompletionStatus status, void* userdata);
static void ForwardCreateReadyComputePipeline(WGPUCreateReadyPipelineStatus status,
WGPUComputePipeline pipeline,
const char* message,
void* userdata);
static void ForwardCreateReadyRenderPipeline(WGPUCreateReadyPipelineStatus status,
WGPURenderPipeline pipeline,
const char* message,
void* userdata);
// Error callbacks
void OnUncapturedError(WGPUErrorType type, const char* message);

View File

@ -77,8 +77,7 @@ namespace dawn_wire { namespace server {
return false;
}
std::unique_ptr<MapUserdata> userdata = std::make_unique<MapUserdata>();
userdata->server = this;
std::unique_ptr<MapUserdata> userdata = MakeUserdata<MapUserdata>();
userdata->buffer = ObjectHandle{bufferId, buffer->generation};
userdata->bufferObj = buffer->handle;
userdata->requestSerial = requestSerial;
@ -112,8 +111,11 @@ namespace dawn_wire { namespace server {
userdata->readHandle = std::unique_ptr<MemoryTransferService::ReadHandle>(readHandle);
}
mProcs.bufferMapAsync(buffer->handle, mode, offset, size, ForwardBufferMapAsync,
userdata.release());
mProcs.bufferMapAsync(
buffer->handle, mode, offset, size,
ForwardToServer<decltype(
&Server::OnBufferMapAsyncCallback)>::Func<&Server::OnBufferMapAsyncCallback>(),
userdata.release());
return true;
}
@ -206,14 +208,7 @@ namespace dawn_wire { namespace server {
static_cast<size_t>(writeFlushInfoLength));
}
void Server::ForwardBufferMapAsync(WGPUBufferMapAsyncStatus status, void* userdata) {
auto data = static_cast<MapUserdata*>(userdata);
data->server->OnBufferMapAsyncCallback(status, data);
}
void Server::OnBufferMapAsyncCallback(WGPUBufferMapAsyncStatus status, MapUserdata* userdata) {
std::unique_ptr<MapUserdata> data(userdata);
void Server::OnBufferMapAsyncCallback(WGPUBufferMapAsyncStatus status, MapUserdata* data) {
// Skip sending the callback if the buffer has already been destroyed.
auto* bufferData = BufferObjects().Get(data->buffer.id);
if (bufferData == nullptr || bufferData->generation != data->buffer.generation) {

View File

@ -16,50 +16,6 @@
namespace dawn_wire { namespace server {
void Server::ForwardUncapturedError(WGPUErrorType type, const char* message, void* userdata) {
auto server = static_cast<Server*>(userdata);
server->OnUncapturedError(type, message);
}
void Server::ForwardDeviceLost(const char* message, void* userdata) {
auto server = static_cast<Server*>(userdata);
server->OnDeviceLost(message);
}
void Server::ForwardCreateReadyComputePipeline(WGPUCreateReadyPipelineStatus status,
WGPUComputePipeline pipeline,
const char* message,
void* userdata) {
std::unique_ptr<CreateReadyPipelineUserData> createReadyPipelineUserData(
static_cast<CreateReadyPipelineUserData*>(userdata));
// We need to ensure createReadyPipelineUserData->server is still pointing to a valid
// object before doing any operations on it.
if (createReadyPipelineUserData->isServerAlive.expired()) {
return;
}
createReadyPipelineUserData->server->OnCreateReadyComputePipelineCallback(
status, pipeline, message, createReadyPipelineUserData.release());
}
void Server::ForwardCreateReadyRenderPipeline(WGPUCreateReadyPipelineStatus status,
WGPURenderPipeline pipeline,
const char* message,
void* userdata) {
std::unique_ptr<CreateReadyPipelineUserData> createReadyPipelineUserData(
static_cast<CreateReadyPipelineUserData*>(userdata));
// We need to ensure createReadyPipelineUserData->server is still pointing to a valid
// object before doing any operations on it.
if (createReadyPipelineUserData->isServerAlive.expired()) {
return;
}
createReadyPipelineUserData->server->OnCreateReadyRenderPipelineCallback(
status, pipeline, message, createReadyPipelineUserData.release());
}
void Server::OnUncapturedError(WGPUErrorType type, const char* message) {
ReturnDeviceUncapturedErrorCallbackCmd cmd;
cmd.type = type;
@ -76,17 +32,32 @@ namespace dawn_wire { namespace server {
}
bool Server::DoDevicePopErrorScope(WGPUDevice cDevice, uint64_t requestSerial) {
ErrorScopeUserdata* userdata = new ErrorScopeUserdata;
userdata->server = this;
auto userdata = MakeUserdata<ErrorScopeUserdata>();
userdata->requestSerial = requestSerial;
bool success = mProcs.devicePopErrorScope(cDevice, ForwardPopErrorScope, userdata);
ErrorScopeUserdata* unownedUserdata = userdata.release();
bool success = mProcs.devicePopErrorScope(
cDevice,
ForwardToServer<decltype(
&Server::OnDevicePopErrorScope)>::Func<&Server::OnDevicePopErrorScope>(),
unownedUserdata);
if (!success) {
delete userdata;
delete unownedUserdata;
}
return success;
}
void Server::OnDevicePopErrorScope(WGPUErrorType type,
const char* message,
ErrorScopeUserdata* userdata) {
ReturnDevicePopErrorScopeCallbackCmd cmd;
cmd.requestSerial = userdata->requestSerial;
cmd.type = type;
cmd.message = message;
SerializeCommand(cmd);
}
bool Server::DoDeviceCreateReadyComputePipeline(
WGPUDevice cDevice,
uint64_t requestSerial,
@ -99,24 +70,22 @@ namespace dawn_wire { namespace server {
resultData->generation = pipelineObjectHandle.generation;
std::unique_ptr<CreateReadyPipelineUserData> userdata =
std::make_unique<CreateReadyPipelineUserData>();
userdata->isServerAlive = mIsAlive;
userdata->server = this;
auto userdata = MakeUserdata<CreateReadyPipelineUserData>();
userdata->requestSerial = requestSerial;
userdata->pipelineObjectID = pipelineObjectHandle.id;
mProcs.deviceCreateReadyComputePipeline(
cDevice, descriptor, ForwardCreateReadyComputePipeline, userdata.release());
cDevice, descriptor,
ForwardToServer<decltype(&Server::OnCreateReadyComputePipelineCallback)>::Func<
&Server::OnCreateReadyComputePipelineCallback>(),
userdata.release());
return true;
}
void Server::OnCreateReadyComputePipelineCallback(WGPUCreateReadyPipelineStatus status,
WGPUComputePipeline pipeline,
const char* message,
CreateReadyPipelineUserData* userdata) {
std::unique_ptr<CreateReadyPipelineUserData> data(userdata);
CreateReadyPipelineUserData* data) {
auto* computePipelineObject = ComputePipelineObjects().Get(data->pipelineObjectID);
ASSERT(computePipelineObject != nullptr);
@ -158,24 +127,22 @@ namespace dawn_wire { namespace server {
resultData->generation = pipelineObjectHandle.generation;
std::unique_ptr<CreateReadyPipelineUserData> userdata =
std::make_unique<CreateReadyPipelineUserData>();
userdata->isServerAlive = mIsAlive;
userdata->server = this;
auto userdata = MakeUserdata<CreateReadyPipelineUserData>();
userdata->requestSerial = requestSerial;
userdata->pipelineObjectID = pipelineObjectHandle.id;
mProcs.deviceCreateReadyRenderPipeline(
cDevice, descriptor, ForwardCreateReadyRenderPipeline, userdata.release());
cDevice, descriptor,
ForwardToServer<decltype(&Server::OnCreateReadyRenderPipelineCallback)>::Func<
&Server::OnCreateReadyRenderPipelineCallback>(),
userdata.release());
return true;
}
void Server::OnCreateReadyRenderPipelineCallback(WGPUCreateReadyPipelineStatus status,
WGPURenderPipeline pipeline,
const char* message,
CreateReadyPipelineUserData* userdata) {
std::unique_ptr<CreateReadyPipelineUserData> data(userdata);
CreateReadyPipelineUserData* data) {
auto* renderPipelineObject = RenderPipelineObjects().Get(data->pipelineObjectID);
ASSERT(renderPipelineObject != nullptr);
@ -206,23 +173,4 @@ namespace dawn_wire { namespace server {
SerializeCommand(cmd);
}
// static
void Server::ForwardPopErrorScope(WGPUErrorType type, const char* message, void* userdata) {
auto* data = reinterpret_cast<ErrorScopeUserdata*>(userdata);
data->server->OnDevicePopErrorScope(type, message, data);
}
void Server::OnDevicePopErrorScope(WGPUErrorType type,
const char* message,
ErrorScopeUserdata* userdata) {
std::unique_ptr<ErrorScopeUserdata> data{userdata};
ReturnDevicePopErrorScopeCallbackCmd cmd;
cmd.requestSerial = data->requestSerial;
cmd.type = type;
cmd.message = message;
SerializeCommand(cmd);
}
}} // namespace dawn_wire::server

View File

@ -18,15 +18,8 @@
namespace dawn_wire { namespace server {
void Server::ForwardFenceCompletedValue(WGPUFenceCompletionStatus status, void* userdata) {
auto data = static_cast<FenceCompletionUserdata*>(userdata);
data->server->OnFenceCompletedValueUpdated(status, data);
}
void Server::OnFenceCompletedValueUpdated(WGPUFenceCompletionStatus status,
FenceCompletionUserdata* userdata) {
std::unique_ptr<FenceCompletionUserdata> data(userdata);
FenceCompletionUserdata* data) {
if (status != WGPUFenceCompletionStatus_Success) {
return;
}
@ -49,25 +42,20 @@ namespace dawn_wire { namespace server {
return false;
}
FenceOnCompletionUserdata* userdata = new FenceOnCompletionUserdata;
userdata->server = this;
auto userdata = MakeUserdata<FenceOnCompletionUserdata>();
userdata->fence = ObjectHandle{fenceId, fence->generation};
userdata->requestSerial = requestSerial;
mProcs.fenceOnCompletion(fence->handle, value, ForwardFenceOnCompletion, userdata);
mProcs.fenceOnCompletion(
fence->handle, value,
ForwardToServer<decltype(
&Server::OnFenceOnCompletion)>::Func<&Server::OnFenceOnCompletion>(),
userdata.release());
return true;
}
// static
void Server::ForwardFenceOnCompletion(WGPUFenceCompletionStatus status, void* userdata) {
auto* data = reinterpret_cast<FenceOnCompletionUserdata*>(userdata);
data->server->OnFenceOnCompletion(status, data);
}
void Server::OnFenceOnCompletion(WGPUFenceCompletionStatus status,
FenceOnCompletionUserdata* userdata) {
std::unique_ptr<FenceOnCompletionUserdata> data{userdata};
FenceOnCompletionUserdata* data) {
ReturnFenceOnCompletionCallbackCmd cmd;
cmd.fence = data->fence;
cmd.requestSerial = data->requestSerial;

View File

@ -28,12 +28,15 @@ namespace dawn_wire { namespace server {
auto* fence = FenceObjects().Get(fenceId);
ASSERT(fence != nullptr);
FenceCompletionUserdata* userdata = new FenceCompletionUserdata;
userdata->server = this;
auto userdata = MakeUserdata<FenceCompletionUserdata>();
userdata->fence = ObjectHandle{fenceId, fence->generation};
userdata->value = signalValue;
mProcs.fenceOnCompletion(cFence, signalValue, ForwardFenceCompletedValue, userdata);
mProcs.fenceOnCompletion(
cFence, signalValue,
ForwardToServer<decltype(&Server::OnFenceCompletedValueUpdated)>::Func<
&Server::OnFenceCompletedValueUpdated>(),
userdata.release());
return true;
}

View File

@ -79,6 +79,12 @@ class WireMultipleDeviceTests : public testing::Test {
~WireHolder() {
mApi.IgnoreAllReleaseCalls();
mWireClient = nullptr;
// These are called on server destruction to clear the callbacks. They must not be
// called after the server is destroyed.
EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr))
.Times(Exactly(1));
EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
mWireServer = nullptr;
}

View File

@ -86,6 +86,13 @@ void WireTest::TearDown() {
// cannot be null.
api.IgnoreAllReleaseCalls();
mWireClient = nullptr;
if (mWireServer) {
// These are called on server destruction to clear the callbacks. They must not be
// called after the server is destroyed.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
}
mWireServer = nullptr;
}
@ -110,6 +117,13 @@ dawn_wire::WireClient* WireTest::GetWireClient() {
void WireTest::DeleteServer() {
EXPECT_CALL(api, QueueRelease(apiQueue)).Times(1);
if (mWireServer) {
// These are called on server destruction to clear the callbacks. They must not be
// called after the server is destroyed.
EXPECT_CALL(api, OnDeviceSetUncapturedErrorCallback(_, nullptr, nullptr)).Times(Exactly(1));
EXPECT_CALL(api, OnDeviceSetDeviceLostCallback(_, nullptr, nullptr)).Times(Exactly(1));
}
mWireServer = nullptr;
}