Implement builder callback forwarding in the wire
This commit is contained in:
parent
7f96177289
commit
cd0ea35889
|
@ -91,6 +91,7 @@ class ObjectType(Type):
|
|||
Type.__init__(self, name, record)
|
||||
self.methods = []
|
||||
self.native_methods = []
|
||||
self.built_type = None
|
||||
|
||||
############################################################
|
||||
# PARSE
|
||||
|
@ -124,6 +125,14 @@ def link_object(obj, types):
|
|||
obj.methods = [method for method in methods if not is_native_method(method)]
|
||||
obj.native_methods = [method for method in methods if is_native_method(method)]
|
||||
|
||||
# Compute the built object type for builders
|
||||
if obj.is_builder:
|
||||
for method in obj.methods:
|
||||
if method.name.canonical_case() == "get result":
|
||||
obj.built_type = method.return_type
|
||||
break
|
||||
assert(obj.built_type != None)
|
||||
|
||||
def parse_json(json):
|
||||
category_to_parser = {
|
||||
'bitmask': BitmaskType,
|
||||
|
|
|
@ -15,9 +15,12 @@
|
|||
#include "Wire.h"
|
||||
#include "WireCmd.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cstring>
|
||||
#include <vector>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace nxt {
|
||||
namespace wire {
|
||||
|
||||
|
@ -26,6 +29,20 @@ namespace wire {
|
|||
|
||||
class Device;
|
||||
|
||||
struct BuilderCallbackData {
|
||||
void Call(nxtBuilderErrorStatus status, const char* message) {
|
||||
if (canCall && callback != nullptr) {
|
||||
canCall = true;
|
||||
callback(status, message, userdata1, userdata2);
|
||||
}
|
||||
}
|
||||
|
||||
nxtBuilderErrorCallback callback = nullptr;
|
||||
nxtCallbackUserdata userdata1 = 0;
|
||||
nxtCallbackUserdata userdata2 = 0;
|
||||
bool canCall = true;
|
||||
};
|
||||
|
||||
//* All non-Device objects of the client side have:
|
||||
//* - A pointer to the device to get where to serialize commands
|
||||
//* - The external reference count
|
||||
|
@ -38,6 +55,8 @@ namespace wire {
|
|||
Device* device;
|
||||
uint32_t refcount;
|
||||
uint32_t id;
|
||||
|
||||
BuilderCallbackData builderCallback;
|
||||
};
|
||||
|
||||
{% for type in by_category["object"] if not type.name.canonical_case() == "device" %}
|
||||
|
@ -46,19 +65,58 @@ namespace wire {
|
|||
};
|
||||
{% endfor %}
|
||||
|
||||
//* TODO: Remember objects so they can all be destroyed at device destruction.
|
||||
//* TODO(cwallez@chromium.org): Do something with objects before they are destroyed ?
|
||||
//* - Call still uncalled builder callbacks
|
||||
template<typename T>
|
||||
class ObjectAllocator {
|
||||
public:
|
||||
struct ObjectAndSerial {
|
||||
ObjectAndSerial(std::unique_ptr<T> object, uint32_t serial)
|
||||
: object(std::move(object)), serial(serial) {
|
||||
}
|
||||
std::unique_ptr<T> object;
|
||||
uint32_t serial;
|
||||
};
|
||||
|
||||
ObjectAllocator(Device* device) : device(device) {
|
||||
// ID 0 is nullptr
|
||||
objects.emplace_back(nullptr, 0);
|
||||
}
|
||||
|
||||
T* New() {
|
||||
return new T(device, 1, GetNewId());
|
||||
ObjectAndSerial* New() {
|
||||
uint32_t id = GetNewId();
|
||||
T* result = new T(device, 1, id);
|
||||
auto object = std::unique_ptr<T>(result);
|
||||
|
||||
if (id >= objects.size()) {
|
||||
assert(id == objects.size());
|
||||
objects.emplace_back(std::move(object), 0);
|
||||
} else {
|
||||
assert(objects[id].object == nullptr);
|
||||
//* TODO(cwallez@chromium.org): investigate if overflows could cause bad things to happen
|
||||
objects[id].serial++;
|
||||
objects[id].object = std::move(object);
|
||||
}
|
||||
|
||||
return &objects[id];
|
||||
}
|
||||
void Free(T* obj) {
|
||||
FreeId(obj->id);
|
||||
delete obj;
|
||||
objects[obj->id].object = nullptr;
|
||||
}
|
||||
|
||||
T* GetObject(uint32_t id) {
|
||||
if (id >= objects.size()) {
|
||||
return nullptr;
|
||||
}
|
||||
return objects[id].object.get();
|
||||
}
|
||||
|
||||
uint32_t GetSerial(uint32_t id) {
|
||||
if (id >= objects.size()) {
|
||||
return 0;
|
||||
}
|
||||
return objects[id].serial;
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -77,6 +135,7 @@ namespace wire {
|
|||
// 0 is an ID reserved to represent nullptr
|
||||
uint32_t currentId = 1;
|
||||
std::vector<uint32_t> freeIds;
|
||||
std::vector<ObjectAndSerial> objects;
|
||||
Device* device;
|
||||
};
|
||||
|
||||
|
@ -165,18 +224,31 @@ namespace wire {
|
|||
|
||||
//* For object creation, store the object ID the client will use for the result.
|
||||
{% if method.return_type.category == "object" %}
|
||||
auto result = self->device->{{method.return_type.name.camelCase()}}.New();
|
||||
allocCmd->resultId = result->id;
|
||||
return result;
|
||||
auto* allocation = self->device->{{method.return_type.name.camelCase()}}.New();
|
||||
|
||||
{% if type.is_builder %}
|
||||
//* We are in GetResult, so the callback that should be called is the
|
||||
//* currently set one. Copy it over to the created object and prevent the
|
||||
//* builder from calling the callback on destruction.
|
||||
allocation->object->builderCallback = self->builderCallback;
|
||||
self->builderCallback.canCall = false;
|
||||
{% endif %}
|
||||
|
||||
allocCmd->resultId = allocation->object->id;
|
||||
allocCmd->resultSerial = allocation->serial;
|
||||
return allocation->object.get();
|
||||
{% endif %}
|
||||
}
|
||||
{% endfor %}
|
||||
|
||||
{% if type.is_builder %}
|
||||
void Client{{as_MethodSuffix(type.name, Name("set error callback"))}}(nxtBuilderErrorCallback callback,
|
||||
void Client{{as_MethodSuffix(type.name, Name("set error callback"))}}({{Type}}* self,
|
||||
nxtBuilderErrorCallback callback,
|
||||
nxtCallbackUserdata userdata1,
|
||||
nxtCallbackUserdata userdata2) {
|
||||
//TODO(cwallez@chromium.org): will be implemented in a follow-up commit.
|
||||
self->builderCallback.callback = callback;
|
||||
self->builderCallback.userdata1 = userdata1;
|
||||
self->builderCallback.userdata2 = userdata2;
|
||||
}
|
||||
{% endif %}
|
||||
|
||||
|
@ -189,6 +261,8 @@ namespace wire {
|
|||
return;
|
||||
}
|
||||
|
||||
obj->builderCallback.Call(NXT_BUILDER_ERROR_STATUS_UNKNOWN, "Unknown");
|
||||
|
||||
wire::{{as_MethodSuffix(type.name, Name("destroy"))}}Cmd cmd;
|
||||
cmd.objectId = obj->id;
|
||||
|
||||
|
@ -240,6 +314,11 @@ namespace wire {
|
|||
case ReturnWireCmd::DeviceErrorCallback:
|
||||
success = HandleDeviceErrorCallbackCmd(&commands, &size);
|
||||
break;
|
||||
{% for type in by_category["object"] if type.is_builder %}
|
||||
case ReturnWireCmd::{{type.name.CamelCase()}}ErrorCallback:
|
||||
success = Handle{{type.name.CamelCase()}}ErrorCallbackCmd(&commands, &size);
|
||||
break;
|
||||
{% endfor %}
|
||||
default:
|
||||
success = false;
|
||||
}
|
||||
|
@ -298,6 +377,30 @@ namespace wire {
|
|||
return true;
|
||||
}
|
||||
|
||||
{% for type in by_category["object"] if type.is_builder %}
|
||||
{% set Type = type.name.CamelCase() %}
|
||||
bool Handle{{Type}}ErrorCallbackCmd(const uint8_t** commands, size_t* size) {
|
||||
const auto* cmd = GetCommand<Return{{Type}}ErrorCallbackCmd>(commands, size);
|
||||
if (cmd == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (cmd->GetMessage()[cmd->messageStrlen] != '\0') {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* builtObject = device->{{type.built_type.name.camelCase()}}.GetObject(cmd->builtObjectId);
|
||||
uint32_t objectSerial = device->{{type.built_type.name.camelCase()}}.GetSerial(cmd->builtObjectId);
|
||||
|
||||
//* The object might have been deleted or a new object created with the same ID.
|
||||
if (builtObject == nullptr || objectSerial != cmd->builtObjectSerial) {
|
||||
return true;
|
||||
}
|
||||
|
||||
builtObject->builderCallback.Call(static_cast<nxtBuilderErrorStatus>(cmd->status), cmd->GetMessage());
|
||||
return true;
|
||||
}
|
||||
{% endfor %}
|
||||
};
|
||||
|
||||
}
|
||||
|
|
|
@ -71,5 +71,20 @@ namespace wire {
|
|||
}
|
||||
{% endfor %}
|
||||
|
||||
{% for type in by_category["object"] if type.is_builder %}
|
||||
{% set Type = type.name.CamelCase() %}
|
||||
size_t Return{{Type}}ErrorCallbackCmd::GetRequiredSize() const {
|
||||
return sizeof(*this) + messageStrlen + 1;
|
||||
}
|
||||
|
||||
char* Return{{Type}}ErrorCallbackCmd::GetMessage() {
|
||||
return reinterpret_cast<char*>(this + 1);
|
||||
}
|
||||
|
||||
const char* Return{{Type}}ErrorCallbackCmd::GetMessage() const {
|
||||
return reinterpret_cast<const char*>(this + 1);
|
||||
}
|
||||
{% endfor %}
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,6 +49,7 @@ namespace wire {
|
|||
//* Commands creating objects say which ID the created object will be referred as.
|
||||
{% if method.return_type.category == "object" %}
|
||||
uint32_t resultId;
|
||||
uint32_t resultSerial;
|
||||
{% endif %}
|
||||
|
||||
//* Value types are directly in the command, objects being replaced with their IDs.
|
||||
|
@ -88,13 +89,32 @@ namespace wire {
|
|||
|
||||
size_t GetRequiredSize() const;
|
||||
};
|
||||
|
||||
{% endfor %}
|
||||
|
||||
//* Enum used as a prefix to each command on the return wire format.
|
||||
enum class ReturnWireCmd : uint32_t {
|
||||
DeviceErrorCallback,
|
||||
{% for type in by_category["object"] if type.is_builder %}
|
||||
{{type.name.CamelCase()}}ErrorCallback,
|
||||
{% endfor %}
|
||||
};
|
||||
|
||||
{% for type in by_category["object"] if type.is_builder %}
|
||||
struct Return{{type.name.CamelCase()}}ErrorCallbackCmd {
|
||||
wire::ReturnWireCmd commandId = ReturnWireCmd::{{type.name.CamelCase()}}ErrorCallback;
|
||||
|
||||
uint32_t builtObjectId;
|
||||
uint32_t builtObjectSerial;
|
||||
uint32_t status;
|
||||
size_t messageStrlen;
|
||||
|
||||
size_t GetRequiredSize() const;
|
||||
char* GetMessage();
|
||||
const char* GetMessage() const;
|
||||
};
|
||||
{% endfor %}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -25,8 +25,15 @@ namespace wire {
|
|||
//* Stores what the backend knows about the type.
|
||||
template<typename T>
|
||||
struct ObjectDataBase {
|
||||
//* The backend-provided handle to this object.
|
||||
//* The backend-provided handle and serial to this object.
|
||||
T handle;
|
||||
uint32_t serial = 0;
|
||||
|
||||
//* Built object ID and serial, needed to send to the client along with builder error callbacks
|
||||
//* TODO(cwallez@chromium.org) only have this for builder T
|
||||
uint32_t builtObjectId = 0;
|
||||
uint32_t builtObjectSerial = 0;
|
||||
|
||||
//* Used by the error-propagation mechanism to know if this object is an error.
|
||||
//* TODO(cwallez@chromium.org): this is doubling the memory usage of
|
||||
//* std::vector<ObjectDataBase> consider making it a special marker value in handle instead.
|
||||
|
@ -105,6 +112,10 @@ namespace wire {
|
|||
|
||||
void ForwardDeviceErrorToServer(const char* message, nxtCallbackUserdata userdata);
|
||||
|
||||
{% for type in by_category["object"] if type.is_builder%}
|
||||
void Forward{{type.name.CamelCase()}}ToClient(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2);
|
||||
{% endfor %}
|
||||
|
||||
class Server : public CommandHandler {
|
||||
public:
|
||||
Server(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer)
|
||||
|
@ -127,6 +138,37 @@ namespace wire {
|
|||
strcpy(allocCmd->GetMessage(), message);
|
||||
}
|
||||
|
||||
{% for type in by_category["object"] if type.is_builder%}
|
||||
{% set Type = type.name.CamelCase() %}
|
||||
void On{{Type}}Error(nxtBuilderErrorStatus status, const char* message, uint32_t id, uint32_t serial) {
|
||||
auto* builder = known{{Type}}.Get(id);
|
||||
|
||||
if (builder == nullptr || builder->serial != serial) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (status != NXT_BUILDER_ERROR_STATUS_SUCCESS) {
|
||||
builder->valid = false;
|
||||
}
|
||||
|
||||
if (status != NXT_BUILDER_ERROR_STATUS_UNKNOWN) {
|
||||
//* Unknown is the only status that can be returned without a call to GetResult
|
||||
//* so we are guaranteed to have created an object.
|
||||
assert(builder->builtObjectId != 0);
|
||||
|
||||
Return{{Type}}ErrorCallbackCmd cmd;
|
||||
cmd.builtObjectId = builder->builtObjectId;
|
||||
cmd.builtObjectSerial = builder->builtObjectSerial;
|
||||
cmd.status = status;
|
||||
cmd.messageStrlen = std::strlen(message);
|
||||
|
||||
auto allocCmd = reinterpret_cast<Return{{Type}}ErrorCallbackCmd*>(GetCmdSpace(cmd.GetRequiredSize()));
|
||||
*allocCmd = cmd;
|
||||
strcpy(allocCmd->GetMessage(), message);
|
||||
}
|
||||
}
|
||||
{% endfor %}
|
||||
|
||||
const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override {
|
||||
while (size > sizeof(WireCmd)) {
|
||||
WireCmd cmdId = *reinterpret_cast<const WireCmd*>(commands);
|
||||
|
@ -275,13 +317,20 @@ namespace wire {
|
|||
//* At that point all the data has been upacked in cmd->* or arg_*
|
||||
|
||||
//* In all cases allocate the object data as it will be refered-to by the client.
|
||||
{% set returns = method.return_type.name.canonical_case() != "void" %}
|
||||
{% set return_type = method.return_type %}
|
||||
{% set returns = return_type.name.canonical_case() != "void" %}
|
||||
{% if returns %}
|
||||
{% set Type = method.return_type.name.CamelCase() %}
|
||||
auto* resultData = known{{Type}}.Allocate(cmd->resultId);
|
||||
if (resultData == nullptr) {
|
||||
return false;
|
||||
}
|
||||
resultData->serial = cmd->resultSerial;
|
||||
|
||||
{% if type.is_builder %}
|
||||
selfData->builtObjectId = cmd->resultId;
|
||||
selfData->builtObjectSerial = cmd->resultSerial;
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
//* After the data is allocated, apply the argument error propagation mechanism
|
||||
|
@ -305,16 +354,17 @@ namespace wire {
|
|||
{% if returns %}
|
||||
resultData->handle = result;
|
||||
resultData->valid = result != nullptr;
|
||||
{% endif %}
|
||||
|
||||
if (gotError) {
|
||||
{% if type.is_builder %}
|
||||
//* Get the data again, has been invalidated by the call to
|
||||
//* known.Allocate
|
||||
known{{type.name.CamelCase()}}.Get(cmd->self)->valid = false;
|
||||
{% endif %}
|
||||
gotError = false;
|
||||
//* builders remember the ID of the object they built so that they can send it
|
||||
//* in the callback to the client.
|
||||
{% if return_type.is_builder %}
|
||||
if (result != nullptr) {
|
||||
uint64_t userdata1 = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(this));
|
||||
uint64_t userdata2 = (uint64_t(resultData->serial) << uint64_t(32)) + cmd->resultId;
|
||||
procs.{{as_varName(return_type.name, Name("set error callback"))}}(result, Forward{{return_type.name.CamelCase()}}ToClient, userdata1, userdata2);
|
||||
}
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -353,6 +403,15 @@ namespace wire {
|
|||
auto server = reinterpret_cast<Server*>(static_cast<intptr_t>(userdata));
|
||||
server->OnDeviceError(message);
|
||||
}
|
||||
|
||||
{% for type in by_category["object"] if type.is_builder%}
|
||||
void Forward{{type.name.CamelCase()}}ToClient(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2) {
|
||||
auto server = reinterpret_cast<Server*>(static_cast<intptr_t>(userdata1));
|
||||
uint32_t id = userdata2 & 0xFFFFFFFFu;
|
||||
uint32_t serial = userdata2 >> uint64_t(32);
|
||||
server->On{{type.name.CamelCase()}}Error(status, message, id, serial);
|
||||
}
|
||||
{% endfor %}
|
||||
}
|
||||
|
||||
CommandHandler* NewServerCommandHandler(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer) {
|
||||
|
|
Loading…
Reference in New Issue