Descriptorize ComputePipeline

Change-Id: Ic9d7014ba44d927d7f9ddf81a8870432c68941e8
This commit is contained in:
Corentin Wallez 2018-08-27 23:12:56 +02:00 committed by Corentin Wallez
parent eb7d64a17f
commit 8e335a5585
35 changed files with 239 additions and 166 deletions

View File

@ -500,27 +500,13 @@
"compute pipeline": {
"category": "object"
},
"compute pipeline builder": {
"category": "object",
"methods": [
{
"name": "get result",
"returns": "compute pipeline"
},
{
"name": "set layout",
"args": [
{"name": "layout", "type": "pipeline layout"}
]
},
{
"name": "set stage",
"args": [
{"name": "stage", "type": "shader stage"},
{"name": "module", "type": "shader module"},
{"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"}
]
}
"compute pipeline descriptor": {
"category": "structure",
"extensible": true,
"members": [
{"name": "layout", "type": "pipeline layout"},
{"name": "module", "type": "shader module"},
{"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"}
]
},
"device": {
@ -569,8 +555,11 @@
"returns": "input state builder"
},
{
"name": "create compute pipeline builder",
"returns": "compute pipeline builder"
"name": "create compute pipeline",
"returns": "compute pipeline",
"args": [
{"name": "descriptor", "type": "compute pipeline descriptor", "annotation": "const*"}
]
},
{
"name": "create render pipeline builder",

View File

@ -231,10 +231,11 @@ void initSim() {
dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
updatePipeline = device.CreateComputePipelineBuilder()
.SetLayout(pl)
.SetStage(dawn::ShaderStage::Compute, module, "main")
.GetResult();
dawn::ComputePipelineDescriptor csDesc;
csDesc.module = module.Clone();
csDesc.entryPoint = "main";
csDesc.layout = pl.Clone();
updatePipeline = device.CreateComputePipeline(&csDesc);
dawn::BufferView updateParamsView = updateParams.CreateBufferViewBuilder()
.SetExtent(0, sizeof(SimParams))

View File

@ -166,7 +166,10 @@ def link_structure(struct, types):
for (member, m) in zip(members, struct.record['members']):
# TODO(kainino@chromium.org): More robust pointer/length handling?
if 'length' in m:
member.length = members_by_name[m['length']]
if m['length'] == 'strlen':
member.length = 'strlen'
else:
member.length = members_by_name[m['length']]
def parse_json(json):
category_to_parser = {

View File

@ -56,15 +56,7 @@ namespace dawn {
{% endfor %}
{% for type in by_category["structure"] %}
struct {{as_cppType(type.name)}} {
{% if type.extensible %}
const void* nextInChain = nullptr;
{% endif %}
{% for member in type.members %}
{{as_annotated_cppType(member)}};
{% endfor %}
};
struct {{as_cppType(type.name)}};
{% endfor %}
template<typename Derived, typename CType>
@ -158,6 +150,18 @@ namespace dawn {
{% endfor %}
{% for type in by_category["structure"] %}
struct {{as_cppType(type.name)}} {
{% if type.extensible %}
const void* nextInChain = nullptr;
{% endif %}
{% for member in type.members %}
{{as_annotated_cppType(member)}};
{% endfor %}
};
{% endfor %}
} // namespace dawn
#endif // DAWN_DAWNCPP_H_

View File

@ -18,24 +18,31 @@
namespace dawn_native {
MaybeError ValidateComputePipelineDescriptor(DeviceBase*,
const ComputePipelineDescriptor* descriptor) {
DAWN_TRY_ASSERT(descriptor->nextInChain == nullptr, "nextInChain must be nullptr");
if (descriptor->entryPoint != std::string("main")) {
DAWN_RETURN_ERROR("Currently the entry point has to be main()");
}
if (descriptor->module->GetExecutionModel() != dawn::ShaderStage::Compute) {
DAWN_RETURN_ERROR("Setting module with wrong execution model");
}
if (!descriptor->module->IsCompatibleWithPipelineLayout(descriptor->layout)) {
DAWN_RETURN_ERROR("Stage not compatible with layout");
}
return {};
}
// ComputePipelineBase
ComputePipelineBase::ComputePipelineBase(ComputePipelineBuilder* builder)
: PipelineBase(builder) {
if (GetStageMask() != dawn::ShaderStageBit::Compute) {
builder->HandleError("Compute pipeline should have exactly a compute stage");
return;
}
}
// ComputePipelineBuilder
ComputePipelineBuilder::ComputePipelineBuilder(DeviceBase* device)
: Builder(device), PipelineBuilder(this) {
}
ComputePipelineBase* ComputePipelineBuilder::GetResultImpl() {
return mDevice->CreateComputePipeline(this);
ComputePipelineBase::ComputePipelineBase(DeviceBase* device,
const ComputePipelineDescriptor* descriptor)
: PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) {
ExtractModuleData(dawn::ShaderStage::Compute, descriptor->module);
}
} // namespace dawn_native

View File

@ -19,17 +19,14 @@
namespace dawn_native {
class DeviceBase;
MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
const ComputePipelineDescriptor* descriptor);
class ComputePipelineBase : public RefCounted, public PipelineBase {
public:
ComputePipelineBase(ComputePipelineBuilder* builder);
};
class ComputePipelineBuilder : public Builder<ComputePipelineBase>, public PipelineBuilder {
public:
ComputePipelineBuilder(DeviceBase* device);
private:
ComputePipelineBase* GetResultImpl() override;
ComputePipelineBase(DeviceBase* device, const ComputePipelineDescriptor* descriptor);
};
} // namespace dawn_native

View File

@ -123,8 +123,15 @@ namespace dawn_native {
CommandBufferBuilder* DeviceBase::CreateCommandBufferBuilder() {
return new CommandBufferBuilder(this);
}
ComputePipelineBuilder* DeviceBase::CreateComputePipelineBuilder() {
return new ComputePipelineBuilder(this);
ComputePipelineBase* DeviceBase::CreateComputePipeline(
const ComputePipelineDescriptor* descriptor) {
ComputePipelineBase* result = nullptr;
if (ConsumedError(CreateComputePipelineInternal(&result, descriptor))) {
return nullptr;
}
return result;
}
DepthStencilStateBuilder* DeviceBase::CreateDepthStencilStateBuilder() {
return new DepthStencilStateBuilder(this);
@ -223,6 +230,14 @@ namespace dawn_native {
return {};
}
MaybeError DeviceBase::CreateComputePipelineInternal(
ComputePipelineBase** result,
const ComputePipelineDescriptor* descriptor) {
DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
DAWN_TRY_ASSIGN(*result, CreateComputePipelineImpl(descriptor));
return {};
}
MaybeError DeviceBase::CreatePipelineLayoutInternal(
PipelineLayoutBase** result,
const PipelineLayoutDescriptor* descriptor) {

View File

@ -47,7 +47,6 @@ namespace dawn_native {
virtual BlendStateBase* CreateBlendState(BlendStateBuilder* builder) = 0;
virtual BufferViewBase* CreateBufferView(BufferViewBuilder* builder) = 0;
virtual CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) = 0;
virtual ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) = 0;
virtual DepthStencilStateBase* CreateDepthStencilState(
DepthStencilStateBuilder* builder) = 0;
virtual InputStateBase* CreateInputState(InputStateBuilder* builder) = 0;
@ -83,7 +82,7 @@ namespace dawn_native {
BlendStateBuilder* CreateBlendStateBuilder();
BufferBase* CreateBuffer(const BufferDescriptor* descriptor);
CommandBufferBuilder* CreateCommandBufferBuilder();
ComputePipelineBuilder* CreateComputePipelineBuilder();
ComputePipelineBase* CreateComputePipeline(const ComputePipelineDescriptor* descriptor);
DepthStencilStateBuilder* CreateDepthStencilStateBuilder();
InputStateBuilder* CreateInputStateBuilder();
PipelineLayoutBase* CreatePipelineLayout(const PipelineLayoutDescriptor* descriptor);
@ -108,6 +107,8 @@ namespace dawn_native {
virtual ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) = 0;
virtual ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) = 0;
virtual ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) = 0;
virtual ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) = 0;
virtual ResultOrError<QueueBase*> CreateQueueImpl() = 0;
@ -121,6 +122,8 @@ namespace dawn_native {
MaybeError CreateBindGroupLayoutInternal(BindGroupLayoutBase** result,
const BindGroupLayoutDescriptor* descriptor);
MaybeError CreateBufferInternal(BufferBase** result, const BufferDescriptor* descriptor);
MaybeError CreateComputePipelineInternal(ComputePipelineBase** result,
const ComputePipelineDescriptor* descriptor);
MaybeError CreatePipelineLayoutInternal(PipelineLayoutBase** result,
const PipelineLayoutDescriptor* descriptor);
MaybeError CreateQueueInternal(QueueBase** result);

View File

@ -48,6 +48,11 @@ namespace dawn_native {
template <typename T>
class PerStage {
public:
PerStage() = default;
PerStage(const T& initialValue) {
mData.fill(initialValue);
}
T& operator[](dawn::ShaderStage stage) {
DAWN_ASSERT(static_cast<uint32_t>(stage) < kNumStages);
return mData[static_cast<uint32_t>(stage)];

View File

@ -24,8 +24,14 @@ namespace dawn_native {
// PipelineBase
PipelineBase::PipelineBase(PipelineBuilder* builder)
: mStageMask(builder->mStageMask), mLayout(std::move(builder->mLayout)) {
PipelineBase::PipelineBase(DeviceBase* device,
PipelineLayoutBase* layout,
dawn::ShaderStageBit stages)
: mStageMask(stages), mLayout(layout), mDevice(device) {
}
PipelineBase::PipelineBase(DeviceBase* device, PipelineBuilder* builder)
: mStageMask(builder->mStageMask), mLayout(std::move(builder->mLayout)), mDevice(device) {
if (!mLayout) {
PipelineLayoutDescriptor descriptor;
descriptor.numBindGroupLayouts = 0;
@ -35,30 +41,32 @@ namespace dawn_native {
mLayout->Release();
}
auto FillPushConstants = [](const ShaderModuleBase* module, PushConstantInfo* info) {
const auto& moduleInfo = module->GetPushConstants();
info->mask = moduleInfo.mask;
for (uint32_t i = 0; i < moduleInfo.names.size(); i++) {
uint32_t size = moduleInfo.sizes[i];
if (size == 0) {
continue;
}
for (uint32_t offset = 0; offset < size; offset++) {
info->types[i + offset] = moduleInfo.types[i];
}
i += size - 1;
}
};
for (auto stageBit : IterateStages(builder->mStageMask)) {
if (!builder->mStages[stageBit].module->IsCompatibleWithPipelineLayout(mLayout.Get())) {
for (auto stage : IterateStages(builder->mStageMask)) {
if (!builder->mStages[stage].module->IsCompatibleWithPipelineLayout(mLayout.Get())) {
builder->GetParentBuilder()->HandleError("Stage not compatible with layout");
return;
}
FillPushConstants(builder->mStages[stageBit].module.Get(), &mPushConstants[stageBit]);
ExtractModuleData(stage, builder->mStages[stage].module.Get());
}
}
void PipelineBase::ExtractModuleData(dawn::ShaderStage stage, ShaderModuleBase* module) {
PushConstantInfo* info = &mPushConstants[stage];
const auto& moduleInfo = module->GetPushConstants();
info->mask = moduleInfo.mask;
for (uint32_t i = 0; i < moduleInfo.names.size(); i++) {
uint32_t size = moduleInfo.sizes[i];
if (size == 0) {
continue;
}
for (uint32_t offset = 0; offset < size; offset++) {
info->types[i + offset] = moduleInfo.types[i];
}
i += size - 1;
}
}
@ -75,6 +83,10 @@ namespace dawn_native {
return mLayout.Get();
}
DeviceBase* PipelineBase::GetDevice() const {
return mDevice;
}
// PipelineBuilder
PipelineBuilder::PipelineBuilder(BuilderBase* parentBuilder)

View File

@ -39,7 +39,8 @@ namespace dawn_native {
class PipelineBase {
public:
PipelineBase(PipelineBuilder* builder);
PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, dawn::ShaderStageBit stages);
PipelineBase(DeviceBase* device, PipelineBuilder* builder);
struct PushConstantInfo {
std::bitset<kMaxPushConstants> mask;
@ -49,11 +50,16 @@ namespace dawn_native {
dawn::ShaderStageBit GetStageMask() const;
PipelineLayoutBase* GetLayout();
DeviceBase* GetDevice() const;
protected:
void ExtractModuleData(dawn::ShaderStage stage, ShaderModuleBase* module);
private:
dawn::ShaderStageBit mStageMask;
Ref<PipelineLayoutBase> mLayout;
PerStage<PushConstantInfo> mPushConstants;
DeviceBase* mDevice;
};
class PipelineBuilder {

View File

@ -27,7 +27,7 @@ namespace dawn_native {
// RenderPipelineBase
RenderPipelineBase::RenderPipelineBase(RenderPipelineBuilder* builder)
: PipelineBase(builder),
: PipelineBase(builder->mDevice, builder),
mDepthStencilState(std::move(builder->mDepthStencilState)),
mIndexFormat(builder->mIndexFormat),
mInputState(std::move(builder->mInputState)),

View File

@ -22,8 +22,8 @@
namespace dawn_native { namespace d3d12 {
ComputePipeline::ComputePipeline(ComputePipelineBuilder* builder)
: ComputePipelineBase(builder), mDevice(ToBackend(builder->GetDevice())) {
ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
: ComputePipelineBase(device, descriptor) {
uint32_t compileFlags = 0;
#if defined(_DEBUG)
// Enable better shader debugging with the graphics debugging tools.
@ -32,33 +32,31 @@ namespace dawn_native { namespace d3d12 {
// SPRIV-cross does matrix multiplication expecting row major matrices
compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
const auto& module = ToBackend(builder->GetStageInfo(dawn::ShaderStage::Compute).module);
const auto& entryPoint = builder->GetStageInfo(dawn::ShaderStage::Compute).entryPoint;
const auto& hlslSource = module->GetHLSLSource();
const ShaderModule* module = ToBackend(descriptor->module);
const std::string& hlslSource = module->GetHLSLSource();
ComPtr<ID3DBlob> compiledShader;
ComPtr<ID3DBlob> errors;
const PlatformFunctions* functions = ToBackend(builder->GetDevice())->GetFunctions();
const PlatformFunctions* functions = device->GetFunctions();
if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
nullptr, entryPoint.c_str(), "cs_5_1", compileFlags, 0,
nullptr, descriptor->entryPoint, "cs_5_1", compileFlags, 0,
&compiledShader, &errors))) {
printf("%s\n", reinterpret_cast<char*>(errors->GetBufferPointer()));
ASSERT(false);
}
D3D12_COMPUTE_PIPELINE_STATE_DESC descriptor = {};
descriptor.pRootSignature = ToBackend(GetLayout())->GetRootSignature().Get();
descriptor.CS.pShaderBytecode = compiledShader->GetBufferPointer();
descriptor.CS.BytecodeLength = compiledShader->GetBufferSize();
D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature().Get();
d3dDesc.CS.pShaderBytecode = compiledShader->GetBufferPointer();
d3dDesc.CS.BytecodeLength = compiledShader->GetBufferSize();
Device* device = ToBackend(builder->GetDevice());
device->GetD3D12Device()->CreateComputePipelineState(&descriptor,
device->GetD3D12Device()->CreateComputePipelineState(&d3dDesc,
IID_PPV_ARGS(&mPipelineState));
}
ComputePipeline::~ComputePipeline() {
mDevice->ReferenceUntilUnused(mPipelineState);
ToBackend(GetDevice())->ReferenceUntilUnused(mPipelineState);
}
ComPtr<ID3D12PipelineState> ComputePipeline::GetPipelineState() {

View File

@ -25,14 +25,13 @@ namespace dawn_native { namespace d3d12 {
class ComputePipeline : public ComputePipelineBase {
public:
ComputePipeline(ComputePipelineBuilder* builder);
ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor);
~ComputePipeline();
ComPtr<ID3D12PipelineState> GetPipelineState();
private:
ComPtr<ID3D12PipelineState> mPipelineState;
Device* mDevice = nullptr;
};
}} // namespace dawn_native::d3d12

View File

@ -301,8 +301,9 @@ namespace dawn_native { namespace d3d12 {
CommandBufferBase* Device::CreateCommandBuffer(CommandBufferBuilder* builder) {
return new CommandBuffer(builder);
}
ComputePipelineBase* Device::CreateComputePipeline(ComputePipelineBuilder* builder) {
return new ComputePipeline(builder);
ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) {
return new ComputePipeline(this, descriptor);
}
DepthStencilStateBase* Device::CreateDepthStencilState(DepthStencilStateBuilder* builder) {
return new DepthStencilState(builder);

View File

@ -43,7 +43,6 @@ namespace dawn_native { namespace d3d12 {
BlendStateBase* CreateBlendState(BlendStateBuilder* builder) override;
BufferViewBase* CreateBufferView(BufferViewBuilder* builder) override;
CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) override;
ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) override;
DepthStencilStateBase* CreateDepthStencilState(DepthStencilStateBuilder* builder) override;
InputStateBase* CreateInputState(InputStateBuilder* builder) override;
RenderPassDescriptorBase* CreateRenderPassDescriptor(
@ -79,6 +78,8 @@ namespace dawn_native { namespace d3d12 {
ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) override;
ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) override;
ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) override;
ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;

View File

@ -21,9 +21,11 @@
namespace dawn_native { namespace metal {
class Device;
class ComputePipeline : public ComputePipelineBase {
public:
ComputePipeline(ComputePipelineBuilder* builder);
ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor);
~ComputePipeline();
void Encode(id<MTLComputeCommandEncoder> encoder);

View File

@ -19,22 +19,22 @@
namespace dawn_native { namespace metal {
ComputePipeline::ComputePipeline(ComputePipelineBuilder* builder)
: ComputePipelineBase(builder) {
auto mtlDevice = ToBackend(builder->GetDevice())->GetMTLDevice();
ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
: ComputePipelineBase(device, descriptor) {
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
const auto& module = ToBackend(builder->GetStageInfo(dawn::ShaderStage::Compute).module);
const auto& entryPoint = builder->GetStageInfo(dawn::ShaderStage::Compute).entryPoint;
const auto& module = ToBackend(descriptor->module);
const char* entryPoint = descriptor->entryPoint;
auto compilationData = module->GetFunction(entryPoint.c_str(), dawn::ShaderStage::Compute,
ToBackend(GetLayout()));
auto compilationData =
module->GetFunction(entryPoint, dawn::ShaderStage::Compute, ToBackend(GetLayout()));
NSError* error = nil;
mMtlComputePipelineState =
[mtlDevice newComputePipelineStateWithFunction:compilationData.function error:&error];
if (error != nil) {
NSLog(@" error => %@", error);
builder->HandleError("Error creating pipeline state");
GetDevice()->HandleError("Error creating pipeline state");
return;
}

View File

@ -39,7 +39,6 @@ namespace dawn_native { namespace metal {
BlendStateBase* CreateBlendState(BlendStateBuilder* builder) override;
BufferViewBase* CreateBufferView(BufferViewBuilder* builder) override;
CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) override;
ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) override;
DepthStencilStateBase* CreateDepthStencilState(DepthStencilStateBuilder* builder) override;
InputStateBase* CreateInputState(InputStateBuilder* builder) override;
RenderPassDescriptorBase* CreateRenderPassDescriptor(
@ -63,6 +62,8 @@ namespace dawn_native { namespace metal {
ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) override;
ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) override;
ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) override;
ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;

View File

@ -97,8 +97,9 @@ namespace dawn_native { namespace metal {
CommandBufferBase* Device::CreateCommandBuffer(CommandBufferBuilder* builder) {
return new CommandBuffer(builder);
}
ComputePipelineBase* Device::CreateComputePipeline(ComputePipelineBuilder* builder) {
return new ComputePipeline(builder);
ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) {
return new ComputePipeline(this, descriptor);
}
DepthStencilStateBase* Device::CreateDepthStencilState(DepthStencilStateBuilder* builder) {
return new DepthStencilState(builder);

View File

@ -52,8 +52,9 @@ namespace dawn_native { namespace null {
CommandBufferBase* Device::CreateCommandBuffer(CommandBufferBuilder* builder) {
return new CommandBuffer(builder);
}
ComputePipelineBase* Device::CreateComputePipeline(ComputePipelineBuilder* builder) {
return new ComputePipeline(builder);
ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) {
return new ComputePipeline(this, descriptor);
}
DepthStencilStateBase* Device::CreateDepthStencilState(DepthStencilStateBuilder* builder) {
return new DepthStencilState(builder);

View File

@ -99,7 +99,6 @@ namespace dawn_native { namespace null {
BlendStateBase* CreateBlendState(BlendStateBuilder* builder) override;
BufferViewBase* CreateBufferView(BufferViewBuilder* builder) override;
CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) override;
ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) override;
DepthStencilStateBase* CreateDepthStencilState(DepthStencilStateBuilder* builder) override;
InputStateBase* CreateInputState(InputStateBuilder* builder) override;
RenderPassDescriptorBase* CreateRenderPassDescriptor(
@ -117,6 +116,8 @@ namespace dawn_native { namespace null {
ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) override;
ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) override;
ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) override;
ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;

View File

@ -14,10 +14,16 @@
#include "dawn_native/opengl/ComputePipelineGL.h"
#include "dawn_native/opengl/DeviceGL.h"
namespace dawn_native { namespace opengl {
ComputePipeline::ComputePipeline(ComputePipelineBuilder* builder)
: ComputePipelineBase(builder), PipelineGL(this, builder) {
ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
: ComputePipelineBase(device, descriptor) {
PerStage<const ShaderModule*> modules(nullptr);
modules[dawn::ShaderStage::Compute] = ToBackend(descriptor->module);
PipelineGL::Initialize(ToBackend(descriptor->layout), modules);
}
void ComputePipeline::ApplyNow() {

View File

@ -23,9 +23,11 @@
namespace dawn_native { namespace opengl {
class Device;
class ComputePipeline : public ComputePipelineBase, public PipelineGL {
public:
ComputePipeline(ComputePipelineBuilder* builder);
ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor);
void ApplyNow();
};

View File

@ -65,8 +65,9 @@ namespace dawn_native { namespace opengl {
CommandBufferBase* Device::CreateCommandBuffer(CommandBufferBuilder* builder) {
return new CommandBuffer(builder);
}
ComputePipelineBase* Device::CreateComputePipeline(ComputePipelineBuilder* builder) {
return new ComputePipeline(builder);
ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) {
return new ComputePipeline(this, descriptor);
}
DepthStencilStateBase* Device::CreateDepthStencilState(DepthStencilStateBuilder* builder) {
return new DepthStencilState(builder);

View File

@ -36,7 +36,6 @@ namespace dawn_native { namespace opengl {
BlendStateBase* CreateBlendState(BlendStateBuilder* builder) override;
BufferViewBase* CreateBufferView(BufferViewBuilder* builder) override;
CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) override;
ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) override;
DepthStencilStateBase* CreateDepthStencilState(DepthStencilStateBuilder* builder) override;
InputStateBase* CreateInputState(InputStateBuilder* builder) override;
RenderPassDescriptorBase* CreateRenderPassDescriptor(
@ -51,6 +50,8 @@ namespace dawn_native { namespace opengl {
ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) override;
ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) override;
ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) override;
ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;

View File

@ -43,7 +43,11 @@ namespace dawn_native { namespace opengl {
} // namespace
PipelineGL::PipelineGL(PipelineBase* parent, PipelineBuilder* builder) {
PipelineGL::PipelineGL() {
}
void PipelineGL::Initialize(const PipelineLayout* layout,
const PerStage<const ShaderModule*>& modules) {
auto CreateShader = [](GLenum type, const char* source) -> GLuint {
GLuint shader = glCreateShader(type);
glShaderSource(shader, 1, &source, nullptr);
@ -91,10 +95,15 @@ namespace dawn_native { namespace opengl {
mProgram = glCreateProgram();
for (auto stage : IterateStages(parent->GetStageMask())) {
const ShaderModule* module = ToBackend(builder->GetStageInfo(stage).module.Get());
dawn::ShaderStageBit activeStages = dawn::ShaderStageBit::None;
for (dawn::ShaderStage stage : IterateStages(kAllStages)) {
if (modules[stage] != nullptr) {
activeStages |= StageBit(stage);
}
}
GLuint shader = CreateShader(GLShaderType(stage), module->GetSource());
for (dawn::ShaderStage stage : IterateStages(activeStages)) {
GLuint shader = CreateShader(GLShaderType(stage), modules[stage]->GetSource());
glAttachShader(mProgram, shader);
}
@ -114,16 +123,14 @@ namespace dawn_native { namespace opengl {
}
}
for (auto stage : IterateStages(parent->GetStageMask())) {
const ShaderModule* module = ToBackend(builder->GetStageInfo(stage).module.Get());
FillPushConstants(module, &mGlPushConstants[stage], mProgram);
for (dawn::ShaderStage stage : IterateStages(activeStages)) {
FillPushConstants(modules[stage], &mGlPushConstants[stage], mProgram);
}
glUseProgram(mProgram);
// The uniforms are part of the program state so we can pre-bind buffer units, texture units
// etc.
const auto& layout = ToBackend(parent->GetLayout());
const auto& indices = layout->GetBindingIndexInfo();
for (uint32_t group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
@ -159,10 +166,8 @@ namespace dawn_native { namespace opengl {
// Compute links between stages for combined samplers, then bind them to texture units
{
std::set<CombinedSampler> combinedSamplersSet;
for (auto stage : IterateStages(parent->GetStageMask())) {
const auto& module = ToBackend(builder->GetStageInfo(stage).module);
for (const auto& combined : module->GetCombinedSamplerInfo()) {
for (dawn::ShaderStage stage : IterateStages(activeStages)) {
for (const auto& combined : modules[stage]->GetCombinedSamplerInfo()) {
combinedSamplersSet.insert(combined);
}
}

View File

@ -25,11 +25,14 @@ namespace dawn_native { namespace opengl {
class Device;
class PersistentPipelineState;
class PipelineLayout;
class ShaderModule;
class PipelineGL {
public:
PipelineGL(PipelineBase* parent, PipelineBuilder* builder);
PipelineGL();
void Initialize(const PipelineLayout* layout, const PerStage<const ShaderModule*>& modules);
using GLPushConstantInfo = std::array<GLint, kMaxPushConstants>;
using BindingLocations =

View File

@ -43,8 +43,13 @@ namespace dawn_native { namespace opengl {
RenderPipeline::RenderPipeline(RenderPipelineBuilder* builder)
: RenderPipelineBase(builder),
PipelineGL(this, builder),
mGlPrimitiveTopology(GLPrimitiveTopology(GetPrimitiveTopology())) {
PerStage<const ShaderModule*> modules(nullptr);
for (dawn::ShaderStage stage : IterateStages(GetStageMask())) {
modules[stage] = ToBackend(builder->GetStageInfo(stage).module.Get());
}
PipelineGL::Initialize(ToBackend(GetLayout()), modules);
}
GLenum RenderPipeline::GetGLPrimitiveTopology() const {

View File

@ -21,34 +21,33 @@
namespace dawn_native { namespace vulkan {
ComputePipeline::ComputePipeline(ComputePipelineBuilder* builder)
: ComputePipelineBase(builder), mDevice(ToBackend(builder->GetDevice())) {
ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor)
: ComputePipelineBase(device, descriptor) {
VkComputePipelineCreateInfo createInfo;
createInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
createInfo.pNext = nullptr;
createInfo.flags = 0;
createInfo.layout = ToBackend(GetLayout())->GetHandle();
createInfo.layout = ToBackend(descriptor->layout)->GetHandle();
createInfo.basePipelineHandle = VK_NULL_HANDLE;
createInfo.basePipelineIndex = -1;
const auto& stageInfo = builder->GetStageInfo(dawn::ShaderStage::Compute);
createInfo.stage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
createInfo.stage.pNext = nullptr;
createInfo.stage.flags = 0;
createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
createInfo.stage.module = ToBackend(stageInfo.module)->GetHandle();
createInfo.stage.pName = stageInfo.entryPoint.c_str();
createInfo.stage.module = ToBackend(descriptor->module)->GetHandle();
createInfo.stage.pName = descriptor->entryPoint;
createInfo.stage.pSpecializationInfo = nullptr;
if (mDevice->fn.CreateComputePipelines(mDevice->GetVkDevice(), VK_NULL_HANDLE, 1,
&createInfo, nullptr, &mHandle) != VK_SUCCESS) {
if (device->fn.CreateComputePipelines(device->GetVkDevice(), VK_NULL_HANDLE, 1, &createInfo,
nullptr, &mHandle) != VK_SUCCESS) {
ASSERT(false);
}
}
ComputePipeline::~ComputePipeline() {
if (mHandle != VK_NULL_HANDLE) {
mDevice->GetFencedDeleter()->DeleteWhenUnused(mHandle);
ToBackend(GetDevice())->GetFencedDeleter()->DeleteWhenUnused(mHandle);
mHandle = VK_NULL_HANDLE;
}
}

View File

@ -25,14 +25,13 @@ namespace dawn_native { namespace vulkan {
class ComputePipeline : public ComputePipelineBase {
public:
ComputePipeline(ComputePipelineBuilder* builder);
ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor);
~ComputePipeline();
VkPipeline GetHandle() const;
private:
VkPipeline mHandle = VK_NULL_HANDLE;
Device* mDevice = nullptr;
};
}} // namespace dawn_native::vulkan

View File

@ -236,8 +236,9 @@ namespace dawn_native { namespace vulkan {
CommandBufferBase* Device::CreateCommandBuffer(CommandBufferBuilder* builder) {
return new CommandBuffer(builder);
}
ComputePipelineBase* Device::CreateComputePipeline(ComputePipelineBuilder* builder) {
return new ComputePipeline(builder);
ResultOrError<ComputePipelineBase*> Device::CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) {
return new ComputePipeline(this, descriptor);
}
DepthStencilStateBase* Device::CreateDepthStencilState(DepthStencilStateBuilder* builder) {
return new DepthStencilState(builder);

View File

@ -67,7 +67,6 @@ namespace dawn_native { namespace vulkan {
BlendStateBase* CreateBlendState(BlendStateBuilder* builder) override;
BufferViewBase* CreateBufferView(BufferViewBuilder* builder) override;
CommandBufferBase* CreateCommandBuffer(CommandBufferBuilder* builder) override;
ComputePipelineBase* CreateComputePipeline(ComputePipelineBuilder* builder) override;
DepthStencilStateBase* CreateDepthStencilState(DepthStencilStateBuilder* builder) override;
InputStateBase* CreateInputState(InputStateBuilder* builder) override;
RenderPassDescriptorBase* CreateRenderPassDescriptor(
@ -82,6 +81,8 @@ namespace dawn_native { namespace vulkan {
ResultOrError<BindGroupLayoutBase*> CreateBindGroupLayoutImpl(
const BindGroupLayoutDescriptor* descriptor) override;
ResultOrError<BufferBase*> CreateBufferImpl(const BufferDescriptor* descriptor) override;
ResultOrError<ComputePipelineBase*> CreateComputePipelineImpl(
const ComputePipelineDescriptor* descriptor) override;
ResultOrError<PipelineLayoutBase*> CreatePipelineLayoutImpl(
const PipelineLayoutDescriptor* descriptor) override;
ResultOrError<QueueBase*> CreateQueueImpl() override;

View File

@ -37,10 +37,12 @@ void ComputeCopyStorageBufferTests::BasicTest(const char* shader) {
// Set up shader and pipeline
auto module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, shader);
auto pl = utils::MakeBasicPipelineLayout(device, &bgl);
auto pipeline = device.CreateComputePipelineBuilder()
.SetLayout(pl)
.SetStage(dawn::ShaderStage::Compute, module, "main")
.GetResult();
dawn::ComputePipelineDescriptor csDesc;
csDesc.module = module.Clone();
csDesc.entryPoint = "main";
csDesc.layout = pl.Clone();
dawn::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc);
// Set up src storage buffer
dawn::BufferDescriptor srcDesc;

View File

@ -145,10 +145,11 @@ class PushConstantTest: public DawnTest {
})").c_str()
);
return device.CreateComputePipelineBuilder()
.SetLayout(pl)
.SetStage(dawn::ShaderStage::Compute, module, "main")
.GetResult();
dawn::ComputePipelineDescriptor descriptor;
descriptor.module = module.Clone();
descriptor.entryPoint = "main";
descriptor.layout = pl.Clone();
return device.CreateComputePipeline(&descriptor);
}
dawn::PipelineLayout MakeEmptyLayout() {