diff --git a/generator/templates/wire/WireClient.cpp b/generator/templates/wire/WireClient.cpp index acd1588681..ad377d22fe 100644 --- a/generator/templates/wire/WireClient.cpp +++ b/generator/templates/wire/WireClient.cpp @@ -18,9 +18,10 @@ #include "common/Assert.h" #include -#include +#include #include #include +#include #include namespace nxt { @@ -31,24 +32,19 @@ namespace wire { class Device; - void PrintBuilderError(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata, nxtCallbackUserdata) { - if (status == NXT_BUILDER_ERROR_STATUS_SUCCESS || status == NXT_BUILDER_ERROR_STATUS_UNKNOWN) { - return; - } - - std::cout << "Got a builder error " << status << ": " << message << std::endl; - } - struct BuilderCallbackData { - void Call(nxtBuilderErrorStatus status, const char* message) { + bool Call(nxtBuilderErrorStatus status, const char* message) { if (canCall && callback != nullptr) { canCall = true; callback(status, message, userdata1, userdata2); + return true; } + + return false; } //* For help with development, prints all builder errors by default. - nxtBuilderErrorCallback callback = PrintBuilderError; + nxtBuilderErrorCallback callback = nullptr; nxtCallbackUserdata userdata1 = 0; nxtCallbackUserdata userdata2 = 0; bool canCall = true; @@ -504,7 +500,13 @@ namespace wire { return true; } - builtObject->builderCallback.Call(static_cast(cmd->status), cmd->GetMessage()); + bool called = builtObject->builderCallback.Call(static_cast(cmd->status), cmd->GetMessage()); + + // Unhandled builder errors are forwarded to the device + if (!called && cmd->status != NXT_BUILDER_ERROR_STATUS_SUCCESS && cmd->status != NXT_BUILDER_ERROR_STATUS_UNKNOWN) { + builtObject->device->HandleError(("Unhandled builder error: " + std::string(cmd->GetMessage())).c_str()); + } + return true; } {% endfor %} diff --git a/src/backend/Builder.cpp b/src/backend/Builder.cpp index 83ebf2c82f..2e77ec2a4a 100644 --- a/src/backend/Builder.cpp +++ b/src/backend/Builder.cpp @@ -17,8 +17,6 @@ #include "backend/Device.h" #include "common/Assert.h" -#include - namespace backend { bool BuilderBase::CanBeUsed() const { @@ -79,7 +77,8 @@ namespace backend { result = nullptr; } - if (!callback) std::cout << storedMessage << std::endl; + // Unhandled builder errors are promoted to device errors + if (!callback) device->HandleError(("Unhandled builder error: " + storedMessage).c_str()); } else { ASSERT(storedStatus == nxt::BuilderErrorStatus::Success); ASSERT(storedMessage.empty()); diff --git a/src/tests/unittests/WireTests.cpp b/src/tests/unittests/WireTests.cpp index 61a2e11938..73a4976d6c 100644 --- a/src/tests/unittests/WireTests.cpp +++ b/src/tests/unittests/WireTests.cpp @@ -440,6 +440,53 @@ TEST_F(WireTests, UnknownBuilderErrorStatusCallback) { } } +// Test that a builder success status doesn't get forwarded to the device +TEST_F(WireTests, SuccessCallbackNotForwardedToDevice) { + nxtDeviceSetErrorCallback(device, ToMockDeviceErrorCallback, 0); + + nxtBufferBuilder bufferBuilder = nxtDeviceCreateBufferBuilder(device); + nxtBufferBuilderGetResult(bufferBuilder); + + nxtBufferBuilder apiBufferBuilder = api.GetNewBufferBuilder(); + EXPECT_CALL(api, DeviceCreateBufferBuilder(apiDevice)) + .WillOnce(Return(apiBufferBuilder)); + + nxtBuffer apiBuffer = api.GetNewBuffer(); + EXPECT_CALL(api, BufferBuilderGetResult(apiBufferBuilder)) + .WillOnce(InvokeWithoutArgs([&]() -> nxtBuffer { + api.CallBuilderErrorCallback(apiBufferBuilder, NXT_BUILDER_ERROR_STATUS_SUCCESS, "I like cheese"); + return apiBuffer; + })); + + FlushClient(); + FlushServer(); +} + +// Test that a builder error status gets forwarded to the device +TEST_F(WireTests, ErrorCallbackForwardedToDevice) { + uint64_t userdata = 30495; + nxtDeviceSetErrorCallback(device, ToMockDeviceErrorCallback, userdata); + + nxtBufferBuilder bufferBuilder = nxtDeviceCreateBufferBuilder(device); + nxtBufferBuilderGetResult(bufferBuilder); + + nxtBufferBuilder apiBufferBuilder = api.GetNewBufferBuilder(); + EXPECT_CALL(api, DeviceCreateBufferBuilder(apiDevice)) + .WillOnce(Return(apiBufferBuilder)); + + EXPECT_CALL(api, BufferBuilderGetResult(apiBufferBuilder)) + .WillOnce(InvokeWithoutArgs([&]() -> nxtBuffer { + api.CallBuilderErrorCallback(apiBufferBuilder, NXT_BUILDER_ERROR_STATUS_ERROR, "Error :("); + return nullptr; + })); + + FlushClient(); + + EXPECT_CALL(*mockDeviceErrorCallback, Call(_, userdata)).Times(1); + + FlushServer(); +} + class WireSetCallbackTests : public WireTestsBase { public: WireSetCallbackTests() : WireTestsBase(false) {