Forward unhandled builder errors to the device

This commit is contained in:
Corentin Wallez 2017-07-27 19:43:43 -04:00 committed by Corentin Wallez
parent ba6a36c974
commit 3818e18c5c
3 changed files with 63 additions and 15 deletions

View File

@ -18,9 +18,10 @@
#include "common/Assert.h"
#include <cstring>
#include <iostream>
#include <cstdlib>
#include <map>
#include <memory>
#include <string>
#include <vector>
namespace nxt {
@ -31,24 +32,19 @@ namespace wire {
class Device;
void PrintBuilderError(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata, nxtCallbackUserdata) {
if (status == NXT_BUILDER_ERROR_STATUS_SUCCESS || status == NXT_BUILDER_ERROR_STATUS_UNKNOWN) {
return;
}
std::cout << "Got a builder error " << status << ": " << message << std::endl;
}
struct BuilderCallbackData {
void Call(nxtBuilderErrorStatus status, const char* message) {
bool Call(nxtBuilderErrorStatus status, const char* message) {
if (canCall && callback != nullptr) {
canCall = true;
callback(status, message, userdata1, userdata2);
return true;
}
return false;
}
//* For help with development, prints all builder errors by default.
nxtBuilderErrorCallback callback = PrintBuilderError;
nxtBuilderErrorCallback callback = nullptr;
nxtCallbackUserdata userdata1 = 0;
nxtCallbackUserdata userdata2 = 0;
bool canCall = true;
@ -504,7 +500,13 @@ namespace wire {
return true;
}
builtObject->builderCallback.Call(static_cast<nxtBuilderErrorStatus>(cmd->status), cmd->GetMessage());
bool called = builtObject->builderCallback.Call(static_cast<nxtBuilderErrorStatus>(cmd->status), cmd->GetMessage());
// Unhandled builder errors are forwarded to the device
if (!called && cmd->status != NXT_BUILDER_ERROR_STATUS_SUCCESS && cmd->status != NXT_BUILDER_ERROR_STATUS_UNKNOWN) {
builtObject->device->HandleError(("Unhandled builder error: " + std::string(cmd->GetMessage())).c_str());
}
return true;
}
{% endfor %}

View File

@ -17,8 +17,6 @@
#include "backend/Device.h"
#include "common/Assert.h"
#include <iostream>
namespace backend {
bool BuilderBase::CanBeUsed() const {
@ -79,7 +77,8 @@ namespace backend {
result = nullptr;
}
if (!callback) std::cout << storedMessage << std::endl;
// Unhandled builder errors are promoted to device errors
if (!callback) device->HandleError(("Unhandled builder error: " + storedMessage).c_str());
} else {
ASSERT(storedStatus == nxt::BuilderErrorStatus::Success);
ASSERT(storedMessage.empty());

View File

@ -440,6 +440,53 @@ TEST_F(WireTests, UnknownBuilderErrorStatusCallback) {
}
}
// Test that a builder success status doesn't get forwarded to the device
TEST_F(WireTests, SuccessCallbackNotForwardedToDevice) {
nxtDeviceSetErrorCallback(device, ToMockDeviceErrorCallback, 0);
nxtBufferBuilder bufferBuilder = nxtDeviceCreateBufferBuilder(device);
nxtBufferBuilderGetResult(bufferBuilder);
nxtBufferBuilder apiBufferBuilder = api.GetNewBufferBuilder();
EXPECT_CALL(api, DeviceCreateBufferBuilder(apiDevice))
.WillOnce(Return(apiBufferBuilder));
nxtBuffer apiBuffer = api.GetNewBuffer();
EXPECT_CALL(api, BufferBuilderGetResult(apiBufferBuilder))
.WillOnce(InvokeWithoutArgs([&]() -> nxtBuffer {
api.CallBuilderErrorCallback(apiBufferBuilder, NXT_BUILDER_ERROR_STATUS_SUCCESS, "I like cheese");
return apiBuffer;
}));
FlushClient();
FlushServer();
}
// Test that a builder error status gets forwarded to the device
TEST_F(WireTests, ErrorCallbackForwardedToDevice) {
uint64_t userdata = 30495;
nxtDeviceSetErrorCallback(device, ToMockDeviceErrorCallback, userdata);
nxtBufferBuilder bufferBuilder = nxtDeviceCreateBufferBuilder(device);
nxtBufferBuilderGetResult(bufferBuilder);
nxtBufferBuilder apiBufferBuilder = api.GetNewBufferBuilder();
EXPECT_CALL(api, DeviceCreateBufferBuilder(apiDevice))
.WillOnce(Return(apiBufferBuilder));
EXPECT_CALL(api, BufferBuilderGetResult(apiBufferBuilder))
.WillOnce(InvokeWithoutArgs([&]() -> nxtBuffer {
api.CallBuilderErrorCallback(apiBufferBuilder, NXT_BUILDER_ERROR_STATUS_ERROR, "Error :(");
return nullptr;
}));
FlushClient();
EXPECT_CALL(*mockDeviceErrorCallback, Call(_, userdata)).Times(1);
FlushServer();
}
class WireSetCallbackTests : public WireTestsBase {
public:
WireSetCallbackTests() : WireTestsBase(false) {