Fix D3D12 descriptor renumbering (#218)
Previously, the renumbering loop would sometimes iterate in the wrong order. To fix this, instead use the binding info already correctly extracted by `ExtractSpirvInfo`. Fixes ComputeCopyStorageBufferTests.BasicTest/D3D12.
This commit is contained in:
parent
21006bbe6f
commit
0582bfdfda
|
@ -14,10 +14,36 @@
|
|||
|
||||
#include "backend/d3d12/ShaderModuleD3D12.h"
|
||||
|
||||
#include "common/Assert.h"
|
||||
|
||||
#include <spirv-cross/spirv_hlsl.hpp>
|
||||
|
||||
namespace backend { namespace d3d12 {
|
||||
|
||||
// TODO(kainino@chromium.org): Consider replacing this with a generic enum_map.
|
||||
template <typename T>
|
||||
class BindingTypeMap {
|
||||
public:
|
||||
T& operator[](nxt::BindingType type) {
|
||||
switch (type) {
|
||||
case nxt::BindingType::UniformBuffer:
|
||||
return mMap[0];
|
||||
case nxt::BindingType::Sampler:
|
||||
return mMap[1];
|
||||
case nxt::BindingType::SampledTexture:
|
||||
return mMap[2];
|
||||
case nxt::BindingType::StorageBuffer:
|
||||
return mMap[3];
|
||||
default:
|
||||
NXT_UNREACHABLE();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr int kNumBindingTypes = 4;
|
||||
std::array<T, kNumBindingTypes> mMap{};
|
||||
};
|
||||
|
||||
ShaderModule::ShaderModule(Device* device, ShaderModuleBuilder* builder)
|
||||
: ShaderModuleBase(builder), mDevice(device) {
|
||||
spirv_cross::CompilerHLSL compiler(builder->AcquireSpirv());
|
||||
|
@ -33,26 +59,22 @@ namespace backend { namespace d3d12 {
|
|||
|
||||
ExtractSpirvInfo(compiler);
|
||||
|
||||
// rename bindings so that each register type b/u/t/s starts at 0 and then offset by
|
||||
// rename bindings so that each register type c/u/t/s starts at 0 and then offset by
|
||||
// kMaxBindingsPerGroup * bindGroupIndex
|
||||
auto RenumberBindings = [&](std::vector<spirv_cross::Resource> resources) {
|
||||
std::array<uint32_t, kMaxBindGroups> baseRegisters = {};
|
||||
const auto& moduleBindingInfo = GetBindingInfo();
|
||||
for (uint32_t group = 0; group < moduleBindingInfo.size(); ++group) {
|
||||
const auto& groupBindingInfo = moduleBindingInfo[group];
|
||||
|
||||
for (const auto& resource : resources) {
|
||||
auto bindGroupIndex =
|
||||
compiler.get_decoration(resource.id, spv::DecorationDescriptorSet);
|
||||
auto& baseRegister = baseRegisters[bindGroupIndex];
|
||||
auto bindGroupOffset = bindGroupIndex * kMaxBindingsPerGroup;
|
||||
compiler.set_decoration(resource.id, spv::DecorationBinding,
|
||||
BindingTypeMap<uint32_t> baseRegisters{};
|
||||
for (const auto& bindingInfo : groupBindingInfo) {
|
||||
if (bindingInfo.used) {
|
||||
uint32_t& baseRegister = baseRegisters[bindingInfo.type];
|
||||
uint32_t bindGroupOffset = group * kMaxBindingsPerGroup;
|
||||
compiler.set_decoration(bindingInfo.id, spv::DecorationBinding,
|
||||
bindGroupOffset + baseRegister++);
|
||||
}
|
||||
};
|
||||
|
||||
const auto& resources = compiler.get_shader_resources();
|
||||
RenumberBindings(resources.uniform_buffers); // c
|
||||
RenumberBindings(resources.storage_buffers); // u
|
||||
RenumberBindings(resources.separate_images); // t
|
||||
RenumberBindings(resources.separate_samplers); // s
|
||||
}
|
||||
}
|
||||
|
||||
mHlslSource = compiler.compile();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue