Add a BufferConsumer primitive for wire [de]serialization

BufferConsumer wraps a buffer pointer and size and exposes a
limited number of operations to get data while decrementing
the remaining available size. This makes it so that code
reading or writing into a buffer cannot easily consume more
bytes than available.

This CL guards against serialization overflows using
BufferConsumer, and it implements GetPtrFromBuffer
(for deserialization) on top of BufferConsumer. A future patch
will make the rest of the deserialization code use BufferConsumer.

Bug: dawn:680
Change-Id: Ic2bd6e7039e83ce70307c2ff47aaca9891c16d91
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/41780
Commit-Queue: Austin Eng <enga@chromium.org>
Reviewed-by: Stephen White <senorblanco@chromium.org>
This commit is contained in:
Austin Eng 2021-02-17 22:14:56 +00:00 committed by Commit Bot service account
parent eb71aaf689
commit 1b31dc0bb2
9 changed files with 236 additions and 84 deletions

View File

@ -201,8 +201,8 @@ namespace {
//* Serializes `record` into `transfer`, using `buffer` to get more space for pointed-to data //* Serializes `record` into `transfer`, using `buffer` to get more space for pointed-to data
//* and `provider` to serialize objects. //* and `provider` to serialize objects.
DAWN_DECLARE_UNUSED void {{Return}}{{name}}Serialize(const {{Return}}{{name}}{{Cmd}}& record, {{Return}}{{name}}Transfer* transfer, DAWN_DECLARE_UNUSED bool {{Return}}{{name}}Serialize(const {{Return}}{{name}}{{Cmd}}& record, {{Return}}{{name}}Transfer* transfer,
char** buffer SerializeBuffer* buffer
{%- if record.may_have_dawn_object -%} {%- if record.may_have_dawn_object -%}
, const ObjectIdProvider& provider , const ObjectIdProvider& provider
{%- endif -%} {%- endif -%}
@ -223,7 +223,7 @@ namespace {
{% if record.extensible %} {% if record.extensible %}
if (record.nextInChain != nullptr) { if (record.nextInChain != nullptr) {
transfer->hasNextInChain = true; transfer->hasNextInChain = true;
SerializeChainedStruct(record.nextInChain, buffer, provider); SERIALIZE_TRY(SerializeChainedStruct(record.nextInChain, buffer, provider));
} else { } else {
transfer->hasNextInChain = false; transfer->hasNextInChain = false;
} }
@ -247,8 +247,9 @@ namespace {
{ {
transfer->{{memberName}}Strlen = std::strlen(record.{{memberName}}); transfer->{{memberName}}Strlen = std::strlen(record.{{memberName}});
memcpy(*buffer, record.{{memberName}}, transfer->{{memberName}}Strlen); char* stringInBuffer;
*buffer += transfer->{{memberName}}Strlen; SERIALIZE_TRY(buffer->NextN(transfer->{{memberName}}Strlen, &stringInBuffer));
memcpy(stringInBuffer, record.{{memberName}}, transfer->{{memberName}}Strlen);
} }
{% endfor %} {% endfor %}
@ -263,14 +264,16 @@ namespace {
{% endif %} {% endif %}
{ {
size_t memberLength = {{member_length(member, "record.")}}; size_t memberLength = {{member_length(member, "record.")}};
auto memberBuffer = reinterpret_cast<{{member_transfer_type(member)}}*>(*buffer);
*buffer += memberLength * {{member_transfer_sizeof(member)}}; {{member_transfer_type(member)}}* memberBuffer;
SERIALIZE_TRY(buffer->NextN(memberLength, &memberBuffer));
for (size_t i = 0; i < memberLength; ++i) { for (size_t i = 0; i < memberLength; ++i) {
{{serialize_member(member, "record." + memberName + "[i]", "memberBuffer[i]" )}} {{serialize_member(member, "record." + memberName + "[i]", "memberBuffer[i]" )}}
} }
} }
{% endfor %} {% endfor %}
return true;
} }
DAWN_UNUSED_FUNC({{Return}}{{name}}Serialize); DAWN_UNUSED_FUNC({{Return}}{{name}}Serialize);
@ -386,20 +389,21 @@ namespace {
return size; return size;
} }
void {{Cmd}}::Serialize(size_t commandSize, char* buffer bool {{Cmd}}::Serialize(size_t commandSize, SerializeBuffer* buffer
{%- if not is_return -%} {%- if not is_return -%}
, const ObjectIdProvider& objectIdProvider , const ObjectIdProvider& objectIdProvider
{%- endif -%} {%- endif -%}
) const { ) const {
auto transfer = reinterpret_cast<{{Name}}Transfer*>(buffer); {{Name}}Transfer* transfer;
SERIALIZE_TRY(buffer->Next(&transfer));
transfer->commandSize = commandSize; transfer->commandSize = commandSize;
buffer += sizeof({{Name}}Transfer);
{{Name}}Serialize(*this, transfer, &buffer SERIALIZE_TRY({{Name}}Serialize(*this, transfer, buffer
{%- if command.may_have_dawn_object -%} {%- if command.may_have_dawn_object -%}
, objectIdProvider , objectIdProvider
{%- endif -%} {%- endif -%}
); ));
return true;
} }
DeserializeResult {{Cmd}}::Deserialize(const volatile char** buffer, size_t* size, DeserializeAllocator* allocator DeserializeResult {{Cmd}}::Deserialize(const volatile char** buffer, size_t* size, DeserializeAllocator* allocator
@ -429,6 +433,13 @@ namespace dawn_wire {
} \ } \
} while (0) } while (0)
#define SERIALIZE_TRY(EXPR) \
do { \
if (!(EXPR)) { \
return false; \
} \
} while (0)
ObjectHandle::ObjectHandle() = default; ObjectHandle::ObjectHandle() = default;
ObjectHandle::ObjectHandle(ObjectId id, ObjectGeneration generation) ObjectHandle::ObjectHandle(ObjectId id, ObjectGeneration generation)
: id(id), generation(generation) { : id(id), generation(generation) {
@ -454,27 +465,53 @@ namespace dawn_wire {
return *this; return *this;
} }
template <typename BufferT>
template <typename T>
bool BufferConsumer<BufferT>::Next(T** data) {
if (sizeof(T) > mSize) {
return false;
}
*data = reinterpret_cast<T*>(mBuffer);
mBuffer += sizeof(T);
mSize -= sizeof(T);
return true;
}
template <typename BufferT>
template <typename T, typename N>
bool BufferConsumer<BufferT>::NextN(N count, T** data) {
static_assert(std::is_unsigned<N>::value, "|count| argument of NextN must be unsigned.");
constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits<size_t>::max() / sizeof(T);
if (count > kMaxCountWithoutOverflows) {
return false;
}
// Cannot overflow because |count| is not greater than |kMaxCountWithoutOverflows|.
size_t totalSize = sizeof(T) * count;
if (totalSize > mSize) {
return false;
}
*data = reinterpret_cast<T*>(mBuffer);
mBuffer += totalSize;
mSize -= totalSize;
return true;
}
namespace { namespace {
// Consumes from (buffer, size) enough memory to contain T[count] and return it in data. // Consumes from (buffer, size) enough memory to contain T[count] and return it in data.
// Returns FatalError if not enough memory was available // Returns FatalError if not enough memory was available
template <typename T> template <typename T>
DeserializeResult GetPtrFromBuffer(const volatile char** buffer, size_t* size, size_t count, const volatile T** data) { DeserializeResult GetPtrFromBuffer(const volatile char** buffer, size_t* size, size_t count, const volatile T** data) {
constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits<size_t>::max() / sizeof(T); DeserializeBuffer deserializeBuffer(*buffer, *size);
if (count > kMaxCountWithoutOverflows) { DeserializeResult result = deserializeBuffer.ReadN(count, data);
return DeserializeResult::FatalError; if (result == DeserializeResult::Success) {
*buffer = deserializeBuffer.Buffer();
*size = deserializeBuffer.AvailableSize();
} }
return result;
size_t totalSize = sizeof(T) * count;
if (totalSize > *size) {
return DeserializeResult::FatalError;
}
*data = reinterpret_cast<const volatile T*>(*buffer);
*buffer += totalSize;
*size -= totalSize;
return DeserializeResult::Success;
} }
// Allocates enough space from allocator to countain T[count] and return it in out. // Allocates enough space from allocator to countain T[count] and return it in out.
@ -496,8 +533,8 @@ namespace dawn_wire {
} }
size_t GetChainedStructExtraRequiredSize(const WGPUChainedStruct* chainedStruct); size_t GetChainedStructExtraRequiredSize(const WGPUChainedStruct* chainedStruct);
void SerializeChainedStruct(WGPUChainedStruct const* chainedStruct, DAWN_NO_DISCARD bool SerializeChainedStruct(WGPUChainedStruct const* chainedStruct,
char** buffer, SerializeBuffer* buffer,
const ObjectIdProvider& provider); const ObjectIdProvider& provider);
DeserializeResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext, DeserializeResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext,
const volatile char** buffer, const volatile char** buffer,
@ -538,8 +575,8 @@ namespace dawn_wire {
return result; return result;
} }
void SerializeChainedStruct(WGPUChainedStruct const* chainedStruct, DAWN_NO_DISCARD bool SerializeChainedStruct(WGPUChainedStruct const* chainedStruct,
char** buffer, SerializeBuffer* buffer,
const ObjectIdProvider& provider) { const ObjectIdProvider& provider) {
ASSERT(chainedStruct != nullptr); ASSERT(chainedStruct != nullptr);
ASSERT(buffer != nullptr); ASSERT(buffer != nullptr);
@ -549,16 +586,16 @@ namespace dawn_wire {
{% set CType = as_cType(sType.name) %} {% set CType = as_cType(sType.name) %}
case {{as_cEnum(types["s type"].name, sType.name)}}: { case {{as_cEnum(types["s type"].name, sType.name)}}: {
auto* transfer = reinterpret_cast<{{CType}}Transfer*>(*buffer); {{CType}}Transfer* transfer;
SERIALIZE_TRY(buffer->Next(&transfer));
transfer->chain.sType = chainedStruct->sType; transfer->chain.sType = chainedStruct->sType;
transfer->chain.hasNext = chainedStruct->next != nullptr; transfer->chain.hasNext = chainedStruct->next != nullptr;
*buffer += sizeof({{CType}}Transfer); SERIALIZE_TRY({{CType}}Serialize(*reinterpret_cast<{{CType}} const*>(chainedStruct), transfer, buffer
{{CType}}Serialize(*reinterpret_cast<{{CType}} const*>(chainedStruct), transfer, buffer
{%- if types[sType.name.get()].may_have_dawn_object -%} {%- if types[sType.name.get()].may_have_dawn_object -%}
, provider , provider
{%- endif -%} {%- endif -%}
); ));
chainedStruct = chainedStruct->next; chainedStruct = chainedStruct->next;
} break; } break;
@ -570,18 +607,18 @@ namespace dawn_wire {
dawn::WarningLog() << "Unknown sType " << chainedStruct->sType << " discarded."; dawn::WarningLog() << "Unknown sType " << chainedStruct->sType << " discarded.";
} }
WGPUChainedStructTransfer* transfer = reinterpret_cast<WGPUChainedStructTransfer*>(*buffer); WGPUChainedStructTransfer* transfer;
SERIALIZE_TRY(buffer->Next(&transfer));
transfer->sType = WGPUSType_Invalid; transfer->sType = WGPUSType_Invalid;
transfer->hasNext = chainedStruct->next != nullptr; transfer->hasNext = chainedStruct->next != nullptr;
*buffer += sizeof(WGPUChainedStructTransfer);
// Still move on in case there are valid structs after this. // Still move on in case there are valid structs after this.
chainedStruct = chainedStruct->next; chainedStruct = chainedStruct->next;
break; break;
} }
} }
} while (chainedStruct != nullptr); } while (chainedStruct != nullptr);
return true;
} }
DeserializeResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext, DeserializeResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext,
@ -677,13 +714,14 @@ namespace dawn_wire {
} }
void SerializeWGPUDeviceProperties(const WGPUDeviceProperties* deviceProperties, void SerializeWGPUDeviceProperties(const WGPUDeviceProperties* deviceProperties,
char* serializeBuffer) { char* buffer) {
size_t devicePropertiesSize = SerializedWGPUDevicePropertiesSize(deviceProperties); SerializeBuffer serializeBuffer(buffer, SerializedWGPUDevicePropertiesSize(deviceProperties));
WGPUDevicePropertiesTransfer* transfer =
reinterpret_cast<WGPUDevicePropertiesTransfer*>(serializeBuffer);
serializeBuffer += devicePropertiesSize;
WGPUDevicePropertiesTransfer* transfer;
bool success =
serializeBuffer.Next(&transfer) &&
WGPUDevicePropertiesSerialize(*deviceProperties, transfer, &serializeBuffer); WGPUDevicePropertiesSerialize(*deviceProperties, transfer, &serializeBuffer);
ASSERT(success);
} }
bool DeserializeWGPUDeviceProperties(WGPUDeviceProperties* deviceProperties, bool DeserializeWGPUDeviceProperties(WGPUDeviceProperties* deviceProperties,

View File

@ -48,6 +48,52 @@ namespace dawn_wire {
FatalError, FatalError,
}; };
template <typename BufferT>
class BufferConsumer {
public:
BufferConsumer(BufferT* buffer, size_t size) : mBuffer(buffer), mSize(size) {}
BufferT* Buffer() const { return mBuffer; }
size_t AvailableSize() const { return mSize; }
protected:
template <typename T, typename N>
DAWN_NO_DISCARD bool NextN(N count, T** data);
template <typename T>
DAWN_NO_DISCARD bool Next(T** data);
private:
BufferT* mBuffer;
size_t mSize;
};
class SerializeBuffer : public BufferConsumer<char> {
public:
using BufferConsumer::BufferConsumer;
using BufferConsumer::NextN;
using BufferConsumer::Next;
};
class DeserializeBuffer : public BufferConsumer<const volatile char> {
public:
using BufferConsumer::BufferConsumer;
template <typename T, typename N>
DAWN_NO_DISCARD DeserializeResult ReadN(N count, const volatile T** data) {
return NextN(count, data)
? DeserializeResult::Success
: DeserializeResult::FatalError;
}
template <typename T>
DAWN_NO_DISCARD DeserializeResult Read(const volatile T** data) {
return Next(data)
? DeserializeResult::Success
: DeserializeResult::FatalError;
}
};
// Interface to allocate more space to deserialize pointed-to data. // Interface to allocate more space to deserialize pointed-to data.
// nullptr is treated as an error. // nullptr is treated as an error.
class DeserializeAllocator { class DeserializeAllocator {
@ -101,7 +147,7 @@ namespace dawn_wire {
//* Serialize the structure and everything it points to into serializeBuffer which must be //* Serialize the structure and everything it points to into serializeBuffer which must be
//* big enough to contain all the data (as queried from GetRequiredSize). //* big enough to contain all the data (as queried from GetRequiredSize).
void Serialize(size_t commandSize, char* serializeBuffer DAWN_NO_DISCARD bool Serialize(size_t commandSize, SerializeBuffer* serializeBuffer
{%- if not is_return_command -%} {%- if not is_return_command -%}
, const ObjectIdProvider& objectIdProvider , const ObjectIdProvider& objectIdProvider
{%- endif -%} {%- endif -%}

View File

@ -63,6 +63,7 @@ dawn_component("dawn_wire") {
"ChunkedCommandHandler.h", "ChunkedCommandHandler.h",
"ChunkedCommandSerializer.cpp", "ChunkedCommandSerializer.cpp",
"ChunkedCommandSerializer.h", "ChunkedCommandSerializer.h",
"Wire.cpp",
"WireClient.cpp", "WireClient.cpp",
"WireDeserializeAllocator.cpp", "WireDeserializeAllocator.cpp",
"WireDeserializeAllocator.h", "WireDeserializeAllocator.h",

View File

@ -35,6 +35,7 @@ target_sources(dawn_wire PRIVATE
"ChunkedCommandHandler.h" "ChunkedCommandHandler.h"
"ChunkedCommandSerializer.cpp" "ChunkedCommandSerializer.cpp"
"ChunkedCommandSerializer.h" "ChunkedCommandSerializer.h"
"Wire.cpp"
"WireClient.cpp" "WireClient.cpp"
"WireDeserializeAllocator.cpp" "WireDeserializeAllocator.cpp"
"WireDeserializeAllocator.h" "WireDeserializeAllocator.h"

View File

@ -32,7 +32,7 @@ namespace dawn_wire {
template <typename Cmd> template <typename Cmd>
void SerializeCommand(const Cmd& cmd) { void SerializeCommand(const Cmd& cmd) {
SerializeCommand(cmd, 0, [](char*) {}); SerializeCommand(cmd, 0, [](SerializeBuffer*) { return true; });
} }
template <typename Cmd, typename ExtraSizeSerializeFn> template <typename Cmd, typename ExtraSizeSerializeFn>
@ -41,15 +41,15 @@ namespace dawn_wire {
ExtraSizeSerializeFn&& SerializeExtraSize) { ExtraSizeSerializeFn&& SerializeExtraSize) {
SerializeCommandImpl( SerializeCommandImpl(
cmd, cmd,
[](const Cmd& cmd, size_t requiredSize, char* allocatedBuffer) { [](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) {
cmd.Serialize(requiredSize, allocatedBuffer); return cmd.Serialize(requiredSize, serializeBuffer);
}, },
extraSize, std::forward<ExtraSizeSerializeFn>(SerializeExtraSize)); extraSize, std::forward<ExtraSizeSerializeFn>(SerializeExtraSize));
} }
template <typename Cmd> template <typename Cmd>
void SerializeCommand(const Cmd& cmd, const ObjectIdProvider& objectIdProvider) { void SerializeCommand(const Cmd& cmd, const ObjectIdProvider& objectIdProvider) {
SerializeCommand(cmd, objectIdProvider, 0, [](char*) {}); SerializeCommand(cmd, objectIdProvider, 0, [](SerializeBuffer*) { return true; });
} }
template <typename Cmd, typename ExtraSizeSerializeFn> template <typename Cmd, typename ExtraSizeSerializeFn>
@ -59,8 +59,9 @@ namespace dawn_wire {
ExtraSizeSerializeFn&& SerializeExtraSize) { ExtraSizeSerializeFn&& SerializeExtraSize) {
SerializeCommandImpl( SerializeCommandImpl(
cmd, cmd,
[&objectIdProvider](const Cmd& cmd, size_t requiredSize, char* allocatedBuffer) { [&objectIdProvider](const Cmd& cmd, size_t requiredSize,
cmd.Serialize(requiredSize, allocatedBuffer, objectIdProvider); SerializeBuffer* serializeBuffer) {
return cmd.Serialize(requiredSize, serializeBuffer, objectIdProvider);
}, },
extraSize, std::forward<ExtraSizeSerializeFn>(SerializeExtraSize)); extraSize, std::forward<ExtraSizeSerializeFn>(SerializeExtraSize));
} }
@ -77,8 +78,13 @@ namespace dawn_wire {
if (requiredSize <= mMaxAllocationSize) { if (requiredSize <= mMaxAllocationSize) {
char* allocatedBuffer = static_cast<char*>(mSerializer->GetCmdSpace(requiredSize)); char* allocatedBuffer = static_cast<char*>(mSerializer->GetCmdSpace(requiredSize));
if (allocatedBuffer != nullptr) { if (allocatedBuffer != nullptr) {
SerializeCmd(cmd, requiredSize, allocatedBuffer); SerializeBuffer serializeBuffer(allocatedBuffer, requiredSize);
SerializeExtraSize(allocatedBuffer + commandSize); bool success = true;
success &= SerializeCmd(cmd, requiredSize, &serializeBuffer);
success &= SerializeExtraSize(&serializeBuffer);
if (DAWN_UNLIKELY(!success)) {
mSerializer->OnSerializeError();
}
} }
return; return;
} }
@ -87,8 +93,14 @@ namespace dawn_wire {
if (!cmdSpace) { if (!cmdSpace) {
return; return;
} }
SerializeCmd(cmd, requiredSize, cmdSpace.get()); SerializeBuffer serializeBuffer(cmdSpace.get(), requiredSize);
SerializeExtraSize(cmdSpace.get() + commandSize); bool success = true;
success &= SerializeCmd(cmd, requiredSize, &serializeBuffer);
success &= SerializeExtraSize(&serializeBuffer);
if (DAWN_UNLIKELY(!success)) {
mSerializer->OnSerializeError();
return;
}
SerializeChunkedCommand(cmdSpace.get(), requiredSize); SerializeChunkedCommand(cmdSpace.get(), requiredSize);
} }

26
src/dawn_wire/Wire.cpp Normal file
View File

@ -0,0 +1,26 @@
// Copyright 2021 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn_wire/Wire.h"
namespace dawn_wire {
CommandSerializer::~CommandSerializer() = default;
void CommandSerializer::OnSerializeError() {
}
CommandHandler::~CommandHandler() = default;
} // namespace dawn_wire

View File

@ -14,6 +14,7 @@
#include "dawn_wire/client/Buffer.h" #include "dawn_wire/client/Buffer.h"
#include "dawn_wire/WireCmd_autogen.h"
#include "dawn_wire/client/Client.h" #include "dawn_wire/client/Client.h"
#include "dawn_wire/client/Device.h" #include "dawn_wire/client/Device.h"
@ -73,18 +74,23 @@ namespace dawn_wire { namespace client {
cmd.handleCreateInfoLength = writeHandleCreateInfoLength; cmd.handleCreateInfoLength = writeHandleCreateInfoLength;
cmd.handleCreateInfo = nullptr; cmd.handleCreateInfo = nullptr;
wireClient->SerializeCommand(cmd, writeHandleCreateInfoLength, [&](char* cmdSpace) { wireClient->SerializeCommand(
cmd, writeHandleCreateInfoLength, [&](SerializeBuffer* serializeBuffer) {
if (descriptor->mappedAtCreation) { if (descriptor->mappedAtCreation) {
if (serializeBuffer->AvailableSize() != writeHandleCreateInfoLength) {
return false;
}
// Serialize the WriteHandle into the space after the command. // Serialize the WriteHandle into the space after the command.
writeHandle->SerializeCreate(cmdSpace); writeHandle->SerializeCreate(serializeBuffer->Buffer());
// Set the buffer state for the mapping at creation. The buffer now owns the write // Set the buffer state for the mapping at creation. The buffer now owns the
// handle.. // write handle..
buffer->mWriteHandle = std::move(writeHandle); buffer->mWriteHandle = std::move(writeHandle);
buffer->mMappedData = writeData; buffer->mMappedData = writeData;
buffer->mMapOffset = 0; buffer->mMapOffset = 0;
buffer->mMapSize = buffer->mSize; buffer->mMapSize = buffer->mSize;
} }
return true;
}); });
return ToAPI(buffer); return ToAPI(buffer);
} }
@ -199,14 +205,24 @@ namespace dawn_wire { namespace client {
// Step 3a. Fill the handle create info in the command. // Step 3a. Fill the handle create info in the command.
if (isReadMode) { if (isReadMode) {
cmd.handleCreateInfoLength = request.readHandle->SerializeCreateSize(); cmd.handleCreateInfoLength = request.readHandle->SerializeCreateSize();
client->SerializeCommand(cmd, cmd.handleCreateInfoLength, [&](char* cmdSpace) { client->SerializeCommand(
request.readHandle->SerializeCreate(cmdSpace); cmd, cmd.handleCreateInfoLength, [&](SerializeBuffer* serializeBuffer) {
bool success = serializeBuffer->AvailableSize() == cmd.handleCreateInfoLength;
if (success) {
request.readHandle->SerializeCreate(serializeBuffer->Buffer());
}
return success;
}); });
} else { } else {
ASSERT(isWriteMode); ASSERT(isWriteMode);
cmd.handleCreateInfoLength = request.writeHandle->SerializeCreateSize(); cmd.handleCreateInfoLength = request.writeHandle->SerializeCreateSize();
client->SerializeCommand(cmd, cmd.handleCreateInfoLength, [&](char* cmdSpace) { client->SerializeCommand(
request.writeHandle->SerializeCreate(cmdSpace); cmd, cmd.handleCreateInfoLength, [&](SerializeBuffer* serializeBuffer) {
bool success = serializeBuffer->AvailableSize() == cmd.handleCreateInfoLength;
if (success) {
request.writeHandle->SerializeCreate(serializeBuffer->Buffer());
}
return success;
}); });
} }
@ -334,10 +350,15 @@ namespace dawn_wire { namespace client {
cmd.writeFlushInfoLength = writeFlushInfoLength; cmd.writeFlushInfoLength = writeFlushInfoLength;
cmd.writeFlushInfo = nullptr; cmd.writeFlushInfo = nullptr;
client->SerializeCommand(cmd, writeFlushInfoLength, [&](char* cmdSpace) { client->SerializeCommand(
cmd, writeFlushInfoLength, [&](SerializeBuffer* serializeBuffer) {
bool success = serializeBuffer->AvailableSize() == writeFlushInfoLength;
if (success) {
// Serialize flush metadata into the space after the command. // Serialize flush metadata into the space after the command.
// This closes the handle for writing. // This closes the handle for writing.
mWriteHandle->SerializeFlush(cmdSpace); mWriteHandle->SerializeFlush(serializeBuffer->Buffer());
}
return success;
}); });
mWriteHandle = nullptr; mWriteHandle = nullptr;

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "common/Assert.h" #include "common/Assert.h"
#include "dawn_wire/WireCmd_autogen.h"
#include "dawn_wire/server/Server.h" #include "dawn_wire/server/Server.h"
#include <memory> #include <memory>
@ -242,11 +243,15 @@ namespace dawn_wire { namespace server {
data->readHandle->SerializeInitialDataSize(readData, data->size); data->readHandle->SerializeInitialDataSize(readData, data->size);
} }
SerializeCommand(cmd, cmd.readInitialDataInfoLength, [&](char* cmdSpace) { SerializeCommand(cmd, cmd.readInitialDataInfoLength, [&](SerializeBuffer* serializeBuffer) {
if (isSuccess) { if (isSuccess) {
if (isRead) { if (isRead) {
if (serializeBuffer->AvailableSize() != cmd.readInitialDataInfoLength) {
return false;
}
// Serialize the initialization message into the space after the command. // Serialize the initialization message into the space after the command.
data->readHandle->SerializeInitialData(readData, data->size, cmdSpace); data->readHandle->SerializeInitialData(readData, data->size,
serializeBuffer->Buffer());
// The in-flight map request returned successfully. // The in-flight map request returned successfully.
// Move the ReadHandle so it is owned by the buffer. // Move the ReadHandle so it is owned by the buffer.
bufferData->readHandle = std::move(data->readHandle); bufferData->readHandle = std::move(data->readHandle);
@ -261,6 +266,7 @@ namespace dawn_wire { namespace server {
data->size); data->size);
} }
} }
return true;
}); });
} }

View File

@ -25,7 +25,7 @@ namespace dawn_wire {
class DAWN_WIRE_EXPORT CommandSerializer { class DAWN_WIRE_EXPORT CommandSerializer {
public: public:
virtual ~CommandSerializer() = default; virtual ~CommandSerializer();
// Get space for serializing commands. // Get space for serializing commands.
// GetCmdSpace will never be called with a value larger than // GetCmdSpace will never be called with a value larger than
@ -34,11 +34,12 @@ namespace dawn_wire {
virtual void* GetCmdSpace(size_t size) = 0; virtual void* GetCmdSpace(size_t size) = 0;
virtual bool Flush() = 0; virtual bool Flush() = 0;
virtual size_t GetMaximumAllocationSize() const = 0; virtual size_t GetMaximumAllocationSize() const = 0;
virtual void OnSerializeError();
}; };
class DAWN_WIRE_EXPORT CommandHandler { class DAWN_WIRE_EXPORT CommandHandler {
public: public:
virtual ~CommandHandler() = default; virtual ~CommandHandler();
virtual const volatile char* HandleCommands(const volatile char* commands, size_t size) = 0; virtual const volatile char* HandleCommands(const volatile char* commands, size_t size) = 0;
}; };