Descriptorize ShaderModule

Change-Id: Ic79d00380f583485de0fb05bd47b1f869919ebe6
This commit is contained in:
Corentin Wallez 2018-08-20 17:01:20 +02:00 committed by Corentin Wallez
parent 3ccde9ce72
commit df6710358b
28 changed files with 113 additions and 127 deletions

View File

@ -596,8 +596,11 @@
] ]
}, },
{ {
"name": "create shader module builder", "name": "create shader module",
"returns": "shader module builder" "returns": "shader module",
"args": [
{"name": "descriptor", "type": "shader module descriptor", "annotation": "const*"}
]
}, },
{ {
"name": "create swap chain builder", "name": "create swap chain builder",
@ -910,21 +913,13 @@
"shader module": { "shader module": {
"category": "object" "category": "object"
}, },
"shader module builder": { "shader module descriptor": {
"category": "object", "category": "structure",
"methods": [ "extensible": true,
{ "members": [
"name": "get result",
"returns": "shader module"
},
{
"name": "set source",
"args": [
{"name": "code size", "type": "uint32_t"}, {"name": "code size", "type": "uint32_t"},
{"name": "code", "type": "uint32_t", "annotation": "const*", "length": "code size"} {"name": "code", "type": "uint32_t", "annotation": "const*", "length": "code size"}
] ]
}
]
}, },
"shader stage": { "shader stage": {
"category": "enum", "category": "enum",

View File

@ -160,8 +160,14 @@ namespace dawn_native {
return result; return result;
} }
ShaderModuleBuilder* DeviceBase::CreateShaderModuleBuilder() { ShaderModuleBase* DeviceBase::CreateShaderModule(const ShaderModuleDescriptor* descriptor) {
return new ShaderModuleBuilder(this); ShaderModuleBase* result = nullptr;
if (ConsumedError(CreateShaderModuleInternal(&result, descriptor))) {
return nullptr;
}
return result;
} }
SwapChainBuilder* DeviceBase::CreateSwapChainBuilder() { SwapChainBuilder* DeviceBase::CreateSwapChainBuilder() {
return new SwapChainBuilder(this); return new SwapChainBuilder(this);
@ -219,6 +225,13 @@ namespace dawn_native {
return {}; return {};
} }
MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result,
const ShaderModuleDescriptor* descriptor) {
DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor));
DAWN_TRY_ASSIGN(*result, CreateShaderModuleImpl(descriptor));
return {};
}
// Other implementation details // Other implementation details
void DeviceBase::ConsumeError(ErrorData* error) { void DeviceBase::ConsumeError(ErrorData* error) {

View File

@ -55,7 +55,6 @@ namespace dawn_native {
virtual RenderPassDescriptorBase* CreateRenderPassDescriptor( virtual RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) = 0; RenderPassDescriptorBuilder* builder) = 0;
virtual RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) = 0; virtual RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) = 0;
virtual ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) = 0;
virtual SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) = 0; virtual SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) = 0;
virtual TextureBase* CreateTexture(TextureBuilder* builder) = 0; virtual TextureBase* CreateTexture(TextureBuilder* builder) = 0;
virtual TextureViewBase* CreateTextureView(TextureViewBuilder* builder) = 0; virtual TextureViewBase* CreateTextureView(TextureViewBuilder* builder) = 0;
@ -94,7 +93,7 @@ namespace dawn_native {
RenderPassDescriptorBuilder* CreateRenderPassDescriptorBuilder(); RenderPassDescriptorBuilder* CreateRenderPassDescriptorBuilder();
RenderPipelineBuilder* CreateRenderPipelineBuilder(); RenderPipelineBuilder* CreateRenderPipelineBuilder();
SamplerBase* CreateSampler(const SamplerDescriptor* descriptor); SamplerBase* CreateSampler(const SamplerDescriptor* descriptor);
ShaderModuleBuilder* CreateShaderModuleBuilder(); ShaderModuleBase* CreateShaderModule(const ShaderModuleDescriptor* descriptor);
SwapChainBuilder* CreateSwapChainBuilder(); SwapChainBuilder* CreateSwapChainBuilder();
TextureBuilder* CreateTextureBuilder(); TextureBuilder* CreateTextureBuilder();
@ -111,6 +110,8 @@ namespace dawn_native {
virtual ResultOrError<QueueBase*> CreateQueueImpl() = 0; virtual ResultOrError<QueueBase*> CreateQueueImpl() = 0;
virtual ResultOrError<SamplerBase*> CreateSamplerImpl( virtual ResultOrError<SamplerBase*> CreateSamplerImpl(
const SamplerDescriptor* descriptor) = 0; const SamplerDescriptor* descriptor) = 0;
virtual ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) = 0;
MaybeError CreateBindGroupLayoutInternal(BindGroupLayoutBase** result, MaybeError CreateBindGroupLayoutInternal(BindGroupLayoutBase** result,
const BindGroupLayoutDescriptor* descriptor); const BindGroupLayoutDescriptor* descriptor);
@ -118,6 +119,8 @@ namespace dawn_native {
const PipelineLayoutDescriptor* descriptor); const PipelineLayoutDescriptor* descriptor);
MaybeError CreateQueueInternal(QueueBase** result); MaybeError CreateQueueInternal(QueueBase** result);
MaybeError CreateSamplerInternal(SamplerBase** result, const SamplerDescriptor* descriptor); MaybeError CreateSamplerInternal(SamplerBase** result, const SamplerDescriptor* descriptor);
MaybeError CreateShaderModuleInternal(ShaderModuleBase** result,
const ShaderModuleDescriptor* descriptor);
void ConsumeError(ErrorData* error); void ConsumeError(ErrorData* error);

View File

@ -23,7 +23,15 @@
namespace dawn_native { namespace dawn_native {
ShaderModuleBase::ShaderModuleBase(ShaderModuleBuilder* builder) : mDevice(builder->mDevice) { MaybeError ValidateShaderModuleDescriptor(DeviceBase*, const ShaderModuleDescriptor*) {
// TODO(cwallez@chromium.org): Use spirv-val to check the module is well-formed
return {};
}
// ShaderModuleBase
ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor*)
: mDevice(device) {
} }
DeviceBase* ShaderModuleBase::GetDevice() const { DeviceBase* ShaderModuleBase::GetDevice() const {
@ -218,24 +226,4 @@ namespace dawn_native {
return true; return true;
} }
ShaderModuleBuilder::ShaderModuleBuilder(DeviceBase* device) : Builder(device) {
}
std::vector<uint32_t> ShaderModuleBuilder::AcquireSpirv() {
return std::move(mSpirv);
}
ShaderModuleBase* ShaderModuleBuilder::GetResultImpl() {
if (mSpirv.size() == 0) {
HandleError("Shader module needs to have the source set");
return nullptr;
}
return mDevice->CreateShaderModule(this);
}
void ShaderModuleBuilder::SetSource(uint32_t codeSize, const uint32_t* code) {
mSpirv.assign(code, code + codeSize);
}
} // namespace dawn_native } // namespace dawn_native

View File

@ -17,6 +17,7 @@
#include "common/Constants.h" #include "common/Constants.h"
#include "dawn_native/Builder.h" #include "dawn_native/Builder.h"
#include "dawn_native/Error.h"
#include "dawn_native/Forward.h" #include "dawn_native/Forward.h"
#include "dawn_native/RefCounted.h" #include "dawn_native/RefCounted.h"
@ -32,9 +33,12 @@ namespace spirv_cross {
namespace dawn_native { namespace dawn_native {
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
const ShaderModuleDescriptor* descriptor);
class ShaderModuleBase : public RefCounted { class ShaderModuleBase : public RefCounted {
public: public:
ShaderModuleBase(ShaderModuleBuilder* builder); ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
DeviceBase* GetDevice() const; DeviceBase* GetDevice() const;
@ -75,23 +79,6 @@ namespace dawn_native {
dawn::ShaderStage mExecutionModel; dawn::ShaderStage mExecutionModel;
}; };
class ShaderModuleBuilder : public Builder<ShaderModuleBase> {
public:
ShaderModuleBuilder(DeviceBase* device);
std::vector<uint32_t> AcquireSpirv();
// Dawn API
void SetSource(uint32_t codeSize, const uint32_t* code);
private:
friend class ShaderModuleBase;
ShaderModuleBase* GetResultImpl() override;
std::vector<uint32_t> mSpirv;
};
} // namespace dawn_native } // namespace dawn_native
#endif // DAWNNATIVE_SHADERMODULE_H_ #endif // DAWNNATIVE_SHADERMODULE_H_

View File

@ -304,8 +304,9 @@ namespace dawn_native { namespace d3d12 {
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) { ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor); return new Sampler(this, descriptor);
} }
ShaderModuleBase* Device::CreateShaderModule(ShaderModuleBuilder* builder) { ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
return new ShaderModule(builder); const ShaderModuleDescriptor* descriptor) {
return new ShaderModule(this, descriptor);
} }
SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) { SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) {
return new SwapChain(builder); return new SwapChain(builder);

View File

@ -49,7 +49,6 @@ namespace dawn_native { namespace d3d12 {
RenderPassDescriptorBase* CreateRenderPassDescriptor( RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) override; RenderPassDescriptorBuilder* builder) override;
RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override; RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override;
ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) override;
SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override; SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override;
TextureBase* CreateTexture(TextureBuilder* builder) override; TextureBase* CreateTexture(TextureBuilder* builder) override;
TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override; TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override;
@ -83,6 +82,8 @@ namespace dawn_native { namespace d3d12 {
const PipelineLayoutDescriptor* descriptor) override; const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override; ResultOrError<QueueBase*> CreateQueueImpl() override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) override;
uint64_t mSerial = 0; uint64_t mSerial = 0;
ComPtr<ID3D12Fence> mFence; ComPtr<ID3D12Fence> mFence;

View File

@ -15,6 +15,7 @@
#include "dawn_native/d3d12/ShaderModuleD3D12.h" #include "dawn_native/d3d12/ShaderModuleD3D12.h"
#include "common/Assert.h" #include "common/Assert.h"
#include "dawn_native/d3d12/DeviceD3D12.h"
#include <spirv-cross/spirv_hlsl.hpp> #include <spirv-cross/spirv_hlsl.hpp>
@ -44,8 +45,9 @@ namespace dawn_native { namespace d3d12 {
std::array<T, kNumBindingTypes> mMap{}; std::array<T, kNumBindingTypes> mMap{};
}; };
ShaderModule::ShaderModule(ShaderModuleBuilder* builder) : ShaderModuleBase(builder) { ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
spirv_cross::CompilerHLSL compiler(builder->AcquireSpirv()); : ShaderModuleBase(device, descriptor) {
spirv_cross::CompilerHLSL compiler(descriptor->code, descriptor->codeSize);
spirv_cross::CompilerGLSL::Options options_glsl; spirv_cross::CompilerGLSL::Options options_glsl;
options_glsl.vertex.fixup_clipspace = true; options_glsl.vertex.fixup_clipspace = true;

View File

@ -23,7 +23,7 @@ namespace dawn_native { namespace d3d12 {
class ShaderModule : public ShaderModuleBase { class ShaderModule : public ShaderModuleBase {
public: public:
ShaderModule(ShaderModuleBuilder* builder); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
const std::string& GetHLSLSource() const; const std::string& GetHLSLSource() const;

View File

@ -46,7 +46,6 @@ namespace dawn_native { namespace metal {
RenderPassDescriptorBase* CreateRenderPassDescriptor( RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) override; RenderPassDescriptorBuilder* builder) override;
RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override; RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override;
ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) override;
SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override; SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override;
TextureBase* CreateTexture(TextureBuilder* builder) override; TextureBase* CreateTexture(TextureBuilder* builder) override;
TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override; TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override;
@ -69,6 +68,8 @@ namespace dawn_native { namespace metal {
const PipelineLayoutDescriptor* descriptor) override; const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override; ResultOrError<QueueBase*> CreateQueueImpl() override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) override;
void OnCompletedHandler(); void OnCompletedHandler();

View File

@ -123,8 +123,9 @@ namespace dawn_native { namespace metal {
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) { ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor); return new Sampler(this, descriptor);
} }
ShaderModuleBase* Device::CreateShaderModule(ShaderModuleBuilder* builder) { ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
return new ShaderModule(builder); const ShaderModuleDescriptor* descriptor) {
return new ShaderModule(this, descriptor);
} }
SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) { SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) {
return new SwapChain(builder); return new SwapChain(builder);

View File

@ -25,11 +25,12 @@ namespace spirv_cross {
namespace dawn_native { namespace metal { namespace dawn_native { namespace metal {
class Device;
class PipelineLayout; class PipelineLayout;
class ShaderModule : public ShaderModuleBase { class ShaderModule : public ShaderModuleBase {
public: public:
ShaderModule(ShaderModuleBuilder* builder); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
struct MetalFunctionData { struct MetalFunctionData {
id<MTLFunction> function; id<MTLFunction> function;

View File

@ -40,8 +40,9 @@ namespace dawn_native { namespace metal {
} }
} }
ShaderModule::ShaderModule(ShaderModuleBuilder* builder) ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
: ShaderModuleBase(builder), mSpirv(builder->AcquireSpirv()) { : ShaderModuleBase(device, descriptor) {
mSpirv.assign(descriptor->code, descriptor->code + descriptor->codeSize);
spirv_cross::CompilerMSL compiler(mSpirv); spirv_cross::CompilerMSL compiler(mSpirv);
ExtractSpirvInfo(compiler); ExtractSpirvInfo(compiler);
} }

View File

@ -78,10 +78,11 @@ namespace dawn_native { namespace null {
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) { ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor); return new Sampler(this, descriptor);
} }
ShaderModuleBase* Device::CreateShaderModule(ShaderModuleBuilder* builder) { ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
auto module = new ShaderModule(builder); const ShaderModuleDescriptor* descriptor) {
auto module = new ShaderModule(this, descriptor);
spirv_cross::Compiler compiler(builder->AcquireSpirv()); spirv_cross::Compiler compiler(descriptor->code, descriptor->codeSize);
module->ExtractSpirvInfo(compiler); module->ExtractSpirvInfo(compiler);
return module; return module;

View File

@ -106,7 +106,6 @@ namespace dawn_native { namespace null {
RenderPassDescriptorBase* CreateRenderPassDescriptor( RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) override; RenderPassDescriptorBuilder* builder) override;
RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override; RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override;
ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) override;
SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override; SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override;
TextureBase* CreateTexture(TextureBuilder* builder) override; TextureBase* CreateTexture(TextureBuilder* builder) override;
TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override; TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override;
@ -123,6 +122,8 @@ namespace dawn_native { namespace null {
const PipelineLayoutDescriptor* descriptor) override; const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override; ResultOrError<QueueBase*> CreateQueueImpl() override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) override;
std::vector<std::unique_ptr<PendingOperation>> mPendingOperations; std::vector<std::unique_ptr<PendingOperation>> mPendingOperations;
}; };

View File

@ -91,8 +91,9 @@ namespace dawn_native { namespace opengl {
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) { ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor); return new Sampler(this, descriptor);
} }
ShaderModuleBase* Device::CreateShaderModule(ShaderModuleBuilder* builder) { ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
return new ShaderModule(builder); const ShaderModuleDescriptor* descriptor) {
return new ShaderModule(this, descriptor);
} }
SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) { SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) {
return new SwapChain(builder); return new SwapChain(builder);

View File

@ -43,7 +43,6 @@ namespace dawn_native { namespace opengl {
RenderPassDescriptorBase* CreateRenderPassDescriptor( RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) override; RenderPassDescriptorBuilder* builder) override;
RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override; RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override;
ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) override;
SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override; SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override;
TextureBase* CreateTexture(TextureBuilder* builder) override; TextureBase* CreateTexture(TextureBuilder* builder) override;
TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override; TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override;
@ -57,6 +56,8 @@ namespace dawn_native { namespace opengl {
const PipelineLayoutDescriptor* descriptor) override; const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override; ResultOrError<QueueBase*> CreateQueueImpl() override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) override;
}; };
}} // namespace dawn_native::opengl }} // namespace dawn_native::opengl

View File

@ -16,6 +16,7 @@
#include "common/Assert.h" #include "common/Assert.h"
#include "common/Platform.h" #include "common/Platform.h"
#include "dawn_native/opengl/DeviceGL.h"
#include <spirv-cross/spirv_glsl.hpp> #include <spirv-cross/spirv_glsl.hpp>
@ -46,8 +47,9 @@ namespace dawn_native { namespace opengl {
return o.str(); return o.str();
} }
ShaderModule::ShaderModule(ShaderModuleBuilder* builder) : ShaderModuleBase(builder) { ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
spirv_cross::CompilerGLSL compiler(builder->AcquireSpirv()); : ShaderModuleBase(device, descriptor) {
spirv_cross::CompilerGLSL compiler(descriptor->code, descriptor->codeSize);
spirv_cross::CompilerGLSL::Options options; spirv_cross::CompilerGLSL::Options options;
// TODO(cwallez@chromium.org): discover the backing context version and use that. // TODO(cwallez@chromium.org): discover the backing context version and use that.

View File

@ -40,7 +40,7 @@ namespace dawn_native { namespace opengl {
class ShaderModule : public ShaderModuleBase { class ShaderModule : public ShaderModuleBase {
public: public:
ShaderModule(ShaderModuleBuilder* builder); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
using CombinedSamplerInfo = std::vector<CombinedSampler>; using CombinedSamplerInfo = std::vector<CombinedSampler>;

View File

@ -262,8 +262,9 @@ namespace dawn_native { namespace vulkan {
ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) { ResultOrError<SamplerBase*> Device::CreateSamplerImpl(const SamplerDescriptor* descriptor) {
return new Sampler(this, descriptor); return new Sampler(this, descriptor);
} }
ShaderModuleBase* Device::CreateShaderModule(ShaderModuleBuilder* builder) { ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
return new ShaderModule(builder); const ShaderModuleDescriptor* descriptor) {
return new ShaderModule(this, descriptor);
} }
SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) { SwapChainBase* Device::CreateSwapChain(SwapChainBuilder* builder) {
return new SwapChain(builder); return new SwapChain(builder);

View File

@ -74,7 +74,6 @@ namespace dawn_native { namespace vulkan {
RenderPassDescriptorBase* CreateRenderPassDescriptor( RenderPassDescriptorBase* CreateRenderPassDescriptor(
RenderPassDescriptorBuilder* builder) override; RenderPassDescriptorBuilder* builder) override;
RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override; RenderPipelineBase* CreateRenderPipeline(RenderPipelineBuilder* builder) override;
ShaderModuleBase* CreateShaderModule(ShaderModuleBuilder* builder) override;
SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override; SwapChainBase* CreateSwapChain(SwapChainBuilder* builder) override;
TextureBase* CreateTexture(TextureBuilder* builder) override; TextureBase* CreateTexture(TextureBuilder* builder) override;
TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override; TextureViewBase* CreateTextureView(TextureViewBuilder* builder) override;
@ -88,6 +87,8 @@ namespace dawn_native { namespace vulkan {
const PipelineLayoutDescriptor* descriptor) override; const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override; ResultOrError<QueueBase*> CreateQueueImpl() override;
ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override; ResultOrError<SamplerBase*> CreateSamplerImpl(const SamplerDescriptor* descriptor) override;
ResultOrError<ShaderModuleBase*> CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) override;
bool CreateInstance(VulkanGlobalKnobs* usedKnobs, bool CreateInstance(VulkanGlobalKnobs* usedKnobs,
const std::vector<const char*>& requiredExtensions); const std::vector<const char*>& requiredExtensions);

View File

@ -21,22 +21,19 @@
namespace dawn_native { namespace vulkan { namespace dawn_native { namespace vulkan {
ShaderModule::ShaderModule(ShaderModuleBuilder* builder) : ShaderModuleBase(builder) { ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor)
std::vector<uint32_t> spirv = builder->AcquireSpirv(); : ShaderModuleBase(device, descriptor) {
// Use SPIRV-Cross to extract info from the SPIRV even if Vulkan consumes SPIRV. We want to // Use SPIRV-Cross to extract info from the SPIRV even if Vulkan consumes SPIRV. We want to
// have a translation step eventually anyway. // have a translation step eventually anyway.
spirv_cross::Compiler compiler(spirv); spirv_cross::Compiler compiler(descriptor->code, descriptor->codeSize);
ExtractSpirvInfo(compiler); ExtractSpirvInfo(compiler);
VkShaderModuleCreateInfo createInfo; VkShaderModuleCreateInfo createInfo;
createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
createInfo.pNext = nullptr; createInfo.pNext = nullptr;
createInfo.flags = 0; createInfo.flags = 0;
createInfo.codeSize = spirv.size() * sizeof(uint32_t); createInfo.codeSize = descriptor->codeSize * sizeof(uint32_t);
createInfo.pCode = spirv.data(); createInfo.pCode = descriptor->code;
Device* device = ToBackend(GetDevice());
if (device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &mHandle) != if (device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &mHandle) !=
VK_SUCCESS) { VK_SUCCESS) {

View File

@ -21,9 +21,11 @@
namespace dawn_native { namespace vulkan { namespace dawn_native { namespace vulkan {
class Device;
class ShaderModule : public ShaderModuleBase { class ShaderModule : public ShaderModuleBase {
public: public:
ShaderModule(ShaderModuleBuilder* builder); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule(); ~ShaderModule();
VkShaderModule GetHandle() const; VkShaderModule GetHandle() const;

View File

@ -295,15 +295,13 @@ TEST_F(WireTests, ValueArrayArgument) {
// Test that the wire is able to send C strings // Test that the wire is able to send C strings
TEST_F(WireTests, CStringArgument) { TEST_F(WireTests, CStringArgument) {
// Create shader module // Create shader module
dawnShaderModuleBuilder shaderModuleBuilder = dawnDeviceCreateShaderModuleBuilder(device); dawnShaderModuleDescriptor descriptor;
dawnShaderModule shaderModule = dawnShaderModuleBuilderGetResult(shaderModuleBuilder); descriptor.nextInChain = nullptr;
descriptor.codeSize = 0;
dawnShaderModuleBuilder apiShaderModuleBuilder = api.GetNewShaderModuleBuilder(); dawnShaderModule shaderModule = dawnDeviceCreateShaderModule(device, &descriptor);
EXPECT_CALL(api, DeviceCreateShaderModuleBuilder(apiDevice))
.WillOnce(Return(apiShaderModuleBuilder));
dawnShaderModule apiShaderModule = api.GetNewShaderModule(); dawnShaderModule apiShaderModule = api.GetNewShaderModule();
EXPECT_CALL(api, ShaderModuleBuilderGetResult(apiShaderModuleBuilder)) EXPECT_CALL(api, DeviceCreateShaderModule(apiDevice, _))
.WillOnce(Return(apiShaderModule)); .WillOnce(Return(apiShaderModule));
// Create pipeline // Create pipeline

View File

@ -25,19 +25,14 @@ class InputStateTest : public ValidationTest {
dawn::RenderPipeline CreatePipeline(bool success, const dawn::InputState& inputState, std::string vertexSource) { dawn::RenderPipeline CreatePipeline(bool success, const dawn::InputState& inputState, std::string vertexSource) {
DummyRenderPass renderpassData = CreateDummyRenderPass(); DummyRenderPass renderpassData = CreateDummyRenderPass();
dawn::ShaderModuleBuilder vsModuleBuilder = AssertWillBeSuccess(device.CreateShaderModuleBuilder()); dawn::ShaderModule vsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, vertexSource.c_str());
utils::FillShaderModuleBuilder(vsModuleBuilder, dawn::ShaderStage::Vertex, vertexSource.c_str()); dawn::ShaderModule fsModule = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
dawn::ShaderModule vsModule = vsModuleBuilder.GetResult();
dawn::ShaderModuleBuilder fsModuleBuilder = AssertWillBeSuccess(device.CreateShaderModuleBuilder());
utils::FillShaderModuleBuilder(fsModuleBuilder, dawn::ShaderStage::Fragment, R"(
#version 450 #version 450
layout(location = 0) out vec4 fragColor; layout(location = 0) out vec4 fragColor;
void main() { void main() {
fragColor = vec4(1.0, 0.0, 0.0, 1.0); fragColor = vec4(1.0, 0.0, 0.0, 1.0);
} }
)"); )");
dawn::ShaderModule fsModule = fsModuleBuilder.GetResult();
dawn::RenderPipelineBuilder builder; dawn::RenderPipelineBuilder builder;
if (success) { if (success) {

View File

@ -27,14 +27,12 @@ class PushConstantTest : public ValidationTest {
uint32_t constants[kMaxPushConstants] = {0}; uint32_t constants[kMaxPushConstants] = {0};
void TestCreateShaderModule(bool success, std::string vertexSource) { void TestCreateShaderModule(bool success, std::string vertexSource) {
dawn::ShaderModuleBuilder builder; dawn::ShaderModule module;
if (success) { if (success) {
builder = AssertWillBeSuccess(device.CreateShaderModuleBuilder()); module = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, vertexSource.c_str());
} else { } else {
builder = AssertWillBeError(device.CreateShaderModuleBuilder()); ASSERT_DEVICE_ERROR(module = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, vertexSource.c_str()));
} }
utils::FillShaderModuleBuilder(builder, dawn::ShaderStage::Vertex, vertexSource.c_str());
builder.GetResult();
} }
private: private:

View File

@ -25,7 +25,7 @@
namespace utils { namespace utils {
void FillShaderModuleBuilder(const dawn::ShaderModuleBuilder& builder, dawn::ShaderModule CreateShaderModule(const dawn::Device& device,
dawn::ShaderStage stage, dawn::ShaderStage stage,
const char* source) { const char* source) {
shaderc::Compiler compiler; shaderc::Compiler compiler;
@ -49,7 +49,7 @@ namespace utils {
auto result = compiler.CompileGlslToSpv(source, strlen(source), kind, "myshader?", options); auto result = compiler.CompileGlslToSpv(source, strlen(source), kind, "myshader?", options);
if (result.GetCompilationStatus() != shaderc_compilation_status_success) { if (result.GetCompilationStatus() != shaderc_compilation_status_success) {
std::cerr << result.GetErrorMessage(); std::cerr << result.GetErrorMessage();
return; return {};
} }
// result.cend and result.cbegin return pointers to uint32_t. // result.cend and result.cbegin return pointers to uint32_t.
@ -58,7 +58,10 @@ namespace utils {
// So this size is in units of sizeof(uint32_t). // So this size is in units of sizeof(uint32_t).
ptrdiff_t resultSize = resultEnd - resultBegin; ptrdiff_t resultSize = resultEnd - resultBegin;
// SetSource takes data as uint32_t*. // SetSource takes data as uint32_t*.
builder.SetSource(static_cast<uint32_t>(resultSize), result.cbegin());
dawn::ShaderModuleDescriptor descriptor;
descriptor.codeSize = static_cast<uint32_t>(resultSize);
descriptor.code = result.cbegin();
#ifdef DUMP_SPIRV_ASSEMBLY #ifdef DUMP_SPIRV_ASSEMBLY
{ {
@ -87,14 +90,8 @@ namespace utils {
printf("\n"); printf("\n");
printf("SPIRV JS ARRAY DUMP END\n"); printf("SPIRV JS ARRAY DUMP END\n");
#endif #endif
}
dawn::ShaderModule CreateShaderModule(const dawn::Device& device, return device.CreateShaderModule(&descriptor);
dawn::ShaderStage stage,
const char* source) {
dawn::ShaderModuleBuilder builder = device.CreateShaderModuleBuilder();
FillShaderModuleBuilder(builder, stage, source);
return builder.GetResult();
} }
dawn::Buffer CreateBufferFromData(const dawn::Device& device, dawn::Buffer CreateBufferFromData(const dawn::Device& device,

View File

@ -18,9 +18,6 @@
namespace utils { namespace utils {
void FillShaderModuleBuilder(const dawn::ShaderModuleBuilder& builder,
dawn::ShaderStage stage,
const char* source);
dawn::ShaderModule CreateShaderModule(const dawn::Device& device, dawn::ShaderModule CreateShaderModule(const dawn::Device& device,
dawn::ShaderStage stage, dawn::ShaderStage stage,
const char* source); const char* source);