diff --git a/generator/templates/dawn_wire/WireClient.cpp b/generator/templates/dawn_wire/WireClient.cpp index 49b0f0db03..4dd37eb347 100644 --- a/generator/templates/dawn_wire/WireClient.cpp +++ b/generator/templates/dawn_wire/WireClient.cpp @@ -215,6 +215,9 @@ namespace dawn_wire { // Implementation of the ObjectIdProvider interface {% for type in by_category["object"] %} ObjectId GetId({{as_cType(type.name)}} object) const override { + if (object == nullptr) { + return 0; + } return reinterpret_cast<{{as_wireType(type)}}>(object)->id; } {% endfor %} diff --git a/generator/templates/dawn_wire/WireServer.cpp b/generator/templates/dawn_wire/WireServer.cpp index 77afff0c21..6b00cb225d 100644 --- a/generator/templates/dawn_wire/WireServer.cpp +++ b/generator/templates/dawn_wire/WireServer.cpp @@ -358,6 +358,11 @@ namespace dawn_wire { // Implementation of the ObjectIdResolver interface {% for type in by_category["object"] %} DeserializeResult GetFromId(ObjectId id, {{as_cType(type.name)}}* out) const override { + if (id == 0) { + *out = nullptr; + return DeserializeResult::Success; + } + auto data = mKnown{{type.name.CamelCase()}}.Get(id); if (data == nullptr) { return DeserializeResult::FatalError; diff --git a/src/tests/unittests/WireTests.cpp b/src/tests/unittests/WireTests.cpp index ba29834e1b..731ee96233 100644 --- a/src/tests/unittests/WireTests.cpp +++ b/src/tests/unittests/WireTests.cpp @@ -328,7 +328,7 @@ TEST_F(WireTests, CStringArgument) { } // Test that the wire is able to send objects as value arguments -TEST_F(WireTests, DISABLED_ObjectAsValueArgument) { +TEST_F(WireTests, ObjectAsValueArgument) { // Create pipeline dawnComputePipelineDescriptor pipelineDesc; pipelineDesc.nextInChain = nullptr; @@ -485,6 +485,46 @@ TEST_F(WireTests, StructureOfStructureArrayArgument) { FlushClient(); } +// Test passing nullptr instead of objects - object as value version +TEST_F(WireTests, NullptrAsValue) { + dawnCommandBufferBuilder builder = dawnDeviceCreateCommandBufferBuilder(device); + dawnComputePassEncoder pass = dawnCommandBufferBuilderBeginComputePass(builder); + dawnComputePassEncoderSetComputePipeline(pass, nullptr); + + dawnCommandBufferBuilder apiBuilder = api.GetNewCommandBufferBuilder(); + EXPECT_CALL(api, DeviceCreateCommandBufferBuilder(apiDevice)) + .WillOnce(Return(apiBuilder)); + + dawnComputePassEncoder apiPass = api.GetNewComputePassEncoder(); + EXPECT_CALL(api, CommandBufferBuilderBeginComputePass(apiBuilder)) + .WillOnce(Return(apiPass)); + + EXPECT_CALL(api, ComputePassEncoderSetComputePipeline(apiPass, nullptr)) + .Times(1); + + FlushClient(); +} + +// Test passing nullptr instead of objects - array of objects version +TEST_F(WireTests, NullptrInArray) { + dawnBindGroupLayout nullBGL = nullptr; + + dawnPipelineLayoutDescriptor descriptor; + descriptor.nextInChain = nullptr; + descriptor.numBindGroupLayouts = 1; + descriptor.bindGroupLayouts = &nullBGL; + + dawnDeviceCreatePipelineLayout(device, &descriptor); + EXPECT_CALL(api, DeviceCreatePipelineLayout(apiDevice, MatchesLambda([](const dawnPipelineLayoutDescriptor* desc) -> bool { + return desc->nextInChain == nullptr && + desc->numBindGroupLayouts == 1 && + desc->bindGroupLayouts[0] == nullptr; + }))) + .WillOnce(Return(nullptr)); + + FlushClient(); +} + // Test that the server doesn't forward calls to error objects or with error objects // Also test that when GetResult is called on an error builder, the error callback is fired // TODO(cwallez@chromium.org): This test is disabled because the introduction of encoders breaks