diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 7774c61bf5..90115f0d24 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -14,6 +14,7 @@ #include "dawn_native/ShaderModule.h" +#include "common/HashUtils.h" #include "common/VertexFormatUtils.h" #include "dawn_native/BindGroupLayout.h" #include "dawn_native/CompilationMessages.h" @@ -1376,4 +1377,11 @@ namespace dawn_native { return std::move(result); } + size_t PipelineLayoutEntryPointPairHashFunc::operator()( + const PipelineLayoutEntryPointPair& pair) const { + size_t hash = 0; + HashCombine(&hash, pair.first, pair.second); + return hash; + } + } // namespace dawn_native diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index ab3a271558..556d604c16 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -51,6 +51,11 @@ namespace dawn_native { struct EntryPointMetadata; + using PipelineLayoutEntryPointPair = std::pair; + struct PipelineLayoutEntryPointPairHashFunc { + size_t operator()(const PipelineLayoutEntryPointPair& pair) const; + }; + // A map from name to EntryPointMetadata. using EntryPointMetadataTable = std::unordered_map>; diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp index 33ea80b456..8b14210112 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp @@ -261,7 +261,7 @@ namespace dawn_native { namespace d3d12 { tint::transform::Transform::Output output = transformManager.Run(GetTintProgram(), transformInputs); - tint::Program& program = output.program; + const tint::Program& program = output.program; if (!program.IsValid()) { errorStream << "Tint program transform error: " << program.Diagnostics().str() << std::endl; diff --git a/src/dawn_native/vulkan/BindGroupLayoutVk.cpp b/src/dawn_native/vulkan/BindGroupLayoutVk.cpp index 700e850318..78f7a7a58e 100644 --- a/src/dawn_native/vulkan/BindGroupLayoutVk.cpp +++ b/src/dawn_native/vulkan/BindGroupLayoutVk.cpp @@ -89,13 +89,16 @@ namespace dawn_native { namespace vulkan { ityp::vector bindings; bindings.reserve(GetBindingCount()); + bool useBindingIndex = GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator); + for (const auto& it : GetBindingMap()) { BindingNumber bindingNumber = it.first; BindingIndex bindingIndex = it.second; const BindingInfo& bindingInfo = GetBindingInfo(bindingIndex); VkDescriptorSetLayoutBinding vkBinding; - vkBinding.binding = static_cast(bindingNumber); + vkBinding.binding = useBindingIndex ? static_cast(bindingIndex) + : static_cast(bindingNumber); vkBinding.descriptorType = VulkanDescriptorType(bindingInfo); vkBinding.descriptorCount = 1; vkBinding.stageFlags = VulkanShaderStageFlags(bindingInfo.visibility); diff --git a/src/dawn_native/vulkan/BindGroupLayoutVk.h b/src/dawn_native/vulkan/BindGroupLayoutVk.h index 72f8b698d7..cc502c94d9 100644 --- a/src/dawn_native/vulkan/BindGroupLayoutVk.h +++ b/src/dawn_native/vulkan/BindGroupLayoutVk.h @@ -43,6 +43,10 @@ namespace dawn_native { namespace vulkan { // the pools are reused when no longer used. Minimizing the number of descriptor pool allocation // is important because creating them can incur GPU memory allocation which is usually an // expensive syscall. + // + // The Vulkan BindGroupLayout is dependent on UseTintGenerator or not. + // When UseTintGenerator is on, VkDescriptorSetLayoutBinding::binding is set to BindingIndex, + // otherwise it is set to BindingNumber. class BindGroupLayout final : public BindGroupLayoutBase { public: static ResultOrError> Create( diff --git a/src/dawn_native/vulkan/BindGroupVk.cpp b/src/dawn_native/vulkan/BindGroupVk.cpp index 07653e8bf5..b2334d1095 100644 --- a/src/dawn_native/vulkan/BindGroupVk.cpp +++ b/src/dawn_native/vulkan/BindGroupVk.cpp @@ -47,6 +47,8 @@ namespace dawn_native { namespace vulkan { ityp::stack_vec writeImageInfo(bindingCount); + bool useBindingIndex = device->IsToggleEnabled(Toggle::UseTintGenerator); + uint32_t numWrites = 0; for (const auto& it : GetLayout()->GetBindingMap()) { BindingNumber bindingNumber = it.first; @@ -57,7 +59,8 @@ namespace dawn_native { namespace vulkan { write.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; write.pNext = nullptr; write.dstSet = GetHandle(); - write.dstBinding = static_cast(bindingNumber); + write.dstBinding = useBindingIndex ? static_cast(bindingIndex) + : static_cast(bindingNumber); write.dstArrayElement = 0; write.descriptorCount = 1; write.descriptorType = VulkanDescriptorType(bindingInfo); diff --git a/src/dawn_native/vulkan/BindGroupVk.h b/src/dawn_native/vulkan/BindGroupVk.h index dac780bf0b..14b6940eb9 100644 --- a/src/dawn_native/vulkan/BindGroupVk.h +++ b/src/dawn_native/vulkan/BindGroupVk.h @@ -26,6 +26,9 @@ namespace dawn_native { namespace vulkan { class Device; + // The Vulkan BindGroup is dependent on UseTintGenerator or not. + // When UseTintGenerator is on, VkWriteDescriptorSet::dstBinding is set to BindingIndex, + // otherwise it is set to BindingNumber. class BindGroup final : public BindGroupBase, public PlacementAllocated { public: static ResultOrError> Create(Device* device, diff --git a/src/dawn_native/vulkan/ComputePipelineVk.cpp b/src/dawn_native/vulkan/ComputePipelineVk.cpp index a81dee9039..322c026220 100644 --- a/src/dawn_native/vulkan/ComputePipelineVk.cpp +++ b/src/dawn_native/vulkan/ComputePipelineVk.cpp @@ -45,7 +45,15 @@ namespace dawn_native { namespace vulkan { createInfo.stage.pNext = nullptr; createInfo.stage.flags = 0; createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; - createInfo.stage.module = ToBackend(descriptor->computeStage.module)->GetHandle(); + if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) { + // Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline + DAWN_TRY_ASSIGN(createInfo.stage.module, + ToBackend(descriptor->computeStage.module) + ->GetTransformedModuleHandle(descriptor->computeStage.entryPoint, + ToBackend(GetLayout()))); + } else { + createInfo.stage.module = ToBackend(descriptor->computeStage.module)->GetHandle(); + } createInfo.stage.pName = descriptor->computeStage.entryPoint; createInfo.stage.pSpecializationInfo = nullptr; diff --git a/src/dawn_native/vulkan/RenderPipelineVk.cpp b/src/dawn_native/vulkan/RenderPipelineVk.cpp index bbb4f8eed4..a743b8d4f3 100644 --- a/src/dawn_native/vulkan/RenderPipelineVk.cpp +++ b/src/dawn_native/vulkan/RenderPipelineVk.cpp @@ -332,12 +332,27 @@ namespace dawn_native { namespace vulkan { VkPipelineShaderStageCreateInfo shaderStages[2]; { + if (device->IsToggleEnabled(Toggle::UseTintGenerator)) { + // Generate a new VkShaderModule with BindingRemapper tint transform for each + // pipeline + DAWN_TRY_ASSIGN(shaderStages[0].module, + ToBackend(descriptor->vertex.module) + ->GetTransformedModuleHandle(descriptor->vertex.entryPoint, + ToBackend(GetLayout()))); + DAWN_TRY_ASSIGN(shaderStages[1].module, + ToBackend(descriptor->fragment->module) + ->GetTransformedModuleHandle(descriptor->fragment->entryPoint, + ToBackend(GetLayout()))); + } else { + shaderStages[0].module = ToBackend(descriptor->vertex.module)->GetHandle(); + shaderStages[1].module = ToBackend(descriptor->fragment->module)->GetHandle(); + } + shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; shaderStages[0].pNext = nullptr; shaderStages[0].flags = 0; shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT; shaderStages[0].pSpecializationInfo = nullptr; - shaderStages[0].module = ToBackend(descriptor->vertex.module)->GetHandle(); shaderStages[0].pName = descriptor->vertex.entryPoint; shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; @@ -345,7 +360,6 @@ namespace dawn_native { namespace vulkan { shaderStages[1].flags = 0; shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT; shaderStages[1].pSpecializationInfo = nullptr; - shaderStages[1].module = ToBackend(descriptor->fragment->module)->GetHandle(); shaderStages[1].pName = descriptor->fragment->entryPoint; } diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp index 0fb4c610b4..b8a2b23a5c 100644 --- a/src/dawn_native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp @@ -15,8 +15,10 @@ #include "dawn_native/vulkan/ShaderModuleVk.h" #include "dawn_native/TintUtils.h" +#include "dawn_native/vulkan/BindGroupLayoutVk.h" #include "dawn_native/vulkan/DeviceVk.h" #include "dawn_native/vulkan/FencedDeleter.h" +#include "dawn_native/vulkan/PipelineLayoutVk.h" #include "dawn_native/vulkan/VulkanError.h" #include @@ -103,10 +105,106 @@ namespace dawn_native { namespace vulkan { device->GetFencedDeleter()->DeleteWhenUnused(mHandle); mHandle = VK_NULL_HANDLE; } + + for (const auto& iter : mTransformedShaderModuleCache) { + device->GetFencedDeleter()->DeleteWhenUnused(iter.second); + } } VkShaderModule ShaderModule::GetHandle() const { + ASSERT(!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)); return mHandle; } + ResultOrError ShaderModule::GetTransformedModuleHandle( + const char* entryPointName, + PipelineLayout* layout) { + ScopedTintICEHandler scopedICEHandler(GetDevice()); + + ASSERT(GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)); + + auto cacheKey = std::make_pair(layout, entryPointName); + auto iter = mTransformedShaderModuleCache.find(cacheKey); + if (iter != mTransformedShaderModuleCache.end()) { + auto cached = iter->second; + return cached; + } + + // Creation of VkShaderModule is deferred to this point when using tint generator + std::ostringstream errorStream; + errorStream << "Tint SPIR-V writer failure:" << std::endl; + + // Remap BindingNumber to BindingIndex in WGSL shader + using BindingRemapper = tint::transform::BindingRemapper; + using BindingPoint = tint::transform::BindingPoint; + BindingRemapper::BindingPoints bindingPoints; + BindingRemapper::AccessControls accessControls; + + const EntryPointMetadata::BindingInfoArray& moduleBindingInfo = + GetEntryPoint(entryPointName).bindings; + + for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { + const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); + const auto& groupBindingInfo = moduleBindingInfo[group]; + for (const auto& it : groupBindingInfo) { + BindingNumber binding = it.first; + BindingIndex bindingIndex = bgl->GetBindingIndex(binding); + BindingPoint srcBindingPoint{static_cast(group), + static_cast(binding)}; + + BindingPoint dstBindingPoint{static_cast(group), + static_cast(bindingIndex)}; + if (srcBindingPoint != dstBindingPoint) { + bindingPoints.emplace(srcBindingPoint, dstBindingPoint); + } + } + } + + tint::transform::Manager transformManager; + transformManager.append(std::make_unique()); + + tint::transform::DataMap transformInputs; + transformInputs.Add(std::move(bindingPoints), + std::move(accessControls)); + tint::transform::Transform::Output output = + transformManager.Run(GetTintProgram(), transformInputs); + + const tint::Program& program = output.program; + if (!program.IsValid()) { + errorStream << "Tint program transform error: " << program.Diagnostics().str() + << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + tint::writer::spirv::Generator generator(&program); + if (!generator.Generate()) { + errorStream << "Generator: " << generator.error() << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + std::vector spirv = generator.result(); + + // Don't save the transformedParseResult but just create a VkShaderModule + VkShaderModuleCreateInfo createInfo; + createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; + createInfo.pNext = nullptr; + createInfo.flags = 0; + std::vector vulkanSource; + createInfo.codeSize = spirv.size() * sizeof(uint32_t); + createInfo.pCode = spirv.data(); + + Device* device = ToBackend(GetDevice()); + + VkShaderModule newHandle = VK_NULL_HANDLE; + + DAWN_TRY(CheckVkSuccess( + device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &*newHandle), + "CreateShaderModule")); + if (newHandle != VK_NULL_HANDLE) { + mTransformedShaderModuleCache.emplace(cacheKey, newHandle); + } + + return newHandle; + } + }} // namespace dawn_native::vulkan diff --git a/src/dawn_native/vulkan/ShaderModuleVk.h b/src/dawn_native/vulkan/ShaderModuleVk.h index 7c0d8ef841..9dd7817d7e 100644 --- a/src/dawn_native/vulkan/ShaderModuleVk.h +++ b/src/dawn_native/vulkan/ShaderModuleVk.h @@ -23,6 +23,11 @@ namespace dawn_native { namespace vulkan { class Device; + class PipelineLayout; + + using TransformedShaderModuleCache = std::unordered_map; class ShaderModule final : public ShaderModuleBase { public: @@ -32,12 +37,19 @@ namespace dawn_native { namespace vulkan { VkShaderModule GetHandle() const; + // This is only called when UseTintGenerator is on + ResultOrError GetTransformedModuleHandle(const char* entryPointName, + PipelineLayout* layout); + private: ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ~ShaderModule() override; MaybeError Initialize(ShaderModuleParseResult* parseResult); VkShaderModule mHandle = VK_NULL_HANDLE; + + // New handles created by GetTransformedModuleHandle at pipeline creation time + TransformedShaderModuleCache mTransformedShaderModuleCache; }; }} // namespace dawn_native::vulkan