diff --git a/generator/templates/wire/WireClient.cpp b/generator/templates/wire/WireClient.cpp index 8277ab71b4..75b9d91f04 100644 --- a/generator/templates/wire/WireClient.cpp +++ b/generator/templates/wire/WireClient.cpp @@ -21,6 +21,7 @@ #include #include +#include namespace nxt { namespace wire { @@ -69,12 +70,51 @@ namespace wire { BuilderCallbackData builderCallback; }; - {% for type in by_category["object"] if not type.name.canonical_case() == "device" %} + {% set special_objects = [ + "device", + "buffer", + ] %} + {% for type in by_category["object"] if not type.name.canonical_case() in special_objects %} struct {{type.name.CamelCase()}} : ObjectBase { using ObjectBase::ObjectBase; }; {% endfor %} + struct Buffer : ObjectBase { + using ObjectBase::ObjectBase; + + ~Buffer() { + //* Callbacks need to be fired in all cases, as they can handle freeing resources + //* so we call them with "Unknown" status. + ClearMapRequests(NXT_BUFFER_MAP_READ_STATUS_UNKNOWN); + + if (mappedData) { + free(mappedData); + } + } + + void ClearMapRequests(nxtBufferMapReadStatus status) { + for (auto& it : readRequests) { + it.second.callback(status, nullptr, it.second.userdata); + } + readRequests.clear(); + } + + //* We want to defer all the validation to the server, which means we could have multiple + //* map request in flight at a single time and need to track them separately. + //* On well-behaved applications, only one request should exist at a single time. + struct MapReadRequestData { + nxtBufferMapReadCallback callback = nullptr; + nxtCallbackUserdata userdata = 0; + uint32_t size = 0; + }; + std::map readRequests; + uint32_t readRequestSerial = 0; + + //* Only one mapped pointer can be active at a time because Unmap clears all the in-flight requests. + void* mappedData = nullptr; + }; + //* TODO(cwallez@chromium.org): Do something with objects before they are destroyed ? //* - Call still uncalled builder callbacks template @@ -169,6 +209,12 @@ namespace wire { ObjectAllocator<{{type.name.CamelCase()}}> {{type.name.camelCase()}}; {% endfor %} + void HandleError(const char* message) { + if (errorCallback) { + errorCallback(message, errorUserdata); + } + } + nxtDeviceErrorCallback errorCallback = nullptr; nxtCallbackUserdata errorUserdata; @@ -290,6 +336,41 @@ namespace wire { {% endfor %} void ClientBufferMapReadAsync(Buffer* buffer, uint32_t start, uint32_t size, nxtBufferMapReadCallback callback, nxtCallbackUserdata userdata) { + uint32_t serial = buffer->readRequestSerial++; + assert(buffer->readRequests.find(serial) == buffer->readRequests.end()); + + Buffer::MapReadRequestData request; + request.callback = callback; + request.userdata = userdata; + request.size = size; + buffer->readRequests[serial] = request; + + wire::BufferMapReadAsyncCmd cmd; + cmd.bufferId = buffer->id; + cmd.requestSerial = serial; + cmd.start = start; + cmd.size = size; + + size_t requiredSize = cmd.GetRequiredSize(); + auto allocCmd = reinterpret_cast(buffer->device->GetCmdSpace(requiredSize)); + *allocCmd = cmd; + } + + void ProxyClientBufferUnmap(Buffer* buffer) { + //* Invalidate the local pointer, and cancel all other in-flight requests that would turn into + //* errors anyway (you can't double map). This prevents race when the following happens, where + //* the application code would have unmapped a buffer but still receive a callback: + //* - Client -> Server: MapRequest1, Unmap, MapRequest2 + //* - Server -> Client: Result of MapRequest1 + //* - Unmap locally on the client + //* - Server -> Client: Result of MapRequest2 + if (buffer->mappedData) { + free(buffer->mappedData); + buffer->mappedData = nullptr; + } + buffer->ClearMapRequests(NXT_BUFFER_MAP_READ_STATUS_UNKNOWN); + + ClientBufferUnmap(buffer); } void ClientDeviceReference(Device* self) { @@ -303,11 +384,23 @@ namespace wire { self->errorUserdata = userdata; } + // Some commands don't have a custom wire format, but need to be handled manually to update + // some client-side state tracking. For these we have to functions: + // - An autogenerated Client{{suffix}} method that sends the command on the wire + // - A manual ProxyClient{{suffix}} method that will be inserted in the proctable instead of + // the autogenerated one, and that will have to call Client{{suffix}} + {% set proxied_commands = ["BufferUnmap"] %} + nxtProcTable GetProcs() { nxtProcTable table; {% for type in by_category["object"] %} {% for method in native_methods(type) %} - table.{{as_varName(type.name, method.name)}} = reinterpret_cast<{{as_cProc(type.name, method.name)}}>(Client{{as_MethodSuffix(type.name, method.name)}}); + {% set suffix = as_MethodSuffix(type.name, method.name) %} + {% if suffix in proxied_commands %} + table.{{as_varName(type.name, method.name)}} = reinterpret_cast<{{as_cProc(type.name, method.name)}}>(ProxyClient{{suffix}}); + {% else %} + table.{{as_varName(type.name, method.name)}} = reinterpret_cast<{{as_cProc(type.name, method.name)}}>(Client{{suffix}}); + {% endif %} {% endfor %} {% endfor %} return table; @@ -332,6 +425,9 @@ namespace wire { success = Handle{{type.name.CamelCase()}}ErrorCallbackCmd(&commands, &size); break; {% endfor %} + case ReturnWireCmd::BufferMapReadAsyncCallback: + success = HandleBufferMapReadAsyncCallback(&commands, &size); + break; default: success = false; } @@ -383,9 +479,7 @@ namespace wire { return false; } - if (device->errorCallback != nullptr) { - device->errorCallback(cmd->GetMessage(), device->errorUserdata); - } + device->HandleError(cmd->GetMessage()); return true; } @@ -414,6 +508,49 @@ namespace wire { return true; } {% endfor %} + + bool HandleBufferMapReadAsyncCallback(const uint8_t** commands, size_t* size) { + const auto* cmd = GetCommand(commands, size); + if (cmd == nullptr) { + return false; + } + + auto* buffer = device->buffer.GetObject(cmd->bufferId); + uint32_t bufferSerial = device->buffer.GetSerial(cmd->bufferId); + + //* The buffer might have been deleted or recreated so this isn't an error. + if (buffer == nullptr || bufferSerial != cmd->bufferSerial) { + return true; + } + + //* The requests can have been deleted via an Unmap so this isn't an error. + auto requestIt = buffer->readRequests.find(cmd->requestSerial); + if (requestIt == buffer->readRequests.end()) { + return true; + } + + auto request = requestIt->second; + + //* On success, we copy the data locally because the IPC buffer isn't valid outside of this function + if (cmd->status == NXT_BUFFER_MAP_READ_STATUS_SUCCESS) { + + //* The server didn't send the right amount of data, this is an error and could cause + //* the application to crash if we did call the callback. + if (request.size != cmd->dataLength) { + return false; + } + + if (buffer->mappedData != nullptr) { + return false; + } + buffer->mappedData = malloc(request.size); + memcpy(buffer->mappedData, cmd->GetData(), request.size); + } + + request.callback(static_cast(cmd->status), buffer->mappedData, request.userdata); + buffer->readRequests.erase(requestIt); + return true; + } }; } diff --git a/generator/templates/wire/WireCmd.h b/generator/templates/wire/WireCmd.h index d6be3ace7a..b83a049fa2 100644 --- a/generator/templates/wire/WireCmd.h +++ b/generator/templates/wire/WireCmd.h @@ -28,6 +28,7 @@ namespace wire { {% endfor %} {{as_MethodSuffix(type.name, Name("destroy"))}}, {% endfor %} + BufferMapReadAsync, }; {% for type in by_category["object"] %} @@ -98,6 +99,7 @@ namespace wire { {% for type in by_category["object"] if type.is_builder %} {{type.name.CamelCase()}}ErrorCallback, {% endfor %} + BufferMapReadAsyncCallback, }; {% for type in by_category["object"] if type.is_builder %} diff --git a/generator/templates/wire/WireServer.cpp b/generator/templates/wire/WireServer.cpp index 556daa573b..db2e3eb45d 100644 --- a/generator/templates/wire/WireServer.cpp +++ b/generator/templates/wire/WireServer.cpp @@ -23,6 +23,16 @@ namespace nxt { namespace wire { namespace server { + class Server; + + struct MapReadUserdata { + Server* server; + uint32_t bufferId; + uint32_t bufferSerial; + uint32_t requestSerial; + uint32_t size; + }; + //* Stores what the backend knows about the type. template struct ObjectDataBase { @@ -117,6 +127,8 @@ namespace wire { void Forward{{type.name.CamelCase()}}ToClient(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2); {% endfor %} + void ForwardBufferMapReadAsync(nxtBufferMapReadStatus status, const void* ptr, nxtCallbackUserdata userdata); + class Server : public CommandHandler { public: Server(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer) @@ -170,6 +182,28 @@ namespace wire { } {% endfor %} + void OnMapReadAsyncCallback(nxtBufferMapReadStatus status, const void* ptr, MapReadUserdata* data) { + ReturnBufferMapReadAsyncCallbackCmd cmd; + cmd.bufferId = data->bufferId; + cmd.bufferSerial = data->bufferSerial; + cmd.requestSerial = data->requestSerial; + cmd.status = status; + + cmd.dataLength = 0; + if (status == NXT_BUFFER_MAP_READ_STATUS_SUCCESS) { + cmd.dataLength = data->size; + } + + auto allocCmd = reinterpret_cast(GetCmdSpace(cmd.GetRequiredSize())); + *allocCmd = cmd; + + if (status == NXT_BUFFER_MAP_READ_STATUS_SUCCESS) { + memcpy(allocCmd->GetData(), ptr, data->size); + } + + delete data; + } + const uint8_t* HandleCommands(const uint8_t* commands, size_t size) override { while (size > sizeof(WireCmd)) { WireCmd cmdId = *reinterpret_cast(commands); @@ -188,6 +222,9 @@ namespace wire { success = Handle{{Suffix}}(&commands, &size); break; {% endfor %} + case WireCmd::BufferMapReadAsync: + success = HandleBufferMapReadAsync(&commands, &size); + break; default: success = false; @@ -405,6 +442,39 @@ namespace wire { return true; } {% endfor %} + + bool HandleBufferMapReadAsync(const uint8_t** commands, size_t* size) { + //* These requests are just forwarded to the buffer, with userdata containing what the client + //* will require in the return command. + const auto* cmd = GetCommand(commands, size); + if (cmd == nullptr) { + return false; + } + + auto* buffer = knownBuffer.Get(cmd->bufferId); + if (buffer == nullptr) { + return false; + } + + auto* data = new MapReadUserdata; + data->server = this; + data->bufferId = cmd->bufferId; + data->bufferSerial = buffer->serial; + data->requestSerial = cmd->requestSerial; + data->size = cmd->size; + + auto userdata = static_cast(reinterpret_cast(data)); + + if (!buffer->valid) { + //* Fake the buffer returning a failure, data will be freed in this call. + ForwardBufferMapReadAsync(NXT_BUFFER_MAP_READ_STATUS_ERROR, nullptr, userdata); + return true; + } + + procs.bufferMapReadAsync(buffer->handle, cmd->start, cmd->size, ForwardBufferMapReadAsync, userdata); + + return true; + } }; void ForwardDeviceErrorToServer(const char* message, nxtCallbackUserdata userdata) { @@ -414,12 +484,17 @@ namespace wire { {% for type in by_category["object"] if type.is_builder%} void Forward{{type.name.CamelCase()}}ToClient(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2) { - auto server = reinterpret_cast(static_cast(userdata1)); + auto server = reinterpret_cast(static_cast(userdata1)); uint32_t id = userdata2 & 0xFFFFFFFFu; uint32_t serial = userdata2 >> uint64_t(32); server->On{{type.name.CamelCase()}}Error(status, message, id, serial); } {% endfor %} + + void ForwardBufferMapReadAsync(nxtBufferMapReadStatus status, const void* ptr, nxtCallbackUserdata userdata) { + auto data = reinterpret_cast(static_cast(userdata)); + data->server->OnMapReadAsyncCallback(status, ptr, data); + } } CommandHandler* NewServerCommandHandler(nxtDevice device, const nxtProcTable& procs, CommandSerializer* serializer) { diff --git a/src/wire/WireCmd.cpp b/src/wire/WireCmd.cpp index a023a210a5..6033bb2598 100644 --- a/src/wire/WireCmd.cpp +++ b/src/wire/WireCmd.cpp @@ -29,5 +29,21 @@ namespace wire { return reinterpret_cast(this + 1); } + size_t BufferMapReadAsyncCmd::GetRequiredSize() const { + return sizeof(*this); + } + + size_t ReturnBufferMapReadAsyncCallbackCmd::GetRequiredSize() const { + return sizeof(*this) + dataLength; + } + + void* ReturnBufferMapReadAsyncCallbackCmd::GetData() { + return this + 1; + } + + const void* ReturnBufferMapReadAsyncCallbackCmd::GetData() const { + return this + 1; + } + } } diff --git a/src/wire/WireCmd.h b/src/wire/WireCmd.h index c261c309d7..d7cb8c4a68 100644 --- a/src/wire/WireCmd.h +++ b/src/wire/WireCmd.h @@ -30,6 +30,31 @@ namespace wire { const char* GetMessage() const; }; + struct BufferMapReadAsyncCmd { + wire::WireCmd commandId = WireCmd::BufferMapReadAsync; + + uint32_t bufferId; + uint32_t requestSerial; + uint32_t start; + uint32_t size; + + size_t GetRequiredSize() const; + }; + + struct ReturnBufferMapReadAsyncCallbackCmd { + wire::ReturnWireCmd commandId = ReturnWireCmd::BufferMapReadAsyncCallback; + + uint32_t bufferId; + uint32_t bufferSerial; + uint32_t requestSerial; + uint32_t status; + uint32_t dataLength; + + size_t GetRequiredSize() const; + void* GetData(); + const void* GetData() const; + }; + } }