Implement the device error callback.

This adds support for "natively defined" API types like callbacks that
will have to be implemented manually for each target language. Also this
splits the concept of "native method" into a set of native methods per
language.

Removes the "Synchronous error" concept that was used to make builders
work in the maybe Monad, this will have to be reinroduced with builder
callbacks.
This commit is contained in:
Corentin Wallez 2017-04-20 14:42:36 -04:00 committed by Corentin Wallez
parent 682a8250b3
commit 4b410a33ca
15 changed files with 219 additions and 55 deletions

View File

@ -28,8 +28,6 @@
BackendBinding* CreateMetalBinding();
namespace backend {
void RegisterSynchronousErrorCallback(nxtDevice device, void(*)(const char*, void*), void* userData);
namespace opengl {
void Init(void* (*getProc)(const char*), nxtProcTable* procs, nxtDevice* device);
void HACKCLEAR();
@ -61,6 +59,10 @@ class OpenGLBinding : public BackendBinding {
}
};
void PrintDeviceError(const char* message, nxt::CallbackUserdata) {
std::cout << "Device error: " << message << std::endl;
}
enum class BackendType {
OpenGL,
Metal,
@ -83,15 +85,6 @@ static nxt::wire::CommandHandler* wireClient = nullptr;
static nxt::wire::TerribleCommandBuffer* c2sBuf = nullptr;
static nxt::wire::TerribleCommandBuffer* s2cBuf = nullptr;
void HandleSynchronousError(const char* errorMessage, void* userData) {
std::cerr << errorMessage << std::endl;
if (userData != nullptr) {
auto wireServer = reinterpret_cast<nxt::wire::CommandHandler*>(userData);
wireServer->OnSynchronousError();
}
}
void GetProcTableAndDevice(nxtProcTable* procs, nxt::Device* device) {
switch (backendType) {
case BackendType::OpenGL:
@ -147,8 +140,7 @@ void GetProcTableAndDevice(nxtProcTable* procs, nxt::Device* device) {
break;
}
//TODO(cwallez@chromium.org) this will disappear
backend::RegisterSynchronousErrorCallback(backendDevice, HandleSynchronousError, wireServer);
procs->deviceSetErrorCallback(device->Get(), PrintDeviceError, 0);
}
nxt::ShaderModule CreateShaderModule(const nxt::Device& device, nxt::ShaderStage stage, const char* source) {

View File

@ -74,6 +74,10 @@ class NativeType(Type):
def __init__(self, name, record):
Type.__init__(self, name, record, native=True)
class NativelyDefined(Type):
def __init__(self, name, record):
Type.__init__(self, name, record)
class MethodArgument:
def __init__(self, name, typ, annotation):
self.name = name
@ -86,11 +90,17 @@ class ObjectType(Type):
def __init__(self, name, record):
Type.__init__(self, name, record)
self.methods = []
self.native_methods = []
############################################################
# PARSE
############################################################
import json
def is_native_method(method):
return method.return_type.category == "natively defined" or \
any([arg.type.category == "natively defined" for arg in method.arguments])
def link_object(obj, types):
def make_method(record):
arguments = []
@ -110,13 +120,16 @@ def link_object(obj, types):
return Method(Name(record['name']), types[record.get('returns', 'void')], arguments)
obj.methods = [make_method(m) for m in obj.record.get('methods', [])]
methods = [make_method(m) for m in obj.record.get('methods', [])]
obj.methods = [method for method in methods if not is_native_method(method)]
obj.native_methods = [method for method in methods if is_native_method(method)]
def parse_json(json):
category_to_parser = {
'bitmask': BitmaskType,
'enum': EnumType,
'native': NativeType,
'natively defined': NativelyDefined,
'object': ObjectType,
}
@ -296,11 +309,14 @@ def as_backendType(typ):
else:
return as_cType(typ.name)
def native_methods(types, typ):
return [
def c_native_methods(types, typ):
return cpp_native_methods(typ) + [
Method(Name('reference'), types['void'], []),
Method(Name('release'), types['void'], []),
] + typ.methods
]
def cpp_native_methods(typ):
return typ.methods + typ.native_methods
def debug(text):
print(text)
@ -349,26 +365,29 @@ def main():
'as_cppType': as_cppType,
'as_varName': as_varName,
'decorate': decorate,
'native_methods': lambda typ: native_methods(api_params['types'], typ)
}
renders = []
c_params = {'native_methods': lambda typ: c_native_methods(api_params['types'], typ)}
if 'nxt' in targets:
renders.append(FileRender('api.h', 'nxt/nxt.h', [base_params, api_params]))
renders.append(FileRender('api.c', 'nxt/nxt.c', [base_params, api_params]))
renders.append(FileRender('api.h', 'nxt/nxt.h', [base_params, api_params, c_params]))
renders.append(FileRender('api.c', 'nxt/nxt.c', [base_params, api_params, c_params]))
if 'nxtcpp' in targets:
renders.append(FileRender('apicpp.h', 'nxt/nxtcpp.h', [base_params, api_params]))
renders.append(FileRender('apicpp.cpp', 'nxt/nxtcpp.cpp', [base_params, api_params]))
additional_params = {'native_methods': cpp_native_methods}
renders.append(FileRender('apicpp.h', 'nxt/nxtcpp.h', [base_params, api_params, additional_params]))
renders.append(FileRender('apicpp.cpp', 'nxt/nxtcpp.cpp', [base_params, api_params, additional_params]))
if 'mock_nxt' in targets:
renders.append(FileRender('mock_api.h', 'mock/mock_nxt.h', [base_params, api_params]))
renders.append(FileRender('mock_api.cpp', 'mock/mock_nxt.cpp', [base_params, api_params]))
renders.append(FileRender('mock_api.h', 'mock/mock_nxt.h', [base_params, api_params, c_params]))
renders.append(FileRender('mock_api.cpp', 'mock/mock_nxt.cpp', [base_params, api_params, c_params]))
base_backend_params = [
base_params,
api_params,
c_params,
{
'as_backendType': lambda typ: as_backendType(typ), # TODO as_backendType and friends take a Type and not a Name :(
'as_annotated_backendType': lambda arg: annotated(as_backendType(arg.type), arg)

View File

@ -33,6 +33,10 @@
{% endfor %}
// Custom types depending on the target language
typedef uint64_t nxtCallbackUserdata;
typedef void (*nxtDeviceErrorCallback)(const char* message, nxtCallbackUserdata userdata);
#ifdef __cplusplus
extern "C" {
#endif

View File

@ -64,7 +64,7 @@ namespace nxt {
{{as_varName(arg.name)}}.Get()
{%- elif arg.type.category == "enum" or arg.type.category == "bitmask" -%}
static_cast<{{as_cType(arg.type.name)}}>({{as_varName(arg.name)}})
{%- elif arg.type.category == "native" -%}
{%- elif arg.type.category in ["native", "natively defined"] -%}
{{as_varName(arg.name)}}
{%- else -%}
UNHANDLED
@ -76,7 +76,7 @@ namespace nxt {
)
{%- endmacro %}
{% for method in type.methods %}
{% for method in native_methods(type) %}
{{render_cpp_method_declaration(type, method)}} {
{% if method.return_type.name.concatcase() == "void" %}
{{render_cpp_to_c_method_call(type, method)}};

View File

@ -47,6 +47,10 @@ namespace nxt {
{% endfor %}
{% for type in by_category["natively defined"] %}
using {{as_cppType(type.name)}} = {{as_cType(type.name)}};
{% endfor %}
{% for type in by_category["object"] %}
class {{as_cppType(type.name)}};
{% endfor %}
@ -132,7 +136,7 @@ namespace nxt {
using ObjectBase::ObjectBase;
using ObjectBase::operator=;
{% for method in type.methods %}
{% for method in native_methods(type) %}
{{render_cpp_method_declaration(type, method)}};
{% endfor %}

View File

@ -100,6 +100,9 @@ namespace wire {
ObjectAllocator<{{type.name.CamelCase()}}> {{type.name.camelCase()}};
{% endfor %}
nxtDeviceErrorCallback errorCallback = nullptr;
nxtCallbackUserdata errorUserdata;
private:
CommandSerializer* serializer = nullptr;
};
@ -200,6 +203,11 @@ namespace wire {
void ClientDeviceRelease(Device* self) {
}
void ClientDeviceSetErrorCallback(Device* self, nxtDeviceErrorCallback callback, nxtCallbackUserdata userdata) {
self->errorCallback = callback;
self->errorUserdata = userdata;
}
nxtProcTable GetProcs() {
nxtProcTable table;
{% for type in by_category["object"] %}
@ -216,16 +224,72 @@ namespace wire {
}
const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override {
// TODO(cwallez@chromium.org): process callbacks
return nullptr;
}
while (size > sizeof(ReturnWireCmd)) {
ReturnWireCmd cmdId = *reinterpret_cast<const ReturnWireCmd*>(commands);
void OnSynchronousError() override {
// TODO(cwallez@chromium.org): this will disappear
bool success = false;
switch (cmdId) {
case ReturnWireCmd::DeviceErrorCallback:
success = HandleDeviceErrorCallbackCmd(&commands, &size);
break;
default:
success = false;
}
if (!success) {
return nullptr;
}
}
if (size != 0) {
return nullptr;
}
return commands;
}
private:
Device* device = nullptr;
//* Helper function for the getting of the command data in command handlers.
//* Checks there is enough data left, updates the buffer / size and returns
//* the command (or nullptr for an error).
template<typename T>
static const T* GetCommand(const uint8_t** commands, size_t* size) {
if (*size < sizeof(T)) {
return nullptr;
}
const T* cmd = reinterpret_cast<const T*>(*commands);
size_t cmdSize = cmd->GetRequiredSize();
if (*size < cmdSize) {
return nullptr;
}
*commands += cmdSize;
*size -= cmdSize;
return cmd;
}
bool HandleDeviceErrorCallbackCmd(const uint8_t** commands, size_t* size) {
const auto* cmd = GetCommand<ReturnDeviceErrorCallbackCmd>(commands, size);
if (cmd == nullptr) {
return false;
}
if (cmd->GetMessage()[cmd->messageStrlen] != '\0') {
return false;
}
if (device->errorCallback != nullptr) {
device->errorCallback(cmd->GetMessage(), device->errorUserdata);
}
return true;
}
};
}

View File

@ -83,13 +83,18 @@ namespace wire {
//* The command structure used when sending that an ID is destroyed.
{% set Suffix = as_MethodSuffix(type.name, Name("destroy")) %}
struct {{Suffix}}Cmd {
wire::WireCmd commandId = wire::WireCmd::{{Suffix}};
WireCmd commandId = WireCmd::{{Suffix}};
uint32_t objectId;
size_t GetRequiredSize() const;
};
{% endfor %}
//* Enum used as a prefix to each command on the return wire format.
enum class ReturnWireCmd : uint32_t {
DeviceErrorCallback,
};
}
}

View File

@ -28,7 +28,7 @@ namespace wire {
//* The backend-provided handle to this object.
T handle;
//* Used by the error-propagation mechanism to know if this object is an error.
//* TODO(cwallez@chromium.org): this is doubling the memory usae of
//* TODO(cwallez@chromium.org): this is doubling the memory usage of
//* std::vector<ObjectDataBase> consider making it a special marker value in handle instead.
bool valid;
//* Whether this object has been allocated, used by the KnownObjects queries
@ -103,13 +103,28 @@ namespace wire {
std::vector<Data> known;
};
void ForwardDeviceErrorToServer(const char* message, nxtCallbackUserdata userdata);
class Server : public CommandHandler {
public:
Server(nxtDevice device, const nxtProcTable& procs) : procs(procs) {
Server(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer)
: procs(procs), serializer(serializer) {
//* The client-server knowledge is bootstrapped with device 1.
auto* deviceData = knownDevice.Allocate(1);
deviceData->handle = device;
deviceData->valid = true;
auto userdata = static_cast<nxtCallbackUserdata>(reinterpret_cast<intptr_t>(this));
procs.deviceSetErrorCallback(device, ForwardDeviceErrorToServer, userdata);
}
void OnDeviceError(const char* message) {
ReturnDeviceErrorCallbackCmd cmd;
cmd.messageStrlen = std::strlen(message);
auto allocCmd = reinterpret_cast<ReturnDeviceErrorCallbackCmd*>(GetCmdSpace(cmd.GetRequiredSize()));
*allocCmd = cmd;
strcpy(allocCmd->GetMessage(), message);
}
const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override {
@ -147,14 +162,15 @@ namespace wire {
return commands;
}
void OnSynchronousError() override {
gotError = true;
}
private:
nxtProcTable procs;
CommandSerializer* serializer = nullptr;
bool gotError = false;
void* GetCmdSpace(size_t size) {
return serializer->GetCmdSpace(size);
}
//* The list of known IDs for each object type.
{% for type in by_category["object"] %}
KnownObjects<{{as_cType(type.name)}}> known{{type.name.CamelCase()}};
@ -164,7 +180,7 @@ namespace wire {
//* Checks there is enough data left, updates the buffer / size and returns
//* the command (or nullptr for an error).
template<typename T>
const T* GetCommand(const uint8_t** commands, size_t* size) {
static const T* GetCommand(const uint8_t** commands, size_t* size) {
if (*size < sizeof(T)) {
return nullptr;
}
@ -333,11 +349,14 @@ namespace wire {
{% endfor %}
};
void ForwardDeviceErrorToServer(const char* message, nxtCallbackUserdata userdata) {
auto server = reinterpret_cast<Server*>(static_cast<intptr_t>(userdata));
server->OnDeviceError(message);
}
}
CommandHandler* NewServerCommandHandler(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer) {
//TODO(cwallez@chromium.org) do something with the serializer
return new server::Server(device, procs);
return new server::Server(device, procs, serializer);
}
}

View File

@ -190,6 +190,9 @@
}
]
},
"callback userdata": {
"category": "natively defined"
},
"char": {
"category": "native"
},
@ -368,9 +371,19 @@
{"name": "source", "type": "bind group"},
{"name": "target", "type": "bind group"}
]
},
{
"name": "set error callback",
"args": [
{"name": "callback", "type": "device error callback"},
{"name": "userdata", "type": "callback userdata"}
]
}
]
},
"device error callback": {
"category": "natively defined"
},
"filter mode": {
"category": "enum",
"values": [

View File

@ -30,11 +30,6 @@
namespace backend {
void RegisterSynchronousErrorCallback(nxtDevice device, ErrorCallback callback, void* userData) {
auto deviceBase = reinterpret_cast<DeviceBase*>(device);
deviceBase->RegisterErrorCallback(callback, userData);
}
// DeviceBase::Caches
// The caches are unordered_sets of pointers with special hash and compare functions
@ -57,13 +52,13 @@ namespace backend {
void DeviceBase::HandleError(const char* message) {
if (errorCallback) {
errorCallback(message, errorUserData);
errorCallback(message, errorUserdata);
}
}
void DeviceBase::RegisterErrorCallback(ErrorCallback callback, void* userData) {
void DeviceBase::SetErrorCallback(nxt::DeviceErrorCallback callback, nxt::CallbackUserdata userdata) {
this->errorCallback = callback;
this->errorUserData = userData;
this->errorUserdata = userdata;
}
BindGroupLayoutBase* DeviceBase::GetOrCreateBindGroupLayout(const BindGroupLayoutBase* blueprint, BindGroupLayoutBuilder* builder) {

View File

@ -30,7 +30,7 @@ namespace backend {
~DeviceBase();
void HandleError(const char* message);
void RegisterErrorCallback(ErrorCallback callback, void* userData);
void SetErrorCallback(nxt::DeviceErrorCallback, nxt::CallbackUserdata userdata);
virtual BindGroupBase* CreateBindGroup(BindGroupBuilder* builder) = 0;
virtual BindGroupLayoutBase* CreateBindGroupLayout(BindGroupLayoutBuilder* builder) = 0;
@ -85,8 +85,8 @@ namespace backend {
struct Caches;
Caches* caches = nullptr;
ErrorCallback errorCallback = nullptr;
void* errorUserData = nullptr;
nxt::DeviceErrorCallback errorCallback = nullptr;
nxt::CallbackUserdata errorUserdata;
};
}

View File

@ -30,6 +30,8 @@ SetPic(wire_autogen)
add_library(nxt_wire SHARED
${WIRE_DIR}/TerribleCommandBuffer.h
${WIRE_DIR}/WireCmd.cpp
${WIRE_DIR}/WireCmd.h
)
target_link_libraries(nxt_wire wire_autogen)
SetCXX14(nxt_wire)

View File

@ -33,8 +33,6 @@ namespace wire {
public:
virtual ~CommandHandler() = default;
virtual const uint8_t* HandleCommands(const uint8_t* commands, size_t size) = 0;
virtual void OnSynchronousError() = 0;
};
CommandHandler* NewClientDevice(nxtProcTable* procs, nxtDevice* device, CommandSerializer* serializer);

33
src/wire/WireCmd.cpp Normal file
View File

@ -0,0 +1,33 @@
// Copyright 2017 The NXT Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "WireCmd.h"
namespace nxt {
namespace wire {
size_t ReturnDeviceErrorCallbackCmd::GetRequiredSize() const {
return sizeof(*this) + messageStrlen + 1;
}
char* ReturnDeviceErrorCallbackCmd::GetMessage() {
return reinterpret_cast<char*>(this + 1);
}
const char* ReturnDeviceErrorCallbackCmd::GetMessage() const {
return reinterpret_cast<const char*>(this + 1);
}
}
}

View File

@ -17,4 +17,20 @@
#include "wire/WireCmd_autogen.h"
namespace nxt {
namespace wire {
struct ReturnDeviceErrorCallbackCmd {
wire::ReturnWireCmd commandId = ReturnWireCmd::DeviceErrorCallback;
size_t messageStrlen;
size_t GetRequiredSize() const;
char* GetMessage();
const char* GetMessage() const;
};
}
}
#endif // WIRE_WIRECMD_H_