From 0582bfdfda914441a9ee4b53f4d749153c096c6a Mon Sep 17 00:00:00 2001 From: Kai Ninomiya Date: Tue, 10 Jul 2018 17:25:48 -0700 Subject: [PATCH] Fix D3D12 descriptor renumbering (#218) Previously, the renumbering loop would sometimes iterate in the wrong order. To fix this, instead use the binding info already correctly extracted by `ExtractSpirvInfo`. Fixes ComputeCopyStorageBufferTests.BasicTest/D3D12. --- src/backend/d3d12/ShaderModuleD3D12.cpp | 56 +++++++++++++++++-------- 1 file changed, 39 insertions(+), 17 deletions(-) diff --git a/src/backend/d3d12/ShaderModuleD3D12.cpp b/src/backend/d3d12/ShaderModuleD3D12.cpp index d2968fefac..0e86600c21 100644 --- a/src/backend/d3d12/ShaderModuleD3D12.cpp +++ b/src/backend/d3d12/ShaderModuleD3D12.cpp @@ -14,10 +14,36 @@ #include "backend/d3d12/ShaderModuleD3D12.h" +#include "common/Assert.h" + #include namespace backend { namespace d3d12 { + // TODO(kainino@chromium.org): Consider replacing this with a generic enum_map. + template + class BindingTypeMap { + public: + T& operator[](nxt::BindingType type) { + switch (type) { + case nxt::BindingType::UniformBuffer: + return mMap[0]; + case nxt::BindingType::Sampler: + return mMap[1]; + case nxt::BindingType::SampledTexture: + return mMap[2]; + case nxt::BindingType::StorageBuffer: + return mMap[3]; + default: + NXT_UNREACHABLE(); + } + } + + private: + static constexpr int kNumBindingTypes = 4; + std::array mMap{}; + }; + ShaderModule::ShaderModule(Device* device, ShaderModuleBuilder* builder) : ShaderModuleBase(builder), mDevice(device) { spirv_cross::CompilerHLSL compiler(builder->AcquireSpirv()); @@ -33,26 +59,22 @@ namespace backend { namespace d3d12 { ExtractSpirvInfo(compiler); - // rename bindings so that each register type b/u/t/s starts at 0 and then offset by + // rename bindings so that each register type c/u/t/s starts at 0 and then offset by // kMaxBindingsPerGroup * bindGroupIndex - auto RenumberBindings = [&](std::vector resources) { - std::array baseRegisters = {}; + const auto& moduleBindingInfo = GetBindingInfo(); + for (uint32_t group = 0; group < moduleBindingInfo.size(); ++group) { + const auto& groupBindingInfo = moduleBindingInfo[group]; - for (const auto& resource : resources) { - auto bindGroupIndex = - compiler.get_decoration(resource.id, spv::DecorationDescriptorSet); - auto& baseRegister = baseRegisters[bindGroupIndex]; - auto bindGroupOffset = bindGroupIndex * kMaxBindingsPerGroup; - compiler.set_decoration(resource.id, spv::DecorationBinding, - bindGroupOffset + baseRegister++); + BindingTypeMap baseRegisters{}; + for (const auto& bindingInfo : groupBindingInfo) { + if (bindingInfo.used) { + uint32_t& baseRegister = baseRegisters[bindingInfo.type]; + uint32_t bindGroupOffset = group * kMaxBindingsPerGroup; + compiler.set_decoration(bindingInfo.id, spv::DecorationBinding, + bindGroupOffset + baseRegister++); + } } - }; - - const auto& resources = compiler.get_shader_resources(); - RenumberBindings(resources.uniform_buffers); // c - RenumberBindings(resources.storage_buffers); // u - RenumberBindings(resources.separate_images); // t - RenumberBindings(resources.separate_samplers); // s + } mHlslSource = compiler.compile(); }