diff --git a/src/dawn_native/ComputePipeline.h b/src/dawn_native/ComputePipeline.h index c2f118899c..c2b470a252 100644 --- a/src/dawn_native/ComputePipeline.h +++ b/src/dawn_native/ComputePipeline.h @@ -20,6 +20,7 @@ namespace dawn_native { class DeviceBase; + struct EntryPointMetadata; MaybeError ValidateComputePipelineDescriptor(DeviceBase* device, const ComputePipelineDescriptor* descriptor); @@ -31,6 +32,8 @@ namespace dawn_native { static ComputePipelineBase* MakeError(DeviceBase* device); + const EntryPointMetadata& GetMetadata() const; + // Functors necessary for the unordered_set-based cache. struct HashFunc { size_t operator()(const ComputePipelineBase* pipeline) const; diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp index 1f76e49b17..5f48898fc6 100644 --- a/src/dawn_native/Device.cpp +++ b/src/dawn_native/Device.cpp @@ -877,9 +877,9 @@ namespace dawn_native { if (descriptor->layout == nullptr) { ComputePipelineDescriptor descriptorWithDefaultLayout = *descriptor; - DAWN_TRY_ASSIGN( - descriptorWithDefaultLayout.layout, - PipelineLayoutBase::CreateDefault(this, &descriptor->computeStage.module, 1)); + DAWN_TRY_ASSIGN(descriptorWithDefaultLayout.layout, + PipelineLayoutBase::CreateDefault( + this, {{SingleShaderStage::Compute, &descriptor->computeStage}})); // Ref will keep the pipeline layout alive until the end of the function where // the pipeline will take another reference. Ref layoutRef = AcquireRef(descriptorWithDefaultLayout.layout); @@ -934,18 +934,14 @@ namespace dawn_native { if (descriptor->layout == nullptr) { RenderPipelineDescriptor descriptorWithDefaultLayout = *descriptor; - const ShaderModuleBase* modules[2]; - modules[0] = descriptor->vertexStage.module; - uint32_t count; - if (descriptor->fragmentStage == nullptr) { - count = 1; - } else { - modules[1] = descriptor->fragmentStage->module; - count = 2; + std::vector stages; + stages.emplace_back(SingleShaderStage::Vertex, &descriptor->vertexStage); + if (descriptor->fragmentStage != nullptr) { + stages.emplace_back(SingleShaderStage::Fragment, descriptor->fragmentStage); } DAWN_TRY_ASSIGN(descriptorWithDefaultLayout.layout, - PipelineLayoutBase::CreateDefault(this, modules, count)); + PipelineLayoutBase::CreateDefault(this, std::move(stages))); // Ref will keep the pipeline layout alive until the end of the function where // the pipeline will take another reference. Ref layoutRef = AcquireRef(descriptorWithDefaultLayout.layout); diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp index f37e30cebc..691e3a976a 100644 --- a/src/dawn_native/Pipeline.cpp +++ b/src/dawn_native/Pipeline.cpp @@ -26,16 +26,17 @@ namespace dawn_native { const ProgrammableStageDescriptor* descriptor, const PipelineLayoutBase* layout, SingleShaderStage stage) { - DAWN_TRY(device->ValidateObject(descriptor->module)); + const ShaderModuleBase* module = descriptor->module; + DAWN_TRY(device->ValidateObject(module)); - if (descriptor->entryPoint != std::string("main")) { - return DAWN_VALIDATION_ERROR("Entry point must be \"main\""); - } - if (descriptor->module->GetExecutionModel() != stage) { - return DAWN_VALIDATION_ERROR("Setting module with wrong stages"); + if (!module->HasEntryPoint(descriptor->entryPoint, stage)) { + return DAWN_VALIDATION_ERROR("Entry point doesn't exist in the module"); } + if (layout != nullptr) { - DAWN_TRY(descriptor->module->ValidateCompatibilityWithPipelineLayout(layout)); + const EntryPointMetadata& metadata = + module->GetEntryPoint(descriptor->entryPoint, stage); + DAWN_TRY(ValidateCompatibilityWithPipelineLayout(metadata, layout)); } return {}; } @@ -49,13 +50,20 @@ namespace dawn_native { ASSERT(!stages.empty()); for (const StageAndDescriptor& stage : stages) { + // Extract argument for this stage. + SingleShaderStage shaderStage = stage.first; + ShaderModuleBase* module = stage.second->module; + const char* entryPointName = stage.second->entryPoint; + const EntryPointMetadata& metadata = module->GetEntryPoint(entryPointName, shaderStage); + + // Record them internally. bool isFirstStage = mStageMask == wgpu::ShaderStage::None; - mStageMask |= StageBit(stage.first); - mStages[stage.first] = {stage.second->module, stage.second->entryPoint}; + mStageMask |= StageBit(shaderStage); + mStages[shaderStage] = {module, entryPointName, &metadata}; // Compute the max() of all minBufferSizes across all stages. RequiredBufferSizes stageMinBufferSizes = - stage.second->module->ComputeRequiredBufferSizesForLayout(layout); + ComputeRequiredBufferSizesForLayout(metadata, layout); if (isFirstStage) { mMinBufferSizes = std::move(stageMinBufferSizes); diff --git a/src/dawn_native/Pipeline.h b/src/dawn_native/Pipeline.h index 9df2db6884..44e3f981d2 100644 --- a/src/dawn_native/Pipeline.h +++ b/src/dawn_native/Pipeline.h @@ -36,6 +36,9 @@ namespace dawn_native { struct ProgrammableStage { Ref module; std::string entryPoint; + + // The metadata lives as long as module, that's ref-ed in the same structure. + const EntryPointMetadata* metadata = nullptr; }; class PipelineBase : public CachedObject { @@ -52,8 +55,6 @@ namespace dawn_native { static bool EqualForCache(const PipelineBase* a, const PipelineBase* b); protected: - using StageAndDescriptor = std::pair; - PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, std::vector stages); diff --git a/src/dawn_native/PipelineLayout.cpp b/src/dawn_native/PipelineLayout.cpp index f34003a5b7..61fdf74b95 100644 --- a/src/dawn_native/PipelineLayout.cpp +++ b/src/dawn_native/PipelineLayout.cpp @@ -114,9 +114,8 @@ namespace dawn_native { // static ResultOrError PipelineLayoutBase::CreateDefault( DeviceBase* device, - const ShaderModuleBase* const* modules, - uint32_t count) { - ASSERT(count > 0); + std::vector stages) { + ASSERT(!stages.empty()); // Data which BindGroupLayoutDescriptor will point to for creation ityp::array< @@ -134,20 +133,22 @@ namespace dawn_native { BindingCounts bindingCounts = {}; BindGroupIndex bindGroupLayoutCount(0); - for (uint32_t moduleIndex = 0; moduleIndex < count; ++moduleIndex) { - const ShaderModuleBase* module = modules[moduleIndex]; - const ShaderModuleBase::ModuleBindingInfo& info = module->GetBindingInfo(); + for (const StageAndDescriptor& stage : stages) { + // Extract argument for this stage. + SingleShaderStage shaderStage = stage.first; + const EntryPointMetadata::BindingInfo& info = + stage.second->module->GetEntryPoint(stage.second->entryPoint, shaderStage).bindings; for (BindGroupIndex group(0); group < info.size(); ++group) { for (const auto& it : info[group]) { BindingNumber bindingNumber = it.first; - const ShaderModuleBase::ShaderBindingInfo& bindingInfo = it.second; + const EntryPointMetadata::ShaderBindingInfo& bindingInfo = it.second; BindGroupLayoutEntry bindingSlot; bindingSlot.binding = static_cast(bindingNumber); - DAWN_TRY(ValidateBindingTypeWithShaderStageVisibility( - bindingInfo.type, StageBit(module->GetExecutionModel()))); + DAWN_TRY(ValidateBindingTypeWithShaderStageVisibility(bindingInfo.type, + StageBit(shaderStage))); DAWN_TRY(ValidateStorageTextureFormat(device, bindingInfo.type, bindingInfo.storageTextureFormat)); DAWN_TRY(ValidateStorageTextureViewDimension(bindingInfo.type, @@ -239,10 +240,10 @@ namespace dawn_native { } } - for (uint32_t moduleIndex = 0; moduleIndex < count; ++moduleIndex) { - ASSERT(modules[moduleIndex] - ->ValidateCompatibilityWithPipelineLayout(pipelineLayout) - .IsSuccess()); + for (const StageAndDescriptor& stage : stages) { + const EntryPointMetadata& metadata = + stage.second->module->GetEntryPoint(stage.second->entryPoint, stage.first); + ASSERT(ValidateCompatibilityWithPipelineLayout(metadata, pipelineLayout).IsSuccess()); } return pipelineLayout; diff --git a/src/dawn_native/PipelineLayout.h b/src/dawn_native/PipelineLayout.h index 862caaf8c7..be8a75cf74 100644 --- a/src/dawn_native/PipelineLayout.h +++ b/src/dawn_native/PipelineLayout.h @@ -37,14 +37,17 @@ namespace dawn_native { ityp::array, kMaxBindGroups>; using BindGroupLayoutMask = ityp::bitset; + using StageAndDescriptor = std::pair; + class PipelineLayoutBase : public CachedObject { public: PipelineLayoutBase(DeviceBase* device, const PipelineLayoutDescriptor* descriptor); ~PipelineLayoutBase() override; static PipelineLayoutBase* MakeError(DeviceBase* device); - static ResultOrError - CreateDefault(DeviceBase* device, const ShaderModuleBase* const* modules, uint32_t count); + static ResultOrError CreateDefault( + DeviceBase* device, + std::vector stages); const BindGroupLayoutBase* GetBindGroupLayout(BindGroupIndex group) const; BindGroupLayoutBase* GetBindGroupLayout(BindGroupIndex group); diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index 2d91291a75..62ea4b3a4c 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -333,8 +333,9 @@ namespace dawn_native { DAWN_TRY(ValidateRasterizationStateDescriptor(descriptor->rasterizationState)); } - if ((descriptor->vertexStage.module->GetUsedVertexAttributes() & ~attributesSetMask) - .any()) { + const EntryPointMetadata& vertexMetadata = descriptor->vertexStage.module->GetEntryPoint( + descriptor->vertexStage.entryPoint, SingleShaderStage::Vertex); + if ((vertexMetadata.usedVertexAttributes & ~attributesSetMask).any()) { return DAWN_VALIDATION_ERROR( "Pipeline vertex stage uses vertex buffers not in the vertex state"); } @@ -352,11 +353,13 @@ namespace dawn_native { } ASSERT(descriptor->fragmentStage != nullptr); - const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes = - descriptor->fragmentStage->module->GetFragmentOutputBaseTypes(); + const EntryPointMetadata& fragmentMetadata = + descriptor->fragmentStage->module->GetEntryPoint(descriptor->fragmentStage->entryPoint, + SingleShaderStage::Fragment); for (uint32_t i = 0; i < descriptor->colorStateCount; ++i) { - DAWN_TRY(ValidateColorStateDescriptor(device, descriptor->colorStates[i], - fragmentOutputBaseTypes[i])); + DAWN_TRY( + ValidateColorStateDescriptor(device, descriptor->colorStates[i], + fragmentMetadata.fragmentOutputFormatBaseTypes[i])); } if (descriptor->depthStencilState) { diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h index f06abff5f8..002b330701 100644 --- a/src/dawn_native/RenderPipeline.h +++ b/src/dawn_native/RenderPipeline.h @@ -28,6 +28,7 @@ namespace dawn_native { struct BeginRenderPassCmd; class DeviceBase; + struct EntryPointMetadata; class RenderBundleEncoder; MaybeError ValidateRenderPipelineDescriptor(DeviceBase* device, diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 03ddf035e5..96cd483967 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -549,7 +549,7 @@ namespace dawn_native { #endif // DAWN_ENABLE_WGSL std::vector GetBindGroupMinBufferSizes( - const ShaderModuleBase::BindingInfoMap& shaderBindings, + const EntryPointMetadata::BindingGroupInfoMap& shaderBindings, const BindGroupLayoutBase* layout) { std::vector requiredBufferSizes(layout->GetUnverifiedBufferCount()); uint32_t packedIdx = 0; @@ -578,17 +578,16 @@ namespace dawn_native { return requiredBufferSizes; } - MaybeError ValidateCompatibilityWithBindGroupLayout( - BindGroupIndex group, - const ShaderModuleBase::EntryPointMetadata& entryPoint, - const BindGroupLayoutBase* layout) { + MaybeError ValidateCompatibilityWithBindGroupLayout(BindGroupIndex group, + const EntryPointMetadata& entryPoint, + const BindGroupLayoutBase* layout) { const BindGroupLayoutBase::BindingMap& layoutBindings = layout->GetBindingMap(); // Iterate over all bindings used by this group in the shader, and find the // corresponding binding in the BindGroupLayout, if it exists. for (const auto& it : entryPoint.bindings[group]) { BindingNumber bindingNumber = it.first; - const ShaderModuleBase::ShaderBindingInfo& shaderInfo = it.second; + const EntryPointMetadata::ShaderBindingInfo& shaderInfo = it.second; const auto& bindingIt = layoutBindings.find(bindingNumber); if (bindingIt == layoutBindings.end()) { @@ -732,9 +731,8 @@ namespace dawn_native { return {}; } - RequiredBufferSizes ComputeRequiredBufferSizesForLayout( - const ShaderModuleBase::EntryPointMetadata& entryPoint, - const PipelineLayoutBase* layout) { + RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint, + const PipelineLayoutBase* layout) { RequiredBufferSizes bufferSizes; for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { bufferSizes[group] = GetBindGroupMinBufferSizes(entryPoint.bindings[group], @@ -744,9 +742,8 @@ namespace dawn_native { return bufferSizes; } - MaybeError ValidateCompatibilityWithPipelineLayout( - const ShaderModuleBase::EntryPointMetadata& entryPoint, - const PipelineLayoutBase* layout) { + MaybeError ValidateCompatibilityWithPipelineLayout(const EntryPointMetadata& entryPoint, + const PipelineLayoutBase* layout) { for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { DAWN_TRY(ValidateCompatibilityWithBindGroupLayout(group, entryPoint, layout->GetBindGroupLayout(group))); @@ -766,7 +763,7 @@ namespace dawn_native { // EntryPointMetadata - ShaderModuleBase::EntryPointMetadata::EntryPointMetadata() { + EntryPointMetadata::EntryPointMetadata() { fragmentOutputFormatBaseTypes.fill(Format::Type::Other); } @@ -814,6 +811,20 @@ namespace dawn_native { return new ShaderModuleBase(device, ObjectBase::kError); } + bool ShaderModuleBase::HasEntryPoint(const std::string& entryPoint, + SingleShaderStage stage) const { + // TODO(dawn:216): Properly extract all entryPoints from the shader module. + return entryPoint == "main" && stage == mMainEntryPoint->stage; + } + + const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint, + SingleShaderStage stage) const { + // TODO(dawn:216): Properly extract all entryPoints from the shader module. + ASSERT(entryPoint == "main"); + ASSERT(stage == mMainEntryPoint->stage); + return *mMainEntryPoint; + } + MaybeError ShaderModuleBase::ExtractSpirvInfo(const spirv_cross::Compiler& compiler) { ASSERT(!IsError()); if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) { @@ -824,7 +835,7 @@ namespace dawn_native { return {}; } - ResultOrError> + ResultOrError> ShaderModuleBase::ExtractSpirvInfoWithSpvc() { DeviceBase* device = GetDevice(); std::unique_ptr metadata = std::make_unique(); @@ -848,7 +859,7 @@ namespace dawn_native { // Fill in bindingInfo with the SPIRV bindings auto ExtractResourcesBinding = [](const DeviceBase* device, const std::vector& spvcBindings, - ModuleBindingInfo* metadataBindings) -> MaybeError { + EntryPointMetadata::BindingInfo* metadataBindings) -> MaybeError { for (const shaderc_spvc_binding_info& binding : spvcBindings) { BindGroupIndex bindGroupIndex(binding.set); @@ -857,12 +868,12 @@ namespace dawn_native { } const auto& it = (*metadataBindings)[bindGroupIndex].emplace( - BindingNumber(binding.binding), ShaderBindingInfo{}); + BindingNumber(binding.binding), EntryPointMetadata::ShaderBindingInfo{}); if (!it.second) { return DAWN_VALIDATION_ERROR("Shader has duplicate bindings"); } - ShaderBindingInfo* info = &it.first->second; + EntryPointMetadata::ShaderBindingInfo* info = &it.first->second; info->id = binding.id; info->base_type_id = binding.base_type_id; info->type = ToWGPUBindingType(binding.binding_type); @@ -994,7 +1005,7 @@ namespace dawn_native { return {std::move(metadata)}; } - ResultOrError> + ResultOrError> ShaderModuleBase::ExtractSpirvInfoWithSpirvCross(const spirv_cross::Compiler& compiler) { DeviceBase* device = GetDevice(); std::unique_ptr metadata = std::make_unique(); @@ -1031,7 +1042,7 @@ namespace dawn_native { [](const DeviceBase* device, const spirv_cross::SmallVector& resources, const spirv_cross::Compiler& compiler, wgpu::BindingType bindingType, - ModuleBindingInfo* metadataBindings) -> MaybeError { + EntryPointMetadata::BindingInfo* metadataBindings) -> MaybeError { for (const auto& resource : resources) { if (!compiler.get_decoration_bitset(resource.id).get(spv::DecorationBinding)) { return DAWN_VALIDATION_ERROR("No Binding decoration set for resource"); @@ -1051,13 +1062,13 @@ namespace dawn_native { return DAWN_VALIDATION_ERROR("Bind group index over limits in the SPIRV"); } - const auto& it = - (*metadataBindings)[bindGroupIndex].emplace(bindingNumber, ShaderBindingInfo{}); + const auto& it = (*metadataBindings)[bindGroupIndex].emplace( + bindingNumber, EntryPointMetadata::ShaderBindingInfo{}); if (!it.second) { return DAWN_VALIDATION_ERROR("Shader has duplicate bindings"); } - ShaderBindingInfo* info = &it.first->second; + EntryPointMetadata::ShaderBindingInfo* info = &it.first->second; info->id = resource.id; info->base_type_id = resource.base_type_id; @@ -1204,39 +1215,6 @@ namespace dawn_native { return {std::move(metadata)}; } - const ShaderModuleBase::ModuleBindingInfo& ShaderModuleBase::GetBindingInfo() const { - ASSERT(!IsError()); - return mMainEntryPoint->bindings; - } - - const std::bitset& ShaderModuleBase::GetUsedVertexAttributes() const { - ASSERT(!IsError()); - return mMainEntryPoint->usedVertexAttributes; - } - - const ShaderModuleBase::FragmentOutputBaseTypes& ShaderModuleBase::GetFragmentOutputBaseTypes() - const { - ASSERT(!IsError()); - return mMainEntryPoint->fragmentOutputFormatBaseTypes; - } - - SingleShaderStage ShaderModuleBase::GetExecutionModel() const { - ASSERT(!IsError()); - return mMainEntryPoint->stage; - } - - RequiredBufferSizes ShaderModuleBase::ComputeRequiredBufferSizesForLayout( - const PipelineLayoutBase* layout) const { - ASSERT(!IsError()); - return ::dawn_native::ComputeRequiredBufferSizesForLayout(*mMainEntryPoint, layout); - } - - MaybeError ShaderModuleBase::ValidateCompatibilityWithPipelineLayout( - const PipelineLayoutBase* layout) const { - ASSERT(!IsError()); - return ::dawn_native::ValidateCompatibilityWithPipelineLayout(*mMainEntryPoint, layout); - } - size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const { size_t hash = 0; @@ -1298,4 +1276,10 @@ namespace dawn_native { return {}; } + + SingleShaderStage ShaderModuleBase::GetMainEntryPointStageForTransition() const { + ASSERT(!IsError()); + return mMainEntryPoint->stage; + } + } // namespace dawn_native diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index f6779aa404..15d31aecfb 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -38,18 +38,25 @@ namespace spirv_cross { namespace dawn_native { + struct EntryPointMetadata; + MaybeError ValidateShaderModuleDescriptor(DeviceBase* device, const ShaderModuleDescriptor* descriptor); + MaybeError ValidateCompatibilityWithPipelineLayout(const EntryPointMetadata& entryPoint, + const PipelineLayoutBase* layout); - class ShaderModuleBase : public CachedObject { - public: - ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor); - ~ShaderModuleBase() override; + RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint, + const PipelineLayoutBase* layout); - static ShaderModuleBase* MakeError(DeviceBase* device); - - MaybeError ExtractSpirvInfo(const spirv_cross::Compiler& compiler); + // Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are + // stored in the ShaderModuleBase and destroyed only when the shader module is destroyed so + // pointers to EntryPointMetadata are safe to store as long as you also keep a Ref to the + // ShaderModuleBase. + struct EntryPointMetadata { + EntryPointMetadata(); + // Per-binding shader metadata contains some SPIRV specific information in addition to + // most of the frontend per-binding information. struct ShaderBindingInfo : BindingInfo { // The SPIRV ID of the resource. uint32_t id; @@ -61,22 +68,42 @@ namespace dawn_native { using BindingInfo::visibility; }; - using BindingInfoMap = std::map; - using ModuleBindingInfo = ityp::array; + // bindings[G][B] is the reflection data for the binding defined with + // [[group=G, binding=B]] in WGSL / SPIRV. + using BindingGroupInfoMap = std::map; + using BindingInfo = ityp::array; + BindingInfo bindings; - const ModuleBindingInfo& GetBindingInfo() const; - const std::bitset& GetUsedVertexAttributes() const; - SingleShaderStage GetExecutionModel() const; + // The set of vertex attributes this entryPoint uses. + std::bitset usedVertexAttributes; // An array to record the basic types (float, int and uint) of the fragment shader outputs // or Format::Type::Other means the fragment shader output is unused. using FragmentOutputBaseTypes = std::array; - const FragmentOutputBaseTypes& GetFragmentOutputBaseTypes() const; + FragmentOutputBaseTypes fragmentOutputFormatBaseTypes; - MaybeError ValidateCompatibilityWithPipelineLayout(const PipelineLayoutBase* layout) const; + // The shader stage for this binding, TODO(dawn:216): can likely be removed once we + // properly support multiple entrypoints per ShaderModule. + SingleShaderStage stage; + }; - RequiredBufferSizes ComputeRequiredBufferSizesForLayout( - const PipelineLayoutBase* layout) const; + class ShaderModuleBase : public CachedObject { + public: + ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor); + ~ShaderModuleBase() override; + + static ShaderModuleBase* MakeError(DeviceBase* device); + + // Return true iff the module has an entrypoint called `entryPoint` for stage `stage`. + bool HasEntryPoint(const std::string& entryPoint, SingleShaderStage stage) const; + + // Returns the metadata for the given `entryPoint` and `stage`. HasEntryPoint with the same + // arguments must be true. + const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint, + SingleShaderStage stage) const; + + // TODO make this member protected, it is only used outside of child classes in DeviceNull. + MaybeError ExtractSpirvInfo(const spirv_cross::Compiler& compiler); // Functors necessary for the unordered_set-based cache. struct HashFunc { @@ -96,15 +123,6 @@ namespace dawn_native { uint32_t pullingBufferBindingSet) const; #endif - struct EntryPointMetadata { - EntryPointMetadata(); - - ModuleBindingInfo bindings; - std::bitset usedVertexAttributes; - SingleShaderStage stage; - FragmentOutputBaseTypes fragmentOutputFormatBaseTypes; - }; - protected: static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg); shaderc_spvc::CompileOptions GetCompileOptions() const; @@ -112,6 +130,11 @@ namespace dawn_native { shaderc_spvc::Context mSpvcContext; + // Allows backends to get the stage for the "main" entrypoint while they are transitioned to + // support multiple entrypoints. + // TODO(dawn:216): Remove this once the transition is complete. + SingleShaderStage GetMainEntryPointStageForTransition() const; + private: ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag); diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp index 0bdbafa12c..3b16696774 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp @@ -170,13 +170,15 @@ namespace dawn_native { namespace d3d12 { compiler->set_hlsl_options(options_hlsl); } - const ModuleBindingInfo& moduleBindingInfo = GetBindingInfo(); + const EntryPointMetadata::BindingInfo& moduleBindingInfo = + GetEntryPoint("main", GetMainEntryPointStageForTransition()).bindings; + for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); const auto& bindingOffsets = bgl->GetBindingOffsets(); const auto& groupBindingInfo = moduleBindingInfo[group]; for (const auto& it : groupBindingInfo) { - const ShaderBindingInfo& bindingInfo = it.second; + const EntryPointMetadata::ShaderBindingInfo& bindingInfo = it.second; BindingNumber bindingNumber = it.first; BindingIndex bindingIndex = bgl->GetBindingIndex(bindingNumber); diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm index 5823ffe989..88e905cef4 100644 --- a/src/dawn_native/metal/RenderPipelineMTL.mm +++ b/src/dawn_native/metal/RenderPipelineMTL.mm @@ -368,8 +368,8 @@ namespace dawn_native { namespace metal { } } - const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes = - descriptor->fragmentStage->module->GetFragmentOutputBaseTypes(); + const EntryPointMetadata::FragmentOutputBaseTypes& fragmentOutputBaseTypes = + GetStage(SingleShaderStage::Fragment).metadata->fragmentOutputFormatBaseTypes; for (uint32_t i : IterateBitSet(GetColorAttachmentsMask())) { descriptorMTL.colorAttachments[i].pixelFormat = MetalPixelFormat(GetColorAttachmentFormat(i)); diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm index 208612e631..1e227ac012 100644 --- a/src/dawn_native/metal/ShaderModuleMTL.mm +++ b/src/dawn_native/metal/ShaderModuleMTL.mm @@ -259,11 +259,15 @@ namespace dawn_native { namespace metal { // TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like // clean_func_name: // https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213 - if (strcmp(functionName, "main") == 0) { - functionName = "main0"; + const char* metalFunctionName = functionName; + if (strcmp(metalFunctionName, "main") == 0) { + metalFunctionName = "main0"; + } + if (strcmp(metalFunctionName, "saturate") == 0) { + metalFunctionName = "saturate0"; } - NSString* name = [[NSString alloc] initWithUTF8String:functionName]; + NSString* name = [[NSString alloc] initWithUTF8String:metalFunctionName]; out->function = [library newFunctionWithName:name]; [library release]; } @@ -277,7 +281,7 @@ namespace dawn_native { namespace metal { } if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && - functionStage == SingleShaderStage::Vertex && GetUsedVertexAttributes().any()) { + GetEntryPoint(functionName, functionStage).usedVertexAttributes.any()) { out->needsStorageBufferLength = true; } diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp index 53aa101ad1..0b02d5306f 100644 --- a/src/dawn_native/opengl/ShaderModuleGL.cpp +++ b/src/dawn_native/opengl/ShaderModuleGL.cpp @@ -125,8 +125,6 @@ namespace dawn_native { namespace opengl { DAWN_TRY(ExtractSpirvInfo(*compiler)); - const ShaderModuleBase::ModuleBindingInfo& bindingInfo = GetBindingInfo(); - // Extract bindings names so that it can be used to get its location in program. // Now translate the separate sampler / textures into combined ones and store their info. // We need to do this before removing the set and binding decorations. @@ -182,6 +180,9 @@ namespace dawn_native { namespace opengl { } } + const EntryPointMetadata::BindingInfo& bindingInfo = + GetEntryPoint("main", GetMainEntryPointStageForTransition()).bindings; + // Change binding names to be "dawn_binding__". // Also unsets the SPIRV "Binding" decoration as it outputs "layout(binding=)" which // isn't supported on OSX's OpenGL. diff --git a/src/dawn_native/vulkan/RenderPipelineVk.cpp b/src/dawn_native/vulkan/RenderPipelineVk.cpp index 271aaaad6e..2c18cfd460 100644 --- a/src/dawn_native/vulkan/RenderPipelineVk.cpp +++ b/src/dawn_native/vulkan/RenderPipelineVk.cpp @@ -425,8 +425,8 @@ namespace dawn_native { namespace vulkan { // Initialize the "blend state info" that will be chained in the "create info" from the data // pre-computed in the ColorState std::array colorBlendAttachments; - const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes = - descriptor->fragmentStage->module->GetFragmentOutputBaseTypes(); + const EntryPointMetadata::FragmentOutputBaseTypes& fragmentOutputBaseTypes = + GetStage(SingleShaderStage::Fragment).metadata->fragmentOutputFormatBaseTypes; for (uint32_t i : IterateBitSet(GetColorAttachmentsMask())) { const ColorStateDescriptor* colorStateDescriptor = GetColorStateDescriptor(i); bool isDeclaredInFragmentShader = fragmentOutputBaseTypes[i] != Format::Type::Other;