Add wire serialized command buffer padding.

Pads serialized wire command buffers to 8 bytes so that we don't have
misaligned write/reads which can cause SIGILL depending on platform and
compilation mode, i.e. -c dbg in google3 builds.

- Adds helpers for aligning sizeof calls.
- Adds constant for wire padding (8u).
- Modifies BufferConsumer to allocate according to padding. This
  guarantees that when we [de]serialize stuff, the padding should be
  equal on both sides.
- Modifies extra byte serialization code (adding CommandExtension
  struct). This makes it clearer that each extension needs to be
  padded independently. Otherwise, before in wire/client/Buffer.cpp,
  since the read/write handle sizes were being passed as a sum, but
  read out separately from the BufferConsumer, we corrupt our pointers.
- Adds some simple unit tests.

Bug: dawn:1334
Change-Id: Id80e7c01a34b9f01c3f02b3e6c04c3bb3ad0eff9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/110501
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Loko Kung 2022-11-22 23:19:43 +00:00 committed by Dawn LUCI CQ
parent c982cd45c4
commit 2e1b359087
11 changed files with 208 additions and 105 deletions

View File

@ -165,7 +165,7 @@
if (has_{{memberName}}) if (has_{{memberName}})
{% endif %} {% endif %}
{ {
result += std::strlen(record.{{memberName}}); result += Align(std::strlen(record.{{memberName}}), kWireBufferAlignment);
} }
{% endfor %} {% endfor %}
@ -178,7 +178,9 @@
{% if member.annotation != "value" %} {% if member.annotation != "value" %}
{{ assert(member.annotation != "const*const*") }} {{ assert(member.annotation != "const*const*") }}
auto memberLength = {{member_length(member, "record.")}}; auto memberLength = {{member_length(member, "record.")}};
result += memberLength * {{member_transfer_sizeof(member)}}; auto size = WireAlignSizeofN<{{member_transfer_type(member)}}>(memberLength);
ASSERT(size);
result += *size;
//* Structures might contain more pointers so we need to add their extra size as well. //* Structures might contain more pointers so we need to add their extra size as well.
{% if member.type.category == "structure" %} {% if member.type.category == "structure" %}
for (decltype(memberLength) i = 0; i < memberLength; ++i) { for (decltype(memberLength) i = 0; i < memberLength; ++i) {
@ -431,7 +433,7 @@
{% set Cmd = Name + "Cmd" %} {% set Cmd = Name + "Cmd" %}
size_t {{Cmd}}::GetRequiredSize() const { size_t {{Cmd}}::GetRequiredSize() const {
size_t size = sizeof({{Name}}Transfer) + {{Name}}GetExtraRequiredSize(*this); size_t size = WireAlignSizeof<{{Name}}Transfer>() + {{Name}}GetExtraRequiredSize(*this);
return size; return size;
} }
@ -509,7 +511,7 @@
) %} ) %}
case {{as_cEnum(types["s type"].name, sType.name)}}: { case {{as_cEnum(types["s type"].name, sType.name)}}: {
const auto& typedStruct = *reinterpret_cast<{{as_cType(sType.name)}} const *>(chainedStruct); const auto& typedStruct = *reinterpret_cast<{{as_cType(sType.name)}} const *>(chainedStruct);
result += sizeof({{as_cType(sType.name)}}Transfer); result += WireAlignSizeof<{{as_cType(sType.name)}}Transfer>();
result += {{as_cType(sType.name)}}GetExtraRequiredSize(typedStruct); result += {{as_cType(sType.name)}}GetExtraRequiredSize(typedStruct);
chainedStruct = typedStruct.chain.next; chainedStruct = typedStruct.chain.next;
break; break;
@ -519,7 +521,7 @@
case WGPUSType_Invalid: case WGPUSType_Invalid:
default: default:
// Invalid enum. Reserve space just for the transfer header (sType and hasNext). // Invalid enum. Reserve space just for the transfer header (sType and hasNext).
result += sizeof(WGPUChainedStructTransfer); result += WireAlignSizeof<WGPUChainedStructTransfer>();
chainedStruct = chainedStruct->next; chainedStruct = chainedStruct->next;
break; break;
} }
@ -600,7 +602,7 @@
WIRE_TRY(deserializeBuffer->Read(&transfer)); WIRE_TRY(deserializeBuffer->Read(&transfer));
{{CType}}* outStruct; {{CType}}* outStruct;
WIRE_TRY(GetSpace(allocator, sizeof({{CType}}), &outStruct)); WIRE_TRY(GetSpace(allocator, 1u, &outStruct));
outStruct->chain.sType = sType; outStruct->chain.sType = sType;
outStruct->chain.next = nullptr; outStruct->chain.next = nullptr;
@ -629,7 +631,7 @@
WIRE_TRY(deserializeBuffer->Read(&transfer)); WIRE_TRY(deserializeBuffer->Read(&transfer));
{{ChainedStruct}}* outStruct; {{ChainedStruct}}* outStruct;
WIRE_TRY(GetSpace(allocator, sizeof({{ChainedStruct}}), &outStruct)); WIRE_TRY(GetSpace(allocator, 1u, &outStruct));
outStruct->sType = WGPUSType_Invalid; outStruct->sType = WGPUSType_Invalid;
outStruct->next = nullptr; outStruct->next = nullptr;
@ -654,13 +656,23 @@ namespace dawn::wire {
// Always writes to |out| on success. // Always writes to |out| on success.
template <typename T, typename N> template <typename T, typename N>
WireResult GetSpace(DeserializeAllocator* allocator, N count, T** out) { WireResult GetSpace(DeserializeAllocator* allocator, N count, T** out) {
constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits<size_t>::max() / sizeof(T); // Because we use this function extensively when `count` == 1, we can optimize the
if (count > kMaxCountWithoutOverflows) { // size computations a bit more for those cases via constexpr version of the
// alignment computation.
constexpr size_t kSizeofT = WireAlignSizeof<T>();
size_t size = 0;
if (count == 1) {
size = kSizeofT;
} else {
auto sizeN = WireAlignSizeofN<T>(count);
// A size of 0 indicates an overflow, so return an error.
if (!sizeN) {
return WireResult::FatalError; return WireResult::FatalError;
} }
size = *sizeN;
}
size_t totalSize = sizeof(T) * count; *out = static_cast<T*>(allocator->GetSpace(size));
*out = static_cast<T*>(allocator->GetSpace(totalSize));
if (*out == nullptr) { if (*out == nullptr) {
return WireResult::FatalError; return WireResult::FatalError;
} }

View File

@ -15,6 +15,7 @@
#ifndef SRC_DAWN_COMMON_CONSTANTS_H_ #ifndef SRC_DAWN_COMMON_CONSTANTS_H_
#define SRC_DAWN_COMMON_CONSTANTS_H_ #define SRC_DAWN_COMMON_CONSTANTS_H_
#include <cstddef>
#include <cstdint> #include <cstdint>
static constexpr uint32_t kMaxBindGroups = 4u; static constexpr uint32_t kMaxBindGroups = 4u;
@ -65,4 +66,7 @@ static constexpr uint8_t kSampledTexturesPerExternalTexture = 4u;
static constexpr uint8_t kSamplersPerExternalTexture = 1u; static constexpr uint8_t kSamplersPerExternalTexture = 1u;
static constexpr uint8_t kUniformsPerExternalTexture = 1u; static constexpr uint8_t kUniformsPerExternalTexture = 1u;
// Wire buffer alignments.
static constexpr size_t kWireBufferAlignment = 8u;
#endif // SRC_DAWN_COMMON_CONSTANTS_H_ #endif // SRC_DAWN_COMMON_CONSTANTS_H_

View File

@ -20,6 +20,7 @@
#include <cstring> #include <cstring>
#include <limits> #include <limits>
#include <optional>
#include <type_traits> #include <type_traits>
#include "dawn/common/Assert.h" #include "dawn/common/Assert.h"
@ -61,6 +62,26 @@ T Align(T value, size_t alignment) {
return (value + (alignmentT - 1)) & ~(alignmentT - 1); return (value + (alignmentT - 1)) & ~(alignmentT - 1);
} }
template <typename T, size_t Alignment>
constexpr size_t AlignSizeof() {
static_assert(Alignment != 0 && (Alignment & (Alignment - 1)) == 0,
"Alignment must be a valid power of 2.");
static_assert(sizeof(T) <= std::numeric_limits<size_t>::max() - (Alignment - 1));
return (sizeof(T) + (Alignment - 1)) & ~(Alignment - 1);
}
// Returns an aligned size for an n-sized array of T elements. If the size would overflow, returns
// nullopt instead.
template <typename T, size_t Alignment>
std::optional<size_t> AlignSizeofN(uint64_t n) {
constexpr uint64_t kMaxCountWithoutOverflows =
(std::numeric_limits<size_t>::max() - Alignment + 1) / sizeof(T);
if (n > kMaxCountWithoutOverflows) {
return std::nullopt;
}
return Align(sizeof(T) * n, Alignment);
}
template <typename T> template <typename T>
DAWN_FORCE_INLINE T* AlignPtr(T* ptr, size_t alignment) { DAWN_FORCE_INLINE T* AlignPtr(T* ptr, size_t alignment) {
ASSERT(IsPowerOfTwo(alignment)); ASSERT(IsPowerOfTwo(alignment));

View File

@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include <cmath> #include <cmath>
#include <limits>
#include <vector> #include <vector>
#include "dawn/EnumClassBitmasks.h" #include "dawn/EnumClassBitmasks.h"
@ -180,6 +181,41 @@ TEST(Math, Align) {
ASSERT_EQ(Align(static_cast<uint64_t>(0xFFFFFFFFFFFFFFFF), 1), 0xFFFFFFFFFFFFFFFFull); ASSERT_EQ(Align(static_cast<uint64_t>(0xFFFFFFFFFFFFFFFF), 1), 0xFFFFFFFFFFFFFFFFull);
} }
TEST(Math, AlignSizeof) {
// Basic types should align to self if alignment is a divisor.
ASSERT_EQ((AlignSizeof<uint8_t, 1>()), 1u);
ASSERT_EQ((AlignSizeof<uint16_t, 1>()), 2u);
ASSERT_EQ((AlignSizeof<uint16_t, 2>()), 2u);
ASSERT_EQ((AlignSizeof<uint32_t, 1>()), 4u);
ASSERT_EQ((AlignSizeof<uint32_t, 2>()), 4u);
ASSERT_EQ((AlignSizeof<uint32_t, 4>()), 4u);
ASSERT_EQ((AlignSizeof<uint64_t, 1>()), 8u);
ASSERT_EQ((AlignSizeof<uint64_t, 2>()), 8u);
ASSERT_EQ((AlignSizeof<uint64_t, 4>()), 8u);
ASSERT_EQ((AlignSizeof<uint64_t, 8>()), 8u);
// Everything in range (align, 2*align] aligns to 2*align.
ASSERT_EQ((AlignSizeof<char[5], 4>()), 8u);
ASSERT_EQ((AlignSizeof<char[6], 4>()), 8u);
ASSERT_EQ((AlignSizeof<char[7], 4>()), 8u);
ASSERT_EQ((AlignSizeof<char[8], 4>()), 8u);
}
TEST(Math, AlignSizeofN) {
// Everything in range (align, 2*align] aligns to 2*align.
ASSERT_EQ(*(AlignSizeofN<char, 4>(5)), 8u);
ASSERT_EQ(*(AlignSizeofN<char, 4>(6)), 8u);
ASSERT_EQ(*(AlignSizeofN<char, 4>(7)), 8u);
ASSERT_EQ(*(AlignSizeofN<char, 4>(8)), 8u);
// Extremes should return nullopt.
ASSERT_EQ((AlignSizeofN<char, 4>(std::numeric_limits<size_t>::max())), std::nullopt);
ASSERT_EQ((AlignSizeofN<char, 4>(std::numeric_limits<uint64_t>::max())), std::nullopt);
}
// Tests for IsPtrAligned // Tests for IsPtrAligned
TEST(Math, IsPtrAligned) { TEST(Math, IsPtrAligned) {
constexpr size_t kTestAlignment = 8; constexpr size_t kTestAlignment = 8;

View File

@ -17,10 +17,22 @@
#include <cstddef> #include <cstddef>
#include "dawn/common/Constants.h"
#include "dawn/common/Math.h"
#include "dawn/wire/WireResult.h" #include "dawn/wire/WireResult.h"
namespace dawn::wire { namespace dawn::wire {
// Wire specific alignment helpers.
template <typename T>
constexpr size_t WireAlignSizeof() {
return AlignSizeof<T, kWireBufferAlignment>();
}
template <typename T>
std::optional<size_t> WireAlignSizeofN(size_t n) {
return AlignSizeofN<T, kWireBufferAlignment>(n);
}
// BufferConsumer is a utility class that allows reading bytes from a buffer // BufferConsumer is a utility class that allows reading bytes from a buffer
// while simultaneously decrementing the amount of remaining space by exactly // while simultaneously decrementing the amount of remaining space by exactly
// the amount read. It helps prevent bugs where incrementing a pointer and // the amount read. It helps prevent bugs where incrementing a pointer and

View File

@ -15,11 +15,11 @@
#ifndef SRC_DAWN_WIRE_BUFFERCONSUMER_IMPL_H_ #ifndef SRC_DAWN_WIRE_BUFFERCONSUMER_IMPL_H_
#define SRC_DAWN_WIRE_BUFFERCONSUMER_IMPL_H_ #define SRC_DAWN_WIRE_BUFFERCONSUMER_IMPL_H_
#include "dawn/wire/BufferConsumer.h"
#include <limits> #include <limits>
#include <type_traits> #include <type_traits>
#include "dawn/wire/BufferConsumer.h"
namespace dawn::wire { namespace dawn::wire {
template <typename BufferT> template <typename BufferT>
@ -36,13 +36,14 @@ WireResult BufferConsumer<BufferT>::Peek(T** data) {
template <typename BufferT> template <typename BufferT>
template <typename T> template <typename T>
WireResult BufferConsumer<BufferT>::Next(T** data) { WireResult BufferConsumer<BufferT>::Next(T** data) {
if (sizeof(T) > mSize) { constexpr size_t kSize = WireAlignSizeof<T>();
if (kSize > mSize) {
return WireResult::FatalError; return WireResult::FatalError;
} }
*data = reinterpret_cast<T*>(mBuffer); *data = reinterpret_cast<T*>(mBuffer);
mBuffer += sizeof(T); mBuffer += kSize;
mSize -= sizeof(T); mSize -= kSize;
return WireResult::Success; return WireResult::Success;
} }
@ -51,20 +52,15 @@ template <typename T, typename N>
WireResult BufferConsumer<BufferT>::NextN(N count, T** data) { WireResult BufferConsumer<BufferT>::NextN(N count, T** data) {
static_assert(std::is_unsigned<N>::value, "|count| argument of NextN must be unsigned."); static_assert(std::is_unsigned<N>::value, "|count| argument of NextN must be unsigned.");
constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits<size_t>::max() / sizeof(T); // If size is zero then it indicates an overflow.
if (count > kMaxCountWithoutOverflows) { auto size = WireAlignSizeofN<T>(count);
return WireResult::FatalError; if (size && *size > mSize) {
}
// Cannot overflow because |count| is not greater than |kMaxCountWithoutOverflows|.
size_t totalSize = sizeof(T) * count;
if (totalSize > mSize) {
return WireResult::FatalError; return WireResult::FatalError;
} }
*data = reinterpret_cast<T*>(mBuffer); *data = reinterpret_cast<T*>(mBuffer);
mBuffer += totalSize; mBuffer += *size;
mSize -= totalSize; mSize -= *size;
return WireResult::Success; return WireResult::Success;
} }

View File

@ -17,73 +17,97 @@
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
#include <functional>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include "dawn/common/Alloc.h" #include "dawn/common/Alloc.h"
#include "dawn/common/Compiler.h" #include "dawn/common/Compiler.h"
#include "dawn/common/Constants.h"
#include "dawn/common/Math.h"
#include "dawn/wire/Wire.h" #include "dawn/wire/Wire.h"
#include "dawn/wire/WireCmd_autogen.h" #include "dawn/wire/WireCmd_autogen.h"
namespace dawn::wire { namespace dawn::wire {
// Simple command extension struct used when a command needs to serialize additional information
// that is not baked directly into the command already.
struct CommandExtension {
size_t size;
std::function<void(char*)> serialize;
};
namespace detail {
inline WireResult SerializeCommandExtension(SerializeBuffer* serializeBuffer) {
return WireResult::Success;
}
template <typename Extension, typename... Extensions>
WireResult SerializeCommandExtension(SerializeBuffer* serializeBuffer,
Extension&& e,
Extensions&&... es) {
char* buffer;
WIRE_TRY(serializeBuffer->NextN(e.size, &buffer));
e.serialize(buffer);
WIRE_TRY(SerializeCommandExtension(serializeBuffer, std::forward<Extensions>(es)...));
return WireResult::Success;
}
} // namespace detail
class ChunkedCommandSerializer { class ChunkedCommandSerializer {
public: public:
explicit ChunkedCommandSerializer(CommandSerializer* serializer); explicit ChunkedCommandSerializer(CommandSerializer* serializer);
template <typename Cmd> template <typename Cmd>
void SerializeCommand(const Cmd& cmd) { void SerializeCommand(const Cmd& cmd) {
SerializeCommand(cmd, 0, [](SerializeBuffer*) { return WireResult::Success; }); SerializeCommandImpl(
cmd, [](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) {
return cmd.Serialize(requiredSize, serializeBuffer);
});
} }
template <typename Cmd, typename ExtraSizeSerializeFn> template <typename Cmd, typename... Extensions>
void SerializeCommand(const Cmd& cmd, void SerializeCommand(const Cmd& cmd, CommandExtension&& e, Extensions&&... es) {
size_t extraSize,
ExtraSizeSerializeFn&& SerializeExtraSize) {
SerializeCommandImpl( SerializeCommandImpl(
cmd, cmd,
[](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) { [](const Cmd& cmd, size_t requiredSize, SerializeBuffer* serializeBuffer) {
return cmd.Serialize(requiredSize, serializeBuffer); return cmd.Serialize(requiredSize, serializeBuffer);
}, },
extraSize, std::forward<ExtraSizeSerializeFn>(SerializeExtraSize)); std::forward<CommandExtension>(e), std::forward<Extensions>(es)...);
} }
template <typename Cmd> template <typename Cmd, typename... Extensions>
void SerializeCommand(const Cmd& cmd, const ObjectIdProvider& objectIdProvider) {
SerializeCommand(cmd, objectIdProvider, 0,
[](SerializeBuffer*) { return WireResult::Success; });
}
template <typename Cmd, typename ExtraSizeSerializeFn>
void SerializeCommand(const Cmd& cmd, void SerializeCommand(const Cmd& cmd,
const ObjectIdProvider& objectIdProvider, const ObjectIdProvider& objectIdProvider,
size_t extraSize, Extensions&&... extensions) {
ExtraSizeSerializeFn&& SerializeExtraSize) {
SerializeCommandImpl( SerializeCommandImpl(
cmd, cmd,
[&objectIdProvider](const Cmd& cmd, size_t requiredSize, [&objectIdProvider](const Cmd& cmd, size_t requiredSize,
SerializeBuffer* serializeBuffer) { SerializeBuffer* serializeBuffer) {
return cmd.Serialize(requiredSize, serializeBuffer, objectIdProvider); return cmd.Serialize(requiredSize, serializeBuffer, objectIdProvider);
}, },
extraSize, std::forward<ExtraSizeSerializeFn>(SerializeExtraSize)); std::forward<Extensions>(extensions)...);
} }
private: private:
template <typename Cmd, typename SerializeCmdFn, typename ExtraSizeSerializeFn> template <typename Cmd, typename SerializeCmdFn, typename... Extensions>
void SerializeCommandImpl(const Cmd& cmd, void SerializeCommandImpl(const Cmd& cmd,
SerializeCmdFn&& SerializeCmd, SerializeCmdFn&& SerializeCmd,
size_t extraSize, Extensions&&... extensions) {
ExtraSizeSerializeFn&& SerializeExtraSize) {
size_t commandSize = cmd.GetRequiredSize(); size_t commandSize = cmd.GetRequiredSize();
size_t requiredSize = commandSize + extraSize; size_t requiredSize = (Align(extensions.size, kWireBufferAlignment) + ... + commandSize);
if (requiredSize <= mMaxAllocationSize) { if (requiredSize <= mMaxAllocationSize) {
char* allocatedBuffer = static_cast<char*>(mSerializer->GetCmdSpace(requiredSize)); char* allocatedBuffer = static_cast<char*>(mSerializer->GetCmdSpace(requiredSize));
if (allocatedBuffer != nullptr) { if (allocatedBuffer != nullptr) {
SerializeBuffer serializeBuffer(allocatedBuffer, requiredSize); SerializeBuffer serializeBuffer(allocatedBuffer, requiredSize);
WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer); WireResult rCmd = SerializeCmd(cmd, requiredSize, &serializeBuffer);
WireResult r2 = SerializeExtraSize(&serializeBuffer); WireResult rExts =
if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) { detail::SerializeCommandExtension(&serializeBuffer, extensions...);
if (DAWN_UNLIKELY(rCmd != WireResult::Success || rExts != WireResult::Success)) {
mSerializer->OnSerializeError(); mSerializer->OnSerializeError();
} }
} }
@ -95,9 +119,9 @@ class ChunkedCommandSerializer {
return; return;
} }
SerializeBuffer serializeBuffer(cmdSpace.get(), requiredSize); SerializeBuffer serializeBuffer(cmdSpace.get(), requiredSize);
WireResult r1 = SerializeCmd(cmd, requiredSize, &serializeBuffer); WireResult rCmd = SerializeCmd(cmd, requiredSize, &serializeBuffer);
WireResult r2 = SerializeExtraSize(&serializeBuffer); WireResult rExts = detail::SerializeCommandExtension(&serializeBuffer, extensions...);
if (DAWN_UNLIKELY(r1 != WireResult::Success || r2 != WireResult::Success)) { if (DAWN_UNLIKELY(rCmd != WireResult::Success || rExts != WireResult::Success)) {
mSerializer->OnSerializeError(); mSerializer->OnSerializeError();
return; return;
} }

View File

@ -47,6 +47,8 @@ WGPUBuffer Buffer::Create(Device* device, const WGPUBufferDescriptor* descriptor
cmd.writeHandleCreateInfoLength = 0; cmd.writeHandleCreateInfoLength = 0;
cmd.writeHandleCreateInfo = nullptr; cmd.writeHandleCreateInfo = nullptr;
size_t readHandleCreateInfoLength = 0;
size_t writeHandleCreateInfoLength = 0;
if (mappable) { if (mappable) {
if ((descriptor->usage & WGPUBufferUsage_MapRead) != 0) { if ((descriptor->usage & WGPUBufferUsage_MapRead) != 0) {
// Create the read handle on buffer creation. // Create the read handle on buffer creation.
@ -56,7 +58,8 @@ WGPUBuffer Buffer::Create(Device* device, const WGPUBufferDescriptor* descriptor
device->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping"); device->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping");
return CreateError(device, descriptor); return CreateError(device, descriptor);
} }
cmd.readHandleCreateInfoLength = readHandle->SerializeCreateSize(); readHandleCreateInfoLength = readHandle->SerializeCreateSize();
cmd.readHandleCreateInfoLength = readHandleCreateInfoLength;
} }
if ((descriptor->usage & WGPUBufferUsage_MapWrite) != 0 || descriptor->mappedAtCreation) { if ((descriptor->usage & WGPUBufferUsage_MapWrite) != 0 || descriptor->mappedAtCreation) {
@ -67,7 +70,8 @@ WGPUBuffer Buffer::Create(Device* device, const WGPUBufferDescriptor* descriptor
device->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping"); device->InjectError(WGPUErrorType_OutOfMemory, "Failed to create buffer mapping");
return CreateError(device, descriptor); return CreateError(device, descriptor);
} }
cmd.writeHandleCreateInfoLength = writeHandle->SerializeCreateSize(); writeHandleCreateInfoLength = writeHandle->SerializeCreateSize();
cmd.writeHandleCreateInfoLength = writeHandleCreateInfoLength;
} }
} }
@ -95,27 +99,28 @@ WGPUBuffer Buffer::Create(Device* device, const WGPUBufferDescriptor* descriptor
cmd.result = buffer->GetWireHandle(); cmd.result = buffer->GetWireHandle();
// clang-format off
// Turning off clang format here because for some reason it does not format the
// CommandExtensions consistently, making it harder to read.
wireClient->SerializeCommand( wireClient->SerializeCommand(
cmd, cmd.readHandleCreateInfoLength + cmd.writeHandleCreateInfoLength, cmd,
[&](SerializeBuffer* serializeBuffer) { CommandExtension{readHandleCreateInfoLength,
[&](char* readHandleBuffer) {
if (readHandle != nullptr) { if (readHandle != nullptr) {
char* readHandleBuffer;
WIRE_TRY(serializeBuffer->NextN(cmd.readHandleCreateInfoLength, &readHandleBuffer));
// Serialize the ReadHandle into the space after the command. // Serialize the ReadHandle into the space after the command.
readHandle->SerializeCreate(readHandleBuffer); readHandle->SerializeCreate(readHandleBuffer);
buffer->mReadHandle = std::move(readHandle); buffer->mReadHandle = std::move(readHandle);
} }
}},
CommandExtension{writeHandleCreateInfoLength,
[&](char* writeHandleBuffer) {
if (writeHandle != nullptr) { if (writeHandle != nullptr) {
char* writeHandleBuffer;
WIRE_TRY(
serializeBuffer->NextN(cmd.writeHandleCreateInfoLength, &writeHandleBuffer));
// Serialize the WriteHandle into the space after the command. // Serialize the WriteHandle into the space after the command.
writeHandle->SerializeCreate(writeHandleBuffer); writeHandle->SerializeCreate(writeHandleBuffer);
buffer->mWriteHandle = std::move(writeHandle); buffer->mWriteHandle = std::move(writeHandle);
} }
}});
return WireResult::Success; // clang-format on
});
return ToAPI(buffer); return ToAPI(buffer);
} }
@ -310,16 +315,12 @@ void Buffer::Unmap() {
cmd.size = mMapSize; cmd.size = mMapSize;
client->SerializeCommand( client->SerializeCommand(
cmd, writeDataUpdateInfoLength, [&](SerializeBuffer* serializeBuffer) { cmd, CommandExtension{writeDataUpdateInfoLength, [&](char* writeHandleBuffer) {
char* writeHandleBuffer;
WIRE_TRY(serializeBuffer->NextN(writeDataUpdateInfoLength, &writeHandleBuffer));
// Serialize flush metadata into the space after the command. // Serialize flush metadata into the space after the command.
// This closes the handle for writing. // This closes the handle for writing.
mWriteHandle->SerializeDataUpdate(writeHandleBuffer, cmd.offset, cmd.size); mWriteHandle->SerializeDataUpdate(writeHandleBuffer,
cmd.offset, cmd.size);
return WireResult::Success; }});
});
// If mDestructWriteHandleOnUnmap is true, that means the write handle is merely // If mDestructWriteHandleOnUnmap is true, that means the write handle is merely
// for mappedAtCreation usage. It is destroyed on unmap after flush to server // for mappedAtCreation usage. It is destroyed on unmap after flush to server

View File

@ -85,11 +85,9 @@ class Client : public ClientBase {
mSerializer.SerializeCommand(cmd, *this); mSerializer.SerializeCommand(cmd, *this);
} }
template <typename Cmd, typename ExtraSizeSerializeFn> template <typename Cmd, typename... Extensions>
void SerializeCommand(const Cmd& cmd, void SerializeCommand(const Cmd& cmd, Extensions&&... es) {
size_t extraSize, mSerializer.SerializeCommand(cmd, *this, std::forward<Extensions>(es)...);
ExtraSizeSerializeFn&& SerializeExtraSize) {
mSerializer.SerializeCommand(cmd, *this, extraSize, SerializeExtraSize);
} }
void Disconnect(); void Disconnect();

View File

@ -184,11 +184,9 @@ class Server : public ServerBase {
mSerializer.SerializeCommand(cmd); mSerializer.SerializeCommand(cmd);
} }
template <typename Cmd, typename ExtraSizeSerializeFn> template <typename Cmd, typename... Extensions>
void SerializeCommand(const Cmd& cmd, void SerializeCommand(const Cmd& cmd, Extensions&&... es) {
size_t extraSize, mSerializer.SerializeCommand(cmd, std::forward<Extensions>(es)...);
ExtraSizeSerializeFn&& SerializeExtraSize) {
mSerializer.SerializeCommand(cmd, extraSize, SerializeExtraSize);
} }
void SetForwardingDeviceCallbacks(ObjectData<WGPUDevice>* deviceObject); void SetForwardingDeviceCallbacks(ObjectData<WGPUDevice>* deviceObject);

View File

@ -237,12 +237,14 @@ void Server::OnBufferMapAsyncCallback(MapUserdata* data, WGPUBufferMapAsyncStatu
cmd.readDataUpdateInfo = nullptr; cmd.readDataUpdateInfo = nullptr;
const void* readData = nullptr; const void* readData = nullptr;
size_t readDataUpdateInfoLength = 0;
if (isSuccess) { if (isSuccess) {
if (isRead) { if (isRead) {
// Get the serialization size of the message to initialize ReadHandle data. // Get the serialization size of the message to initialize ReadHandle data.
readData = mProcs.bufferGetConstMappedRange(data->bufferObj, data->offset, data->size); readData = mProcs.bufferGetConstMappedRange(data->bufferObj, data->offset, data->size);
cmd.readDataUpdateInfoLength = readDataUpdateInfoLength =
bufferData->readHandle->SizeOfSerializeDataUpdate(data->offset, data->size); bufferData->readHandle->SizeOfSerializeDataUpdate(data->offset, data->size);
cmd.readDataUpdateInfoLength = readDataUpdateInfoLength;
} else { } else {
ASSERT(data->mode & WGPUMapMode_Write); ASSERT(data->mode & WGPUMapMode_Write);
// The in-flight map request returned successfully. // The in-flight map request returned successfully.
@ -259,16 +261,15 @@ void Server::OnBufferMapAsyncCallback(MapUserdata* data, WGPUBufferMapAsyncStatu
} }
} }
SerializeCommand(cmd, cmd.readDataUpdateInfoLength, [&](SerializeBuffer* serializeBuffer) { SerializeCommand(cmd, CommandExtension{readDataUpdateInfoLength, [&](char* readHandleBuffer) {
if (isSuccess && isRead) { if (isSuccess && isRead) {
char* readHandleBuffer; // The in-flight map request returned
WIRE_TRY(serializeBuffer->NextN(cmd.readDataUpdateInfoLength, &readHandleBuffer)); // successfully.
// The in-flight map request returned successfully. bufferData->readHandle->SerializeDataUpdate(
bufferData->readHandle->SerializeDataUpdate(readData, data->offset, data->size, readData, data->offset, data->size,
readHandleBuffer); readHandleBuffer);
} }
return WireResult::Success; }});
});
} }
} // namespace dawn::wire::server } // namespace dawn::wire::server