From 7f961772894af9492cdb9a90df7939edbbf13bad Mon Sep 17 00:00:00 2001 From: Corentin Wallez Date: Mon, 8 May 2017 15:17:44 +0200 Subject: [PATCH] Implement the builder error callback in the backends This makes the Builder base class retain the error status, if any, and call the callback on GetResult (or ~Builder, whichever comes first). --- generator/main.py | 19 ++++-- generator/templates/BackendProcTable.cpp | 27 ++++++-- generator/templates/api.h | 1 + generator/templates/wire/WireClient.cpp | 8 +++ next.json | 12 ++++ src/backend/common/BindGroup.cpp | 6 +- src/backend/common/BindGroup.h | 4 +- src/backend/common/BindGroupLayout.cpp | 3 +- src/backend/common/BindGroupLayout.h | 5 +- src/backend/common/Buffer.cpp | 6 +- src/backend/common/Buffer.h | 10 +-- src/backend/common/Builder.cpp | 63 ++++++++++++++++--- src/backend/common/Builder.h | 78 ++++++++++++++++++++++-- src/backend/common/CommandBuffer.cpp | 7 ++- src/backend/common/CommandBuffer.h | 6 +- src/backend/common/Device.h | 4 +- src/backend/common/InputState.cpp | 3 +- src/backend/common/InputState.h | 5 +- src/backend/common/Pipeline.cpp | 3 +- src/backend/common/Pipeline.h | 5 +- src/backend/common/PipelineLayout.cpp | 3 +- src/backend/common/PipelineLayout.h | 5 +- src/backend/common/Queue.cpp | 3 +- src/backend/common/Queue.h | 6 +- src/backend/common/Sampler.cpp | 10 +-- src/backend/common/Sampler.h | 5 +- src/backend/common/ShaderModule.cpp | 3 +- src/backend/common/ShaderModule.h | 5 +- src/backend/common/Texture.cpp | 6 +- src/backend/common/Texture.h | 12 ++-- 30 files changed, 247 insertions(+), 86 deletions(-) diff --git a/generator/main.py b/generator/main.py index 7b2970ea80..eff89a699c 100644 --- a/generator/main.py +++ b/generator/main.py @@ -309,15 +309,24 @@ def as_backendType(typ): else: return as_cType(typ.name) +def cpp_native_methods(types, typ): + methods = typ.methods + typ.native_methods + + if typ.is_builder: + methods.append(Method(Name('set error callback'), types['void'], [ + MethodArgument(Name('callback'), types['builder error callback'], 'value'), + MethodArgument(Name('userdata1'), types['callback userdata'], 'value'), + MethodArgument(Name('userdata2'), types['callback userdata'], 'value'), + ])) + + return methods + def c_native_methods(types, typ): - return cpp_native_methods(typ) + [ + return cpp_native_methods(types, typ) + [ Method(Name('reference'), types['void'], []), Method(Name('release'), types['void'], []), ] -def cpp_native_methods(typ): - return typ.methods + typ.native_methods - def debug(text): print(text) @@ -376,7 +385,7 @@ def main(): renders.append(FileRender('api.c', 'nxt/nxt.c', [base_params, api_params, c_params])) if 'nxtcpp' in targets: - additional_params = {'native_methods': cpp_native_methods} + additional_params = {'native_methods': lambda typ: cpp_native_methods(api_params['types'], typ)} renders.append(FileRender('apicpp.h', 'nxt/nxtcpp.h', [base_params, api_params, additional_params])) renders.append(FileRender('apicpp.cpp', 'nxt/nxtcpp.cpp', [base_params, api_params, additional_params])) diff --git a/generator/templates/BackendProcTable.cpp b/generator/templates/BackendProcTable.cpp index b92503bd5e..55a67ab1d5 100644 --- a/generator/templates/BackendProcTable.cpp +++ b/generator/templates/BackendProcTable.cpp @@ -91,7 +91,7 @@ namespace {{namespace}} { {%- endfor -%} ) { {% if type.is_builder and method.name.canonical_case() not in ("release", "reference") %} - if (self->WasConsumed()) return false; + if (!self->CanBeUsed()) return false; {% else %} (void) self; {% endif %} @@ -121,6 +121,8 @@ namespace {{namespace}} { {%- endfor -%} ); + //* Some function have very heavy checks in a seperate method, so that they + //* can be skipped in the NonValidatingEntryPoints. {% if suffix in methodsWithExtraValidation %} if (valid) { valid = self->Validate{{method.name.CamelCase()}}( @@ -130,12 +132,27 @@ namespace {{namespace}} { ); } {% endif %} - //* TODO Do the hand-written checks if necessary - //* On success, forward the arguments to the method, else error out without calling it + + //* If there is an error we forward it appropriately. if (!valid) { - // TODO get the device and give it the error? - std::cout << "Error in {{suffix}}" << std::endl; + //* An error in a builder methods is always handled by the builder + {% if type.is_builder %} + //* HACK(cwallez@chromium.org): special casing GetResult so that the error callback + //* is called if needed. Without this, no call to HandleResult would happen, and the + //* error callback would always get called with an Unknown status + {% if method.name.canonical_case() == "get result" %} + {{as_backendType(method.return_type)}} fakeResult = nullptr; + bool shouldBeFalse = self->HandleResult(fakeResult); + assert(shouldBeFalse == false); + {% else %} + self->HandleError("Error in {{suffix}}"); + {% endif %} + {% else %} + // TODO get the device or builder and give it the error? + std::cout << "Error in {{suffix}}" << std::endl; + {% endif %} } + {% if method.return_type.name.canonical_case() == "void" %} if (!valid) return; {% else %} diff --git a/generator/templates/api.h b/generator/templates/api.h index d627ff9dbb..e31f645328 100644 --- a/generator/templates/api.h +++ b/generator/templates/api.h @@ -36,6 +36,7 @@ // Custom types depending on the target language typedef uint64_t nxtCallbackUserdata; typedef void (*nxtDeviceErrorCallback)(const char* message, nxtCallbackUserdata userdata); +typedef void (*nxtBuilderErrorCallback)(nxtBuilderErrorStatus status, const char* message, nxtCallbackUserdata userdata1, nxtCallbackUserdata userdata2); #ifdef __cplusplus extern "C" { diff --git a/generator/templates/wire/WireClient.cpp b/generator/templates/wire/WireClient.cpp index 8331c88919..dd60bb84e2 100644 --- a/generator/templates/wire/WireClient.cpp +++ b/generator/templates/wire/WireClient.cpp @@ -172,6 +172,14 @@ namespace wire { } {% endfor %} + {% if type.is_builder %} + void Client{{as_MethodSuffix(type.name, Name("set error callback"))}}(nxtBuilderErrorCallback callback, + nxtCallbackUserdata userdata1, + nxtCallbackUserdata userdata2) { + //TODO(cwallez@chromium.org): will be implemented in a follow-up commit. + } + {% endif %} + {% if not type.name.canonical_case() == "device" %} //* When an object's refcount reaches 0, notify the server side of it and delete it. void Client{{as_MethodSuffix(type.name, Name("release"))}}({{Type}}* obj) { diff --git a/next.json b/next.json index de12d128a9..861d3d3117 100644 --- a/next.json +++ b/next.json @@ -102,6 +102,18 @@ {"value": 3, "name": "storage buffer"} ] }, + "builder error status": { + "category": "enum", + "values": [ + {"value": 0, "name": "success"}, + {"value": 1, "name": "error", "TODO": "cwallez@chromium.org: recoverable errors like GPU OOM"}, + {"value": 2, "name": "unknown"}, + {"value": 3, "name": "context lost"} + ] + }, + "builder error callback": { + "category": "natively defined" + }, "buffer": { "category": "object", "methods": [ diff --git a/src/backend/common/BindGroup.cpp b/src/backend/common/BindGroup.cpp index a4435e2b07..cd50ed8d42 100644 --- a/src/backend/common/BindGroup.cpp +++ b/src/backend/common/BindGroup.cpp @@ -64,11 +64,10 @@ namespace backend { BINDGROUP_PROPERTY_LAYOUT = 0x2, }; - BindGroupBuilder::BindGroupBuilder(DeviceBase* device) - : Builder(device) { + BindGroupBuilder::BindGroupBuilder(DeviceBase* device) : Builder(device) { } - BindGroupBase* BindGroupBuilder::GetResult() { + BindGroupBase* BindGroupBuilder::GetResultImpl() { constexpr int allProperties = BINDGROUP_PROPERTY_USAGE | BINDGROUP_PROPERTY_LAYOUT; if ((propertiesSet & allProperties) != allProperties) { HandleError("Bindgroup missing properties"); @@ -80,7 +79,6 @@ namespace backend { return nullptr; } - MarkConsumed(); return device->CreateBindGroup(this); } diff --git a/src/backend/common/BindGroup.h b/src/backend/common/BindGroup.h index 472e822b98..72ad021856 100644 --- a/src/backend/common/BindGroup.h +++ b/src/backend/common/BindGroup.h @@ -43,12 +43,11 @@ namespace backend { std::array, kMaxBindingsPerGroup> bindings; }; - class BindGroupBuilder : public Builder { + class BindGroupBuilder : public Builder { public: BindGroupBuilder(DeviceBase* device); // NXT API - BindGroupBase* GetResult(); void SetLayout(BindGroupLayoutBase* layout); void SetUsage(nxt::BindGroupUsage usage); @@ -76,6 +75,7 @@ namespace backend { private: friend class BindGroupBase; + BindGroupBase* GetResultImpl() override; void SetBindingsBase(uint32_t start, uint32_t count, RefCounted* const * objects); bool SetBindingsValidationBase(uint32_t start, uint32_t count); diff --git a/src/backend/common/BindGroupLayout.cpp b/src/backend/common/BindGroupLayout.cpp index 588d32a63a..0a17090df4 100644 --- a/src/backend/common/BindGroupLayout.cpp +++ b/src/backend/common/BindGroupLayout.cpp @@ -101,8 +101,7 @@ namespace backend { return bindingInfo; } - BindGroupLayoutBase* BindGroupLayoutBuilder::GetResult() { - MarkConsumed(); + BindGroupLayoutBase* BindGroupLayoutBuilder::GetResultImpl() { BindGroupLayoutBase blueprint(this, true); auto* result = device->GetOrCreateBindGroupLayout(&blueprint, this); diff --git a/src/backend/common/BindGroupLayout.h b/src/backend/common/BindGroupLayout.h index ef5c2a689a..d2334d2515 100644 --- a/src/backend/common/BindGroupLayout.h +++ b/src/backend/common/BindGroupLayout.h @@ -44,19 +44,20 @@ namespace backend { bool blueprint = false; }; - class BindGroupLayoutBuilder : public Builder { + class BindGroupLayoutBuilder : public Builder { public: BindGroupLayoutBuilder(DeviceBase* device); const BindGroupLayoutBase::LayoutBindingInfo& GetBindingInfo() const; // NXT API - BindGroupLayoutBase* GetResult(); void SetBindingsType(nxt::ShaderStageBit visibility, nxt::BindingType bindingType, uint32_t start, uint32_t count); private: friend class BindGroupLayoutBase; + BindGroupLayoutBase* GetResultImpl() override; + BindGroupLayoutBase::LayoutBindingInfo bindingInfo; }; diff --git a/src/backend/common/Buffer.cpp b/src/backend/common/Buffer.cpp index 6191763b1e..ded0f24490 100644 --- a/src/backend/common/Buffer.cpp +++ b/src/backend/common/Buffer.cpp @@ -121,7 +121,7 @@ namespace backend { BufferBuilder::BufferBuilder(DeviceBase* device) : Builder(device) { } - BufferBase* BufferBuilder::GetResult() { + BufferBase* BufferBuilder::GetResultImpl() { constexpr int allProperties = BUFFER_PROPERTY_ALLOWED_USAGE | BUFFER_PROPERTY_SIZE; if ((propertiesSet & allProperties) != allProperties) { HandleError("Buffer missing properties"); @@ -133,7 +133,6 @@ namespace backend { return nullptr; } - MarkConsumed(); return device->CreateBuffer(this); } @@ -195,14 +194,13 @@ namespace backend { : Builder(device), buffer(buffer) { } - BufferViewBase* BufferViewBuilder::GetResult() { + BufferViewBase* BufferViewBuilder::GetResultImpl() { constexpr int allProperties = BUFFER_VIEW_PROPERTY_EXTENT; if ((propertiesSet & allProperties) != allProperties) { HandleError("Buffer view missing properties"); return nullptr; } - MarkConsumed(); return device->CreateBufferView(this); } diff --git a/src/backend/common/Buffer.h b/src/backend/common/Buffer.h index 71ec79af4e..06f082d3a6 100644 --- a/src/backend/common/Buffer.h +++ b/src/backend/common/Buffer.h @@ -52,12 +52,11 @@ namespace backend { bool frozen = false; }; - class BufferBuilder : public Builder { + class BufferBuilder : public Builder { public: BufferBuilder(DeviceBase* device); // NXT API - BufferBase* GetResult(); void SetAllowedUsage(nxt::BufferUsageBit usage); void SetInitialUsage(nxt::BufferUsageBit usage); void SetSize(uint32_t size); @@ -65,6 +64,8 @@ namespace backend { private: friend class BufferBase; + BufferBase* GetResultImpl() override; + uint32_t size; nxt::BufferUsageBit allowedUsage = nxt::BufferUsageBit::None; nxt::BufferUsageBit currentUsage = nxt::BufferUsageBit::None; @@ -85,17 +86,18 @@ namespace backend { uint32_t offset; }; - class BufferViewBuilder : public Builder { + class BufferViewBuilder : public Builder { public: BufferViewBuilder(DeviceBase* device, BufferBase* buffer); // NXT API - BufferViewBase* GetResult(); void SetExtent(uint32_t offset, uint32_t size); private: friend class BufferViewBase; + BufferViewBase* GetResultImpl() override; + Ref buffer; uint32_t offset = 0; uint32_t size = 0; diff --git a/src/backend/common/Builder.cpp b/src/backend/common/Builder.cpp index cf3b63c5fa..97555c9485 100644 --- a/src/backend/common/Builder.cpp +++ b/src/backend/common/Builder.cpp @@ -18,20 +18,69 @@ namespace backend { - bool Builder::WasConsumed() const { - return consumed; + bool BuilderBase::CanBeUsed() const { + return !consumed && !gotStatus; } - Builder::Builder(DeviceBase* device) : device(device) { + void BuilderBase::HandleError(const char* message) { + SetStatus(nxt::BuilderErrorStatus::Error, message); } - void Builder::MarkConsumed() { + void BuilderBase::SetErrorCallback(nxt::BuilderErrorCallback callback, + nxt::CallbackUserdata userdata1, + nxt::CallbackUserdata userdata2) { + this->callback = callback; + this->userdata1 = userdata1; + this->userdata2 = userdata2; + } + + BuilderBase::BuilderBase(DeviceBase* device) : device(device) { + } + + BuilderBase::~BuilderBase() { + if (!consumed && callback != nullptr) { + callback(NXT_BUILDER_ERROR_STATUS_UNKNOWN, "Builder destroyed before GetResult", userdata1, userdata2); + } + } + + void BuilderBase::SetStatus(nxt::BuilderErrorStatus status, const char* message) { + ASSERT(status != nxt::BuilderErrorStatus::Success); + ASSERT(status != nxt::BuilderErrorStatus::Unknown); + ASSERT(!gotStatus); // This is not strictly necessary but something to strive for. + gotStatus = true; + + storedStatus = status; + storedMessage = std::move(message); + } + + bool BuilderBase::HandleResult(RefCounted* result) { + // GetResult can only be called once. ASSERT(!consumed); consumed = true; - } - void Builder::HandleError(const char* message) { - device->HandleError(message); + // result == nullptr implies there was an error which implies we should have a status set. + ASSERT(result != nullptr || gotStatus); + + // If we have any error, then we have to return nullptr + if (gotStatus) { + ASSERT(storedStatus != nxt::BuilderErrorStatus::Success); + + // The application will never see "result" so we need to remove the + // external ref here. + if (result != nullptr) { + result->Release(); + result = nullptr; + } + } else { + ASSERT(storedStatus == nxt::BuilderErrorStatus::Success); + ASSERT(storedMessage.empty()); + } + + if (callback) { + callback(static_cast(storedStatus), storedMessage.c_str(), userdata1, userdata2); + } + + return result != nullptr; } } diff --git a/src/backend/common/Builder.h b/src/backend/common/Builder.h index e3431c5b3b..98ae461d9a 100644 --- a/src/backend/common/Builder.h +++ b/src/backend/common/Builder.h @@ -18,24 +18,92 @@ #include "Forward.h" #include "RefCounted.h" +#include "nxt/nxtcpp.h" + +#include + namespace backend { - class Builder : public RefCounted { + // This class implements behavior shared by all builders: + // - Tracking whether GetResult has been called already, needed by the + // autogenerated code to prevent operations on "consumed" builders. + // - The error status callback of the API. The callback is guaranteed to be + // called exactly once with an error, a success, or "unknown" if the + // builder is destroyed; also the builder callback cannot be called before + // either the object is destroyed or GetResult is called. + // + // It is possible for error to be generated before the error callback is + // registered when a builder "set" function performance validation inline. + // Because of this we have to store the status in the builder and defer + // calling the callback to GetResult. + + class BuilderBase : public RefCounted { public: - bool WasConsumed() const; + // Used by the auto-generated validation to prevent usage of the builder + // after GetResult or an error. + bool CanBeUsed() const; + + // Set the status of the builder to an error. void HandleError(const char* message); - protected: - Builder(DeviceBase* device); + // Internal API, to be used by builder and BackendProcTable only. + // rReturns true for success cases, and calls the callback with appropriate status. + bool HandleResult(RefCounted* result); - void MarkConsumed(); + // NXT API + void SetErrorCallback(nxt::BuilderErrorCallback callback, + nxt::CallbackUserdata userdata1, + nxt::CallbackUserdata userdata2); + + protected: + BuilderBase(DeviceBase* device); + ~BuilderBase(); DeviceBase* const device; + bool gotStatus = false; private: + void SetStatus(nxt::BuilderErrorStatus status, const char* message); + + nxt::BuilderErrorCallback callback = nullptr; + nxt::CallbackUserdata userdata1 = 0; + nxt::CallbackUserdata userdata2 = 0; + + nxt::BuilderErrorStatus storedStatus = nxt::BuilderErrorStatus::Success; + std::string storedMessage; + bool consumed = false; }; + // This builder base class is used to capture the calls to GetResult and make sure + // that either: + // - There was an error, callback is called with an error and nullptr is returned. + // - There was no error, callback is called with success and a non-null T* is returned. + template + class Builder : public BuilderBase { + public: + // NXT API + T* GetResult(); + + protected: + using BuilderBase::BuilderBase; + + private: + virtual T* GetResultImpl() = 0; + }; + + template + T* Builder::GetResult() { + T* result = GetResultImpl(); + // An object can have been returned but failed its initialization, so if an error + // happened, return nullptr instead of result. + if (HandleResult(result)) { + return result; + } else { + return nullptr; + } + } + } #endif // BACKEND_COMMON_BUILDER_H_ diff --git a/src/backend/common/CommandBuffer.cpp b/src/backend/common/CommandBuffer.cpp index b95cbee59a..8fadbea77b 100644 --- a/src/backend/common/CommandBuffer.cpp +++ b/src/backend/common/CommandBuffer.cpp @@ -136,7 +136,7 @@ namespace backend { } CommandBufferBuilder::~CommandBufferBuilder() { - if (!WasConsumed()) { + if (!commandsAcquired) { MoveToIterator(); FreeCommands(&iterator); } @@ -484,12 +484,13 @@ namespace backend { } CommandIterator CommandBufferBuilder::AcquireCommands() { + ASSERT(!commandsAcquired); + commandsAcquired = true; return std::move(iterator); } - CommandBufferBase* CommandBufferBuilder::GetResult() { + CommandBufferBase* CommandBufferBuilder::GetResultImpl() { MoveToIterator(); - MarkConsumed(); return device->CreateCommandBuffer(this); } diff --git a/src/backend/common/CommandBuffer.h b/src/backend/common/CommandBuffer.h index 02b85c7de1..341a9d9020 100644 --- a/src/backend/common/CommandBuffer.h +++ b/src/backend/common/CommandBuffer.h @@ -45,7 +45,7 @@ namespace backend { std::set texturesTransitioned; }; - class CommandBufferBuilder : public Builder { + class CommandBufferBuilder : public Builder { public: CommandBufferBuilder(DeviceBase* device); ~CommandBufferBuilder(); @@ -55,8 +55,6 @@ namespace backend { CommandIterator AcquireCommands(); // NXT API - CommandBufferBase* GetResult(); - void CopyBufferToTexture(BufferBase* buffer, uint32_t bufferOffset, TextureBase* texture, uint32_t x, uint32_t y, uint32_t z, uint32_t width, uint32_t height, uint32_t depth, uint32_t level); @@ -81,11 +79,13 @@ namespace backend { private: friend class CommandBufferBase; + CommandBufferBase* GetResultImpl() override; void MoveToIterator(); CommandAllocator allocator; CommandIterator iterator; bool movedToIterator = false; + bool commandsAcquired = false; // These pointers will remain valid since they are referenced by // the bind groups which are referenced by this command buffer. std::set buffersTransitioned; diff --git a/src/backend/common/Device.h b/src/backend/common/Device.h index 03821d64c9..c502167e34 100644 --- a/src/backend/common/Device.h +++ b/src/backend/common/Device.h @@ -30,7 +30,7 @@ namespace backend { ~DeviceBase(); void HandleError(const char* message); - void SetErrorCallback(nxt::DeviceErrorCallback, nxt::CallbackUserdata userdata); + void SetErrorCallback(nxt::DeviceErrorCallback callback, nxt::CallbackUserdata userdata); virtual BindGroupBase* CreateBindGroup(BindGroupBuilder* builder) = 0; virtual BindGroupLayoutBase* CreateBindGroupLayout(BindGroupLayoutBuilder* builder) = 0; @@ -86,7 +86,7 @@ namespace backend { Caches* caches = nullptr; nxt::DeviceErrorCallback errorCallback = nullptr; - nxt::CallbackUserdata errorUserdata; + nxt::CallbackUserdata errorUserdata = 0; }; } diff --git a/src/backend/common/InputState.cpp b/src/backend/common/InputState.cpp index a0ea2be4b5..29a6bec759 100644 --- a/src/backend/common/InputState.cpp +++ b/src/backend/common/InputState.cpp @@ -81,7 +81,7 @@ namespace backend { InputStateBuilder::InputStateBuilder(DeviceBase* device) : Builder(device) { } - InputStateBase* InputStateBuilder::GetResult() { + InputStateBase* InputStateBuilder::GetResultImpl() { for (uint32_t location = 0; location < kMaxVertexAttributes; ++location) { if (attributesSetMask[location] && !inputsSetMask[attributeInfos[location].bindingSlot]) { @@ -90,7 +90,6 @@ namespace backend { } } - MarkConsumed(); return device->CreateInputState(this); } diff --git a/src/backend/common/InputState.h b/src/backend/common/InputState.h index 2fa48dbf6a..f14af2fb9f 100644 --- a/src/backend/common/InputState.h +++ b/src/backend/common/InputState.h @@ -57,12 +57,11 @@ namespace backend { std::array inputInfos; }; - class InputStateBuilder : public Builder { + class InputStateBuilder : public Builder { public: InputStateBuilder(DeviceBase* device); // NXT API - InputStateBase* GetResult(); void SetAttribute(uint32_t shaderLocation, uint32_t bindingSlot, nxt::VertexFormat format, uint32_t offset); void SetInput(uint32_t bindingSlot, uint32_t stride, @@ -71,6 +70,8 @@ namespace backend { private: friend class InputStateBase; + InputStateBase* GetResultImpl() override; + std::bitset attributesSetMask; std::array attributeInfos; std::bitset inputsSetMask; diff --git a/src/backend/common/Pipeline.cpp b/src/backend/common/Pipeline.cpp index ded1a0bb4c..cebd9b5cbc 100644 --- a/src/backend/common/Pipeline.cpp +++ b/src/backend/common/Pipeline.cpp @@ -98,7 +98,7 @@ namespace backend { return stages[stage]; } - PipelineBase* PipelineBuilder::GetResult() { + PipelineBase* PipelineBuilder::GetResultImpl() { // TODO(cwallez@chromium.org): the layout should be required, and put the default objects in the device if (!layout) { layout = device->CreatePipelineLayoutBuilder()->GetResult(); @@ -107,7 +107,6 @@ namespace backend { inputState = device->CreateInputStateBuilder()->GetResult(); } - MarkConsumed(); return device->CreatePipeline(this); } diff --git a/src/backend/common/Pipeline.h b/src/backend/common/Pipeline.h index e2fb2490a4..0d01f17410 100644 --- a/src/backend/common/Pipeline.h +++ b/src/backend/common/Pipeline.h @@ -59,7 +59,7 @@ namespace backend { Ref inputState; }; - class PipelineBuilder : public Builder { + class PipelineBuilder : public Builder { public: PipelineBuilder(DeviceBase* device); @@ -70,7 +70,6 @@ namespace backend { const StageInfo& GetStageInfo(nxt::ShaderStage stage) const; // NXT API - PipelineBase* GetResult(); void SetLayout(PipelineLayoutBase* layout); void SetStage(nxt::ShaderStage stage, ShaderModuleBase* module, const char* entryPoint); void SetInputState(InputStateBase* inputState); @@ -78,6 +77,8 @@ namespace backend { private: friend class PipelineBase; + PipelineBase* GetResultImpl() override; + Ref layout; nxt::ShaderStageBit stageMask; PerStage stages; diff --git a/src/backend/common/PipelineLayout.cpp b/src/backend/common/PipelineLayout.cpp index 0400c4d105..cbca824836 100644 --- a/src/backend/common/PipelineLayout.cpp +++ b/src/backend/common/PipelineLayout.cpp @@ -39,7 +39,7 @@ namespace backend { PipelineLayoutBuilder::PipelineLayoutBuilder(DeviceBase* device) : Builder(device) { } - PipelineLayoutBase* PipelineLayoutBuilder::GetResult() { + PipelineLayoutBase* PipelineLayoutBuilder::GetResultImpl() { // TODO(cwallez@chromium.org): this is a hack, have the null bind group layout somewhere in the device // once we have a cache of BGL for (size_t group = 0; group < kMaxBindGroups; ++group) { @@ -48,7 +48,6 @@ namespace backend { } } - MarkConsumed(); return device->CreatePipelineLayout(this); } diff --git a/src/backend/common/PipelineLayout.h b/src/backend/common/PipelineLayout.h index fcb74b910a..2f28f16e25 100644 --- a/src/backend/common/PipelineLayout.h +++ b/src/backend/common/PipelineLayout.h @@ -40,17 +40,18 @@ namespace backend { std::bitset mask; }; - class PipelineLayoutBuilder : public Builder { + class PipelineLayoutBuilder : public Builder { public: PipelineLayoutBuilder(DeviceBase* device); // NXT API - PipelineLayoutBase* GetResult(); void SetBindGroupLayout(uint32_t groupIndex, BindGroupLayoutBase* layout); private: friend class PipelineLayoutBase; + PipelineLayoutBase* GetResultImpl() override; + BindGroupLayoutArray bindGroupLayouts; std::bitset mask; }; diff --git a/src/backend/common/Queue.cpp b/src/backend/common/Queue.cpp index f6c89ca4a0..be2ed7c256 100644 --- a/src/backend/common/Queue.cpp +++ b/src/backend/common/Queue.cpp @@ -30,8 +30,7 @@ namespace backend { QueueBuilder::QueueBuilder(DeviceBase* device) : Builder(device) { } - QueueBase* QueueBuilder::GetResult() { - MarkConsumed(); + QueueBase* QueueBuilder::GetResultImpl() { return device->CreateQueue(this); } diff --git a/src/backend/common/Queue.h b/src/backend/common/Queue.h index 4cac9f5f68..da56f633cc 100644 --- a/src/backend/common/Queue.h +++ b/src/backend/common/Queue.h @@ -41,12 +41,12 @@ namespace backend { } }; - class QueueBuilder : public Builder { + class QueueBuilder : public Builder { public: QueueBuilder(DeviceBase* device); - // NXT API - QueueBase* GetResult(); + private: + QueueBase* GetResultImpl() override; }; } diff --git a/src/backend/common/Sampler.cpp b/src/backend/common/Sampler.cpp index 8e6ece2194..17cd7531e5 100644 --- a/src/backend/common/Sampler.cpp +++ b/src/backend/common/Sampler.cpp @@ -43,11 +43,6 @@ namespace backend { return mipMapFilter; } - SamplerBase* SamplerBuilder::GetResult() { - MarkConsumed(); - return device->CreateSampler(this); - } - void SamplerBuilder::SetFilterMode(nxt::FilterMode magFilter, nxt::FilterMode minFilter, nxt::FilterMode mipMapFilter) { if ((propertiesSet & SAMPLER_PROPERTY_FILTER) != 0) { HandleError("Sampler filter property set multiple times"); @@ -59,4 +54,9 @@ namespace backend { this->mipMapFilter = mipMapFilter; propertiesSet |= SAMPLER_PROPERTY_FILTER; } + + SamplerBase* SamplerBuilder::GetResultImpl() { + return device->CreateSampler(this); + } + } diff --git a/src/backend/common/Sampler.h b/src/backend/common/Sampler.h index 437b2918ee..f4fbcb637a 100644 --- a/src/backend/common/Sampler.h +++ b/src/backend/common/Sampler.h @@ -28,7 +28,7 @@ namespace backend { SamplerBase(SamplerBuilder* builder); }; - class SamplerBuilder : public Builder { + class SamplerBuilder : public Builder { public: SamplerBuilder(DeviceBase* device); @@ -37,12 +37,13 @@ namespace backend { nxt::FilterMode GetMipMapFilter() const; // NXT API - SamplerBase* GetResult(); void SetFilterMode(nxt::FilterMode magFilter, nxt::FilterMode minFilter, nxt::FilterMode mipMapFilter); private: friend class SamplerBase; + SamplerBase* GetResultImpl() override; + int propertiesSet = 0; nxt::FilterMode magFilter = nxt::FilterMode::Nearest; diff --git a/src/backend/common/ShaderModule.cpp b/src/backend/common/ShaderModule.cpp index 2878b94946..0769712b81 100644 --- a/src/backend/common/ShaderModule.cpp +++ b/src/backend/common/ShaderModule.cpp @@ -198,13 +198,12 @@ namespace backend { return std::move(spirv); } - ShaderModuleBase* ShaderModuleBuilder::GetResult() { + ShaderModuleBase* ShaderModuleBuilder::GetResultImpl() { if (spirv.size() == 0) { HandleError("Shader module needs to have the source set"); return nullptr; } - MarkConsumed(); return device->CreateShaderModule(this); } diff --git a/src/backend/common/ShaderModule.h b/src/backend/common/ShaderModule.h index 1a1591d107..816fd0d1a3 100644 --- a/src/backend/common/ShaderModule.h +++ b/src/backend/common/ShaderModule.h @@ -71,19 +71,20 @@ namespace backend { nxt::ShaderStage executionModel; }; - class ShaderModuleBuilder : public Builder { + class ShaderModuleBuilder : public Builder { public: ShaderModuleBuilder(DeviceBase* device); std::vector AcquireSpirv(); // NXT API - ShaderModuleBase* GetResult(); void SetSource(uint32_t codeSize, const uint32_t* code); private: friend class ShaderModuleBase; + ShaderModuleBase* GetResultImpl() override; + std::vector spirv; }; diff --git a/src/backend/common/Texture.cpp b/src/backend/common/Texture.cpp index f92a0edcd8..58a36657bf 100644 --- a/src/backend/common/Texture.cpp +++ b/src/backend/common/Texture.cpp @@ -121,7 +121,7 @@ namespace backend { : Builder(device) { } - TextureBase* TextureBuilder::GetResult() { + TextureBase* TextureBuilder::GetResultImpl() { constexpr int allProperties = TEXTURE_PROPERTY_DIMENSION | TEXTURE_PROPERTY_EXTENT | TEXTURE_PROPERTY_FORMAT | TEXTURE_PROPERTY_MIP_LEVELS | TEXTURE_PROPERTY_ALLOWED_USAGE; if ((propertiesSet & allProperties) != allProperties) { @@ -136,7 +136,6 @@ namespace backend { // TODO(cwallez@chromium.org): check stuff based on the dimension - MarkConsumed(); return device->CreateTexture(this); } @@ -223,8 +222,7 @@ namespace backend { : Builder(device), texture(texture) { } - TextureViewBase* TextureViewBuilder::GetResult() { - MarkConsumed(); + TextureViewBase* TextureViewBuilder::GetResultImpl() { return device->CreateTextureView(this); } diff --git a/src/backend/common/Texture.h b/src/backend/common/Texture.h index 8090f26b8c..32d176472c 100644 --- a/src/backend/common/Texture.h +++ b/src/backend/common/Texture.h @@ -60,12 +60,11 @@ namespace backend { bool frozen = false; }; - class TextureBuilder : public Builder { + class TextureBuilder : public Builder { public: TextureBuilder(DeviceBase* device); // NXT API - TextureBase* GetResult(); void SetDimension(nxt::TextureDimension dimension); void SetExtent(uint32_t width, uint32_t height, uint32_t depth); void SetFormat(nxt::TextureFormat format); @@ -76,6 +75,8 @@ namespace backend { private: friend class TextureBase; + TextureBase* GetResultImpl() override; + int propertiesSet = 0; nxt::TextureDimension dimension; @@ -96,16 +97,15 @@ namespace backend { Ref texture; }; - class TextureViewBuilder : public Builder { + class TextureViewBuilder : public Builder { public: TextureViewBuilder(DeviceBase* device, TextureBase* texture); - // NXT API - TextureViewBase* GetResult(); - private: friend class TextureViewBase; + TextureViewBase* GetResultImpl() override; + Ref texture; };