From 3e4b57b77ea03dd9fd81b7777f2ea295d0cef1c0 Mon Sep 17 00:00:00 2001 From: Corentin Wallez Date: Mon, 22 Mar 2021 18:24:16 +0000 Subject: [PATCH] ShaderModule: Store the tint::Program in the base class. This is in preparation of removing all the DAWN_ENABLE_WGSL logic: the ShaderModuleBase will have either mSpirv or mTintProgram set based on UseTintGenerator. Also improves the constness of some functions. Also simplifies a bit ShaderModuleBase::Initialize. Bug: dawn:706 Change-Id: Ib879e2aec8a004aeb8ac5dc6e1176b1667fc227d Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/45422 Commit-Queue: Austin Eng Auto-Submit: Corentin Wallez Reviewed-by: Ryan Harrison Reviewed-by: Austin Eng --- src/dawn_native/ShaderModule.cpp | 66 ++++++++------------- src/dawn_native/ShaderModule.h | 14 ++++- src/dawn_native/d3d12/ShaderModuleD3D12.cpp | 8 +-- src/dawn_native/d3d12/ShaderModuleD3D12.h | 4 -- src/dawn_native/metal/ShaderModuleMTL.h | 4 -- src/dawn_native/metal/ShaderModuleMTL.mm | 12 ++-- 6 files changed, 42 insertions(+), 66 deletions(-) diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index e21907d181..2695cbaf6c 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -266,7 +266,7 @@ namespace dawn_native { return std::move(program); } - MaybeError ValidateModule(tint::Program* program) { + MaybeError ValidateModule(const tint::Program* program) { std::ostringstream errorStream; errorStream << "Tint program validation" << std::endl; @@ -729,14 +729,14 @@ namespace dawn_native { // PopulateMetadataUsingSPIRVCross will be removed. ResultOrError ReflectShaderUsingTint( DeviceBase*, - const tint::Program& program) { - ASSERT(program.IsValid()); + const tint::Program* program) { + ASSERT(program->IsValid()); EntryPointMetadataTable result; std::ostringstream errorStream; errorStream << "Tint Reflection failure:" << std::endl; - tint::inspector::Inspector inspector(&program); + tint::inspector::Inspector inspector(program); auto entryPoints = inspector.GetEntryPoints(); if (inspector.has_error()) { errorStream << "Inspector: " << inspector.error() << std::endl; @@ -842,7 +842,7 @@ namespace dawn_native { // fallback/source of truth. ResultOrError ReflectShaderUsingSPIRVCross( DeviceBase* device, - std::vector spirv) { + const std::vector& spirv) { EntryPointMetadataTable result; spirv_cross::Compiler compiler(spirv); for (const spirv_cross::EntryPoint& entryPoint : @@ -866,7 +866,7 @@ namespace dawn_native { // using the SPIRV-Cross implementation. Once the Tint implementation is // completed, this function will be removed. MaybeError PopulateMetadataUsingSPIRVCross(DeviceBase* device, - std::vector spirv, + const std::vector& spirv, EntryPointMetadataTable* tintTable) { EntryPointMetadataTable crossTable; DAWN_TRY_ASSIGN(crossTable, ReflectShaderUsingSPIRVCross(device, spirv)); @@ -998,6 +998,7 @@ namespace dawn_native { if (device->IsValidationEnabled()) { DAWN_TRY(ValidateModule(&program)); } + parseResult.tintProgram = std::make_unique(std::move(program)); } else { tint::transform::Manager transformManager; transformManager.append( @@ -1015,8 +1016,6 @@ namespace dawn_native { parseResult.spirv = std::move(spirv); } - - parseResult.tintProgram = std::make_unique(std::move(program)); break; #else return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build."); @@ -1042,7 +1041,7 @@ namespace dawn_native { #ifdef DAWN_ENABLE_WGSL ResultOrError RunTransforms(tint::transform::Transform* transform, - tint::Program* program) { + const tint::Program* program) { tint::transform::Transform::Output output = transform->Run(program); if (!output.program.IsValid()) { std::string err = "Tint program failure: " + output.program.Diagnostics().str(); @@ -1169,6 +1168,11 @@ namespace dawn_native { } #ifdef DAWN_ENABLE_WGSL + const tint::Program* ShaderModuleBase::GetTintProgram() const { + ASSERT(GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)); + return mTintProgram.get(); + } + ResultOrError> ShaderModuleBase::GeneratePullingSpirv( const std::vector& spirv, const VertexStateDescriptor& vertexState, @@ -1181,7 +1185,7 @@ namespace dawn_native { } ResultOrError> ShaderModuleBase::GeneratePullingSpirv( - tint::Program* programIn, + const tint::Program* programIn, const VertexStateDescriptor& vertexState, const std::string& entryPoint, BindGroupIndex pullingBufferBindingSet) const { @@ -1214,7 +1218,7 @@ namespace dawn_native { MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult) { #ifdef DAWN_ENABLE_WGSL - tint::Program* program = parseResult->tintProgram.get(); + mTintProgram = std::move(parseResult->tintProgram); #endif mSpirv = std::move(parseResult->spirv); @@ -1226,43 +1230,23 @@ namespace dawn_native { DAWN_TRY_ASSIGN(mSpirv, RunRobustBufferAccessPass(mSpirv)); } - // We still need the spirv for reflection. Remove this when we use the Tint inspector - // completely. - std::vector* spirvPtr = &mSpirv; - std::vector localSpirv; if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) { #ifdef DAWN_ENABLE_WGSL - ASSERT(program != nullptr); + // We still need the spirv for reflection. Remove this when we use the Tint inspector + // completely. + std::vector reflectionSpirv; + DAWN_TRY_ASSIGN(reflectionSpirv, ModuleToSPIRV(mTintProgram.get())); + DAWN_TRY(ValidateSpirv(reflectionSpirv.data(), reflectionSpirv.size())); - DAWN_TRY_ASSIGN(localSpirv, ModuleToSPIRV(program)); - DAWN_TRY(ValidateSpirv(localSpirv.data(), localSpirv.size())); - spirvPtr = &localSpirv; + EntryPointMetadataTable table; + DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), mTintProgram.get())); + DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), reflectionSpirv, &table)); + mEntryPoints = std::move(table); #else UNREACHABLE(); -#endif - } - - if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) { -#ifdef DAWN_ENABLE_WGSL - tint::Program localProgram; - - tint::Program* programPtr = program; - if (!GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) { - // We have mSpirv, but no Tint program - DAWN_TRY_ASSIGN(localProgram, ParseSPIRV(mSpirv)); - DAWN_TRY(ValidateModule(&localProgram)); - programPtr = &localProgram; - } - - EntryPointMetadataTable table; - DAWN_TRY_ASSIGN(table, ReflectShaderUsingTint(GetDevice(), *programPtr)); - DAWN_TRY(PopulateMetadataUsingSPIRVCross(GetDevice(), *spirvPtr, &table)); - mEntryPoints = std::move(table); -#else - return DAWN_VALIDATION_ERROR("Using Tint is not enabled in this build."); #endif } else { - DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), *spirvPtr)); + DAWN_TRY_ASSIGN(mEntryPoints, ReflectShaderUsingSPIRVCross(GetDevice(), mSpirv)); } return {}; diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index 61b6973f0f..c428fb40be 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -77,7 +77,7 @@ namespace dawn_native { const PipelineLayoutBase* layout); #ifdef DAWN_ENABLE_WGSL ResultOrError RunTransforms(tint::transform::Transform* transform, - tint::Program* program); + const tint::Program* program); std::unique_ptr MakeVertexPullingTransform( const VertexStateDescriptor& vertexState, @@ -147,6 +147,8 @@ namespace dawn_native { const std::vector& GetSpirv() const; #ifdef DAWN_ENABLE_WGSL + const tint::Program* GetTintProgram() const; + ResultOrError> GeneratePullingSpirv( const std::vector& spirv, const VertexStateDescriptor& vertexState, @@ -154,7 +156,7 @@ namespace dawn_native { BindGroupIndex pullingBufferBindingSet) const; ResultOrError> GeneratePullingSpirv( - tint::Program* program, + const tint::Program* program, const VertexStateDescriptor& vertexState, const std::string& entryPoint, BindGroupIndex pullingBufferBindingSet) const; @@ -166,13 +168,19 @@ namespace dawn_native { private: ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag); + // The original data in the descriptor for caching. enum class Type { Undefined, Spirv, Wgsl }; Type mType; std::vector mOriginalSpirv; - std::vector mSpirv; std::string mWgsl; + // Data computed from what is in the descriptor. mSpirv is set iff !UseTintGenerator while + // mTintProgram is set iff UseTintGenerator. EntryPointMetadataTable mEntryPoints; + std::vector mSpirv; +#ifdef DAWN_ENABLE_WGSL + std::unique_ptr mTintProgram; +#endif }; } // namespace dawn_native diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp index d9ce9a6391..bcd1844a39 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp @@ -186,11 +186,7 @@ namespace dawn_native { namespace d3d12 { } MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { - DAWN_TRY(InitializeBase(parseResult)); -#ifdef DAWN_ENABLE_WGSL - mTintProgram = std::move(parseResult->tintProgram); -#endif - return {}; + return InitializeBase(parseResult); } ResultOrError ShaderModule::TranslateToHLSLWithTint( @@ -215,7 +211,7 @@ namespace dawn_native { namespace d3d12 { transformManager.append(std::make_unique()); transformManager.append(std::make_unique()); - tint::transform::Transform::Output output = transformManager.Run(mTintProgram.get()); + tint::transform::Transform::Output output = transformManager.Run(GetTintProgram()); tint::Program& program = output.program; if (!program.IsValid()) { diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h index b3850bca25..3cd44e480c 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.h +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h @@ -76,10 +76,6 @@ namespace dawn_native { namespace d3d12 { ResultOrError GetDXCompilerVersion() const; uint64_t GetD3DCompilerVersion() const; - -#ifdef DAWN_ENABLE_WGSL - std::unique_ptr mTintProgram; -#endif }; }} // namespace dawn_native::d3d12 diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h index 621c198c3c..3d777e94e0 100644 --- a/src/dawn_native/metal/ShaderModuleMTL.h +++ b/src/dawn_native/metal/ShaderModuleMTL.h @@ -69,10 +69,6 @@ namespace dawn_native { namespace metal { ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ~ShaderModule() override = default; MaybeError Initialize(ShaderModuleParseResult* parseResult); - -#ifdef DAWN_ENABLE_WGSL - std::unique_ptr mTintProgram; -#endif }; }} // namespace dawn_native::metal diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm index 1f9df97781..3357444125 100644 --- a/src/dawn_native/metal/ShaderModuleMTL.mm +++ b/src/dawn_native/metal/ShaderModuleMTL.mm @@ -48,11 +48,7 @@ namespace dawn_native { namespace metal { } MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { - DAWN_TRY(InitializeBase(parseResult)); -#ifdef DAWN_ENABLE_WGSL - mTintProgram = std::move(parseResult->tintProgram); -#endif - return {}; + return InitializeBase(parseResult); } ResultOrError ShaderModule::TranslateToMSLWithTint( @@ -90,7 +86,7 @@ namespace dawn_native { namespace metal { transformManager.append(std::make_unique()); transformManager.append(std::make_unique()); - tint::transform::Transform::Output output = transformManager.Run(mTintProgram.get()); + tint::transform::Transform::Output output = transformManager.Run(GetTintProgram()); tint::Program& program = output.program; if (!program.IsValid()) { @@ -137,9 +133,9 @@ namespace dawn_native { namespace metal { std::vector pullingSpirv; if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && stage == SingleShaderStage::Vertex) { - if (mTintProgram) { + if (GetDevice()->IsToggleEnabled(Toggle::UseTintGenerator)) { DAWN_TRY_ASSIGN(pullingSpirv, - GeneratePullingSpirv(mTintProgram.get(), + GeneratePullingSpirv(GetTintProgram(), *renderPipeline->GetVertexStateDescriptor(), entryPointName, kPullingBufferBindingSet)); } else {