diff --git a/generator/templates/dawn_wire/WireCmd.cpp b/generator/templates/dawn_wire/WireCmd.cpp index 07c5e22484..4816ee122c 100644 --- a/generator/templates/dawn_wire/WireCmd.cpp +++ b/generator/templates/dawn_wire/WireCmd.cpp @@ -16,6 +16,7 @@ #include "common/Assert.h" #include "common/Log.h" +#include "dawn_wire/BufferConsumer_impl.h" #include "dawn_wire/Wire.h" #include @@ -60,9 +61,9 @@ {% elif member.type.category == "structure"%} {%- set Provider = ", provider" if member.type.may_have_dawn_object else "" -%} {% if member.annotation == "const*const*" %} - {{as_cType(member.type.name)}}Serialize(*{{in}}, &{{out}}, buffer{{Provider}}); + WIRE_TRY({{as_cType(member.type.name)}}Serialize(*{{in}}, &{{out}}, buffer{{Provider}})); {% else %} - {{as_cType(member.type.name)}}Serialize({{in}}, &{{out}}, buffer{{Provider}}); + WIRE_TRY({{as_cType(member.type.name)}}Serialize({{in}}, &{{out}}, buffer{{Provider}})); {% endif %} {%- else -%} {{out}} = {{in}}; @@ -73,9 +74,9 @@ {% macro deserialize_member(member, in, out) %} {%- if member.type.category == "object" -%} {%- set Optional = "Optional" if member.optional else "" -%} - DESERIALIZE_TRY(resolver.Get{{Optional}}FromId({{in}}, &{{out}})); + WIRE_TRY(resolver.Get{{Optional}}FromId({{in}}, &{{out}})); {%- elif member.type.category == "structure" -%} - DESERIALIZE_TRY({{as_cType(member.type.name)}}Deserialize(&{{out}}, &{{in}}, deserializeBuffer, allocator + WIRE_TRY({{as_cType(member.type.name)}}Deserialize(&{{out}}, &{{in}}, deserializeBuffer, allocator {%- if member.type.may_have_dawn_object -%} , resolver {%- endif -%} @@ -203,7 +204,7 @@ namespace { //* Serializes `record` into `transfer`, using `buffer` to get more space for pointed-to data //* and `provider` to serialize objects. - DAWN_DECLARE_UNUSED bool {{Return}}{{name}}Serialize(const {{Return}}{{name}}{{Cmd}}& record, {{Return}}{{name}}Transfer* transfer, + DAWN_DECLARE_UNUSED WireResult {{Return}}{{name}}Serialize(const {{Return}}{{name}}{{Cmd}}& record, {{Return}}{{name}}Transfer* transfer, SerializeBuffer* buffer {%- if record.may_have_dawn_object -%} , const ObjectIdProvider& provider @@ -225,7 +226,7 @@ namespace { {% if record.extensible %} if (record.nextInChain != nullptr) { transfer->hasNextInChain = true; - SERIALIZE_TRY(SerializeChainedStruct(record.nextInChain, buffer, provider)); + WIRE_TRY(SerializeChainedStruct(record.nextInChain, buffer, provider)); } else { transfer->hasNextInChain = false; } @@ -250,7 +251,7 @@ namespace { transfer->{{memberName}}Strlen = std::strlen(record.{{memberName}}); char* stringInBuffer; - SERIALIZE_TRY(buffer->NextN(transfer->{{memberName}}Strlen, &stringInBuffer)); + WIRE_TRY(buffer->NextN(transfer->{{memberName}}Strlen, &stringInBuffer)); memcpy(stringInBuffer, record.{{memberName}}, transfer->{{memberName}}Strlen); } {% endfor %} @@ -268,7 +269,7 @@ namespace { auto memberLength = {{member_length(member, "record.")}}; {{member_transfer_type(member)}}* memberBuffer; - SERIALIZE_TRY(buffer->NextN(memberLength, &memberBuffer)); + WIRE_TRY(buffer->NextN(memberLength, &memberBuffer)); //* This loop cannot overflow because it iterates up to |memberLength|. Even if //* memberLength were the maximum integer value, |i| would become equal to it just before @@ -278,14 +279,14 @@ namespace { } } {% endfor %} - return true; + return WireResult::Success; } DAWN_UNUSED_FUNC({{Return}}{{name}}Serialize); //* Deserializes `transfer` into `record` getting more serialized data from `buffer` and `size` //* if needed, using `allocator` to store pointed-to values and `resolver` to translate object //* Ids to actual objects. - DAWN_DECLARE_UNUSED DeserializeResult {{Return}}{{name}}Deserialize({{Return}}{{name}}{{Cmd}}* record, const volatile {{Return}}{{name}}Transfer* transfer, + DAWN_DECLARE_UNUSED WireResult {{Return}}{{name}}Deserialize({{Return}}{{name}}{{Cmd}}* record, const volatile {{Return}}{{name}}Transfer* transfer, DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator {%- if record.may_have_dawn_object -%} , const ObjectIdResolver& resolver @@ -310,7 +311,7 @@ namespace { {% if record.extensible %} record->nextInChain = nullptr; if (transfer->hasNextInChain) { - DESERIALIZE_TRY(DeserializeChainedStruct(&record->nextInChain, deserializeBuffer, allocator, resolver)); + WIRE_TRY(DeserializeChainedStruct(&record->nextInChain, deserializeBuffer, allocator, resolver)); } {% endif %} @@ -336,15 +337,15 @@ namespace { if (stringLength64 >= std::numeric_limits::max()) { //* Cannot allocate space for the string. It can be at most //* size_t::max() - 1. We need 1 byte for the null-terminator. - return DeserializeResult::FatalError; + return WireResult::FatalError; } size_t stringLength = static_cast(stringLength64); const volatile char* stringInBuffer; - DESERIALIZE_TRY(deserializeBuffer->ReadN(stringLength, &stringInBuffer)); + WIRE_TRY(deserializeBuffer->ReadN(stringLength, &stringInBuffer)); char* copiedString; - DESERIALIZE_TRY(GetSpace(allocator, stringLength + 1, &copiedString)); + WIRE_TRY(GetSpace(allocator, stringLength + 1, &copiedString)); //* We can cast away the volatile qualifier because DeserializeBuffer::ReadN already //* validated that the range [stringInBuffer, stringInBuffer + stringLength) is valid. //* memcpy may have an unknown access pattern, but this is fine since the string is only @@ -367,13 +368,14 @@ namespace { { auto memberLength = {{member_length(member, "record->")}}; const volatile {{member_transfer_type(member)}}* memberBuffer; - DESERIALIZE_TRY(deserializeBuffer->ReadN(memberLength, &memberBuffer)); + WIRE_TRY(deserializeBuffer->ReadN(memberLength, &memberBuffer)); {{as_cType(member.type.name)}}* copiedMembers; - DESERIALIZE_TRY(GetSpace(allocator, memberLength, &copiedMembers)); + WIRE_TRY(GetSpace(allocator, memberLength, &copiedMembers)); {% if member.annotation == "const*const*" %} {{as_cType(member.type.name)}}** pointerArray; - DESERIALIZE_TRY(GetSpace(allocator, memberLength, &pointerArray)); + WIRE_TRY(GetSpace(allocator, memberLength, &pointerArray)); + //* This loop cannot overflow because it iterates up to |memberLength|. Even if //* memberLength were the maximum integer value, |i| would become equal to it just before //* exiting the loop, but not increment past or wrap around. @@ -394,7 +396,7 @@ namespace { } {% endfor %} - return DeserializeResult::Success; + return WireResult::Success; } DAWN_UNUSED_FUNC({{Return}}{{name}}Deserialize); {% endmacro %} @@ -409,30 +411,30 @@ namespace { return size; } - bool {{Cmd}}::Serialize(size_t commandSize, SerializeBuffer* buffer + WireResult {{Cmd}}::Serialize(size_t commandSize, SerializeBuffer* buffer {%- if not is_return -%} , const ObjectIdProvider& objectIdProvider {%- endif -%} ) const { {{Name}}Transfer* transfer; - SERIALIZE_TRY(buffer->Next(&transfer)); + WIRE_TRY(buffer->Next(&transfer)); transfer->commandSize = commandSize; - SERIALIZE_TRY({{Name}}Serialize(*this, transfer, buffer + WIRE_TRY({{Name}}Serialize(*this, transfer, buffer {%- if command.may_have_dawn_object -%} , objectIdProvider {%- endif -%} )); - return true; + return WireResult::Success; } - DeserializeResult {{Cmd}}::Deserialize(DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator + WireResult {{Cmd}}::Deserialize(DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator {%- if command.may_have_dawn_object -%} , const ObjectIdResolver& resolver {%- endif -%} ) { const volatile {{Name}}Transfer* transfer; - DESERIALIZE_TRY(deserializeBuffer->Read(&transfer)); + WIRE_TRY(deserializeBuffer->Read(&transfer)); return {{Name}}Deserialize(this, transfer, deserializeBuffer, allocator {%- if command.may_have_dawn_object -%} @@ -444,22 +446,6 @@ namespace { namespace dawn_wire { - // Macro to simplify error handling, similar to DAWN_TRY but for DeserializeResult. -#define DESERIALIZE_TRY(EXPR) \ - do { \ - DeserializeResult exprResult = EXPR; \ - if (exprResult != DeserializeResult::Success) { \ - return exprResult; \ - } \ - } while (0) - -#define SERIALIZE_TRY(EXPR) \ - do { \ - if (!(EXPR)) { \ - return false; \ - } \ - } while (0) - ObjectHandle::ObjectHandle() = default; ObjectHandle::ObjectHandle(ObjectId id, ObjectGeneration generation) : id(id), generation(generation) { @@ -485,77 +471,31 @@ namespace dawn_wire { return *this; } - template - template - bool BufferConsumer::Peek(T** data) { - if (sizeof(T) > mSize) { - return false; - } - - *data = reinterpret_cast(mBuffer); - return true; - } - - template - template - bool BufferConsumer::Next(T** data) { - if (sizeof(T) > mSize) { - return false; - } - - *data = reinterpret_cast(mBuffer); - mBuffer += sizeof(T); - mSize -= sizeof(T); - return true; - } - - template - template - bool BufferConsumer::NextN(N count, T** data) { - static_assert(std::is_unsigned::value, "|count| argument of NextN must be unsigned."); - - constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits::max() / sizeof(T); - if (count > kMaxCountWithoutOverflows) { - return false; - } - - // Cannot overflow because |count| is not greater than |kMaxCountWithoutOverflows|. - size_t totalSize = sizeof(T) * count; - if (totalSize > mSize) { - return false; - } - - *data = reinterpret_cast(mBuffer); - mBuffer += totalSize; - mSize -= totalSize; - return true; - } - namespace { // Allocates enough space from allocator to countain T[count] and return it in out. // Return FatalError if the allocator couldn't allocate the memory. // Always writes to |out| on success. template - DeserializeResult GetSpace(DeserializeAllocator* allocator, N count, T** out) { + WireResult GetSpace(DeserializeAllocator* allocator, N count, T** out) { constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits::max() / sizeof(T); if (count > kMaxCountWithoutOverflows) { - return DeserializeResult::FatalError; + return WireResult::FatalError; } size_t totalSize = sizeof(T) * count; *out = static_cast(allocator->GetSpace(totalSize)); if (*out == nullptr) { - return DeserializeResult::FatalError; + return WireResult::FatalError; } - return DeserializeResult::Success; + return WireResult::Success; } size_t GetChainedStructExtraRequiredSize(const WGPUChainedStruct* chainedStruct); - DAWN_NO_DISCARD bool SerializeChainedStruct(WGPUChainedStruct const* chainedStruct, - SerializeBuffer* buffer, - const ObjectIdProvider& provider); - DeserializeResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext, + DAWN_NO_DISCARD WireResult SerializeChainedStruct(WGPUChainedStruct const* chainedStruct, + SerializeBuffer* buffer, + const ObjectIdProvider& provider); + WireResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext, DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator, const ObjectIdResolver& resolver); @@ -593,9 +533,9 @@ namespace dawn_wire { return result; } - DAWN_NO_DISCARD bool SerializeChainedStruct(WGPUChainedStruct const* chainedStruct, - SerializeBuffer* buffer, - const ObjectIdProvider& provider) { + DAWN_NO_DISCARD WireResult SerializeChainedStruct(WGPUChainedStruct const* chainedStruct, + SerializeBuffer* buffer, + const ObjectIdProvider& provider) { ASSERT(chainedStruct != nullptr); ASSERT(buffer != nullptr); do { @@ -605,11 +545,11 @@ namespace dawn_wire { case {{as_cEnum(types["s type"].name, sType.name)}}: { {{CType}}Transfer* transfer; - SERIALIZE_TRY(buffer->Next(&transfer)); + WIRE_TRY(buffer->Next(&transfer)); transfer->chain.sType = chainedStruct->sType; transfer->chain.hasNext = chainedStruct->next != nullptr; - SERIALIZE_TRY({{CType}}Serialize(*reinterpret_cast<{{CType}} const*>(chainedStruct), transfer, buffer + WIRE_TRY({{CType}}Serialize(*reinterpret_cast<{{CType}} const*>(chainedStruct), transfer, buffer {%- if types[sType.name.get()].may_have_dawn_object -%} , provider {%- endif -%} @@ -626,7 +566,7 @@ namespace dawn_wire { } WGPUChainedStructTransfer* transfer; - SERIALIZE_TRY(buffer->Next(&transfer)); + WIRE_TRY(buffer->Next(&transfer)); transfer->sType = WGPUSType_Invalid; transfer->hasNext = chainedStruct->next != nullptr; @@ -636,34 +576,34 @@ namespace dawn_wire { } } } while (chainedStruct != nullptr); - return true; + return WireResult::Success; } - DeserializeResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext, + WireResult DeserializeChainedStruct(const WGPUChainedStruct** outChainNext, DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator, const ObjectIdResolver& resolver) { bool hasNext; do { const volatile WGPUChainedStructTransfer* header; - DESERIALIZE_TRY(deserializeBuffer->Peek(&header)); + WIRE_TRY(deserializeBuffer->Peek(&header)); WGPUSType sType = header->sType; switch (sType) { {% for sType in types["s type"].values if sType.valid and sType.name.CamelCase() not in client_side_structures %} {% set CType = as_cType(sType.name) %} case {{as_cEnum(types["s type"].name, sType.name)}}: { const volatile {{CType}}Transfer* transfer; - DESERIALIZE_TRY(deserializeBuffer->Read(&transfer)); + WIRE_TRY(deserializeBuffer->Read(&transfer)); {{CType}}* outStruct; - DESERIALIZE_TRY(GetSpace(allocator, sizeof({{CType}}), &outStruct)); + WIRE_TRY(GetSpace(allocator, sizeof({{CType}}), &outStruct)); outStruct->chain.sType = sType; outStruct->chain.next = nullptr; *outChainNext = &outStruct->chain; outChainNext = &outStruct->chain.next; - DESERIALIZE_TRY({{CType}}Deserialize(outStruct, transfer, deserializeBuffer, allocator + WIRE_TRY({{CType}}Deserialize(outStruct, transfer, deserializeBuffer, allocator {%- if types[sType.name.get()].may_have_dawn_object -%} , resolver {%- endif -%} @@ -680,10 +620,10 @@ namespace dawn_wire { } const volatile WGPUChainedStructTransfer* transfer; - DESERIALIZE_TRY(deserializeBuffer->Read(&transfer)); + WIRE_TRY(deserializeBuffer->Read(&transfer)); WGPUChainedStruct* outStruct; - DESERIALIZE_TRY(GetSpace(allocator, sizeof(WGPUChainedStruct), &outStruct)); + WIRE_TRY(GetSpace(allocator, sizeof(WGPUChainedStruct), &outStruct)); outStruct->sType = WGPUSType_Invalid; outStruct->next = nullptr; @@ -696,7 +636,7 @@ namespace dawn_wire { } } while (hasNext); - return DeserializeResult::Success; + return WireResult::Success; } //* Output [de]serialization helpers for commands @@ -733,10 +673,12 @@ namespace dawn_wire { SerializeBuffer serializeBuffer(buffer, SerializedWGPUDevicePropertiesSize(deviceProperties)); WGPUDevicePropertiesTransfer* transfer; - bool success = - serializeBuffer.Next(&transfer) && - WGPUDevicePropertiesSerialize(*deviceProperties, transfer, &serializeBuffer); - ASSERT(success); + + WireResult result = serializeBuffer.Next(&transfer); + ASSERT(result == WireResult::Success); + + result = WGPUDevicePropertiesSerialize(*deviceProperties, transfer, &serializeBuffer); + ASSERT(result == WireResult::Success); } bool DeserializeWGPUDeviceProperties(WGPUDeviceProperties* deviceProperties, @@ -744,12 +686,12 @@ namespace dawn_wire { size_t size) { const volatile WGPUDevicePropertiesTransfer* transfer; DeserializeBuffer deserializeBuffer(buffer, size); - if (deserializeBuffer.Read(&transfer) != DeserializeResult::Success) { + if (deserializeBuffer.Read(&transfer) != WireResult::Success) { return false; } return WGPUDevicePropertiesDeserialize(deviceProperties, transfer, &deserializeBuffer, - nullptr) == DeserializeResult::Success; + nullptr) == WireResult::Success; } } // namespace dawn_wire diff --git a/generator/templates/dawn_wire/WireCmd.h b/generator/templates/dawn_wire/WireCmd.h index e5be81e034..68f365c20a 100644 --- a/generator/templates/dawn_wire/WireCmd.h +++ b/generator/templates/dawn_wire/WireCmd.h @@ -17,7 +17,9 @@ #include +#include "dawn_wire/BufferConsumer.h" #include "dawn_wire/ObjectType_autogen.h" +#include "dawn_wire/WireResult.h" namespace dawn_wire { @@ -43,67 +45,6 @@ namespace dawn_wire { ObjectHandle& AssignFrom(const volatile ObjectHandle& rhs); }; - enum class DeserializeResult { - Success, - FatalError, - }; - - template - class BufferConsumer { - public: - BufferConsumer(BufferT* buffer, size_t size) : mBuffer(buffer), mSize(size) {} - - BufferT* Buffer() const { return mBuffer; } - size_t AvailableSize() const { return mSize; } - - protected: - template - DAWN_NO_DISCARD bool NextN(N count, T** data); - - template - DAWN_NO_DISCARD bool Next(T** data); - - template - DAWN_NO_DISCARD bool Peek(T** data); - - private: - BufferT* mBuffer; - size_t mSize; - }; - - class SerializeBuffer : public BufferConsumer { - public: - using BufferConsumer::BufferConsumer; - using BufferConsumer::NextN; - using BufferConsumer::Next; - }; - - class DeserializeBuffer : public BufferConsumer { - public: - using BufferConsumer::BufferConsumer; - - template - DAWN_NO_DISCARD DeserializeResult ReadN(N count, const volatile T** data) { - return NextN(count, data) - ? DeserializeResult::Success - : DeserializeResult::FatalError; - } - - template - DAWN_NO_DISCARD DeserializeResult Read(const volatile T** data) { - return Next(data) - ? DeserializeResult::Success - : DeserializeResult::FatalError; - } - - template - DAWN_NO_DISCARD DeserializeResult Peek(const volatile T** data) { - return BufferConsumer::Peek(data) - ? DeserializeResult::Success - : DeserializeResult::FatalError; - } - }; - // Interface to allocate more space to deserialize pointed-to data. // nullptr is treated as an error. class DeserializeAllocator { @@ -116,8 +57,8 @@ namespace dawn_wire { class ObjectIdResolver { public: {% for type in by_category["object"] %} - virtual DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const = 0; - virtual DeserializeResult GetOptionalFromId(ObjectId id, {{as_cType(type.name)}}* out) const = 0; + virtual WireResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const = 0; + virtual WireResult GetOptionalFromId(ObjectId id, {{as_cType(type.name)}}* out) const = 0; {% endfor %} }; @@ -157,7 +98,7 @@ namespace dawn_wire { //* Serialize the structure and everything it points to into serializeBuffer which must be //* big enough to contain all the data (as queried from GetRequiredSize). - DAWN_NO_DISCARD bool Serialize(size_t commandSize, SerializeBuffer* serializeBuffer + WireResult Serialize(size_t commandSize, SerializeBuffer* serializeBuffer {%- if not is_return_command -%} , const ObjectIdProvider& objectIdProvider {%- endif -%} @@ -170,7 +111,7 @@ namespace dawn_wire { //* Deserialize returns: //* - Success if everything went well (yay!) //* - FatalError is something bad happened (buffer too small for example) - DeserializeResult Deserialize(DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator + WireResult Deserialize(DeserializeBuffer* deserializeBuffer, DeserializeAllocator* allocator {%- if command.may_have_dawn_object -%} , const ObjectIdResolver& resolver {%- endif -%} diff --git a/generator/templates/dawn_wire/client/ClientHandlers.cpp b/generator/templates/dawn_wire/client/ClientHandlers.cpp index de1ca3c619..13ac79c13b 100644 --- a/generator/templates/dawn_wire/client/ClientHandlers.cpp +++ b/generator/templates/dawn_wire/client/ClientHandlers.cpp @@ -21,9 +21,9 @@ namespace dawn_wire { namespace client { {% for command in cmd_records["return command"] %} bool Client::Handle{{command.name.CamelCase()}}(DeserializeBuffer* deserializeBuffer) { Return{{command.name.CamelCase()}}Cmd cmd; - DeserializeResult deserializeResult = cmd.Deserialize(deserializeBuffer, &mAllocator); + WireResult deserializeResult = cmd.Deserialize(deserializeBuffer, &mAllocator); - if (deserializeResult == DeserializeResult::FatalError) { + if (deserializeResult == WireResult::FatalError) { return false; } diff --git a/generator/templates/dawn_wire/server/ServerBase.h b/generator/templates/dawn_wire/server/ServerBase.h index 66193a477e..eb0aab8c3b 100644 --- a/generator/templates/dawn_wire/server/ServerBase.h +++ b/generator/templates/dawn_wire/server/ServerBase.h @@ -70,20 +70,20 @@ namespace dawn_wire { namespace server { private: // Implementation of the ObjectIdResolver interface {% for type in by_category["object"] %} - DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const final { + WireResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const final { auto data = mKnown{{type.name.CamelCase()}}.Get(id); if (data == nullptr) { - return DeserializeResult::FatalError; + return WireResult::FatalError; } *out = data->handle; - return DeserializeResult::Success; + return WireResult::Success; } - DeserializeResult GetOptionalFromId(ObjectId id, {{as_cType(type.name)}}* out) const final { + WireResult GetOptionalFromId(ObjectId id, {{as_cType(type.name)}}* out) const final { if (id == 0) { *out = nullptr; - return DeserializeResult::Success; + return WireResult::Success; } return GetFromId(id, out); diff --git a/generator/templates/dawn_wire/server/ServerHandlers.cpp b/generator/templates/dawn_wire/server/ServerHandlers.cpp index a544a505d1..ea3da6cae3 100644 --- a/generator/templates/dawn_wire/server/ServerHandlers.cpp +++ b/generator/templates/dawn_wire/server/ServerHandlers.cpp @@ -25,13 +25,13 @@ namespace dawn_wire { namespace server { //* The generic command handlers bool Server::Handle{{Suffix}}(DeserializeBuffer* deserializeBuffer) { {{Suffix}}Cmd cmd; - DeserializeResult deserializeResult = cmd.Deserialize(deserializeBuffer, &mAllocator + WireResult deserializeResult = cmd.Deserialize(deserializeBuffer, &mAllocator {%- if command.may_have_dawn_object -%} , *this {%- endif -%} ); - if (deserializeResult == DeserializeResult::FatalError) { + if (deserializeResult == WireResult::FatalError) { return false; } diff --git a/src/dawn_wire/BUILD.gn b/src/dawn_wire/BUILD.gn index 4b212996ce..ad72394bb5 100644 --- a/src/dawn_wire/BUILD.gn +++ b/src/dawn_wire/BUILD.gn @@ -59,6 +59,8 @@ dawn_component("dawn_wire") { configs = [ "${dawn_root}/src/common:dawn_internal" ] sources = get_target_outputs(":dawn_wire_gen") sources += [ + "BufferConsumer.h", + "BufferConsumer_impl.h", "ChunkedCommandHandler.cpp", "ChunkedCommandHandler.h", "ChunkedCommandSerializer.cpp", @@ -67,6 +69,7 @@ dawn_component("dawn_wire") { "WireClient.cpp", "WireDeserializeAllocator.cpp", "WireDeserializeAllocator.h", + "WireResult.h", "WireServer.cpp", "client/ApiObjects.h", "client/Buffer.cpp", diff --git a/src/dawn_wire/BufferConsumer.h b/src/dawn_wire/BufferConsumer.h new file mode 100644 index 0000000000..3797bf40c8 --- /dev/null +++ b/src/dawn_wire/BufferConsumer.h @@ -0,0 +1,85 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWNWIRE_BUFFERCONSUMER_H_ +#define DAWNWIRE_BUFFERCONSUMER_H_ + +#include "dawn_wire/WireResult.h" + +#include + +namespace dawn_wire { + + // BufferConsumer is a utility class that allows reading bytes from a buffer + // while simultaneously decrementing the amount of remaining space by exactly + // the amount read. It helps prevent bugs where incrementing a pointer and + // decrementing a size value are not kept in sync. + // BufferConsumer also contains bounds checks to prevent reading out-of-bounds. + template + class BufferConsumer { + static_assert(sizeof(BufferT) == 1, + "BufferT must be 1-byte, but may have const/volatile qualifiers."); + + public: + BufferConsumer(BufferT* buffer, size_t size) : mBuffer(buffer), mSize(size) { + } + + BufferT* Buffer() const { + return mBuffer; + } + size_t AvailableSize() const { + return mSize; + } + + protected: + template + WireResult NextN(N count, T** data); + + template + WireResult Next(T** data); + + template + WireResult Peek(T** data); + + private: + BufferT* mBuffer; + size_t mSize; + }; + + class SerializeBuffer : public BufferConsumer { + public: + using BufferConsumer::BufferConsumer; + using BufferConsumer::Next; + using BufferConsumer::NextN; + }; + + class DeserializeBuffer : public BufferConsumer { + public: + using BufferConsumer::BufferConsumer; + using BufferConsumer::Peek; + + template + WireResult ReadN(N count, const volatile T** data) { + return NextN(count, data); + } + + template + WireResult Read(const volatile T** data) { + return Next(data); + } + }; + +} // namespace dawn_wire + +#endif // DAWNWIRE_BUFFERCONSUMER_H_ \ No newline at end of file diff --git a/src/dawn_wire/BufferConsumer_impl.h b/src/dawn_wire/BufferConsumer_impl.h new file mode 100644 index 0000000000..f815dec0cd --- /dev/null +++ b/src/dawn_wire/BufferConsumer_impl.h @@ -0,0 +1,72 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWNWIRE_BUFFERCONSUMER_IMPL_H_ +#define DAWNWIRE_BUFFERCONSUMER_IMPL_H_ + +#include "dawn_wire/BufferConsumer.h" + +#include + +namespace dawn_wire { + + template + template + WireResult BufferConsumer::Peek(T** data) { + if (sizeof(T) > mSize) { + return WireResult::FatalError; + } + + *data = reinterpret_cast(mBuffer); + return WireResult::Success; + } + + template + template + WireResult BufferConsumer::Next(T** data) { + if (sizeof(T) > mSize) { + return WireResult::FatalError; + } + + *data = reinterpret_cast(mBuffer); + mBuffer += sizeof(T); + mSize -= sizeof(T); + return WireResult::Success; + } + + template + template + WireResult BufferConsumer::NextN(N count, T** data) { + static_assert(std::is_unsigned::value, "|count| argument of NextN must be unsigned."); + + constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits::max() / sizeof(T); + if (count > kMaxCountWithoutOverflows) { + return WireResult::FatalError; + } + + // Cannot overflow because |count| is not greater than |kMaxCountWithoutOverflows|. + size_t totalSize = sizeof(T) * count; + if (totalSize > mSize) { + return WireResult::FatalError; + } + + *data = reinterpret_cast(mBuffer); + mBuffer += totalSize; + mSize -= totalSize; + return WireResult::Success; + } + +} // namespace dawn_wire + +#endif // DAWNWIRE_BUFFERCONSUMER_IMPL_H_ \ No newline at end of file diff --git a/src/dawn_wire/CMakeLists.txt b/src/dawn_wire/CMakeLists.txt index d6d430f588..77e96c0b69 100644 --- a/src/dawn_wire/CMakeLists.txt +++ b/src/dawn_wire/CMakeLists.txt @@ -31,6 +31,8 @@ target_sources(dawn_wire PRIVATE "${DAWN_INCLUDE_DIR}/dawn_wire/WireServer.h" "${DAWN_INCLUDE_DIR}/dawn_wire/dawn_wire_export.h" ${DAWN_WIRE_GEN_SOURCES} + "BufferConsumer.h" + "BufferConsumer_impl.h" "ChunkedCommandHandler.cpp" "ChunkedCommandHandler.h" "ChunkedCommandSerializer.cpp" @@ -39,6 +41,7 @@ target_sources(dawn_wire PRIVATE "WireClient.cpp" "WireDeserializeAllocator.cpp" "WireDeserializeAllocator.h" + "WireResult.h" "WireServer.cpp" "client/ApiObjects.h" "client/Buffer.cpp" diff --git a/src/dawn_wire/ChunkedCommandSerializer.h b/src/dawn_wire/ChunkedCommandSerializer.h index e62cb99d79..2465f8153d 100644 --- a/src/dawn_wire/ChunkedCommandSerializer.h +++ b/src/dawn_wire/ChunkedCommandSerializer.h @@ -32,7 +32,7 @@ namespace dawn_wire { template void SerializeCommand(const Cmd& cmd) { - SerializeCommand(cmd, 0, [](SerializeBuffer*) { return true; }); + SerializeCommand(cmd, 0, [](SerializeBuffer*) { return WireResult::Success; }); } template @@ -49,7 +49,8 @@ namespace dawn_wire { template void SerializeCommand(const Cmd& cmd, const ObjectIdProvider& objectIdProvider) { - SerializeCommand(cmd, objectIdProvider, 0, [](SerializeBuffer*) { return true; }); + SerializeCommand(cmd, objectIdProvider, 0, + [](SerializeBuffer*) { return WireResult::Success; }); } template @@ -79,10 +80,9 @@ namespace dawn_wire { char* allocatedBuffer = static_cast(mSerializer->GetCmdSpace(requiredSize)); if (allocatedBuffer != nullptr) { SerializeBuffer serializeBuffer(allocatedBuffer, requiredSize); - bool success = true; - success &= SerializeCmd(cmd, requiredSize, &serializeBuffer); - success &= SerializeExtraSize(&serializeBuffer); - if (DAWN_UNLIKELY(!success)) { + WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer); + WireResult r2 = SerializeExtraSize(&serializeBuffer); + if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) { mSerializer->OnSerializeError(); } } @@ -94,10 +94,9 @@ namespace dawn_wire { return; } SerializeBuffer serializeBuffer(cmdSpace.get(), requiredSize); - bool success = true; - success &= SerializeCmd(cmd, requiredSize, &serializeBuffer); - success &= SerializeExtraSize(&serializeBuffer); - if (DAWN_UNLIKELY(!success)) { + WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer); + WireResult r2 = SerializeExtraSize(&serializeBuffer); + if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) { mSerializer->OnSerializeError(); return; } diff --git a/src/dawn_wire/WireResult.h b/src/dawn_wire/WireResult.h new file mode 100644 index 0000000000..026a98e3c8 --- /dev/null +++ b/src/dawn_wire/WireResult.h @@ -0,0 +1,38 @@ +// Copyright 2021 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DAWNWIRE_WIRERESULT_H_ +#define DAWNWIRE_WIRERESULT_H_ + +#include "common/Compiler.h" + +namespace dawn_wire { + + enum DAWN_NO_DISCARD class WireResult { + Success, + FatalError, + }; + +// Macro to simplify error handling, similar to DAWN_TRY but for WireResult. +#define WIRE_TRY(EXPR) \ + do { \ + WireResult exprResult = EXPR; \ + if (DAWN_UNLIKELY(exprResult != WireResult::Success)) { \ + return exprResult; \ + } \ + } while (0) + +} // namespace dawn_wire + +#endif // DAWNWIRE_WIRERESULT_H_ \ No newline at end of file diff --git a/src/dawn_wire/client/Buffer.cpp b/src/dawn_wire/client/Buffer.cpp index 4f4598423f..3c7519c6ed 100644 --- a/src/dawn_wire/client/Buffer.cpp +++ b/src/dawn_wire/client/Buffer.cpp @@ -14,6 +14,7 @@ #include "dawn_wire/client/Buffer.h" +#include "dawn_wire/BufferConsumer_impl.h" #include "dawn_wire/WireCmd_autogen.h" #include "dawn_wire/client/Client.h" #include "dawn_wire/client/Device.h" @@ -77,11 +78,11 @@ namespace dawn_wire { namespace client { wireClient->SerializeCommand( cmd, writeHandleCreateInfoLength, [&](SerializeBuffer* serializeBuffer) { if (descriptor->mappedAtCreation) { - if (serializeBuffer->AvailableSize() != writeHandleCreateInfoLength) { - return false; - } + char* writeHandleBuffer; + WIRE_TRY( + serializeBuffer->NextN(writeHandleCreateInfoLength, &writeHandleBuffer)); // Serialize the WriteHandle into the space after the command. - writeHandle->SerializeCreate(serializeBuffer->Buffer()); + writeHandle->SerializeCreate(writeHandleBuffer); // Set the buffer state for the mapping at creation. The buffer now owns the // write handle.. @@ -90,7 +91,7 @@ namespace dawn_wire { namespace client { buffer->mMapOffset = 0; buffer->mMapSize = buffer->mSize; } - return true; + return WireResult::Success; }); return ToAPI(buffer); } @@ -207,22 +208,21 @@ namespace dawn_wire { namespace client { cmd.handleCreateInfoLength = request.readHandle->SerializeCreateSize(); client->SerializeCommand( cmd, cmd.handleCreateInfoLength, [&](SerializeBuffer* serializeBuffer) { - bool success = serializeBuffer->AvailableSize() == cmd.handleCreateInfoLength; - if (success) { - request.readHandle->SerializeCreate(serializeBuffer->Buffer()); - } - return success; + char* readHandleBuffer; + WIRE_TRY(serializeBuffer->NextN(cmd.handleCreateInfoLength, &readHandleBuffer)); + request.readHandle->SerializeCreate(readHandleBuffer); + return WireResult::Success; }); } else { ASSERT(isWriteMode); cmd.handleCreateInfoLength = request.writeHandle->SerializeCreateSize(); client->SerializeCommand( cmd, cmd.handleCreateInfoLength, [&](SerializeBuffer* serializeBuffer) { - bool success = serializeBuffer->AvailableSize() == cmd.handleCreateInfoLength; - if (success) { - request.writeHandle->SerializeCreate(serializeBuffer->Buffer()); - } - return success; + char* writeHandleBuffer; + WIRE_TRY( + serializeBuffer->NextN(cmd.handleCreateInfoLength, &writeHandleBuffer)); + request.writeHandle->SerializeCreate(writeHandleBuffer); + return WireResult::Success; }); } @@ -352,13 +352,13 @@ namespace dawn_wire { namespace client { client->SerializeCommand( cmd, writeFlushInfoLength, [&](SerializeBuffer* serializeBuffer) { - bool success = serializeBuffer->AvailableSize() == writeFlushInfoLength; - if (success) { - // Serialize flush metadata into the space after the command. - // This closes the handle for writing. - mWriteHandle->SerializeFlush(serializeBuffer->Buffer()); - } - return success; + char* writeHandleBuffer; + WIRE_TRY(serializeBuffer->NextN(writeFlushInfoLength, &writeHandleBuffer)); + + // Serialize flush metadata into the space after the command. + // This closes the handle for writing. + mWriteHandle->SerializeFlush(writeHandleBuffer); + return WireResult::Success; }); mWriteHandle = nullptr; diff --git a/src/dawn_wire/server/ServerBuffer.cpp b/src/dawn_wire/server/ServerBuffer.cpp index 8aaf4ffda7..b9fda44317 100644 --- a/src/dawn_wire/server/ServerBuffer.cpp +++ b/src/dawn_wire/server/ServerBuffer.cpp @@ -13,6 +13,7 @@ // limitations under the License. #include "common/Assert.h" +#include "dawn_wire/BufferConsumer_impl.h" #include "dawn_wire/WireCmd_autogen.h" #include "dawn_wire/server/Server.h" @@ -251,12 +252,12 @@ namespace dawn_wire { namespace server { SerializeCommand(cmd, cmd.readInitialDataInfoLength, [&](SerializeBuffer* serializeBuffer) { if (isSuccess) { if (isRead) { - if (serializeBuffer->AvailableSize() != cmd.readInitialDataInfoLength) { - return false; - } + char* readHandleBuffer; + WIRE_TRY( + serializeBuffer->NextN(cmd.readInitialDataInfoLength, &readHandleBuffer)); + // Serialize the initialization message into the space after the command. - data->readHandle->SerializeInitialData(readData, data->size, - serializeBuffer->Buffer()); + data->readHandle->SerializeInitialData(readData, data->size, readHandleBuffer); // The in-flight map request returned successfully. // Move the ReadHandle so it is owned by the buffer. bufferData->readHandle = std::move(data->readHandle); @@ -271,7 +272,7 @@ namespace dawn_wire { namespace server { data->size); } } - return true; + return WireResult::Success; }); }