From 4f8bdaf473fb7c95ab0c4f6da72f93601a00e019 Mon Sep 17 00:00:00 2001 From: Corentin Wallez Date: Wed, 26 Aug 2020 09:57:52 +0000 Subject: [PATCH] Make ShaderModuleBase use an internal EntryPointMetadata WGSL and SPIR-V modules can contain multiple entrypoints, for different shader stages, that the pipelines can choose from. This is the first CL in a stack that will change Dawn internals to not rely on ShaderModules having a single entrypoint. EntryPointMetadata is introduced that will contain all reflection data for an entrypoint of a shader module. To ease review this CL doesn't introduce any functional changes and doesn't expose the EntryPointMetadata at the ShaderModuleBase interface. Instead ShaderModuleBase contains a single metadata object for its single entry point, and layout-related queries and proxied to the EntryPointMetadata object. Finally some small renames and formatting changes are done. Bug: dawn:216 Change-Id: I0f4d12a5075ba14c5e8fd666be4073d34288f6f9 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/27240 Commit-Queue: Corentin Wallez Reviewed-by: Austin Eng --- src/dawn_native/ShaderModule.cpp | 787 ++++++++++++++++--------------- src/dawn_native/ShaderModule.h | 30 +- 2 files changed, 422 insertions(+), 395 deletions(-) diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 6057e70d6e..03ddf035e5 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -371,182 +371,330 @@ namespace dawn_native { } } #endif - } // anonymous namespace - MaybeError ValidateSpirv(DeviceBase*, const uint32_t* code, uint32_t codeSize) { - spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1); + MaybeError ValidateSpirv(DeviceBase*, const uint32_t* code, uint32_t codeSize) { + spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1); - std::ostringstream errorStream; - errorStream << "SPIRV Validation failure:" << std::endl; + std::ostringstream errorStream; + errorStream << "SPIRV Validation failure:" << std::endl; - spirvTools.SetMessageConsumer([&errorStream](spv_message_level_t level, const char*, - const spv_position_t& position, - const char* message) { - switch (level) { - case SPV_MSG_FATAL: - case SPV_MSG_INTERNAL_ERROR: - case SPV_MSG_ERROR: - errorStream << "error: line " << position.index << ": " << message << std::endl; - break; - case SPV_MSG_WARNING: - errorStream << "warning: line " << position.index << ": " << message - << std::endl; - break; - case SPV_MSG_INFO: - errorStream << "info: line " << position.index << ": " << message << std::endl; - break; - default: - break; + spirvTools.SetMessageConsumer([&errorStream](spv_message_level_t level, const char*, + const spv_position_t& position, + const char* message) { + switch (level) { + case SPV_MSG_FATAL: + case SPV_MSG_INTERNAL_ERROR: + case SPV_MSG_ERROR: + errorStream << "error: line " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_WARNING: + errorStream << "warning: line " << position.index << ": " << message + << std::endl; + break; + case SPV_MSG_INFO: + errorStream << "info: line " << position.index << ": " << message + << std::endl; + break; + default: + break; + } + }); + + if (!spirvTools.Validate(code, codeSize)) { + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); } - }); - if (!spirvTools.Validate(code, codeSize)) { - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + return {}; } - return {}; - } - #ifdef DAWN_ENABLE_WGSL - MaybeError ValidateWGSL(const char* source) { - std::ostringstream errorStream; - errorStream << "Tint WGSL failure:" << std::endl; + MaybeError ValidateWGSL(const char* source) { + std::ostringstream errorStream; + errorStream << "Tint WGSL failure:" << std::endl; - tint::Context context; - tint::reader::wgsl::Parser parser(&context, source); + tint::Context context; + tint::reader::wgsl::Parser parser(&context, source); - if (!parser.Parse()) { - errorStream << "Parser: " << parser.error() << std::endl; - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } - - tint::ast::Module module = parser.module(); - if (!module.IsValid()) { - errorStream << "Invalid module generated..." << std::endl; - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } - - tint::TypeDeterminer type_determiner(&context, &module); - if (!type_determiner.Determine()) { - errorStream << "Type Determination: " << type_determiner.error(); - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } - - tint::Validator validator; - if (!validator.Validate(&module)) { - errorStream << "Validation: " << validator.error() << std::endl; - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } - - return {}; - } - - ResultOrError> ConvertWGSLToSPIRV(const char* source) { - std::ostringstream errorStream; - errorStream << "Tint WGSL->SPIR-V failure:" << std::endl; - - tint::Context context; - tint::reader::wgsl::Parser parser(&context, source); - - // TODO: This is a duplicate parse with ValidateWGSL, need to store - // state between calls to avoid this. - if (!parser.Parse()) { - errorStream << "Parser: " << parser.error() << std::endl; - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } - - tint::ast::Module module = parser.module(); - if (!module.IsValid()) { - errorStream << "Invalid module generated..." << std::endl; - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } - - tint::TypeDeterminer type_determiner(&context, &module); - if (!type_determiner.Determine()) { - errorStream << "Type Determination: " << type_determiner.error(); - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } - - tint::writer::spirv::Generator generator(std::move(module)); - if (!generator.Generate()) { - errorStream << "Generator: " << generator.error() << std::endl; - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } - - std::vector spirv = generator.result(); - return std::move(spirv); - } - - ResultOrError> ConvertWGSLToSPIRVWithPulling( - const char* source, - const VertexStateDescriptor& vertexState, - const std::string& entryPoint, - uint32_t pullingBufferBindingSet) { - std::ostringstream errorStream; - errorStream << "Tint WGSL->SPIR-V failure:" << std::endl; - - tint::Context context; - tint::reader::wgsl::Parser parser(&context, source); - - // TODO: This is a duplicate parse with ValidateWGSL, need to store - // state between calls to avoid this. - if (!parser.Parse()) { - errorStream << "Parser: " << parser.error() << std::endl; - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } - - tint::ast::Module module = parser.module(); - if (!module.IsValid()) { - errorStream << "Invalid module generated..." << std::endl; - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } - - tint::ast::transform::VertexPullingTransform transform(&context, &module); - auto state = std::make_unique(); - for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) { - auto& vertexBuffer = vertexState.vertexBuffers[i]; - tint::ast::transform::VertexBufferLayoutDescriptor layout; - layout.array_stride = vertexBuffer.arrayStride; - layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode); - - for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) { - auto& attribute = vertexBuffer.attributes[j]; - tint::ast::transform::VertexAttributeDescriptor attr; - attr.format = ToTintVertexFormat(attribute.format); - attr.offset = attribute.offset; - attr.shader_location = attribute.shaderLocation; - - layout.attributes.push_back(std::move(attr)); + if (!parser.Parse()) { + errorStream << "Parser: " << parser.error() << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); } - state->vertex_buffers.push_back(std::move(layout)); - } - transform.SetVertexState(std::move(state)); - transform.SetEntryPoint(entryPoint); - transform.SetPullingBufferBindingSet(pullingBufferBindingSet); + tint::ast::Module module = parser.module(); + if (!module.IsValid()) { + errorStream << "Invalid module generated..." << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } - if (!transform.Run()) { - errorStream << "Vertex pulling transform: " << transform.GetError(); - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + tint::TypeDeterminer type_determiner(&context, &module); + if (!type_determiner.Determine()) { + errorStream << "Type Determination: " << type_determiner.error(); + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + tint::Validator validator; + if (!validator.Validate(&module)) { + errorStream << "Validation: " << validator.error() << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + return {}; } - tint::TypeDeterminer type_determiner(&context, &module); - if (!type_determiner.Determine()) { - errorStream << "Type Determination: " << type_determiner.error(); - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + ResultOrError> ConvertWGSLToSPIRV(const char* source) { + std::ostringstream errorStream; + errorStream << "Tint WGSL->SPIR-V failure:" << std::endl; + + tint::Context context; + tint::reader::wgsl::Parser parser(&context, source); + + // TODO: This is a duplicate parse with ValidateWGSL, need to store + // state between calls to avoid this. + if (!parser.Parse()) { + errorStream << "Parser: " << parser.error() << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + tint::ast::Module module = parser.module(); + if (!module.IsValid()) { + errorStream << "Invalid module generated..." << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + tint::TypeDeterminer type_determiner(&context, &module); + if (!type_determiner.Determine()) { + errorStream << "Type Determination: " << type_determiner.error(); + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + tint::writer::spirv::Generator generator(std::move(module)); + if (!generator.Generate()) { + errorStream << "Generator: " << generator.error() << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + std::vector spirv = generator.result(); + return std::move(spirv); } - tint::writer::spirv::Generator generator(std::move(module)); - if (!generator.Generate()) { - errorStream << "Generator: " << generator.error() << std::endl; - return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); - } + ResultOrError> ConvertWGSLToSPIRVWithPulling( + const char* source, + const VertexStateDescriptor& vertexState, + const std::string& entryPoint, + uint32_t pullingBufferBindingSet) { + std::ostringstream errorStream; + errorStream << "Tint WGSL->SPIR-V failure:" << std::endl; - std::vector spirv = generator.result(); - return std::move(spirv); - } + tint::Context context; + tint::reader::wgsl::Parser parser(&context, source); + + // TODO: This is a duplicate parse with ValidateWGSL, need to store + // state between calls to avoid this. + if (!parser.Parse()) { + errorStream << "Parser: " << parser.error() << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + tint::ast::Module module = parser.module(); + if (!module.IsValid()) { + errorStream << "Invalid module generated..." << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + tint::ast::transform::VertexPullingTransform transform(&context, &module); + auto state = std::make_unique(); + for (uint32_t i = 0; i < vertexState.vertexBufferCount; ++i) { + auto& vertexBuffer = vertexState.vertexBuffers[i]; + tint::ast::transform::VertexBufferLayoutDescriptor layout; + layout.array_stride = vertexBuffer.arrayStride; + layout.step_mode = ToTintInputStepMode(vertexBuffer.stepMode); + + for (uint32_t j = 0; j < vertexBuffer.attributeCount; ++j) { + auto& attribute = vertexBuffer.attributes[j]; + tint::ast::transform::VertexAttributeDescriptor attr; + attr.format = ToTintVertexFormat(attribute.format); + attr.offset = attribute.offset; + attr.shader_location = attribute.shaderLocation; + + layout.attributes.push_back(std::move(attr)); + } + + state->vertex_buffers.push_back(std::move(layout)); + } + transform.SetVertexState(std::move(state)); + transform.SetEntryPoint(entryPoint); + transform.SetPullingBufferBindingSet(pullingBufferBindingSet); + + if (!transform.Run()) { + errorStream << "Vertex pulling transform: " << transform.GetError(); + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + tint::TypeDeterminer type_determiner(&context, &module); + if (!type_determiner.Determine()) { + errorStream << "Type Determination: " << type_determiner.error(); + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + tint::writer::spirv::Generator generator(std::move(module)); + if (!generator.Generate()) { + errorStream << "Generator: " << generator.error() << std::endl; + return DAWN_VALIDATION_ERROR(errorStream.str().c_str()); + } + + std::vector spirv = generator.result(); + return std::move(spirv); + } #endif // DAWN_ENABLE_WGSL + std::vector GetBindGroupMinBufferSizes( + const ShaderModuleBase::BindingInfoMap& shaderBindings, + const BindGroupLayoutBase* layout) { + std::vector requiredBufferSizes(layout->GetUnverifiedBufferCount()); + uint32_t packedIdx = 0; + + for (BindingIndex bindingIndex{0}; bindingIndex < layout->GetBufferCount(); + ++bindingIndex) { + const BindingInfo& bindingInfo = layout->GetBindingInfo(bindingIndex); + if (bindingInfo.minBufferBindingSize != 0) { + // Skip bindings that have minimum buffer size set in the layout + continue; + } + + ASSERT(packedIdx < requiredBufferSizes.size()); + const auto& shaderInfo = shaderBindings.find(bindingInfo.binding); + if (shaderInfo != shaderBindings.end()) { + requiredBufferSizes[packedIdx] = shaderInfo->second.minBufferBindingSize; + } else { + // We have to include buffers if they are included in the bind group's + // packed vector. We don't actually need to check these at draw time, so + // if this is a problem in the future we can optimize it further. + requiredBufferSizes[packedIdx] = 0; + } + ++packedIdx; + } + + return requiredBufferSizes; + } + + MaybeError ValidateCompatibilityWithBindGroupLayout( + BindGroupIndex group, + const ShaderModuleBase::EntryPointMetadata& entryPoint, + const BindGroupLayoutBase* layout) { + const BindGroupLayoutBase::BindingMap& layoutBindings = layout->GetBindingMap(); + + // Iterate over all bindings used by this group in the shader, and find the + // corresponding binding in the BindGroupLayout, if it exists. + for (const auto& it : entryPoint.bindings[group]) { + BindingNumber bindingNumber = it.first; + const ShaderModuleBase::ShaderBindingInfo& shaderInfo = it.second; + + const auto& bindingIt = layoutBindings.find(bindingNumber); + if (bindingIt == layoutBindings.end()) { + return DAWN_VALIDATION_ERROR("Missing bind group layout entry for " + + GetShaderDeclarationString(group, bindingNumber)); + } + BindingIndex bindingIndex(bindingIt->second); + const BindingInfo& layoutInfo = layout->GetBindingInfo(bindingIndex); + + if (layoutInfo.type != shaderInfo.type) { + // Binding mismatch between shader and bind group is invalid. For example, a + // writable binding in the shader with a readonly storage buffer in the bind + // group layout is invalid. However, a readonly binding in the shader with a + // writable storage buffer in the bind group layout is valid. + bool validBindingConversion = + layoutInfo.type == wgpu::BindingType::StorageBuffer && + shaderInfo.type == wgpu::BindingType::ReadonlyStorageBuffer; + + // TODO(crbug.com/dawn/367): Temporarily allow using either a sampler or a + // comparison sampler until we can perform the proper shader analysis of what + // type is used in the shader module. + validBindingConversion |= + (layoutInfo.type == wgpu::BindingType::Sampler && + shaderInfo.type == wgpu::BindingType::ComparisonSampler); + validBindingConversion |= + (layoutInfo.type == wgpu::BindingType::ComparisonSampler && + shaderInfo.type == wgpu::BindingType::Sampler); + + if (!validBindingConversion) { + return DAWN_VALIDATION_ERROR( + "The binding type of the bind group layout entry conflicts " + + GetShaderDeclarationString(group, bindingNumber)); + } + } + + if ((layoutInfo.visibility & StageBit(entryPoint.stage)) == 0) { + return DAWN_VALIDATION_ERROR("The bind group layout entry for " + + GetShaderDeclarationString(group, bindingNumber) + + " is not visible for the shader stage"); + } + + switch (layoutInfo.type) { + case wgpu::BindingType::SampledTexture: { + if (layoutInfo.textureComponentType != shaderInfo.textureComponentType) { + return DAWN_VALIDATION_ERROR( + "The textureComponentType of the bind group layout entry is " + "different from " + + GetShaderDeclarationString(group, bindingNumber)); + } + + if (layoutInfo.viewDimension != shaderInfo.viewDimension) { + return DAWN_VALIDATION_ERROR( + "The viewDimension of the bind group layout entry is different " + "from " + + GetShaderDeclarationString(group, bindingNumber)); + } + break; + } + + case wgpu::BindingType::ReadonlyStorageTexture: + case wgpu::BindingType::WriteonlyStorageTexture: { + ASSERT(layoutInfo.storageTextureFormat != wgpu::TextureFormat::Undefined); + ASSERT(shaderInfo.storageTextureFormat != wgpu::TextureFormat::Undefined); + if (layoutInfo.storageTextureFormat != shaderInfo.storageTextureFormat) { + return DAWN_VALIDATION_ERROR( + "The storageTextureFormat of the bind group layout entry is " + "different from " + + GetShaderDeclarationString(group, bindingNumber)); + } + if (layoutInfo.viewDimension != shaderInfo.viewDimension) { + return DAWN_VALIDATION_ERROR( + "The viewDimension of the bind group layout entry is different " + "from " + + GetShaderDeclarationString(group, bindingNumber)); + } + break; + } + + case wgpu::BindingType::UniformBuffer: + case wgpu::BindingType::ReadonlyStorageBuffer: + case wgpu::BindingType::StorageBuffer: { + if (layoutInfo.minBufferBindingSize != 0 && + shaderInfo.minBufferBindingSize > layoutInfo.minBufferBindingSize) { + return DAWN_VALIDATION_ERROR( + "The minimum buffer size of the bind group layout entry is smaller " + "than " + + GetShaderDeclarationString(group, bindingNumber)); + } + break; + } + case wgpu::BindingType::Sampler: + case wgpu::BindingType::ComparisonSampler: + break; + + case wgpu::BindingType::StorageTexture: + default: + UNREACHABLE(); + return DAWN_VALIDATION_ERROR("Unsupported binding type"); + } + } + + return {}; + } + + } // anonymous namespace + MaybeError ValidateShaderModuleDescriptor(DeviceBase* device, const ShaderModuleDescriptor* descriptor) { const ChainedStruct* chainedDescriptor = descriptor->nextInChain; @@ -582,7 +730,45 @@ namespace dawn_native { } return {}; - } // namespace + } + + RequiredBufferSizes ComputeRequiredBufferSizesForLayout( + const ShaderModuleBase::EntryPointMetadata& entryPoint, + const PipelineLayoutBase* layout) { + RequiredBufferSizes bufferSizes; + for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { + bufferSizes[group] = GetBindGroupMinBufferSizes(entryPoint.bindings[group], + layout->GetBindGroupLayout(group)); + } + + return bufferSizes; + } + + MaybeError ValidateCompatibilityWithPipelineLayout( + const ShaderModuleBase::EntryPointMetadata& entryPoint, + const PipelineLayoutBase* layout) { + for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { + DAWN_TRY(ValidateCompatibilityWithBindGroupLayout(group, entryPoint, + layout->GetBindGroupLayout(group))); + } + + for (BindGroupIndex group : IterateBitSet(~layout->GetBindGroupLayoutsMask())) { + if (entryPoint.bindings[group].size() > 0) { + std::ostringstream ostream; + ostream << "No bind group layout entry matches the declaration set " + << static_cast(group) << " in the shader module"; + return DAWN_VALIDATION_ERROR(ostream.str()); + } + } + + return {}; + } + + // EntryPointMetadata + + ShaderModuleBase::EntryPointMetadata::EntryPointMetadata() { + fragmentOutputFormatBaseTypes.fill(Format::Type::Other); + } // ShaderModuleBase @@ -608,7 +794,6 @@ namespace dawn_native { UNREACHABLE(); } - mFragmentOutputFormatBaseTypes.fill(Format::Type::Other); if (GetDevice()->IsToggleEnabled(Toggle::UseSpvcParser)) { mSpvcContext.SetUseSpvcParser(true); } @@ -632,18 +817,22 @@ namespace dawn_native { MaybeError ShaderModuleBase::ExtractSpirvInfo(const spirv_cross::Compiler& compiler) { ASSERT(!IsError()); if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) { - DAWN_TRY(ExtractSpirvInfoWithSpvc()); + DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfoWithSpvc()); } else { - DAWN_TRY(ExtractSpirvInfoWithSpirvCross(compiler)); + DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfoWithSpirvCross(compiler)); } return {}; } - MaybeError ShaderModuleBase::ExtractSpirvInfoWithSpvc() { + ResultOrError> + ShaderModuleBase::ExtractSpirvInfoWithSpvc() { + DeviceBase* device = GetDevice(); + std::unique_ptr metadata = std::make_unique(); + shaderc_spvc_execution_model execution_model; DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetExecutionModel(&execution_model), "Unable to get execution model for shader.")); - mExecutionModel = ToSingleShaderStage(execution_model); + metadata->stage = ToSingleShaderStage(execution_model); size_t push_constant_buffers_count; DAWN_TRY( @@ -658,15 +847,16 @@ namespace dawn_native { // Fill in bindingInfo with the SPIRV bindings auto ExtractResourcesBinding = - [this](std::vector bindings) -> MaybeError { - for (const auto& binding : bindings) { + [](const DeviceBase* device, const std::vector& spvcBindings, + ModuleBindingInfo* metadataBindings) -> MaybeError { + for (const shaderc_spvc_binding_info& binding : spvcBindings) { BindGroupIndex bindGroupIndex(binding.set); if (bindGroupIndex >= kMaxBindGroupsTyped) { return DAWN_VALIDATION_ERROR("Bind group index over limits in the SPIRV"); } - const auto& it = mBindingInfo[bindGroupIndex].emplace( + const auto& it = (*metadataBindings)[bindGroupIndex].emplace( BindingNumber(binding.binding), ShaderBindingInfo{}); if (!it.second) { return DAWN_VALIDATION_ERROR("Shader has duplicate bindings"); @@ -694,8 +884,7 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR( "Invalid image format declaration on storage image"); } - const Format& format = - GetDevice()->GetValidInternalFormat(storageTextureFormat); + const Format& format = device->GetValidInternalFormat(storageTextureFormat); if (!format.supportsStorageUsage) { return DAWN_VALIDATION_ERROR( "The storage texture format is not supported"); @@ -722,45 +911,45 @@ namespace dawn_native { shaderc_spvc_shader_resource_uniform_buffers, shaderc_spvc_binding_type_uniform_buffer, &resource_bindings), "Unable to get binding info for uniform buffers from shader")); - DAWN_TRY(ExtractResourcesBinding(resource_bindings)); + DAWN_TRY(ExtractResourcesBinding(device, resource_bindings, &metadata->bindings)); DAWN_TRY(CheckSpvcSuccess( mSpvcContext.GetBindingInfo(shaderc_spvc_shader_resource_separate_images, shaderc_spvc_binding_type_sampled_texture, &resource_bindings), "Unable to get binding info for sampled textures from shader")); - DAWN_TRY(ExtractResourcesBinding(resource_bindings)); + DAWN_TRY(ExtractResourcesBinding(device, resource_bindings, &metadata->bindings)); DAWN_TRY(CheckSpvcSuccess( mSpvcContext.GetBindingInfo(shaderc_spvc_shader_resource_separate_samplers, shaderc_spvc_binding_type_sampler, &resource_bindings), "Unable to get binding info for samples from shader")); - DAWN_TRY(ExtractResourcesBinding(resource_bindings)); + DAWN_TRY(ExtractResourcesBinding(device, resource_bindings, &metadata->bindings)); DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetBindingInfo( shaderc_spvc_shader_resource_storage_buffers, shaderc_spvc_binding_type_storage_buffer, &resource_bindings), "Unable to get binding info for storage buffers from shader")); - DAWN_TRY(ExtractResourcesBinding(resource_bindings)); + DAWN_TRY(ExtractResourcesBinding(device, resource_bindings, &metadata->bindings)); DAWN_TRY(CheckSpvcSuccess( mSpvcContext.GetBindingInfo(shaderc_spvc_shader_resource_storage_images, shaderc_spvc_binding_type_storage_texture, &resource_bindings), "Unable to get binding info for storage textures from shader")); - DAWN_TRY(ExtractResourcesBinding(resource_bindings)); + DAWN_TRY(ExtractResourcesBinding(device, resource_bindings, &metadata->bindings)); std::vector input_stage_locations; DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetInputStageLocationInfo(&input_stage_locations), "Unable to get input stage location information from shader")); for (const auto& input : input_stage_locations) { - if (mExecutionModel == SingleShaderStage::Vertex) { + if (metadata->stage == SingleShaderStage::Vertex) { if (input.location >= kMaxVertexAttributes) { return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV"); } - mUsedVertexAttributes.set(input.location); - } else if (mExecutionModel == SingleShaderStage::Fragment) { + metadata->usedVertexAttributes.set(input.location); + } else if (metadata->stage == SingleShaderStage::Fragment) { // Without a location qualifier on vertex inputs, spirv_cross::CompilerMSL gives // them all the location 0, causing a compile error. if (!input.has_location) { @@ -774,13 +963,13 @@ namespace dawn_native { "Unable to get output stage location information from shader")); for (const auto& output : output_stage_locations) { - if (mExecutionModel == SingleShaderStage::Vertex) { + if (metadata->stage == SingleShaderStage::Vertex) { // Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL // gives them all the location 0, causing a compile error. if (!output.has_location) { return DAWN_VALIDATION_ERROR("Need location qualifier on vertex output"); } - } else if (mExecutionModel == SingleShaderStage::Fragment) { + } else if (metadata->stage == SingleShaderStage::Fragment) { if (output.location >= kMaxColorAttachments) { return DAWN_VALIDATION_ERROR( "Fragment output location over limits in the SPIRV"); @@ -788,7 +977,7 @@ namespace dawn_native { } } - if (mExecutionModel == SingleShaderStage::Fragment) { + if (metadata->stage == SingleShaderStage::Fragment) { std::vector output_types; DAWN_TRY(CheckSpvcSuccess(mSpvcContext.GetOutputStageTypeInfo(&output_types), "Unable to get output stage type information from shader")); @@ -797,27 +986,32 @@ namespace dawn_native { if (output.type == shaderc_spvc_texture_format_type_other) { return DAWN_VALIDATION_ERROR("Unexpected Fragment output type"); } - mFragmentOutputFormatBaseTypes[output.location] = ToDawnFormatType(output.type); + metadata->fragmentOutputFormatBaseTypes[output.location] = + ToDawnFormatType(output.type); } } - return {}; + + return {std::move(metadata)}; } - MaybeError ShaderModuleBase::ExtractSpirvInfoWithSpirvCross( - const spirv_cross::Compiler& compiler) { + ResultOrError> + ShaderModuleBase::ExtractSpirvInfoWithSpirvCross(const spirv_cross::Compiler& compiler) { + DeviceBase* device = GetDevice(); + std::unique_ptr metadata = std::make_unique(); + // TODO(cwallez@chromium.org): make errors here creation errors // currently errors here do not prevent the shadermodule from being used const auto& resources = compiler.get_shader_resources(); switch (compiler.get_execution_model()) { case spv::ExecutionModelVertex: - mExecutionModel = SingleShaderStage::Vertex; + metadata->stage = SingleShaderStage::Vertex; break; case spv::ExecutionModelFragment: - mExecutionModel = SingleShaderStage::Fragment; + metadata->stage = SingleShaderStage::Fragment; break; case spv::ExecutionModelGLCompute: - mExecutionModel = SingleShaderStage::Compute; + metadata->stage = SingleShaderStage::Compute; break; default: UNREACHABLE(); @@ -834,9 +1028,10 @@ namespace dawn_native { // Fill in bindingInfo with the SPIRV bindings auto ExtractResourcesBinding = - [this](const spirv_cross::SmallVector& resources, - const spirv_cross::Compiler& compiler, - wgpu::BindingType bindingType) -> MaybeError { + [](const DeviceBase* device, + const spirv_cross::SmallVector& resources, + const spirv_cross::Compiler& compiler, wgpu::BindingType bindingType, + ModuleBindingInfo* metadataBindings) -> MaybeError { for (const auto& resource : resources) { if (!compiler.get_decoration_bitset(resource.id).get(spv::DecorationBinding)) { return DAWN_VALIDATION_ERROR("No Binding decoration set for resource"); @@ -857,7 +1052,7 @@ namespace dawn_native { } const auto& it = - mBindingInfo[bindGroupIndex].emplace(bindingNumber, ShaderBindingInfo{}); + (*metadataBindings)[bindGroupIndex].emplace(bindingNumber, ShaderBindingInfo{}); if (!it.second) { return DAWN_VALIDATION_ERROR("Shader has duplicate bindings"); } @@ -919,8 +1114,7 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR( "Invalid image format declaration on storage image"); } - const Format& format = - GetDevice()->GetValidInternalFormat(storageTextureFormat); + const Format& format = device->GetValidInternalFormat(storageTextureFormat); if (!format.supportsStorageUsage) { return DAWN_VALIDATION_ERROR( "The storage texture format is not supported"); @@ -938,19 +1132,19 @@ namespace dawn_native { return {}; }; - DAWN_TRY(ExtractResourcesBinding(resources.uniform_buffers, compiler, - wgpu::BindingType::UniformBuffer)); - DAWN_TRY(ExtractResourcesBinding(resources.separate_images, compiler, - wgpu::BindingType::SampledTexture)); - DAWN_TRY(ExtractResourcesBinding(resources.separate_samplers, compiler, - wgpu::BindingType::Sampler)); - DAWN_TRY(ExtractResourcesBinding(resources.storage_buffers, compiler, - wgpu::BindingType::StorageBuffer)); - DAWN_TRY(ExtractResourcesBinding(resources.storage_images, compiler, - wgpu::BindingType::StorageTexture)); + DAWN_TRY(ExtractResourcesBinding(device, resources.uniform_buffers, compiler, + wgpu::BindingType::UniformBuffer, &metadata->bindings)); + DAWN_TRY(ExtractResourcesBinding(device, resources.separate_images, compiler, + wgpu::BindingType::SampledTexture, &metadata->bindings)); + DAWN_TRY(ExtractResourcesBinding(device, resources.separate_samplers, compiler, + wgpu::BindingType::Sampler, &metadata->bindings)); + DAWN_TRY(ExtractResourcesBinding(device, resources.storage_buffers, compiler, + wgpu::BindingType::StorageBuffer, &metadata->bindings)); + DAWN_TRY(ExtractResourcesBinding(device, resources.storage_images, compiler, + wgpu::BindingType::StorageTexture, &metadata->bindings)); // Extract the vertex attributes - if (mExecutionModel == SingleShaderStage::Vertex) { + if (metadata->stage == SingleShaderStage::Vertex) { for (const auto& attrib : resources.stage_inputs) { if (!(compiler.get_decoration_bitset(attrib.id).get(spv::DecorationLocation))) { return DAWN_VALIDATION_ERROR( @@ -962,7 +1156,7 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR("Attribute location over limits in the SPIRV"); } - mUsedVertexAttributes.set(location); + metadata->usedVertexAttributes.set(location); } // Without a location qualifier on vertex outputs, spirv_cross::CompilerMSL gives @@ -974,7 +1168,7 @@ namespace dawn_native { } } - if (mExecutionModel == SingleShaderStage::Fragment) { + if (metadata->stage == SingleShaderStage::Fragment) { // Without a location qualifier on vertex inputs, spirv_cross::CompilerMSL gives // them all the location 0, causing a compile error. for (const auto& attrib : resources.stage_inputs) { @@ -1003,209 +1197,44 @@ namespace dawn_native { if (formatType == Format::Type::Other) { return DAWN_VALIDATION_ERROR("Unexpected Fragment output type"); } - mFragmentOutputFormatBaseTypes[location] = formatType; + metadata->fragmentOutputFormatBaseTypes[location] = formatType; } } - return {}; + + return {std::move(metadata)}; } const ShaderModuleBase::ModuleBindingInfo& ShaderModuleBase::GetBindingInfo() const { ASSERT(!IsError()); - return mBindingInfo; + return mMainEntryPoint->bindings; } const std::bitset& ShaderModuleBase::GetUsedVertexAttributes() const { ASSERT(!IsError()); - return mUsedVertexAttributes; + return mMainEntryPoint->usedVertexAttributes; } const ShaderModuleBase::FragmentOutputBaseTypes& ShaderModuleBase::GetFragmentOutputBaseTypes() const { ASSERT(!IsError()); - return mFragmentOutputFormatBaseTypes; + return mMainEntryPoint->fragmentOutputFormatBaseTypes; } SingleShaderStage ShaderModuleBase::GetExecutionModel() const { ASSERT(!IsError()); - return mExecutionModel; + return mMainEntryPoint->stage; } RequiredBufferSizes ShaderModuleBase::ComputeRequiredBufferSizesForLayout( const PipelineLayoutBase* layout) const { - RequiredBufferSizes bufferSizes; - for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { - bufferSizes[group] = - GetBindGroupMinBufferSizes(mBindingInfo[group], layout->GetBindGroupLayout(group)); - } - - return bufferSizes; - } - - std::vector ShaderModuleBase::GetBindGroupMinBufferSizes( - const BindingInfoMap& shaderMap, - const BindGroupLayoutBase* layout) const { - std::vector requiredBufferSizes(layout->GetUnverifiedBufferCount()); - uint32_t packedIdx = 0; - - for (BindingIndex bindingIndex{0}; bindingIndex < layout->GetBufferCount(); - ++bindingIndex) { - const BindingInfo& bindingInfo = layout->GetBindingInfo(bindingIndex); - if (bindingInfo.minBufferBindingSize != 0) { - // Skip bindings that have minimum buffer size set in the layout - continue; - } - - ASSERT(packedIdx < requiredBufferSizes.size()); - const auto& shaderInfo = shaderMap.find(bindingInfo.binding); - if (shaderInfo != shaderMap.end()) { - requiredBufferSizes[packedIdx] = shaderInfo->second.minBufferBindingSize; - } else { - // We have to include buffers if they are included in the bind group's - // packed vector. We don't actually need to check these at draw time, so - // if this is a problem in the future we can optimize it further. - requiredBufferSizes[packedIdx] = 0; - } - ++packedIdx; - } - - return requiredBufferSizes; + ASSERT(!IsError()); + return ::dawn_native::ComputeRequiredBufferSizesForLayout(*mMainEntryPoint, layout); } MaybeError ShaderModuleBase::ValidateCompatibilityWithPipelineLayout( const PipelineLayoutBase* layout) const { ASSERT(!IsError()); - - for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { - DAWN_TRY( - ValidateCompatibilityWithBindGroupLayout(group, layout->GetBindGroupLayout(group))); - } - - for (BindGroupIndex group : IterateBitSet(~layout->GetBindGroupLayoutsMask())) { - if (mBindingInfo[group].size() > 0) { - std::ostringstream ostream; - ostream << "No bind group layout entry matches the declaration set " - << static_cast(group) << " in the shader module"; - return DAWN_VALIDATION_ERROR(ostream.str()); - } - } - - return {}; - } - - MaybeError ShaderModuleBase::ValidateCompatibilityWithBindGroupLayout( - BindGroupIndex group, - const BindGroupLayoutBase* layout) const { - ASSERT(!IsError()); - - const BindGroupLayoutBase::BindingMap& bindingMap = layout->GetBindingMap(); - - // Iterate over all bindings used by this group in the shader, and find the - // corresponding binding in the BindGroupLayout, if it exists. - for (const auto& it : mBindingInfo[group]) { - BindingNumber bindingNumber = it.first; - const ShaderBindingInfo& moduleInfo = it.second; - - const auto& bindingIt = bindingMap.find(bindingNumber); - if (bindingIt == bindingMap.end()) { - return DAWN_VALIDATION_ERROR("Missing bind group layout entry for " + - GetShaderDeclarationString(group, bindingNumber)); - } - BindingIndex bindingIndex(bindingIt->second); - - const BindingInfo& bindingInfo = layout->GetBindingInfo(bindingIndex); - - if (bindingInfo.type != moduleInfo.type) { - // Binding mismatch between shader and bind group is invalid. For example, a - // writable binding in the shader with a readonly storage buffer in the bind group - // layout is invalid. However, a readonly binding in the shader with a writable - // storage buffer in the bind group layout is valid. - bool validBindingConversion = - bindingInfo.type == wgpu::BindingType::StorageBuffer && - moduleInfo.type == wgpu::BindingType::ReadonlyStorageBuffer; - - // TODO(crbug.com/dawn/367): Temporarily allow using either a sampler or a - // comparison sampler until we can perform the proper shader analysis of what type - // is used in the shader module. - validBindingConversion |= (bindingInfo.type == wgpu::BindingType::Sampler && - moduleInfo.type == wgpu::BindingType::ComparisonSampler); - validBindingConversion |= - (bindingInfo.type == wgpu::BindingType::ComparisonSampler && - moduleInfo.type == wgpu::BindingType::Sampler); - - if (!validBindingConversion) { - return DAWN_VALIDATION_ERROR( - "The binding type of the bind group layout entry conflicts " + - GetShaderDeclarationString(group, bindingNumber)); - } - } - - if ((bindingInfo.visibility & StageBit(mExecutionModel)) == 0) { - return DAWN_VALIDATION_ERROR("The bind group layout entry for " + - GetShaderDeclarationString(group, bindingNumber) + - " is not visible for the shader stage"); - } - - switch (bindingInfo.type) { - case wgpu::BindingType::SampledTexture: { - if (bindingInfo.textureComponentType != moduleInfo.textureComponentType) { - return DAWN_VALIDATION_ERROR( - "The textureComponentType of the bind group layout entry is different " - "from " + - GetShaderDeclarationString(group, bindingNumber)); - } - - if (bindingInfo.viewDimension != moduleInfo.viewDimension) { - return DAWN_VALIDATION_ERROR( - "The viewDimension of the bind group layout entry is different " - "from " + - GetShaderDeclarationString(group, bindingNumber)); - } - break; - } - - case wgpu::BindingType::ReadonlyStorageTexture: - case wgpu::BindingType::WriteonlyStorageTexture: { - ASSERT(bindingInfo.storageTextureFormat != wgpu::TextureFormat::Undefined); - ASSERT(moduleInfo.storageTextureFormat != wgpu::TextureFormat::Undefined); - if (bindingInfo.storageTextureFormat != moduleInfo.storageTextureFormat) { - return DAWN_VALIDATION_ERROR( - "The storageTextureFormat of the bind group layout entry is different " - "from " + - GetShaderDeclarationString(group, bindingNumber)); - } - if (bindingInfo.viewDimension != moduleInfo.viewDimension) { - return DAWN_VALIDATION_ERROR( - "The viewDimension of the bind group layout entry is different " - "from " + - GetShaderDeclarationString(group, bindingNumber)); - } - break; - } - - case wgpu::BindingType::UniformBuffer: - case wgpu::BindingType::ReadonlyStorageBuffer: - case wgpu::BindingType::StorageBuffer: { - if (bindingInfo.minBufferBindingSize != 0 && - moduleInfo.minBufferBindingSize > bindingInfo.minBufferBindingSize) { - return DAWN_VALIDATION_ERROR( - "The minimum buffer size of the bind group layout entry is smaller " - "than " + - GetShaderDeclarationString(group, bindingNumber)); - } - break; - } - case wgpu::BindingType::Sampler: - case wgpu::BindingType::ComparisonSampler: - break; - - case wgpu::BindingType::StorageTexture: - default: - UNREACHABLE(); - return DAWN_VALIDATION_ERROR("Unsupported binding type"); - } - } - - return {}; + return ::dawn_native::ValidateCompatibilityWithPipelineLayout(*mMainEntryPoint, layout); } size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const { diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index 336551ee85..f6779aa404 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -43,8 +43,6 @@ namespace dawn_native { class ShaderModuleBase : public CachedObject { public: - enum class Type { Undefined, Spirv, Wgsl }; - ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor); ~ShaderModuleBase() override; @@ -98,6 +96,15 @@ namespace dawn_native { uint32_t pullingBufferBindingSet) const; #endif + struct EntryPointMetadata { + EntryPointMetadata(); + + ModuleBindingInfo bindings; + std::bitset usedVertexAttributes; + SingleShaderStage stage; + FragmentOutputBaseTypes fragmentOutputFormatBaseTypes; + }; + protected: static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg); shaderc_spvc::CompileOptions GetCompileOptions() const; @@ -108,27 +115,18 @@ namespace dawn_native { private: ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag); - MaybeError ValidateCompatibilityWithBindGroupLayout( - BindGroupIndex group, - const BindGroupLayoutBase* layout) const; - - std::vector GetBindGroupMinBufferSizes(const BindingInfoMap& shaderMap, - const BindGroupLayoutBase* layout) const; - // Different implementations reflection into the shader depending on // whether using spvc, or directly accessing spirv-cross. - MaybeError ExtractSpirvInfoWithSpvc(); - MaybeError ExtractSpirvInfoWithSpirvCross(const spirv_cross::Compiler& compiler); + ResultOrError> ExtractSpirvInfoWithSpvc(); + ResultOrError> ExtractSpirvInfoWithSpirvCross( + const spirv_cross::Compiler& compiler); + enum class Type { Undefined, Spirv, Wgsl }; Type mType; std::vector mSpirv; std::string mWgsl; - ModuleBindingInfo mBindingInfo; - std::bitset mUsedVertexAttributes; - SingleShaderStage mExecutionModel; - - FragmentOutputBaseTypes mFragmentOutputFormatBaseTypes; + std::unique_ptr mMainEntryPoint; }; } // namespace dawn_native