Remap BindGroup bindingIndex for vulkan backend when using Tint Generator

Bug: dawn:750
Change-Id: I239f5544a5822422d61a249f2ef028df326f90ed
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/47380
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Shrek Shao <shrekshao@google.com>
This commit is contained in:
shrekshao 2021-04-13 15:38:24 +00:00 committed by Commit Bot service account
parent 9e0debd91e
commit 417d91cd1e
11 changed files with 164 additions and 6 deletions

View File

@ -14,6 +14,7 @@
#include "dawn_native/ShaderModule.h"
#include "common/HashUtils.h"
#include "common/VertexFormatUtils.h"
#include "dawn_native/BindGroupLayout.h"
#include "dawn_native/CompilationMessages.h"
@ -1376,4 +1377,11 @@ namespace dawn_native {
return std::move(result);
}
size_t PipelineLayoutEntryPointPairHashFunc::operator()(
const PipelineLayoutEntryPointPair& pair) const {
size_t hash = 0;
HashCombine(&hash, pair.first, pair.second);
return hash;
}
} // namespace dawn_native

View File

@ -51,6 +51,11 @@ namespace dawn_native {
struct EntryPointMetadata;
using PipelineLayoutEntryPointPair = std::pair<PipelineLayoutBase*, std::string>;
struct PipelineLayoutEntryPointPairHashFunc {
size_t operator()(const PipelineLayoutEntryPointPair& pair) const;
};
// A map from name to EntryPointMetadata.
using EntryPointMetadataTable =
std::unordered_map<std::string, std::unique_ptr<EntryPointMetadata>>;

View File

@ -261,7 +261,7 @@ namespace dawn_native { namespace d3d12 {
tint::transform::Transform::Output output =
transformManager.Run(GetTintProgram(), transformInputs);
tint::Program& program = output.program;
const tint::Program& program = output.program;
if (!program.IsValid()) {
errorStream << "Tint program transform error: " << program.Diagnostics().str()
<< std::endl;

View File

@ -89,13 +89,16 @@ namespace dawn_native { namespace vulkan {
ityp::vector<BindingIndex, VkDescriptorSetLayoutBinding> bindings;
bindings.reserve(GetBindingCount());
bool useBindingIndex = GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator);
for (const auto& it : GetBindingMap()) {
BindingNumber bindingNumber = it.first;
BindingIndex bindingIndex = it.second;
const BindingInfo& bindingInfo = GetBindingInfo(bindingIndex);
VkDescriptorSetLayoutBinding vkBinding;
vkBinding.binding = static_cast<uint32_t>(bindingNumber);
vkBinding.binding = useBindingIndex ? static_cast<uint32_t>(bindingIndex)
: static_cast<uint32_t>(bindingNumber);
vkBinding.descriptorType = VulkanDescriptorType(bindingInfo);
vkBinding.descriptorCount = 1;
vkBinding.stageFlags = VulkanShaderStageFlags(bindingInfo.visibility);

View File

@ -43,6 +43,10 @@ namespace dawn_native { namespace vulkan {
// the pools are reused when no longer used. Minimizing the number of descriptor pool allocation
// is important because creating them can incur GPU memory allocation which is usually an
// expensive syscall.
//
// The Vulkan BindGroupLayout is dependent on UseTintGenerator or not.
// When UseTintGenerator is on, VkDescriptorSetLayoutBinding::binding is set to BindingIndex,
// otherwise it is set to BindingNumber.
class BindGroupLayout final : public BindGroupLayoutBase {
public:
static ResultOrError<Ref<BindGroupLayout>> Create(

View File

@ -47,6 +47,8 @@ namespace dawn_native { namespace vulkan {
ityp::stack_vec<uint32_t, VkDescriptorImageInfo, kMaxOptimalBindingsPerGroup>
writeImageInfo(bindingCount);
bool useBindingIndex = device->IsToggleEnabled(Toggle::UseTintGenerator);
uint32_t numWrites = 0;
for (const auto& it : GetLayout()->GetBindingMap()) {
BindingNumber bindingNumber = it.first;
@ -57,7 +59,8 @@ namespace dawn_native { namespace vulkan {
write.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
write.pNext = nullptr;
write.dstSet = GetHandle();
write.dstBinding = static_cast<uint32_t>(bindingNumber);
write.dstBinding = useBindingIndex ? static_cast<uint32_t>(bindingIndex)
: static_cast<uint32_t>(bindingNumber);
write.dstArrayElement = 0;
write.descriptorCount = 1;
write.descriptorType = VulkanDescriptorType(bindingInfo);

View File

@ -26,6 +26,9 @@ namespace dawn_native { namespace vulkan {
class Device;
// The Vulkan BindGroup is dependent on UseTintGenerator or not.
// When UseTintGenerator is on, VkWriteDescriptorSet::dstBinding is set to BindingIndex,
// otherwise it is set to BindingNumber.
class BindGroup final : public BindGroupBase, public PlacementAllocated {
public:
static ResultOrError<Ref<BindGroup>> Create(Device* device,

View File

@ -45,7 +45,15 @@ namespace dawn_native { namespace vulkan {
createInfo.stage.pNext = nullptr;
createInfo.stage.flags = 0;
createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
createInfo.stage.module = ToBackend(descriptor->computeStage.module)->GetHandle();
if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) {
// Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline
DAWN_TRY_ASSIGN(createInfo.stage.module,
ToBackend(descriptor->computeStage.module)
->GetTransformedModuleHandle(descriptor->computeStage.entryPoint,
ToBackend(GetLayout())));
} else {
createInfo.stage.module = ToBackend(descriptor->computeStage.module)->GetHandle();
}
createInfo.stage.pName = descriptor->computeStage.entryPoint;
createInfo.stage.pSpecializationInfo = nullptr;

View File

@ -332,12 +332,27 @@ namespace dawn_native { namespace vulkan {
VkPipelineShaderStageCreateInfo shaderStages[2];
{
if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
// Generate a new VkShaderModule with BindingRemapper tint transform for each
// pipeline
DAWN_TRY_ASSIGN(shaderStages[0].module,
ToBackend(descriptor->vertex.module)
->GetTransformedModuleHandle(descriptor->vertex.entryPoint,
ToBackend(GetLayout())));
DAWN_TRY_ASSIGN(shaderStages[1].module,
ToBackend(descriptor->fragment->module)
->GetTransformedModuleHandle(descriptor->fragment->entryPoint,
ToBackend(GetLayout())));
} else {
shaderStages[0].module = ToBackend(descriptor->vertex.module)->GetHandle();
shaderStages[1].module = ToBackend(descriptor->fragment->module)->GetHandle();
}
shaderStages[0].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
shaderStages[0].pNext = nullptr;
shaderStages[0].flags = 0;
shaderStages[0].stage = VK_SHADER_STAGE_VERTEX_BIT;
shaderStages[0].pSpecializationInfo = nullptr;
shaderStages[0].module = ToBackend(descriptor->vertex.module)->GetHandle();
shaderStages[0].pName = descriptor->vertex.entryPoint;
shaderStages[1].sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
@ -345,7 +360,6 @@ namespace dawn_native { namespace vulkan {
shaderStages[1].flags = 0;
shaderStages[1].stage = VK_SHADER_STAGE_FRAGMENT_BIT;
shaderStages[1].pSpecializationInfo = nullptr;
shaderStages[1].module = ToBackend(descriptor->fragment->module)->GetHandle();
shaderStages[1].pName = descriptor->fragment->entryPoint;
}

View File

@ -15,8 +15,10 @@
#include "dawn_native/vulkan/ShaderModuleVk.h"
#include "dawn_native/TintUtils.h"
#include "dawn_native/vulkan/BindGroupLayoutVk.h"
#include "dawn_native/vulkan/DeviceVk.h"
#include "dawn_native/vulkan/FencedDeleter.h"
#include "dawn_native/vulkan/PipelineLayoutVk.h"
#include "dawn_native/vulkan/VulkanError.h"
#include <spirv_cross.hpp>
@ -103,10 +105,106 @@ namespace dawn_native { namespace vulkan {
device->GetFencedDeleter()->DeleteWhenUnused(mHandle);
mHandle = VK_NULL_HANDLE;
}
for (const auto& iter : mTransformedShaderModuleCache) {
device->GetFencedDeleter()->DeleteWhenUnused(iter.second);
}
}
VkShaderModule ShaderModule::GetHandle() const {
ASSERT(!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
return mHandle;
}
ResultOrError<VkShaderModule> ShaderModule::GetTransformedModuleHandle(
const char* entryPointName,
PipelineLayout* layout) {
ScopedTintICEHandler scopedICEHandler(GetDevice());
ASSERT(GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator));
auto cacheKey = std::make_pair(layout, entryPointName);
auto iter = mTransformedShaderModuleCache.find(cacheKey);
if (iter != mTransformedShaderModuleCache.end()) {
auto cached = iter->second;
return cached;
}
// Creation of VkShaderModule is deferred to this point when using tint generator
std::ostringstream errorStream;
errorStream << "Tint SPIR-V writer failure:" << std::endl;
// Remap BindingNumber to BindingIndex in WGSL shader
using BindingRemapper = tint::transform::BindingRemapper;
using BindingPoint = tint::transform::BindingPoint;
BindingRemapper::BindingPoints bindingPoints;
BindingRemapper::AccessControls accessControls;
const EntryPointMetadata::BindingInfoArray& moduleBindingInfo =
GetEntryPoint(entryPointName).bindings;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
const auto& groupBindingInfo = moduleBindingInfo[group];
for (const auto& it : groupBindingInfo) {
BindingNumber binding = it.first;
BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
static_cast<uint32_t>(binding)};
BindingPoint dstBindingPoint{static_cast<uint32_t>(group),
static_cast<uint32_t>(bindingIndex)};
if (srcBindingPoint != dstBindingPoint) {
bindingPoints.emplace(srcBindingPoint, dstBindingPoint);
}
}
}
tint::transform::Manager transformManager;
transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
tint::transform::DataMap transformInputs;
transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
std::move(accessControls));
tint::transform::Transform::Output output =
transformManager.Run(GetTintProgram(), transformInputs);
const tint::Program& program = output.program;
if (!program.IsValid()) {
errorStream << "Tint program transform error: " << program.Diagnostics().str()
<< std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
tint::writer::spirv::Generator generator(&program);
if (!generator.Generate()) {
errorStream << "Generator: " << generator.error() << std::endl;
return DAWN_VALIDATION_ERROR(errorStream.str().c_str());
}
std::vector<uint32_t> spirv = generator.result();
// Don't save the transformedParseResult but just create a VkShaderModule
VkShaderModuleCreateInfo createInfo;
createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
createInfo.pNext = nullptr;
createInfo.flags = 0;
std::vector<uint32_t> vulkanSource;
createInfo.codeSize = spirv.size() * sizeof(uint32_t);
createInfo.pCode = spirv.data();
Device* device = ToBackend(GetDevice());
VkShaderModule newHandle = VK_NULL_HANDLE;
DAWN_TRY(CheckVkSuccess(
device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &*newHandle),
"CreateShaderModule"));
if (newHandle != VK_NULL_HANDLE) {
mTransformedShaderModuleCache.emplace(cacheKey, newHandle);
}
return newHandle;
}
}} // namespace dawn_native::vulkan

View File

@ -23,6 +23,11 @@
namespace dawn_native { namespace vulkan {
class Device;
class PipelineLayout;
using TransformedShaderModuleCache = std::unordered_map<PipelineLayoutEntryPointPair,
VkShaderModule,
PipelineLayoutEntryPointPairHashFunc>;
class ShaderModule final : public ShaderModuleBase {
public:
@ -32,12 +37,19 @@ namespace dawn_native { namespace vulkan {
VkShaderModule GetHandle() const;
// This is only called when UseTintGenerator is on
ResultOrError<VkShaderModule> GetTransformedModuleHandle(const char* entryPointName,
PipelineLayout* layout);
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override;
MaybeError Initialize(ShaderModuleParseResult* parseResult);
VkShaderModule mHandle = VK_NULL_HANDLE;
// New handles created by GetTransformedModuleHandle at pipeline creation time
TransformedShaderModuleCache mTransformedShaderModuleCache;
};
}} // namespace dawn_native::vulkan