// Copyright 2017 The Dawn Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "dawn_native/metal/ShaderModuleMTL.h" #include "dawn_native/BindGroupLayout.h" #include "dawn_native/SpirvUtils.h" #include "dawn_native/TintUtils.h" #include "dawn_native/metal/DeviceMTL.h" #include "dawn_native/metal/PipelineLayoutMTL.h" #include "dawn_native/metal/RenderPipelineMTL.h" #include // Tint include must be after spirv_msl.hpp, because spirv-cross has its own // version of spirv_headers. We also need to undef SPV_REVISION because SPIRV-Cross // is at 3 while spirv-headers is at 4. #undef SPV_REVISION #include #include namespace dawn_native { namespace metal { // static ResultOrError> ShaderModule::Create(Device* device, const ShaderModuleDescriptor* descriptor, ShaderModuleParseResult* parseResult) { Ref module = AcquireRef(new ShaderModule(device, descriptor)); DAWN_TRY(module->Initialize(parseResult)); return module; } ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor) : ShaderModuleBase(device, descriptor) { } MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { ScopedTintICEHandler scopedICEHandler(GetDevice()); return InitializeBase(parseResult); } ResultOrError ShaderModule::TranslateToMSLWithTint( const char* entryPointName, SingleShaderStage stage, const PipelineLayout* layout, uint32_t sampleMask, const RenderPipeline* renderPipeline, const VertexState* vertexState, std::string* remappedEntryPointName, bool* needsStorageBufferLength, bool* hasInvariantAttribute) { ScopedTintICEHandler scopedICEHandler(GetDevice()); std::ostringstream errorStream; errorStream << "Tint MSL failure:" << std::endl; // Remap BindingNumber to BindingIndex in WGSL shader using BindingRemapper = tint::transform::BindingRemapper; using BindingPoint = tint::transform::BindingPoint; BindingRemapper::BindingPoints bindingPoints; BindingRemapper::AccessControls accessControls; for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayoutBase::BindingMap& bindingMap = layout->GetBindGroupLayout(group)->GetBindingMap(); for (const auto& it : bindingMap) { BindingNumber bindingNumber = it.first; BindingIndex bindingIndex = it.second; const BindingInfo& bindingInfo = layout->GetBindGroupLayout(group)->GetBindingInfo(bindingIndex); if (!(bindingInfo.visibility & StageBit(stage))) { continue; } uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex]; BindingPoint srcBindingPoint{static_cast(group), static_cast(bindingNumber)}; BindingPoint dstBindingPoint{0, shaderIndex}; if (srcBindingPoint != dstBindingPoint) { bindingPoints.emplace(srcBindingPoint, dstBindingPoint); } } } tint::transform::Manager transformManager; tint::transform::DataMap transformInputs; if (stage == SingleShaderStage::Vertex && GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) { transformManager.Add(); AddVertexPullingTransformConfig(*vertexState, entryPointName, kPullingBufferBindingSet, &transformInputs); for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot); // Tell Tint to map (kPullingBufferBindingSet, slot) to this MSL buffer index. BindingPoint srcBindingPoint{static_cast(kPullingBufferBindingSet), static_cast(slot)}; BindingPoint dstBindingPoint{0, metalIndex}; if (srcBindingPoint != dstBindingPoint) { bindingPoints.emplace(srcBindingPoint, dstBindingPoint); } } } if (GetDevice()->IsRobustnessEnabled()) { transformManager.Add(); } transformManager.Add(); transformManager.Add(); transformInputs.Add(std::move(bindingPoints), std::move(accessControls), /* mayCollide */ true); tint::Program program; tint::transform::DataMap transformOutputs; DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs, &transformOutputs, nullptr)); if (auto* data = transformOutputs.Get()) { auto it = data->remappings.find(entryPointName); if (it == data->remappings.end()) { return DAWN_VALIDATION_ERROR("Could not find remapped name for entry point."); } *remappedEntryPointName = it->second; } else { return DAWN_VALIDATION_ERROR("Transform output missing renamer data."); } tint::writer::msl::Options options; options.buffer_size_ubo_index = kBufferLengthBufferSlot; options.fixed_sample_mask = sampleMask; auto result = tint::writer::msl::Generate(&program, options); if (!result.success) { errorStream << "Generator: " << result.error << std::endl; return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); } *needsStorageBufferLength = result.needs_storage_buffer_sizes; *hasInvariantAttribute = result.has_invariant_attribute; return std::move(result.msl); } ResultOrError ShaderModule::TranslateToMSLWithSPIRVCross( const char* entryPointName, SingleShaderStage stage, const PipelineLayout* layout, uint32_t sampleMask, const RenderPipeline* renderPipeline, const VertexState* vertexState, std::string* remappedEntryPointName, bool* needsStorageBufferLength) { const std::vector* spirv = &GetSpirv(); spv::ExecutionModel executionModel = ShaderStageToExecutionModel(stage); std::vector pullingSpirv; if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && stage == SingleShaderStage::Vertex) { if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) { DAWN_TRY_ASSIGN(pullingSpirv, GeneratePullingSpirv(GetTintProgram(), *vertexState, entryPointName, kPullingBufferBindingSet)); } else { DAWN_TRY_ASSIGN(pullingSpirv, GeneratePullingSpirv(GetSpirv(), *vertexState, entryPointName, kPullingBufferBindingSet)); } spirv = &pullingSpirv; } // If these options are changed, the values in DawnSPIRVCrossMSLFastFuzzer.cpp need to // be updated. spirv_cross::CompilerMSL::Options options_msl; // Disable PointSize builtin for https://bugs.chromium.org/p/dawn/issues/detail?id=146 // Because Metal will reject PointSize builtin if the shader is compiled into a render // pipeline that uses a non-point topology. // TODO (hao.x.li@intel.com): Remove this once WebGPU requires there is no // gl_PointSize builtin (https://github.com/gpuweb/gpuweb/issues/332). options_msl.enable_point_size_builtin = false; // Always use vertex buffer 30 (the last one in the vertex buffer table) to contain // the shader storage buffer lengths. options_msl.buffer_size_buffer_index = kBufferLengthBufferSlot; options_msl.additional_fixed_sample_mask = sampleMask; spirv_cross::CompilerMSL compiler(*spirv); compiler.set_msl_options(options_msl); compiler.set_entry_point(entryPointName, executionModel); // By default SPIRV-Cross will give MSL resources indices in increasing order. // To make the MSL indices match the indices chosen in the PipelineLayout, we build // a table of MSLResourceBinding to give to SPIRV-Cross. // Create one resource binding entry per stage per binding. for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayoutBase::BindingMap& bindingMap = layout->GetBindGroupLayout(group)->GetBindingMap(); for (const auto& it : bindingMap) { BindingNumber bindingNumber = it.first; BindingIndex bindingIndex = it.second; const BindingInfo& bindingInfo = layout->GetBindGroupLayout(group)->GetBindingInfo(bindingIndex); if (!(bindingInfo.visibility & StageBit(stage))) { continue; } uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex]; spirv_cross::MSLResourceBinding mslBinding; mslBinding.stage = executionModel; mslBinding.desc_set = static_cast(group); mslBinding.binding = static_cast(bindingNumber); mslBinding.msl_buffer = mslBinding.msl_texture = mslBinding.msl_sampler = shaderIndex; compiler.add_msl_resource_binding(mslBinding); } } // Add vertex buffers bound as storage buffers if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && stage == SingleShaderStage::Vertex) { for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot); spirv_cross::MSLResourceBinding mslBinding; mslBinding.stage = spv::ExecutionModelVertex; mslBinding.desc_set = static_cast(kPullingBufferBindingSet); mslBinding.binding = static_cast(slot); mslBinding.msl_buffer = metalIndex; compiler.add_msl_resource_binding(mslBinding); } } // SPIRV-Cross also supports re-ordering attributes but it seems to do the correct thing // by default. std::string msl = compiler.compile(); // Some entry point names are forbidden in MSL so SPIRV-Cross modifies them. Query the // modified entryPointName from it. *remappedEntryPointName = compiler.get_entry_point(entryPointName, executionModel).name; *needsStorageBufferLength = compiler.needs_buffer_size_buffer(); return std::move(msl); } MaybeError ShaderModule::CreateFunction(const char* entryPointName, SingleShaderStage stage, const PipelineLayout* layout, ShaderModule::MetalFunctionData* out, uint32_t sampleMask, const RenderPipeline* renderPipeline, const VertexState* vertexState) { ASSERT(!IsError()); ASSERT(out); // Vertex stages must specify a renderPipeline and vertexState if (stage == SingleShaderStage::Vertex) { ASSERT(renderPipeline != nullptr); ASSERT(vertexState != nullptr); } std::string remappedEntryPointName; std::string msl; bool hasInvariantAttribute = false; if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) { DAWN_TRY_ASSIGN(msl, TranslateToMSLWithTint( entryPointName, stage, layout, sampleMask, renderPipeline, vertexState, &remappedEntryPointName, &out->needsStorageBufferLength, &hasInvariantAttribute)); } else { DAWN_TRY_ASSIGN(msl, TranslateToMSLWithSPIRVCross(entryPointName, stage, layout, sampleMask, renderPipeline, vertexState, &remappedEntryPointName, &out->needsStorageBufferLength)); } // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall // category. -Wunused-variable in particular comes up a lot in generated code, and some // (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead // of a warning. msl = R"(\ #ifdef __clang__ #pragma clang diagnostic ignored "-Wall" #endif )" + msl; if (GetDevice()->IsToggleEnabled(Toggle::DumpShaders)) { std::ostringstream dumpedMsg; dumpedMsg << "/* Dumped generated MSL */" << std::endl << msl; GetDevice()->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); } NSRef mslSource = AcquireNSRef([[NSString alloc] initWithUTF8String:msl.c_str()]); NSRef compileOptions = AcquireNSRef([[MTLCompileOptions alloc] init]); if (hasInvariantAttribute) { if (@available(macOS 11.0, iOS 13.0, *)) { (*compileOptions).preserveInvariance = true; } } auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice(); NSError* error = nullptr; NSPRef> library = AcquireNSPRef([mtlDevice newLibraryWithSource:mslSource.Get() options:compileOptions.Get() error:&error]); if (error != nullptr) { if (error.code != MTLLibraryErrorCompileWarning) { const char* errorString = [error.localizedDescription UTF8String]; return DAWN_VALIDATION_ERROR(std::string("Unable to create library object: ") + errorString); } } NSRef name = AcquireNSRef([[NSString alloc] initWithUTF8String:remappedEntryPointName.c_str()]); out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]); if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && GetEntryPoint(entryPointName).usedVertexAttributes.any()) { out->needsStorageBufferLength = true; } return {}; } }} // namespace dawn_native::metal