Fix D3D12 descriptor renumbering (#218)

Previously, the renumbering loop would sometimes iterate in the wrong order. To fix this, instead use the binding info already correctly extracted by `ExtractSpirvInfo`.

Fixes ComputeCopyStorageBufferTests.BasicTest/D3D12.
This commit is contained in:
Kai Ninomiya 2018-07-10 17:25:48 -07:00 committed by GitHub
parent 21006bbe6f
commit 0582bfdfda
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 39 additions and 17 deletions

View File

@ -14,10 +14,36 @@
#include "backend/d3d12/ShaderModuleD3D12.h"
#include "common/Assert.h"
#include <spirv-cross/spirv_hlsl.hpp>
namespace backend { namespace d3d12 {
// TODO(kainino@chromium.org): Consider replacing this with a generic enum_map.
template <typename T>
class BindingTypeMap {
public:
T& operator[](nxt::BindingType type) {
switch (type) {
case nxt::BindingType::UniformBuffer:
return mMap[0];
case nxt::BindingType::Sampler:
return mMap[1];
case nxt::BindingType::SampledTexture:
return mMap[2];
case nxt::BindingType::StorageBuffer:
return mMap[3];
default:
NXT_UNREACHABLE();
}
}
private:
static constexpr int kNumBindingTypes = 4;
std::array<T, kNumBindingTypes> mMap{};
};
ShaderModule::ShaderModule(Device* device, ShaderModuleBuilder* builder)
: ShaderModuleBase(builder), mDevice(device) {
spirv_cross::CompilerHLSL compiler(builder->AcquireSpirv());
@ -33,26 +59,22 @@ namespace backend { namespace d3d12 {
ExtractSpirvInfo(compiler);
// rename bindings so that each register type b/u/t/s starts at 0 and then offset by
// rename bindings so that each register type c/u/t/s starts at 0 and then offset by
// kMaxBindingsPerGroup * bindGroupIndex
auto RenumberBindings = [&](std::vector<spirv_cross::Resource> resources) {
std::array<uint32_t, kMaxBindGroups> baseRegisters = {};
const auto& moduleBindingInfo = GetBindingInfo();
for (uint32_t group = 0; group < moduleBindingInfo.size(); ++group) {
const auto& groupBindingInfo = moduleBindingInfo[group];
for (const auto& resource : resources) {
auto bindGroupIndex =
compiler.get_decoration(resource.id, spv::DecorationDescriptorSet);
auto& baseRegister = baseRegisters[bindGroupIndex];
auto bindGroupOffset = bindGroupIndex * kMaxBindingsPerGroup;
compiler.set_decoration(resource.id, spv::DecorationBinding,
bindGroupOffset + baseRegister++);
BindingTypeMap<uint32_t> baseRegisters{};
for (const auto& bindingInfo : groupBindingInfo) {
if (bindingInfo.used) {
uint32_t& baseRegister = baseRegisters[bindingInfo.type];
uint32_t bindGroupOffset = group * kMaxBindingsPerGroup;
compiler.set_decoration(bindingInfo.id, spv::DecorationBinding,
bindGroupOffset + baseRegister++);
}
}
};
const auto& resources = compiler.get_shader_resources();
RenumberBindings(resources.uniform_buffers); // c
RenumberBindings(resources.storage_buffers); // u
RenumberBindings(resources.separate_images); // t
RenumberBindings(resources.separate_samplers); // s
}
mHlslSource = compiler.compile();
}