From 51db53fa06b44e9394b0d1f6b08f4a1b5edb9c60 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Thu, 18 Feb 2021 19:28:29 +0000 Subject: [PATCH] dawn_wire: Harden deserialization routines - Encapsulate deserialize buffer and size into a DeserializeBuffer class. This limits the possible operations so we can be sure buffer/size are not manually mutated such that we consume more bytes than available. - Ensure that memberLength (on deserialization) doesn't narrow (or widen). Previously, values were always implicitly cast to size_t. - Slight optimization that removes "= nullptr" initialization for pointers written by DeserializeBuffer::Read. These are always written to on success, so we don't need to initialize to nullptr. Bug: dawn:680 Change-Id: I3779a343e85ff90810707148a952c6ba27cf9d22 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/41521 Commit-Queue: Austin Eng Reviewed-by: Stephen White --- generator/templates/dawn_wire/WireCmd.cpp | 118 ++++++++---------- generator/templates/dawn_wire/WireCmd.h | 12 +- .../dawn_wire/client/ClientHandlers.cpp | 17 +-- .../dawn_wire/client/ClientPrototypes.inc | 2 +- .../dawn_wire/server/ServerHandlers.cpp | 17 +-- .../dawn_wire/server/ServerPrototypes.inc | 2 +- 6 files changed, 88 insertions(+), 80 deletions(-) diff --git a/generator/templates/dawn_wire/WireCmd.cpp b/generator/templates/dawn_wire/WireCmd.cpp index 4ad397611b..a368ef283e 100644 --- a/generator/templates/dawn_wire/WireCmd.cpp +++ b/generator/templates/dawn_wire/WireCmd.cpp @@ -27,7 +27,7 @@ //* Outputs an rvalue that's the number of elements a pointer member points to. {% macro member_length(member, record_accessor) -%} {%- if member.length == "constant" -%} - {{member.constant_length}} + {{member.constant_length}}u {%- else -%} {{record_accessor}}{{as_varName(member.length.name)}} {%- endif -%} @@ -74,7 +74,7 @@ {%- set Optional = "Optional" if member.optional else "" -%} DESERIALIZE_TRY(resolver.Get{{Optional}}FromId({{in}}, &{{out}})); {%- elif member.type.category == "structure" -%} - DESERIALIZE_TRY({{as_cType(member.type.name)}}Deserialize(&{{out}}, &{{in}}, buffer, size, allocator + DESERIALIZE_TRY({{as_cType(member.type.name)}}Deserialize(&{{out}}, &{{in}}, deserializeBuffer, allocator {%- if member.type.may_have_dawn_object -%} , resolver {%- endif -%} @@ -174,11 +174,11 @@ namespace { {% endif %} { {% if member.annotation != "value" %} - size_t memberLength = {{member_length(member, "record.")}}; + auto memberLength = {{member_length(member, "record.")}}; result += memberLength * {{member_transfer_sizeof(member)}}; //* Structures might contain more pointers so we need to add their extra size as well. {% if member.type.category == "structure" %} - for (size_t i = 0; i < memberLength; ++i) { + for (decltype(memberLength) i = 0; i < memberLength; ++i) { {% if member.annotation == "const*const*" %} result += {{as_cType(member.type.name)}}GetExtraRequiredSize(*record.{{as_varName(member.name)}}[i]); {% else %} @@ -263,12 +263,12 @@ namespace { if (has_{{memberName}}) {% endif %} { - size_t memberLength = {{member_length(member, "record.")}}; + auto memberLength = {{member_length(member, "record.")}}; {{member_transfer_type(member)}}* memberBuffer; SERIALIZE_TRY(buffer->NextN(memberLength, &memberBuffer)); - for (size_t i = 0; i < memberLength; ++i) { + for (decltype(memberLength) i = 0; i < memberLength; ++i) { {{serialize_member(member, "record." + memberName + "[i]", "memberBuffer[i]" )}} } } @@ -281,14 +281,12 @@ namespace { //* if needed, using `allocator` to store pointed-to values and `resolver` to translate object //* Ids to actual objects. DAWN_DECLARE_UNUSED DeserializeResult {{Return}}{{name}}Deserialize({{Return}}{{name}}{{Cmd}}* record, const volatile {{Return}}{{name}}Transfer* transfer, - const volatile char** buffer, size_t* size, DeserializeAllocator* allocator + DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator {%- if record.may_have_dawn_object -%} , const ObjectIdResolver& resolver {%- endif -%} ) { DAWN_UNUSED(allocator); - DAWN_UNUSED(buffer); - DAWN_UNUSED(size); {% if is_cmd %} ASSERT(transfer->commandId == {{Return}}WireCmd::{{name}}); @@ -307,7 +305,7 @@ namespace { {% if record.extensible %} record->nextInChain = nullptr; if (transfer->hasNextInChain) { - DESERIALIZE_TRY(DeserializeChainedStruct(&record->nextInChain, buffer, size, allocator, resolver)); + DESERIALIZE_TRY(DeserializeChainedStruct(&record->nextInChain, deserializeBuffer, allocator, resolver)); } {% endif %} @@ -331,14 +329,14 @@ namespace { { size_t stringLength = transfer->{{memberName}}Strlen; if (stringLength == std::numeric_limits::max()) { - //* Cannot allocate space for the null terminator. + //* Cannot allocate enough space for the null terminator. return DeserializeResult::FatalError; } - const volatile char* stringInBuffer = nullptr; - DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, stringLength, &stringInBuffer)); + const volatile char* stringInBuffer; + DESERIALIZE_TRY(deserializeBuffer->ReadN(stringLength, &stringInBuffer)); - char* copiedString = nullptr; + char* copiedString; DESERIALIZE_TRY(GetSpace(allocator, stringLength + 1, &copiedString)); //* We can cast away the volatile qualifier because GetPtrFromBuffer already validated //* that the range [stringInBuffer, stringInBuffer + stringLength) is valid. @@ -360,16 +358,16 @@ namespace { if (has_{{memberName}}) {% endif %} { - size_t memberLength = {{member_length(member, "record->")}}; - auto memberBuffer = reinterpret_cast(buffer); - DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, memberLength, &memberBuffer)); + auto memberLength = {{member_length(member, "record->")}}; + const volatile {{member_transfer_type(member)}}* memberBuffer; + DESERIALIZE_TRY(deserializeBuffer->ReadN(memberLength, &memberBuffer)); - {{as_cType(member.type.name)}}* copiedMembers = nullptr; + {{as_cType(member.type.name)}}* copiedMembers; DESERIALIZE_TRY(GetSpace(allocator, memberLength, &copiedMembers)); {% if member.annotation == "const*const*" %} - {{as_cType(member.type.name)}}** pointerArray = nullptr; + {{as_cType(member.type.name)}}** pointerArray; DESERIALIZE_TRY(GetSpace(allocator, memberLength, &pointerArray)); - for (size_t i = 0; i < memberLength; ++i) { + for (decltype(memberLength) i = 0; i < memberLength; ++i) { pointerArray[i] = &copiedMembers[i]; } record->{{memberName}} = pointerArray; @@ -377,7 +375,7 @@ namespace { record->{{memberName}} = copiedMembers; {% endif %} - for (size_t i = 0; i < memberLength; ++i) { + for (decltype(memberLength) i = 0; i < memberLength; ++i) { {{deserialize_member(member, "memberBuffer[i]", "copiedMembers[i]")}} } } @@ -415,15 +413,15 @@ namespace { return true; } - DeserializeResult {{Cmd}}::Deserialize(const volatile char** buffer, size_t* size, DeserializeAllocator* allocator + DeserializeResult {{Cmd}}::Deserialize(DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator {%- if command.may_have_dawn_object -%} , const ObjectIdResolver& resolver {%- endif -%} ) { - const volatile {{Name}}Transfer* transfer = nullptr; - DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, 1, &transfer)); + const volatile {{Name}}Transfer* transfer; + DESERIALIZE_TRY(deserializeBuffer->Read(&transfer)); - return {{Name}}Deserialize(this, transfer, buffer, size, allocator + return {{Name}}Deserialize(this, transfer, deserializeBuffer, allocator {%- if command.may_have_dawn_object -%} , resolver {%- endif -%} @@ -474,6 +472,17 @@ namespace dawn_wire { return *this; } + template + template + bool BufferConsumer::Peek(T** data) { + if (sizeof(T) > mSize) { + return false; + } + + *data = reinterpret_cast(mBuffer); + return true; + } + template template bool BufferConsumer::Next(T** data) { @@ -508,25 +517,13 @@ namespace dawn_wire { mSize -= totalSize; return true; } + namespace { - - // Consumes from (buffer, size) enough memory to contain T[count] and return it in data. - // Returns FatalError if not enough memory was available - template - DeserializeResult GetPtrFromBuffer(const volatile char** buffer, size_t* size, size_t count, const volatile T** data) { - DeserializeBuffer deserializeBuffer(*buffer, *size); - DeserializeResult result = deserializeBuffer.ReadN(count, data); - if (result == DeserializeResult::Success) { - *buffer = deserializeBuffer.Buffer(); - *size = deserializeBuffer.AvailableSize(); - } - return result; - } - // Allocates enough space from allocator to countain T[count] and return it in out. // Return FatalError if the allocator couldn't allocate the memory. - template - DeserializeResult GetSpace(DeserializeAllocator* allocator, size_t count, T** out) { + // Always writes to |out| on success. + template + DeserializeResult GetSpace(DeserializeAllocator* allocator, N count, T** out) { constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits::max() / sizeof(T); if (count > kMaxCountWithoutOverflows) { return DeserializeResult::FatalError; @@ -546,8 +543,7 @@ namespace dawn_wire { SerializeBuffer* buffer, const ObjectIdProvider& provider); DeserializeResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext, - const volatile char** buffer, - size_t* size, + DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator, const ObjectIdResolver& resolver); @@ -631,25 +627,22 @@ namespace dawn_wire { } DeserializeResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext, - const volatile char** buffer, - size_t* size, + DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator, const ObjectIdResolver& resolver) { bool hasNext; do { - if (*size < sizeof(WGPUChainedStructTransfer)) { - return DeserializeResult::FatalError; - } - WGPUSType sType = - reinterpret_cast(*buffer)->sType; + const volatile WGPUChainedStructTransfer* header; + DESERIALIZE_TRY(deserializeBuffer->Peek(&header)); + WGPUSType sType = header->sType; switch (sType) { {% for sType in types["s type"].values if sType.valid and sType.name.CamelCase() not in client_side_structures %} {% set CType = as_cType(sType.name) %} case {{as_cEnum(types["s type"].name, sType.name)}}: { - const volatile {{CType}}Transfer* transfer = nullptr; - DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, 1, &transfer)); + const volatile {{CType}}Transfer* transfer; + DESERIALIZE_TRY(deserializeBuffer->Read(&transfer)); - {{CType}}* outStruct = nullptr; + {{CType}}* outStruct; DESERIALIZE_TRY(GetSpace(allocator, sizeof({{CType}}), &outStruct)); outStruct->chain.sType = sType; outStruct->chain.next = nullptr; @@ -657,7 +650,7 @@ namespace dawn_wire { *outChainNext = &outStruct->chain; outChainNext = &outStruct->chain.next; - DESERIALIZE_TRY({{CType}}Deserialize(outStruct, transfer, buffer, size, allocator + DESERIALIZE_TRY({{CType}}Deserialize(outStruct, transfer, deserializeBuffer, allocator {%- if types[sType.name.get()].may_have_dawn_object -%} , resolver {%- endif -%} @@ -673,10 +666,10 @@ namespace dawn_wire { dawn::WarningLog() << "Unknown sType " << sType << " discarded."; } - const volatile WGPUChainedStructTransfer* transfer = nullptr; - DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, 1, &transfer)); + const volatile WGPUChainedStructTransfer* transfer; + DESERIALIZE_TRY(deserializeBuffer->Read(&transfer)); - WGPUChainedStruct* outStruct = nullptr; + WGPUChainedStruct* outStruct; DESERIALIZE_TRY(GetSpace(allocator, sizeof(WGPUChainedStruct), &outStruct)); outStruct->sType = WGPUSType_Invalid; outStruct->next = nullptr; @@ -734,16 +727,15 @@ namespace dawn_wire { } bool DeserializeWGPUDeviceProperties(WGPUDeviceProperties* deviceProperties, - const volatile char* deserializeBuffer, - size_t deserializeBufferSize) { - const volatile WGPUDevicePropertiesTransfer* transfer = nullptr; - if (GetPtrFromBuffer(&deserializeBuffer, &deserializeBufferSize, 1, &transfer) != - DeserializeResult::Success) { + const volatile char* buffer, + size_t size) { + const volatile WGPUDevicePropertiesTransfer* transfer; + DeserializeBuffer deserializeBuffer(buffer, size); + if (deserializeBuffer.Read(&transfer) != DeserializeResult::Success) { return false; } return WGPUDevicePropertiesDeserialize(deviceProperties, transfer, &deserializeBuffer, - &deserializeBufferSize, nullptr) == DeserializeResult::Success; } diff --git a/generator/templates/dawn_wire/WireCmd.h b/generator/templates/dawn_wire/WireCmd.h index a1898b708d..e5be81e034 100644 --- a/generator/templates/dawn_wire/WireCmd.h +++ b/generator/templates/dawn_wire/WireCmd.h @@ -63,6 +63,9 @@ namespace dawn_wire { template DAWN_NO_DISCARD bool Next(T** data); + template + DAWN_NO_DISCARD bool Peek(T** data); + private: BufferT* mBuffer; size_t mSize; @@ -92,6 +95,13 @@ namespace dawn_wire { ? DeserializeResult::Success : DeserializeResult::FatalError; } + + template + DAWN_NO_DISCARD DeserializeResult Peek(const volatile T** data) { + return BufferConsumer::Peek(data) + ? DeserializeResult::Success + : DeserializeResult::FatalError; + } }; // Interface to allocate more space to deserialize pointed-to data. @@ -160,7 +170,7 @@ namespace dawn_wire { //* Deserialize returns: //* - Success if everything went well (yay!) //* - FatalError is something bad happened (buffer too small for example) - DeserializeResult Deserialize(const volatile char** buffer, size_t* size, DeserializeAllocator* allocator + DeserializeResult Deserialize(DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator {%- if command.may_have_dawn_object -%} , const ObjectIdResolver& resolver {%- endif -%} diff --git a/generator/templates/dawn_wire/client/ClientHandlers.cpp b/generator/templates/dawn_wire/client/ClientHandlers.cpp index 51122bf807..de1ca3c619 100644 --- a/generator/templates/dawn_wire/client/ClientHandlers.cpp +++ b/generator/templates/dawn_wire/client/ClientHandlers.cpp @@ -19,9 +19,9 @@ namespace dawn_wire { namespace client { {% for command in cmd_records["return command"] %} - bool Client::Handle{{command.name.CamelCase()}}(const volatile char** commands, size_t* size) { + bool Client::Handle{{command.name.CamelCase()}}(DeserializeBuffer* deserializeBuffer) { Return{{command.name.CamelCase()}}Cmd cmd; - DeserializeResult deserializeResult = cmd.Deserialize(commands, size, &mAllocator); + DeserializeResult deserializeResult = cmd.Deserialize(deserializeBuffer, &mAllocator); if (deserializeResult == DeserializeResult::FatalError) { return false; @@ -54,10 +54,12 @@ namespace dawn_wire { namespace client { {% endfor %} const volatile char* Client::HandleCommandsImpl(const volatile char* commands, size_t size) { - while (size >= sizeof(CmdHeader) + sizeof(ReturnWireCmd)) { + DeserializeBuffer deserializeBuffer(commands, size); + + while (deserializeBuffer.AvailableSize() >= sizeof(CmdHeader) + sizeof(ReturnWireCmd)) { // Start by chunked command handling, if it is done, then it means the whole buffer // was consumed by it, so we return a pointer to the end of the commands. - switch (HandleChunkedCommands(commands, size)) { + switch (HandleChunkedCommands(deserializeBuffer.Buffer(), deserializeBuffer.AvailableSize())) { case ChunkedCommandsResult::Consumed: return commands + size; case ChunkedCommandsResult::Error: @@ -66,13 +68,14 @@ namespace dawn_wire { namespace client { break; } - ReturnWireCmd cmdId = *reinterpret_cast(commands + sizeof(CmdHeader)); + ReturnWireCmd cmdId = *static_cast(static_cast( + deserializeBuffer.Buffer() + sizeof(CmdHeader))); bool success = false; switch (cmdId) { {% for command in cmd_records["return command"] %} {% set Suffix = command.name.CamelCase() %} case ReturnWireCmd::{{Suffix}}: - success = Handle{{Suffix}}(&commands, &size); + success = Handle{{Suffix}}(&deserializeBuffer); break; {% endfor %} default: @@ -85,7 +88,7 @@ namespace dawn_wire { namespace client { mAllocator.Reset(); } - if (size != 0) { + if (deserializeBuffer.AvailableSize() != 0) { return nullptr; } diff --git a/generator/templates/dawn_wire/client/ClientPrototypes.inc b/generator/templates/dawn_wire/client/ClientPrototypes.inc index df18896587..3a5f62fa32 100644 --- a/generator/templates/dawn_wire/client/ClientPrototypes.inc +++ b/generator/templates/dawn_wire/client/ClientPrototypes.inc @@ -14,7 +14,7 @@ //* Return command handlers {% for command in cmd_records["return command"] %} - bool Handle{{command.name.CamelCase()}}(const volatile char** commands, size_t* size); + bool Handle{{command.name.CamelCase()}}(DeserializeBuffer* deserializeBuffer); {% endfor %} //* Return command doers diff --git a/generator/templates/dawn_wire/server/ServerHandlers.cpp b/generator/templates/dawn_wire/server/ServerHandlers.cpp index f23a684c63..a544a505d1 100644 --- a/generator/templates/dawn_wire/server/ServerHandlers.cpp +++ b/generator/templates/dawn_wire/server/ServerHandlers.cpp @@ -23,9 +23,9 @@ namespace dawn_wire { namespace server { {% set Suffix = command.name.CamelCase() %} //* The generic command handlers - bool Server::Handle{{Suffix}}(const volatile char** commands, size_t* size) { + bool Server::Handle{{Suffix}}(DeserializeBuffer* deserializeBuffer) { {{Suffix}}Cmd cmd; - DeserializeResult deserializeResult = cmd.Deserialize(commands, size, &mAllocator + DeserializeResult deserializeResult = cmd.Deserialize(deserializeBuffer, &mAllocator {%- if command.may_have_dawn_object -%} , *this {%- endif -%} @@ -107,10 +107,12 @@ namespace dawn_wire { namespace server { {% endfor %} const volatile char* Server::HandleCommandsImpl(const volatile char* commands, size_t size) { - while (size >= sizeof(CmdHeader) + sizeof(WireCmd)) { + DeserializeBuffer deserializeBuffer(commands, size); + + while (deserializeBuffer.AvailableSize() >= sizeof(CmdHeader) + sizeof(WireCmd)) { // Start by chunked command handling, if it is done, then it means the whole buffer // was consumed by it, so we return a pointer to the end of the commands. - switch (HandleChunkedCommands(commands, size)) { + switch (HandleChunkedCommands(deserializeBuffer.Buffer(), deserializeBuffer.AvailableSize())) { case ChunkedCommandsResult::Consumed: return commands + size; case ChunkedCommandsResult::Error: @@ -119,12 +121,13 @@ namespace dawn_wire { namespace server { break; } - WireCmd cmdId = *reinterpret_cast(commands + sizeof(CmdHeader)); + WireCmd cmdId = *static_cast(static_cast( + deserializeBuffer.Buffer() + sizeof(CmdHeader))); bool success = false; switch (cmdId) { {% for command in cmd_records["command"] %} case WireCmd::{{command.name.CamelCase()}}: - success = Handle{{command.name.CamelCase()}}(&commands, &size); + success = Handle{{command.name.CamelCase()}}(&deserializeBuffer); break; {% endfor %} default: @@ -137,7 +140,7 @@ namespace dawn_wire { namespace server { mAllocator.Reset(); } - if (size != 0) { + if (deserializeBuffer.AvailableSize() != 0) { return nullptr; } diff --git a/generator/templates/dawn_wire/server/ServerPrototypes.inc b/generator/templates/dawn_wire/server/ServerPrototypes.inc index a9c03e2dbb..31af0ed13a 100644 --- a/generator/templates/dawn_wire/server/ServerPrototypes.inc +++ b/generator/templates/dawn_wire/server/ServerPrototypes.inc @@ -15,7 +15,7 @@ // Command handlers & doers {% for command in cmd_records["command"] %} {% set Suffix = command.name.CamelCase() %} - bool Handle{{Suffix}}(const volatile char** commands, size_t* size); + bool Handle{{Suffix}}(DeserializeBuffer* deserializeBuffer); bool Do{{Suffix}}( {%- for member in command.members -%}