Wire: Move the logic of [de]serialization in WireCmd.

This will help with follow-up changes that add support for a more
complete grammer of types, including structures containing pointers
to objects or other structures.

Instead of having the wire::Client and wire::Server directly act on
buffer memory, a couple interfaces are introduced so that WireCmd can do
things like get the object<->id mapping and temporary allocations.

While the serialization and deserialization of most commands was moved
into WireCmd, the commands that don't directly correspond to NXT methods
have their logic moved inside Client and Server and will be made to
expose the new interface in a follow-up commit.
This commit is contained in:
Corentin Wallez 2018-06-06 17:36:49 +02:00 committed by Corentin Wallez
parent 419e9841a8
commit 88fb8fa353
10 changed files with 496 additions and 351 deletions

View File

@ -24,8 +24,7 @@
#include <string>
#include <vector>
namespace nxt {
namespace wire {
namespace nxt { namespace wire {
//* Client side implementation of the API, will serialize everything to memory to send to the server side.
namespace client {
@ -187,7 +186,7 @@ namespace wire {
//* The client wire uses the global NXT device to store its global data such as the serializer
//* and the object id allocators.
class Device : public ObjectBase {
class Device : public ObjectBase, public wire::ObjectIdProvider {
public:
Device(CommandSerializer* serializer)
: ObjectBase(this, 1, 1),
@ -205,6 +204,13 @@ namespace wire {
ObjectAllocator<{{type.name.CamelCase()}}> {{type.name.camelCase()}};
{% endfor %}
// Implementation of the ObjectIdProvider interface
{% for type in by_category["object"] %}
ObjectId GetId({{as_cType(type.name)}} object) const override {
return reinterpret_cast<{{as_backendType(type)}}>(object)->id;
}
{% endfor %}
void HandleError(const char* message) {
if (errorCallback) {
errorCallback(message, errorUserdata);
@ -226,55 +232,18 @@ namespace wire {
{% set Suffix = as_MethodSuffix(type.name, method.name) %}
{{as_backendType(method.return_type)}} Client{{Suffix}}(
{{-as_backendType(type)}} self
{{-as_cType(type.name)}} cSelf
{%- for arg in method.arguments -%}
, {{as_annotated_backendType(arg)}}
, {{as_annotated_cType(arg)}}
{%- endfor -%}
) {
{{as_backendType(type)}} self = reinterpret_cast<{{as_backendType(type)}}>(cSelf);
Device* device = self->device;
wire::{{Suffix}}Cmd cmd;
//* Create the structure going on the wire on the stack and fill it with the value
//* arguments so it can compute its size.
{
//* Value objects are stored as IDs
{% for arg in method.arguments if arg.annotation == "value" %}
{% if arg.type.category == "object" %}
cmd.{{as_varName(arg.name)}} = {{as_varName(arg.name)}}->id;
{% else %}
cmd.{{as_varName(arg.name)}} = {{as_varName(arg.name)}};
{% endif %}
{% endfor %}
cmd.self = self->id;
//* The length of const char* is considered a value argument.
{% for arg in method.arguments if arg.length == "strlen" %}
cmd.{{as_varName(arg.name)}}Strlen = strlen({{as_varName(arg.name)}});
{% endfor %}
}
//* Allocate space to send the command and copy the value args over.
size_t requiredSize = cmd.GetRequiredSize();
auto allocCmd = reinterpret_cast<decltype(cmd)*>(device->GetCmdSpace(requiredSize));
*allocCmd = cmd;
//* In the allocated space, write the non-value arguments.
{% for arg in method.arguments if arg.annotation != "value" %}
{% set argName = as_varName(arg.name) %}
{% if arg.length == "strlen" %}
memcpy(allocCmd->GetPtr_{{argName}}(), {{argName}}, allocCmd->{{argName}}Strlen + 1);
{% elif arg.length == "constant_one" %}
memcpy(allocCmd->GetPtr_{{argName}}(), {{argName}}, sizeof(*{{argName}}));
{% elif arg.type.category == "object" %}
auto {{argName}}Storage = reinterpret_cast<uint32_t*>(allocCmd->GetPtr_{{argName}}());
for (size_t i = 0; i < {{as_varName(arg.length.name)}}; i++) {
{{argName}}Storage[i] = {{argName}}[i]->id;
}
{% else %}
memcpy(allocCmd->GetPtr_{{argName}}(), {{argName}}, {{as_varName(arg.length.name)}} * sizeof(*{{argName}}));
{% endif %}
{% endfor %}
cmd.self = cSelf;
//* For object creation, store the object ID the client will use for the result.
{% if method.return_type.category == "object" %}
@ -288,8 +257,20 @@ namespace wire {
self->builderCallback.canCall = false;
{% endif %}
allocCmd->resultId = allocation->object->id;
allocCmd->resultSerial = allocation->serial;
cmd.resultId = allocation->object->id;
cmd.resultSerial = allocation->serial;
{% endif %}
{% for arg in method.arguments %}
cmd.{{as_varName(arg.name)}} = {{as_varName(arg.name)}};
{% endfor %}
//* Allocate space to send the command and copy the value args over.
size_t requiredSize = cmd.GetRequiredSize();
char* allocatedBuffer = static_cast<char*>(device->GetCmdSpace(requiredSize));
cmd.Serialize(allocatedBuffer, *device);
{% if method.return_type.category == "object" %}
return allocation->object.get();
{% endif %}
}
@ -320,8 +301,7 @@ namespace wire {
wire::{{as_MethodSuffix(type.name, Name("destroy"))}}Cmd cmd;
cmd.objectId = obj->id;
size_t requiredSize = cmd.GetRequiredSize();
auto allocCmd = reinterpret_cast<decltype(cmd)*>(obj->device->GetCmdSpace(requiredSize));
auto allocCmd = static_cast<decltype(cmd)*>(obj->device->GetCmdSpace(sizeof(cmd)));
*allocCmd = cmd;
obj->device->{{type.name.camelCase()}}.Free(obj);
@ -349,8 +329,7 @@ namespace wire {
cmd.start = start;
cmd.size = size;
size_t requiredSize = cmd.GetRequiredSize();
auto allocCmd = reinterpret_cast<decltype(cmd)*>(buffer->device->GetCmdSpace(requiredSize));
auto allocCmd = static_cast<decltype(cmd)*>(buffer->device->GetCmdSpace(sizeof(cmd)));
*allocCmd = cmd;
}
@ -359,7 +338,9 @@ namespace wire {
ASSERT(false);
}
void ProxyClientBufferUnmap(Buffer* buffer) {
void ProxyClientBufferUnmap(nxtBuffer cBuffer) {
Buffer* buffer = reinterpret_cast<Buffer*>(cBuffer);
//* Invalidate the local pointer, and cancel all other in-flight requests that would turn into
//* errors anyway (you can't double map). This prevents race when the following happens, where
//* the application code would have unmapped a buffer but still receive a callback:
@ -373,7 +354,7 @@ namespace wire {
}
buffer->ClearMapRequests(NXT_BUFFER_MAP_ASYNC_STATUS_UNKNOWN);
ClientBufferUnmap(buffer);
ClientBufferUnmap(cBuffer);
}
void ClientDeviceReference(Device*) {
@ -414,7 +395,7 @@ namespace wire {
Client(Device* device) : mDevice(device) {
}
const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override {
const char* HandleCommands(const char* commands, size_t size) override {
while (size > sizeof(ReturnWireCmd)) {
ReturnWireCmd cmdId = *reinterpret_cast<const ReturnWireCmd*>(commands);
@ -453,49 +434,52 @@ namespace wire {
//* Helper function for the getting of the command data in command handlers.
//* Checks there is enough data left, updates the buffer / size and returns
//* the command (or nullptr for an error).
template<typename T>
static const T* GetCommand(const uint8_t** commands, size_t* size) {
if (*size < sizeof(T)) {
template <typename T>
static const T* GetData(const char** buffer, size_t* size, size_t count) {
// TODO(cwallez@chromium.org): Check for overflow
size_t totalSize = count * sizeof(T);
if (*size < totalSize) {
return nullptr;
}
const T* cmd = reinterpret_cast<const T*>(*commands);
const T* data = reinterpret_cast<const T*>(*buffer);
size_t cmdSize = cmd->GetRequiredSize();
if (*size < cmdSize) {
return nullptr;
*buffer += totalSize;
*size -= totalSize;
return data;
}
template <typename T>
static const T* GetCommand(const char** commands, size_t* size) {
return GetData<T>(commands, size, 1);
}
*commands += cmdSize;
*size -= cmdSize;
return cmd;
}
bool HandleDeviceErrorCallbackCmd(const uint8_t** commands, size_t* size) {
bool HandleDeviceErrorCallbackCmd(const char** commands, size_t* size) {
const auto* cmd = GetCommand<ReturnDeviceErrorCallbackCmd>(commands, size);
if (cmd == nullptr) {
return false;
}
if (cmd->GetMessage()[cmd->messageStrlen] != '\0') {
const char* message = GetData<char>(commands, size, cmd->messageStrlen + 1);
if (message == nullptr || message[cmd->messageStrlen] != '\0') {
return false;
}
mDevice->HandleError(cmd->GetMessage());
mDevice->HandleError(message);
return true;
}
{% for type in by_category["object"] if type.is_builder %}
{% set Type = type.name.CamelCase() %}
bool Handle{{Type}}ErrorCallbackCmd(const uint8_t** commands, size_t* size) {
bool Handle{{Type}}ErrorCallbackCmd(const char** commands, size_t* size) {
const auto* cmd = GetCommand<Return{{Type}}ErrorCallbackCmd>(commands, size);
if (cmd == nullptr) {
return false;
}
if (cmd->GetMessage()[cmd->messageStrlen] != '\0') {
const char* message = GetData<char>(commands, size, cmd->messageStrlen + 1);
if (message == nullptr || message[cmd->messageStrlen] != '\0') {
return false;
}
@ -507,18 +491,18 @@ namespace wire {
return true;
}
bool called = builtObject->builderCallback.Call(static_cast<nxtBuilderErrorStatus>(cmd->status), cmd->GetMessage());
bool called = builtObject->builderCallback.Call(static_cast<nxtBuilderErrorStatus>(cmd->status), message);
// Unhandled builder errors are forwarded to the device
if (!called && cmd->status != NXT_BUILDER_ERROR_STATUS_SUCCESS && cmd->status != NXT_BUILDER_ERROR_STATUS_UNKNOWN) {
mDevice->HandleError(("Unhandled builder error: " + std::string(cmd->GetMessage())).c_str());
mDevice->HandleError(("Unhandled builder error: " + std::string(message)).c_str());
}
return true;
}
{% endfor %}
bool HandleBufferMapReadAsyncCallback(const uint8_t** commands, size_t* size) {
bool HandleBufferMapReadAsyncCallback(const char** commands, size_t* size) {
const auto* cmd = GetCommand<ReturnBufferMapReadAsyncCallbackCmd>(commands, size);
if (cmd == nullptr) {
return false;
@ -554,8 +538,14 @@ namespace wire {
if (buffer->mappedData != nullptr) {
return false;
}
const char* requestData = GetData<char>(commands, size, request.size);
if (requestData == nullptr) {
return false;
}
buffer->mappedData = malloc(request.size);
memcpy(buffer->mappedData, cmd->GetData(), request.size);
memcpy(buffer->mappedData, requestData, request.size);
request.callback(static_cast<nxtBufferMapAsyncStatus>(cmd->status), buffer->mappedData, request.userdata);
} else {
@ -577,5 +567,4 @@ namespace wire {
return new client::Client(clientDevice);
}
}
}
}} // namespace nxt::wire

View File

@ -12,83 +12,246 @@
//* See the License for the specific language governing permissions and
//* limitations under the License.
#include "wire/WireCmd_autogen.h"
#include "wire/WireCmd.h"
namespace nxt {
namespace wire {
#include <cstring>
namespace nxt { namespace wire {
// Macro to simplify error handling, similar to NXT_TRY but for DeserializeResult.
#define DESERIALIZE_TRY(EXPR) \
{ \
DeserializeResult exprResult = EXPR; \
if (exprResult != DeserializeResult::Success) { \
return exprResult; \
} \
}
// Consumes from (buffer, size) enough memory to contain T[count] and return it in data.
// Returns FatalError if not enough memory was available
template <typename T>
DeserializeResult GetPtrFromBuffer(const char** buffer, size_t* size, size_t count, const T** data) {
// TODO(cwallez@chromium.org): For robustness we would need to handle overflows here.
size_t totalSize = sizeof(T) * count;
if (totalSize > *size) {
return DeserializeResult::FatalError;
}
*data = reinterpret_cast<const T*>(*buffer);
*buffer += totalSize;
*size -= totalSize;
return DeserializeResult::Success;
}
// Allocates enough space from allocator to countain T[count] and return it in out.
// Return FatalError if the allocator couldn't allocate the memory.
template <typename T>
DeserializeResult GetSpace(DeserializeAllocator* allocator, size_t count, T** out) {
// TODO(cwallez@chromium.org): For robustness we would need to handle overflows here.
size_t totalSize = sizeof(T) * count;
*out = static_cast<T*>(allocator->GetSpace(totalSize));
if (*out == nullptr) {
return DeserializeResult::FatalError;
}
return DeserializeResult::Success;
}
{% for type in by_category["object"] %}
{% for method in type.methods %}
{% set Suffix = as_MethodSuffix(type.name, method.name) %}
{% set Cmd = Suffix + "Cmd" %}
size_t {{Suffix}}Cmd::GetRequiredSize() const {
size_t result = sizeof(*this);
//* Structure for the wire format of each of the commands. Parameters passed by value
//* are embedded directly in the structure. Other parameters are assumed to be in the
//* memory directly following the structure in the buffer. With value parameters the
//* structure can compute how much buffer size it needs and where the start of non-value
//* parameters is in the buffer.
struct {{Cmd}}Transfer {
//* Start the structure with the command ID, so that casting to WireCmd gives the ID.
wire::WireCmd commandId;
ObjectId self;
{% if method.return_type.category == "object" %}
ObjectId resultId;
ObjectSerial resultSerial;
{% endif %}
//* Value types are directly in the command, objects being replaced with their IDs.
{% for arg in method.arguments if arg.annotation == "value" %}
{% if arg.type.category == "object" %}
ObjectId {{as_varName(arg.name)}};
{% else %}
{{as_cType(arg.type.name)}} {{as_varName(arg.name)}};
{% endif %}
{% endfor %}
//* const char* have their length embedded directly in the command.
{% for arg in method.arguments if arg.length == "strlen" %}
size_t {{as_varName(arg.name)}}Strlen;
{% endfor %}
};
size_t {{Cmd}}::GetRequiredSize() const {
size_t result = sizeof({{Cmd}}Transfer);
{% for arg in method.arguments if arg.annotation != "value" %}
{% set argName = as_varName(arg.name) %}
{% if arg.length == "strlen" %}
result += {{as_varName(arg.name)}}Strlen + 1;
result += std::strlen({{as_varName(arg.name)}});
{% elif arg.length == "constant_one" %}
result += sizeof({{as_cType(arg.type.name)}});
{% elif arg.type.category == "object" %}
result += {{as_varName(arg.length.name)}} * sizeof(uint32_t);
result += {{as_varName(arg.length.name)}} * sizeof(ObjectId);
{% else %}
result += {{as_varName(arg.length.name)}} * sizeof({{as_cType(arg.type.name)}});
{% endif %}
{% endfor %}
return result;
}
{% for const in ["", "const"] %}
{% for get_arg in method.arguments if get_arg.annotation != "value" %}
void {{Cmd}}::Serialize(char* buffer, const ObjectIdProvider& objectIdProvider) const {
auto transfer = reinterpret_cast<{{Cmd}}Transfer*>(buffer);
buffer += sizeof({{Cmd}}Transfer);
{{const}} uint8_t* {{Suffix}}Cmd::GetPtr_{{as_varName(get_arg.name)}}() {{const}} {
//* Start counting after the current structure
{{const}} uint8_t* ptr = reinterpret_cast<{{const}} uint8_t*>(this + 1);
transfer->commandId = wire::WireCmd::{{Suffix}};
transfer->self = objectIdProvider.GetId(self);
//* Increment the pointer until we find the 'arg' then return early.
//* This will mean some of the code will be unreachable but there is no
//* "break" in Jinja2.
{% for arg in method.arguments if arg.annotation != "value" %}
{% if get_arg == arg %}
return ptr;
{% if method.return_type.category == "object" %}
transfer->resultId = resultId;
transfer->resultSerial = resultSerial;
{% endif %}
{% if arg.length == "strlen" %}
ptr += {{as_varName(arg.name)}}Strlen + 1;
{% elif arg.length == "constant_one" %}
ptr += sizeof({{as_cType(arg.type.name)}});
{% elif arg.type.category == "object" %}
ptr += {{as_varName(arg.length.name)}} * sizeof(uint32_t);
//* Value types are directly in the command, objects being replaced with their IDs.
{% for arg in method.arguments if arg.annotation == "value" %}
{% set argName = as_varName(arg.name) %}
{% if arg.type.category == "object" %}
transfer->{{argName}} = objectIdProvider.GetId(this->{{argName}});
{% else %}
ptr += {{as_varName(arg.length.name)}} * sizeof({{as_cType(arg.type.name)}});
transfer->{{argName}} = this->{{argName}};
{% endif %}
{% endfor %}
//* const char* have their length embedded directly in the command.
{% for arg in method.arguments if arg.length == "strlen" %}
{% set argName = as_varName(arg.name) %}
transfer->{{argName}}Strlen = std::strlen(this->{{argName}});
{% endfor %}
//* In the allocated space, write the non-value arguments.
{% for arg in method.arguments if arg.annotation != "value" %}
{% set argName = as_varName(arg.name) %}
{% if arg.length == "strlen" %}
memcpy(buffer, this->{{argName}}, transfer->{{argName}}Strlen);
buffer += transfer->{{argName}}Strlen;
{% elif arg.length == "constant_one" %}
memcpy(buffer, this->{{argName}}, sizeof(*(this->{{argName}})));
buffer += sizeof(*(this->{{argName}}));
{% elif arg.type.category == "object" %}
{% set argLength = as_varName(arg.length.name) %}
auto {{argName}}Storage = reinterpret_cast<ObjectId*>(buffer);
for (size_t i = 0; i < {{argLength}}; i++) {
{{argName}}Storage[i] = objectIdProvider.GetId(this->{{argName}}[i]);
}
buffer += sizeof(ObjectId) * {{argLength}};
{% else %}
{% set argLength = as_varName(arg.length.name) %}
memcpy(buffer, this->{{argName}}, {{argLength}} * sizeof(*(this->{{argName}})));
buffer += {{argLength}} * sizeof(*(this->{{argName}}));
{% endif %}
{% endfor %}
}
{% endfor %}
{% endfor %}
DeserializeResult {{Cmd}}::Deserialize(const char** buffer, size_t* size, DeserializeAllocator* allocator, const ObjectIdResolver& resolver) {
(void) allocator;
const {{Cmd}}Transfer* transfer = nullptr;
DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, 1, &transfer));
selfId = transfer->self;
{% if method.return_type.category == "object" %}
resultId = transfer->resultId;
resultSerial = transfer->resultSerial;
{% endif %}
DESERIALIZE_TRY(resolver.GetFromId(selfId, &self));
{% for arg in method.arguments if arg.annotation == "value" %}
{% set argName = as_varName(arg.name) %}
{% if arg.type.category == "object" %}
DESERIALIZE_TRY(resolver.GetFromId(transfer->{{argName}}, &(this->{{argName}})));
{% else %}
this->{{argName}} = transfer->{{argName}};
{% endif %}
{% endfor %}
{% set Suffix = as_MethodSuffix(type.name, Name("destroy")) %}
size_t {{Suffix}}Cmd::GetRequiredSize() const {
return sizeof(*this);
{% for arg in method.arguments if arg.annotation != "value" %}
{% set argName = as_varName(arg.name) %}
{% if arg.length == "strlen" %}
{
size_t stringLength = transfer->{{argName}}Strlen;
const char* stringInBuffer = nullptr;
DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, stringLength, &stringInBuffer));
char* copiedString = nullptr;
DESERIALIZE_TRY(GetSpace(allocator, stringLength + 1, &copiedString));
memcpy(copiedString, stringInBuffer, stringLength);
copiedString[stringLength] = '\0';
this->{{argName}} = copiedString;
}
{% elif arg.length == "constant_one" %}
{
const {{as_cType(arg.type.name)}}* argInBuffer = nullptr;
DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, 1, &argInBuffer));
{{as_cType(arg.type.name)}}* copiedArg = nullptr;
DESERIALIZE_TRY(GetSpace(allocator, 1, &copiedArg));
memcpy(copiedArg, argInBuffer, sizeof(*{{argName}}));
this->{{argName}} = copiedArg;
}
{% elif arg.type.category == "object" %}
{% set argLength = as_varName(arg.length.name) %}
{
const ObjectId* idsInBuffer = nullptr;
DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, {{argLength}}, &idsInBuffer));
{{as_cType(arg.type.name)}}* copiedObjects = nullptr;
DESERIALIZE_TRY(GetSpace(allocator, {{argLength}}, &copiedObjects));
for (size_t i = 0; i < {{argLength}}; i++) {
DESERIALIZE_TRY(resolver.GetFromId(idsInBuffer[i], &copiedObjects[i]));
}
this->{{argName}} = copiedObjects;
}
{% else %}
{% set argLength = as_varName(arg.length.name) %}
{
const {{as_cType(arg.type.name)}}* argInBuffer = nullptr;
DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, {{argLength}}, &argInBuffer));
{{as_cType(arg.type.name)}}* copiedArg = nullptr;
DESERIALIZE_TRY(GetSpace(allocator, {{argLength}}, &copiedArg))
memcpy(copiedArg, argInBuffer, {{argLength}} * sizeof(*{{argName}}));
this->{{argName}} = copiedArg;
}
{% endif %}
{% endfor %}
return DeserializeResult::Success;
}
{% endfor %}
{% for type in by_category["object"] if type.is_builder %}
{% set Type = type.name.CamelCase() %}
size_t Return{{Type}}ErrorCallbackCmd::GetRequiredSize() const {
return sizeof(*this) + messageStrlen + 1;
}
char* Return{{Type}}ErrorCallbackCmd::GetMessage() {
return reinterpret_cast<char*>(this + 1);
}
const char* Return{{Type}}ErrorCallbackCmd::GetMessage() const {
return reinterpret_cast<const char*>(this + 1);
}
{% endfor %}
}
}
}} // namespace nxt::wire

View File

@ -15,10 +15,41 @@
#ifndef WIRE_WIRECMD_AUTOGEN_H_
#define WIRE_WIRECMD_AUTOGEN_H_
#include <nxt/nxt.h>
namespace nxt { namespace wire {
namespace nxt {
namespace wire {
using ObjectId = uint32_t;
using ObjectSerial = uint32_t;
enum class DeserializeResult {
Success,
FatalError,
ErrorObject,
};
// Interface to allocate more space to deserialize pointed-to data.
// nullptr is treated as an error.
class DeserializeAllocator {
public:
virtual void* GetSpace(size_t size) = 0;
};
// Interface to convert an ID to a server object, if possible.
// Methods return FatalError if the ID is for a non-existent object, ErrorObject if the
// object is an error value and Success otherwise.
class ObjectIdResolver {
public:
{% for type in by_category["object"] %}
virtual DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const = 0;
{% endfor %}
};
// Interface to convert a client object to its ID for the wiring.
class ObjectIdProvider {
public:
{% for type in by_category["object"] %}
virtual ObjectId GetId({{as_cType(type.name)}} object) const = 0;
{% endfor %}
};
//* Enum used as a prefix to each command on the wire format.
enum class WireCmd : uint32_t {
@ -34,50 +65,45 @@ namespace wire {
{% for type in by_category["object"] %}
{% for method in type.methods %}
{% set Suffix = as_MethodSuffix(type.name, method.name) %}
{% set Cmd = Suffix + "Cmd" %}
//* Structure for the wire format of each of the commands. Parameters passed by value
//* are embedded directly in the structure. Other parameters are assumed to be in the
//* memory directly following the structure in the buffer. With value parameters the
//* structure can compute how much buffer size it needs and where the start of non-value
//* parameters is in the buffer.
struct {{Suffix}}Cmd {
//* These are "structure" version of the list of arguments to the different NXT methods.
//* They provide helpers to serialize/deserialize to/from a buffer.
struct {{Cmd}} {
//* From a filled structure, compute how much size will be used in the serialization buffer.
size_t GetRequiredSize() const;
//* Start the structure with the command ID, so that casting to WireCmd gives the ID.
wire::WireCmd commandId = wire::WireCmd::{{Suffix}};
//* Serialize the structure and everything it points to into serializeBuffer which must be
//* big enough to contain all the data (as queried from GetRequiredSize).
void Serialize(char* serializeBuffer, const ObjectIdProvider& objectIdProvider) const;
uint32_t self;
//* Deserializes the structure from a buffer, consuming a maximum of *size bytes. When this
//* function returns, buffer and size will be updated by the number of bytes consumed to
//* deserialize the structure. Structures containing pointers will use allocator to get
//* scratch space to deserialize the pointed-to data.
//* Deserialize returns:
//* - Success if everything went well (yay!)
//* - FatalError is something bad happened (buffer too small for example)
//* - ErrorObject if one if the deserialized object is an error value, for the implementation
//* of the Maybe monad.
//* If the return value is not FatalError, selfId, resultId and resultSerial (if present) are
//* filled.
DeserializeResult Deserialize(const char** buffer, size_t* size, DeserializeAllocator* allocator, const ObjectIdResolver& resolver);
{{as_cType(type.name)}} self;
//* Command handlers want to know the object ID in addition to the backing object.
//* Doesn't need to be filled before Serialize, or GetRequiredSize.
ObjectId selfId;
//* Commands creating objects say which ID the created object will be referred as.
{% if method.return_type.category == "object" %}
uint32_t resultId;
uint32_t resultSerial;
ObjectId resultId;
ObjectSerial resultSerial;
{% endif %}
//* Value types are directly in the command, objects being replaced with their IDs.
{% for arg in method.arguments if arg.annotation == "value" %}
{% if arg.type.category == "object" %}
uint32_t {{as_varName(arg.name)}};
{% else %}
{{as_cType(arg.type.name)}} {{as_varName(arg.name)}};
{% endif %}
{% endfor %}
//* const char* have their length embedded directly in the command.
{% for arg in method.arguments if arg.length == "strlen" %}
size_t {{as_varName(arg.name)}}Strlen;
{% endfor %}
//* The following commands do computation, provided the members for value parameters
//* have been initialized.
//* Compute how much buffer memory is required to hold the structure and all its arguments.
size_t GetRequiredSize() const;
//* Gets the pointer to the start of the buffer containing a non-value parameter.
{% for get_arg in method.arguments if get_arg.annotation != "value" %}
{% set ArgName = as_varName(get_arg.name) %}
uint8_t* GetPtr_{{ArgName}}();
const uint8_t* GetPtr_{{ArgName}}() const;
{% for arg in method.arguments %}
{{as_annotated_cType(arg)}};
{% endfor %}
};
{% endfor %}
@ -86,9 +112,7 @@ namespace wire {
{% set Suffix = as_MethodSuffix(type.name, Name("destroy")) %}
struct {{Suffix}}Cmd {
WireCmd commandId = WireCmd::{{Suffix}};
uint32_t objectId;
size_t GetRequiredSize() const;
ObjectId objectId;
};
{% endfor %}
@ -102,22 +126,18 @@ namespace wire {
BufferMapReadAsyncCallback,
};
//* Command for the server calling a builder status callback.
{% for type in by_category["object"] if type.is_builder %}
struct Return{{type.name.CamelCase()}}ErrorCallbackCmd {
wire::ReturnWireCmd commandId = ReturnWireCmd::{{type.name.CamelCase()}}ErrorCallback;
uint32_t builtObjectId;
uint32_t builtObjectSerial;
ObjectId builtObjectId;
ObjectSerial builtObjectSerial;
uint32_t status;
size_t messageStrlen;
size_t GetRequiredSize() const;
char* GetMessage();
const char* GetMessage() const;
};
{% endfor %}
}
}
}} // namespace nxt::wire
#endif // WIRE_WIRECMD_AUTOGEN_H_

View File

@ -17,11 +17,12 @@
#include "common/Assert.h"
#include <algorithm>
#include <cstdlib>
#include <cstring>
#include <vector>
namespace nxt {
namespace wire {
namespace nxt { namespace wire {
namespace server {
class Server;
@ -72,6 +73,19 @@ namespace wire {
//* Get a backend objects for a given client ID.
//* Returns nullptr if the ID hasn't previously been allocated.
const Data* Get(uint32_t id) const {
if (id >= mKnown.size()) {
return nullptr;
}
const Data* data = &mKnown[id];
if (!data->allocated) {
return nullptr;
}
return data;
}
Data* Get(uint32_t id) {
if (id >= mKnown.size()) {
return nullptr;
@ -130,7 +144,60 @@ namespace wire {
void ForwardBufferMapReadAsync(nxtBufferMapAsyncStatus status, const void* ptr, nxtCallbackUserdata userdata);
class Server : public CommandHandler {
// A really really simple implementation of the DeserializeAllocator. It's main feature
// is that it has some inline storage so as to avoid allocations for the majority of
// commands.
class ServerAllocator : public DeserializeAllocator {
public:
ServerAllocator() {
Reset();
}
~ServerAllocator() {
Reset();
}
void* GetSpace(size_t size) override {
// Return space in the current buffer if possible first.
if (mRemainingSize >= size) {
char* buffer = mCurrentBuffer;
mCurrentBuffer += size;
mRemainingSize -= size;
return buffer;
}
// Otherwise allocate a new buffer and try again.
size_t allocationSize = std::max(size, size_t(2048));
char* allocation = static_cast<char*>(malloc(allocationSize));
if (allocation == nullptr) {
return nullptr;
}
mAllocations.push_back(allocation);
mCurrentBuffer = allocation;
mRemainingSize = allocationSize;
return GetSpace(size);
}
void Reset() {
for (auto allocation : mAllocations) {
free(allocation);
}
mAllocations.clear();
// The initial buffer is the inline buffer so that some allocations can be skipped
mCurrentBuffer = mStaticBuffer;
mRemainingSize = sizeof(mStaticBuffer);
}
private:
size_t mRemainingSize = 0;
char* mCurrentBuffer = nullptr;
char mStaticBuffer[2048];
std::vector<char*> mAllocations;
};
class Server : public CommandHandler, public ObjectIdResolver {
public:
Server(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer)
: mProcs(procs), mSerializer(serializer) {
@ -147,9 +214,11 @@ namespace wire {
ReturnDeviceErrorCallbackCmd cmd;
cmd.messageStrlen = std::strlen(message);
auto allocCmd = reinterpret_cast<ReturnDeviceErrorCallbackCmd*>(GetCmdSpace(cmd.GetRequiredSize()));
auto allocCmd = static_cast<ReturnDeviceErrorCallbackCmd*>(GetCmdSpace(sizeof(cmd)));
*allocCmd = cmd;
strcpy(allocCmd->GetMessage(), message);
char* messageAlloc = static_cast<char*>(GetCmdSpace(cmd.messageStrlen + 1));
strcpy(messageAlloc, message);
}
{% for type in by_category["object"] if type.is_builder%}
@ -176,9 +245,10 @@ namespace wire {
cmd.status = status;
cmd.messageStrlen = std::strlen(message);
auto allocCmd = reinterpret_cast<Return{{Type}}ErrorCallbackCmd*>(GetCmdSpace(cmd.GetRequiredSize()));
auto allocCmd = static_cast<Return{{Type}}ErrorCallbackCmd*>(GetCmdSpace(sizeof(cmd)));
*allocCmd = cmd;
strcpy(allocCmd->GetMessage(), message);
char* messageAlloc = static_cast<char*>(GetCmdSpace(strlen(message) + 1));
strcpy(messageAlloc, message);
}
}
{% endfor %}
@ -189,23 +259,22 @@ namespace wire {
cmd.bufferSerial = data->bufferSerial;
cmd.requestSerial = data->requestSerial;
cmd.status = status;
cmd.dataLength = 0;
if (status == NXT_BUFFER_MAP_ASYNC_STATUS_SUCCESS) {
cmd.dataLength = data->size;
}
auto allocCmd = reinterpret_cast<ReturnBufferMapReadAsyncCallbackCmd*>(GetCmdSpace(cmd.GetRequiredSize()));
auto allocCmd = static_cast<ReturnBufferMapReadAsyncCallbackCmd*>(GetCmdSpace(sizeof(cmd)));
*allocCmd = cmd;
if (status == NXT_BUFFER_MAP_ASYNC_STATUS_SUCCESS) {
memcpy(allocCmd->GetData(), ptr, data->size);
allocCmd->dataLength = data->size;
void* dataAlloc = GetCmdSpace(data->size);
memcpy(dataAlloc, ptr, data->size);
}
delete data;
}
const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override {
const char* HandleCommands(const char* commands, size_t size) override {
mProcs.deviceTick(mKnownDevice.Get(1)->handle);
while (size > sizeof(WireCmd)) {
@ -236,6 +305,7 @@ namespace wire {
if (!success) {
return nullptr;
}
mAllocator.Reset();
}
if (size != 0) {
@ -249,10 +319,29 @@ namespace wire {
nxtProcTable mProcs;
CommandSerializer* mSerializer = nullptr;
ServerAllocator mAllocator;
void* GetCmdSpace(size_t size) {
return mSerializer->GetCmdSpace(size);
}
// Implementation of the ObjectIdResolver interface
{% for type in by_category["object"] %}
DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const override {
auto data = mKnown{{type.name.CamelCase()}}.Get(id);
if (data == nullptr) {
return DeserializeResult::FatalError;
}
*out = data->handle;
if (data->valid) {
return DeserializeResult::Success;
} else {
return DeserializeResult::ErrorObject;
}
}
{% endfor %}
//* The list of known IDs for each object type.
{% for type in by_category["object"] %}
KnownObjects<{{as_cType(type.name)}}> mKnown{{type.name.CamelCase()}};
@ -262,20 +351,15 @@ namespace wire {
//* Checks there is enough data left, updates the buffer / size and returns
//* the command (or nullptr for an error).
template<typename T>
static const T* GetCommand(const uint8_t** commands, size_t* size) {
static const T* GetCommand(const char** commands, size_t* size) {
if (*size < sizeof(T)) {
return nullptr;
}
const T* cmd = reinterpret_cast<const T*>(*commands);
size_t cmdSize = cmd->GetRequiredSize();
if (*size < cmdSize) {
return nullptr;
}
*commands += cmdSize;
*size -= cmdSize;
*commands += sizeof(T);
*size -= sizeof(T);
return cmd;
}
@ -287,99 +371,42 @@ namespace wire {
//* The generic command handlers
bool Handle{{Suffix}}(const uint8_t** commands, size_t* size) {
//* Get command ptr, and check it fits in the buffer.
const auto* cmd = GetCommand<{{Suffix}}Cmd>(commands, size);
if (cmd == nullptr) {
bool Handle{{Suffix}}(const char** commands, size_t* size) {
{{Suffix}}Cmd cmd;
DeserializeResult deserializeResult = cmd.Deserialize(commands, size, &mAllocator, *this);
if (deserializeResult == DeserializeResult::FatalError) {
return false;
}
//* While unpacking arguments, if any of them is an error, valid will be set to false.
bool valid = true;
//* Unpack 'self'
{% set Type = type.name.CamelCase() %}
{{as_cType(type.name)}} self;
auto* selfData = mKnown{{Type}}.Get(cmd->self);
{
if (selfData == nullptr) {
return false;
}
valid = valid && selfData->valid;
self = selfData->handle;
}
//* Unpack value objects from IDs.
{% for arg in method.arguments if arg.annotation == "value" and arg.type.category == "object" %}
{% set Type = arg.type.name.CamelCase() %}
{{as_cType(arg.type.name)}} arg_{{as_varName(arg.name)}};
{
auto* data = mKnown{{Type}}.Get(cmd->{{as_varName(arg.name)}});
if (data == nullptr) {
return false;
}
valid = valid && data->valid;
arg_{{as_varName(arg.name)}} = data->handle;
}
{% endfor %}
//* Unpack pointer arguments
{% for arg in method.arguments if arg.annotation != "value" %}
{% set argName = as_varName(arg.name) %}
const {{as_cType(arg.type.name)}}* arg_{{argName}};
{% if arg.length == "strlen" %}
//* Unpack strings, checking they are null-terminated.
arg_{{argName}} = reinterpret_cast<const {{as_cType(arg.type.name)}}*>(cmd->GetPtr_{{argName}}());
if (arg_{{argName}}[cmd->{{argName}}Strlen] != 0) {
return false;
}
{% elif arg.type.category == "object" %}
//* Unpack arrays of objects.
//* TODO(cwallez@chromium.org) do not allocate when there are few objects.
std::vector<{{as_cType(arg.type.name)}}> {{argName}}Storage(cmd->{{as_varName(arg.length.name)}});
auto {{argName}}Ids = reinterpret_cast<const uint32_t*>(cmd->GetPtr_{{argName}}());
for (size_t i = 0; i < cmd->{{as_varName(arg.length.name)}}; i++) {
{% set Type = arg.type.name.CamelCase() %}
auto* data = mKnown{{Type}}.Get({{argName}}Ids[i]);
if (data == nullptr) {
return false;
}
{{argName}}Storage[i] = data->handle;
valid = valid && data->valid;
}
arg_{{argName}} = {{argName}}Storage.data();
{% else %}
//* For anything else, just get the pointer.
arg_{{argName}} = reinterpret_cast<const {{as_cType(arg.type.name)}}*>(cmd->GetPtr_{{argName}}());
{% endif %}
{% endfor %}
//* At that point all the data has been upacked in cmd->* or arg_*
auto* selfData = mKnown{{type.name.CamelCase()}}.Get(cmd.selfId);
ASSERT(selfData != nullptr);
//* In all cases allocate the object data as it will be refered-to by the client.
{% set return_type = method.return_type %}
{% set returns = return_type.name.canonical_case() != "void" %}
{% if returns %}
{% set Type = method.return_type.name.CamelCase() %}
auto* resultData = mKnown{{Type}}.Allocate(cmd->resultId);
auto* resultData = mKnown{{Type}}.Allocate(cmd.resultId);
if (resultData == nullptr) {
return false;
}
resultData->serial = cmd->resultSerial;
resultData->serial = cmd.resultSerial;
{% if type.is_builder %}
selfData->builtObjectId = cmd->resultId;
selfData->builtObjectSerial = cmd->resultSerial;
selfData->builtObjectId = cmd.resultId;
selfData->builtObjectSerial = cmd.resultSerial;
{% endif %}
{% endif %}
//* After the data is allocated, apply the argument error propagation mechanism
if (!valid) {
if (deserializeResult == DeserializeResult::ErrorObject) {
{% if type.is_builder %}
selfData->valid = false;
//* If we are in GetResult, fake an error callback
{% if returns %}
On{{type.name.CamelCase()}}Error(NXT_BUILDER_ERROR_STATUS_ERROR, "Maybe monad", cmd->self, selfData->serial);
On{{type.name.CamelCase()}}Error(NXT_BUILDER_ERROR_STATUS_ERROR, "Maybe monad", cmd.selfId, selfData->serial);
{% endif %}
{% endif %}
return true;
@ -388,13 +415,9 @@ namespace wire {
{% if returns %}
auto result ={{" "}}
{%- endif %}
mProcs.{{as_varName(type.name, method.name)}}(self
mProcs.{{as_varName(type.name, method.name)}}(cmd.self
{%- for arg in method.arguments -%}
{%- if arg.annotation == "value" and arg.type.category != "object" -%}
, cmd->{{as_varName(arg.name)}}
{%- else -%}
, arg_{{as_varName(arg.name)}}
{%- endif -%}
, cmd.{{as_varName(arg.name)}}
{%- endfor -%}
);
@ -407,7 +430,7 @@ namespace wire {
{% if return_type.is_builder %}
if (result != nullptr) {
uint64_t userdata1 = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(this));
uint64_t userdata2 = (uint64_t(resultData->serial) << uint64_t(32)) + cmd->resultId;
uint64_t userdata2 = (uint64_t(resultData->serial) << uint64_t(32)) + cmd.resultId;
mProcs.{{as_varName(return_type.name, Name("set error callback"))}}(result, Forward{{return_type.name.CamelCase()}}ToClient, userdata1, userdata2);
}
{% endif %}
@ -420,18 +443,20 @@ namespace wire {
//* Handlers for the destruction of objects: clients do the tracking of the
//* reference / release and only send destroy on refcount = 0.
{% set Suffix = as_MethodSuffix(type.name, Name("destroy")) %}
bool Handle{{Suffix}}(const uint8_t** commands, size_t* size) {
bool Handle{{Suffix}}(const char** commands, size_t* size) {
const auto* cmd = GetCommand<{{Suffix}}Cmd>(commands, size);
if (cmd == nullptr) {
return false;
}
ObjectId objectId = cmd->objectId;
//* ID 0 are reserved for nullptr and cannot be destroyed.
if (cmd->objectId == 0) {
if (objectId == 0) {
return false;
}
auto* data = mKnown{{type.name.CamelCase()}}.Get(cmd->objectId);
auto* data = mKnown{{type.name.CamelCase()}}.Get(objectId);
if (data == nullptr) {
return false;
}
@ -440,12 +465,12 @@ namespace wire {
mProcs.{{as_varName(type.name, Name("release"))}}(data->handle);
}
mKnown{{type.name.CamelCase()}}.Free(cmd->objectId);
mKnown{{type.name.CamelCase()}}.Free(objectId);
return true;
}
{% endfor %}
bool HandleBufferMapReadAsync(const uint8_t** commands, size_t* size) {
bool HandleBufferMapReadAsync(const char** commands, size_t* size) {
//* These requests are just forwarded to the buffer, with userdata containing what the client
//* will require in the return command.
const auto* cmd = GetCommand<BufferMapReadAsyncCmd>(commands, size);
@ -453,17 +478,22 @@ namespace wire {
return false;
}
auto* buffer = mKnownBuffer.Get(cmd->bufferId);
ObjectId bufferId = cmd->bufferId;
uint32_t requestSerial = cmd->requestSerial;
uint32_t requestSize = cmd->size;
uint32_t requestStart = cmd->start;
auto* buffer = mKnownBuffer.Get(bufferId);
if (buffer == nullptr) {
return false;
}
auto* data = new MapReadUserdata;
data->server = this;
data->bufferId = cmd->bufferId;
data->bufferId = bufferId;
data->bufferSerial = buffer->serial;
data->requestSerial = cmd->requestSerial;
data->size = cmd->size;
data->requestSerial = requestSerial;
data->size = requestSize;
auto userdata = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(data));
@ -473,7 +503,7 @@ namespace wire {
return true;
}
mProcs.bufferMapReadAsync(buffer->handle, cmd->start, cmd->size, ForwardBufferMapReadAsync, userdata);
mProcs.bufferMapReadAsync(buffer->handle, requestStart, requestSize, ForwardBufferMapReadAsync, userdata);
return true;
}
@ -503,5 +533,4 @@ namespace wire {
return new server::Server(device, procs, serializer);
}
}
}
}} // namespace nxt::wire

View File

@ -25,7 +25,6 @@ Generate(
${GENERATOR_COMMON_ARGS}
-T wire
EXTRA_SOURCES
${WIRE_DIR}/WireCmd.cpp
${WIRE_DIR}/WireCmd.h
)
target_include_directories(wire_autogen PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})

View File

@ -31,7 +31,7 @@ namespace nxt { namespace wire {
return nullptr;
}
uint8_t* result = &mBuffer[mOffset];
char* result = &mBuffer[mOffset];
mOffset += size;
if (mOffset > sizeof(mBuffer)) {

View File

@ -34,7 +34,7 @@ namespace nxt { namespace wire {
private:
CommandHandler* mHandler = nullptr;
size_t mOffset = 0;
uint8_t mBuffer[10000000];
char mBuffer[10000000];
};
}} // namespace nxt::wire

View File

@ -31,7 +31,7 @@ namespace nxt { namespace wire {
class CommandHandler {
public:
virtual ~CommandHandler() = default;
virtual const uint8_t* HandleCommands(const uint8_t* commands, size_t size) = 0;
virtual const char* HandleCommands(const char* commands, size_t size) = 0;
};
CommandHandler* NewClientDevice(nxtProcTable* procs,

View File

@ -1,47 +0,0 @@
// Copyright 2017 The NXT Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "wire/WireCmd.h"
namespace nxt { namespace wire {
size_t ReturnDeviceErrorCallbackCmd::GetRequiredSize() const {
return sizeof(*this) + messageStrlen + 1;
}
char* ReturnDeviceErrorCallbackCmd::GetMessage() {
return reinterpret_cast<char*>(this + 1);
}
const char* ReturnDeviceErrorCallbackCmd::GetMessage() const {
return reinterpret_cast<const char*>(this + 1);
}
size_t BufferMapReadAsyncCmd::GetRequiredSize() const {
return sizeof(*this);
}
size_t ReturnBufferMapReadAsyncCallbackCmd::GetRequiredSize() const {
return sizeof(*this) + dataLength;
}
void* ReturnBufferMapReadAsyncCallbackCmd::GetData() {
return this + 1;
}
const void* ReturnBufferMapReadAsyncCallbackCmd::GetData() const {
return this + 1;
}
}} // namespace nxt::wire

View File

@ -15,6 +15,8 @@
#ifndef WIRE_WIRECMD_H_
#define WIRE_WIRECMD_H_
#include <nxt/nxt.h>
#include "wire/WireCmd_autogen.h"
namespace nxt { namespace wire {
@ -23,10 +25,6 @@ namespace nxt { namespace wire {
wire::ReturnWireCmd commandId = ReturnWireCmd::DeviceErrorCallback;
size_t messageStrlen;
size_t GetRequiredSize() const;
char* GetMessage();
const char* GetMessage() const;
};
struct BufferMapReadAsyncCmd {
@ -36,8 +34,6 @@ namespace nxt { namespace wire {
uint32_t requestSerial;
uint32_t start;
uint32_t size;
size_t GetRequiredSize() const;
};
struct ReturnBufferMapReadAsyncCallbackCmd {
@ -48,10 +44,6 @@ namespace nxt { namespace wire {
uint32_t requestSerial;
uint32_t status;
uint32_t dataLength;
size_t GetRequiredSize() const;
void* GetData();
const void* GetData() const;
};
}} // namespace nxt::wire