Implement the builder error callback in the backends

This makes the Builder base class retain the error status, if any, and
call the callback on GetResult (or ~Builder, whichever comes first).
This commit is contained in:
Corentin Wallez 2017-05-08 15:17:44 +02:00 committed by Corentin Wallez
parent 5dc7915d38
commit 7f96177289
30 changed files with 247 additions and 86 deletions

View File

@ -309,15 +309,24 @@ def as_backendType(typ):
else: else:
return as_cType(typ.name) return as_cType(typ.name)
def cpp_native_methods(types, typ):
methods = typ.methods + typ.native_methods
if typ.is_builder:
methods.append(Method(Name('set error callback'), types['void'], [
MethodArgument(Name('callback'), types['builder error callback'], 'value'),
MethodArgument(Name('userdata1'), types['callback userdata'], 'value'),
MethodArgument(Name('userdata2'), types['callback userdata'], 'value'),
]))
return methods
def c_native_methods(types, typ): def c_native_methods(types, typ):
return cpp_native_methods(typ) + [ return cpp_native_methods(types, typ) + [
Method(Name('reference'), types['void'], []), Method(Name('reference'), types['void'], []),
Method(Name('release'), types['void'], []), Method(Name('release'), types['void'], []),
] ]
def cpp_native_methods(typ):
return typ.methods + typ.native_methods
def debug(text): def debug(text):
print(text) print(text)
@ -376,7 +385,7 @@ def main():
renders.append(FileRender('api.c', 'nxt/nxt.c', [base_params, api_params, c_params])) renders.append(FileRender('api.c', 'nxt/nxt.c', [base_params, api_params, c_params]))
if 'nxtcpp' in targets: if 'nxtcpp' in targets:
additional_params = {'native_methods': cpp_native_methods} additional_params = {'native_methods': lambda typ: cpp_native_methods(api_params['types'], typ)}
renders.append(FileRender('apicpp.h', 'nxt/nxtcpp.h', [base_params, api_params, additional_params])) renders.append(FileRender('apicpp.h', 'nxt/nxtcpp.h', [base_params, api_params, additional_params]))
renders.append(FileRender('apicpp.cpp', 'nxt/nxtcpp.cpp', [base_params, api_params, additional_params])) renders.append(FileRender('apicpp.cpp', 'nxt/nxtcpp.cpp', [base_params, api_params, additional_params]))

View File

@ -91,7 +91,7 @@ namespace {{namespace}} {
{%- endfor -%} {%- endfor -%}
) { ) {
{% if type.is_builder and method.name.canonical_case() not in ("release", "reference") %} {% if type.is_builder and method.name.canonical_case() not in ("release", "reference") %}
if (self->WasConsumed()) return false; if (!self->CanBeUsed()) return false;
{% else %} {% else %}
(void) self; (void) self;
{% endif %} {% endif %}
@ -121,6 +121,8 @@ namespace {{namespace}} {
{%- endfor -%} {%- endfor -%}
); );
//* Some function have very heavy checks in a seperate method, so that they
//* can be skipped in the NonValidatingEntryPoints.
{% if suffix in methodsWithExtraValidation %} {% if suffix in methodsWithExtraValidation %}
if (valid) { if (valid) {
valid = self->Validate{{method.name.CamelCase()}}( valid = self->Validate{{method.name.CamelCase()}}(
@ -130,12 +132,27 @@ namespace {{namespace}} {
); );
} }
{% endif %} {% endif %}
//* TODO Do the hand-written checks if necessary
//* On success, forward the arguments to the method, else error out without calling it //* If there is an error we forward it appropriately.
if (!valid) { if (!valid) {
// TODO get the device and give it the error? //* An error in a builder methods is always handled by the builder
std::cout << "Error in {{suffix}}" << std::endl; {% if type.is_builder %}
//* HACK(cwallez@chromium.org): special casing GetResult so that the error callback
//* is called if needed. Without this, no call to HandleResult would happen, and the
//* error callback would always get called with an Unknown status
{% if method.name.canonical_case() == "get result" %}
{{as_backendType(method.return_type)}} fakeResult = nullptr;
bool shouldBeFalse = self->HandleResult(fakeResult);
assert(shouldBeFalse == false);
{% else %}
self->HandleError("Error in {{suffix}}");
{% endif %}
{% else %}
// TODO get the device or builder and give it the error?
std::cout << "Error in {{suffix}}" << std::endl;
{% endif %}
} }
{% if method.return_type.name.canonical_case() == "void" %} {% if method.return_type.name.canonical_case() == "void" %}
if (!valid) return; if (!valid) return;
{% else %} {% else %}

View File

@ -36,6 +36,7 @@
// Custom types depending on the target language // Custom types depending on the target language
typedef uint64_t nxtCallbackUserdata; typedef uint64_t nxtCallbackUserdata;
typedef void (*nxtDeviceErrorCallback)(const char* message, nxtCallbackUserdata userdata); typedef void (*nxtDeviceErrorCallback)(const char* message, nxtCallbackUserdata userdata);
typedef void (*nxtBuilderErrorCallback)(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2);
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {

View File

@ -172,6 +172,14 @@ namespace wire {
} }
{% endfor %} {% endfor %}
{% if type.is_builder %}
void Client{{as_MethodSuffix(type.name, Name("set error callback"))}}(nxtBuilderErrorCallback callback,
nxtCallbackUserdata userdata1,
nxtCallbackUserdata userdata2) {
//TODO(cwallez@chromium.org): will be implemented in a follow-up commit.
}
{% endif %}
{% if not type.name.canonical_case() == "device" %} {% if not type.name.canonical_case() == "device" %}
//* When an object's refcount reaches 0, notify the server side of it and delete it. //* When an object's refcount reaches 0, notify the server side of it and delete it.
void Client{{as_MethodSuffix(type.name, Name("release"))}}({{Type}}* obj) { void Client{{as_MethodSuffix(type.name, Name("release"))}}({{Type}}* obj) {

View File

@ -102,6 +102,18 @@
{"value": 3, "name": "storage buffer"} {"value": 3, "name": "storage buffer"}
] ]
}, },
"builder error status": {
"category": "enum",
"values": [
{"value": 0, "name": "success"},
{"value": 1, "name": "error", "TODO": "cwallez@chromium.org: recoverable errors like GPU OOM"},
{"value": 2, "name": "unknown"},
{"value": 3, "name": "context lost"}
]
},
"builder error callback": {
"category": "natively defined"
},
"buffer": { "buffer": {
"category": "object", "category": "object",
"methods": [ "methods": [

View File

@ -64,11 +64,10 @@ namespace backend {
BINDGROUP_PROPERTY_LAYOUT = 0x2, BINDGROUP_PROPERTY_LAYOUT = 0x2,
}; };
BindGroupBuilder::BindGroupBuilder(DeviceBase* device) BindGroupBuilder::BindGroupBuilder(DeviceBase* device) : Builder(device) {
: Builder(device) {
} }
BindGroupBase* BindGroupBuilder::GetResult() { BindGroupBase* BindGroupBuilder::GetResultImpl() {
constexpr int allProperties = BINDGROUP_PROPERTY_USAGE | BINDGROUP_PROPERTY_LAYOUT; constexpr int allProperties = BINDGROUP_PROPERTY_USAGE | BINDGROUP_PROPERTY_LAYOUT;
if ((propertiesSet & allProperties) != allProperties) { if ((propertiesSet & allProperties) != allProperties) {
HandleError("Bindgroup missing properties"); HandleError("Bindgroup missing properties");
@ -80,7 +79,6 @@ namespace backend {
return nullptr; return nullptr;
} }
MarkConsumed();
return device->CreateBindGroup(this); return device->CreateBindGroup(this);
} }

View File

@ -43,12 +43,11 @@ namespace backend {
std::array<Ref<RefCounted>, kMaxBindingsPerGroup> bindings; std::array<Ref<RefCounted>, kMaxBindingsPerGroup> bindings;
}; };
class BindGroupBuilder : public Builder { class BindGroupBuilder : public Builder<BindGroupBase> {
public: public:
BindGroupBuilder(DeviceBase* device); BindGroupBuilder(DeviceBase* device);
// NXT API // NXT API
BindGroupBase* GetResult();
void SetLayout(BindGroupLayoutBase* layout); void SetLayout(BindGroupLayoutBase* layout);
void SetUsage(nxt::BindGroupUsage usage); void SetUsage(nxt::BindGroupUsage usage);
@ -76,6 +75,7 @@ namespace backend {
private: private:
friend class BindGroupBase; friend class BindGroupBase;
BindGroupBase* GetResultImpl() override;
void SetBindingsBase(uint32_t start, uint32_t count, RefCounted* const * objects); void SetBindingsBase(uint32_t start, uint32_t count, RefCounted* const * objects);
bool SetBindingsValidationBase(uint32_t start, uint32_t count); bool SetBindingsValidationBase(uint32_t start, uint32_t count);

View File

@ -101,8 +101,7 @@ namespace backend {
return bindingInfo; return bindingInfo;
} }
BindGroupLayoutBase* BindGroupLayoutBuilder::GetResult() { BindGroupLayoutBase* BindGroupLayoutBuilder::GetResultImpl() {
MarkConsumed();
BindGroupLayoutBase blueprint(this, true); BindGroupLayoutBase blueprint(this, true);
auto* result = device->GetOrCreateBindGroupLayout(&blueprint, this); auto* result = device->GetOrCreateBindGroupLayout(&blueprint, this);

View File

@ -44,19 +44,20 @@ namespace backend {
bool blueprint = false; bool blueprint = false;
}; };
class BindGroupLayoutBuilder : public Builder { class BindGroupLayoutBuilder : public Builder<BindGroupLayoutBase> {
public: public:
BindGroupLayoutBuilder(DeviceBase* device); BindGroupLayoutBuilder(DeviceBase* device);
const BindGroupLayoutBase::LayoutBindingInfo& GetBindingInfo() const; const BindGroupLayoutBase::LayoutBindingInfo& GetBindingInfo() const;
// NXT API // NXT API
BindGroupLayoutBase* GetResult();
void SetBindingsType(nxt::ShaderStageBit visibility, nxt::BindingType bindingType, uint32_t start, uint32_t count); void SetBindingsType(nxt::ShaderStageBit visibility, nxt::BindingType bindingType, uint32_t start, uint32_t count);
private: private:
friend class BindGroupLayoutBase; friend class BindGroupLayoutBase;
BindGroupLayoutBase* GetResultImpl() override;
BindGroupLayoutBase::LayoutBindingInfo bindingInfo; BindGroupLayoutBase::LayoutBindingInfo bindingInfo;
}; };

View File

@ -121,7 +121,7 @@ namespace backend {
BufferBuilder::BufferBuilder(DeviceBase* device) : Builder(device) { BufferBuilder::BufferBuilder(DeviceBase* device) : Builder(device) {
} }
BufferBase* BufferBuilder::GetResult() { BufferBase* BufferBuilder::GetResultImpl() {
constexpr int allProperties = BUFFER_PROPERTY_ALLOWED_USAGE | BUFFER_PROPERTY_SIZE; constexpr int allProperties = BUFFER_PROPERTY_ALLOWED_USAGE | BUFFER_PROPERTY_SIZE;
if ((propertiesSet & allProperties) != allProperties) { if ((propertiesSet & allProperties) != allProperties) {
HandleError("Buffer missing properties"); HandleError("Buffer missing properties");
@ -133,7 +133,6 @@ namespace backend {
return nullptr; return nullptr;
} }
MarkConsumed();
return device->CreateBuffer(this); return device->CreateBuffer(this);
} }
@ -195,14 +194,13 @@ namespace backend {
: Builder(device), buffer(buffer) { : Builder(device), buffer(buffer) {
} }
BufferViewBase* BufferViewBuilder::GetResult() { BufferViewBase* BufferViewBuilder::GetResultImpl() {
constexpr int allProperties = BUFFER_VIEW_PROPERTY_EXTENT; constexpr int allProperties = BUFFER_VIEW_PROPERTY_EXTENT;
if ((propertiesSet & allProperties) != allProperties) { if ((propertiesSet & allProperties) != allProperties) {
HandleError("Buffer view missing properties"); HandleError("Buffer view missing properties");
return nullptr; return nullptr;
} }
MarkConsumed();
return device->CreateBufferView(this); return device->CreateBufferView(this);
} }

View File

@ -52,12 +52,11 @@ namespace backend {
bool frozen = false; bool frozen = false;
}; };
class BufferBuilder : public Builder { class BufferBuilder : public Builder<BufferBase> {
public: public:
BufferBuilder(DeviceBase* device); BufferBuilder(DeviceBase* device);
// NXT API // NXT API
BufferBase* GetResult();
void SetAllowedUsage(nxt::BufferUsageBit usage); void SetAllowedUsage(nxt::BufferUsageBit usage);
void SetInitialUsage(nxt::BufferUsageBit usage); void SetInitialUsage(nxt::BufferUsageBit usage);
void SetSize(uint32_t size); void SetSize(uint32_t size);
@ -65,6 +64,8 @@ namespace backend {
private: private:
friend class BufferBase; friend class BufferBase;
BufferBase* GetResultImpl() override;
uint32_t size; uint32_t size;
nxt::BufferUsageBit allowedUsage = nxt::BufferUsageBit::None; nxt::BufferUsageBit allowedUsage = nxt::BufferUsageBit::None;
nxt::BufferUsageBit currentUsage = nxt::BufferUsageBit::None; nxt::BufferUsageBit currentUsage = nxt::BufferUsageBit::None;
@ -85,17 +86,18 @@ namespace backend {
uint32_t offset; uint32_t offset;
}; };
class BufferViewBuilder : public Builder { class BufferViewBuilder : public Builder<BufferViewBase> {
public: public:
BufferViewBuilder(DeviceBase* device, BufferBase* buffer); BufferViewBuilder(DeviceBase* device, BufferBase* buffer);
// NXT API // NXT API
BufferViewBase* GetResult();
void SetExtent(uint32_t offset, uint32_t size); void SetExtent(uint32_t offset, uint32_t size);
private: private:
friend class BufferViewBase; friend class BufferViewBase;
BufferViewBase* GetResultImpl() override;
Ref<BufferBase> buffer; Ref<BufferBase> buffer;
uint32_t offset = 0; uint32_t offset = 0;
uint32_t size = 0; uint32_t size = 0;

View File

@ -18,20 +18,69 @@
namespace backend { namespace backend {
bool Builder::WasConsumed() const { bool BuilderBase::CanBeUsed() const {
return consumed; return !consumed && !gotStatus;
} }
Builder::Builder(DeviceBase* device) : device(device) { void BuilderBase::HandleError(const char* message) {
SetStatus(nxt::BuilderErrorStatus::Error, message);
} }
void Builder::MarkConsumed() { void BuilderBase::SetErrorCallback(nxt::BuilderErrorCallback callback,
nxt::CallbackUserdata userdata1,
nxt::CallbackUserdata userdata2) {
this->callback = callback;
this->userdata1 = userdata1;
this->userdata2 = userdata2;
}
BuilderBase::BuilderBase(DeviceBase* device) : device(device) {
}
BuilderBase::~BuilderBase() {
if (!consumed && callback != nullptr) {
callback(NXT_BUILDER_ERROR_STATUS_UNKNOWN, "Builder destroyed before GetResult", userdata1, userdata2);
}
}
void BuilderBase::SetStatus(nxt::BuilderErrorStatus status, const char* message) {
ASSERT(status != nxt::BuilderErrorStatus::Success);
ASSERT(status != nxt::BuilderErrorStatus::Unknown);
ASSERT(!gotStatus); // This is not strictly necessary but something to strive for.
gotStatus = true;
storedStatus = status;
storedMessage = std::move(message);
}
bool BuilderBase::HandleResult(RefCounted* result) {
// GetResult can only be called once.
ASSERT(!consumed); ASSERT(!consumed);
consumed = true; consumed = true;
}
void Builder::HandleError(const char* message) { // result == nullptr implies there was an error which implies we should have a status set.
device->HandleError(message); ASSERT(result != nullptr || gotStatus);
// If we have any error, then we have to return nullptr
if (gotStatus) {
ASSERT(storedStatus != nxt::BuilderErrorStatus::Success);
// The application will never see "result" so we need to remove the
// external ref here.
if (result != nullptr) {
result->Release();
result = nullptr;
}
} else {
ASSERT(storedStatus == nxt::BuilderErrorStatus::Success);
ASSERT(storedMessage.empty());
}
if (callback) {
callback(static_cast<nxtBuilderErrorStatus>(storedStatus), storedMessage.c_str(), userdata1, userdata2);
}
return result != nullptr;
} }
} }

View File

@ -18,24 +18,92 @@
#include "Forward.h" #include "Forward.h"
#include "RefCounted.h" #include "RefCounted.h"
#include "nxt/nxtcpp.h"
#include <string>
namespace backend { namespace backend {
class Builder : public RefCounted { // This class implements behavior shared by all builders:
// - Tracking whether GetResult has been called already, needed by the
// autogenerated code to prevent operations on "consumed" builders.
// - The error status callback of the API. The callback is guaranteed to be
// called exactly once with an error, a success, or "unknown" if the
// builder is destroyed; also the builder callback cannot be called before
// either the object is destroyed or GetResult is called.
//
// It is possible for error to be generated before the error callback is
// registered when a builder "set" function performance validation inline.
// Because of this we have to store the status in the builder and defer
// calling the callback to GetResult.
class BuilderBase : public RefCounted {
public: public:
bool WasConsumed() const; // Used by the auto-generated validation to prevent usage of the builder
// after GetResult or an error.
bool CanBeUsed() const;
// Set the status of the builder to an error.
void HandleError(const char* message); void HandleError(const char* message);
protected: // Internal API, to be used by builder and BackendProcTable only.
Builder(DeviceBase* device); // rReturns true for success cases, and calls the callback with appropriate status.
bool HandleResult(RefCounted* result);
void MarkConsumed(); // NXT API
void SetErrorCallback(nxt::BuilderErrorCallback callback,
nxt::CallbackUserdata userdata1,
nxt::CallbackUserdata userdata2);
protected:
BuilderBase(DeviceBase* device);
~BuilderBase();
DeviceBase* const device; DeviceBase* const device;
bool gotStatus = false;
private: private:
void SetStatus(nxt::BuilderErrorStatus status, const char* message);
nxt::BuilderErrorCallback callback = nullptr;
nxt::CallbackUserdata userdata1 = 0;
nxt::CallbackUserdata userdata2 = 0;
nxt::BuilderErrorStatus storedStatus = nxt::BuilderErrorStatus::Success;
std::string storedMessage;
bool consumed = false; bool consumed = false;
}; };
// This builder base class is used to capture the calls to GetResult and make sure
// that either:
// - There was an error, callback is called with an error and nullptr is returned.
// - There was no error, callback is called with success and a non-null T* is returned.
template<typename T>
class Builder : public BuilderBase {
public:
// NXT API
T* GetResult();
protected:
using BuilderBase::BuilderBase;
private:
virtual T* GetResultImpl() = 0;
};
template<typename T>
T* Builder<T>::GetResult() {
T* result = GetResultImpl();
// An object can have been returned but failed its initialization, so if an error
// happened, return nullptr instead of result.
if (HandleResult(result)) {
return result;
} else {
return nullptr;
}
}
} }
#endif // BACKEND_COMMON_BUILDER_H_ #endif // BACKEND_COMMON_BUILDER_H_

View File

@ -136,7 +136,7 @@ namespace backend {
} }
CommandBufferBuilder::~CommandBufferBuilder() { CommandBufferBuilder::~CommandBufferBuilder() {
if (!WasConsumed()) { if (!commandsAcquired) {
MoveToIterator(); MoveToIterator();
FreeCommands(&iterator); FreeCommands(&iterator);
} }
@ -484,12 +484,13 @@ namespace backend {
} }
CommandIterator CommandBufferBuilder::AcquireCommands() { CommandIterator CommandBufferBuilder::AcquireCommands() {
ASSERT(!commandsAcquired);
commandsAcquired = true;
return std::move(iterator); return std::move(iterator);
} }
CommandBufferBase* CommandBufferBuilder::GetResult() { CommandBufferBase* CommandBufferBuilder::GetResultImpl() {
MoveToIterator(); MoveToIterator();
MarkConsumed();
return device->CreateCommandBuffer(this); return device->CreateCommandBuffer(this);
} }

View File

@ -45,7 +45,7 @@ namespace backend {
std::set<TextureBase*> texturesTransitioned; std::set<TextureBase*> texturesTransitioned;
}; };
class CommandBufferBuilder : public Builder { class CommandBufferBuilder : public Builder<CommandBufferBase> {
public: public:
CommandBufferBuilder(DeviceBase* device); CommandBufferBuilder(DeviceBase* device);
~CommandBufferBuilder(); ~CommandBufferBuilder();
@ -55,8 +55,6 @@ namespace backend {
CommandIterator AcquireCommands(); CommandIterator AcquireCommands();
// NXT API // NXT API
CommandBufferBase* GetResult();
void CopyBufferToTexture(BufferBase* buffer, uint32_t bufferOffset, void CopyBufferToTexture(BufferBase* buffer, uint32_t bufferOffset,
TextureBase* texture, uint32_t x, uint32_t y, uint32_t z, TextureBase* texture, uint32_t x, uint32_t y, uint32_t z,
uint32_t width, uint32_t height, uint32_t depth, uint32_t level); uint32_t width, uint32_t height, uint32_t depth, uint32_t level);
@ -81,11 +79,13 @@ namespace backend {
private: private:
friend class CommandBufferBase; friend class CommandBufferBase;
CommandBufferBase* GetResultImpl() override;
void MoveToIterator(); void MoveToIterator();
CommandAllocator allocator; CommandAllocator allocator;
CommandIterator iterator; CommandIterator iterator;
bool movedToIterator = false; bool movedToIterator = false;
bool commandsAcquired = false;
// These pointers will remain valid since they are referenced by // These pointers will remain valid since they are referenced by
// the bind groups which are referenced by this command buffer. // the bind groups which are referenced by this command buffer.
std::set<BufferBase*> buffersTransitioned; std::set<BufferBase*> buffersTransitioned;

View File

@ -30,7 +30,7 @@ namespace backend {
~DeviceBase(); ~DeviceBase();
void HandleError(const char* message); void HandleError(const char* message);
void SetErrorCallback(nxt::DeviceErrorCallback, nxt::CallbackUserdata userdata); void SetErrorCallback(nxt::DeviceErrorCallback callback, nxt::CallbackUserdata userdata);
virtual BindGroupBase* CreateBindGroup(BindGroupBuilder* builder) = 0; virtual BindGroupBase* CreateBindGroup(BindGroupBuilder* builder) = 0;
virtual BindGroupLayoutBase* CreateBindGroupLayout(BindGroupLayoutBuilder* builder) = 0; virtual BindGroupLayoutBase* CreateBindGroupLayout(BindGroupLayoutBuilder* builder) = 0;
@ -86,7 +86,7 @@ namespace backend {
Caches* caches = nullptr; Caches* caches = nullptr;
nxt::DeviceErrorCallback errorCallback = nullptr; nxt::DeviceErrorCallback errorCallback = nullptr;
nxt::CallbackUserdata errorUserdata; nxt::CallbackUserdata errorUserdata = 0;
}; };
} }

View File

@ -81,7 +81,7 @@ namespace backend {
InputStateBuilder::InputStateBuilder(DeviceBase* device) : Builder(device) { InputStateBuilder::InputStateBuilder(DeviceBase* device) : Builder(device) {
} }
InputStateBase* InputStateBuilder::GetResult() { InputStateBase* InputStateBuilder::GetResultImpl() {
for (uint32_t location = 0; location < kMaxVertexAttributes; ++location) { for (uint32_t location = 0; location < kMaxVertexAttributes; ++location) {
if (attributesSetMask[location] && if (attributesSetMask[location] &&
!inputsSetMask[attributeInfos[location].bindingSlot]) { !inputsSetMask[attributeInfos[location].bindingSlot]) {
@ -90,7 +90,6 @@ namespace backend {
} }
} }
MarkConsumed();
return device->CreateInputState(this); return device->CreateInputState(this);
} }

View File

@ -57,12 +57,11 @@ namespace backend {
std::array<InputInfo, kMaxVertexInputs> inputInfos; std::array<InputInfo, kMaxVertexInputs> inputInfos;
}; };
class InputStateBuilder : public Builder { class InputStateBuilder : public Builder<InputStateBase> {
public: public:
InputStateBuilder(DeviceBase* device); InputStateBuilder(DeviceBase* device);
// NXT API // NXT API
InputStateBase* GetResult();
void SetAttribute(uint32_t shaderLocation, uint32_t bindingSlot, void SetAttribute(uint32_t shaderLocation, uint32_t bindingSlot,
nxt::VertexFormat format, uint32_t offset); nxt::VertexFormat format, uint32_t offset);
void SetInput(uint32_t bindingSlot, uint32_t stride, void SetInput(uint32_t bindingSlot, uint32_t stride,
@ -71,6 +70,8 @@ namespace backend {
private: private:
friend class InputStateBase; friend class InputStateBase;
InputStateBase* GetResultImpl() override;
std::bitset<kMaxVertexAttributes> attributesSetMask; std::bitset<kMaxVertexAttributes> attributesSetMask;
std::array<InputStateBase::AttributeInfo, kMaxVertexAttributes> attributeInfos; std::array<InputStateBase::AttributeInfo, kMaxVertexAttributes> attributeInfos;
std::bitset<kMaxVertexInputs> inputsSetMask; std::bitset<kMaxVertexInputs> inputsSetMask;

View File

@ -98,7 +98,7 @@ namespace backend {
return stages[stage]; return stages[stage];
} }
PipelineBase* PipelineBuilder::GetResult() { PipelineBase* PipelineBuilder::GetResultImpl() {
// TODO(cwallez@chromium.org): the layout should be required, and put the default objects in the device // TODO(cwallez@chromium.org): the layout should be required, and put the default objects in the device
if (!layout) { if (!layout) {
layout = device->CreatePipelineLayoutBuilder()->GetResult(); layout = device->CreatePipelineLayoutBuilder()->GetResult();
@ -107,7 +107,6 @@ namespace backend {
inputState = device->CreateInputStateBuilder()->GetResult(); inputState = device->CreateInputStateBuilder()->GetResult();
} }
MarkConsumed();
return device->CreatePipeline(this); return device->CreatePipeline(this);
} }

View File

@ -59,7 +59,7 @@ namespace backend {
Ref<InputStateBase> inputState; Ref<InputStateBase> inputState;
}; };
class PipelineBuilder : public Builder { class PipelineBuilder : public Builder<PipelineBase> {
public: public:
PipelineBuilder(DeviceBase* device); PipelineBuilder(DeviceBase* device);
@ -70,7 +70,6 @@ namespace backend {
const StageInfo& GetStageInfo(nxt::ShaderStage stage) const; const StageInfo& GetStageInfo(nxt::ShaderStage stage) const;
// NXT API // NXT API
PipelineBase* GetResult();
void SetLayout(PipelineLayoutBase* layout); void SetLayout(PipelineLayoutBase* layout);
void SetStage(nxt::ShaderStage stage, ShaderModuleBase* module, const char* entryPoint); void SetStage(nxt::ShaderStage stage, ShaderModuleBase* module, const char* entryPoint);
void SetInputState(InputStateBase* inputState); void SetInputState(InputStateBase* inputState);
@ -78,6 +77,8 @@ namespace backend {
private: private:
friend class PipelineBase; friend class PipelineBase;
PipelineBase* GetResultImpl() override;
Ref<PipelineLayoutBase> layout; Ref<PipelineLayoutBase> layout;
nxt::ShaderStageBit stageMask; nxt::ShaderStageBit stageMask;
PerStage<StageInfo> stages; PerStage<StageInfo> stages;

View File

@ -39,7 +39,7 @@ namespace backend {
PipelineLayoutBuilder::PipelineLayoutBuilder(DeviceBase* device) : Builder(device) { PipelineLayoutBuilder::PipelineLayoutBuilder(DeviceBase* device) : Builder(device) {
} }
PipelineLayoutBase* PipelineLayoutBuilder::GetResult() { PipelineLayoutBase* PipelineLayoutBuilder::GetResultImpl() {
// TODO(cwallez@chromium.org): this is a hack, have the null bind group layout somewhere in the device // TODO(cwallez@chromium.org): this is a hack, have the null bind group layout somewhere in the device
// once we have a cache of BGL // once we have a cache of BGL
for (size_t group = 0; group < kMaxBindGroups; ++group) { for (size_t group = 0; group < kMaxBindGroups; ++group) {
@ -48,7 +48,6 @@ namespace backend {
} }
} }
MarkConsumed();
return device->CreatePipelineLayout(this); return device->CreatePipelineLayout(this);
} }

View File

@ -40,17 +40,18 @@ namespace backend {
std::bitset<kMaxBindGroups> mask; std::bitset<kMaxBindGroups> mask;
}; };
class PipelineLayoutBuilder : public Builder { class PipelineLayoutBuilder : public Builder<PipelineLayoutBase> {
public: public:
PipelineLayoutBuilder(DeviceBase* device); PipelineLayoutBuilder(DeviceBase* device);
// NXT API // NXT API
PipelineLayoutBase* GetResult();
void SetBindGroupLayout(uint32_t groupIndex, BindGroupLayoutBase* layout); void SetBindGroupLayout(uint32_t groupIndex, BindGroupLayoutBase* layout);
private: private:
friend class PipelineLayoutBase; friend class PipelineLayoutBase;
PipelineLayoutBase* GetResultImpl() override;
BindGroupLayoutArray bindGroupLayouts; BindGroupLayoutArray bindGroupLayouts;
std::bitset<kMaxBindGroups> mask; std::bitset<kMaxBindGroups> mask;
}; };

View File

@ -30,8 +30,7 @@ namespace backend {
QueueBuilder::QueueBuilder(DeviceBase* device) : Builder(device) { QueueBuilder::QueueBuilder(DeviceBase* device) : Builder(device) {
} }
QueueBase* QueueBuilder::GetResult() { QueueBase* QueueBuilder::GetResultImpl() {
MarkConsumed();
return device->CreateQueue(this); return device->CreateQueue(this);
} }

View File

@ -41,12 +41,12 @@ namespace backend {
} }
}; };
class QueueBuilder : public Builder { class QueueBuilder : public Builder<QueueBase> {
public: public:
QueueBuilder(DeviceBase* device); QueueBuilder(DeviceBase* device);
// NXT API private:
QueueBase* GetResult(); QueueBase* GetResultImpl() override;
}; };
} }

View File

@ -43,11 +43,6 @@ namespace backend {
return mipMapFilter; return mipMapFilter;
} }
SamplerBase* SamplerBuilder::GetResult() {
MarkConsumed();
return device->CreateSampler(this);
}
void SamplerBuilder::SetFilterMode(nxt::FilterMode magFilter, nxt::FilterMode minFilter, nxt::FilterMode mipMapFilter) { void SamplerBuilder::SetFilterMode(nxt::FilterMode magFilter, nxt::FilterMode minFilter, nxt::FilterMode mipMapFilter) {
if ((propertiesSet & SAMPLER_PROPERTY_FILTER) != 0) { if ((propertiesSet & SAMPLER_PROPERTY_FILTER) != 0) {
HandleError("Sampler filter property set multiple times"); HandleError("Sampler filter property set multiple times");
@ -59,4 +54,9 @@ namespace backend {
this->mipMapFilter = mipMapFilter; this->mipMapFilter = mipMapFilter;
propertiesSet |= SAMPLER_PROPERTY_FILTER; propertiesSet |= SAMPLER_PROPERTY_FILTER;
} }
SamplerBase* SamplerBuilder::GetResultImpl() {
return device->CreateSampler(this);
}
} }

View File

@ -28,7 +28,7 @@ namespace backend {
SamplerBase(SamplerBuilder* builder); SamplerBase(SamplerBuilder* builder);
}; };
class SamplerBuilder : public Builder { class SamplerBuilder : public Builder<SamplerBase> {
public: public:
SamplerBuilder(DeviceBase* device); SamplerBuilder(DeviceBase* device);
@ -37,12 +37,13 @@ namespace backend {
nxt::FilterMode GetMipMapFilter() const; nxt::FilterMode GetMipMapFilter() const;
// NXT API // NXT API
SamplerBase* GetResult();
void SetFilterMode(nxt::FilterMode magFilter, nxt::FilterMode minFilter, nxt::FilterMode mipMapFilter); void SetFilterMode(nxt::FilterMode magFilter, nxt::FilterMode minFilter, nxt::FilterMode mipMapFilter);
private: private:
friend class SamplerBase; friend class SamplerBase;
SamplerBase* GetResultImpl() override;
int propertiesSet = 0; int propertiesSet = 0;
nxt::FilterMode magFilter = nxt::FilterMode::Nearest; nxt::FilterMode magFilter = nxt::FilterMode::Nearest;

View File

@ -198,13 +198,12 @@ namespace backend {
return std::move(spirv); return std::move(spirv);
} }
ShaderModuleBase* ShaderModuleBuilder::GetResult() { ShaderModuleBase* ShaderModuleBuilder::GetResultImpl() {
if (spirv.size() == 0) { if (spirv.size() == 0) {
HandleError("Shader module needs to have the source set"); HandleError("Shader module needs to have the source set");
return nullptr; return nullptr;
} }
MarkConsumed();
return device->CreateShaderModule(this); return device->CreateShaderModule(this);
} }

View File

@ -71,19 +71,20 @@ namespace backend {
nxt::ShaderStage executionModel; nxt::ShaderStage executionModel;
}; };
class ShaderModuleBuilder : public Builder { class ShaderModuleBuilder : public Builder<ShaderModuleBase> {
public: public:
ShaderModuleBuilder(DeviceBase* device); ShaderModuleBuilder(DeviceBase* device);
std::vector<uint32_t> AcquireSpirv(); std::vector<uint32_t> AcquireSpirv();
// NXT API // NXT API
ShaderModuleBase* GetResult();
void SetSource(uint32_t codeSize, const uint32_t* code); void SetSource(uint32_t codeSize, const uint32_t* code);
private: private:
friend class ShaderModuleBase; friend class ShaderModuleBase;
ShaderModuleBase* GetResultImpl() override;
std::vector<uint32_t> spirv; std::vector<uint32_t> spirv;
}; };

View File

@ -121,7 +121,7 @@ namespace backend {
: Builder(device) { : Builder(device) {
} }
TextureBase* TextureBuilder::GetResult() { TextureBase* TextureBuilder::GetResultImpl() {
constexpr int allProperties = TEXTURE_PROPERTY_DIMENSION | TEXTURE_PROPERTY_EXTENT | constexpr int allProperties = TEXTURE_PROPERTY_DIMENSION | TEXTURE_PROPERTY_EXTENT |
TEXTURE_PROPERTY_FORMAT | TEXTURE_PROPERTY_MIP_LEVELS | TEXTURE_PROPERTY_ALLOWED_USAGE; TEXTURE_PROPERTY_FORMAT | TEXTURE_PROPERTY_MIP_LEVELS | TEXTURE_PROPERTY_ALLOWED_USAGE;
if ((propertiesSet & allProperties) != allProperties) { if ((propertiesSet & allProperties) != allProperties) {
@ -136,7 +136,6 @@ namespace backend {
// TODO(cwallez@chromium.org): check stuff based on the dimension // TODO(cwallez@chromium.org): check stuff based on the dimension
MarkConsumed();
return device->CreateTexture(this); return device->CreateTexture(this);
} }
@ -223,8 +222,7 @@ namespace backend {
: Builder(device), texture(texture) { : Builder(device), texture(texture) {
} }
TextureViewBase* TextureViewBuilder::GetResult() { TextureViewBase* TextureViewBuilder::GetResultImpl() {
MarkConsumed();
return device->CreateTextureView(this); return device->CreateTextureView(this);
} }

View File

@ -60,12 +60,11 @@ namespace backend {
bool frozen = false; bool frozen = false;
}; };
class TextureBuilder : public Builder { class TextureBuilder : public Builder<TextureBase> {
public: public:
TextureBuilder(DeviceBase* device); TextureBuilder(DeviceBase* device);
// NXT API // NXT API
TextureBase* GetResult();
void SetDimension(nxt::TextureDimension dimension); void SetDimension(nxt::TextureDimension dimension);
void SetExtent(uint32_t width, uint32_t height, uint32_t depth); void SetExtent(uint32_t width, uint32_t height, uint32_t depth);
void SetFormat(nxt::TextureFormat format); void SetFormat(nxt::TextureFormat format);
@ -76,6 +75,8 @@ namespace backend {
private: private:
friend class TextureBase; friend class TextureBase;
TextureBase* GetResultImpl() override;
int propertiesSet = 0; int propertiesSet = 0;
nxt::TextureDimension dimension; nxt::TextureDimension dimension;
@ -96,16 +97,15 @@ namespace backend {
Ref<TextureBase> texture; Ref<TextureBase> texture;
}; };
class TextureViewBuilder : public Builder { class TextureViewBuilder : public Builder<TextureViewBase> {
public: public:
TextureViewBuilder(DeviceBase* device, TextureBase* texture); TextureViewBuilder(DeviceBase* device, TextureBase* texture);
// NXT API
TextureViewBase* GetResult();
private: private:
friend class TextureViewBase; friend class TextureViewBase;
TextureViewBase* GetResultImpl() override;
Ref<TextureBase> texture; Ref<TextureBase> texture;
}; };