diff --git a/src/dawn_native/ComputePipeline.cpp b/src/dawn_native/ComputePipeline.cpp index a2fb60f652..0ae2911969 100644 --- a/src/dawn_native/ComputePipeline.cpp +++ b/src/dawn_native/ComputePipeline.cpp @@ -14,6 +14,7 @@ #include "dawn_native/ComputePipeline.h" +#include "common/HashUtils.h" #include "dawn_native/Device.h" namespace dawn_native { @@ -33,8 +34,12 @@ namespace dawn_native { // ComputePipelineBase ComputePipelineBase::ComputePipelineBase(DeviceBase* device, - const ComputePipelineDescriptor* descriptor) - : PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) { + const ComputePipelineDescriptor* descriptor, + bool blueprint) + : PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute), + mModule(descriptor->computeStage->module), + mEntryPoint(descriptor->computeStage->entryPoint), + mIsBlueprint(blueprint) { ExtractModuleData(dawn::ShaderStage::Compute, descriptor->computeStage->module); } @@ -42,9 +47,29 @@ namespace dawn_native { : PipelineBase(device, tag) { } + ComputePipelineBase::~ComputePipelineBase() { + // Do not uncache the actual cached object if we are a blueprint + if (!mIsBlueprint) { + ASSERT(!IsError()); + GetDevice()->UncacheComputePipeline(this); + } + } + // static ComputePipelineBase* ComputePipelineBase::MakeError(DeviceBase* device) { return new ComputePipelineBase(device, ObjectBase::kError); } + size_t ComputePipelineBase::HashFunc::operator()(const ComputePipelineBase* pipeline) const { + size_t hash = 0; + HashCombine(&hash, pipeline->mModule.Get(), pipeline->mEntryPoint, pipeline->GetLayout()); + return hash; + } + + bool ComputePipelineBase::EqualityFunc::operator()(const ComputePipelineBase* a, + const ComputePipelineBase* b) const { + return a->mModule.Get() == b->mModule.Get() && a->mEntryPoint == b->mEntryPoint && + a->GetLayout() == b->GetLayout(); + } + } // namespace dawn_native diff --git a/src/dawn_native/ComputePipeline.h b/src/dawn_native/ComputePipeline.h index c1450d3a9c..006c469d96 100644 --- a/src/dawn_native/ComputePipeline.h +++ b/src/dawn_native/ComputePipeline.h @@ -26,12 +26,28 @@ namespace dawn_native { class ComputePipelineBase : public PipelineBase { public: - ComputePipelineBase(DeviceBase* device, const ComputePipelineDescriptor* descriptor); + ComputePipelineBase(DeviceBase* device, + const ComputePipelineDescriptor* descriptor, + bool blueprint = false); + ~ComputePipelineBase() override; static ComputePipelineBase* MakeError(DeviceBase* device); + // Functors necessary for the unordered_set-based cache. + struct HashFunc { + size_t operator()(const ComputePipelineBase* pipeline) const; + }; + struct EqualityFunc { + bool operator()(const ComputePipelineBase* a, const ComputePipelineBase* b) const; + }; + private: ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag); + + // TODO(cwallez@chromium.org): Store a crypto hash of the module instead. + Ref mModule; + std::string mEntryPoint; + bool mIsBlueprint = false; }; } // namespace dawn_native diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp index 844feb0778..c11a67e8e6 100644 --- a/src/dawn_native/Device.cpp +++ b/src/dawn_native/Device.cpp @@ -48,6 +48,7 @@ namespace dawn_native { struct DeviceBase::Caches { ContentLessObjectCache bindGroupLayouts; + ContentLessObjectCache computePipelines; ContentLessObjectCache pipelineLayouts; ContentLessObjectCache shaderModules; }; @@ -121,6 +122,27 @@ namespace dawn_native { ASSERT(removedCount == 1); } + ResultOrError DeviceBase::GetOrCreateComputePipeline( + const ComputePipelineDescriptor* descriptor) { + ComputePipelineBase blueprint(this, descriptor, true); + + auto iter = mCaches->computePipelines.find(&blueprint); + if (iter != mCaches->computePipelines.end()) { + (*iter)->Reference(); + return *iter; + } + + ComputePipelineBase* backendObj; + DAWN_TRY_ASSIGN(backendObj, CreateComputePipelineImpl(descriptor)); + mCaches->computePipelines.insert(backendObj); + return backendObj; + } + + void DeviceBase::UncacheComputePipeline(ComputePipelineBase* obj) { + size_t removedCount = mCaches->computePipelines.erase(obj); + ASSERT(removedCount == 1); + } + ResultOrError DeviceBase::GetOrCreatePipelineLayout( const PipelineLayoutDescriptor* descriptor) { PipelineLayoutBase blueprint(this, descriptor, true); @@ -369,7 +391,7 @@ namespace dawn_native { ComputePipelineBase** result, const ComputePipelineDescriptor* descriptor) { DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor)); - DAWN_TRY_ASSIGN(*result, CreateComputePipelineImpl(descriptor)); + DAWN_TRY_ASSIGN(*result, GetOrCreateComputePipeline(descriptor)); return {}; } diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h index aae0751cc5..0addd7803c 100644 --- a/src/dawn_native/Device.h +++ b/src/dawn_native/Device.h @@ -84,6 +84,10 @@ namespace dawn_native { const BindGroupLayoutDescriptor* descriptor); void UncacheBindGroupLayout(BindGroupLayoutBase* obj); + ResultOrError GetOrCreateComputePipeline( + const ComputePipelineDescriptor* descriptor); + void UncacheComputePipeline(ComputePipelineBase* obj); + ResultOrError GetOrCreatePipelineLayout( const PipelineLayoutDescriptor* descriptor); void UncachePipelineLayout(PipelineLayoutBase* obj); diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp index e839b1b5b4..8c245e1aa9 100644 --- a/src/dawn_native/Pipeline.cpp +++ b/src/dawn_native/Pipeline.cpp @@ -87,4 +87,9 @@ namespace dawn_native { return mLayout.Get(); } + const PipelineLayoutBase* PipelineBase::GetLayout() const { + ASSERT(!IsError()); + return mLayout.Get(); + } + } // namespace dawn_native diff --git a/src/dawn_native/Pipeline.h b/src/dawn_native/Pipeline.h index c917125bd4..55d57bf02e 100644 --- a/src/dawn_native/Pipeline.h +++ b/src/dawn_native/Pipeline.h @@ -48,6 +48,7 @@ namespace dawn_native { const PushConstantInfo& GetPushConstants(dawn::ShaderStage stage) const; dawn::ShaderStageBit GetStageMask() const; PipelineLayoutBase* GetLayout(); + const PipelineLayoutBase* GetLayout() const; protected: PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, dawn::ShaderStageBit stages); diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp index 64b6ae15f1..43feae8561 100644 --- a/src/tests/end2end/ObjectCachingTests.cpp +++ b/src/tests/end2end/ObjectCachingTests.cpp @@ -75,4 +75,86 @@ TEST_P(ObjectCachingTest, ShaderModuleDeduplication) { EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire()); } +// Test that ComputePipeline are correctly deduplicated wrt. their ShaderModule +TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) { + dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"( + #version 450 + void main() { + int i = 0; + })"); + dawn::ShaderModule sameModule = + utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"( + #version 450 + void main() { + int i = 0; + })"); + dawn::ShaderModule otherModule = + utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"( + #version 450 + void main() { + })"); + + EXPECT_NE(module.Get(), otherModule.Get()); + EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire()); + + dawn::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr); + + dawn::PipelineStageDescriptor stageDesc; + stageDesc.entryPoint = "main"; + stageDesc.module = module; + + dawn::ComputePipelineDescriptor desc; + desc.computeStage = &stageDesc; + desc.layout = layout; + + dawn::ComputePipeline pipeline = device.CreateComputePipeline(&desc); + + stageDesc.module = sameModule; + dawn::ComputePipeline samePipeline = device.CreateComputePipeline(&desc); + + stageDesc.module = otherModule; + dawn::ComputePipeline otherPipeline = device.CreateComputePipeline(&desc); + + EXPECT_NE(pipeline.Get(), otherPipeline.Get()); + EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); +} + +// Test that ComputePipeline are correctly deduplicated wrt. their layout +TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) { + dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout( + device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}}); + dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout( + device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}}); + + dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl); + dawn::PipelineLayout samePl = utils::MakeBasicPipelineLayout(device, &bgl); + dawn::PipelineLayout otherPl = utils::MakeBasicPipelineLayout(device, nullptr); + + EXPECT_NE(pl.Get(), otherPl.Get()); + EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire()); + + dawn::PipelineStageDescriptor stageDesc; + stageDesc.entryPoint = "main"; + stageDesc.module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"( + #version 450 + void main() { + int i = 0; + })"); + + dawn::ComputePipelineDescriptor desc; + desc.computeStage = &stageDesc; + + desc.layout = pl; + dawn::ComputePipeline pipeline = device.CreateComputePipeline(&desc); + + desc.layout = samePl; + dawn::ComputePipeline samePipeline = device.CreateComputePipeline(&desc); + + desc.layout = otherPl; + dawn::ComputePipeline otherPipeline = device.CreateComputePipeline(&desc); + + EXPECT_NE(pipeline.Get(), otherPipeline.Get()); + EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); +} + DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);