From c535198d96f9047bf14c603677449326d0f1342a Mon Sep 17 00:00:00 2001 From: Corentin Wallez Date: Wed, 1 May 2019 13:27:07 +0000 Subject: [PATCH] dawn_native: deduplicate shader modules BUG=dawn:143 Change-Id: I2c0fa63e3a6d77c137418f12b9807d16a0636d57 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6862 Commit-Queue: Corentin Wallez Reviewed-by: Kai Ninomiya --- src/dawn_native/Device.cpp | 24 +++++++++++++++++- src/dawn_native/Device.h | 4 +++ src/dawn_native/ShaderModule.cpp | 32 ++++++++++++++++++++++-- src/dawn_native/ShaderModule.h | 18 ++++++++++++- src/tests/end2end/ObjectCachingTests.cpp | 27 ++++++++++++++++++++ 5 files changed, 101 insertions(+), 4 deletions(-) diff --git a/src/dawn_native/Device.cpp b/src/dawn_native/Device.cpp index 33bbdec36e..844feb0778 100644 --- a/src/dawn_native/Device.cpp +++ b/src/dawn_native/Device.cpp @@ -49,6 +49,7 @@ namespace dawn_native { struct DeviceBase::Caches { ContentLessObjectCache bindGroupLayouts; ContentLessObjectCache pipelineLayouts; + ContentLessObjectCache shaderModules; }; // DeviceBase @@ -141,6 +142,27 @@ namespace dawn_native { ASSERT(removedCount == 1); } + ResultOrError DeviceBase::GetOrCreateShaderModule( + const ShaderModuleDescriptor* descriptor) { + ShaderModuleBase blueprint(this, descriptor, true); + + auto iter = mCaches->shaderModules.find(&blueprint); + if (iter != mCaches->shaderModules.end()) { + (*iter)->Reference(); + return *iter; + } + + ShaderModuleBase* backendObj; + DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor)); + mCaches->shaderModules.insert(backendObj); + return backendObj; + } + + void DeviceBase::UncacheShaderModule(ShaderModuleBase* obj) { + size_t removedCount = mCaches->shaderModules.erase(obj); + ASSERT(removedCount == 1); + } + // Object creation API methods BindGroupBase* DeviceBase::CreateBindGroup(const BindGroupDescriptor* descriptor) { @@ -382,7 +404,7 @@ namespace dawn_native { MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result, const ShaderModuleDescriptor* descriptor) { DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor)); - DAWN_TRY_ASSIGN(*result, CreateShaderModuleImpl(descriptor)); + DAWN_TRY_ASSIGN(*result, GetOrCreateShaderModule(descriptor)); return {}; } diff --git a/src/dawn_native/Device.h b/src/dawn_native/Device.h index 3f4a913ef1..aae0751cc5 100644 --- a/src/dawn_native/Device.h +++ b/src/dawn_native/Device.h @@ -88,6 +88,10 @@ namespace dawn_native { const PipelineLayoutDescriptor* descriptor); void UncachePipelineLayout(PipelineLayoutBase* obj); + ResultOrError GetOrCreateShaderModule( + const ShaderModuleDescriptor* descriptor); + void UncacheShaderModule(ShaderModuleBase* obj); + // Dawn API BindGroupBase* CreateBindGroup(const BindGroupDescriptor* descriptor); BindGroupLayoutBase* CreateBindGroupLayout(const BindGroupLayoutDescriptor* descriptor); diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index ca6548d2bc..1cdf18fdf1 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -14,6 +14,7 @@ #include "dawn_native/ShaderModule.h" +#include "common/HashUtils.h" #include "dawn_native/BindGroupLayout.h" #include "dawn_native/Device.h" #include "dawn_native/Pipeline.h" @@ -67,14 +68,26 @@ namespace dawn_native { // ShaderModuleBase - ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor*) - : ObjectBase(device) { + ShaderModuleBase::ShaderModuleBase(DeviceBase* device, + const ShaderModuleDescriptor* descriptor, + bool blueprint) + : ObjectBase(device), + mCode(descriptor->code, descriptor->code + descriptor->codeSize), + mIsBlueprint(blueprint) { } ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag) : ObjectBase(device, tag) { } + ShaderModuleBase::~ShaderModuleBase() { + // Do not uncache the actual cached object if we are a blueprint + if (!mIsBlueprint) { + ASSERT(!IsError()); + GetDevice()->UncacheShaderModule(this); + } + } + // static ShaderModuleBase* ShaderModuleBase::MakeError(DeviceBase* device) { return new ShaderModuleBase(device, ObjectBase::kError); @@ -287,4 +300,19 @@ namespace dawn_native { return true; } + size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const { + size_t hash = 0; + + for (uint32_t word : module->mCode) { + HashCombine(&hash, word); + } + + return hash; + } + + bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a, + const ShaderModuleBase* b) const { + return a->mCode == b->mCode; + } + } // namespace dawn_native diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index b8020f9b7b..ab00c273fb 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -37,7 +37,10 @@ namespace dawn_native { class ShaderModuleBase : public ObjectBase { public: - ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor); + ShaderModuleBase(DeviceBase* device, + const ShaderModuleDescriptor* descriptor, + bool blueprint = false); + ~ShaderModuleBase() override; static ShaderModuleBase* MakeError(DeviceBase* device); @@ -68,11 +71,24 @@ namespace dawn_native { bool IsCompatibleWithPipelineLayout(const PipelineLayoutBase* layout); + // Functors necessary for the unordered_set-based cache. + struct HashFunc { + size_t operator()(const ShaderModuleBase* module) const; + }; + struct EqualityFunc { + bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const; + }; + private: ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag); bool IsCompatibleWithBindGroupLayout(size_t group, const BindGroupLayoutBase* layout); + // TODO(cwallez@chromium.org): The code is only stored for deduplication. We could maybe + // store a cryptographic hash of the code instead? + std::vector mCode; + bool mIsBlueprint = false; + PushConstantInfo mPushConstants = {}; ModuleBindingInfo mBindingInfo; std::bitset mUsedVertexAttributes; diff --git a/src/tests/end2end/ObjectCachingTests.cpp b/src/tests/end2end/ObjectCachingTests.cpp index d2a591b819..64b6ae15f1 100644 --- a/src/tests/end2end/ObjectCachingTests.cpp +++ b/src/tests/end2end/ObjectCachingTests.cpp @@ -48,4 +48,31 @@ TEST_P(ObjectCachingTest, PipelineLayoutDeduplication) { EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire()); } +// Test that ShaderModules are correctly deduplicated. +TEST_P(ObjectCachingTest, ShaderModuleDeduplication) { + dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"( + #version 450 + layout(location = 0) out vec4 fragColor; + void main() { + fragColor = vec4(0.0, 1.0, 0.0, 1.0); + })"); + dawn::ShaderModule sameModule = + utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"( + #version 450 + layout(location = 0) out vec4 fragColor; + void main() { + fragColor = vec4(0.0, 1.0, 0.0, 1.0); + })"); + dawn::ShaderModule otherModule = + utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"( + #version 450 + layout(location = 0) out vec4 fragColor; + void main() { + fragColor = vec4(0.0); + })"); + + EXPECT_NE(module.Get(), otherModule.Get()); + EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire()); +} + DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);