diff --git a/generator/templates/mock_api.cpp b/generator/templates/mock_api.cpp index 83d136dcbe..f6adc7879e 100644 --- a/generator/templates/mock_api.cpp +++ b/generator/templates/mock_api.cpp @@ -56,6 +56,15 @@ void ProcTableAsClass::DeviceSetErrorCallback(nxtDevice self, nxtDeviceErrorCall this->OnDeviceSetErrorCallback(self, callback, userdata); } +void ProcTableAsClass::CallDeviceErrorCallback(nxtDevice device, const char* message) { + auto object = reinterpret_cast(device); + object->deviceErrorCallback(message, object->userdata1); +} +void ProcTableAsClass::CallBuilderErrorCallback(void* builder , nxtBuilderErrorStatus status, const char* message) { + auto object = reinterpret_cast(builder); + object->builderErrorCallback(status, message, object->userdata1, object->userdata2); +} + {% for type in by_category["object"] if type.is_builder %} void ProcTableAsClass::{{as_MethodSuffix(type.name, Name("set error callback"))}}({{as_cType(type.name)}} self, nxtBuilderErrorCallback callback, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2) { auto object = reinterpret_cast(self); diff --git a/generator/templates/mock_api.h b/generator/templates/mock_api.h index e5c70f86ed..7d3074e760 100644 --- a/generator/templates/mock_api.h +++ b/generator/templates/mock_api.h @@ -60,6 +60,10 @@ class ProcTableAsClass { virtual void OnDeviceSetErrorCallback(nxtDevice device, nxtDeviceErrorCallback callback, nxtCallbackUserdata userdata) = 0; virtual void OnBuilderSetErrorCallback(nxtBufferBuilder builder, nxtBuilderErrorCallback callback, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2) = 0; + // Calls the stored callbacks + void CallDeviceErrorCallback(nxtDevice device, const char* message); + void CallBuilderErrorCallback(void* builder , nxtBuilderErrorStatus status, const char* message); + struct Object { ProcTableAsClass* procs = nullptr; nxtDeviceErrorCallback deviceErrorCallback = nullptr; diff --git a/src/tests/unittests/WireTests.cpp b/src/tests/unittests/WireTests.cpp index fa2938095a..e15e6cc5a8 100644 --- a/src/tests/unittests/WireTests.cpp +++ b/src/tests/unittests/WireTests.cpp @@ -23,6 +23,27 @@ using namespace testing; using namespace nxt::wire; + +class MockDeviceErrorCallback { + public: + MOCK_METHOD2(Call, void(const char* message, nxtCallbackUserdata userdata)); +}; + +static MockDeviceErrorCallback* mockDeviceErrorCallback = nullptr; +static void ToMockDeviceErrorCallback(const char* message, nxtCallbackUserdata userdata) { + mockDeviceErrorCallback->Call(message, userdata); +} + +class MockBuilderErrorCallback { + public: + MOCK_METHOD4(Call, void(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2)); +}; + +static MockBuilderErrorCallback* mockBuilderErrorCallback = nullptr; +static void ToMockBuilderErrorCallback(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2) { + mockBuilderErrorCallback->Call(status, message, userdata1, userdata2); +} + class WireTestsBase : public Test { protected: WireTestsBase(bool ignoreSetCallbackCalls) @@ -30,12 +51,16 @@ class WireTestsBase : public Test { } void SetUp() override { + mockDeviceErrorCallback = new MockDeviceErrorCallback; + mockBuilderErrorCallback = new MockBuilderErrorCallback; + nxtProcTable mockProcs; nxtDevice mockDevice; api.GetProcTableAndDevice(&mockProcs, &mockDevice); + // This SetCallback call cannot be ignored because it is done as soon as we start the server + EXPECT_CALL(api, OnDeviceSetErrorCallback(_, _, _)).Times(Exactly(1)); if (ignoreSetCallbackCalls) { - EXPECT_CALL(api, OnDeviceSetErrorCallback(_, _, _)).Times(Exactly(1)); EXPECT_CALL(api, OnBuilderSetErrorCallback(_, _, _, _)).Times(AnyNumber()); } @@ -59,6 +84,8 @@ class WireTestsBase : public Test { delete wireClient; delete c2sBuf; delete s2cBuf; + delete mockDeviceErrorCallback; + delete mockBuilderErrorCallback; } void FlushClient() { @@ -299,12 +326,73 @@ TEST_F(WireTests, ObjectsAsPointerArgument) { FlushClient(); } +class WireSetCallbackTests : public WireTestsBase { + public: + WireSetCallbackTests() : WireTestsBase(false) { + } +}; + +// Test the return wire for device error callbacks +TEST_F(WireSetCallbackTests, DeviceErrorCallback) { + uint64_t userdata = 3049785; + nxtDeviceSetErrorCallback(device, ToMockDeviceErrorCallback, userdata); + + // Setting the error callback should stay on the client side and do nothing + FlushClient(); + + // Calling the callback on the server side will result in the callback being called on the client side + api.CallDeviceErrorCallback(apiDevice, "Some error message"); + + EXPECT_CALL(*mockDeviceErrorCallback, Call(StrEq("Some error message"), userdata)) + .Times(1); + + FlushServer(); +} + +// Test the return wire for device error callbacks +TEST_F(WireSetCallbackTests, BuilderErrorCallback) { + uint64_t userdata1 = 982734; + uint64_t userdata2 = 982734239028; + + // Create the buffer builder, the callback is set immediately on the server side + nxtBufferBuilder bufferBuilder = nxtDeviceCreateBufferBuilder(device); + + nxtBufferBuilder apiBufferBuilder = api.GetNewBufferBuilder(); + EXPECT_CALL(api, DeviceCreateBufferBuilder(apiDevice)) + .WillOnce(Return(apiBufferBuilder)); + + EXPECT_CALL(api, OnBuilderSetErrorCallback(apiBufferBuilder, _, _, _)) + .Times(1); + + FlushClient(); + + // Setting the callback on the client side doesn't do anything on the server side + nxtBufferBuilderSetErrorCallback(bufferBuilder, ToMockBuilderErrorCallback, userdata1, userdata2); + FlushClient(); + + // Create an object so that it is a valid case to call the error callback + nxtBuffer buffer = nxtBufferBuilderGetResult(bufferBuilder); + + nxtBuffer apiBuffer = api.GetNewBuffer(); + EXPECT_CALL(api, BufferBuilderGetResult(apiBufferBuilder)) + .WillOnce(InvokeWithoutArgs([&]() -> nxtBuffer { + api.CallBuilderErrorCallback(apiBufferBuilder, NXT_BUILDER_ERROR_STATUS_SUCCESS, "Success!"); + return apiBuffer; + })); + + FlushClient(); + + // The error callback gets called on the client side + EXPECT_CALL(*mockBuilderErrorCallback, Call(NXT_BUILDER_ERROR_STATUS_SUCCESS, StrEq("Success!"), userdata1, userdata2)) + .Times(1); + + FlushServer(); +} + // TODO // - Object creation, then calls do nothing after error on builder // - Object creation then error then create object, then should do nothing. -// - Device error gets forwarded properly // - Builder error -// - An error gets forwarded properly // - No other call to builder after error // - No call to object after error // - No error -> success