dawn_wire: Tag deserialize commands with volatile pointer

This prevents bugs where the compiler assumes a piece of memory
will be the same if read from twice.

Bug: dawn:230
Change-Id: Ib3358e56b6cf8f1fbf449c5d564ef85c969d695b
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/11840
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Austin Eng 2019-10-07 20:38:47 +00:00 committed by Commit Bot service account
parent d561448d0d
commit 8eb8385e2e
13 changed files with 47 additions and 27 deletions

View File

@ -16,6 +16,7 @@
#include "common/Assert.h" #include "common/Assert.h"
#include <algorithm>
#include <cstring> #include <cstring>
#include <limits> #include <limits>
@ -225,8 +226,8 @@
//* Deserializes `transfer` into `record` getting more serialized data from `buffer` and `size` //* Deserializes `transfer` into `record` getting more serialized data from `buffer` and `size`
//* if needed, using `allocator` to store pointed-to values and `resolver` to translate object //* if needed, using `allocator` to store pointed-to values and `resolver` to translate object
//* Ids to actual objects. //* Ids to actual objects.
DAWN_DECLARE_UNUSED DeserializeResult {{Return}}{{name}}Deserialize({{Return}}{{name}}{{Cmd}}* record, const {{Return}}{{name}}Transfer* transfer, DAWN_DECLARE_UNUSED DeserializeResult {{Return}}{{name}}Deserialize({{Return}}{{name}}{{Cmd}}* record, const volatile {{Return}}{{name}}Transfer* transfer,
const char** buffer, size_t* size, DeserializeAllocator* allocator const volatile char** buffer, size_t* size, DeserializeAllocator* allocator
{%- if record.has_dawn_object -%} {%- if record.has_dawn_object -%}
, const ObjectIdResolver& resolver , const ObjectIdResolver& resolver
{%- endif -%} {%- endif -%}
@ -257,18 +258,19 @@
{% for member in members if member.length == "strlen" %} {% for member in members if member.length == "strlen" %}
{% set memberName = as_varName(member.name) %} {% set memberName = as_varName(member.name) %}
record->{{memberName}} = nullptr;
{% if member.optional %} {% if member.optional %}
if (transfer->has_{{memberName}}) bool has_{{memberName}} = transfer->has_{{memberName}};
record->{{memberName}} = nullptr;
if (has_{{memberName}})
{% endif %} {% endif %}
{ {
size_t stringLength = transfer->{{memberName}}Strlen; size_t stringLength = transfer->{{memberName}}Strlen;
const char* stringInBuffer = nullptr; const volatile char* stringInBuffer = nullptr;
DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, stringLength, &stringInBuffer)); DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, stringLength, &stringInBuffer));
char* copiedString = nullptr; char* copiedString = nullptr;
DESERIALIZE_TRY(GetSpace(allocator, stringLength + 1, &copiedString)); DESERIALIZE_TRY(GetSpace(allocator, stringLength + 1, &copiedString));
memcpy(copiedString, stringInBuffer, stringLength); std::copy(stringInBuffer, stringInBuffer + stringLength, copiedString);
copiedString[stringLength] = '\0'; copiedString[stringLength] = '\0';
record->{{memberName}} = copiedString; record->{{memberName}} = copiedString;
} }
@ -285,7 +287,7 @@
{% endif %} {% endif %}
{ {
size_t memberLength = {{member_length(member, "record->")}}; size_t memberLength = {{member_length(member, "record->")}};
auto memberBuffer = reinterpret_cast<const {{member_transfer_type(member)}}*>(buffer); auto memberBuffer = reinterpret_cast<const volatile {{member_transfer_type(member)}}*>(buffer);
DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, memberLength, &memberBuffer)); DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, memberLength, &memberBuffer));
{{as_cType(member.type.name)}}* copiedMembers = nullptr; {{as_cType(member.type.name)}}* copiedMembers = nullptr;
@ -337,12 +339,12 @@
); );
} }
DeserializeResult {{Cmd}}::Deserialize(const char** buffer, size_t* size, DeserializeAllocator* allocator DeserializeResult {{Cmd}}::Deserialize(const volatile char** buffer, size_t* size, DeserializeAllocator* allocator
{%- if command.has_dawn_object -%} {%- if command.has_dawn_object -%}
, const ObjectIdResolver& resolver , const ObjectIdResolver& resolver
{%- endif -%} {%- endif -%}
) { ) {
const {{Name}}Transfer* transfer = nullptr; const volatile {{Name}}Transfer* transfer = nullptr;
DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, 1, &transfer)); DESERIALIZE_TRY(GetPtrFromBuffer(buffer, size, 1, &transfer));
return {{Name}}Deserialize(this, transfer, buffer, size, allocator return {{Name}}Deserialize(this, transfer, buffer, size, allocator
@ -364,12 +366,22 @@ namespace dawn_wire {
} \ } \
} }
ObjectHandle::ObjectHandle() = default;
ObjectHandle::ObjectHandle(ObjectId id, ObjectSerial serial) : id(id), serial(serial) {}
ObjectHandle::ObjectHandle(const volatile ObjectHandle& rhs) : id(rhs.id), serial(rhs.serial) {}
ObjectHandle& ObjectHandle::operator=(const ObjectHandle& rhs) = default;
ObjectHandle& ObjectHandle::operator=(const volatile ObjectHandle& rhs) {
id = rhs.id;
serial = rhs.serial;
return *this;
}
namespace { namespace {
// Consumes from (buffer, size) enough memory to contain T[count] and return it in data. // Consumes from (buffer, size) enough memory to contain T[count] and return it in data.
// Returns FatalError if not enough memory was available // Returns FatalError if not enough memory was available
template <typename T> template <typename T>
DeserializeResult GetPtrFromBuffer(const char** buffer, size_t* size, size_t count, const T** data) { DeserializeResult GetPtrFromBuffer(const volatile char** buffer, size_t* size, size_t count, const volatile T** data) {
constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits<size_t>::max() / sizeof(T); constexpr size_t kMaxCountWithoutOverflows = std::numeric_limits<size_t>::max() / sizeof(T);
if (count > kMaxCountWithoutOverflows) { if (count > kMaxCountWithoutOverflows) {
return DeserializeResult::FatalError; return DeserializeResult::FatalError;
@ -380,7 +392,7 @@ namespace dawn_wire {
return DeserializeResult::FatalError; return DeserializeResult::FatalError;
} }
*data = reinterpret_cast<const T*>(*buffer); *data = reinterpret_cast<const volatile T*>(*buffer);
*buffer += totalSize; *buffer += totalSize;
*size -= totalSize; *size -= totalSize;

View File

@ -24,6 +24,12 @@ namespace dawn_wire {
struct ObjectHandle { struct ObjectHandle {
ObjectId id; ObjectId id;
ObjectSerial serial; ObjectSerial serial;
ObjectHandle();
ObjectHandle(ObjectId id, ObjectSerial serial);
ObjectHandle(const volatile ObjectHandle& rhs);
ObjectHandle& operator=(const ObjectHandle& rhs);
ObjectHandle& operator=(const volatile ObjectHandle& rhs);
}; };
enum class DeserializeResult { enum class DeserializeResult {
@ -99,7 +105,7 @@ namespace dawn_wire {
//* Deserialize returns: //* Deserialize returns:
//* - Success if everything went well (yay!) //* - Success if everything went well (yay!)
//* - FatalError is something bad happened (buffer too small for example) //* - FatalError is something bad happened (buffer too small for example)
DeserializeResult Deserialize(const char** buffer, size_t* size, DeserializeAllocator* allocator DeserializeResult Deserialize(const volatile char** buffer, size_t* size, DeserializeAllocator* allocator
{%- if command.has_dawn_object -%} {%- if command.has_dawn_object -%}
, const ObjectIdResolver& resolver , const ObjectIdResolver& resolver
{%- endif -%} {%- endif -%}

View File

@ -19,7 +19,7 @@
namespace dawn_wire { namespace client { namespace dawn_wire { namespace client {
{% for command in cmd_records["return command"] %} {% for command in cmd_records["return command"] %}
bool Client::Handle{{command.name.CamelCase()}}(const char** commands, size_t* size) { bool Client::Handle{{command.name.CamelCase()}}(const volatile char** commands, size_t* size) {
Return{{command.name.CamelCase()}}Cmd cmd; Return{{command.name.CamelCase()}}Cmd cmd;
DeserializeResult deserializeResult = cmd.Deserialize(commands, size, &mAllocator); DeserializeResult deserializeResult = cmd.Deserialize(commands, size, &mAllocator);
@ -53,9 +53,9 @@ namespace dawn_wire { namespace client {
} }
{% endfor %} {% endfor %}
const char* Client::HandleCommands(const char* commands, size_t size) { const volatile char* Client::HandleCommands(const volatile char* commands, size_t size) {
while (size >= sizeof(ReturnWireCmd)) { while (size >= sizeof(ReturnWireCmd)) {
ReturnWireCmd cmdId = *reinterpret_cast<const ReturnWireCmd*>(commands); ReturnWireCmd cmdId = *reinterpret_cast<const volatile ReturnWireCmd*>(commands);
bool success = false; bool success = false;
switch (cmdId) { switch (cmdId) {

View File

@ -14,7 +14,7 @@
//* Return command handlers //* Return command handlers
{% for command in cmd_records["return command"] %} {% for command in cmd_records["return command"] %}
bool Handle{{command.name.CamelCase()}}(const char** commands, size_t* size); bool Handle{{command.name.CamelCase()}}(const volatile char** commands, size_t* size);
{% endfor %} {% endfor %}
//* Return command doers //* Return command doers

View File

@ -25,7 +25,7 @@ namespace dawn_wire { namespace server {
{% set Suffix = command.name.CamelCase() %} {% set Suffix = command.name.CamelCase() %}
{% if Suffix not in client_side_commands %} {% if Suffix not in client_side_commands %}
//* The generic command handlers //* The generic command handlers
bool Server::Handle{{Suffix}}(const char** commands, size_t* size) { bool Server::Handle{{Suffix}}(const volatile char** commands, size_t* size) {
{{Suffix}}Cmd cmd; {{Suffix}}Cmd cmd;
DeserializeResult deserializeResult = cmd.Deserialize(commands, size, &mAllocator DeserializeResult deserializeResult = cmd.Deserialize(commands, size, &mAllocator
{%- if command.has_dawn_object -%} {%- if command.has_dawn_object -%}
@ -91,11 +91,11 @@ namespace dawn_wire { namespace server {
{% endif %} {% endif %}
{% endfor %} {% endfor %}
const char* Server::HandleCommands(const char* commands, size_t size) { const volatile char* Server::HandleCommands(const volatile char* commands, size_t size) {
mProcs.deviceTick(DeviceObjects().Get(1)->handle); mProcs.deviceTick(DeviceObjects().Get(1)->handle);
while (size >= sizeof(WireCmd)) { while (size >= sizeof(WireCmd)) {
WireCmd cmdId = *reinterpret_cast<const WireCmd*>(commands); WireCmd cmdId = *reinterpret_cast<const volatile WireCmd*>(commands);
bool success = false; bool success = false;
switch (cmdId) { switch (cmdId) {

View File

@ -15,7 +15,7 @@
// Command handlers & doers // Command handlers & doers
{% for command in cmd_records["command"] if command.name.CamelCase() not in client_side_commands %} {% for command in cmd_records["command"] if command.name.CamelCase() not in client_side_commands %}
{% set Suffix = command.name.CamelCase() %} {% set Suffix = command.name.CamelCase() %}
bool Handle{{Suffix}}(const char** commands, size_t* size); bool Handle{{Suffix}}(const volatile char** commands, size_t* size);
bool Do{{Suffix}}( bool Do{{Suffix}}(
{%- for member in command.members -%} {%- for member in command.members -%}

View File

@ -33,7 +33,7 @@ namespace dawn_wire {
return client::GetProcs(); return client::GetProcs();
} }
const char* WireClient::HandleCommands(const char* commands, size_t size) { const volatile char* WireClient::HandleCommands(const volatile char* commands, size_t size) {
return mImpl->HandleCommands(commands, size); return mImpl->HandleCommands(commands, size);
} }

View File

@ -28,7 +28,7 @@ namespace dawn_wire {
mImpl.reset(); mImpl.reset();
} }
const char* WireServer::HandleCommands(const char* commands, size_t size) { const volatile char* WireServer::HandleCommands(const volatile char* commands, size_t size) {
return mImpl->HandleCommands(commands, size); return mImpl->HandleCommands(commands, size);
} }

View File

@ -33,7 +33,7 @@ namespace dawn_wire { namespace client {
Client(CommandSerializer* serializer, MemoryTransferService* memoryTransferService); Client(CommandSerializer* serializer, MemoryTransferService* memoryTransferService);
~Client(); ~Client();
const char* HandleCommands(const char* commands, size_t size); const volatile char* HandleCommands(const volatile char* commands, size_t size);
ReservedTexture ReserveTexture(DawnDevice device); ReservedTexture ReserveTexture(DawnDevice device);
void* GetCmdSpace(size_t size) { void* GetCmdSpace(size_t size) {

View File

@ -53,7 +53,7 @@ namespace dawn_wire { namespace server {
MemoryTransferService* memoryTransferService); MemoryTransferService* memoryTransferService);
~Server(); ~Server();
const char* HandleCommands(const char* commands, size_t size); const volatile char* HandleCommands(const volatile char* commands, size_t size);
bool InjectTexture(DawnTexture texture, uint32_t id, uint32_t generation); bool InjectTexture(DawnTexture texture, uint32_t id, uint32_t generation);

View File

@ -32,7 +32,7 @@ namespace dawn_wire {
class DAWN_WIRE_EXPORT CommandHandler { class DAWN_WIRE_EXPORT CommandHandler {
public: public:
virtual ~CommandHandler() = default; virtual ~CommandHandler() = default;
virtual const char* HandleCommands(const char* commands, size_t size) = 0; virtual const volatile char* HandleCommands(const volatile char* commands, size_t size) = 0;
}; };
} // namespace dawn_wire } // namespace dawn_wire

View File

@ -44,7 +44,8 @@ namespace dawn_wire {
DawnDevice GetDevice() const; DawnDevice GetDevice() const;
DawnProcTable GetProcs() const; DawnProcTable GetProcs() const;
const char* HandleCommands(const char* commands, size_t size) override final; const volatile char* HandleCommands(const volatile char* commands,
size_t size) override final;
ReservedTexture ReserveTexture(DawnDevice device); ReservedTexture ReserveTexture(DawnDevice device);

View File

@ -38,7 +38,8 @@ namespace dawn_wire {
WireServer(const WireServerDescriptor& descriptor); WireServer(const WireServerDescriptor& descriptor);
~WireServer(); ~WireServer();
const char* HandleCommands(const char* commands, size_t size) override final; const volatile char* HandleCommands(const volatile char* commands,
size_t size) override final;
bool InjectTexture(DawnTexture texture, uint32_t id, uint32_t generation); bool InjectTexture(DawnTexture texture, uint32_t id, uint32_t generation);