diff --git a/src/dawn_native/BUILD.gn b/src/dawn_native/BUILD.gn index 6b70082ebf..4558f8cd73 100644 --- a/src/dawn_native/BUILD.gn +++ b/src/dawn_native/BUILD.gn @@ -253,6 +253,8 @@ source_set("dawn_native_sources") { "Sampler.h", "ShaderModule.cpp", "ShaderModule.h", + "SpirvUtils.cpp", + "SpirvUtils.h", "StagingBuffer.cpp", "StagingBuffer.h", "Surface.cpp", diff --git a/src/dawn_native/CMakeLists.txt b/src/dawn_native/CMakeLists.txt index b7db56139e..ac435bbc09 100644 --- a/src/dawn_native/CMakeLists.txt +++ b/src/dawn_native/CMakeLists.txt @@ -131,6 +131,8 @@ target_sources(dawn_native PRIVATE "Sampler.h" "ShaderModule.cpp" "ShaderModule.h" + "SpirvUtils.cpp" + "SpirvUtils.h" "StagingBuffer.cpp" "StagingBuffer.h" "Surface.cpp" diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 2bdadb9d53..3d68909ed8 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -19,6 +19,7 @@ #include "dawn_native/Device.h" #include "dawn_native/Pipeline.h" #include "dawn_native/PipelineLayout.h" +#include "dawn_native/SpirvUtils.h" #include #include @@ -36,114 +37,6 @@ namespace dawn_native { namespace { - Format::Type SpirvCrossBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType) { - switch (spirvBaseType) { - case spirv_cross::SPIRType::Float: - return Format::Type::Float; - case spirv_cross::SPIRType::Int: - return Format::Type::Sint; - case spirv_cross::SPIRType::UInt: - return Format::Type::Uint; - default: - UNREACHABLE(); - return Format::Type::Other; - } - } - - wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed) { - switch (dim) { - case spv::Dim::Dim1D: - return wgpu::TextureViewDimension::e1D; - case spv::Dim::Dim2D: - if (arrayed) { - return wgpu::TextureViewDimension::e2DArray; - } else { - return wgpu::TextureViewDimension::e2D; - } - case spv::Dim::Dim3D: - return wgpu::TextureViewDimension::e3D; - case spv::Dim::DimCube: - if (arrayed) { - return wgpu::TextureViewDimension::CubeArray; - } else { - return wgpu::TextureViewDimension::Cube; - } - default: - UNREACHABLE(); - return wgpu::TextureViewDimension::Undefined; - } - } - - wgpu::TextureFormat ToWGPUTextureFormat(spv::ImageFormat format) { - switch (format) { - case spv::ImageFormatR8: - return wgpu::TextureFormat::R8Unorm; - case spv::ImageFormatR8Snorm: - return wgpu::TextureFormat::R8Snorm; - case spv::ImageFormatR8ui: - return wgpu::TextureFormat::R8Uint; - case spv::ImageFormatR8i: - return wgpu::TextureFormat::R8Sint; - case spv::ImageFormatR16ui: - return wgpu::TextureFormat::R16Uint; - case spv::ImageFormatR16i: - return wgpu::TextureFormat::R16Sint; - case spv::ImageFormatR16f: - return wgpu::TextureFormat::R16Float; - case spv::ImageFormatRg8: - return wgpu::TextureFormat::RG8Unorm; - case spv::ImageFormatRg8Snorm: - return wgpu::TextureFormat::RG8Snorm; - case spv::ImageFormatRg8ui: - return wgpu::TextureFormat::RG8Uint; - case spv::ImageFormatRg8i: - return wgpu::TextureFormat::RG8Sint; - case spv::ImageFormatR32f: - return wgpu::TextureFormat::R32Float; - case spv::ImageFormatR32ui: - return wgpu::TextureFormat::R32Uint; - case spv::ImageFormatR32i: - return wgpu::TextureFormat::R32Sint; - case spv::ImageFormatRg16ui: - return wgpu::TextureFormat::RG16Uint; - case spv::ImageFormatRg16i: - return wgpu::TextureFormat::RG16Sint; - case spv::ImageFormatRg16f: - return wgpu::TextureFormat::RG16Float; - case spv::ImageFormatRgba8: - return wgpu::TextureFormat::RGBA8Unorm; - case spv::ImageFormatRgba8Snorm: - return wgpu::TextureFormat::RGBA8Snorm; - case spv::ImageFormatRgba8ui: - return wgpu::TextureFormat::RGBA8Uint; - case spv::ImageFormatRgba8i: - return wgpu::TextureFormat::RGBA8Sint; - case spv::ImageFormatRgb10A2: - return wgpu::TextureFormat::RGB10A2Unorm; - case spv::ImageFormatR11fG11fB10f: - return wgpu::TextureFormat::RG11B10Ufloat; - case spv::ImageFormatRg32f: - return wgpu::TextureFormat::RG32Float; - case spv::ImageFormatRg32ui: - return wgpu::TextureFormat::RG32Uint; - case spv::ImageFormatRg32i: - return wgpu::TextureFormat::RG32Sint; - case spv::ImageFormatRgba16ui: - return wgpu::TextureFormat::RGBA16Uint; - case spv::ImageFormatRgba16i: - return wgpu::TextureFormat::RGBA16Sint; - case spv::ImageFormatRgba16f: - return wgpu::TextureFormat::RGBA16Float; - case spv::ImageFormatRgba32f: - return wgpu::TextureFormat::RGBA32Float; - case spv::ImageFormatRgba32ui: - return wgpu::TextureFormat::RGBA32Uint; - case spv::ImageFormatRgba32i: - return wgpu::TextureFormat::RGBA32Sint; - default: - return wgpu::TextureFormat::Undefined; - } - } std::string GetShaderDeclarationString(BindGroupIndex group, BindingNumber binding) { std::ostringstream ostream; @@ -550,27 +443,15 @@ namespace dawn_native { ResultOrError> ExtractSpirvInfo( const DeviceBase* device, - const spirv_cross::Compiler& compiler) { + const spirv_cross::Compiler& compiler, + const char* entryPointName) { std::unique_ptr metadata = std::make_unique(); // TODO(cwallez@chromium.org): make errors here creation errors // currently errors here do not prevent the shadermodule from being used const auto& resources = compiler.get_shader_resources(); - switch (compiler.get_execution_model()) { - case spv::ExecutionModelVertex: - metadata->stage = SingleShaderStage::Vertex; - break; - case spv::ExecutionModelFragment: - metadata->stage = SingleShaderStage::Fragment; - break; - case spv::ExecutionModelGLCompute: - metadata->stage = SingleShaderStage::Compute; - break; - default: - UNREACHABLE(); - return DAWN_VALIDATION_ERROR("Unexpected shader execution model"); - } + metadata->stage = ExecutionModelToShaderStage(compiler.get_execution_model()); if (resources.push_constant_buffers.size() > 0) { return DAWN_VALIDATION_ERROR("Push constants aren't supported."); @@ -635,7 +516,7 @@ namespace dawn_native { info->viewDimension = SpirvDimToTextureViewDimension(imageType.dim, imageType.arrayed); info->textureComponentType = - SpirvCrossBaseTypeToFormatType(textureComponentType); + SpirvBaseTypeToFormatType(textureComponentType); info->type = bindingType; break; } @@ -664,7 +545,7 @@ namespace dawn_native { spirv_cross::SPIRType::ImageType imageType = compiler.get_type(info->base_type_id).image; wgpu::TextureFormat storageTextureFormat = - ToWGPUTextureFormat(imageType.format); + SpirvImageFormatToTextureFormat(imageType.format); if (storageTextureFormat == wgpu::TextureFormat::Undefined) { return DAWN_VALIDATION_ERROR( "Invalid image format declaration on storage image"); @@ -756,7 +637,7 @@ namespace dawn_native { spirv_cross::SPIRType::BaseType shaderFragmentOutputBaseType = compiler.get_type(fragmentOutput.base_type_id).basetype; Format::Type formatType = - SpirvCrossBaseTypeToFormatType(shaderFragmentOutputBaseType); + SpirvBaseTypeToFormatType(shaderFragmentOutputBaseType); if (formatType == Format::Type::Other) { return DAWN_VALIDATION_ERROR("Unexpected Fragment output type"); } @@ -764,6 +645,14 @@ namespace dawn_native { } } + if (metadata->stage == SingleShaderStage::Compute) { + const spirv_cross::SPIREntryPoint& spirEntryPoint = + compiler.get_entry_point(entryPointName, spv::ExecutionModelGLCompute); + metadata->localWorkgroupSize.x = spirEntryPoint.workgroup_size.x; + metadata->localWorkgroupSize.y = spirEntryPoint.workgroup_size.y; + metadata->localWorkgroupSize.z = spirEntryPoint.workgroup_size.z; + } + return {std::move(metadata)}; } @@ -935,7 +824,7 @@ namespace dawn_native { } spirv_cross::Compiler compiler(mSpirv); - DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfo(GetDevice(), compiler)); + DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfo(GetDevice(), compiler, "main")); return {}; } diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index 36f7c7843c..aef5286641 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -24,7 +24,6 @@ #include "dawn_native/Forward.h" #include "dawn_native/IntegerTypes.h" #include "dawn_native/PerStage.h" - #include "dawn_native/dawn_platform.h" #include @@ -82,8 +81,10 @@ namespace dawn_native { ityp::array; FragmentOutputBaseTypes fragmentOutputFormatBaseTypes; - // The shader stage for this binding, TODO(dawn:216): can likely be removed once we - // properly support multiple entrypoints per ShaderModule. + // The local workgroup size declared for a compute entry point (or 0s otehrwise). + Origin3D localWorkgroupSize; + + // The shader stage for this binding. SingleShaderStage stage; }; diff --git a/src/dawn_native/SpirvUtils.cpp b/src/dawn_native/SpirvUtils.cpp new file mode 100644 index 0000000000..e462e0d1f7 --- /dev/null +++ b/src/dawn_native/SpirvUtils.cpp @@ -0,0 +1,154 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn_native/SpirvUtils.h" + +namespace dawn_native { + + spv::ExecutionModel ShaderStageToExecutionModel(SingleShaderStage stage) { + switch (stage) { + case SingleShaderStage::Vertex: + return spv::ExecutionModelVertex; + case SingleShaderStage::Fragment: + return spv::ExecutionModelFragment; + case SingleShaderStage::Compute: + return spv::ExecutionModelGLCompute; + default: + UNREACHABLE(); + } + } + + SingleShaderStage ExecutionModelToShaderStage(spv::ExecutionModel model) { + switch (model) { + case spv::ExecutionModelVertex: + return SingleShaderStage::Vertex; + case spv::ExecutionModelFragment: + return SingleShaderStage::Fragment; + case spv::ExecutionModelGLCompute: + return SingleShaderStage::Compute; + default: + UNREACHABLE(); + } + } + + wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed) { + switch (dim) { + case spv::Dim::Dim1D: + return wgpu::TextureViewDimension::e1D; + case spv::Dim::Dim2D: + if (arrayed) { + return wgpu::TextureViewDimension::e2DArray; + } else { + return wgpu::TextureViewDimension::e2D; + } + case spv::Dim::Dim3D: + return wgpu::TextureViewDimension::e3D; + case spv::Dim::DimCube: + if (arrayed) { + return wgpu::TextureViewDimension::CubeArray; + } else { + return wgpu::TextureViewDimension::Cube; + } + default: + UNREACHABLE(); + return wgpu::TextureViewDimension::Undefined; + } + } + + wgpu::TextureFormat SpirvImageFormatToTextureFormat(spv::ImageFormat format) { + switch (format) { + case spv::ImageFormatR8: + return wgpu::TextureFormat::R8Unorm; + case spv::ImageFormatR8Snorm: + return wgpu::TextureFormat::R8Snorm; + case spv::ImageFormatR8ui: + return wgpu::TextureFormat::R8Uint; + case spv::ImageFormatR8i: + return wgpu::TextureFormat::R8Sint; + case spv::ImageFormatR16ui: + return wgpu::TextureFormat::R16Uint; + case spv::ImageFormatR16i: + return wgpu::TextureFormat::R16Sint; + case spv::ImageFormatR16f: + return wgpu::TextureFormat::R16Float; + case spv::ImageFormatRg8: + return wgpu::TextureFormat::RG8Unorm; + case spv::ImageFormatRg8Snorm: + return wgpu::TextureFormat::RG8Snorm; + case spv::ImageFormatRg8ui: + return wgpu::TextureFormat::RG8Uint; + case spv::ImageFormatRg8i: + return wgpu::TextureFormat::RG8Sint; + case spv::ImageFormatR32f: + return wgpu::TextureFormat::R32Float; + case spv::ImageFormatR32ui: + return wgpu::TextureFormat::R32Uint; + case spv::ImageFormatR32i: + return wgpu::TextureFormat::R32Sint; + case spv::ImageFormatRg16ui: + return wgpu::TextureFormat::RG16Uint; + case spv::ImageFormatRg16i: + return wgpu::TextureFormat::RG16Sint; + case spv::ImageFormatRg16f: + return wgpu::TextureFormat::RG16Float; + case spv::ImageFormatRgba8: + return wgpu::TextureFormat::RGBA8Unorm; + case spv::ImageFormatRgba8Snorm: + return wgpu::TextureFormat::RGBA8Snorm; + case spv::ImageFormatRgba8ui: + return wgpu::TextureFormat::RGBA8Uint; + case spv::ImageFormatRgba8i: + return wgpu::TextureFormat::RGBA8Sint; + case spv::ImageFormatRgb10A2: + return wgpu::TextureFormat::RGB10A2Unorm; + case spv::ImageFormatR11fG11fB10f: + return wgpu::TextureFormat::RG11B10Ufloat; + case spv::ImageFormatRg32f: + return wgpu::TextureFormat::RG32Float; + case spv::ImageFormatRg32ui: + return wgpu::TextureFormat::RG32Uint; + case spv::ImageFormatRg32i: + return wgpu::TextureFormat::RG32Sint; + case spv::ImageFormatRgba16ui: + return wgpu::TextureFormat::RGBA16Uint; + case spv::ImageFormatRgba16i: + return wgpu::TextureFormat::RGBA16Sint; + case spv::ImageFormatRgba16f: + return wgpu::TextureFormat::RGBA16Float; + case spv::ImageFormatRgba32f: + return wgpu::TextureFormat::RGBA32Float; + case spv::ImageFormatRgba32ui: + return wgpu::TextureFormat::RGBA32Uint; + case spv::ImageFormatRgba32i: + return wgpu::TextureFormat::RGBA32Sint; + default: + return wgpu::TextureFormat::Undefined; + } + } + + Format::Type SpirvBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType) { + switch (spirvBaseType) { + case spirv_cross::SPIRType::Float: + return Format::Type::Float; + case spirv_cross::SPIRType::Int: + return Format::Type::Sint; + case spirv_cross::SPIRType::UInt: + return Format::Type::Uint; + default: + UNREACHABLE(); + return Format::Type::Other; + } + } + +} // namespace dawn_native diff --git a/src/dawn_native/SpirvUtils.h b/src/dawn_native/SpirvUtils.h new file mode 100644 index 0000000000..ceb6fd6dcf --- /dev/null +++ b/src/dawn_native/SpirvUtils.h @@ -0,0 +1,44 @@ +// Copyright 2020 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// This file contains utilities to convert from-to spirv.hpp datatypes without polluting other +// headers with spirv.hpp + +#ifndef DAWNNATIVE_SPIRV_UTILS_H_ +#define DAWNNATIVE_SPIRV_UTILS_H_ + +#include "dawn_native/Format.h" +#include "dawn_native/PerStage.h" +#include "dawn_native/dawn_platform.h" + +#include + +namespace dawn_native { + + // Returns the spirv_cross equivalent for this shader stage and vice-versa. + spv::ExecutionModel ShaderStageToExecutionModel(SingleShaderStage stage); + SingleShaderStage ExecutionModelToShaderStage(spv::ExecutionModel model); + + // Returns the texture view dimension for corresponding to (dim, arrayed). + wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed); + + // Returns the texture format corresponding to format. + wgpu::TextureFormat SpirvImageFormatToTextureFormat(spv::ImageFormat format); + + // Returns the format "component type" corresponding to the SPIRV base type. + Format::Type SpirvBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType); + +} // namespace dawn_native + +#endif // DAWNNATIVE_SPIRV_UTILS_H_ diff --git a/src/dawn_native/metal/ComputePipelineMTL.mm b/src/dawn_native/metal/ComputePipelineMTL.mm index aca771d244..6c965152a4 100644 --- a/src/dawn_native/metal/ComputePipelineMTL.mm +++ b/src/dawn_native/metal/ComputePipelineMTL.mm @@ -34,8 +34,8 @@ namespace dawn_native { namespace metal { ShaderModule* computeModule = ToBackend(descriptor->computeStage.module); const char* computeEntryPoint = descriptor->computeStage.entryPoint; ShaderModule::MetalFunctionData computeData; - DAWN_TRY(computeModule->GetFunction(computeEntryPoint, SingleShaderStage::Compute, - ToBackend(GetLayout()), &computeData)); + DAWN_TRY(computeModule->CreateFunction(computeEntryPoint, SingleShaderStage::Compute, + ToBackend(GetLayout()), &computeData)); NSError* error = nil; mMtlComputePipelineState = @@ -46,7 +46,9 @@ namespace dawn_native { namespace metal { } // Copy over the local workgroup size as it is passed to dispatch explicitly in Metal - mLocalWorkgroupSize = computeData.localWorkgroupSize; + Origin3D localSize = GetStage(SingleShaderStage::Compute).metadata->localWorkgroupSize; + mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z); + mRequiresStorageBufferLength = computeData.needsStorageBufferLength; return {}; } diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm index ede97aacb6..ee5fd0405d 100644 --- a/src/dawn_native/metal/RenderPipelineMTL.mm +++ b/src/dawn_native/metal/RenderPipelineMTL.mm @@ -335,8 +335,9 @@ namespace dawn_native { namespace metal { ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module); const char* vertexEntryPoint = descriptor->vertexStage.entryPoint; ShaderModule::MetalFunctionData vertexData; - DAWN_TRY(vertexModule->GetFunction(vertexEntryPoint, SingleShaderStage::Vertex, - ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, this)); + DAWN_TRY(vertexModule->CreateFunction(vertexEntryPoint, SingleShaderStage::Vertex, + ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, + this)); descriptorMTL.vertexFunction = vertexData.function; if (vertexData.needsStorageBufferLength) { @@ -346,9 +347,9 @@ namespace dawn_native { namespace metal { ShaderModule* fragmentModule = ToBackend(descriptor->fragmentStage->module); const char* fragmentEntryPoint = descriptor->fragmentStage->entryPoint; ShaderModule::MetalFunctionData fragmentData; - DAWN_TRY(fragmentModule->GetFunction(fragmentEntryPoint, SingleShaderStage::Fragment, - ToBackend(GetLayout()), &fragmentData, - descriptor->sampleMask)); + DAWN_TRY(fragmentModule->CreateFunction(fragmentEntryPoint, SingleShaderStage::Fragment, + ToBackend(GetLayout()), &fragmentData, + descriptor->sampleMask)); descriptorMTL.fragmentFunction = fragmentData.function; if (fragmentData.needsStorageBufferLength) { diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h index 3a211e68c2..4e543c7c0f 100644 --- a/src/dawn_native/metal/ShaderModuleMTL.h +++ b/src/dawn_native/metal/ShaderModuleMTL.h @@ -38,18 +38,17 @@ namespace dawn_native { namespace metal { struct MetalFunctionData { id function = nil; - MTLSize localWorkgroupSize; bool needsStorageBufferLength; ~MetalFunctionData() { [function release]; } }; - MaybeError GetFunction(const char* functionName, - SingleShaderStage functionStage, - const PipelineLayout* layout, - MetalFunctionData* out, - uint32_t sampleMask = 0xFFFFFFFF, - const RenderPipeline* renderPipeline = nullptr); + MaybeError CreateFunction(const char* entryPointName, + SingleShaderStage stage, + const PipelineLayout* layout, + MetalFunctionData* out, + uint32_t sampleMask = 0xFFFFFFFF, + const RenderPipeline* renderPipeline = nullptr); private: ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm index 047d4539d1..2588aa051f 100644 --- a/src/dawn_native/metal/ShaderModuleMTL.mm +++ b/src/dawn_native/metal/ShaderModuleMTL.mm @@ -15,6 +15,7 @@ #include "dawn_native/metal/ShaderModuleMTL.h" #include "dawn_native/BindGroupLayout.h" +#include "dawn_native/SpirvUtils.h" #include "dawn_native/metal/DeviceMTL.h" #include "dawn_native/metal/PipelineLayoutMTL.h" #include "dawn_native/metal/RenderPipelineMTL.h" @@ -25,22 +26,6 @@ namespace dawn_native { namespace metal { - namespace { - - spv::ExecutionModel SpirvExecutionModelForStage(SingleShaderStage stage) { - switch (stage) { - case SingleShaderStage::Vertex: - return spv::ExecutionModelVertex; - case SingleShaderStage::Fragment: - return spv::ExecutionModelFragment; - case SingleShaderStage::Compute: - return spv::ExecutionModelGLCompute; - default: - UNREACHABLE(); - } - } - } // namespace - // static ResultOrError ShaderModule::Create(Device* device, const ShaderModuleDescriptor* descriptor) { @@ -57,25 +42,26 @@ namespace dawn_native { namespace metal { return InitializeBase(); } - MaybeError ShaderModule::GetFunction(const char* functionName, - SingleShaderStage functionStage, - const PipelineLayout* layout, - ShaderModule::MetalFunctionData* out, - uint32_t sampleMask, - const RenderPipeline* renderPipeline) { + MaybeError ShaderModule::CreateFunction(const char* entryPointName, + SingleShaderStage stage, + const PipelineLayout* layout, + ShaderModule::MetalFunctionData* out, + uint32_t sampleMask, + const RenderPipeline* renderPipeline) { ASSERT(!IsError()); ASSERT(out); const std::vector* spirv = &GetSpirv(); + spv::ExecutionModel executionModel = ShaderStageToExecutionModel(stage); #ifdef DAWN_ENABLE_WGSL // Use set 4 since it is bigger than what users can access currently static const uint32_t kPullingBufferBindingSet = 4; std::vector pullingSpirv; if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && - functionStage == SingleShaderStage::Vertex) { + stage == SingleShaderStage::Vertex) { DAWN_TRY_ASSIGN(pullingSpirv, GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(), - functionName, kPullingBufferBindingSet)); + entryPointName, kPullingBufferBindingSet)); spirv = &pullingSpirv; } #endif @@ -99,6 +85,7 @@ namespace dawn_native { namespace metal { spirv_cross::CompilerMSL compiler(*spirv); compiler.set_msl_options(options_msl); + compiler.set_entry_point(entryPointName, executionModel); // By default SPIRV-Cross will give MSL resources indices in increasing order. // To make the MSL indices match the indices chosen in the PipelineLayout, we build @@ -116,30 +103,33 @@ namespace dawn_native { namespace metal { const BindingInfo& bindingInfo = layout->GetBindGroupLayout(group)->GetBindingInfo(bindingIndex); - for (auto stage : IterateStages(bindingInfo.visibility)) { - uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex]; - spirv_cross::MSLResourceBinding mslBinding; - mslBinding.stage = SpirvExecutionModelForStage(stage); - mslBinding.desc_set = static_cast(group); - mslBinding.binding = static_cast(bindingNumber); - mslBinding.msl_buffer = mslBinding.msl_texture = mslBinding.msl_sampler = - shaderIndex; - - compiler.add_msl_resource_binding(mslBinding); + if (!(bindingInfo.visibility & StageBit(stage))) { + continue; } + + uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex]; + + spirv_cross::MSLResourceBinding mslBinding; + mslBinding.stage = executionModel; + mslBinding.desc_set = static_cast(group); + mslBinding.binding = static_cast(bindingNumber); + mslBinding.msl_buffer = mslBinding.msl_texture = mslBinding.msl_sampler = + shaderIndex; + + compiler.add_msl_resource_binding(mslBinding); } } #ifdef DAWN_ENABLE_WGSL // Add vertex buffers bound as storage buffers if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && - functionStage == SingleShaderStage::Vertex) { + stage == SingleShaderStage::Vertex) { for (uint32_t dawnIndex : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(dawnIndex); spirv_cross::MSLResourceBinding mslBinding; - mslBinding.stage = SpirvExecutionModelForStage(SingleShaderStage::Vertex); + mslBinding.stage = spv::ExecutionModelVertex; mslBinding.desc_set = kPullingBufferBindingSet; mslBinding.binding = dawnIndex; mslBinding.msl_buffer = metalIndex; @@ -148,17 +138,17 @@ namespace dawn_native { namespace metal { } #endif - { - spv::ExecutionModel executionModel = SpirvExecutionModelForStage(functionStage); - auto size = compiler.get_entry_point(functionName, executionModel).workgroup_size; - out->localWorkgroupSize = MTLSizeMake(size.x, size.y, size.z); - } - { // SPIRV-Cross also supports re-ordering attributes but it seems to do the correct thing // by default. NSString* mslSource; std::string msl = compiler.compile(); + + // Some entry point names are forbidden in MSL so SPIRV-Cross modifies them. Query the + // modified entryPointName from it. + const std::string& modifiedEntryPointName = + compiler.get_entry_point(entryPointName, executionModel).name; + // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall // category. -Wunused-variable in particular comes up a lot in generated code, and some // (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead @@ -183,18 +173,7 @@ namespace dawn_native { namespace metal { } } - // TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like - // clean_func_name: - // https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213 - const char* metalFunctionName = functionName; - if (strcmp(metalFunctionName, "main") == 0) { - metalFunctionName = "main0"; - } - if (strcmp(metalFunctionName, "saturate") == 0) { - metalFunctionName = "saturate0"; - } - - NSString* name = [[NSString alloc] initWithUTF8String:metalFunctionName]; + NSString* name = [[NSString alloc] initWithUTF8String:modifiedEntryPointName.c_str()]; out->function = [library newFunctionWithName:name]; [library release]; } @@ -202,7 +181,7 @@ namespace dawn_native { namespace metal { out->needsStorageBufferLength = compiler.needs_buffer_size_buffer(); if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && - GetEntryPoint(functionName, functionStage).usedVertexAttributes.any()) { + GetEntryPoint(entryPointName, stage).usedVertexAttributes.any()) { out->needsStorageBufferLength = true; }