diff --git a/src/dawn_native/CommandBufferStateTracker.cpp b/src/dawn_native/CommandBufferStateTracker.cpp index 3e806a18e3..3418753d53 100644 --- a/src/dawn_native/CommandBufferStateTracker.cpp +++ b/src/dawn_native/CommandBufferStateTracker.cpp @@ -26,11 +26,11 @@ namespace dawn_native { namespace { bool BufferSizesAtLeastAsBig(const ityp::span unverifiedBufferSizes, - const std::vector& pipelineMinimumBufferSizes) { - ASSERT(unverifiedBufferSizes.size() == pipelineMinimumBufferSizes.size()); + const std::vector& pipelineMinBufferSizes) { + ASSERT(unverifiedBufferSizes.size() == pipelineMinBufferSizes.size()); for (uint32_t i = 0; i < unverifiedBufferSizes.size(); ++i) { - if (unverifiedBufferSizes[i] < pipelineMinimumBufferSizes[i]) { + if (unverifiedBufferSizes[i] < pipelineMinBufferSizes[i]) { return false; } } @@ -105,7 +105,7 @@ namespace dawn_native { if (mBindgroups[i] == nullptr || mLastPipelineLayout->GetBindGroupLayout(i) != mBindgroups[i]->GetLayout() || !BufferSizesAtLeastAsBig(mBindgroups[i]->GetUnverifiedBufferSizes(), - (*mMinimumBufferSizes)[i])) { + (*mMinBufferSizes)[i])) { matches = false; break; } @@ -190,7 +190,7 @@ namespace dawn_native { "Pipeline and bind group layout doesn't match for bind group " + std::to_string(static_cast(i))); } else if (!BufferSizesAtLeastAsBig(mBindgroups[i]->GetUnverifiedBufferSizes(), - (*mMinimumBufferSizes)[i])) { + (*mMinBufferSizes)[i])) { return DAWN_VALIDATION_ERROR("Binding sizes too small for bind group " + std::to_string(static_cast(i))); } @@ -236,7 +236,7 @@ namespace dawn_native { void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) { mLastPipelineLayout = pipeline->GetLayout(); - mMinimumBufferSizes = &pipeline->GetMinimumBufferSizes(); + mMinBufferSizes = &pipeline->GetMinBufferSizes(); mAspects.set(VALIDATION_ASPECT_PIPELINE); diff --git a/src/dawn_native/CommandBufferStateTracker.h b/src/dawn_native/CommandBufferStateTracker.h index 67645e4580..146214db8e 100644 --- a/src/dawn_native/CommandBufferStateTracker.h +++ b/src/dawn_native/CommandBufferStateTracker.h @@ -61,7 +61,7 @@ namespace dawn_native { PipelineLayoutBase* mLastPipelineLayout = nullptr; RenderPipelineBase* mLastRenderPipeline = nullptr; - const RequiredBufferSizes* mMinimumBufferSizes = nullptr; + const RequiredBufferSizes* mMinBufferSizes = nullptr; }; } // namespace dawn_native diff --git a/src/dawn_native/ComputePipeline.cpp b/src/dawn_native/ComputePipeline.cpp index ee49b1151d..793765dc47 100644 --- a/src/dawn_native/ComputePipeline.cpp +++ b/src/dawn_native/ComputePipeline.cpp @@ -19,13 +19,6 @@ namespace dawn_native { - namespace { - RequiredBufferSizes ComputeMinBufferSizes(const ComputePipelineDescriptor* descriptor) { - return descriptor->computeStage.module->ComputeRequiredBufferSizesForLayout( - descriptor->layout); - } - } // anonymous namespace - MaybeError ValidateComputePipelineDescriptor(DeviceBase* device, const ComputePipelineDescriptor* descriptor) { if (descriptor->nextInChain != nullptr) { @@ -47,10 +40,7 @@ namespace dawn_native { const ComputePipelineDescriptor* descriptor) : PipelineBase(device, descriptor->layout, - wgpu::ShaderStage::Compute, - ComputeMinBufferSizes(descriptor)), - mModule(descriptor->computeStage.module), - mEntryPoint(descriptor->computeStage.entryPoint) { + {{SingleShaderStage::Compute, &descriptor->computeStage}}) { } ComputePipelineBase::ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag) @@ -70,15 +60,12 @@ namespace dawn_native { } size_t ComputePipelineBase::HashFunc::operator()(const ComputePipelineBase* pipeline) const { - size_t hash = 0; - HashCombine(&hash, pipeline->mModule.Get(), pipeline->mEntryPoint, pipeline->GetLayout()); - return hash; + return PipelineBase::HashForCache(pipeline); } bool ComputePipelineBase::EqualityFunc::operator()(const ComputePipelineBase* a, const ComputePipelineBase* b) const { - return a->mModule.Get() == b->mModule.Get() && a->mEntryPoint == b->mEntryPoint && - a->GetLayout() == b->GetLayout(); + return PipelineBase::EqualForCache(a, b); } } // namespace dawn_native diff --git a/src/dawn_native/ComputePipeline.h b/src/dawn_native/ComputePipeline.h index 43d7966568..c2f118899c 100644 --- a/src/dawn_native/ComputePipeline.h +++ b/src/dawn_native/ComputePipeline.h @@ -41,10 +41,6 @@ namespace dawn_native { private: ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag); - - // TODO(cwallez@chromium.org): Store a crypto hash of the module instead. - Ref mModule; - std::string mEntryPoint; }; } // namespace dawn_native diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp index ae05c02f4b..f37e30cebc 100644 --- a/src/dawn_native/Pipeline.cpp +++ b/src/dawn_native/Pipeline.cpp @@ -14,6 +14,7 @@ #include "dawn_native/Pipeline.h" +#include "common/HashUtils.h" #include "dawn_native/BindGroupLayout.h" #include "dawn_native/Device.h" #include "dawn_native/PipelineLayout.h" @@ -43,23 +44,38 @@ namespace dawn_native { PipelineBase::PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, - wgpu::ShaderStage stages, - RequiredBufferSizes minimumBufferSizes) - : CachedObject(device), - mStageMask(stages), - mLayout(layout), - mMinimumBufferSizes(std::move(minimumBufferSizes)) { + std::vector stages) + : CachedObject(device), mLayout(layout) { + ASSERT(!stages.empty()); + + for (const StageAndDescriptor& stage : stages) { + bool isFirstStage = mStageMask == wgpu::ShaderStage::None; + mStageMask |= StageBit(stage.first); + mStages[stage.first] = {stage.second->module, stage.second->entryPoint}; + + // Compute the max() of all minBufferSizes across all stages. + RequiredBufferSizes stageMinBufferSizes = + stage.second->module->ComputeRequiredBufferSizesForLayout(layout); + + if (isFirstStage) { + mMinBufferSizes = std::move(stageMinBufferSizes); + } else { + for (BindGroupIndex group(0); group < mMinBufferSizes.size(); ++group) { + ASSERT(stageMinBufferSizes[group].size() == mMinBufferSizes[group].size()); + + for (size_t i = 0; i < stageMinBufferSizes[group].size(); ++i) { + mMinBufferSizes[group][i] = + std::max(mMinBufferSizes[group][i], stageMinBufferSizes[group][i]); + } + } + } + } } PipelineBase::PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag) : CachedObject(device, tag) { } - wgpu::ShaderStage PipelineBase::GetStageMask() const { - ASSERT(!IsError()); - return mStageMask; - } - PipelineLayoutBase* PipelineBase::GetLayout() { ASSERT(!IsError()); return mLayout.Get(); @@ -70,9 +86,14 @@ namespace dawn_native { return mLayout.Get(); } - const RequiredBufferSizes& PipelineBase::GetMinimumBufferSizes() const { + const RequiredBufferSizes& PipelineBase::GetMinBufferSizes() const { ASSERT(!IsError()); - return mMinimumBufferSizes; + return mMinBufferSizes; + } + + const ProgrammableStage& PipelineBase::GetStage(SingleShaderStage stage) const { + ASSERT(!IsError()); + return mStages[stage]; } MaybeError PipelineBase::ValidateGetBindGroupLayout(uint32_t groupIndex) { @@ -102,4 +123,39 @@ namespace dawn_native { return bgl; } + // static + size_t PipelineBase::HashForCache(const PipelineBase* pipeline) { + size_t hash = 0; + + // The layout is deduplicated so it can be hashed by pointer. + HashCombine(&hash, pipeline->mLayout.Get()); + + HashCombine(&hash, pipeline->mStageMask); + for (SingleShaderStage stage : IterateStages(pipeline->mStageMask)) { + // The module is deduplicated so it can be hashed by pointer. + HashCombine(&hash, pipeline->mStages[stage].module.Get()); + HashCombine(&hash, pipeline->mStages[stage].entryPoint); + } + + return hash; + } + + // static + bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) { + // The layout is deduplicated so it can be compared by pointer. + if (a->mLayout.Get() != b->mLayout.Get() || a->mStageMask != b->mStageMask) { + return false; + } + + for (SingleShaderStage stage : IterateStages(a->mStageMask)) { + // The module is deduplicated so it can be compared by pointer. + if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() || + a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) { + return false; + } + } + + return true; + } + } // namespace dawn_native diff --git a/src/dawn_native/Pipeline.h b/src/dawn_native/Pipeline.h index bfc846bcde..9df2db6884 100644 --- a/src/dawn_native/Pipeline.h +++ b/src/dawn_native/Pipeline.h @@ -33,27 +33,40 @@ namespace dawn_native { const PipelineLayoutBase* layout, SingleShaderStage stage); + struct ProgrammableStage { + Ref module; + std::string entryPoint; + }; + class PipelineBase : public CachedObject { public: - wgpu::ShaderStage GetStageMask() const; PipelineLayoutBase* GetLayout(); const PipelineLayoutBase* GetLayout() const; + const RequiredBufferSizes& GetMinBufferSizes() const; + const ProgrammableStage& GetStage(SingleShaderStage stage) const; + BindGroupLayoutBase* GetBindGroupLayout(uint32_t groupIndex); - const RequiredBufferSizes& GetMinimumBufferSizes() const; + + // Helper function for the functors for std::unordered_map-based pipeline caches. + static size_t HashForCache(const PipelineBase* pipeline); + static bool EqualForCache(const PipelineBase* a, const PipelineBase* b); protected: + using StageAndDescriptor = std::pair; + PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, - wgpu::ShaderStage stages, - RequiredBufferSizes bufferSizes); + std::vector stages); PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag); private: MaybeError ValidateGetBindGroupLayout(uint32_t group); - wgpu::ShaderStage mStageMask; + wgpu::ShaderStage mStageMask = wgpu::ShaderStage::None; + PerStage mStages; + Ref mLayout; - RequiredBufferSizes mMinimumBufferSizes; + RequiredBufferSizes mMinBufferSizes; }; } // namespace dawn_native diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index 83d9b11631..1341ee5331 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -193,29 +193,6 @@ namespace dawn_native { return {}; } - RequiredBufferSizes ComputeMinBufferSizes(const RenderPipelineDescriptor* descriptor) { - RequiredBufferSizes bufferSizes = - descriptor->vertexStage.module->ComputeRequiredBufferSizesForLayout( - descriptor->layout); - - // Merge the two buffer size requirements by taking the larger element from each - if (descriptor->fragmentStage != nullptr) { - RequiredBufferSizes fragmentSizes = - descriptor->fragmentStage->module->ComputeRequiredBufferSizesForLayout( - descriptor->layout); - - for (BindGroupIndex group(0); group < bufferSizes.size(); ++group) { - ASSERT(bufferSizes[group].size() == fragmentSizes[group].size()); - for (size_t i = 0; i < bufferSizes[group].size(); ++i) { - bufferSizes[group][i] = - std::max(bufferSizes[group][i], fragmentSizes[group][i]); - } - } - } - - return bufferSizes; - } - } // anonymous namespace // Helper functions @@ -411,16 +388,12 @@ namespace dawn_native { const RenderPipelineDescriptor* descriptor) : PipelineBase(device, descriptor->layout, - wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment, - ComputeMinBufferSizes(descriptor)), + {{SingleShaderStage::Vertex, &descriptor->vertexStage}, + {SingleShaderStage::Fragment, descriptor->fragmentStage}}), mAttachmentState(device->GetOrCreateAttachmentState(descriptor)), mPrimitiveTopology(descriptor->primitiveTopology), mSampleMask(descriptor->sampleMask), - mAlphaToCoverageEnabled(descriptor->alphaToCoverageEnabled), - mVertexModule(descriptor->vertexStage.module), - mVertexEntryPoint(descriptor->vertexStage.entryPoint), - mFragmentModule(descriptor->fragmentStage->module), - mFragmentEntryPoint(descriptor->fragmentStage->entryPoint) { + mAlphaToCoverageEnabled(descriptor->alphaToCoverageEnabled) { if (descriptor->vertexState != nullptr) { mVertexState = *descriptor->vertexState; } else { @@ -608,12 +581,8 @@ namespace dawn_native { } size_t RenderPipelineBase::HashFunc::operator()(const RenderPipelineBase* pipeline) const { - size_t hash = 0; - // Hash modules and layout - HashCombine(&hash, pipeline->GetLayout()); - HashCombine(&hash, pipeline->mVertexModule.Get(), pipeline->mFragmentEntryPoint); - HashCombine(&hash, pipeline->mFragmentModule.Get(), pipeline->mFragmentEntryPoint); + size_t hash = PipelineBase::HashForCache(pipeline); // Hierarchically hash the attachment state. // It contains the attachments set, texture formats, and sample count. @@ -671,11 +640,8 @@ namespace dawn_native { bool RenderPipelineBase::EqualityFunc::operator()(const RenderPipelineBase* a, const RenderPipelineBase* b) const { - // Check modules and layout - if (a->GetLayout() != b->GetLayout() || a->mVertexModule.Get() != b->mVertexModule.Get() || - a->mVertexEntryPoint != b->mVertexEntryPoint || - a->mFragmentModule.Get() != b->mFragmentModule.Get() || - a->mFragmentEntryPoint != b->mFragmentEntryPoint) { + // Check the layout and shader stages. + if (!PipelineBase::EqualForCache(a, b)) { return false; } diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h index 5c328f9f28..aee8d3dbaf 100644 --- a/src/dawn_native/RenderPipeline.h +++ b/src/dawn_native/RenderPipeline.h @@ -114,13 +114,6 @@ namespace dawn_native { RasterizationStateDescriptor mRasterizationState; uint32_t mSampleMask; bool mAlphaToCoverageEnabled; - - // Stage information - // TODO(cwallez@chromium.org): Store a crypto hash of the modules instead. - Ref mVertexModule; - std::string mVertexEntryPoint; - Ref mFragmentModule; - std::string mFragmentEntryPoint; }; } // namespace dawn_native