Nuke the ClientMatches hack for same-device validation

The wire now supports more than one device, and Chrome is updated
to use the new code path. This fixes same-device validation for
createReadyPipeline.

Bug: dawn:565
Change-Id: Id05001ed1a7e535690c87f535da6f72a0e794c59
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/40460
Reviewed-by: Stephen White <senorblanco@chromium.org>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Austin Eng 2021-02-05 21:08:44 +00:00 committed by Commit Bot service account
parent fb2e77106a
commit 05d9e2cde2
4 changed files with 89 additions and 433 deletions

View File

@ -22,96 +22,15 @@
#include <vector>
namespace dawn_wire { namespace client {
namespace {
//* Outputs an rvalue that's the number of elements a pointer member points to.
{% macro member_length(member, accessor) -%}
{%- if member.length == "constant" -%}
{{member.constant_length}}
{%- else -%}
{{accessor}}{{as_varName(member.length.name)}}
{%- endif -%}
{%- endmacro %}
{% for type in by_category["object"] %}
DAWN_DECLARE_UNUSED bool ClientMatches(const Client* client, const {{as_cType(type.name)}} obj) {
return client == reinterpret_cast<const {{as_wireType(type)}}>(obj)->client;
}
DAWN_DECLARE_UNUSED bool ClientMatches(const Client* client, const {{as_cType(type.name)}} *const obj, uint32_t count = 1) {
ASSERT(count == 0 || obj != nullptr);
for (uint32_t i = 0; i < count; ++i) {
if (!ClientMatches(client, obj[i])) {
return false;
}
}
return true;
}
{% endfor %}
bool ClientMatches(const Client* client, WGPUChainedStruct const* chainedStruct);
{% for type in by_category["structure"] if type.may_have_dawn_object %}
DAWN_DECLARE_UNUSED bool ClientMatches(const Client* client, const {{as_cType(type.name)}}& obj) {
{% if type.extensible %}
if (!ClientMatches(client, obj.nextInChain)) {
return false;
}
{% endif %}
{% for member in type.members if member.type.may_have_dawn_object or member.type.category == "object" %}
{% if member.optional %}
if (obj.{{as_varName(member.name)}} != nullptr)
{% endif %}
{
if (!ClientMatches(client, obj.{{as_varName(member.name)}}
{%- if member.annotation != "value" and member.length != "strlen" -%}
, {{member_length(member, "obj.")}}
{%- endif -%})) {
return false;
}
}
{% endfor %}
return true;
}
DAWN_DECLARE_UNUSED bool ClientMatches(const Client* client, const {{as_cType(type.name)}} *const obj, uint32_t count = 1) {
for (uint32_t i = 0; i < count; ++i) {
if (!ClientMatches(client, obj[i])) {
return false;
}
}
return true;
}
{% endfor %}
bool ClientMatches(const Client* client, WGPUChainedStruct const* chainedStruct) {
while (chainedStruct != nullptr) {
switch (chainedStruct->sType) {
{% for sType in types["s type"].values if sType.valid %}
{% set CType = as_cType(sType.name) %}
case {{as_cEnum(types["s type"].name, sType.name)}}: {
{% if types[sType.name.get()].may_have_dawn_object %}
if (!ClientMatches(client, reinterpret_cast<const {{CType}}*>(chainedStruct))) {
return false;
}
{% endif %}
break;
}
{% endfor %}
case WGPUSType_Invalid:
break;
default:
dawn::WarningLog()
<< "All objects may not be from the same client. "
<< "Unknown sType " << chainedStruct->sType << " ignored.";
break;
}
chainedStruct = chainedStruct->next;
}
return true;
}
} // anonymous namespace
//* Outputs an rvalue that's the number of elements a pointer member points to.
{% macro member_length(member, accessor) -%}
{%- if member.length == "constant" -%}
{{member.constant_length}}
{%- else -%}
{{accessor}}{{as_varName(member.length.name)}}
{%- endif -%}
{%- endmacro %}
//* Implementation of the client API functions.
{% for type in by_category["object"] %}
@ -130,49 +49,6 @@ namespace dawn_wire { namespace client {
, {{as_annotated_cType(arg)}}
{%- endfor -%}
) {
{% if len(method.arguments) > 0 %}
{
bool sameClient = true;
auto self = reinterpret_cast<{{as_wireType(type)}}>(cSelf);
Client* client = self->client;
DAWN_UNUSED(client);
do {
{% for arg in method.arguments if arg.type.may_have_dawn_object or arg.type.category == "object" %}
{% if arg.optional %}
if ({{as_varName(arg.name)}} != nullptr)
{% endif %}
{
if (!ClientMatches(client, {{as_varName(arg.name)}}
{%- if arg.annotation != "value" and arg.length != "strlen" -%}
, {{member_length(arg, "")}}
{%- endif -%})) {
sameClient = false;
break;
}
}
{% endfor %}
} while (false);
if (DAWN_UNLIKELY(!sameClient)) {
reinterpret_cast<Device*>(client->GetDevice())->InjectError(WGPUErrorType_Validation,
"All objects must be from the same device.");
{% if method.return_type.category == "object" %}
// Allocate an object without registering it on the server. This is backed by a real allocation on
// the client so commands can be sent with it. But because it's not allocated on the server, it will
// be a fatal error to use it.
auto self = reinterpret_cast<{{as_wireType(type)}}>(cSelf);
auto* allocation = self->client->{{method.return_type.name.CamelCase()}}Allocator().New(self->client);
return reinterpret_cast<{{as_cType(method.return_type.name)}}>(allocation->object.get());
{% elif method.return_type.name.canonical_case() == "void" %}
return;
{% else %}
return {};
{% endif %}
}
}
{% endif %}
auto self = reinterpret_cast<{{as_wireType(type)}}>(cSelf);
{% if Suffix not in client_handwritten_commands %}
{{Suffix}}Cmd cmd;

View File

@ -196,6 +196,7 @@ test("dawn_unittests") {
"unittests/validation/GetBindGroupLayoutValidationTests.cpp",
"unittests/validation/IndexBufferValidationTests.cpp",
"unittests/validation/MinimumBufferSizeValidationTests.cpp",
"unittests/validation/MultipleDeviceTests.cpp",
"unittests/validation/QueryValidationTests.cpp",
"unittests/validation/QueueSubmitValidationTests.cpp",
"unittests/validation/QueueWriteTextureValidationTests.cpp",
@ -228,7 +229,6 @@ test("dawn_unittests") {
"unittests/wire/WireInjectDeviceTests.cpp",
"unittests/wire/WireInjectTextureTests.cpp",
"unittests/wire/WireMemoryTransferServiceTests.cpp",
"unittests/wire/WireMultipleDeviceTests.cpp",
"unittests/wire/WireOptionalTests.cpp",
"unittests/wire/WireTest.cpp",
"unittests/wire/WireTest.h",

View File

@ -0,0 +1,80 @@
// Copyright 2021 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tests/unittests/validation/ValidationTest.h"
#include "tests/MockCallback.h"
using namespace testing;
class MultipleDeviceTest : public ValidationTest {};
// Test that it is invalid to submit a command buffer created on a different device.
TEST_F(MultipleDeviceTest, ValidatesSameDevice) {
wgpu::Device device2 = RegisterDevice(CreateTestDevice());
wgpu::CommandBuffer commandBuffer = device2.CreateCommandEncoder().Finish();
ASSERT_DEVICE_ERROR(device.GetQueue().Submit(1, &commandBuffer));
}
// Test that CreateReadyPipeline fails creation with an Error status if it uses
// objects from a different device.
TEST_F(MultipleDeviceTest, ValidatesSameDeviceCreateReadyPipeline) {
wgpu::ShaderModuleWGSLDescriptor wgslDesc = {};
wgslDesc.source = R"(
[[stage(compute)]] fn main() -> void {
}
)";
wgpu::ShaderModuleDescriptor shaderModuleDesc = {};
shaderModuleDesc.nextInChain = &wgslDesc;
// Base case: CreateReadyComputePipeline succeeds.
{
wgpu::ShaderModule shaderModule = device.CreateShaderModule(&shaderModuleDesc);
wgpu::ComputePipelineDescriptor pipelineDesc = {};
pipelineDesc.computeStage.module = shaderModule;
pipelineDesc.computeStage.entryPoint = "main";
StrictMock<MockCallback<WGPUCreateReadyComputePipelineCallback>> creationCallback;
EXPECT_CALL(creationCallback,
Call(WGPUCreateReadyPipelineStatus_Success, NotNull(), _, this))
.WillOnce(WithArg<1>(Invoke(
[](WGPUComputePipeline pipeline) { wgpu::ComputePipeline::Acquire(pipeline); })));
device.CreateReadyComputePipeline(&pipelineDesc, creationCallback.Callback(),
creationCallback.MakeUserdata(this));
WaitForAllOperations(device);
}
// CreateReadyComputePipeline errors if the shader module is created on a different device.
{
wgpu::Device device2 = RegisterDevice(CreateTestDevice());
wgpu::ShaderModule shaderModule = device2.CreateShaderModule(&shaderModuleDesc);
wgpu::ComputePipelineDescriptor pipelineDesc = {};
pipelineDesc.computeStage.module = shaderModule;
pipelineDesc.computeStage.entryPoint = "main";
StrictMock<MockCallback<WGPUCreateReadyComputePipelineCallback>> creationCallback;
EXPECT_CALL(creationCallback,
Call(WGPUCreateReadyPipelineStatus_Error, nullptr, _, this + 1))
.Times(1);
device.CreateReadyComputePipeline(&pipelineDesc, creationCallback.Callback(),
creationCallback.MakeUserdata(this + 1));
WaitForAllOperations(device);
}
}

View File

@ -1,300 +0,0 @@
// Copyright 2020 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tests/unittests/wire/WireTest.h"
#include "common/Assert.h"
#include "dawn/dawn_proc.h"
#include "dawn_wire/WireClient.h"
#include "dawn_wire/WireServer.h"
#include "tests/MockCallback.h"
#include "utils/TerribleCommandBuffer.h"
#include <array>
using namespace testing;
using namespace dawn_wire;
class WireMultipleDeviceTests : public testing::Test {
protected:
void SetUp() override {
dawnProcSetProcs(&dawn_wire::client::GetProcs());
}
void TearDown() override {
dawnProcSetProcs(nullptr);
}
class WireHolder {
public:
WireHolder() {
DawnProcTable mockProcs;
mApi.GetProcTableAndDevice(&mockProcs, &mServerDevice);
// Ignore Tick()
EXPECT_CALL(mApi, DeviceTick(_)).Times(AnyNumber());
// This SetCallback call cannot be ignored because it is done as soon as we start the
// server
EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(_, _, _)).Times(Exactly(1));
EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(_, _, _)).Times(Exactly(1));
mS2cBuf = std::make_unique<utils::TerribleCommandBuffer>();
mC2sBuf = std::make_unique<utils::TerribleCommandBuffer>();
WireServerDescriptor serverDesc = {};
serverDesc.device = mServerDevice;
serverDesc.procs = &mockProcs;
serverDesc.serializer = mS2cBuf.get();
EXPECT_CALL(mApi, DeviceReference(mServerDevice));
mWireServer.reset(new WireServer(serverDesc));
mC2sBuf->SetHandler(mWireServer.get());
WireClientDescriptor clientDesc = {};
clientDesc.serializer = mC2sBuf.get();
mWireClient.reset(new WireClient(clientDesc));
mS2cBuf->SetHandler(mWireClient.get());
mClientDevice = mWireClient->GetDevice();
// The GetQueue is done on WireClient startup so we expect it now.
mClientQueue = wgpuDeviceGetQueue(mClientDevice);
mServerQueue = mApi.GetNewQueue();
EXPECT_CALL(mApi, DeviceGetQueue(mServerDevice)).WillOnce(Return(mServerQueue));
FlushClient();
}
~WireHolder() {
mApi.IgnoreAllReleaseCalls();
mWireClient = nullptr;
// These are called on server destruction to clear the callbacks. They must not be
// called after the server is destroyed.
EXPECT_CALL(mApi, OnDeviceSetUncapturedErrorCallback(mServerDevice, nullptr, nullptr))
.Times(Exactly(1));
EXPECT_CALL(mApi, OnDeviceSetDeviceLostCallback(mServerDevice, nullptr, nullptr))
.Times(Exactly(1));
mWireServer = nullptr;
}
void FlushClient(bool success = true) {
ASSERT_EQ(mC2sBuf->Flush(), success);
}
void FlushServer(bool success = true) {
ASSERT_EQ(mS2cBuf->Flush(), success);
}
testing::StrictMock<MockProcTable>* Api() {
return &mApi;
}
WGPUDevice ClientDevice() {
return mClientDevice;
}
WGPUDevice ServerDevice() {
return mServerDevice;
}
WGPUQueue ClientQueue() {
return mClientQueue;
}
WGPUQueue ServerQueue() {
return mServerQueue;
}
private:
testing::StrictMock<MockProcTable> mApi;
std::unique_ptr<dawn_wire::WireServer> mWireServer;
std::unique_ptr<dawn_wire::WireClient> mWireClient;
std::unique_ptr<utils::TerribleCommandBuffer> mS2cBuf;
std::unique_ptr<utils::TerribleCommandBuffer> mC2sBuf;
WGPUDevice mServerDevice;
WGPUDevice mClientDevice;
WGPUQueue mServerQueue;
WGPUQueue mClientQueue;
};
void ExpectInjectedError(WireHolder* wire) {
std::string errorMessage;
EXPECT_CALL(*wire->Api(),
DeviceInjectError(wire->ServerDevice(), WGPUErrorType_Validation, _))
.WillOnce(Invoke([&](WGPUDevice device, WGPUErrorType type, const char* message) {
errorMessage = message;
// Mock the call to the error callback.
wire->Api()->CallDeviceSetUncapturedErrorCallbackCallback(device, type, message);
}));
wire->FlushClient();
// The error callback should be forwarded to the client.
StrictMock<MockCallback<WGPUErrorCallback>> mockErrorCallback;
wgpuDeviceSetUncapturedErrorCallback(wire->ClientDevice(), mockErrorCallback.Callback(),
mockErrorCallback.MakeUserdata(this));
EXPECT_CALL(mockErrorCallback, Call(WGPUErrorType_Validation, StrEq(errorMessage), this))
.Times(1);
wire->FlushServer();
}
};
// Test that using objects from a different device is a validation error.
TEST_F(WireMultipleDeviceTests, ValidatesSameDevice) {
WireHolder wireA;
WireHolder wireB;
// Create the fence
WGPUFenceDescriptor desc = {};
WGPUFence fenceA = wgpuQueueCreateFence(wireA.ClientQueue(), &desc);
// Signal with a fence from a different wire.
wgpuQueueSignal(wireB.ClientQueue(), fenceA, 1u);
// We should inject an error into the server.
ExpectInjectedError(&wireB);
}
// Test that objects created from mixed devices are an error to use.
TEST_F(WireMultipleDeviceTests, DifferentDeviceObjectCreationIsError) {
WireHolder wireA;
WireHolder wireB;
// Create a bind group layout on wire A.
WGPUBindGroupLayoutDescriptor bglDesc = {};
WGPUBindGroupLayout bglA = wgpuDeviceCreateBindGroupLayout(wireA.ClientDevice(), &bglDesc);
EXPECT_CALL(*wireA.Api(), DeviceCreateBindGroupLayout(wireA.ServerDevice(), _))
.WillOnce(Return(wireA.Api()->GetNewBindGroupLayout()));
wireA.FlushClient();
std::array<WGPUBindGroupEntry, 2> entries = {};
// Create a buffer on wire A.
WGPUBufferDescriptor bufferDesc = {};
entries[0].buffer = wgpuDeviceCreateBuffer(wireA.ClientDevice(), &bufferDesc);
EXPECT_CALL(*wireA.Api(), DeviceCreateBuffer(wireA.ServerDevice(), _))
.WillOnce(Return(wireA.Api()->GetNewBuffer()));
wireA.FlushClient();
// Create a sampler on wire B.
WGPUSamplerDescriptor samplerDesc = {};
entries[1].sampler = wgpuDeviceCreateSampler(wireB.ClientDevice(), &samplerDesc);
EXPECT_CALL(*wireB.Api(), DeviceCreateSampler(wireB.ServerDevice(), _))
.WillOnce(Return(wireB.Api()->GetNewSampler()));
wireB.FlushClient();
// Create a bind group on wire A using the bgl (A), buffer (A), and sampler (B).
WGPUBindGroupDescriptor bgDesc = {};
bgDesc.layout = bglA;
bgDesc.entryCount = entries.size();
bgDesc.entries = entries.data();
WGPUBindGroup bindGroupA = wgpuDeviceCreateBindGroup(wireA.ClientDevice(), &bgDesc);
// It should inject an error because the sampler is from a different device.
ExpectInjectedError(&wireA);
// The bind group was never created on a server because it failed device validation.
// Any commands that use it should error.
wgpuBindGroupRelease(bindGroupA);
wireA.FlushClient(false);
}
// Test that using objects, included in an extension struct,
// from a difference device is a validation error.
TEST_F(WireMultipleDeviceTests, ValidatesSameDeviceInExtensionStruct) {
WireHolder wireA;
WireHolder wireB;
WGPUShaderModuleDescriptor shaderModuleDesc = {};
WGPUShaderModule shaderModuleA =
wgpuDeviceCreateShaderModule(wireA.ClientDevice(), &shaderModuleDesc);
// Flush on wire A. We should see the shader module created.
EXPECT_CALL(*wireA.Api(), DeviceCreateShaderModule(wireA.ServerDevice(), _))
.WillOnce(Return(wireA.Api()->GetNewShaderModule()));
wireA.FlushClient();
WGPURenderPipelineDescriptorDummyExtension extDesc = {};
extDesc.chain.sType = WGPUSType_RenderPipelineDescriptorDummyExtension;
extDesc.dummyStage.entryPoint = "main";
extDesc.dummyStage.module = shaderModuleA;
WGPURenderPipelineDescriptor pipelineDesc = {};
pipelineDesc.nextInChain = &extDesc.chain;
WGPURenderPipeline pipelineB =
wgpuDeviceCreateRenderPipeline(wireB.ClientDevice(), &pipelineDesc);
// We should inject an error into the server.
ExpectInjectedError(&wireB);
// The pipeline was never created on a server because it failed device validation.
// Any commands that use it should error.
wgpuRenderPipelineRelease(pipelineB);
wireB.FlushClient(false);
}
// Test that using objects, included in a chained extension struct,
// from a different device is a validation error.
TEST_F(WireMultipleDeviceTests, ValidatesSameDeviceSecondInExtensionStructChain) {
WireHolder wireA;
WireHolder wireB;
WGPUShaderModuleDescriptor shaderModuleDesc = {};
WGPUShaderModule shaderModuleA =
wgpuDeviceCreateShaderModule(wireA.ClientDevice(), &shaderModuleDesc);
// Flush on wire A. We should see the shader module created.
EXPECT_CALL(*wireA.Api(), DeviceCreateShaderModule(wireA.ServerDevice(), _))
.WillOnce(Return(wireA.Api()->GetNewShaderModule()));
wireA.FlushClient();
WGPUShaderModule shaderModuleB =
wgpuDeviceCreateShaderModule(wireB.ClientDevice(), &shaderModuleDesc);
// Flush on wire B. We should see the shader module created.
EXPECT_CALL(*wireB.Api(), DeviceCreateShaderModule(wireB.ServerDevice(), _))
.WillOnce(Return(wireB.Api()->GetNewShaderModule()));
wireB.FlushClient();
WGPURenderPipelineDescriptorDummyExtension extDescA = {};
extDescA.chain.sType = WGPUSType_RenderPipelineDescriptorDummyExtension;
extDescA.dummyStage.entryPoint = "main";
extDescA.dummyStage.module = shaderModuleA;
WGPURenderPipelineDescriptorDummyExtension extDescB = {};
extDescB.chain.sType = WGPUSType_RenderPipelineDescriptorDummyExtension;
extDescB.chain.next = &extDescA.chain;
extDescB.dummyStage.entryPoint = "main";
extDescB.dummyStage.module = shaderModuleB;
// The first extension struct is from Device B, and the second is from A.
// We should validate the second struct, is from the same device.
WGPURenderPipelineDescriptor pipelineDesc = {};
pipelineDesc.nextInChain = &extDescB.chain;
WGPURenderPipeline pipelineB =
wgpuDeviceCreateRenderPipeline(wireB.ClientDevice(), &pipelineDesc);
// We should inject an error into the server.
ExpectInjectedError(&wireB);
// The pipeline was never created on a server because it failed device validation.
// Any commands that use it should error.
wgpuRenderPipelineRelease(pipelineB);
wireB.FlushClient(false);
}