Implement the device error callback.

This adds support for "natively defined" API types like callbacks that
will have to be implemented manually for each target language. Also this
splits the concept of "native method" into a set of native methods per
language.

Removes the "Synchronous error" concept that was used to make builders
work in the maybe Monad, this will have to be reinroduced with builder
callbacks.
This commit is contained in:
Corentin Wallez 2017-04-20 14:42:36 -04:00 committed by Corentin Wallez
parent 682a8250b3
commit 4b410a33ca
15 changed files with 219 additions and 55 deletions

View File

@ -28,8 +28,6 @@
BackendBinding* CreateMetalBinding(); BackendBinding* CreateMetalBinding();
namespace backend { namespace backend {
void RegisterSynchronousErrorCallback(nxtDevice device, void(*)(const char*, void*), void* userData);
namespace opengl { namespace opengl {
void Init(void* (*getProc)(const char*), nxtProcTable* procs, nxtDevice* device); void Init(void* (*getProc)(const char*), nxtProcTable* procs, nxtDevice* device);
void HACKCLEAR(); void HACKCLEAR();
@ -61,6 +59,10 @@ class OpenGLBinding : public BackendBinding {
} }
}; };
void PrintDeviceError(const char* message, nxt::CallbackUserdata) {
std::cout << "Device error: " << message << std::endl;
}
enum class BackendType { enum class BackendType {
OpenGL, OpenGL,
Metal, Metal,
@ -83,15 +85,6 @@ static nxt::wire::CommandHandler* wireClient = nullptr;
static nxt::wire::TerribleCommandBuffer* c2sBuf = nullptr; static nxt::wire::TerribleCommandBuffer* c2sBuf = nullptr;
static nxt::wire::TerribleCommandBuffer* s2cBuf = nullptr; static nxt::wire::TerribleCommandBuffer* s2cBuf = nullptr;
void HandleSynchronousError(const char* errorMessage, void* userData) {
std::cerr << errorMessage << std::endl;
if (userData != nullptr) {
auto wireServer = reinterpret_cast<nxt::wire::CommandHandler*>(userData);
wireServer->OnSynchronousError();
}
}
void GetProcTableAndDevice(nxtProcTable* procs, nxt::Device* device) { void GetProcTableAndDevice(nxtProcTable* procs, nxt::Device* device) {
switch (backendType) { switch (backendType) {
case BackendType::OpenGL: case BackendType::OpenGL:
@ -147,8 +140,7 @@ void GetProcTableAndDevice(nxtProcTable* procs, nxt::Device* device) {
break; break;
} }
//TODO(cwallez@chromium.org) this will disappear procs->deviceSetErrorCallback(device->Get(), PrintDeviceError, 0);
backend::RegisterSynchronousErrorCallback(backendDevice, HandleSynchronousError, wireServer);
} }
nxt::ShaderModule CreateShaderModule(const nxt::Device& device, nxt::ShaderStage stage, const char* source) { nxt::ShaderModule CreateShaderModule(const nxt::Device& device, nxt::ShaderStage stage, const char* source) {

View File

@ -74,6 +74,10 @@ class NativeType(Type):
def __init__(self, name, record): def __init__(self, name, record):
Type.__init__(self, name, record, native=True) Type.__init__(self, name, record, native=True)
class NativelyDefined(Type):
def __init__(self, name, record):
Type.__init__(self, name, record)
class MethodArgument: class MethodArgument:
def __init__(self, name, typ, annotation): def __init__(self, name, typ, annotation):
self.name = name self.name = name
@ -86,11 +90,17 @@ class ObjectType(Type):
def __init__(self, name, record): def __init__(self, name, record):
Type.__init__(self, name, record) Type.__init__(self, name, record)
self.methods = [] self.methods = []
self.native_methods = []
############################################################ ############################################################
# PARSE # PARSE
############################################################ ############################################################
import json import json
def is_native_method(method):
return method.return_type.category == "natively defined" or \
any([arg.type.category == "natively defined" for arg in method.arguments])
def link_object(obj, types): def link_object(obj, types):
def make_method(record): def make_method(record):
arguments = [] arguments = []
@ -110,13 +120,16 @@ def link_object(obj, types):
return Method(Name(record['name']), types[record.get('returns', 'void')], arguments) return Method(Name(record['name']), types[record.get('returns', 'void')], arguments)
obj.methods = [make_method(m) for m in obj.record.get('methods', [])] methods = [make_method(m) for m in obj.record.get('methods', [])]
obj.methods = [method for method in methods if not is_native_method(method)]
obj.native_methods = [method for method in methods if is_native_method(method)]
def parse_json(json): def parse_json(json):
category_to_parser = { category_to_parser = {
'bitmask': BitmaskType, 'bitmask': BitmaskType,
'enum': EnumType, 'enum': EnumType,
'native': NativeType, 'native': NativeType,
'natively defined': NativelyDefined,
'object': ObjectType, 'object': ObjectType,
} }
@ -296,11 +309,14 @@ def as_backendType(typ):
else: else:
return as_cType(typ.name) return as_cType(typ.name)
def native_methods(types, typ): def c_native_methods(types, typ):
return [ return cpp_native_methods(typ) + [
Method(Name('reference'), types['void'], []), Method(Name('reference'), types['void'], []),
Method(Name('release'), types['void'], []), Method(Name('release'), types['void'], []),
] + typ.methods ]
def cpp_native_methods(typ):
return typ.methods + typ.native_methods
def debug(text): def debug(text):
print(text) print(text)
@ -349,26 +365,29 @@ def main():
'as_cppType': as_cppType, 'as_cppType': as_cppType,
'as_varName': as_varName, 'as_varName': as_varName,
'decorate': decorate, 'decorate': decorate,
'native_methods': lambda typ: native_methods(api_params['types'], typ)
} }
renders = [] renders = []
c_params = {'native_methods': lambda typ: c_native_methods(api_params['types'], typ)}
if 'nxt' in targets: if 'nxt' in targets:
renders.append(FileRender('api.h', 'nxt/nxt.h', [base_params, api_params])) renders.append(FileRender('api.h', 'nxt/nxt.h', [base_params, api_params, c_params]))
renders.append(FileRender('api.c', 'nxt/nxt.c', [base_params, api_params])) renders.append(FileRender('api.c', 'nxt/nxt.c', [base_params, api_params, c_params]))
if 'nxtcpp' in targets: if 'nxtcpp' in targets:
renders.append(FileRender('apicpp.h', 'nxt/nxtcpp.h', [base_params, api_params])) additional_params = {'native_methods': cpp_native_methods}
renders.append(FileRender('apicpp.cpp', 'nxt/nxtcpp.cpp', [base_params, api_params])) renders.append(FileRender('apicpp.h', 'nxt/nxtcpp.h', [base_params, api_params, additional_params]))
renders.append(FileRender('apicpp.cpp', 'nxt/nxtcpp.cpp', [base_params, api_params, additional_params]))
if 'mock_nxt' in targets: if 'mock_nxt' in targets:
renders.append(FileRender('mock_api.h', 'mock/mock_nxt.h', [base_params, api_params])) renders.append(FileRender('mock_api.h', 'mock/mock_nxt.h', [base_params, api_params, c_params]))
renders.append(FileRender('mock_api.cpp', 'mock/mock_nxt.cpp', [base_params, api_params])) renders.append(FileRender('mock_api.cpp', 'mock/mock_nxt.cpp', [base_params, api_params, c_params]))
base_backend_params = [ base_backend_params = [
base_params, base_params,
api_params, api_params,
c_params,
{ {
'as_backendType': lambda typ: as_backendType(typ), # TODO as_backendType and friends take a Type and not a Name :( 'as_backendType': lambda typ: as_backendType(typ), # TODO as_backendType and friends take a Type and not a Name :(
'as_annotated_backendType': lambda arg: annotated(as_backendType(arg.type), arg) 'as_annotated_backendType': lambda arg: annotated(as_backendType(arg.type), arg)

View File

@ -33,6 +33,10 @@
{% endfor %} {% endfor %}
// Custom types depending on the target language
typedef uint64_t nxtCallbackUserdata;
typedef void (*nxtDeviceErrorCallback)(const char* message, nxtCallbackUserdata userdata);
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif

View File

@ -64,7 +64,7 @@ namespace nxt {
{{as_varName(arg.name)}}.Get() {{as_varName(arg.name)}}.Get()
{%- elif arg.type.category == "enum" or arg.type.category == "bitmask" -%} {%- elif arg.type.category == "enum" or arg.type.category == "bitmask" -%}
static_cast<{{as_cType(arg.type.name)}}>({{as_varName(arg.name)}}) static_cast<{{as_cType(arg.type.name)}}>({{as_varName(arg.name)}})
{%- elif arg.type.category == "native" -%} {%- elif arg.type.category in ["native", "natively defined"] -%}
{{as_varName(arg.name)}} {{as_varName(arg.name)}}
{%- else -%} {%- else -%}
UNHANDLED UNHANDLED
@ -76,7 +76,7 @@ namespace nxt {
) )
{%- endmacro %} {%- endmacro %}
{% for method in type.methods %} {% for method in native_methods(type) %}
{{render_cpp_method_declaration(type, method)}} { {{render_cpp_method_declaration(type, method)}} {
{% if method.return_type.name.concatcase() == "void" %} {% if method.return_type.name.concatcase() == "void" %}
{{render_cpp_to_c_method_call(type, method)}}; {{render_cpp_to_c_method_call(type, method)}};

View File

@ -47,6 +47,10 @@ namespace nxt {
{% endfor %} {% endfor %}
{% for type in by_category["natively defined"] %}
using {{as_cppType(type.name)}} = {{as_cType(type.name)}};
{% endfor %}
{% for type in by_category["object"] %} {% for type in by_category["object"] %}
class {{as_cppType(type.name)}}; class {{as_cppType(type.name)}};
{% endfor %} {% endfor %}
@ -132,7 +136,7 @@ namespace nxt {
using ObjectBase::ObjectBase; using ObjectBase::ObjectBase;
using ObjectBase::operator=; using ObjectBase::operator=;
{% for method in type.methods %} {% for method in native_methods(type) %}
{{render_cpp_method_declaration(type, method)}}; {{render_cpp_method_declaration(type, method)}};
{% endfor %} {% endfor %}

View File

@ -100,6 +100,9 @@ namespace wire {
ObjectAllocator<{{type.name.CamelCase()}}> {{type.name.camelCase()}}; ObjectAllocator<{{type.name.CamelCase()}}> {{type.name.camelCase()}};
{% endfor %} {% endfor %}
nxtDeviceErrorCallback errorCallback = nullptr;
nxtCallbackUserdata errorUserdata;
private: private:
CommandSerializer* serializer = nullptr; CommandSerializer* serializer = nullptr;
}; };
@ -200,6 +203,11 @@ namespace wire {
void ClientDeviceRelease(Device* self) { void ClientDeviceRelease(Device* self) {
} }
void ClientDeviceSetErrorCallback(Device* self, nxtDeviceErrorCallback callback, nxtCallbackUserdata userdata) {
self->errorCallback = callback;
self->errorUserdata = userdata;
}
nxtProcTable GetProcs() { nxtProcTable GetProcs() {
nxtProcTable table; nxtProcTable table;
{% for type in by_category["object"] %} {% for type in by_category["object"] %}
@ -216,16 +224,72 @@ namespace wire {
} }
const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override { const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override {
// TODO(cwallez@chromium.org): process callbacks while (size > sizeof(ReturnWireCmd)) {
return nullptr; ReturnWireCmd cmdId = *reinterpret_cast<const ReturnWireCmd*>(commands);
}
void OnSynchronousError() override { bool success = false;
// TODO(cwallez@chromium.org): this will disappear switch (cmdId) {
case ReturnWireCmd::DeviceErrorCallback:
success = HandleDeviceErrorCallbackCmd(&commands, &size);
break;
default:
success = false;
}
if (!success) {
return nullptr;
}
}
if (size != 0) {
return nullptr;
}
return commands;
} }
private: private:
Device* device = nullptr; Device* device = nullptr;
//* Helper function for the getting of the command data in command handlers.
//* Checks there is enough data left, updates the buffer / size and returns
//* the command (or nullptr for an error).
template<typename T>
static const T* GetCommand(const uint8_t** commands, size_t* size) {
if (*size < sizeof(T)) {
return nullptr;
}
const T* cmd = reinterpret_cast<const T*>(*commands);
size_t cmdSize = cmd->GetRequiredSize();
if (*size < cmdSize) {
return nullptr;
}
*commands += cmdSize;
*size -= cmdSize;
return cmd;
}
bool HandleDeviceErrorCallbackCmd(const uint8_t** commands, size_t* size) {
const auto* cmd = GetCommand<ReturnDeviceErrorCallbackCmd>(commands, size);
if (cmd == nullptr) {
return false;
}
if (cmd->GetMessage()[cmd->messageStrlen] != '\0') {
return false;
}
if (device->errorCallback != nullptr) {
device->errorCallback(cmd->GetMessage(), device->errorUserdata);
}
return true;
}
}; };
} }

View File

@ -83,13 +83,18 @@ namespace wire {
//* The command structure used when sending that an ID is destroyed. //* The command structure used when sending that an ID is destroyed.
{% set Suffix = as_MethodSuffix(type.name, Name("destroy")) %} {% set Suffix = as_MethodSuffix(type.name, Name("destroy")) %}
struct {{Suffix}}Cmd { struct {{Suffix}}Cmd {
wire::WireCmd commandId = wire::WireCmd::{{Suffix}}; WireCmd commandId = WireCmd::{{Suffix}};
uint32_t objectId; uint32_t objectId;
size_t GetRequiredSize() const; size_t GetRequiredSize() const;
}; };
{% endfor %} {% endfor %}
//* Enum used as a prefix to each command on the return wire format.
enum class ReturnWireCmd : uint32_t {
DeviceErrorCallback,
};
} }
} }

View File

@ -28,7 +28,7 @@ namespace wire {
//* The backend-provided handle to this object. //* The backend-provided handle to this object.
T handle; T handle;
//* Used by the error-propagation mechanism to know if this object is an error. //* Used by the error-propagation mechanism to know if this object is an error.
//* TODO(cwallez@chromium.org): this is doubling the memory usae of //* TODO(cwallez@chromium.org): this is doubling the memory usage of
//* std::vector<ObjectDataBase> consider making it a special marker value in handle instead. //* std::vector<ObjectDataBase> consider making it a special marker value in handle instead.
bool valid; bool valid;
//* Whether this object has been allocated, used by the KnownObjects queries //* Whether this object has been allocated, used by the KnownObjects queries
@ -103,13 +103,28 @@ namespace wire {
std::vector<Data> known; std::vector<Data> known;
}; };
void ForwardDeviceErrorToServer(const char* message, nxtCallbackUserdata userdata);
class Server : public CommandHandler { class Server : public CommandHandler {
public: public:
Server(nxtDevice device, const nxtProcTable& procs) : procs(procs) { Server(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer)
: procs(procs), serializer(serializer) {
//* The client-server knowledge is bootstrapped with device 1. //* The client-server knowledge is bootstrapped with device 1.
auto* deviceData = knownDevice.Allocate(1); auto* deviceData = knownDevice.Allocate(1);
deviceData->handle = device; deviceData->handle = device;
deviceData->valid = true; deviceData->valid = true;
auto userdata = static_cast<nxtCallbackUserdata>(reinterpret_cast<intptr_t>(this));
procs.deviceSetErrorCallback(device, ForwardDeviceErrorToServer, userdata);
}
void OnDeviceError(const char* message) {
ReturnDeviceErrorCallbackCmd cmd;
cmd.messageStrlen = std::strlen(message);
auto allocCmd = reinterpret_cast<ReturnDeviceErrorCallbackCmd*>(GetCmdSpace(cmd.GetRequiredSize()));
*allocCmd = cmd;
strcpy(allocCmd->GetMessage(), message);
} }
const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override { const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override {
@ -147,14 +162,15 @@ namespace wire {
return commands; return commands;
} }
void OnSynchronousError() override {
gotError = true;
}
private: private:
nxtProcTable procs; nxtProcTable procs;
CommandSerializer* serializer = nullptr;
bool gotError = false; bool gotError = false;
void* GetCmdSpace(size_t size) {
return serializer->GetCmdSpace(size);
}
//* The list of known IDs for each object type. //* The list of known IDs for each object type.
{% for type in by_category["object"] %} {% for type in by_category["object"] %}
KnownObjects<{{as_cType(type.name)}}> known{{type.name.CamelCase()}}; KnownObjects<{{as_cType(type.name)}}> known{{type.name.CamelCase()}};
@ -164,7 +180,7 @@ namespace wire {
//* Checks there is enough data left, updates the buffer / size and returns //* Checks there is enough data left, updates the buffer / size and returns
//* the command (or nullptr for an error). //* the command (or nullptr for an error).
template<typename T> template<typename T>
const T* GetCommand(const uint8_t** commands, size_t* size) { static const T* GetCommand(const uint8_t** commands, size_t* size) {
if (*size < sizeof(T)) { if (*size < sizeof(T)) {
return nullptr; return nullptr;
} }
@ -333,11 +349,14 @@ namespace wire {
{% endfor %} {% endfor %}
}; };
void ForwardDeviceErrorToServer(const char* message, nxtCallbackUserdata userdata) {
auto server = reinterpret_cast<Server*>(static_cast<intptr_t>(userdata));
server->OnDeviceError(message);
}
} }
CommandHandler* NewServerCommandHandler(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer) { CommandHandler* NewServerCommandHandler(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer) {
//TODO(cwallez@chromium.org) do something with the serializer return new server::Server(device, procs, serializer);
return new server::Server(device, procs);
} }
} }

View File

@ -190,6 +190,9 @@
} }
] ]
}, },
"callback userdata": {
"category": "natively defined"
},
"char": { "char": {
"category": "native" "category": "native"
}, },
@ -368,9 +371,19 @@
{"name": "source", "type": "bind group"}, {"name": "source", "type": "bind group"},
{"name": "target", "type": "bind group"} {"name": "target", "type": "bind group"}
] ]
},
{
"name": "set error callback",
"args": [
{"name": "callback", "type": "device error callback"},
{"name": "userdata", "type": "callback userdata"}
]
} }
] ]
}, },
"device error callback": {
"category": "natively defined"
},
"filter mode": { "filter mode": {
"category": "enum", "category": "enum",
"values": [ "values": [

View File

@ -30,11 +30,6 @@
namespace backend { namespace backend {
void RegisterSynchronousErrorCallback(nxtDevice device, ErrorCallback callback, void* userData) {
auto deviceBase = reinterpret_cast<DeviceBase*>(device);
deviceBase->RegisterErrorCallback(callback, userData);
}
// DeviceBase::Caches // DeviceBase::Caches
// The caches are unordered_sets of pointers with special hash and compare functions // The caches are unordered_sets of pointers with special hash and compare functions
@ -57,13 +52,13 @@ namespace backend {
void DeviceBase::HandleError(const char* message) { void DeviceBase::HandleError(const char* message) {
if (errorCallback) { if (errorCallback) {
errorCallback(message, errorUserData); errorCallback(message, errorUserdata);
} }
} }
void DeviceBase::RegisterErrorCallback(ErrorCallback callback, void* userData) { void DeviceBase::SetErrorCallback(nxt::DeviceErrorCallback callback, nxt::CallbackUserdata userdata) {
this->errorCallback = callback; this->errorCallback = callback;
this->errorUserData = userData; this->errorUserdata = userdata;
} }
BindGroupLayoutBase* DeviceBase::GetOrCreateBindGroupLayout(const BindGroupLayoutBase* blueprint, BindGroupLayoutBuilder* builder) { BindGroupLayoutBase* DeviceBase::GetOrCreateBindGroupLayout(const BindGroupLayoutBase* blueprint, BindGroupLayoutBuilder* builder) {

View File

@ -30,7 +30,7 @@ namespace backend {
~DeviceBase(); ~DeviceBase();
void HandleError(const char* message); void HandleError(const char* message);
void RegisterErrorCallback(ErrorCallback callback, void* userData); void SetErrorCallback(nxt::DeviceErrorCallback, nxt::CallbackUserdata userdata);
virtual BindGroupBase* CreateBindGroup(BindGroupBuilder* builder) = 0; virtual BindGroupBase* CreateBindGroup(BindGroupBuilder* builder) = 0;
virtual BindGroupLayoutBase* CreateBindGroupLayout(BindGroupLayoutBuilder* builder) = 0; virtual BindGroupLayoutBase* CreateBindGroupLayout(BindGroupLayoutBuilder* builder) = 0;
@ -85,8 +85,8 @@ namespace backend {
struct Caches; struct Caches;
Caches* caches = nullptr; Caches* caches = nullptr;
ErrorCallback errorCallback = nullptr; nxt::DeviceErrorCallback errorCallback = nullptr;
void* errorUserData = nullptr; nxt::CallbackUserdata errorUserdata;
}; };
} }

View File

@ -30,6 +30,8 @@ SetPic(wire_autogen)
add_library(nxt_wire SHARED add_library(nxt_wire SHARED
${WIRE_DIR}/TerribleCommandBuffer.h ${WIRE_DIR}/TerribleCommandBuffer.h
${WIRE_DIR}/WireCmd.cpp
${WIRE_DIR}/WireCmd.h
) )
target_link_libraries(nxt_wire wire_autogen) target_link_libraries(nxt_wire wire_autogen)
SetCXX14(nxt_wire) SetCXX14(nxt_wire)

View File

@ -33,8 +33,6 @@ namespace wire {
public: public:
virtual ~CommandHandler() = default; virtual ~CommandHandler() = default;
virtual const uint8_t* HandleCommands(const uint8_t* commands, size_t size) = 0; virtual const uint8_t* HandleCommands(const uint8_t* commands, size_t size) = 0;
virtual void OnSynchronousError() = 0;
}; };
CommandHandler* NewClientDevice(nxtProcTable* procs, nxtDevice* device, CommandSerializer* serializer); CommandHandler* NewClientDevice(nxtProcTable* procs, nxtDevice* device, CommandSerializer* serializer);

33
src/wire/WireCmd.cpp Normal file
View File

@ -0,0 +1,33 @@
// Copyright 2017 The NXT Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "WireCmd.h"
namespace nxt {
namespace wire {
size_t ReturnDeviceErrorCallbackCmd::GetRequiredSize() const {
return sizeof(*this) + messageStrlen + 1;
}
char* ReturnDeviceErrorCallbackCmd::GetMessage() {
return reinterpret_cast<char*>(this + 1);
}
const char* ReturnDeviceErrorCallbackCmd::GetMessage() const {
return reinterpret_cast<const char*>(this + 1);
}
}
}

View File

@ -17,4 +17,20 @@
#include "wire/WireCmd_autogen.h" #include "wire/WireCmd_autogen.h"
namespace nxt {
namespace wire {
struct ReturnDeviceErrorCallbackCmd {
wire::ReturnWireCmd commandId = ReturnWireCmd::DeviceErrorCallback;
size_t messageStrlen;
size_t GetRequiredSize() const;
char* GetMessage();
const char* GetMessage() const;
};
}
}
#endif // WIRE_WIRECMD_H_ #endif // WIRE_WIRECMD_H_