Implement builder callback forwarding in the wire

This commit is contained in:
Corentin Wallez 2017-04-20 14:43:11 -04:00 committed by Corentin Wallez
parent 7f96177289
commit cd0ea35889
5 changed files with 225 additions and 19 deletions

View File

@ -91,6 +91,7 @@ class ObjectType(Type):
Type.__init__(self, name, record) Type.__init__(self, name, record)
self.methods = [] self.methods = []
self.native_methods = [] self.native_methods = []
self.built_type = None
############################################################ ############################################################
# PARSE # PARSE
@ -124,6 +125,14 @@ def link_object(obj, types):
obj.methods = [method for method in methods if not is_native_method(method)] obj.methods = [method for method in methods if not is_native_method(method)]
obj.native_methods = [method for method in methods if is_native_method(method)] obj.native_methods = [method for method in methods if is_native_method(method)]
# Compute the built object type for builders
if obj.is_builder:
for method in obj.methods:
if method.name.canonical_case() == "get result":
obj.built_type = method.return_type
break
assert(obj.built_type != None)
def parse_json(json): def parse_json(json):
category_to_parser = { category_to_parser = {
'bitmask': BitmaskType, 'bitmask': BitmaskType,

View File

@ -15,9 +15,12 @@
#include "Wire.h" #include "Wire.h"
#include "WireCmd.h" #include "WireCmd.h"
#include <cassert>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
#include <iostream>
namespace nxt { namespace nxt {
namespace wire { namespace wire {
@ -26,6 +29,20 @@ namespace wire {
class Device; class Device;
struct BuilderCallbackData {
void Call(nxtBuilderErrorStatus status, const char* message) {
if (canCall && callback != nullptr) {
canCall = true;
callback(status, message, userdata1, userdata2);
}
}
nxtBuilderErrorCallback callback = nullptr;
nxtCallbackUserdata userdata1 = 0;
nxtCallbackUserdata userdata2 = 0;
bool canCall = true;
};
//* All non-Device objects of the client side have: //* All non-Device objects of the client side have:
//* - A pointer to the device to get where to serialize commands //* - A pointer to the device to get where to serialize commands
//* - The external reference count //* - The external reference count
@ -38,6 +55,8 @@ namespace wire {
Device* device; Device* device;
uint32_t refcount; uint32_t refcount;
uint32_t id; uint32_t id;
BuilderCallbackData builderCallback;
}; };
{% for type in by_category["object"] if not type.name.canonical_case() == "device" %} {% for type in by_category["object"] if not type.name.canonical_case() == "device" %}
@ -46,19 +65,58 @@ namespace wire {
}; };
{% endfor %} {% endfor %}
//* TODO: Remember objects so they can all be destroyed at device destruction. //* TODO(cwallez@chromium.org): Do something with objects before they are destroyed ?
//* - Call still uncalled builder callbacks
template<typename T> template<typename T>
class ObjectAllocator { class ObjectAllocator {
public: public:
struct ObjectAndSerial {
ObjectAndSerial(std::unique_ptr<T> object, uint32_t serial)
: object(std::move(object)), serial(serial) {
}
std::unique_ptr<T> object;
uint32_t serial;
};
ObjectAllocator(Device* device) : device(device) { ObjectAllocator(Device* device) : device(device) {
// ID 0 is nullptr
objects.emplace_back(nullptr, 0);
} }
T* New() { ObjectAndSerial* New() {
return new T(device, 1, GetNewId()); uint32_t id = GetNewId();
T* result = new T(device, 1, id);
auto object = std::unique_ptr<T>(result);
if (id >= objects.size()) {
assert(id == objects.size());
objects.emplace_back(std::move(object), 0);
} else {
assert(objects[id].object == nullptr);
//* TODO(cwallez@chromium.org): investigate if overflows could cause bad things to happen
objects[id].serial++;
objects[id].object = std::move(object);
}
return &objects[id];
} }
void Free(T* obj) { void Free(T* obj) {
FreeId(obj->id); FreeId(obj->id);
delete obj; objects[obj->id].object = nullptr;
}
T* GetObject(uint32_t id) {
if (id >= objects.size()) {
return nullptr;
}
return objects[id].object.get();
}
uint32_t GetSerial(uint32_t id) {
if (id >= objects.size()) {
return 0;
}
return objects[id].serial;
} }
private: private:
@ -77,6 +135,7 @@ namespace wire {
// 0 is an ID reserved to represent nullptr // 0 is an ID reserved to represent nullptr
uint32_t currentId = 1; uint32_t currentId = 1;
std::vector<uint32_t> freeIds; std::vector<uint32_t> freeIds;
std::vector<ObjectAndSerial> objects;
Device* device; Device* device;
}; };
@ -165,18 +224,31 @@ namespace wire {
//* For object creation, store the object ID the client will use for the result. //* For object creation, store the object ID the client will use for the result.
{% if method.return_type.category == "object" %} {% if method.return_type.category == "object" %}
auto result = self->device->{{method.return_type.name.camelCase()}}.New(); auto* allocation = self->device->{{method.return_type.name.camelCase()}}.New();
allocCmd->resultId = result->id;
return result; {% if type.is_builder %}
//* We are in GetResult, so the callback that should be called is the
//* currently set one. Copy it over to the created object and prevent the
//* builder from calling the callback on destruction.
allocation->object->builderCallback = self->builderCallback;
self->builderCallback.canCall = false;
{% endif %}
allocCmd->resultId = allocation->object->id;
allocCmd->resultSerial = allocation->serial;
return allocation->object.get();
{% endif %} {% endif %}
} }
{% endfor %} {% endfor %}
{% if type.is_builder %} {% if type.is_builder %}
void Client{{as_MethodSuffix(type.name, Name("set error callback"))}}(nxtBuilderErrorCallback callback, void Client{{as_MethodSuffix(type.name, Name("set error callback"))}}({{Type}}* self,
nxtBuilderErrorCallback callback,
nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata1,
nxtCallbackUserdata userdata2) { nxtCallbackUserdata userdata2) {
//TODO(cwallez@chromium.org): will be implemented in a follow-up commit. self->builderCallback.callback = callback;
self->builderCallback.userdata1 = userdata1;
self->builderCallback.userdata2 = userdata2;
} }
{% endif %} {% endif %}
@ -189,6 +261,8 @@ namespace wire {
return; return;
} }
obj->builderCallback.Call(NXT_BUILDER_ERROR_STATUS_UNKNOWN, "Unknown");
wire::{{as_MethodSuffix(type.name, Name("destroy"))}}Cmd cmd; wire::{{as_MethodSuffix(type.name, Name("destroy"))}}Cmd cmd;
cmd.objectId = obj->id; cmd.objectId = obj->id;
@ -240,6 +314,11 @@ namespace wire {
case ReturnWireCmd::DeviceErrorCallback: case ReturnWireCmd::DeviceErrorCallback:
success = HandleDeviceErrorCallbackCmd(&commands, &size); success = HandleDeviceErrorCallbackCmd(&commands, &size);
break; break;
{% for type in by_category["object"] if type.is_builder %}
case ReturnWireCmd::{{type.name.CamelCase()}}ErrorCallback:
success = Handle{{type.name.CamelCase()}}ErrorCallbackCmd(&commands, &size);
break;
{% endfor %}
default: default:
success = false; success = false;
} }
@ -298,6 +377,30 @@ namespace wire {
return true; return true;
} }
{% for type in by_category["object"] if type.is_builder %}
{% set Type = type.name.CamelCase() %}
bool Handle{{Type}}ErrorCallbackCmd(const uint8_t** commands, size_t* size) {
const auto* cmd = GetCommand<Return{{Type}}ErrorCallbackCmd>(commands, size);
if (cmd == nullptr) {
return false;
}
if (cmd->GetMessage()[cmd->messageStrlen] != '\0') {
return false;
}
auto* builtObject = device->{{type.built_type.name.camelCase()}}.GetObject(cmd->builtObjectId);
uint32_t objectSerial = device->{{type.built_type.name.camelCase()}}.GetSerial(cmd->builtObjectId);
//* The object might have been deleted or a new object created with the same ID.
if (builtObject == nullptr || objectSerial != cmd->builtObjectSerial) {
return true;
}
builtObject->builderCallback.Call(static_cast<nxtBuilderErrorStatus>(cmd->status), cmd->GetMessage());
return true;
}
{% endfor %}
}; };
} }

View File

@ -71,5 +71,20 @@ namespace wire {
} }
{% endfor %} {% endfor %}
{% for type in by_category["object"] if type.is_builder %}
{% set Type = type.name.CamelCase() %}
size_t Return{{Type}}ErrorCallbackCmd::GetRequiredSize() const {
return sizeof(*this) + messageStrlen + 1;
}
char* Return{{Type}}ErrorCallbackCmd::GetMessage() {
return reinterpret_cast<char*>(this + 1);
}
const char* Return{{Type}}ErrorCallbackCmd::GetMessage() const {
return reinterpret_cast<const char*>(this + 1);
}
{% endfor %}
} }
} }

View File

@ -49,6 +49,7 @@ namespace wire {
//* Commands creating objects say which ID the created object will be referred as. //* Commands creating objects say which ID the created object will be referred as.
{% if method.return_type.category == "object" %} {% if method.return_type.category == "object" %}
uint32_t resultId; uint32_t resultId;
uint32_t resultSerial;
{% endif %} {% endif %}
//* Value types are directly in the command, objects being replaced with their IDs. //* Value types are directly in the command, objects being replaced with their IDs.
@ -88,13 +89,32 @@ namespace wire {
size_t GetRequiredSize() const; size_t GetRequiredSize() const;
}; };
{% endfor %} {% endfor %}
//* Enum used as a prefix to each command on the return wire format. //* Enum used as a prefix to each command on the return wire format.
enum class ReturnWireCmd : uint32_t { enum class ReturnWireCmd : uint32_t {
DeviceErrorCallback, DeviceErrorCallback,
{% for type in by_category["object"] if type.is_builder %}
{{type.name.CamelCase()}}ErrorCallback,
{% endfor %}
}; };
{% for type in by_category["object"] if type.is_builder %}
struct Return{{type.name.CamelCase()}}ErrorCallbackCmd {
wire::ReturnWireCmd commandId = ReturnWireCmd::{{type.name.CamelCase()}}ErrorCallback;
uint32_t builtObjectId;
uint32_t builtObjectSerial;
uint32_t status;
size_t messageStrlen;
size_t GetRequiredSize() const;
char* GetMessage();
const char* GetMessage() const;
};
{% endfor %}
} }
} }

View File

@ -25,8 +25,15 @@ namespace wire {
//* Stores what the backend knows about the type. //* Stores what the backend knows about the type.
template<typename T> template<typename T>
struct ObjectDataBase { struct ObjectDataBase {
//* The backend-provided handle to this object. //* The backend-provided handle and serial to this object.
T handle; T handle;
uint32_t serial = 0;
//* Built object ID and serial, needed to send to the client along with builder error callbacks
//* TODO(cwallez@chromium.org) only have this for builder T
uint32_t builtObjectId = 0;
uint32_t builtObjectSerial = 0;
//* Used by the error-propagation mechanism to know if this object is an error. //* Used by the error-propagation mechanism to know if this object is an error.
//* TODO(cwallez@chromium.org): this is doubling the memory usage of //* TODO(cwallez@chromium.org): this is doubling the memory usage of
//* std::vector<ObjectDataBase> consider making it a special marker value in handle instead. //* std::vector<ObjectDataBase> consider making it a special marker value in handle instead.
@ -105,6 +112,10 @@ namespace wire {
void ForwardDeviceErrorToServer(const char* message, nxtCallbackUserdata userdata); void ForwardDeviceErrorToServer(const char* message, nxtCallbackUserdata userdata);
{% for type in by_category["object"] if type.is_builder%}
void Forward{{type.name.CamelCase()}}ToClient(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2);
{% endfor %}
class Server : public CommandHandler { class Server : public CommandHandler {
public: public:
Server(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer) Server(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer)
@ -127,6 +138,37 @@ namespace wire {
strcpy(allocCmd->GetMessage(), message); strcpy(allocCmd->GetMessage(), message);
} }
{% for type in by_category["object"] if type.is_builder%}
{% set Type = type.name.CamelCase() %}
void On{{Type}}Error(nxtBuilderErrorStatus status, const char* message, uint32_t id, uint32_t serial) {
auto* builder = known{{Type}}.Get(id);
if (builder == nullptr || builder->serial != serial) {
return;
}
if (status != NXT_BUILDER_ERROR_STATUS_SUCCESS) {
builder->valid = false;
}
if (status != NXT_BUILDER_ERROR_STATUS_UNKNOWN) {
//* Unknown is the only status that can be returned without a call to GetResult
//* so we are guaranteed to have created an object.
assert(builder->builtObjectId != 0);
Return{{Type}}ErrorCallbackCmd cmd;
cmd.builtObjectId = builder->builtObjectId;
cmd.builtObjectSerial = builder->builtObjectSerial;
cmd.status = status;
cmd.messageStrlen = std::strlen(message);
auto allocCmd = reinterpret_cast<Return{{Type}}ErrorCallbackCmd*>(GetCmdSpace(cmd.GetRequiredSize()));
*allocCmd = cmd;
strcpy(allocCmd->GetMessage(), message);
}
}
{% endfor %}
const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override { const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override {
while (size > sizeof(WireCmd)) { while (size > sizeof(WireCmd)) {
WireCmd cmdId = *reinterpret_cast<const WireCmd*>(commands); WireCmd cmdId = *reinterpret_cast<const WireCmd*>(commands);
@ -275,13 +317,20 @@ namespace wire {
//* At that point all the data has been upacked in cmd->* or arg_* //* At that point all the data has been upacked in cmd->* or arg_*
//* In all cases allocate the object data as it will be refered-to by the client. //* In all cases allocate the object data as it will be refered-to by the client.
{% set returns = method.return_type.name.canonical_case() != "void" %} {% set return_type = method.return_type %}
{% set returns = return_type.name.canonical_case() != "void" %}
{% if returns %} {% if returns %}
{% set Type = method.return_type.name.CamelCase() %} {% set Type = method.return_type.name.CamelCase() %}
auto* resultData = known{{Type}}.Allocate(cmd->resultId); auto* resultData = known{{Type}}.Allocate(cmd->resultId);
if (resultData == nullptr) { if (resultData == nullptr) {
return false; return false;
} }
resultData->serial = cmd->resultSerial;
{% if type.is_builder %}
selfData->builtObjectId = cmd->resultId;
selfData->builtObjectSerial = cmd->resultSerial;
{% endif %}
{% endif %} {% endif %}
//* After the data is allocated, apply the argument error propagation mechanism //* After the data is allocated, apply the argument error propagation mechanism
@ -305,16 +354,17 @@ namespace wire {
{% if returns %} {% if returns %}
resultData->handle = result; resultData->handle = result;
resultData->valid = result != nullptr; resultData->valid = result != nullptr;
{% endif %}
if (gotError) { //* builders remember the ID of the object they built so that they can send it
{% if type.is_builder %} //* in the callback to the client.
//* Get the data again, has been invalidated by the call to {% if return_type.is_builder %}
//* known.Allocate if (result != nullptr) {
known{{type.name.CamelCase()}}.Get(cmd->self)->valid = false; uint64_t userdata1 = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(this));
{% endif %} uint64_t userdata2 = (uint64_t(resultData->serial) << uint64_t(32)) + cmd->resultId;
gotError = false; procs.{{as_varName(return_type.name, Name("set error callback"))}}(result, Forward{{return_type.name.CamelCase()}}ToClient, userdata1, userdata2);
} }
{% endif %}
{% endif %}
return true; return true;
} }
@ -353,6 +403,15 @@ namespace wire {
auto server = reinterpret_cast<Server*>(static_cast<intptr_t>(userdata)); auto server = reinterpret_cast<Server*>(static_cast<intptr_t>(userdata));
server->OnDeviceError(message); server->OnDeviceError(message);
} }
{% for type in by_category["object"] if type.is_builder%}
void Forward{{type.name.CamelCase()}}ToClient(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2) {
auto server = reinterpret_cast<Server*>(static_cast<intptr_t>(userdata1));
uint32_t id = userdata2 & 0xFFFFFFFFu;
uint32_t serial = userdata2 >> uint64_t(32);
server->On{{type.name.CamelCase()}}Error(status, message, id, serial);
}
{% endfor %}
} }
CommandHandler* NewServerCommandHandler(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer) { CommandHandler* NewServerCommandHandler(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer) {