diff --git a/src/dawn_native/PipelineLayout.cpp b/src/dawn_native/PipelineLayout.cpp index a3c2eae258..ef163a5090 100644 --- a/src/dawn_native/PipelineLayout.cpp +++ b/src/dawn_native/PipelineLayout.cpp @@ -144,7 +144,7 @@ namespace dawn_native { // Does the trivial conversions from a ShaderBindingInfo to a BindGroupLayoutEntry auto ConvertMetadataToEntry = - [](const EntryPointMetadata::ShaderBindingInfo& shaderBinding, + [](const ShaderBindingInfo& shaderBinding, const ExternalTextureBindingLayout* externalTextureBindingEntry) -> BindGroupLayoutEntry { BindGroupLayoutEntry entry = {}; @@ -242,7 +242,7 @@ namespace dawn_native { for (BindGroupIndex group(0); group < metadata.bindings.size(); ++group) { for (const auto& bindingIt : metadata.bindings[group]) { BindingNumber bindingNumber = bindingIt.first; - const EntryPointMetadata::ShaderBindingInfo& shaderBinding = bindingIt.second; + const ShaderBindingInfo& shaderBinding = bindingIt.second; // Create the BindGroupLayoutEntry BindGroupLayoutEntry entry = diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 0941922823..c7a9ba1950 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -431,9 +431,8 @@ namespace dawn_native { return std::move(program); } - std::vector GetBindGroupMinBufferSizes( - const EntryPointMetadata::BindingGroupInfoMap& shaderBindings, - const BindGroupLayoutBase* layout) { + std::vector GetBindGroupMinBufferSizes(const BindingGroupInfoMap& shaderBindings, + const BindGroupLayoutBase* layout) { std::vector requiredBufferSizes(layout->GetUnverifiedBufferCount()); uint32_t packedIdx = 0; @@ -471,7 +470,7 @@ namespace dawn_native { // corresponding binding in the BindGroupLayout, if it exists. for (const auto& it : entryPoint.bindings[group]) { BindingNumber bindingNumber = it.first; - const EntryPointMetadata::ShaderBindingInfo& shaderInfo = it.second; + const ShaderBindingInfo& shaderInfo = it.second; const auto& bindingIt = layoutBindings.find(bindingNumber); if (bindingIt == layoutBindings.end()) { @@ -825,12 +824,12 @@ namespace dawn_native { } const auto& it = metadata->bindings[bindGroupIndex].emplace( - bindingNumber, EntryPointMetadata::ShaderBindingInfo{}); + bindingNumber, ShaderBindingInfo{}); if (!it.second) { return DAWN_VALIDATION_ERROR("Shader has duplicate bindings"); } - EntryPointMetadata::ShaderBindingInfo* info = &it.first->second; + ShaderBindingInfo* info = &it.first->second; info->bindingType = TintResourceTypeToBindingInfoType(resource.resource_type); switch (info->bindingType) { diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index 6240cecd08..a4e8d3dce7 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -115,45 +115,46 @@ namespace dawn_native { BindGroupIndex pullingBufferBindingSet, tint::transform::DataMap* transformInputs); + // Mirrors wgpu::SamplerBindingLayout but instead stores a single boolean + // for isComparison instead of a wgpu::SamplerBindingType enum. + struct ShaderSamplerBindingInfo { + bool isComparison; + }; + + // Mirrors wgpu::TextureBindingLayout but instead has a set of compatible sampleTypes + // instead of a single enum. + struct ShaderTextureBindingInfo { + SampleTypeBit compatibleSampleTypes; + wgpu::TextureViewDimension viewDimension; + bool multisampled; + }; + + // Per-binding shader metadata contains some SPIRV specific information in addition to + // most of the frontend per-binding information. + struct ShaderBindingInfo { + // The SPIRV ID of the resource. + uint32_t id; + uint32_t base_type_id; + + BindingNumber binding; + BindingInfoType bindingType; + + BufferBindingLayout buffer; + ShaderSamplerBindingInfo sampler; + ShaderTextureBindingInfo texture; + StorageTextureBindingLayout storageTexture; + }; + + using BindingGroupInfoMap = std::map; + using BindingInfoArray = ityp::array; + // Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are // stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so // pointers to EntryPointMetadata are safe to store as long as you also keep a Ref to the // ShaderModuleBase. struct EntryPointMetadata { - // Mirrors wgpu::SamplerBindingLayout but instead stores a single boolean - // for isComparison instead of a wgpu::SamplerBindingType enum. - struct ShaderSamplerBindingInfo { - bool isComparison; - }; - - // Mirrors wgpu::TextureBindingLayout but instead has a set of compatible sampleTypes - // instead of a single enum. - struct ShaderTextureBindingInfo { - SampleTypeBit compatibleSampleTypes; - wgpu::TextureViewDimension viewDimension; - bool multisampled; - }; - - // Per-binding shader metadata contains some SPIRV specific information in addition to - // most of the frontend per-binding information. - struct ShaderBindingInfo { - // The SPIRV ID of the resource. - uint32_t id; - uint32_t base_type_id; - - BindingNumber binding; - BindingInfoType bindingType; - - BufferBindingLayout buffer; - ShaderSamplerBindingInfo sampler; - ShaderTextureBindingInfo texture; - StorageTextureBindingLayout storageTexture; - }; - // bindings[G][B] is the reflection data for the binding defined with // [[group=G, binding=B]] in WGSL / SPIRV. - using BindingGroupInfoMap = std::map; - using BindingInfoArray = ityp::array; BindingInfoArray bindings; struct SamplerTexturePair { diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp index 3a598901e5..eac2a4f29a 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp @@ -195,8 +195,7 @@ namespace dawn_native { namespace d3d12 { BindingRemapper::BindingPoints bindingPoints; BindingRemapper::AccessControls accessControls; - const EntryPointMetadata::BindingInfoArray& moduleBindingInfo = - GetEntryPoint(entryPointName).bindings; + const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings; // d3d12::BindGroupLayout packs the bindings per HLSL register-space. // We modify the Tint AST to make the "bindings" decoration match the diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp index 1ea18a33bd..9412a2e9e9 100644 --- a/src/dawn_native/opengl/ShaderModuleGL.cpp +++ b/src/dawn_native/opengl/ShaderModuleGL.cpp @@ -66,31 +66,19 @@ namespace dawn_native { namespace opengl { return o.str(); } - ResultOrError> ExtractSpirvInfo( + ResultOrError> ExtractSpirvInfo( const DeviceBase* device, const spirv_cross::Compiler& compiler, const std::string& entryPointName, SingleShaderStage stage) { - std::unique_ptr metadata = std::make_unique(); - metadata->stage = stage; - const auto& resources = compiler.get_shader_resources(); - if (resources.push_constant_buffers.size() > 0) { - return DAWN_VALIDATION_ERROR("Push constants aren't supported."); - } - - if (resources.sampled_images.size() > 0) { - return DAWN_VALIDATION_ERROR("Combined images and samplers aren't supported."); - } - // Fill in bindingInfo with the SPIRV bindings auto ExtractResourcesBinding = [](const DeviceBase* device, const spirv_cross::SmallVector& resources, const spirv_cross::Compiler& compiler, BindingInfoType bindingType, - EntryPointMetadata::BindingInfoArray* metadataBindings, - bool isStorageBuffer = false) -> MaybeError { + BindingInfoArray* bindings, bool isStorageBuffer = false) -> MaybeError { for (const auto& resource : resources) { if (!compiler.get_decoration_bitset(resource.id).get(spv::DecorationBinding)) { return DAWN_VALIDATION_ERROR("No Binding decoration set for resource"); @@ -110,13 +98,13 @@ namespace dawn_native { namespace opengl { return DAWN_VALIDATION_ERROR("Bind group index over limits in the SPIRV"); } - const auto& it = (*metadataBindings)[bindGroupIndex].emplace( - bindingNumber, EntryPointMetadata::ShaderBindingInfo{}); + const auto& it = + (*bindings)[bindGroupIndex].emplace(bindingNumber, ShaderBindingInfo{}); if (!it.second) { return DAWN_VALIDATION_ERROR("Shader has duplicate bindings"); } - EntryPointMetadata::ShaderBindingInfo* info = &it.first->second; + ShaderBindingInfo* info = &it.first->second; info->id = resource.id; info->base_type_id = resource.base_type_id; info->bindingType = bindingType; @@ -220,92 +208,21 @@ namespace dawn_native { namespace opengl { return {}; }; + std::unique_ptr resultBindings = std::make_unique(); + BindingInfoArray* bindings = resultBindings.get(); DAWN_TRY(ExtractResourcesBinding(device, resources.uniform_buffers, compiler, - BindingInfoType::Buffer, &metadata->bindings)); + BindingInfoType::Buffer, bindings)); DAWN_TRY(ExtractResourcesBinding(device, resources.separate_images, compiler, - BindingInfoType::Texture, &metadata->bindings)); + BindingInfoType::Texture, bindings)); DAWN_TRY(ExtractResourcesBinding(device, resources.separate_samplers, compiler, - BindingInfoType::Sampler, &metadata->bindings)); + BindingInfoType::Sampler, bindings)); DAWN_TRY(ExtractResourcesBinding(device, resources.storage_buffers, compiler, - BindingInfoType::Buffer, &metadata->bindings, true)); + BindingInfoType::Buffer, bindings, true)); // ReadonlyStorageTexture is used as a tag to do general storage texture handling. DAWN_TRY(ExtractResourcesBinding(device, resources.storage_images, compiler, - BindingInfoType::StorageTexture, &metadata->bindings)); + BindingInfoType::StorageTexture, resultBindings.get())); - // Extract the vertex attributes - if (stage == SingleShaderStage::Vertex) { - for (const auto& attrib : resources.stage_inputs) { - if (!(compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation))) { - return DAWN_VALIDATION_ERROR( - "Unable to find Location decoration for Vertex input"); - } - uint32_t unsanitizedLocation = - compiler.get_decoration(attrib.id, spv::DecorationLocation); - - if (unsanitizedLocation >= kMaxVertexAttributes) { - return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV"); - } - VertexAttributeLocation location(static_cast(unsanitizedLocation)); - - spirv_cross::SPIRType::BaseType inputBaseType = - compiler.get_type(attrib.base_type_id).basetype; - metadata->vertexInputBaseTypes[location] = - SpirvBaseTypeToVertexFormatBaseType(inputBaseType); - metadata->usedVertexInputs.set(location); - } - - // Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives - // them all the location 0, causing a compile error. - for (const auto& attrib : resources.stage_outputs) { - if (!compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)) { - return DAWN_VALIDATION_ERROR("Need location qualifier on vertex output"); - } - } - } - - if (stage == SingleShaderStage::Fragment) { - // Without a location qualifier on vertex inputs, spirv_cross::CompilerMSL gives - // them all the location 0, causing a compile error. - for (const auto& attrib : resources.stage_inputs) { - if (!compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)) { - return DAWN_VALIDATION_ERROR("Need location qualifier on fragment input"); - } - } - - for (const auto& fragmentOutput : resources.stage_outputs) { - if (!compiler.get_decoration_bitset(fragmentOutput.id) - .get(spv::DecorationLocation)) { - return DAWN_VALIDATION_ERROR( - "Unable to find Location decoration for Fragment output"); - } - uint32_t unsanitizedAttachment = - compiler.get_decoration(fragmentOutput.id, spv::DecorationLocation); - - if (unsanitizedAttachment >= kMaxColorAttachments) { - return DAWN_VALIDATION_ERROR( - "Fragment output index must be less than max number of color " - "attachments"); - } - ColorAttachmentIndex attachment(static_cast(unsanitizedAttachment)); - - spirv_cross::SPIRType::BaseType shaderFragmentOutputBaseType = - compiler.get_type(fragmentOutput.base_type_id).basetype; - // spriv path so temporarily always set to 4u to always pass validation - metadata->fragmentOutputVariables[attachment] = { - SpirvBaseTypeToTextureComponentType(shaderFragmentOutputBaseType), 4u}; - metadata->fragmentOutputsWritten.set(attachment); - } - } - - if (stage == SingleShaderStage::Compute) { - const spirv_cross::SPIREntryPoint& spirEntryPoint = - compiler.get_entry_point(entryPointName, spv::ExecutionModelGLCompute); - metadata->localWorkgroupSize.x = spirEntryPoint.workgroup_size.x; - metadata->localWorkgroupSize.y = spirEntryPoint.workgroup_size.y; - metadata->localWorkgroupSize.z = spirEntryPoint.workgroup_size.z; - } - - return {std::move(metadata)}; + return {std::move(resultBindings)}; } // static @@ -322,10 +239,10 @@ namespace dawn_native { namespace opengl { } // static - ResultOrError ShaderModule::ReflectShaderUsingSPIRVCross( + ResultOrError ShaderModule::ReflectShaderUsingSPIRVCross( DeviceBase* device, const std::vector& spirv) { - EntryPointMetadataTable result; + BindingInfoArrayTable result; spirv_cross::Compiler compiler(spirv); for (const spirv_cross::EntryPoint& entryPoint : compiler.get_entry_points_and_stages()) { ASSERT(result.count(entryPoint.name) == 0); @@ -333,9 +250,9 @@ namespace dawn_native { namespace opengl { SingleShaderStage stage = ExecutionModelToShaderStage(entryPoint.execution_model); compiler.set_entry_point(entryPoint.name, entryPoint.execution_model); - std::unique_ptr metadata; - DAWN_TRY_ASSIGN(metadata, ExtractSpirvInfo(device, compiler, entryPoint.name, stage)); - result[entryPoint.name] = std::move(metadata); + std::unique_ptr bindings; + DAWN_TRY_ASSIGN(bindings, ExtractSpirvInfo(device, compiler, entryPoint.name, stage)); + result[entryPoint.name] = std::move(bindings); } return std::move(result); } @@ -355,7 +272,7 @@ namespace dawn_native { namespace opengl { return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); } - DAWN_TRY_ASSIGN(mGLEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), result.spirv)); + DAWN_TRY_ASSIGN(mGLBindings, ReflectShaderUsingSPIRVCross(GetDevice(), result.spirv)); return {}; } @@ -445,8 +362,7 @@ namespace dawn_native { namespace opengl { compiler.set_name(combined.combined_id, info->GetName()); } - const EntryPointMetadata::BindingInfoArray& bindingInfo = - (*mGLEntryPoints.at(entryPointName)).bindings; + const BindingInfoArray& bindingInfo = *(mGLBindings.at(entryPointName)); // Change binding names to be "dawn_binding__". // Also unsets the SPIRV "Binding" decoration as it outputs "layout(binding=)" which diff --git a/src/dawn_native/opengl/ShaderModuleGL.h b/src/dawn_native/opengl/ShaderModuleGL.h index 78a2f2a272..d955225357 100644 --- a/src/dawn_native/opengl/ShaderModuleGL.h +++ b/src/dawn_native/opengl/ShaderModuleGL.h @@ -44,6 +44,9 @@ namespace dawn_native { namespace opengl { using CombinedSamplerInfo = std::vector; + using BindingInfoArrayTable = + std::unordered_map>; + class ShaderModule final : public ShaderModuleBase { public: static ResultOrError> Create(Device* device, @@ -60,11 +63,11 @@ namespace dawn_native { namespace opengl { ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ~ShaderModule() override = default; MaybeError Initialize(ShaderModuleParseResult* parseResult); - static ResultOrError ReflectShaderUsingSPIRVCross( + static ResultOrError ReflectShaderUsingSPIRVCross( DeviceBase* device, const std::vector& spirv); - EntryPointMetadataTable mGLEntryPoints; + BindingInfoArrayTable mGLBindings; }; }} // namespace dawn_native::opengl diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp index 0282b77126..798424e560 100644 --- a/src/dawn_native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp @@ -121,8 +121,7 @@ namespace dawn_native { namespace vulkan { BindingRemapper::BindingPoints bindingPoints; BindingRemapper::AccessControls accessControls; - const EntryPointMetadata::BindingInfoArray& moduleBindingInfo = - GetEntryPoint(entryPointName).bindings; + const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings; for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));