// Copyright 2017 The NXT Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "backend/ShaderModule.h" #include "backend/BindGroupLayout.h" #include "backend/Device.h" #include "backend/Pipeline.h" #include "backend/PipelineLayout.h" #include namespace backend { ShaderModuleBase::ShaderModuleBase(ShaderModuleBuilder* builder) : device(builder->device) { } void ShaderModuleBase::ExtractSpirvInfo(const spirv_cross::Compiler& compiler) { // TODO(cwallez@chromium.org): make errors here builder-level // currently errors here do not prevent the shadermodule from being used const auto& resources = compiler.get_shader_resources(); switch (compiler.get_execution_model()) { case spv::ExecutionModelVertex: executionModel = nxt::ShaderStage::Vertex; break; case spv::ExecutionModelFragment: executionModel = nxt::ShaderStage::Fragment; break; case spv::ExecutionModelGLCompute: executionModel = nxt::ShaderStage::Compute; break; default: UNREACHABLE(); } // Extract push constants pushConstants.mask.reset(); pushConstants.sizes.fill(0); pushConstants.types.fill(PushConstantType::Int); if (resources.push_constant_buffers.size() > 0) { auto interfaceBlock = resources.push_constant_buffers[0]; const auto& blockType = compiler.get_type(interfaceBlock.type_id); ASSERT(blockType.basetype == spirv_cross::SPIRType::Struct); for (uint32_t i = 0; i < blockType.member_types.size(); i++) { ASSERT(compiler.get_member_decoration_mask(blockType.self, i) & 1ull << spv::DecorationOffset); uint32_t offset = compiler.get_member_decoration(blockType.self, i, spv::DecorationOffset); ASSERT(offset % 4 == 0); offset /= 4; ASSERT(offset < kMaxPushConstants); auto memberType = compiler.get_type(blockType.member_types[i]); PushConstantType constantType; if (memberType.basetype == spirv_cross::SPIRType::Int) { constantType = PushConstantType::Int; } else if (memberType.basetype == spirv_cross::SPIRType::UInt) { constantType = PushConstantType::UInt; } else { ASSERT(memberType.basetype == spirv_cross::SPIRType::Float); constantType = PushConstantType::Float; } pushConstants.mask.set(offset); pushConstants.names[offset] = interfaceBlock.name + "." + compiler.get_member_name(blockType.self, i); pushConstants.sizes[offset] = memberType.vecsize * memberType.columns; pushConstants.types[offset] = constantType; } } // Fill in bindingInfo with the SPIRV bindings auto ExtractResourcesBinding = [this](const std::vector& resources, const spirv_cross::Compiler& compiler, nxt::BindingType type) { constexpr uint64_t requiredBindingDecorationMask = (1ull << spv::DecorationBinding) | (1ull << spv::DecorationDescriptorSet); for (const auto& resource : resources) { ASSERT((compiler.get_decoration_mask(resource.id) & requiredBindingDecorationMask) == requiredBindingDecorationMask); uint32_t binding = compiler.get_decoration(resource.id, spv::DecorationBinding); uint32_t set = compiler.get_decoration(resource.id, spv::DecorationDescriptorSet); if (binding >= kMaxBindingsPerGroup || set >= kMaxBindGroups) { device->HandleError("Binding over limits in the SPIRV"); continue; } auto& info = bindingInfo[set][binding]; info.used = true; info.id = resource.id; info.base_type_id = resource.base_type_id; info.type = type; } }; ExtractResourcesBinding(resources.uniform_buffers, compiler, nxt::BindingType::UniformBuffer); ExtractResourcesBinding(resources.separate_images, compiler, nxt::BindingType::SampledTexture); ExtractResourcesBinding(resources.separate_samplers, compiler, nxt::BindingType::Sampler); ExtractResourcesBinding(resources.storage_buffers, compiler, nxt::BindingType::StorageBuffer); // Extract the vertex attributes if (executionModel == nxt::ShaderStage::Vertex) { for (const auto& attrib : resources.stage_inputs) { ASSERT(compiler.get_decoration_mask(attrib.id) & (1ull << spv::DecorationLocation)); uint32_t location = compiler.get_decoration(attrib.id, spv::DecorationLocation); if (location >= kMaxVertexAttributes) { device->HandleError("Attribute location over limits in the SPIRV"); return; } usedVertexAttributes.set(location); } // Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives them all // the location 0, causing a compile error. for (const auto& attrib : resources.stage_outputs) { if (!(compiler.get_decoration_mask(attrib.id) & (1ull << spv::DecorationLocation))) { device->HandleError("Need location qualifier on vertex output"); return; } } } if (executionModel == nxt::ShaderStage::Fragment) { // Without a location qualifier on vertex inputs, spirv_cross::CompilerMSL gives them all // the location 0, causing a compile error. for (const auto& attrib : resources.stage_inputs) { if (!(compiler.get_decoration_mask(attrib.id) & (1ull << spv::DecorationLocation))) { device->HandleError("Need location qualifier on fragment input"); return; } } } } const ShaderModuleBase::PushConstantInfo& ShaderModuleBase::GetPushConstants() const { return pushConstants; } const ShaderModuleBase::ModuleBindingInfo& ShaderModuleBase::GetBindingInfo() const { return bindingInfo; } const std::bitset& ShaderModuleBase::GetUsedVertexAttributes() const { return usedVertexAttributes; } nxt::ShaderStage ShaderModuleBase::GetExecutionModel() const { return executionModel; } bool ShaderModuleBase::IsCompatibleWithPipelineLayout(const PipelineLayoutBase* layout) { for (size_t group = 0; group < kMaxBindGroups; ++group) { if (!IsCompatibleWithBindGroupLayout(group, layout->GetBindGroupLayout(group))) { return false; } } return true; } bool ShaderModuleBase::IsCompatibleWithBindGroupLayout(size_t group, const BindGroupLayoutBase* layout) { const auto& layoutInfo = layout->GetBindingInfo(); for (size_t i = 0; i < kMaxBindingsPerGroup; ++i) { const auto& moduleInfo = bindingInfo[group][i]; if (!moduleInfo.used) { continue; } if (moduleInfo.type != layoutInfo.types[i]) { return false; } if ((layoutInfo.visibilities[i] & StageBit(executionModel)) == 0) { return false; } } return true; } ShaderModuleBuilder::ShaderModuleBuilder(DeviceBase* device) : Builder(device) { } std::vector ShaderModuleBuilder::AcquireSpirv() { return std::move(spirv); } ShaderModuleBase* ShaderModuleBuilder::GetResultImpl() { if (spirv.size() == 0) { HandleError("Shader module needs to have the source set"); return nullptr; } return device->CreateShaderModule(this); } void ShaderModuleBuilder::SetSource(uint32_t codeSize, const uint32_t* code) { spirv.assign(code, code + codeSize); } }