diff --git a/generator/main.py b/generator/main.py index 9f27d581f4..e58daac59e 100644 --- a/generator/main.py +++ b/generator/main.py @@ -79,11 +79,12 @@ class NativelyDefined(Type): Type.__init__(self, name, record) class MethodArgument: - def __init__(self, name, typ, annotation): + def __init__(self, name, typ, annotation, optional): self.name = name self.type = typ self.annotation = annotation self.length = None + self.optional = optional Method = namedtuple('Method', ['name', 'return_type', 'arguments']) class ObjectType(Type): @@ -94,11 +95,12 @@ class ObjectType(Type): self.built_type = None class StructureMember: - def __init__(self, name, typ, annotation): + def __init__(self, name, typ, annotation, optional): self.name = name self.type = typ self.annotation = annotation self.length = None + self.optional = optional class StructureType(Type): def __init__(self, name, record): @@ -120,7 +122,8 @@ def link_object(obj, types): arguments = [] arguments_by_name = {} for a in record.get('args', []): - arg = MethodArgument(Name(a['name']), types[a['type']], a.get('annotation', 'value')) + arg = MethodArgument(Name(a['name']), types[a['type']], + a.get('annotation', 'value'), a.get('optional', False)) arguments.append(arg) arguments_by_name[arg.name.canonical_case()] = arg @@ -153,7 +156,8 @@ def link_object(obj, types): def link_structure(struct, types): def make_member(m): - return StructureMember(Name(m['name']), types[m['type']], m.get('annotation', 'value')) + return StructureMember(Name(m['name']), types[m['type']], + m.get('annotation', 'value'), m.get('optional', False)) members = [] members_by_name = {} @@ -410,9 +414,9 @@ def cpp_native_methods(types, typ): if typ.is_builder: methods.append(Method(Name('set error callback'), types['void'], [ - MethodArgument(Name('callback'), types['builder error callback'], 'value'), - MethodArgument(Name('userdata1'), types['callback userdata'], 'value'), - MethodArgument(Name('userdata2'), types['callback userdata'], 'value'), + MethodArgument(Name('callback'), types['builder error callback'], 'value', False), + MethodArgument(Name('userdata1'), types['callback userdata'], 'value', False), + MethodArgument(Name('userdata2'), types['callback userdata'], 'value', False), ])) return methods diff --git a/generator/templates/dawn_wire/WireClient.cpp b/generator/templates/dawn_wire/WireClient.cpp index 6641a14580..117cfad5dc 100644 --- a/generator/templates/dawn_wire/WireClient.cpp +++ b/generator/templates/dawn_wire/WireClient.cpp @@ -214,11 +214,14 @@ namespace dawn_wire { // Implementation of the ObjectIdProvider interface {% for type in by_category["object"] %} - ObjectId GetId({{as_cType(type.name)}} object) const override { + ObjectId GetId({{as_cType(type.name)}} object) const final { + return reinterpret_cast<{{as_wireType(type)}}>(object)->id; + } + ObjectId GetOptionalId({{as_cType(type.name)}} object) const final { if (object == nullptr) { return 0; } - return reinterpret_cast<{{as_wireType(type)}}>(object)->id; + return GetId(object); } {% endfor %} diff --git a/generator/templates/dawn_wire/WireCmd.cpp b/generator/templates/dawn_wire/WireCmd.cpp index a72c9f4d5d..ceb828f123 100644 --- a/generator/templates/dawn_wire/WireCmd.cpp +++ b/generator/templates/dawn_wire/WireCmd.cpp @@ -48,7 +48,8 @@ //* Outputs the serialization code to put `in` in `out` {% macro serialize_member(member, in, out) %} {%- if member.type.category == "object" -%} - {{out}} = provider.GetId({{in}}); + {% set Optional = "Optional" if member.optional else "" %} + {{out}} = provider.Get{{Optional}}Id({{in}}); {% elif member.type.category == "structure"%} {{as_cType(member.type.name)}}Serialize({{in}}, &{{out}}, buffer, provider); {%- else -%} @@ -59,7 +60,8 @@ //* Outputs the deserialization code to put `in` in `out` {% macro deserialize_member(member, in, out) %} {%- if member.type.category == "object" -%} - DESERIALIZE_TRY(resolver.GetFromId({{in}}, &{{out}})); + {% set Optional = "Optional" if member.optional else "" %} + DESERIALIZE_TRY(resolver.Get{{Optional}}FromId({{in}}, &{{out}})); {% elif member.type.category == "structure"%} DESERIALIZE_TRY({{as_cType(member.type.name)}}Deserialize(&{{out}}, &{{in}}, buffer, size, allocator, resolver)); {%- else -%} diff --git a/generator/templates/dawn_wire/WireCmd.h b/generator/templates/dawn_wire/WireCmd.h index 9e451ef191..2d5300872d 100644 --- a/generator/templates/dawn_wire/WireCmd.h +++ b/generator/templates/dawn_wire/WireCmd.h @@ -40,6 +40,7 @@ namespace dawn_wire { public: {% for type in by_category["object"] %} virtual DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const = 0; + virtual DeserializeResult GetOptionalFromId(ObjectId id, {{as_cType(type.name)}}* out) const = 0; {% endfor %} }; @@ -48,6 +49,7 @@ namespace dawn_wire { public: {% for type in by_category["object"] %} virtual ObjectId GetId({{as_cType(type.name)}} object) const = 0; + virtual ObjectId GetOptionalId({{as_cType(type.name)}} object) const = 0; {% endfor %} }; diff --git a/generator/templates/dawn_wire/WireServer.cpp b/generator/templates/dawn_wire/WireServer.cpp index 945f56539a..ebc341760b 100644 --- a/generator/templates/dawn_wire/WireServer.cpp +++ b/generator/templates/dawn_wire/WireServer.cpp @@ -357,12 +357,7 @@ namespace dawn_wire { // Implementation of the ObjectIdResolver interface {% for type in by_category["object"] %} - DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const override { - if (id == 0) { - *out = nullptr; - return DeserializeResult::Success; - } - + DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const final { auto data = mKnown{{type.name.CamelCase()}}.Get(id); if (data == nullptr) { return DeserializeResult::FatalError; @@ -375,6 +370,15 @@ namespace dawn_wire { return DeserializeResult::ErrorObject; } } + + DeserializeResult GetOptionalFromId(ObjectId id, {{as_cType(type.name)}}* out) const final { + if (id == 0) { + *out = nullptr; + return DeserializeResult::Success; + } + + return GetFromId(id, out); + } {% endfor %} //* The list of known IDs for each object type. diff --git a/src/tests/unittests/WireTests.cpp b/src/tests/unittests/WireTests.cpp index 731ee96233..4670cab974 100644 --- a/src/tests/unittests/WireTests.cpp +++ b/src/tests/unittests/WireTests.cpp @@ -329,32 +329,27 @@ TEST_F(WireTests, CStringArgument) { // Test that the wire is able to send objects as value arguments TEST_F(WireTests, ObjectAsValueArgument) { - // Create pipeline - dawnComputePipelineDescriptor pipelineDesc; - pipelineDesc.nextInChain = nullptr; - pipelineDesc.layout = nullptr; - pipelineDesc.entryPoint = "main"; - pipelineDesc.module = nullptr; - dawnComputePipeline pipeline = dawnDeviceCreateComputePipeline(device, &pipelineDesc); + // Create a RenderPassDescriptor + dawnRenderPassDescriptorBuilder renderPassBuilder = dawnDeviceCreateRenderPassDescriptorBuilder(device); + dawnRenderPassDescriptor renderPass = dawnRenderPassDescriptorBuilderGetResult(renderPassBuilder); - dawnComputePipeline apiPipeline = api.GetNewComputePipeline(); - EXPECT_CALL(api, DeviceCreateComputePipeline(apiDevice, _)) - .WillOnce(Return(apiPipeline)); + dawnRenderPassDescriptorBuilder apiRenderPassBuilder = api.GetNewRenderPassDescriptorBuilder(); + EXPECT_CALL(api, DeviceCreateRenderPassDescriptorBuilder(apiDevice)) + .WillOnce(Return(apiRenderPassBuilder)); + dawnRenderPassDescriptor apiRenderPass = api.GetNewRenderPassDescriptor(); + EXPECT_CALL(api, RenderPassDescriptorBuilderGetResult(apiRenderPassBuilder)) + .WillOnce(Return(apiRenderPass)); - // Create command buffer builder, setting pipeline + // Create command buffer builder, setting render pass descriptor dawnCommandBufferBuilder cmdBufBuilder = dawnDeviceCreateCommandBufferBuilder(device); - dawnComputePassEncoder pass = dawnCommandBufferBuilderBeginComputePass(cmdBufBuilder); - dawnComputePassEncoderSetComputePipeline(pass, pipeline); + dawnCommandBufferBuilderBeginRenderPass(cmdBufBuilder, renderPass); dawnCommandBufferBuilder apiCmdBufBuilder = api.GetNewCommandBufferBuilder(); EXPECT_CALL(api, DeviceCreateCommandBufferBuilder(apiDevice)) .WillOnce(Return(apiCmdBufBuilder)); - dawnComputePassEncoder apiPass = api.GetNewComputePassEncoder(); - EXPECT_CALL(api, CommandBufferBuilderBeginComputePass(apiCmdBufBuilder)) - .WillOnce(Return(apiPass)); - - EXPECT_CALL(api, ComputePassEncoderSetComputePipeline(apiPass, apiPipeline)); + EXPECT_CALL(api, CommandBufferBuilderBeginRenderPass(apiCmdBufBuilder, apiRenderPass)) + .Times(1); FlushClient(); } @@ -486,7 +481,7 @@ TEST_F(WireTests, StructureOfStructureArrayArgument) { } // Test passing nullptr instead of objects - object as value version -TEST_F(WireTests, NullptrAsValue) { +TEST_F(WireTests, DISABLED_NullptrAsValue) { dawnCommandBufferBuilder builder = dawnDeviceCreateCommandBufferBuilder(device); dawnComputePassEncoder pass = dawnCommandBufferBuilderBeginComputePass(builder); dawnComputePassEncoderSetComputePipeline(pass, nullptr); @@ -506,7 +501,7 @@ TEST_F(WireTests, NullptrAsValue) { } // Test passing nullptr instead of objects - array of objects version -TEST_F(WireTests, NullptrInArray) { +TEST_F(WireTests, DISABLED_NullptrInArray) { dawnBindGroupLayout nullBGL = nullptr; dawnPipelineLayoutDescriptor descriptor;