diff --git a/dawn.json b/dawn.json index a2516ba1a3..82b243a5ff 100644 --- a/dawn.json +++ b/dawn.json @@ -375,8 +375,7 @@ "extensible": true, "members": [ {"name": "layout", "type": "pipeline layout"}, - {"name": "module", "type": "shader module"}, - {"name": "entry point", "type": "char", "annotation": "const*", "length": "strlen"} + {"name": "compute stage", "type": "pipeline stage descriptor", "annotation": "const*"} ] }, "device": { diff --git a/examples/ComputeBoids.cpp b/examples/ComputeBoids.cpp index e9f44ac897..ae4234ce7a 100644 --- a/examples/ComputeBoids.cpp +++ b/examples/ComputeBoids.cpp @@ -241,9 +241,13 @@ void initSim() { dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl); dawn::ComputePipelineDescriptor csDesc; - csDesc.module = module; - csDesc.entryPoint = "main"; csDesc.layout = pl; + + dawn::PipelineStageDescriptor computeStage; + computeStage.module = module; + computeStage.entryPoint = "main"; + csDesc.computeStage = &computeStage; + updatePipeline = device.CreateComputePipeline(&csDesc); for (uint32_t i = 0; i < 2; ++i) { diff --git a/src/dawn_native/ComputePipeline.cpp b/src/dawn_native/ComputePipeline.cpp index c95115b458..a2fb60f652 100644 --- a/src/dawn_native/ComputePipeline.cpp +++ b/src/dawn_native/ComputePipeline.cpp @@ -24,21 +24,9 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR("nextInChain must be nullptr"); } - DAWN_TRY(device->ValidateObject(descriptor->module)); DAWN_TRY(device->ValidateObject(descriptor->layout)); - - if (descriptor->entryPoint != std::string("main")) { - return DAWN_VALIDATION_ERROR("Currently the entry point has to be main()"); - } - - if (descriptor->module->GetExecutionModel() != dawn::ShaderStage::Compute) { - return DAWN_VALIDATION_ERROR("Setting module with wrong execution model"); - } - - if (!descriptor->module->IsCompatibleWithPipelineLayout(descriptor->layout)) { - return DAWN_VALIDATION_ERROR("Stage not compatible with layout"); - } - + DAWN_TRY(ValidatePipelineStageDescriptor(device, descriptor->computeStage, + descriptor->layout, dawn::ShaderStage::Compute)); return {}; } @@ -47,7 +35,7 @@ namespace dawn_native { ComputePipelineBase::ComputePipelineBase(DeviceBase* device, const ComputePipelineDescriptor* descriptor) : PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) { - ExtractModuleData(dawn::ShaderStage::Compute, descriptor->module); + ExtractModuleData(dawn::ShaderStage::Compute, descriptor->computeStage->module); } ComputePipelineBase::ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag) diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp index 257bf7695e..e839b1b5b4 100644 --- a/src/dawn_native/Pipeline.cpp +++ b/src/dawn_native/Pipeline.cpp @@ -20,6 +20,24 @@ namespace dawn_native { + MaybeError ValidatePipelineStageDescriptor(DeviceBase* device, + const PipelineStageDescriptor* descriptor, + const PipelineLayoutBase* layout, + dawn::ShaderStage stage) { + DAWN_TRY(device->ValidateObject(descriptor->module)); + + if (descriptor->entryPoint != std::string("main")) { + return DAWN_VALIDATION_ERROR("Entry point must be \"main\""); + } + if (descriptor->module->GetExecutionModel() != stage) { + return DAWN_VALIDATION_ERROR("Setting module with wrong stages"); + } + if (!descriptor->module->IsCompatibleWithPipelineLayout(layout)) { + return DAWN_VALIDATION_ERROR("Stage not compatible with layout"); + } + return {}; + } + // PipelineBase PipelineBase::PipelineBase(DeviceBase* device, diff --git a/src/dawn_native/Pipeline.h b/src/dawn_native/Pipeline.h index 1f141f9906..c917125bd4 100644 --- a/src/dawn_native/Pipeline.h +++ b/src/dawn_native/Pipeline.h @@ -34,6 +34,11 @@ namespace dawn_native { Float, }; + MaybeError ValidatePipelineStageDescriptor(DeviceBase* device, + const PipelineStageDescriptor* descriptor, + const PipelineLayoutBase* layout, + dawn::ShaderStage stage); + class PipelineBase : public ObjectBase { public: struct PushConstantInfo { diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index 9c2cf4e9e6..eb8a7b869a 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -97,24 +97,6 @@ namespace dawn_native { return {}; } - MaybeError ValidatePipelineStageDescriptor(DeviceBase* device, - const PipelineStageDescriptor* descriptor, - const PipelineLayoutBase* layout, - dawn::ShaderStage stage) { - DAWN_TRY(device->ValidateObject(descriptor->module)); - - if (descriptor->entryPoint != std::string("main")) { - return DAWN_VALIDATION_ERROR("Entry point must be \"main\""); - } - if (descriptor->module->GetExecutionModel() != stage) { - return DAWN_VALIDATION_ERROR("Setting module with wrong stages"); - } - if (!descriptor->module->IsCompatibleWithPipelineLayout(layout)) { - return DAWN_VALIDATION_ERROR("Stage not compatible with layout"); - } - return {}; - } - MaybeError ValidateColorStateDescriptor(const ColorStateDescriptor* descriptor) { if (descriptor->nextInChain != nullptr) { return DAWN_VALIDATION_ERROR("nextInChain must be nullptr"); diff --git a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp index 67f4cbc474..d70846ea09 100644 --- a/src/dawn_native/d3d12/ComputePipelineD3D12.cpp +++ b/src/dawn_native/d3d12/ComputePipelineD3D12.cpp @@ -32,7 +32,7 @@ namespace dawn_native { namespace d3d12 { // SPRIV-cross does matrix multiplication expecting row major matrices compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR; - const ShaderModule* module = ToBackend(descriptor->module); + const ShaderModule* module = ToBackend(descriptor->computeStage->module); const std::string& hlslSource = module->GetHLSLSource(ToBackend(GetLayout())); ComPtr compiledShader; @@ -40,8 +40,8 @@ namespace dawn_native { namespace d3d12 { const PlatformFunctions* functions = device->GetFunctions(); if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, - nullptr, descriptor->entryPoint, "cs_5_1", compileFlags, 0, - &compiledShader, &errors))) { + nullptr, descriptor->computeStage->entryPoint, "cs_5_1", + compileFlags, 0, &compiledShader, &errors))) { printf("%s\n", reinterpret_cast(errors->GetBufferPointer())); ASSERT(false); } diff --git a/src/dawn_native/metal/ComputePipelineMTL.mm b/src/dawn_native/metal/ComputePipelineMTL.mm index 5c6eaa62f4..fc76ee195c 100644 --- a/src/dawn_native/metal/ComputePipelineMTL.mm +++ b/src/dawn_native/metal/ComputePipelineMTL.mm @@ -23,15 +23,14 @@ namespace dawn_native { namespace metal { : ComputePipelineBase(device, descriptor) { auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice(); - const auto& module = ToBackend(descriptor->module); - const char* entryPoint = descriptor->entryPoint; - - auto compilationData = - module->GetFunction(entryPoint, dawn::ShaderStage::Compute, ToBackend(GetLayout())); + const ShaderModule* computeModule = ToBackend(descriptor->computeStage->module); + const char* computeEntryPoint = descriptor->computeStage->entryPoint; + ShaderModule::MetalFunctionData computeData = computeModule->GetFunction( + computeEntryPoint, dawn::ShaderStage::Compute, ToBackend(GetLayout())); NSError* error = nil; mMtlComputePipelineState = - [mtlDevice newComputePipelineStateWithFunction:compilationData.function error:&error]; + [mtlDevice newComputePipelineStateWithFunction:computeData.function error:&error]; if (error != nil) { NSLog(@" error => %@", error); GetDevice()->HandleError("Error creating pipeline state"); @@ -39,7 +38,7 @@ namespace dawn_native { namespace metal { } // Copy over the local workgroup size as it is passed to dispatch explicitly in Metal - mLocalWorkgroupSize = compilationData.localWorkgroupSize; + mLocalWorkgroupSize = computeData.localWorkgroupSize; } ComputePipeline::~ComputePipeline() { diff --git a/src/dawn_native/opengl/ComputePipelineGL.cpp b/src/dawn_native/opengl/ComputePipelineGL.cpp index 815e4d7499..2cbad4efcf 100644 --- a/src/dawn_native/opengl/ComputePipelineGL.cpp +++ b/src/dawn_native/opengl/ComputePipelineGL.cpp @@ -21,7 +21,7 @@ namespace dawn_native { namespace opengl { ComputePipeline::ComputePipeline(Device* device, const ComputePipelineDescriptor* descriptor) : ComputePipelineBase(device, descriptor) { PerStage modules(nullptr); - modules[dawn::ShaderStage::Compute] = ToBackend(descriptor->module); + modules[dawn::ShaderStage::Compute] = ToBackend(descriptor->computeStage->module); PipelineGL::Initialize(ToBackend(descriptor->layout), modules); } diff --git a/src/dawn_native/vulkan/ComputePipelineVk.cpp b/src/dawn_native/vulkan/ComputePipelineVk.cpp index 06948b35a8..8e7c7aa746 100644 --- a/src/dawn_native/vulkan/ComputePipelineVk.cpp +++ b/src/dawn_native/vulkan/ComputePipelineVk.cpp @@ -35,8 +35,8 @@ namespace dawn_native { namespace vulkan { createInfo.stage.pNext = nullptr; createInfo.stage.flags = 0; createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; - createInfo.stage.module = ToBackend(descriptor->module)->GetHandle(); - createInfo.stage.pName = descriptor->entryPoint; + createInfo.stage.module = ToBackend(descriptor->computeStage->module)->GetHandle(); + createInfo.stage.pName = descriptor->computeStage->entryPoint; createInfo.stage.pSpecializationInfo = nullptr; if (device->fn.CreateComputePipelines(device->GetVkDevice(), VK_NULL_HANDLE, 1, &createInfo, diff --git a/src/tests/end2end/BindGroupTests.cpp b/src/tests/end2end/BindGroupTests.cpp index 9bf437e26c..3a46286e3e 100644 --- a/src/tests/end2end/BindGroupTests.cpp +++ b/src/tests/end2end/BindGroupTests.cpp @@ -68,9 +68,13 @@ TEST_P(BindGroupTests, ReusedBindGroupSingleSubmit) { dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, shader); dawn::ComputePipelineDescriptor cpDesc; - cpDesc.module = module; - cpDesc.entryPoint = "main"; cpDesc.layout = pl; + + dawn::PipelineStageDescriptor computeStage; + computeStage.module = module; + computeStage.entryPoint = "main"; + cpDesc.computeStage = &computeStage; + dawn::ComputePipeline cp = device.CreateComputePipeline(&cpDesc); dawn::BufferDescriptor bufferDesc; diff --git a/src/tests/end2end/ComputeCopyStorageBufferTests.cpp b/src/tests/end2end/ComputeCopyStorageBufferTests.cpp index 337f2f18fd..6cca2067cd 100644 --- a/src/tests/end2end/ComputeCopyStorageBufferTests.cpp +++ b/src/tests/end2end/ComputeCopyStorageBufferTests.cpp @@ -39,9 +39,13 @@ void ComputeCopyStorageBufferTests::BasicTest(const char* shader) { auto pl = utils::MakeBasicPipelineLayout(device, &bgl); dawn::ComputePipelineDescriptor csDesc; - csDesc.module = module; - csDesc.entryPoint = "main"; csDesc.layout = pl; + + dawn::PipelineStageDescriptor computeStage; + computeStage.module = module; + computeStage.entryPoint = "main"; + csDesc.computeStage = &computeStage; + dawn::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc); // Set up src storage buffer diff --git a/src/tests/end2end/PushConstantTests.cpp b/src/tests/end2end/PushConstantTests.cpp index c85aca631e..e385a17c1e 100644 --- a/src/tests/end2end/PushConstantTests.cpp +++ b/src/tests/end2end/PushConstantTests.cpp @@ -149,9 +149,13 @@ class PushConstantTest: public DawnTest { ); dawn::ComputePipelineDescriptor descriptor; - descriptor.module = module; - descriptor.entryPoint = "main"; descriptor.layout = pl; + + dawn::PipelineStageDescriptor computeStage; + computeStage.module = module; + computeStage.entryPoint = "main"; + descriptor.computeStage = &computeStage; + return device.CreateComputePipeline(&descriptor); }