Remove descriptor from the parameter of ComputePipeline::Initialize()

This patch removes the parameter "descriptor" in the function
ComputePipeline::Initialize() so that we don't need to define
FlatComputePipelineDescriptor right now.

For render pipeline, as descriptor->vertex is being used for vertex
pulling (passed into vertexModule->CreateFunction()), we will first
refactor the related code in vertex pulling before removing the
parameter "descriptor" in the function RenderPipeline::Initialize().

BUG=dawn:529

Change-Id: Ib172ac0c76fa24070e78c0e57c3262acad9399b9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/64000
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Jiawei Shao <jiawei.shao@intel.com>
This commit is contained in:
Jiawei Shao 2021-09-11 09:04:34 +00:00 committed by Dawn LUCI CQ
parent b21ccebac8
commit 42448dafb4
20 changed files with 111 additions and 154 deletions

View File

@ -19,31 +19,6 @@
namespace dawn_native { namespace dawn_native {
FlatComputePipelineDescriptor::FlatComputePipelineDescriptor(
const ComputePipelineDescriptor* descriptor)
: mLabel(descriptor->label != nullptr ? descriptor->label : ""),
mLayout(descriptor->layout) {
label = mLabel.c_str();
layout = mLayout.Get();
// TODO(dawn:800): Remove after deprecation period.
if (descriptor->compute.module == nullptr && descriptor->computeStage.module != nullptr) {
mComputeModule = descriptor->computeStage.module;
mEntryPoint = descriptor->computeStage.entryPoint;
} else {
mComputeModule = descriptor->compute.module;
mEntryPoint = descriptor->compute.entryPoint;
}
compute.entryPoint = mEntryPoint.c_str();
compute.module = mComputeModule.Get();
}
void FlatComputePipelineDescriptor::SetLayout(Ref<PipelineLayoutBase> appliedLayout) {
mLayout = std::move(appliedLayout);
layout = mLayout.Get();
}
MaybeError ValidateComputePipelineDescriptor(DeviceBase* device, MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
const ComputePipelineDescriptor* descriptor) { const ComputePipelineDescriptor* descriptor) {
if (descriptor->nextInChain != nullptr) { if (descriptor->nextInChain != nullptr) {
@ -92,7 +67,7 @@ namespace dawn_native {
} }
} }
MaybeError ComputePipelineBase::Initialize(const ComputePipelineDescriptor* descriptor) { MaybeError ComputePipelineBase::Initialize() {
return {}; return {};
} }

View File

@ -23,22 +23,6 @@ namespace dawn_native {
class DeviceBase; class DeviceBase;
struct EntryPointMetadata; struct EntryPointMetadata;
// We use FlatComputePipelineDescriptor to keep all the members of ComputePipelineDescriptor
// (especially the members in pointers) valid in CreateComputePipelineAsyncTask when the
// creation of the compute pipeline is executed asynchronously.
struct FlatComputePipelineDescriptor : public ComputePipelineDescriptor, public NonMovable {
public:
explicit FlatComputePipelineDescriptor(const ComputePipelineDescriptor* descriptor);
void SetLayout(Ref<PipelineLayoutBase> appliedLayout);
private:
std::string mLabel;
Ref<PipelineLayoutBase> mLayout;
std::string mEntryPoint;
Ref<ShaderModuleBase> mComputeModule;
};
MaybeError ValidateComputePipelineDescriptor(DeviceBase* device, MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
const ComputePipelineDescriptor* descriptor); const ComputePipelineDescriptor* descriptor);
@ -60,7 +44,7 @@ namespace dawn_native {
// CreateComputePipelineAsyncTask is declared as a friend of ComputePipelineBase as it // CreateComputePipelineAsyncTask is declared as a friend of ComputePipelineBase as it
// needs to call the private member function ComputePipelineBase::Initialize(). // needs to call the private member function ComputePipelineBase::Initialize().
friend class CreateComputePipelineAsyncTask; friend class CreateComputePipelineAsyncTask;
virtual MaybeError Initialize(const ComputePipelineDescriptor* descriptor); virtual MaybeError Initialize();
}; };
} // namespace dawn_native } // namespace dawn_native

View File

@ -103,23 +103,18 @@ namespace dawn_native {
CreateComputePipelineAsyncTask::CreateComputePipelineAsyncTask( CreateComputePipelineAsyncTask::CreateComputePipelineAsyncTask(
Ref<ComputePipelineBase> nonInitializedComputePipeline, Ref<ComputePipelineBase> nonInitializedComputePipeline,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
size_t blueprintHash, size_t blueprintHash,
WGPUCreateComputePipelineAsyncCallback callback, WGPUCreateComputePipelineAsyncCallback callback,
void* userdata) void* userdata)
: mComputePipeline(nonInitializedComputePipeline), : mComputePipeline(nonInitializedComputePipeline),
mBlueprintHash(blueprintHash), mBlueprintHash(blueprintHash),
mCallback(callback), mCallback(callback),
mUserdata(userdata), mUserdata(userdata) {
mAppliedDescriptor(std::move(descriptor)) {
ASSERT(mComputePipeline != nullptr); ASSERT(mComputePipeline != nullptr);
// TODO(jiawei.shao@intel.com): save nextInChain when it is supported in Dawn.
ASSERT(mAppliedDescriptor->nextInChain == nullptr);
} }
void CreateComputePipelineAsyncTask::Run() { void CreateComputePipelineAsyncTask::Run() {
MaybeError maybeError = mComputePipeline->Initialize(mAppliedDescriptor.get()); MaybeError maybeError = mComputePipeline->Initialize();
std::string errorMessage; std::string errorMessage;
if (maybeError.IsError()) { if (maybeError.IsError()) {
mComputePipeline = nullptr; mComputePipeline = nullptr;

View File

@ -72,7 +72,6 @@ namespace dawn_native {
class CreateComputePipelineAsyncTask { class CreateComputePipelineAsyncTask {
public: public:
CreateComputePipelineAsyncTask(Ref<ComputePipelineBase> nonInitializedComputePipeline, CreateComputePipelineAsyncTask(Ref<ComputePipelineBase> nonInitializedComputePipeline,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor,
size_t blueprintHash, size_t blueprintHash,
WGPUCreateComputePipelineAsyncCallback callback, WGPUCreateComputePipelineAsyncCallback callback,
void* userdata); void* userdata);
@ -87,8 +86,6 @@ namespace dawn_native {
size_t mBlueprintHash; size_t mBlueprintHash;
WGPUCreateComputePipelineAsyncCallback mCallback; WGPUCreateComputePipelineAsyncCallback mCallback;
void* mUserdata; void* mUserdata;
std::unique_ptr<FlatComputePipelineDescriptor> mAppliedDescriptor;
}; };
} // namespace dawn_native } // namespace dawn_native

View File

@ -123,18 +123,29 @@ namespace dawn_native {
void* mUserdata; void* mUserdata;
}; };
MaybeError ValidateLayoutAndSetDefaultLayout( ResultOrError<Ref<PipelineLayoutBase>>
ValidateLayoutAndGetComputePipelineDescriptorWithDefaults(
DeviceBase* device, DeviceBase* device,
FlatComputePipelineDescriptor* appliedDescriptor) { const ComputePipelineDescriptor& descriptor,
if (appliedDescriptor->layout == nullptr) { ComputePipelineDescriptor* outDescriptor) {
Ref<PipelineLayoutBase> layoutRef; Ref<PipelineLayoutBase> layoutRef;
*outDescriptor = descriptor;
// TODO(dawn:800): Remove after deprecation period.
if (outDescriptor->compute.module == nullptr &&
outDescriptor->computeStage.module != nullptr) {
outDescriptor->compute.module = outDescriptor->computeStage.module;
outDescriptor->compute.entryPoint = outDescriptor->computeStage.entryPoint;
}
if (outDescriptor->layout == nullptr) {
DAWN_TRY_ASSIGN(layoutRef, PipelineLayoutBase::CreateDefault( DAWN_TRY_ASSIGN(layoutRef, PipelineLayoutBase::CreateDefault(
device, {{SingleShaderStage::Compute, device, {{SingleShaderStage::Compute,
appliedDescriptor->compute.module, outDescriptor->compute.module,
appliedDescriptor->compute.entryPoint}})); outDescriptor->compute.entryPoint}}));
appliedDescriptor->SetLayout(std::move(layoutRef)); outDescriptor->layout = layoutRef.Get();
} }
return {};
return layoutRef;
} }
ResultOrError<Ref<PipelineLayoutBase>> ResultOrError<Ref<PipelineLayoutBase>>
@ -1129,8 +1140,12 @@ namespace dawn_native {
DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor)); DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
} }
FlatComputePipelineDescriptor appliedDescriptor(descriptor); // Ref will keep the pipeline layout alive until the end of the function where
DAWN_TRY(ValidateLayoutAndSetDefaultLayout(this, &appliedDescriptor)); // the pipeline will take another reference.
Ref<PipelineLayoutBase> layoutRef;
ComputePipelineDescriptor appliedDescriptor;
DAWN_TRY_ASSIGN(layoutRef, ValidateLayoutAndGetComputePipelineDescriptorWithDefaults(
this, *descriptor, &appliedDescriptor));
auto pipelineAndBlueprintFromCache = GetCachedComputePipeline(&appliedDescriptor); auto pipelineAndBlueprintFromCache = GetCachedComputePipeline(&appliedDescriptor);
if (pipelineAndBlueprintFromCache.first.Get() != nullptr) { if (pipelineAndBlueprintFromCache.first.Get() != nullptr) {
@ -1152,12 +1167,13 @@ namespace dawn_native {
DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor)); DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
} }
std::unique_ptr<FlatComputePipelineDescriptor> appliedDescriptor = Ref<PipelineLayoutBase> layoutRef;
std::make_unique<FlatComputePipelineDescriptor>(descriptor); ComputePipelineDescriptor appliedDescriptor;
DAWN_TRY(ValidateLayoutAndSetDefaultLayout(this, appliedDescriptor.get())); DAWN_TRY_ASSIGN(layoutRef, ValidateLayoutAndGetComputePipelineDescriptorWithDefaults(
this, *descriptor, &appliedDescriptor));
// Call the callback directly when we can get a cached compute pipeline object. // Call the callback directly when we can get a cached compute pipeline object.
auto pipelineAndBlueprintFromCache = GetCachedComputePipeline(appliedDescriptor.get()); auto pipelineAndBlueprintFromCache = GetCachedComputePipeline(&appliedDescriptor);
if (pipelineAndBlueprintFromCache.first.Get() != nullptr) { if (pipelineAndBlueprintFromCache.first.Get() != nullptr) {
Ref<ComputePipelineBase> result = std::move(pipelineAndBlueprintFromCache.first); Ref<ComputePipelineBase> result = std::move(pipelineAndBlueprintFromCache.first);
callback(WGPUCreatePipelineAsyncStatus_Success, callback(WGPUCreatePipelineAsyncStatus_Success,
@ -1167,24 +1183,22 @@ namespace dawn_native {
// where the pipeline object may be created asynchronously and the result will be saved // where the pipeline object may be created asynchronously and the result will be saved
// to mCreatePipelineAsyncTracker. // to mCreatePipelineAsyncTracker.
const size_t blueprintHash = pipelineAndBlueprintFromCache.second; const size_t blueprintHash = pipelineAndBlueprintFromCache.second;
CreateComputePipelineAsyncImpl(std::move(appliedDescriptor), blueprintHash, callback, CreateComputePipelineAsyncImpl(&appliedDescriptor, blueprintHash, callback, userdata);
userdata);
} }
return {}; return {};
} }
// This function is overwritten with the async version on the backends // This function is overwritten with the async version on the backends that supports creating
// that supports creating compute pipeline asynchronously // compute pipeline asynchronously.
void DeviceBase::CreateComputePipelineAsyncImpl( void DeviceBase::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, size_t blueprintHash,
size_t blueprintHash, WGPUCreateComputePipelineAsyncCallback callback,
WGPUCreateComputePipelineAsyncCallback callback, void* userdata) {
void* userdata) {
Ref<ComputePipelineBase> result; Ref<ComputePipelineBase> result;
std::string errorMessage; std::string errorMessage;
auto resultOrError = CreateComputePipelineImpl(descriptor.get()); auto resultOrError = CreateComputePipelineImpl(descriptor);
if (resultOrError.IsError()) { if (resultOrError.IsError()) {
std::unique_ptr<ErrorData> error = resultOrError.AcquireError(); std::unique_ptr<ErrorData> error = resultOrError.AcquireError();
errorMessage = error->GetMessage(); errorMessage = error->GetMessage();

View File

@ -46,7 +46,6 @@ namespace dawn_native {
class PersistentCache; class PersistentCache;
class StagingBufferBase; class StagingBufferBase;
struct CallbackTask; struct CallbackTask;
struct FlatComputePipelineDescriptor;
struct InternalPipelineStore; struct InternalPipelineStore;
struct ShaderModuleParseResult; struct ShaderModuleParseResult;
@ -360,11 +359,10 @@ namespace dawn_native {
size_t blueprintHash); size_t blueprintHash);
Ref<RenderPipelineBase> AddOrGetCachedRenderPipeline(Ref<RenderPipelineBase> renderPipeline, Ref<RenderPipelineBase> AddOrGetCachedRenderPipeline(Ref<RenderPipelineBase> renderPipeline,
size_t blueprintHash); size_t blueprintHash);
virtual void CreateComputePipelineAsyncImpl( virtual void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, size_t blueprintHash,
size_t blueprintHash, WGPUCreateComputePipelineAsyncCallback callback,
WGPUCreateComputePipelineAsyncCallback callback, void* userdata);
void* userdata);
virtual void CreateRenderPipelineAsyncImpl(const RenderPipelineDescriptor* descriptor, virtual void CreateRenderPipelineAsyncImpl(const RenderPipelineDescriptor* descriptor,
size_t blueprintHash, size_t blueprintHash,
WGPUCreateRenderPipelineAsyncCallback callback, WGPUCreateRenderPipelineAsyncCallback callback,

View File

@ -28,11 +28,11 @@ namespace dawn_native { namespace d3d12 {
Device* device, Device* device,
const ComputePipelineDescriptor* descriptor) { const ComputePipelineDescriptor* descriptor) {
Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor)); Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
DAWN_TRY(pipeline->Initialize(descriptor)); DAWN_TRY(pipeline->Initialize());
return pipeline; return pipeline;
} }
MaybeError ComputePipeline::Initialize(const ComputePipelineDescriptor* descriptor) { MaybeError ComputePipeline::Initialize() {
Device* device = ToBackend(GetDevice()); Device* device = ToBackend(GetDevice());
uint32_t compileFlags = 0; uint32_t compileFlags = 0;
@ -43,14 +43,15 @@ namespace dawn_native { namespace d3d12 {
// SPRIV-cross does matrix multiplication expecting row major matrices // SPRIV-cross does matrix multiplication expecting row major matrices
compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR; compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
ShaderModule* module = ToBackend(descriptor->compute.module); const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
ShaderModule* module = ToBackend(computeStage.module.Get());
D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {}; D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature(); d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
CompiledShader compiledShader; CompiledShader compiledShader;
DAWN_TRY_ASSIGN(compiledShader, DAWN_TRY_ASSIGN(compiledShader,
module->Compile(descriptor->compute.entryPoint, SingleShaderStage::Compute, module->Compile(computeStage.entryPoint.c_str(), SingleShaderStage::Compute,
ToBackend(GetLayout()), compileFlags)); ToBackend(GetLayout()), compileFlags));
d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode(); d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
auto* d3d12Device = device->GetD3D12Device(); auto* d3d12Device = device->GetD3D12Device();
@ -77,14 +78,14 @@ namespace dawn_native { namespace d3d12 {
} }
void ComputePipeline::CreateAsync(Device* device, void ComputePipeline::CreateAsync(Device* device,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, const ComputePipelineDescriptor* descriptor,
size_t blueprintHash, size_t blueprintHash,
WGPUCreateComputePipelineAsyncCallback callback, WGPUCreateComputePipelineAsyncCallback callback,
void* userdata) { void* userdata) {
Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor.get())); Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask = std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask =
std::make_unique<CreateComputePipelineAsyncTask>(pipeline, std::move(descriptor), std::make_unique<CreateComputePipelineAsyncTask>(pipeline, blueprintHash, callback,
blueprintHash, callback, userdata); userdata);
CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask)); CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
} }

View File

@ -29,7 +29,7 @@ namespace dawn_native { namespace d3d12 {
Device* device, Device* device,
const ComputePipelineDescriptor* descriptor); const ComputePipelineDescriptor* descriptor);
static void CreateAsync(Device* device, static void CreateAsync(Device* device,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, const ComputePipelineDescriptor* descriptor,
size_t blueprintHash, size_t blueprintHash,
WGPUCreateComputePipelineAsyncCallback callback, WGPUCreateComputePipelineAsyncCallback callback,
void* userdata); void* userdata);
@ -43,7 +43,7 @@ namespace dawn_native { namespace d3d12 {
private: private:
~ComputePipeline() override; ~ComputePipeline() override;
using ComputePipelineBase::ComputePipelineBase; using ComputePipelineBase::ComputePipelineBase;
MaybeError Initialize(const ComputePipelineDescriptor* descriptor) override; MaybeError Initialize() override;
ComPtr<ID3D12PipelineState> mPipelineState; ComPtr<ID3D12PipelineState> mPipelineState;
}; };

View File

@ -376,13 +376,11 @@ namespace dawn_native { namespace d3d12 {
const TextureViewDescriptor* descriptor) { const TextureViewDescriptor* descriptor) {
return TextureView::Create(texture, descriptor); return TextureView::Create(texture, descriptor);
} }
void Device::CreateComputePipelineAsyncImpl( void Device::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, size_t blueprintHash,
size_t blueprintHash, WGPUCreateComputePipelineAsyncCallback callback,
WGPUCreateComputePipelineAsyncCallback callback, void* userdata) {
void* userdata) { ComputePipeline::CreateAsync(this, descriptor, blueprintHash, callback, userdata);
ComputePipeline::CreateAsync(this, std::move(descriptor), blueprintHash, callback,
userdata);
} }
ResultOrError<std::unique_ptr<StagingBufferBase>> Device::CreateStagingBuffer(size_t size) { ResultOrError<std::unique_ptr<StagingBufferBase>> Device::CreateStagingBuffer(size_t size) {

View File

@ -173,11 +173,10 @@ namespace dawn_native { namespace d3d12 {
ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl( ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl(
TextureBase* texture, TextureBase* texture,
const TextureViewDescriptor* descriptor) override; const TextureViewDescriptor* descriptor) override;
void CreateComputePipelineAsyncImpl( void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, size_t blueprintHash,
size_t blueprintHash, WGPUCreateComputePipelineAsyncCallback callback,
WGPUCreateComputePipelineAsyncCallback callback, void* userdata) override;
void* userdata) override;
void ShutDownImpl() override; void ShutDownImpl() override;
MaybeError WaitForIdleForDestruction() override; MaybeError WaitForIdleForDestruction() override;

View File

@ -31,7 +31,7 @@ namespace dawn_native { namespace metal {
Device* device, Device* device,
const ComputePipelineDescriptor* descriptor); const ComputePipelineDescriptor* descriptor);
static void CreateAsync(Device* device, static void CreateAsync(Device* device,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, const ComputePipelineDescriptor* descriptor,
size_t blueprintHash, size_t blueprintHash,
WGPUCreateComputePipelineAsyncCallback callback, WGPUCreateComputePipelineAsyncCallback callback,
void* userdata); void* userdata);
@ -42,7 +42,7 @@ namespace dawn_native { namespace metal {
private: private:
using ComputePipelineBase::ComputePipelineBase; using ComputePipelineBase::ComputePipelineBase;
MaybeError Initialize(const ComputePipelineDescriptor* descriptor) override; MaybeError Initialize() override;
NSPRef<id<MTLComputePipelineState>> mMtlComputePipelineState; NSPRef<id<MTLComputePipelineState>> mMtlComputePipelineState;
MTLSize mLocalWorkgroupSize; MTLSize mLocalWorkgroupSize;

View File

@ -25,15 +25,16 @@ namespace dawn_native { namespace metal {
Device* device, Device* device,
const ComputePipelineDescriptor* descriptor) { const ComputePipelineDescriptor* descriptor) {
Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor)); Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
DAWN_TRY(pipeline->Initialize(descriptor)); DAWN_TRY(pipeline->Initialize());
return pipeline; return pipeline;
} }
MaybeError ComputePipeline::Initialize(const ComputePipelineDescriptor* descriptor) { MaybeError ComputePipeline::Initialize() {
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice(); auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
ShaderModule* computeModule = ToBackend(descriptor->compute.module); const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
const char* computeEntryPoint = descriptor->compute.entryPoint; ShaderModule* computeModule = ToBackend(computeStage.module.Get());
const char* computeEntryPoint = computeStage.entryPoint.c_str();
ShaderModule::MetalFunctionData computeData; ShaderModule::MetalFunctionData computeData;
DAWN_TRY(computeModule->CreateFunction(computeEntryPoint, SingleShaderStage::Compute, DAWN_TRY(computeModule->CreateFunction(computeEntryPoint, SingleShaderStage::Compute,
ToBackend(GetLayout()), &computeData)); ToBackend(GetLayout()), &computeData));
@ -69,14 +70,14 @@ namespace dawn_native { namespace metal {
} }
void ComputePipeline::CreateAsync(Device* device, void ComputePipeline::CreateAsync(Device* device,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, const ComputePipelineDescriptor* descriptor,
size_t blueprintHash, size_t blueprintHash,
WGPUCreateComputePipelineAsyncCallback callback, WGPUCreateComputePipelineAsyncCallback callback,
void* userdata) { void* userdata) {
Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor.get())); Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask = std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask =
std::make_unique<CreateComputePipelineAsyncTask>(pipeline, std::move(descriptor), std::make_unique<CreateComputePipelineAsyncTask>(pipeline, blueprintHash, callback,
blueprintHash, callback, userdata); userdata);
CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask)); CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
} }

View File

@ -113,11 +113,10 @@ namespace dawn_native { namespace metal {
ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl( ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl(
TextureBase* texture, TextureBase* texture,
const TextureViewDescriptor* descriptor) override; const TextureViewDescriptor* descriptor) override;
void CreateComputePipelineAsyncImpl( void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, size_t blueprintHash,
size_t blueprintHash, WGPUCreateComputePipelineAsyncCallback callback,
WGPUCreateComputePipelineAsyncCallback callback, void* userdata) override;
void* userdata) override;
void InitTogglesFromDriver(); void InitTogglesFromDriver();
void ShutDownImpl() override; void ShutDownImpl() override;

View File

@ -267,13 +267,11 @@ namespace dawn_native { namespace metal {
const TextureViewDescriptor* descriptor) { const TextureViewDescriptor* descriptor) {
return TextureView::Create(texture, descriptor); return TextureView::Create(texture, descriptor);
} }
void Device::CreateComputePipelineAsyncImpl( void Device::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, size_t blueprintHash,
size_t blueprintHash, WGPUCreateComputePipelineAsyncCallback callback,
WGPUCreateComputePipelineAsyncCallback callback, void* userdata) {
void* userdata) { ComputePipeline::CreateAsync(this, descriptor, blueprintHash, callback, userdata);
ComputePipeline::CreateAsync(this, std::move(descriptor), blueprintHash, callback,
userdata);
} }
ResultOrError<ExecutionSerial> Device::CheckAndUpdateCompletedSerials() { ResultOrError<ExecutionSerial> Device::CheckAndUpdateCompletedSerials() {

View File

@ -23,11 +23,11 @@ namespace dawn_native { namespace opengl {
Device* device, Device* device,
const ComputePipelineDescriptor* descriptor) { const ComputePipelineDescriptor* descriptor) {
Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor)); Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
DAWN_TRY(pipeline->Initialize(descriptor)); DAWN_TRY(pipeline->Initialize());
return pipeline; return pipeline;
} }
MaybeError ComputePipeline::Initialize(const ComputePipelineDescriptor*) { MaybeError ComputePipeline::Initialize() {
DAWN_TRY( DAWN_TRY(
InitializeBase(ToBackend(GetDevice())->gl, ToBackend(GetLayout()), GetAllStages())); InitializeBase(ToBackend(GetDevice())->gl, ToBackend(GetLayout()), GetAllStages()));
return {}; return {};

View File

@ -36,7 +36,7 @@ namespace dawn_native { namespace opengl {
private: private:
using ComputePipelineBase::ComputePipelineBase; using ComputePipelineBase::ComputePipelineBase;
~ComputePipeline() override = default; ~ComputePipeline() override = default;
MaybeError Initialize(const ComputePipelineDescriptor* descriptor) override; MaybeError Initialize() override;
}; };
}} // namespace dawn_native::opengl }} // namespace dawn_native::opengl

View File

@ -29,16 +29,16 @@ namespace dawn_native { namespace vulkan {
Device* device, Device* device,
const ComputePipelineDescriptor* descriptor) { const ComputePipelineDescriptor* descriptor) {
Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor)); Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
DAWN_TRY(pipeline->Initialize(descriptor)); DAWN_TRY(pipeline->Initialize());
return pipeline; return pipeline;
} }
MaybeError ComputePipeline::Initialize(const ComputePipelineDescriptor* descriptor) { MaybeError ComputePipeline::Initialize() {
VkComputePipelineCreateInfo createInfo; VkComputePipelineCreateInfo createInfo;
createInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO; createInfo.sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
createInfo.pNext = nullptr; createInfo.pNext = nullptr;
createInfo.flags = 0; createInfo.flags = 0;
createInfo.layout = ToBackend(descriptor->layout)->GetHandle(); createInfo.layout = ToBackend(GetLayout())->GetHandle();
createInfo.basePipelineHandle = ::VK_NULL_HANDLE; createInfo.basePipelineHandle = ::VK_NULL_HANDLE;
createInfo.basePipelineIndex = -1; createInfo.basePipelineIndex = -1;
@ -47,11 +47,12 @@ namespace dawn_native { namespace vulkan {
createInfo.stage.flags = 0; createInfo.stage.flags = 0;
createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
// Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline // Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
DAWN_TRY_ASSIGN(createInfo.stage.module, DAWN_TRY_ASSIGN(createInfo.stage.module,
ToBackend(descriptor->compute.module) ToBackend(computeStage.module.Get())
->GetTransformedModuleHandle(descriptor->compute.entryPoint, ->GetTransformedModuleHandle(computeStage.entryPoint.c_str(),
ToBackend(GetLayout()))); ToBackend(GetLayout())));
createInfo.stage.pName = descriptor->compute.entryPoint; createInfo.stage.pName = computeStage.entryPoint.c_str();
createInfo.stage.pSpecializationInfo = nullptr; createInfo.stage.pSpecializationInfo = nullptr;
Device* device = ToBackend(GetDevice()); Device* device = ToBackend(GetDevice());
@ -95,14 +96,14 @@ namespace dawn_native { namespace vulkan {
} }
void ComputePipeline::CreateAsync(Device* device, void ComputePipeline::CreateAsync(Device* device,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, const ComputePipelineDescriptor* descriptor,
size_t blueprintHash, size_t blueprintHash,
WGPUCreateComputePipelineAsyncCallback callback, WGPUCreateComputePipelineAsyncCallback callback,
void* userdata) { void* userdata) {
Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor.get())); Ref<ComputePipeline> pipeline = AcquireRef(new ComputePipeline(device, descriptor));
std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask = std::unique_ptr<CreateComputePipelineAsyncTask> asyncTask =
std::make_unique<CreateComputePipelineAsyncTask>(pipeline, std::move(descriptor), std::make_unique<CreateComputePipelineAsyncTask>(pipeline, blueprintHash, callback,
blueprintHash, callback, userdata); userdata);
CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask)); CreateComputePipelineAsyncTask::RunAsync(std::move(asyncTask));
} }

View File

@ -30,7 +30,7 @@ namespace dawn_native { namespace vulkan {
Device* device, Device* device,
const ComputePipelineDescriptor* descriptor); const ComputePipelineDescriptor* descriptor);
static void CreateAsync(Device* device, static void CreateAsync(Device* device,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, const ComputePipelineDescriptor* descriptor,
size_t blueprintHash, size_t blueprintHash,
WGPUCreateComputePipelineAsyncCallback callback, WGPUCreateComputePipelineAsyncCallback callback,
void* userdata); void* userdata);
@ -43,7 +43,7 @@ namespace dawn_native { namespace vulkan {
private: private:
~ComputePipeline() override; ~ComputePipeline() override;
using ComputePipelineBase::ComputePipelineBase; using ComputePipelineBase::ComputePipelineBase;
MaybeError Initialize(const ComputePipelineDescriptor* descriptor) override; MaybeError Initialize() override;
VkPipeline mHandle = VK_NULL_HANDLE; VkPipeline mHandle = VK_NULL_HANDLE;
}; };

View File

@ -162,13 +162,11 @@ namespace dawn_native { namespace vulkan {
const TextureViewDescriptor* descriptor) { const TextureViewDescriptor* descriptor) {
return TextureView::Create(texture, descriptor); return TextureView::Create(texture, descriptor);
} }
void Device::CreateComputePipelineAsyncImpl( void Device::CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, size_t blueprintHash,
size_t blueprintHash, WGPUCreateComputePipelineAsyncCallback callback,
WGPUCreateComputePipelineAsyncCallback callback, void* userdata) {
void* userdata) { ComputePipeline::CreateAsync(this, descriptor, blueprintHash, callback, userdata);
ComputePipeline::CreateAsync(this, std::move(descriptor), blueprintHash, callback,
userdata);
} }
MaybeError Device::TickImpl() { MaybeError Device::TickImpl() {

View File

@ -137,11 +137,10 @@ namespace dawn_native { namespace vulkan {
ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl( ResultOrError<Ref<TextureViewBase>> CreateTextureViewImpl(
TextureBase* texture, TextureBase* texture,
const TextureViewDescriptor* descriptor) override; const TextureViewDescriptor* descriptor) override;
void CreateComputePipelineAsyncImpl( void CreateComputePipelineAsyncImpl(const ComputePipelineDescriptor* descriptor,
std::unique_ptr<FlatComputePipelineDescriptor> descriptor, size_t blueprintHash,
size_t blueprintHash, WGPUCreateComputePipelineAsyncCallback callback,
WGPUCreateComputePipelineAsyncCallback callback, void* userdata) override;
void* userdata) override;
ResultOrError<VulkanDeviceKnobs> CreateDevice(VkPhysicalDevice physicalDevice); ResultOrError<VulkanDeviceKnobs> CreateDevice(VkPhysicalDevice physicalDevice);
void GatherQueueFromDevice(); void GatherQueueFromDevice();