From 51791e04094f7f399dab8c507b3ab5336936f2fb Mon Sep 17 00:00:00 2001 From: Kai Ninomiya Date: Tue, 28 Sep 2021 11:52:17 +0000 Subject: [PATCH] Add 'reason' argument to device lost callback Breaking change, but it should only require small changes in any project that relies on it, so just doing this instead of a two-stage deprecation. Will require a manual roll into (at least) Chromium. Bug: dawn:1080, chromium:1253721 Change-Id: I6699e0629c3b2fe63e7f9d5ba0a928f00316a588 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/64520 Reviewed-by: Austin Eng Reviewed-by: Corentin Wallez --- dawn.json | 5 ++-- dawn_wire.json | 1 + src/dawn_native/Device.cpp | 6 ++-- src/dawn_wire/client/Client.cpp | 3 +- src/dawn_wire/client/ClientDoers.cpp | 6 ++-- src/dawn_wire/client/Device.cpp | 6 ++-- src/dawn_wire/client/Device.h | 2 +- src/dawn_wire/server/Server.cpp | 4 +-- src/dawn_wire/server/Server.h | 2 +- src/dawn_wire/server/ServerDevice.cpp | 5 +++- src/tests/DawnTest.cpp | 2 +- src/tests/DawnTest.h | 2 +- src/tests/end2end/DeviceLostTests.cpp | 15 ++++++---- .../end2end/SwapChainValidationTests.cpp | 4 ++- .../unittests/wire/WireDisconnectTests.cpp | 28 ++++++++++++------- .../unittests/wire/WireErrorCallbackTests.cpp | 15 ++++++---- 16 files changed, 67 insertions(+), 39 deletions(-) diff --git a/dawn.json b/dawn.json index 7fadb940f7..569dde198d 100644 --- a/dawn.json +++ b/dawn.json @@ -1004,6 +1004,7 @@ "device lost callback": { "category": "callback", "args": [ + {"name": "reason", "type": "device lost reason"}, {"name": "message", "type": "char", "annotation": "const*"}, {"name": "userdata", "type": "void", "annotation": "*"} ] @@ -1012,8 +1013,8 @@ "category": "enum", "emscripten_no_enum_table": true, "values": [ - {"name": "undefined", "value": 0, "jsrepr": "undefined"}, - {"name": "destroyed", "value": 1} + {"value": 0, "name": "undefined", "jsrepr": "undefined"}, + {"value": 1, "name": "destroyed"} ] }, "device properties": { diff --git a/dawn_wire.json b/dawn_wire.json index 7d4b89c397..75526d52fe 100644 --- a/dawn_wire.json +++ b/dawn_wire.json @@ -118,6 +118,7 @@ ], "device lost callback" : [ { "name": "device", "type": "ObjectHandle", "handle_type": "device" }, + { "name": "reason", "type": "device lost reason" }, { "name": "message", "type": "char", "annotation": "const*", "length": "strlen" } ], "device pop error scope callback": [ diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp index 64c180db62..35f105d4f4 100644 --- a/src/dawn_native/Device.cpp +++ b/src/dawn_native/Device.cpp @@ -209,7 +209,7 @@ namespace dawn_native { } }; - mDeviceLostCallback = [](char const*, void*) { + mDeviceLostCallback = [](WGPUDeviceLostReason, char const*, void*) { static bool calledOnce = false; if (!calledOnce) { calledOnce = true; @@ -365,7 +365,9 @@ namespace dawn_native { if (type == InternalErrorType::DeviceLost) { // The device was lost, call the application callback. if (mDeviceLostCallback != nullptr) { - mDeviceLostCallback(message, mDeviceLostUserdata); + // TODO(crbug.com/dawn/628): Make sure the "Destroyed" reason is passed if + // the device was destroyed. + mDeviceLostCallback(WGPUDeviceLostReason_Undefined, message, mDeviceLostUserdata); mDeviceLostCallback = nullptr; } diff --git a/src/dawn_wire/client/Client.cpp b/src/dawn_wire/client/Client.cpp index a00bb5e90f..2d4445e794 100644 --- a/src/dawn_wire/client/Client.cpp +++ b/src/dawn_wire/client/Client.cpp @@ -138,7 +138,8 @@ namespace dawn_wire { namespace client { { for (LinkNode* device = deviceList.head(); device != deviceList.end(); device = device->next()) { - static_cast(device->value())->HandleDeviceLost("GPU connection lost"); + static_cast(device->value()) + ->HandleDeviceLost(WGPUDeviceLostReason_Undefined, "GPU connection lost"); } } for (auto& objectList : mObjects) { diff --git a/src/dawn_wire/client/ClientDoers.cpp b/src/dawn_wire/client/ClientDoers.cpp index e3e34bde15..e6665abf01 100644 --- a/src/dawn_wire/client/ClientDoers.cpp +++ b/src/dawn_wire/client/ClientDoers.cpp @@ -52,12 +52,14 @@ namespace dawn_wire { namespace client { return true; } - bool Client::DoDeviceLostCallback(Device* device, char const* message) { + bool Client::DoDeviceLostCallback(Device* device, + WGPUDeviceLostReason reason, + char const* message) { if (device == nullptr) { // The device might have been deleted or recreated so this isn't an error. return true; } - device->HandleDeviceLost(message); + device->HandleDeviceLost(reason, message); return true; } diff --git a/src/dawn_wire/client/Device.cpp b/src/dawn_wire/client/Device.cpp index 21df71e9f5..8379d51b1c 100644 --- a/src/dawn_wire/client/Device.cpp +++ b/src/dawn_wire/client/Device.cpp @@ -35,7 +35,7 @@ namespace dawn_wire { namespace client { } }; - mDeviceLostCallback = [](char const*, void*) { + mDeviceLostCallback = [](WGPUDeviceLostReason, char const*, void*) { static bool calledOnce = false; if (!calledOnce) { calledOnce = true; @@ -80,10 +80,10 @@ namespace dawn_wire { namespace client { } } - void Device::HandleDeviceLost(const char* message) { + void Device::HandleDeviceLost(WGPUDeviceLostReason reason, const char* message) { if (mDeviceLostCallback && !mDidRunLostCallback) { mDidRunLostCallback = true; - mDeviceLostCallback(message, mDeviceLostUserdata); + mDeviceLostCallback(reason, message, mDeviceLostUserdata); } } diff --git a/src/dawn_wire/client/Device.h b/src/dawn_wire/client/Device.h index ae2d9fd5e4..426799c1eb 100644 --- a/src/dawn_wire/client/Device.h +++ b/src/dawn_wire/client/Device.h @@ -53,7 +53,7 @@ namespace dawn_wire { namespace client { void HandleError(WGPUErrorType errorType, const char* message); void HandleLogging(WGPULoggingType loggingType, const char* message); - void HandleDeviceLost(const char* message); + void HandleDeviceLost(WGPUDeviceLostReason reason, const char* message); bool OnPopErrorScopeCallback(uint64_t requestSerial, WGPUErrorType type, const char* message); diff --git a/src/dawn_wire/server/Server.cpp b/src/dawn_wire/server/Server.cpp index 7a504ef4de..8297cbdcf4 100644 --- a/src/dawn_wire/server/Server.cpp +++ b/src/dawn_wire/server/Server.cpp @@ -144,9 +144,9 @@ namespace dawn_wire { namespace server { data->info.get()); mProcs.deviceSetDeviceLostCallback( device, - [](const char* message, void* userdata) { + [](WGPUDeviceLostReason reason, const char* message, void* userdata) { DeviceInfo* info = static_cast(userdata); - info->server->OnDeviceLost(info->self, message); + info->server->OnDeviceLost(info->self, reason, message); }, data->info.get()); diff --git a/src/dawn_wire/server/Server.h b/src/dawn_wire/server/Server.h index 711813537d..b4429871f7 100644 --- a/src/dawn_wire/server/Server.h +++ b/src/dawn_wire/server/Server.h @@ -196,7 +196,7 @@ namespace dawn_wire { namespace server { // Error callbacks void OnUncapturedError(ObjectHandle device, WGPUErrorType type, const char* message); - void OnDeviceLost(ObjectHandle device, const char* message); + void OnDeviceLost(ObjectHandle device, WGPUDeviceLostReason reason, const char* message); void OnLogging(ObjectHandle device, WGPULoggingType type, const char* message); void OnDevicePopErrorScope(WGPUErrorType type, const char* message, diff --git a/src/dawn_wire/server/ServerDevice.cpp b/src/dawn_wire/server/ServerDevice.cpp index 939e632e56..c8cddf4e50 100644 --- a/src/dawn_wire/server/ServerDevice.cpp +++ b/src/dawn_wire/server/ServerDevice.cpp @@ -59,9 +59,12 @@ namespace dawn_wire { namespace server { SerializeCommand(cmd); } - void Server::OnDeviceLost(ObjectHandle device, const char* message) { + void Server::OnDeviceLost(ObjectHandle device, + WGPUDeviceLostReason reason, + const char* message) { ReturnDeviceLostCallbackCmd cmd; cmd.device = device; + cmd.reason = reason; cmd.message = message; SerializeCommand(cmd); diff --git a/src/tests/DawnTest.cpp b/src/tests/DawnTest.cpp index 23d9433a23..9f0c5021cf 100644 --- a/src/tests/DawnTest.cpp +++ b/src/tests/DawnTest.cpp @@ -1010,7 +1010,7 @@ void DawnTestBase::OnDeviceError(WGPUErrorType type, const char* message, void* self->mError = true; } -void DawnTestBase::OnDeviceLost(const char* message, void* userdata) { +void DawnTestBase::OnDeviceLost(WGPUDeviceLostReason reason, const char* message, void* userdata) { // Using ADD_FAILURE + ASSERT instead of FAIL to prevent the current test from continuing with a // corrupt state. ADD_FAILURE() << "Device Lost during test: " << message; diff --git a/src/tests/DawnTest.h b/src/tests/DawnTest.h index aec8a66387..206c15def7 100644 --- a/src/tests/DawnTest.h +++ b/src/tests/DawnTest.h @@ -492,7 +492,7 @@ class DawnTestBase { // Tracking for validation errors static void OnDeviceError(WGPUErrorType type, const char* message, void* userdata); - static void OnDeviceLost(const char* message, void* userdata); + static void OnDeviceLost(WGPUDeviceLostReason reason, const char* message, void* userdata); bool mExpectError = false; bool mError = false; diff --git a/src/tests/end2end/DeviceLostTests.cpp b/src/tests/end2end/DeviceLostTests.cpp index 3866a0747e..91bdabe973 100644 --- a/src/tests/end2end/DeviceLostTests.cpp +++ b/src/tests/end2end/DeviceLostTests.cpp @@ -25,12 +25,14 @@ using namespace testing; class MockDeviceLostCallback { public: - MOCK_METHOD(void, Call, (const char* message, void* userdata)); + MOCK_METHOD(void, Call, (WGPUDeviceLostReason reason, const char* message, void* userdata)); }; static std::unique_ptr mockDeviceLostCallback; -static void ToMockDeviceLostCallback(const char* message, void* userdata) { - mockDeviceLostCallback->Call(message, userdata); +static void ToMockDeviceLostCallback(WGPUDeviceLostReason reason, + const char* message, + void* userdata) { + mockDeviceLostCallback->Call(reason, message, userdata); DawnTestBase* self = static_cast(userdata); self->StartExpectDeviceError(); } @@ -67,7 +69,8 @@ class DeviceLostTest : public DawnTest { } void LoseForTesting() { - EXPECT_CALL(*mockDeviceLostCallback, Call(_, this)).Times(1); + EXPECT_CALL(*mockDeviceLostCallback, Call(WGPUDeviceLostReason_Undefined, _, this)) + .Times(1); device.LoseForTesting(); } @@ -427,13 +430,13 @@ TEST_P(DeviceLostTest, QueueOnSubmittedWorkDoneBeforeLossFails) { // Test that LostForTesting can only be called on one time TEST_P(DeviceLostTest, LoseForTestingOnce) { // First LoseForTesting call should occur normally. The callback is already set in SetUp. - EXPECT_CALL(*mockDeviceLostCallback, Call(_, this)).Times(1); + EXPECT_CALL(*mockDeviceLostCallback, Call(WGPUDeviceLostReason_Undefined, _, this)).Times(1); device.LoseForTesting(); // Second LoseForTesting call should result in no callbacks. The LoseForTesting will return // without doing anything when it sees that device has already been lost. device.SetDeviceLostCallback(ToMockDeviceLostCallback, this); - EXPECT_CALL(*mockDeviceLostCallback, Call(_, this)).Times(0); + EXPECT_CALL(*mockDeviceLostCallback, Call(_, _, this)).Times(0); device.LoseForTesting(); } diff --git a/src/tests/end2end/SwapChainValidationTests.cpp b/src/tests/end2end/SwapChainValidationTests.cpp index 77180db0d3..1926f627c2 100644 --- a/src/tests/end2end/SwapChainValidationTests.cpp +++ b/src/tests/end2end/SwapChainValidationTests.cpp @@ -319,7 +319,9 @@ TEST_P(SwapChainValidationTests, SwapChainIsInvalidAfterSurfaceDestruction_After } // Test that after Device is Lost, all swap chain operations fail -static void ToMockDeviceLostCallback(const char* message, void* userdata) { +static void ToMockDeviceLostCallback(WGPUDeviceLostReason reason, + const char* message, + void* userdata) { DawnTest* self = static_cast(userdata); self->StartExpectDeviceError(); } diff --git a/src/tests/unittests/wire/WireDisconnectTests.cpp b/src/tests/unittests/wire/WireDisconnectTests.cpp index 81b75b26f3..fee5858a41 100644 --- a/src/tests/unittests/wire/WireDisconnectTests.cpp +++ b/src/tests/unittests/wire/WireDisconnectTests.cpp @@ -69,7 +69,8 @@ TEST_F(WireDisconnectTests, CallsDeviceLostCallback) { mockDeviceLostCallback.MakeUserdata(this)); // Disconnect the wire client. We should receive device lost only once. - EXPECT_CALL(mockDeviceLostCallback, Call(_, this)).Times(Exactly(1)); + EXPECT_CALL(mockDeviceLostCallback, Call(WGPUDeviceLostReason_Undefined, _, this)) + .Times(Exactly(1)); GetWireClient()->Disconnect(); GetWireClient()->Disconnect(); } @@ -80,14 +81,17 @@ TEST_F(WireDisconnectTests, ServerLostThenDisconnect) { wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(), mockDeviceLostCallback.MakeUserdata(this)); - api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, "some reason"); + api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, WGPUDeviceLostReason_Undefined, + "some reason"); // Flush the device lost return command. - EXPECT_CALL(mockDeviceLostCallback, Call(StrEq("some reason"), this)).Times(Exactly(1)); + EXPECT_CALL(mockDeviceLostCallback, + Call(WGPUDeviceLostReason_Undefined, StrEq("some reason"), this)) + .Times(Exactly(1)); FlushServer(); // Disconnect the client. We shouldn't see the lost callback again. - EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0)); + EXPECT_CALL(mockDeviceLostCallback, Call(_, _, _)).Times(Exactly(0)); GetWireClient()->Disconnect(); } @@ -98,13 +102,15 @@ TEST_F(WireDisconnectTests, ServerLostThenDisconnectInCallback) { wgpuDeviceSetDeviceLostCallback(device, mockDeviceLostCallback.Callback(), mockDeviceLostCallback.MakeUserdata(this)); - api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, "lost reason"); + api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, WGPUDeviceLostReason_Undefined, + "lost reason"); // Disconnect the client inside the lost callback. We should see the callback // only once. - EXPECT_CALL(mockDeviceLostCallback, Call(StrEq("lost reason"), this)) + EXPECT_CALL(mockDeviceLostCallback, + Call(WGPUDeviceLostReason_Undefined, StrEq("lost reason"), this)) .WillOnce(InvokeWithoutArgs([&]() { - EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0)); + EXPECT_CALL(mockDeviceLostCallback, Call(_, _, _)).Times(Exactly(0)); GetWireClient()->Disconnect(); })); FlushServer(); @@ -117,13 +123,15 @@ TEST_F(WireDisconnectTests, DisconnectThenServerLost) { mockDeviceLostCallback.MakeUserdata(this)); // Disconnect the client. We should see the callback once. - EXPECT_CALL(mockDeviceLostCallback, Call(_, this)).Times(Exactly(1)); + EXPECT_CALL(mockDeviceLostCallback, Call(WGPUDeviceLostReason_Undefined, _, this)) + .Times(Exactly(1)); GetWireClient()->Disconnect(); // Lose the device on the server. The client callback shouldn't be // called again. - api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, "lost reason"); - EXPECT_CALL(mockDeviceLostCallback, Call(_, _)).Times(Exactly(0)); + api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, WGPUDeviceLostReason_Undefined, + "lost reason"); + EXPECT_CALL(mockDeviceLostCallback, Call(_, _, _)).Times(Exactly(0)); FlushServer(); } diff --git a/src/tests/unittests/wire/WireErrorCallbackTests.cpp b/src/tests/unittests/wire/WireErrorCallbackTests.cpp index fb4cad3aec..32ba5f56c4 100644 --- a/src/tests/unittests/wire/WireErrorCallbackTests.cpp +++ b/src/tests/unittests/wire/WireErrorCallbackTests.cpp @@ -56,12 +56,14 @@ namespace { class MockDeviceLostCallback { public: - MOCK_METHOD(void, Call, (const char* message, void* userdata)); + MOCK_METHOD(void, Call, (WGPUDeviceLostReason reason, const char* message, void* userdata)); }; std::unique_ptr> mockDeviceLostCallback; - void ToMockDeviceLostCallback(const char* message, void* userdata) { - mockDeviceLostCallback->Call(message, userdata); + void ToMockDeviceLostCallback(WGPUDeviceLostReason reason, + const char* message, + void* userdata) { + mockDeviceLostCallback->Call(reason, message, userdata); } } // anonymous namespace @@ -319,9 +321,12 @@ TEST_F(WireErrorCallbackTests, DeviceLostCallback) { // Calling the callback on the server side will result in the callback being called on the // client side - api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, "Some error message"); + api.CallDeviceSetDeviceLostCallbackCallback(apiDevice, WGPUDeviceLostReason_Undefined, + "Some error message"); - EXPECT_CALL(*mockDeviceLostCallback, Call(StrEq("Some error message"), this)).Times(1); + EXPECT_CALL(*mockDeviceLostCallback, + Call(WGPUDeviceLostReason_Undefined, StrEq("Some error message"), this)) + .Times(1); FlushServer(); }