diff --git a/generator/templates/dawn_wire/WireCmd.cpp b/generator/templates/dawn_wire/WireCmd.cpp index 87c1fe369b..f44fd6edf0 100644 --- a/generator/templates/dawn_wire/WireCmd.cpp +++ b/generator/templates/dawn_wire/WireCmd.cpp @@ -201,8 +201,8 @@ namespace { //* Serializes `record` into `transfer`, using `buffer` to get more space for pointed-to data //* and `provider` to serialize objects. - DAWN_DECLARE_UNUSED void {{Return}}{{name}}Serialize(const {{Return}}{{name}}{{Cmd}}& record, {{Return}}{{name}}Transfer* transfer, - char** buffer + DAWN_DECLARE_UNUSED bool {{Return}}{{name}}Serialize(const {{Return}}{{name}}{{Cmd}}& record, {{Return}}{{name}}Transfer* transfer, + SerializeBuffer* buffer {%- if record.may_have_dawn_object -%} , const ObjectIdProvider& provider {%- endif -%} @@ -223,7 +223,7 @@ namespace { {% if record.extensible %} if (record.nextInChain != nullptr) { transfer->hasNextInChain = true; - SerializeChainedStruct(record.nextInChain, buffer, provider); + SERIALIZE_TRY(SerializeChainedStruct(record.nextInChain, buffer, provider)); } else { transfer->hasNextInChain = false; } @@ -245,10 +245,11 @@ namespace { if (has_{{memberName}}) {% endif %} { - transfer->{{memberName}}Strlen = std::strlen(record.{{memberName}}); + transfer->{{memberName}}Strlen = std::strlen(record.{{memberName}}); - memcpy(*buffer, record.{{memberName}}, transfer->{{memberName}}Strlen); - *buffer += transfer->{{memberName}}Strlen; + char* stringInBuffer; + SERIALIZE_TRY(buffer->NextN(transfer->{{memberName}}Strlen, &stringInBuffer)); + memcpy(stringInBuffer, record.{{memberName}}, transfer->{{memberName}}Strlen); } {% endfor %} @@ -263,14 +264,16 @@ namespace { {% endif %} { size_t memberLength = {{member_length(member, "record.")}}; - auto memberBuffer = reinterpret_cast<{{member_transfer_type(member)}}*>(*buffer); - *buffer += memberLength * {{member_transfer_sizeof(member)}}; + + {{member_transfer_type(member)}}* memberBuffer; + SERIALIZE_TRY(buffer->NextN(memberLength, &memberBuffer)); for (size_t i = 0; i < memberLength; ++i) { {{serialize_member(member, "record." + memberName + "[i]", "memberBuffer[i]" )}} } } {% endfor %} + return true; } DAWN_UNUSED_FUNC({{Return}}{{name}}Serialize); @@ -386,20 +389,21 @@ namespace { return size; } - void {{Cmd}}::Serialize(size_t commandSize, char* buffer + bool {{Cmd}}::Serialize(size_t commandSize, SerializeBuffer* buffer {%- if not is_return -%} , const ObjectIdProvider& objectIdProvider {%- endif -%} ) const { - auto transfer = reinterpret_cast<{{Name}}Transfer*>(buffer); + {{Name}}Transfer* transfer; + SERIALIZE_TRY(buffer->Next(&transfer)); transfer->commandSize = commandSize; - buffer += sizeof({{Name}}Transfer); - {{Name}}Serialize(*this, transfer, &buffer + SERIALIZE_TRY({{Name}}Serialize(*this, transfer, buffer {%- if command.may_have_dawn_object -%} , objectIdProvider {%- endif -%} - ); + )); + return true; } DeserializeResult {{Cmd}}::Deserialize(const volatile char** buffer, size_t* size, DeserializeAllocator* allocator @@ -429,6 +433,13 @@ namespace dawn_wire { } \ } while (0) +#define SERIALIZE_TRY(EXPR) \ + do { \ + if (!(EXPR)) { \ + return false; \ + } \ + } while (0) + ObjectHandle::ObjectHandle() = default; ObjectHandle::ObjectHandle(ObjectId id, ObjectGeneration generation) : id(id), generation(generation) { @@ -454,27 +465,53 @@ namespace dawn_wire { return *this; } + template + template + bool BufferConsumer::Next(T** data) { + if (sizeof(T) > mSize) { + return false; + } + + *data = reinterpret_cast(mBuffer); + mBuffer += sizeof(T); + mSize -= sizeof(T); + return true; + } + + template + template + bool BufferConsumer::NextN(N count, T** data) { + static_assert(std::is_unsigned::value, "|count| argument of NextN must be unsigned."); + + constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits::max() / sizeof(T); + if (count > kMaxCountWithoutOverflows) { + return false; + } + + // Cannot overflow because |count| is not greater than |kMaxCountWithoutOverflows|. + size_t totalSize = sizeof(T) * count; + if (totalSize > mSize) { + return false; + } + + *data = reinterpret_cast(mBuffer); + mBuffer += totalSize; + mSize -= totalSize; + return true; + } namespace { // Consumes from (buffer, size) enough memory to contain T[count] and return it in data. // Returns FatalError if not enough memory was available template DeserializeResult GetPtrFromBuffer(const volatile char** buffer, size_t* size, size_t count, const volatile T** data) { - constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits::max() / sizeof(T); - if (count > kMaxCountWithoutOverflows) { - return DeserializeResult::FatalError; + DeserializeBuffer deserializeBuffer(*buffer, *size); + DeserializeResult result = deserializeBuffer.ReadN(count, data); + if (result == DeserializeResult::Success) { + *buffer = deserializeBuffer.Buffer(); + *size = deserializeBuffer.AvailableSize(); } - - size_t totalSize = sizeof(T) * count; - if (totalSize > *size) { - return DeserializeResult::FatalError; - } - - *data = reinterpret_cast(*buffer); - *buffer += totalSize; - *size -= totalSize; - - return DeserializeResult::Success; + return result; } // Allocates enough space from allocator to countain T[count] and return it in out. @@ -496,9 +533,9 @@ namespace dawn_wire { } size_t GetChainedStructExtraRequiredSize(const WGPUChainedStruct* chainedStruct); - void SerializeChainedStruct(WGPUChainedStruct const* chainedStruct, - char** buffer, - const ObjectIdProvider& provider); + DAWN_NO_DISCARD bool SerializeChainedStruct(WGPUChainedStruct const* chainedStruct, + SerializeBuffer* buffer, + const ObjectIdProvider& provider); DeserializeResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext, const volatile char** buffer, size_t* size, @@ -538,9 +575,9 @@ namespace dawn_wire { return result; } - void SerializeChainedStruct(WGPUChainedStruct const* chainedStruct, - char** buffer, - const ObjectIdProvider& provider) { + DAWN_NO_DISCARD bool SerializeChainedStruct(WGPUChainedStruct const* chainedStruct, + SerializeBuffer* buffer, + const ObjectIdProvider& provider) { ASSERT(chainedStruct != nullptr); ASSERT(buffer != nullptr); do { @@ -549,16 +586,16 @@ namespace dawn_wire { {% set CType = as_cType(sType.name) %} case {{as_cEnum(types["s type"].name, sType.name)}}: { - auto* transfer = reinterpret_cast<{{CType}}Transfer*>(*buffer); + {{CType}}Transfer* transfer; + SERIALIZE_TRY(buffer->Next(&transfer)); transfer->chain.sType = chainedStruct->sType; transfer->chain.hasNext = chainedStruct->next != nullptr; - *buffer += sizeof({{CType}}Transfer); - {{CType}}Serialize(*reinterpret_cast<{{CType}} const*>(chainedStruct), transfer, buffer + SERIALIZE_TRY({{CType}}Serialize(*reinterpret_cast<{{CType}} const*>(chainedStruct), transfer, buffer {%- if types[sType.name.get()].may_have_dawn_object -%} , provider {%- endif -%} - ); + )); chainedStruct = chainedStruct->next; } break; @@ -570,18 +607,18 @@ namespace dawn_wire { dawn::WarningLog() << "Unknown sType " << chainedStruct->sType << " discarded."; } - WGPUChainedStructTransfer* transfer = reinterpret_cast(*buffer); + WGPUChainedStructTransfer* transfer; + SERIALIZE_TRY(buffer->Next(&transfer)); transfer->sType = WGPUSType_Invalid; transfer->hasNext = chainedStruct->next != nullptr; - *buffer += sizeof(WGPUChainedStructTransfer); - // Still move on in case there are valid structs after this. chainedStruct = chainedStruct->next; break; } } } while (chainedStruct != nullptr); + return true; } DeserializeResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext, @@ -677,13 +714,14 @@ namespace dawn_wire { } void SerializeWGPUDeviceProperties(const WGPUDeviceProperties* deviceProperties, - char* serializeBuffer) { - size_t devicePropertiesSize = SerializedWGPUDevicePropertiesSize(deviceProperties); - WGPUDevicePropertiesTransfer* transfer = - reinterpret_cast(serializeBuffer); - serializeBuffer += devicePropertiesSize; + char* buffer) { + SerializeBuffer serializeBuffer(buffer, SerializedWGPUDevicePropertiesSize(deviceProperties)); - WGPUDevicePropertiesSerialize(*deviceProperties, transfer, &serializeBuffer); + WGPUDevicePropertiesTransfer* transfer; + bool success = + serializeBuffer.Next(&transfer) && + WGPUDevicePropertiesSerialize(*deviceProperties, transfer, &serializeBuffer); + ASSERT(success); } bool DeserializeWGPUDeviceProperties(WGPUDeviceProperties* deviceProperties, diff --git a/generator/templates/dawn_wire/WireCmd.h b/generator/templates/dawn_wire/WireCmd.h index a592d2f8d2..a1898b708d 100644 --- a/generator/templates/dawn_wire/WireCmd.h +++ b/generator/templates/dawn_wire/WireCmd.h @@ -48,6 +48,52 @@ namespace dawn_wire { FatalError, }; + template + class BufferConsumer { + public: + BufferConsumer(BufferT* buffer, size_t size) : mBuffer(buffer), mSize(size) {} + + BufferT* Buffer() const { return mBuffer; } + size_t AvailableSize() const { return mSize; } + + protected: + template + DAWN_NO_DISCARD bool NextN(N count, T** data); + + template + DAWN_NO_DISCARD bool Next(T** data); + + private: + BufferT* mBuffer; + size_t mSize; + }; + + class SerializeBuffer : public BufferConsumer { + public: + using BufferConsumer::BufferConsumer; + using BufferConsumer::NextN; + using BufferConsumer::Next; + }; + + class DeserializeBuffer : public BufferConsumer { + public: + using BufferConsumer::BufferConsumer; + + template + DAWN_NO_DISCARD DeserializeResult ReadN(N count, const volatile T** data) { + return NextN(count, data) + ? DeserializeResult::Success + : DeserializeResult::FatalError; + } + + template + DAWN_NO_DISCARD DeserializeResult Read(const volatile T** data) { + return Next(data) + ? DeserializeResult::Success + : DeserializeResult::FatalError; + } + }; + // Interface to allocate more space to deserialize pointed-to data. // nullptr is treated as an error. class DeserializeAllocator { @@ -101,7 +147,7 @@ namespace dawn_wire { //* Serialize the structure and everything it points to into serializeBuffer which must be //* big enough to contain all the data (as queried from GetRequiredSize). - void Serialize(size_t commandSize, char* serializeBuffer + DAWN_NO_DISCARD bool Serialize(size_t commandSize, SerializeBuffer* serializeBuffer {%- if not is_return_command -%} , const ObjectIdProvider& objectIdProvider {%- endif -%} diff --git a/src/dawn_wire/BUILD.gn b/src/dawn_wire/BUILD.gn index 04340b006d..4b212996ce 100644 --- a/src/dawn_wire/BUILD.gn +++ b/src/dawn_wire/BUILD.gn @@ -63,6 +63,7 @@ dawn_component("dawn_wire") { "ChunkedCommandHandler.h", "ChunkedCommandSerializer.cpp", "ChunkedCommandSerializer.h", + "Wire.cpp", "WireClient.cpp", "WireDeserializeAllocator.cpp", "WireDeserializeAllocator.h", diff --git a/src/dawn_wire/CMakeLists.txt b/src/dawn_wire/CMakeLists.txt index 2776bb4a8c..d6d430f588 100644 --- a/src/dawn_wire/CMakeLists.txt +++ b/src/dawn_wire/CMakeLists.txt @@ -35,6 +35,7 @@ target_sources(dawn_wire PRIVATE "ChunkedCommandHandler.h" "ChunkedCommandSerializer.cpp" "ChunkedCommandSerializer.h" + "Wire.cpp" "WireClient.cpp" "WireDeserializeAllocator.cpp" "WireDeserializeAllocator.h" diff --git a/src/dawn_wire/ChunkedCommandSerializer.h b/src/dawn_wire/ChunkedCommandSerializer.h index 1f21dcdcab..e62cb99d79 100644 --- a/src/dawn_wire/ChunkedCommandSerializer.h +++ b/src/dawn_wire/ChunkedCommandSerializer.h @@ -32,7 +32,7 @@ namespace dawn_wire { template void SerializeCommand(const Cmd& cmd) { - SerializeCommand(cmd, 0, [](char*) {}); + SerializeCommand(cmd, 0, [](SerializeBuffer*) { return true; }); } template @@ -41,15 +41,15 @@ namespace dawn_wire { ExtraSizeSerializeFn&& SerializeExtraSize) { SerializeCommandImpl( cmd, - [](const Cmd& cmd, size_t requiredSize, char* allocatedBuffer) { - cmd.Serialize(requiredSize, allocatedBuffer); + [](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) { + return cmd.Serialize(requiredSize, serializeBuffer); }, extraSize, std::forward(SerializeExtraSize)); } template void SerializeCommand(const Cmd& cmd, const ObjectIdProvider& objectIdProvider) { - SerializeCommand(cmd, objectIdProvider, 0, [](char*) {}); + SerializeCommand(cmd, objectIdProvider, 0, [](SerializeBuffer*) { return true; }); } template @@ -59,8 +59,9 @@ namespace dawn_wire { ExtraSizeSerializeFn&& SerializeExtraSize) { SerializeCommandImpl( cmd, - [&objectIdProvider](const Cmd& cmd, size_t requiredSize, char* allocatedBuffer) { - cmd.Serialize(requiredSize, allocatedBuffer, objectIdProvider); + [&objectIdProvider](const Cmd& cmd, size_t requiredSize, + SerializeBuffer* serializeBuffer) { + return cmd.Serialize(requiredSize, serializeBuffer, objectIdProvider); }, extraSize, std::forward(SerializeExtraSize)); } @@ -77,8 +78,13 @@ namespace dawn_wire { if (requiredSize <= mMaxAllocationSize) { char* allocatedBuffer = static_cast(mSerializer->GetCmdSpace(requiredSize)); if (allocatedBuffer != nullptr) { - SerializeCmd(cmd, requiredSize, allocatedBuffer); - SerializeExtraSize(allocatedBuffer + commandSize); + SerializeBuffer serializeBuffer(allocatedBuffer, requiredSize); + bool success = true; + success &= SerializeCmd(cmd, requiredSize, &serializeBuffer); + success &= SerializeExtraSize(&serializeBuffer); + if (DAWN_UNLIKELY(!success)) { + mSerializer->OnSerializeError(); + } } return; } @@ -87,8 +93,14 @@ namespace dawn_wire { if (!cmdSpace) { return; } - SerializeCmd(cmd, requiredSize, cmdSpace.get()); - SerializeExtraSize(cmdSpace.get() + commandSize); + SerializeBuffer serializeBuffer(cmdSpace.get(), requiredSize); + bool success = true; + success &= SerializeCmd(cmd, requiredSize, &serializeBuffer); + success &= SerializeExtraSize(&serializeBuffer); + if (DAWN_UNLIKELY(!success)) { + mSerializer->OnSerializeError(); + return; + } SerializeChunkedCommand(cmdSpace.get(), requiredSize); } diff --git a/src/dawn_wire/Wire.cpp b/src/dawn_wire/Wire.cpp new file mode 100644 index 0000000000..89e5ac192b --- /dev/null +++ b/src/dawn_wire/Wire.cpp @@ -0,0 +1,26 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn_wire/Wire.h" + +namespace dawn_wire { + + CommandSerializer::~CommandSerializer() = default; + + void CommandSerializer::OnSerializeError() { + } + + CommandHandler::~CommandHandler() = default; + +} // namespace dawn_wire diff --git a/src/dawn_wire/client/Buffer.cpp b/src/dawn_wire/client/Buffer.cpp index da3691c584..4f4598423f 100644 --- a/src/dawn_wire/client/Buffer.cpp +++ b/src/dawn_wire/client/Buffer.cpp @@ -14,6 +14,7 @@ #include "dawn_wire/client/Buffer.h" +#include "dawn_wire/WireCmd_autogen.h" #include "dawn_wire/client/Client.h" #include "dawn_wire/client/Device.h" @@ -73,19 +74,24 @@ namespace dawn_wire { namespace client { cmd.handleCreateInfoLength = writeHandleCreateInfoLength; cmd.handleCreateInfo = nullptr; - wireClient->SerializeCommand(cmd, writeHandleCreateInfoLength, [&](char* cmdSpace) { - if (descriptor->mappedAtCreation) { - // Serialize the WriteHandle into the space after the command. - writeHandle->SerializeCreate(cmdSpace); + wireClient->SerializeCommand( + cmd, writeHandleCreateInfoLength, [&](SerializeBuffer* serializeBuffer) { + if (descriptor->mappedAtCreation) { + if (serializeBuffer->AvailableSize() != writeHandleCreateInfoLength) { + return false; + } + // Serialize the WriteHandle into the space after the command. + writeHandle->SerializeCreate(serializeBuffer->Buffer()); - // Set the buffer state for the mapping at creation. The buffer now owns the write - // handle.. - buffer->mWriteHandle = std::move(writeHandle); - buffer->mMappedData = writeData; - buffer->mMapOffset = 0; - buffer->mMapSize = buffer->mSize; - } - }); + // Set the buffer state for the mapping at creation. The buffer now owns the + // write handle.. + buffer->mWriteHandle = std::move(writeHandle); + buffer->mMappedData = writeData; + buffer->mMapOffset = 0; + buffer->mMapSize = buffer->mSize; + } + return true; + }); return ToAPI(buffer); } @@ -199,15 +205,25 @@ namespace dawn_wire { namespace client { // Step 3a. Fill the handle create info in the command. if (isReadMode) { cmd.handleCreateInfoLength = request.readHandle->SerializeCreateSize(); - client->SerializeCommand(cmd, cmd.handleCreateInfoLength, [&](char* cmdSpace) { - request.readHandle->SerializeCreate(cmdSpace); - }); + client->SerializeCommand( + cmd, cmd.handleCreateInfoLength, [&](SerializeBuffer* serializeBuffer) { + bool success = serializeBuffer->AvailableSize() == cmd.handleCreateInfoLength; + if (success) { + request.readHandle->SerializeCreate(serializeBuffer->Buffer()); + } + return success; + }); } else { ASSERT(isWriteMode); cmd.handleCreateInfoLength = request.writeHandle->SerializeCreateSize(); - client->SerializeCommand(cmd, cmd.handleCreateInfoLength, [&](char* cmdSpace) { - request.writeHandle->SerializeCreate(cmdSpace); - }); + client->SerializeCommand( + cmd, cmd.handleCreateInfoLength, [&](SerializeBuffer* serializeBuffer) { + bool success = serializeBuffer->AvailableSize() == cmd.handleCreateInfoLength; + if (success) { + request.writeHandle->SerializeCreate(serializeBuffer->Buffer()); + } + return success; + }); } // Step 4. Register this request so that we can retrieve it from its serial when the server @@ -334,11 +350,16 @@ namespace dawn_wire { namespace client { cmd.writeFlushInfoLength = writeFlushInfoLength; cmd.writeFlushInfo = nullptr; - client->SerializeCommand(cmd, writeFlushInfoLength, [&](char* cmdSpace) { - // Serialize flush metadata into the space after the command. - // This closes the handle for writing. - mWriteHandle->SerializeFlush(cmdSpace); - }); + client->SerializeCommand( + cmd, writeFlushInfoLength, [&](SerializeBuffer* serializeBuffer) { + bool success = serializeBuffer->AvailableSize() == writeFlushInfoLength; + if (success) { + // Serialize flush metadata into the space after the command. + // This closes the handle for writing. + mWriteHandle->SerializeFlush(serializeBuffer->Buffer()); + } + return success; + }); mWriteHandle = nullptr; } else if (mReadHandle) { diff --git a/src/dawn_wire/server/ServerBuffer.cpp b/src/dawn_wire/server/ServerBuffer.cpp index 7cc5c9b745..c2798457cd 100644 --- a/src/dawn_wire/server/ServerBuffer.cpp +++ b/src/dawn_wire/server/ServerBuffer.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "common/Assert.h" +#include "dawn_wire/WireCmd_autogen.h" #include "dawn_wire/server/Server.h" #include @@ -242,11 +243,15 @@ namespace dawn_wire { namespace server { data->readHandle->SerializeInitialDataSize(readData, data->size); } - SerializeCommand(cmd, cmd.readInitialDataInfoLength, [&](char* cmdSpace) { + SerializeCommand(cmd, cmd.readInitialDataInfoLength, [&](SerializeBuffer* serializeBuffer) { if (isSuccess) { if (isRead) { + if (serializeBuffer->AvailableSize() != cmd.readInitialDataInfoLength) { + return false; + } // Serialize the initialization message into the space after the command. - data->readHandle->SerializeInitialData(readData, data->size, cmdSpace); + data->readHandle->SerializeInitialData(readData, data->size, + serializeBuffer->Buffer()); // The in-flight map request returned successfully. // Move the ReadHandle so it is owned by the buffer. bufferData->readHandle = std::move(data->readHandle); @@ -261,6 +266,7 @@ namespace dawn_wire { namespace server { data->size); } } + return true; }); } diff --git a/src/include/dawn_wire/Wire.h b/src/include/dawn_wire/Wire.h index bb4670158c..590a6cce1a 100644 --- a/src/include/dawn_wire/Wire.h +++ b/src/include/dawn_wire/Wire.h @@ -25,7 +25,7 @@ namespace dawn_wire { class DAWN_WIRE_EXPORT CommandSerializer { public: - virtual ~CommandSerializer() = default; + virtual ~CommandSerializer(); // Get space for serializing commands. // GetCmdSpace will never be called with a value larger than @@ -34,11 +34,12 @@ namespace dawn_wire { virtual void* GetCmdSpace(size_t size) = 0; virtual bool Flush() = 0; virtual size_t GetMaximumAllocationSize() const = 0; + virtual void OnSerializeError(); }; class DAWN_WIRE_EXPORT CommandHandler { public: - virtual ~CommandHandler() = default; + virtual ~CommandHandler(); virtual const volatile char* HandleCommands(const volatile char* commands, size_t size) = 0; };