Make ShaderModuleBase use its own spirv_cross for reflection.

Bug: dawn:216
Change-Id: Ie79aaeb3a878960606e8c09a4969bf7a1dbe1b13
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/28240
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
Corentin Wallez 2020-09-09 22:44:57 +00:00 committed by Commit Bot service account
parent 97b880e6ff
commit 947201da19
8 changed files with 239 additions and 249 deletions

View File

@ -548,6 +548,225 @@ namespace dawn_native {
return {};
}
ResultOrError<std::unique_ptr<EntryPointMetadata>> ExtractSpirvInfo(
const DeviceBase* device,
const spirv_cross::Compiler& compiler) {
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
// TODO(cwallez@chromium.org): make errors here creation errors
// currently errors here do not prevent the shadermodule from being used
const auto& resources = compiler.get_shader_resources();
switch (compiler.get_execution_model()) {
case spv::ExecutionModelVertex:
metadata->stage = SingleShaderStage::Vertex;
break;
case spv::ExecutionModelFragment:
metadata->stage = SingleShaderStage::Fragment;
break;
case spv::ExecutionModelGLCompute:
metadata->stage = SingleShaderStage::Compute;
break;
default:
UNREACHABLE();
return DAWN_VALIDATION_ERROR("Unexpected shader execution model");
}
if (resources.push_constant_buffers.size() > 0) {
return DAWN_VALIDATION_ERROR("Push constants aren't supported.");
}
if (resources.sampled_images.size() > 0) {
return DAWN_VALIDATION_ERROR("Combined images and samplers aren't supported.");
}
// Fill in bindingInfo with the SPIRV bindings
auto ExtractResourcesBinding =
[](const DeviceBase* device,
const spirv_cross::SmallVector<spirv_cross::Resource>& resources,
const spirv_cross::Compiler& compiler, wgpu::BindingType bindingType,
EntryPointMetadata::BindingInfo* metadataBindings) -> MaybeError {
for (const auto& resource : resources) {
if (!compiler.get_decoration_bitset(resource.id).get(spv::DecorationBinding)) {
return DAWN_VALIDATION_ERROR("No Binding decoration set for resource");
}
if (!compiler.get_decoration_bitset(resource.id)
.get(spv::DecorationDescriptorSet)) {
return DAWN_VALIDATION_ERROR("No Descriptor Decoration set for resource");
}
BindingNumber bindingNumber(
compiler.get_decoration(resource.id, spv::DecorationBinding));
BindGroupIndex bindGroupIndex(
compiler.get_decoration(resource.id, spv::DecorationDescriptorSet));
if (bindGroupIndex >= kMaxBindGroupsTyped) {
return DAWN_VALIDATION_ERROR("Bind group index over limits in the SPIRV");
}
const auto& it = (*metadataBindings)[bindGroupIndex].emplace(
bindingNumber, EntryPointMetadata::ShaderBindingInfo{});
if (!it.second) {
return DAWN_VALIDATION_ERROR("Shader has duplicate bindings");
}
EntryPointMetadata::ShaderBindingInfo* info = &it.first->second;
info->id = resource.id;
info->base_type_id = resource.base_type_id;
if (bindingType == wgpu::BindingType::UniformBuffer ||
bindingType == wgpu::BindingType::StorageBuffer ||
bindingType == wgpu::BindingType::ReadonlyStorageBuffer) {
// Determine buffer size, with a minimum of 1 element in the runtime array
spirv_cross::SPIRType type = compiler.get_type(info->base_type_id);
info->minBufferBindingSize =
compiler.get_declared_struct_size_runtime_array(type, 1);
}
switch (bindingType) {
case wgpu::BindingType::SampledTexture: {
spirv_cross::SPIRType::ImageType imageType =
compiler.get_type(info->base_type_id).image;
spirv_cross::SPIRType::BaseType textureComponentType =
compiler.get_type(imageType.type).basetype;
info->multisampled = imageType.ms;
info->viewDimension =
SpirvDimToTextureViewDimension(imageType.dim, imageType.arrayed);
info->textureComponentType =
SpirvCrossBaseTypeToFormatType(textureComponentType);
info->type = bindingType;
break;
}
case wgpu::BindingType::StorageBuffer: {
// Differentiate between readonly storage bindings and writable ones
// based on the NonWritable decoration
spirv_cross::Bitset flags =
compiler.get_buffer_block_flags(resource.id);
if (flags.get(spv::DecorationNonWritable)) {
info->type = wgpu::BindingType::ReadonlyStorageBuffer;
} else {
info->type = wgpu::BindingType::StorageBuffer;
}
break;
}
case wgpu::BindingType::StorageTexture: {
spirv_cross::Bitset flags = compiler.get_decoration_bitset(resource.id);
if (flags.get(spv::DecorationNonReadable)) {
info->type = wgpu::BindingType::WriteonlyStorageTexture;
} else if (flags.get(spv::DecorationNonWritable)) {
info->type = wgpu::BindingType::ReadonlyStorageTexture;
} else {
info->type = wgpu::BindingType::StorageTexture;
}
spirv_cross::SPIRType::ImageType imageType =
compiler.get_type(info->base_type_id).image;
wgpu::TextureFormat storageTextureFormat =
ToWGPUTextureFormat(imageType.format);
if (storageTextureFormat == wgpu::TextureFormat::Undefined) {
return DAWN_VALIDATION_ERROR(
"Invalid image format declaration on storage image");
}
const Format& format =
device->GetValidInternalFormat(storageTextureFormat);
if (!format.supportsStorageUsage) {
return DAWN_VALIDATION_ERROR(
"The storage texture format is not supported");
}
info->multisampled = imageType.ms;
info->storageTextureFormat = storageTextureFormat;
info->viewDimension =
SpirvDimToTextureViewDimension(imageType.dim, imageType.arrayed);
break;
}
default:
info->type = bindingType;
}
}
return {};
};
DAWN_TRY(ExtractResourcesBinding(device, resources.uniform_buffers, compiler,
wgpu::BindingType::UniformBuffer,
&metadata->bindings));
DAWN_TRY(ExtractResourcesBinding(device, resources.separate_images, compiler,
wgpu::BindingType::SampledTexture,
&metadata->bindings));
DAWN_TRY(ExtractResourcesBinding(device, resources.separate_samplers, compiler,
wgpu::BindingType::Sampler, &metadata->bindings));
DAWN_TRY(ExtractResourcesBinding(device, resources.storage_buffers, compiler,
wgpu::BindingType::StorageBuffer,
&metadata->bindings));
DAWN_TRY(ExtractResourcesBinding(device, resources.storage_images, compiler,
wgpu::BindingType::StorageTexture,
&metadata->bindings));
// Extract the vertex attributes
if (metadata->stage == SingleShaderStage::Vertex) {
for (const auto& attrib : resources.stage_inputs) {
if (!(compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation))) {
return DAWN_VALIDATION_ERROR(
"Unable to find Location decoration for Vertex input");
}
uint32_t location = compiler.get_decoration(attrib.id, spv::DecorationLocation);
if (location >= kMaxVertexAttributes) {
return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV");
}
metadata->usedVertexAttributes.set(location);
}
// Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives
// them all the location 0, causing a compile error.
for (const auto& attrib : resources.stage_outputs) {
if (!compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)) {
return DAWN_VALIDATION_ERROR("Need location qualifier on vertex output");
}
}
}
if (metadata->stage == SingleShaderStage::Fragment) {
// Without a location qualifier on vertex inputs, spirv_cross::CompilerMSL gives
// them all the location 0, causing a compile error.
for (const auto& attrib : resources.stage_inputs) {
if (!compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)) {
return DAWN_VALIDATION_ERROR("Need location qualifier on fragment input");
}
}
for (const auto& fragmentOutput : resources.stage_outputs) {
if (!compiler.get_decoration_bitset(fragmentOutput.id)
.get(spv::DecorationLocation)) {
return DAWN_VALIDATION_ERROR(
"Unable to find Location decoration for Fragment output");
}
uint32_t unsanitizedAttachment =
compiler.get_decoration(fragmentOutput.id, spv::DecorationLocation);
if (unsanitizedAttachment >= kMaxColorAttachments) {
return DAWN_VALIDATION_ERROR(
"Fragment output attachment index must be less than max number of "
"color "
"attachments");
}
ColorAttachmentIndex attachment(static_cast<uint8_t>(unsanitizedAttachment));
spirv_cross::SPIRType::BaseType shaderFragmentOutputBaseType =
compiler.get_type(fragmentOutput.base_type_id).basetype;
Format::Type formatType =
SpirvCrossBaseTypeToFormatType(shaderFragmentOutputBaseType);
if (formatType == Format::Type::Other) {
return DAWN_VALIDATION_ERROR("Unexpected Fragment output type");
}
metadata->fragmentOutputFormatBaseTypes[attachment] = formatType;
}
}
return {std::move(metadata)};
}
} // anonymous namespace
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
@ -677,224 +896,6 @@ namespace dawn_native {
return *mMainEntryPoint;
}
MaybeError ShaderModuleBase::ExtractSpirvInfo(const spirv_cross::Compiler& compiler) {
ASSERT(!IsError());
DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfoImpl(compiler));
return {};
}
ResultOrError<std::unique_ptr<EntryPointMetadata>> ShaderModuleBase::ExtractSpirvInfoImpl(
const spirv_cross::Compiler& compiler) {
DeviceBase* device = GetDevice();
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
// TODO(cwallez@chromium.org): make errors here creation errors
// currently errors here do not prevent the shadermodule from being used
const auto& resources = compiler.get_shader_resources();
switch (compiler.get_execution_model()) {
case spv::ExecutionModelVertex:
metadata->stage = SingleShaderStage::Vertex;
break;
case spv::ExecutionModelFragment:
metadata->stage = SingleShaderStage::Fragment;
break;
case spv::ExecutionModelGLCompute:
metadata->stage = SingleShaderStage::Compute;
break;
default:
UNREACHABLE();
return DAWN_VALIDATION_ERROR("Unexpected shader execution model");
}
if (resources.push_constant_buffers.size() > 0) {
return DAWN_VALIDATION_ERROR("Push constants aren't supported.");
}
if (resources.sampled_images.size() > 0) {
return DAWN_VALIDATION_ERROR("Combined images and samplers aren't supported.");
}
// Fill in bindingInfo with the SPIRV bindings
auto ExtractResourcesBinding =
[](const DeviceBase* device,
const spirv_cross::SmallVector<spirv_cross::Resource>& resources,
const spirv_cross::Compiler& compiler, wgpu::BindingType bindingType,
EntryPointMetadata::BindingInfo* metadataBindings) -> MaybeError {
for (const auto& resource : resources) {
if (!compiler.get_decoration_bitset(resource.id).get(spv::DecorationBinding)) {
return DAWN_VALIDATION_ERROR("No Binding decoration set for resource");
}
if (!compiler.get_decoration_bitset(resource.id)
.get(spv::DecorationDescriptorSet)) {
return DAWN_VALIDATION_ERROR("No Descriptor Decoration set for resource");
}
BindingNumber bindingNumber(
compiler.get_decoration(resource.id, spv::DecorationBinding));
BindGroupIndex bindGroupIndex(
compiler.get_decoration(resource.id, spv::DecorationDescriptorSet));
if (bindGroupIndex >= kMaxBindGroupsTyped) {
return DAWN_VALIDATION_ERROR("Bind group index over limits in the SPIRV");
}
const auto& it = (*metadataBindings)[bindGroupIndex].emplace(
bindingNumber, EntryPointMetadata::ShaderBindingInfo{});
if (!it.second) {
return DAWN_VALIDATION_ERROR("Shader has duplicate bindings");
}
EntryPointMetadata::ShaderBindingInfo* info = &it.first->second;
info->id = resource.id;
info->base_type_id = resource.base_type_id;
if (bindingType == wgpu::BindingType::UniformBuffer ||
bindingType == wgpu::BindingType::StorageBuffer ||
bindingType == wgpu::BindingType::ReadonlyStorageBuffer) {
// Determine buffer size, with a minimum of 1 element in the runtime array
spirv_cross::SPIRType type = compiler.get_type(info->base_type_id);
info->minBufferBindingSize =
compiler.get_declared_struct_size_runtime_array(type, 1);
}
switch (bindingType) {
case wgpu::BindingType::SampledTexture: {
spirv_cross::SPIRType::ImageType imageType =
compiler.get_type(info->base_type_id).image;
spirv_cross::SPIRType::BaseType textureComponentType =
compiler.get_type(imageType.type).basetype;
info->multisampled = imageType.ms;
info->viewDimension =
SpirvDimToTextureViewDimension(imageType.dim, imageType.arrayed);
info->textureComponentType =
SpirvCrossBaseTypeToFormatType(textureComponentType);
info->type = bindingType;
break;
}
case wgpu::BindingType::StorageBuffer: {
// Differentiate between readonly storage bindings and writable ones
// based on the NonWritable decoration
spirv_cross::Bitset flags = compiler.get_buffer_block_flags(resource.id);
if (flags.get(spv::DecorationNonWritable)) {
info->type = wgpu::BindingType::ReadonlyStorageBuffer;
} else {
info->type = wgpu::BindingType::StorageBuffer;
}
break;
}
case wgpu::BindingType::StorageTexture: {
spirv_cross::Bitset flags = compiler.get_decoration_bitset(resource.id);
if (flags.get(spv::DecorationNonReadable)) {
info->type = wgpu::BindingType::WriteonlyStorageTexture;
} else if (flags.get(spv::DecorationNonWritable)) {
info->type = wgpu::BindingType::ReadonlyStorageTexture;
} else {
info->type = wgpu::BindingType::StorageTexture;
}
spirv_cross::SPIRType::ImageType imageType =
compiler.get_type(info->base_type_id).image;
wgpu::TextureFormat storageTextureFormat =
ToWGPUTextureFormat(imageType.format);
if (storageTextureFormat == wgpu::TextureFormat::Undefined) {
return DAWN_VALIDATION_ERROR(
"Invalid image format declaration on storage image");
}
const Format& format = device->GetValidInternalFormat(storageTextureFormat);
if (!format.supportsStorageUsage) {
return DAWN_VALIDATION_ERROR(
"The storage texture format is not supported");
}
info->multisampled = imageType.ms;
info->storageTextureFormat = storageTextureFormat;
info->viewDimension =
SpirvDimToTextureViewDimension(imageType.dim, imageType.arrayed);
break;
}
default:
info->type = bindingType;
}
}
return {};
};
DAWN_TRY(ExtractResourcesBinding(device, resources.uniform_buffers, compiler,
wgpu::BindingType::UniformBuffer, &metadata->bindings));
DAWN_TRY(ExtractResourcesBinding(device, resources.separate_images, compiler,
wgpu::BindingType::SampledTexture, &metadata->bindings));
DAWN_TRY(ExtractResourcesBinding(device, resources.separate_samplers, compiler,
wgpu::BindingType::Sampler, &metadata->bindings));
DAWN_TRY(ExtractResourcesBinding(device, resources.storage_buffers, compiler,
wgpu::BindingType::StorageBuffer, &metadata->bindings));
DAWN_TRY(ExtractResourcesBinding(device, resources.storage_images, compiler,
wgpu::BindingType::StorageTexture, &metadata->bindings));
// Extract the vertex attributes
if (metadata->stage == SingleShaderStage::Vertex) {
for (const auto& attrib : resources.stage_inputs) {
if (!(compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation))) {
return DAWN_VALIDATION_ERROR(
"Unable to find Location decoration for Vertex input");
}
uint32_t location = compiler.get_decoration(attrib.id, spv::DecorationLocation);
if (location >= kMaxVertexAttributes) {
return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV");
}
metadata->usedVertexAttributes.set(location);
}
// Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives
// them all the location 0, causing a compile error.
for (const auto& attrib : resources.stage_outputs) {
if (!compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)) {
return DAWN_VALIDATION_ERROR("Need location qualifier on vertex output");
}
}
}
if (metadata->stage == SingleShaderStage::Fragment) {
// Without a location qualifier on vertex inputs, spirv_cross::CompilerMSL gives
// them all the location 0, causing a compile error.
for (const auto& attrib : resources.stage_inputs) {
if (!compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation)) {
return DAWN_VALIDATION_ERROR("Need location qualifier on fragment input");
}
}
for (const auto& fragmentOutput : resources.stage_outputs) {
if (!compiler.get_decoration_bitset(fragmentOutput.id)
.get(spv::DecorationLocation)) {
return DAWN_VALIDATION_ERROR(
"Unable to find Location decoration for Fragment output");
}
uint32_t unsanitizedAttachment =
compiler.get_decoration(fragmentOutput.id, spv::DecorationLocation);
if (unsanitizedAttachment >= kMaxColorAttachments) {
return DAWN_VALIDATION_ERROR(
"Fragment output attachment index must be less than max number of color "
"attachments");
}
ColorAttachmentIndex attachment(static_cast<uint8_t>(unsanitizedAttachment));
spirv_cross::SPIRType::BaseType shaderFragmentOutputBaseType =
compiler.get_type(fragmentOutput.base_type_id).basetype;
Format::Type formatType =
SpirvCrossBaseTypeToFormatType(shaderFragmentOutputBaseType);
if (formatType == Format::Type::Other) {
return DAWN_VALIDATION_ERROR("Unexpected Fragment output type");
}
metadata->fragmentOutputFormatBaseTypes[attachment] = formatType;
}
}
return {std::move(metadata)};
}
size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const {
size_t hash = 0;
@ -933,6 +934,9 @@ namespace dawn_native {
#endif // DAWN_ENABLE_WGSL
}
spirv_cross::Compiler compiler(mSpirv);
DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfo(GetDevice(), compiler));
return {};
}

View File

@ -102,9 +102,6 @@ namespace dawn_native {
const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint,
SingleShaderStage stage) const;
// TODO make this member protected, it is only used outside of child classes in DeviceNull.
MaybeError ExtractSpirvInfo(const spirv_cross::Compiler& compiler);
// Functors necessary for the unordered_set<ShaderModuleBase*>-based cache.
struct HashFunc {
size_t operator()(const ShaderModuleBase* module) const;
@ -133,9 +130,6 @@ namespace dawn_native {
private:
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
ResultOrError<std::unique_ptr<EntryPointMetadata>> ExtractSpirvInfoImpl(
const spirv_cross::Compiler& compiler);
enum class Type { Undefined, Spirv, Wgsl };
Type mType;
std::vector<uint32_t> mSpirv;

View File

@ -97,13 +97,7 @@ namespace dawn_native { namespace d3d12 {
}
MaybeError ShaderModule::Initialize() {
DAWN_TRY(InitializeBase());
const std::vector<uint32_t>& spirv = GetSpirv();
spirv_cross::CompilerHLSL compiler(spirv);
DAWN_TRY(ExtractSpirvInfo(compiler));
return {};
return InitializeBase();
}
ResultOrError<std::string> ShaderModule::GetHLSLSource(PipelineLayout* layout) {

View File

@ -54,13 +54,7 @@ namespace dawn_native { namespace metal {
}
MaybeError ShaderModule::Initialize() {
DAWN_TRY(InitializeBase());
const std::vector<uint32_t>& spirv = GetSpirv();
spirv_cross::CompilerMSL compiler(spirv);
DAWN_TRY(ExtractSpirvInfo(compiler));
return {};
return InitializeBase();
}
MaybeError ShaderModule::GetFunction(const char* functionName,

View File

@ -129,8 +129,7 @@ namespace dawn_native { namespace null {
ResultOrError<ShaderModuleBase*> Device::CreateShaderModuleImpl(
const ShaderModuleDescriptor* descriptor) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(this, descriptor));
spirv_cross::Compiler compiler(module->GetSpirv());
DAWN_TRY(module->ExtractSpirvInfo(compiler));
DAWN_TRY(module->Initialize());
return module.Detach();
}
ResultOrError<SwapChainBase*> Device::CreateSwapChainImpl(
@ -386,6 +385,12 @@ namespace dawn_native { namespace null {
}
}
// ShaderModule
MaybeError ShaderModule::Initialize() {
return InitializeBase();
}
// OldSwapChain
OldSwapChain::OldSwapChain(Device* device, const SwapChainDescriptor* descriptor)

View File

@ -50,7 +50,7 @@ namespace dawn_native { namespace null {
class Queue;
using RenderPipeline = RenderPipelineBase;
using Sampler = SamplerBase;
using ShaderModule = ShaderModuleBase;
class ShaderModule;
class SwapChain;
using Texture = TextureBase;
using TextureView = TextureViewBase;
@ -217,7 +217,6 @@ namespace dawn_native { namespace null {
class CommandBuffer final : public CommandBufferBase {
public:
CommandBuffer(CommandEncoder* encoder, const CommandBufferDescriptor* descriptor);
};
class QuerySet final : public QuerySetBase {
@ -243,6 +242,13 @@ namespace dawn_native { namespace null {
size_t size) override;
};
class ShaderModule final : public ShaderModuleBase {
public:
using ShaderModuleBase::ShaderModuleBase;
MaybeError Initialize();
};
class SwapChain final : public NewSwapChainBase {
public:
SwapChain(Device* device,

View File

@ -95,8 +95,6 @@ namespace dawn_native { namespace opengl {
spirv_cross::CompilerGLSL compiler(spirv);
compiler.set_common_options(options);
DAWN_TRY(ExtractSpirvInfo(compiler));
// Extract bindings names so that it can be used to get its location in program.
// Now translate the separate sampler / textures into combined ones and store their info.
// We need to do this before removing the set and binding decorations.

View File

@ -40,11 +40,6 @@ namespace dawn_native { namespace vulkan {
DAWN_TRY(InitializeBase());
const std::vector<uint32_t>& spirv = GetSpirv();
// Use SPIRV-Cross to extract info from the SPIRV even if Vulkan consumes SPIRV. We want to
// have a translation step eventually anyway.
spirv_cross::Compiler compiler(spirv);
DAWN_TRY(ExtractSpirvInfo(compiler));
VkShaderModuleCreateInfo createInfo;
createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
createInfo.pNext = nullptr;