diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp index c11a67e8e6..753f709a4e 100644 --- a/src/dawn_native/Device.cpp +++ b/src/dawn_native/Device.cpp @@ -50,6 +50,7 @@ namespace dawn_native { ContentLessObjectCache bindGroupLayouts; ContentLessObjectCache computePipelines; ContentLessObjectCache pipelineLayouts; + ContentLessObjectCache renderPipelines; ContentLessObjectCache shaderModules; }; @@ -164,6 +165,27 @@ namespace dawn_native { ASSERT(removedCount == 1); } + ResultOrError DeviceBase::GetOrCreateRenderPipeline( + const RenderPipelineDescriptor* descriptor) { + RenderPipelineBase blueprint(this, descriptor, true); + + auto iter = mCaches->renderPipelines.find(&blueprint); + if (iter != mCaches->renderPipelines.end()) { + (*iter)->Reference(); + return *iter; + } + + RenderPipelineBase* backendObj; + DAWN_TRY_ASSIGN(backendObj, CreateRenderPipelineImpl(descriptor)); + mCaches->renderPipelines.insert(backendObj); + return backendObj; + } + + void DeviceBase::UncacheRenderPipeline(RenderPipelineBase* obj) { + size_t removedCount = mCaches->renderPipelines.erase(obj); + ASSERT(removedCount == 1); + } + ResultOrError DeviceBase::GetOrCreateShaderModule( const ShaderModuleDescriptor* descriptor) { ShaderModuleBase blueprint(this, descriptor, true); @@ -412,7 +434,7 @@ namespace dawn_native { RenderPipelineBase** result, const RenderPipelineDescriptor* descriptor) { DAWN_TRY(ValidateRenderPipelineDescriptor(this, descriptor)); - DAWN_TRY_ASSIGN(*result, CreateRenderPipelineImpl(descriptor)); + DAWN_TRY_ASSIGN(*result, GetOrCreateRenderPipeline(descriptor)); return {}; } diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h index 0addd7803c..b9039e43d1 100644 --- a/src/dawn_native/Device.h +++ b/src/dawn_native/Device.h @@ -92,6 +92,10 @@ namespace dawn_native { const PipelineLayoutDescriptor* descriptor); void UncachePipelineLayout(PipelineLayoutBase* obj); + ResultOrError GetOrCreateRenderPipeline( + const RenderPipelineDescriptor* descriptor); + void UncacheRenderPipeline(RenderPipelineBase* obj); + ResultOrError GetOrCreateShaderModule( const ShaderModuleDescriptor* descriptor); void UncacheShaderModule(ShaderModuleBase* obj); diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index 84344ee816..e44d2879ee 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -15,6 +15,7 @@ #include "dawn_native/RenderPipeline.h" #include "common/BitSetIterator.h" +#include "common/HashUtils.h" #include "dawn_native/Commands.h" #include "dawn_native/Device.h" #include "dawn_native/Texture.h" @@ -328,15 +329,21 @@ namespace dawn_native { // RenderPipelineBase RenderPipelineBase::RenderPipelineBase(DeviceBase* device, - const RenderPipelineDescriptor* descriptor) + const RenderPipelineDescriptor* descriptor, + bool blueprint) : PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Vertex | dawn::ShaderStageBit::Fragment), mInputState(*descriptor->inputState), + mHasDepthStencilAttachment(descriptor->depthStencilState != nullptr), mPrimitiveTopology(descriptor->primitiveTopology), mRasterizationState(*descriptor->rasterizationState), - mHasDepthStencilAttachment(descriptor->depthStencilState != nullptr), - mSampleCount(descriptor->sampleCount) { + mSampleCount(descriptor->sampleCount), + mVertexModule(descriptor->vertexStage->module), + mVertexEntryPoint(descriptor->vertexStage->entryPoint), + mFragmentModule(descriptor->fragmentStage->module), + mFragmentEntryPoint(descriptor->fragmentStage->entryPoint), + mIsBlueprint(blueprint) { uint32_t location = 0; for (uint32_t i = 0; i < mInputState.numAttributes; ++i) { location = mInputState.attributes[i].shaderLocation; @@ -391,6 +398,13 @@ namespace dawn_native { return new RenderPipelineBase(device, ObjectBase::kError); } + RenderPipelineBase::~RenderPipelineBase() { + // Do not uncache the actual cached object if we are a blueprint + if (!mIsBlueprint && !IsError()) { + GetDevice()->UncacheRenderPipeline(this); + } + } + const InputStateDescriptor* RenderPipelineBase::GetInputStateDescriptor() const { ASSERT(!IsError()); return &mInputState; @@ -419,13 +433,13 @@ namespace dawn_native { } const ColorStateDescriptor* RenderPipelineBase::GetColorStateDescriptor( - uint32_t attachmentSlot) { + uint32_t attachmentSlot) const { ASSERT(!IsError()); ASSERT(attachmentSlot < mColorStates.size()); return &mColorStates[attachmentSlot]; } - const DepthStencilStateDescriptor* RenderPipelineBase::GetDepthStencilStateDescriptor() { + const DepthStencilStateDescriptor* RenderPipelineBase::GetDepthStencilStateDescriptor() const { ASSERT(!IsError()); return &mDepthStencilState; } @@ -509,4 +523,175 @@ namespace dawn_native { return attributesUsingInput[slot]; } + size_t RenderPipelineBase::HashFunc::operator()(const RenderPipelineBase* pipeline) const { + size_t hash = 0; + + // Hash modules and layout + HashCombine(&hash, pipeline->GetLayout()); + HashCombine(&hash, pipeline->mVertexModule.Get(), pipeline->mFragmentEntryPoint); + HashCombine(&hash, pipeline->mFragmentModule.Get(), pipeline->mFragmentEntryPoint); + + // Hash attachments + HashCombine(&hash, pipeline->mColorAttachmentsSet); + for (uint32_t i : IterateBitSet(pipeline->mColorAttachmentsSet)) { + const ColorStateDescriptor& desc = *pipeline->GetColorStateDescriptor(i); + HashCombine(&hash, desc.format, desc.writeMask); + HashCombine(&hash, desc.colorBlend.operation, desc.colorBlend.srcFactor, + desc.colorBlend.dstFactor); + HashCombine(&hash, desc.alphaBlend.operation, desc.alphaBlend.srcFactor, + desc.alphaBlend.dstFactor); + } + + if (pipeline->mHasDepthStencilAttachment) { + const DepthStencilStateDescriptor& desc = pipeline->mDepthStencilState; + HashCombine(&hash, desc.format, desc.depthWriteEnabled, desc.depthCompare); + HashCombine(&hash, desc.stencilReadMask, desc.stencilWriteMask); + HashCombine(&hash, desc.stencilFront.compare, desc.stencilFront.failOp, + desc.stencilFront.depthFailOp, desc.stencilFront.passOp); + HashCombine(&hash, desc.stencilBack.compare, desc.stencilBack.failOp, + desc.stencilBack.depthFailOp, desc.stencilBack.passOp); + } + + // Hash vertex input state + HashCombine(&hash, pipeline->mAttributesSetMask); + for (uint32_t i : IterateBitSet(pipeline->mAttributesSetMask)) { + const VertexAttributeDescriptor& desc = pipeline->GetAttribute(i); + HashCombine(&hash, desc.shaderLocation, desc.inputSlot, desc.offset, desc.format); + } + + HashCombine(&hash, pipeline->mInputsSetMask); + for (uint32_t i : IterateBitSet(pipeline->mInputsSetMask)) { + const VertexInputDescriptor& desc = pipeline->GetInput(i); + HashCombine(&hash, desc.inputSlot, desc.stride, desc.stepMode); + } + + HashCombine(&hash, pipeline->mInputState.indexFormat); + + // Hash rasterization state + { + const RasterizationStateDescriptor& desc = pipeline->mRasterizationState; + HashCombine(&hash, desc.frontFace, desc.cullMode); + HashCombine(&hash, desc.depthBias, desc.depthBiasSlopeScale, desc.depthBiasClamp); + } + + // Hash other state + HashCombine(&hash, pipeline->mSampleCount, pipeline->mPrimitiveTopology); + + return hash; + } + + bool RenderPipelineBase::EqualityFunc::operator()(const RenderPipelineBase* a, + const RenderPipelineBase* b) const { + // Check modules and layout + if (a->GetLayout() != b->GetLayout() || a->mVertexModule.Get() != b->mVertexModule.Get() || + a->mVertexEntryPoint != b->mVertexEntryPoint || + a->mFragmentModule.Get() != b->mFragmentModule.Get() || + a->mFragmentEntryPoint != b->mFragmentEntryPoint) { + return false; + } + + // Check attachments + if (a->mColorAttachmentsSet != b->mColorAttachmentsSet || + a->mHasDepthStencilAttachment != b->mHasDepthStencilAttachment) { + return false; + } + + for (uint32_t i : IterateBitSet(a->mColorAttachmentsSet)) { + const ColorStateDescriptor& descA = *a->GetColorStateDescriptor(i); + const ColorStateDescriptor& descB = *b->GetColorStateDescriptor(i); + if (descA.format != descB.format || descA.writeMask != descB.writeMask) { + return false; + } + if (descA.colorBlend.operation != descB.colorBlend.operation || + descA.colorBlend.srcFactor != descB.colorBlend.srcFactor || + descA.colorBlend.dstFactor != descB.colorBlend.dstFactor) { + return false; + } + if (descA.alphaBlend.operation != descB.alphaBlend.operation || + descA.alphaBlend.srcFactor != descB.alphaBlend.srcFactor || + descA.alphaBlend.dstFactor != descB.alphaBlend.dstFactor) { + return false; + } + } + + if (a->mHasDepthStencilAttachment) { + const DepthStencilStateDescriptor& descA = a->mDepthStencilState; + const DepthStencilStateDescriptor& descB = b->mDepthStencilState; + if (descA.format != descB.format || + descA.depthWriteEnabled != descB.depthWriteEnabled || + descA.depthCompare != descB.depthCompare) { + return false; + } + if (descA.stencilReadMask != descB.stencilReadMask || + descA.stencilWriteMask != descB.stencilWriteMask) { + return false; + } + if (descA.stencilFront.compare != descB.stencilFront.compare || + descA.stencilFront.failOp != descB.stencilFront.failOp || + descA.stencilFront.depthFailOp != descB.stencilFront.depthFailOp || + descA.stencilFront.passOp != descB.stencilFront.passOp) { + return false; + } + if (descA.stencilBack.compare != descB.stencilBack.compare || + descA.stencilBack.failOp != descB.stencilBack.failOp || + descA.stencilBack.depthFailOp != descB.stencilBack.depthFailOp || + descA.stencilBack.passOp != descB.stencilBack.passOp) { + return false; + } + } + + // Check vertex input state + if (a->mAttributesSetMask != b->mAttributesSetMask) { + return false; + } + + for (uint32_t i : IterateBitSet(a->mAttributesSetMask)) { + const VertexAttributeDescriptor& descA = a->GetAttribute(i); + const VertexAttributeDescriptor& descB = b->GetAttribute(i); + if (descA.shaderLocation != descB.shaderLocation || + descA.inputSlot != descB.inputSlot || descA.offset != descB.offset || + descA.format != descB.format) { + return false; + } + } + + if (a->mInputsSetMask != b->mInputsSetMask) { + return false; + } + + for (uint32_t i : IterateBitSet(a->mInputsSetMask)) { + const VertexInputDescriptor& descA = a->GetInput(i); + const VertexInputDescriptor& descB = b->GetInput(i); + if (descA.inputSlot != descB.inputSlot || descA.stride != descB.stride || + descA.stepMode != descB.stepMode) { + return false; + } + } + + if (a->mInputState.indexFormat != b->mInputState.indexFormat) { + return false; + } + + // Check rasterization state + { + const RasterizationStateDescriptor& descA = a->mRasterizationState; + const RasterizationStateDescriptor& descB = b->mRasterizationState; + if (descA.frontFace != descB.frontFace || descA.cullMode != descB.cullMode) { + return false; + } + if (descA.depthBias != descB.depthBias || + descA.depthBiasSlopeScale != descB.depthBiasSlopeScale || + descA.depthBiasClamp != descB.depthBiasClamp) { + return false; + } + } + + // Check other state + if (a->mSampleCount != b->mSampleCount || a->mPrimitiveTopology != b->mPrimitiveTopology) { + return false; + } + + return true; + } + } // namespace dawn_native diff --git a/src/dawn_native/RenderPipeline.h b/src/dawn_native/RenderPipeline.h index f72afd7e97..98fc1adfb1 100644 --- a/src/dawn_native/RenderPipeline.h +++ b/src/dawn_native/RenderPipeline.h @@ -40,7 +40,10 @@ namespace dawn_native { class RenderPipelineBase : public PipelineBase { public: - RenderPipelineBase(DeviceBase* device, const RenderPipelineDescriptor* descriptor); + RenderPipelineBase(DeviceBase* device, + const RenderPipelineDescriptor* descriptor, + bool blueprint = false); + ~RenderPipelineBase() override; static RenderPipelineBase* MakeError(DeviceBase* device); @@ -50,8 +53,8 @@ namespace dawn_native { const std::bitset& GetInputsSetMask() const; const VertexInputDescriptor& GetInput(uint32_t slot) const; - const ColorStateDescriptor* GetColorStateDescriptor(uint32_t attachmentSlot); - const DepthStencilStateDescriptor* GetDepthStencilStateDescriptor(); + const ColorStateDescriptor* GetColorStateDescriptor(uint32_t attachmentSlot) const; + const DepthStencilStateDescriptor* GetDepthStencilStateDescriptor() const; dawn::PrimitiveTopology GetPrimitiveTopology() const; dawn::CullMode GetCullMode() const; dawn::FrontFace GetFrontFace() const; @@ -68,23 +71,43 @@ namespace dawn_native { std::bitset GetAttributesUsingInput(uint32_t slot) const; std::array, kMaxVertexInputs> attributesUsingInput; + // Functors necessary for the unordered_set-based cache. + struct HashFunc { + size_t operator()(const RenderPipelineBase* pipeline) const; + }; + struct EqualityFunc { + bool operator()(const RenderPipelineBase* a, const RenderPipelineBase* b) const; + }; + private: RenderPipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag); + // Vertex input InputStateDescriptor mInputState; std::bitset mAttributesSetMask; std::array mAttributeInfos; std::bitset mInputsSetMask; std::array mInputInfos; - dawn::PrimitiveTopology mPrimitiveTopology; - RasterizationStateDescriptor mRasterizationState; + + // Attachments + bool mHasDepthStencilAttachment = false; DepthStencilStateDescriptor mDepthStencilState; + std::bitset mColorAttachmentsSet; std::array mColorStates; - std::bitset mColorAttachmentsSet; - bool mHasDepthStencilAttachment = false; - + // Other state + dawn::PrimitiveTopology mPrimitiveTopology; + RasterizationStateDescriptor mRasterizationState; uint32_t mSampleCount; + + // Stage information + // TODO(cwallez@chromium.org): Store a crypto hash of the modules instead. + Ref mVertexModule; + std::string mVertexEntryPoint; + Ref mFragmentModule; + std::string mFragmentEntryPoint; + + bool mIsBlueprint = false; }; } // namespace dawn_native diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp index 5172a6b77d..8b8e300981 100644 --- a/src/tests/end2end/ObjectCachingTests.cpp +++ b/src/tests/end2end/ObjectCachingTests.cpp @@ -14,6 +14,7 @@ #include "tests/DawnTest.h" +#include "utils/ComboRenderPipelineDescriptor.h" #include "utils/DawnHelpers.h" class ObjectCachingTest : public DawnTest {}; @@ -165,4 +166,124 @@ TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) { EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); } +// Test that RenderPipelines are correctly deduplicated wrt. their layout +TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnLayout) { + dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout( + device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}}); + dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout( + device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}}); + + dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl); + dawn::PipelineLayout samePl = utils::MakeBasicPipelineLayout(device, &bgl); + dawn::PipelineLayout otherPl = utils::MakeBasicPipelineLayout(device, nullptr); + + EXPECT_NE(pl.Get(), otherPl.Get()); + EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire()); + + utils::ComboRenderPipelineDescriptor desc(device); + desc.cVertexStage.module = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"( + #version 450 + void main() { + gl_Position = vec4(0.0); + })"); + desc.cFragmentStage.module = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"( + #version 450 + void main() { + })"); + + desc.layout = pl; + dawn::RenderPipeline pipeline = device.CreateRenderPipeline(&desc); + + desc.layout = samePl; + dawn::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc); + + desc.layout = otherPl; + dawn::RenderPipeline otherPipeline = device.CreateRenderPipeline(&desc); + + EXPECT_NE(pipeline.Get(), otherPipeline.Get()); + EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); +} + +// Test that RenderPipelines are correctly deduplicated wrt. their vertex module +TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnVertexModule) { + dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"( + #version 450 + void main() { + gl_Position = vec4(0.0); + })"); + dawn::ShaderModule sameModule = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"( + #version 450 + void main() { + gl_Position = vec4(0.0); + })"); + dawn::ShaderModule otherModule = + utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"( + #version 450 + void main() { + gl_Position = vec4(1.0); + })"); + + EXPECT_NE(module.Get(), otherModule.Get()); + EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire()); + + utils::ComboRenderPipelineDescriptor desc(device); + desc.cFragmentStage.module = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"( + #version 450 + void main() { + })"); + + desc.cVertexStage.module = module; + dawn::RenderPipeline pipeline = device.CreateRenderPipeline(&desc); + + desc.cVertexStage.module = sameModule; + dawn::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc); + + desc.cVertexStage.module = otherModule; + dawn::RenderPipeline otherPipeline = device.CreateRenderPipeline(&desc); + + EXPECT_NE(pipeline.Get(), otherPipeline.Get()); + EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); +} + +// Test that RenderPipelines are correctly deduplicated wrt. their fragment module +TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnFragmentModule) { + dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"( + #version 450 + void main() { + })"); + dawn::ShaderModule sameModule = + utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"( + #version 450 + void main() { + })"); + dawn::ShaderModule otherModule = + utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"( + #version 450 + void main() { + int i = 0; + })"); + + EXPECT_NE(module.Get(), otherModule.Get()); + EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire()); + + utils::ComboRenderPipelineDescriptor desc(device); + desc.cVertexStage.module = utils::CreateShaderModule(device, dawn::ShaderStage::Vertex, R"( + #version 450 + void main() { + gl_Position = vec4(0.0); + })"); + + desc.cFragmentStage.module = module; + dawn::RenderPipeline pipeline = device.CreateRenderPipeline(&desc); + + desc.cFragmentStage.module = sameModule; + dawn::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc); + + desc.cFragmentStage.module = otherModule; + dawn::RenderPipeline otherPipeline = device.CreateRenderPipeline(&desc); + + EXPECT_NE(pipeline.Get(), otherPipeline.Get()); + EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); +} + DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);