diff --git a/src/dawn_native/vulkan/ShaderModuleVk.cpp b/src/dawn_native/vulkan/ShaderModuleVk.cpp index b8b74753f4..bd8988065c 100644 --- a/src/dawn_native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn_native/vulkan/ShaderModuleVk.cpp @@ -77,7 +77,9 @@ namespace dawn_native { namespace vulkan { } ShaderModule::ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor) - : ShaderModuleBase(device, descriptor), mTransformedShaderModuleCache(device) { + : ShaderModuleBase(device, descriptor), + mTransformedShaderModuleCache( + std::make_unique(device)) { } MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult) { @@ -97,16 +99,24 @@ namespace dawn_native { namespace vulkan { return InitializeBase(parseResult); } + void ShaderModule::DestroyImpl() { + // Remove reference to internal cache to trigger cleanup. + mTransformedShaderModuleCache = nullptr; + } + ShaderModule::~ShaderModule() = default; ResultOrError ShaderModule::GetTransformedModuleHandle( const char* entryPointName, PipelineLayout* layout) { + // If the shader was destroyed, we should never call this function. + ASSERT(IsAlive()); + ScopedTintICEHandler scopedICEHandler(GetDevice()); auto cacheKey = std::make_pair(layout, entryPointName); VkShaderModule cachedShaderModule = - mTransformedShaderModuleCache.FindShaderModule(cacheKey); + mTransformedShaderModuleCache->FindShaderModule(cacheKey); if (cachedShaderModule != VK_NULL_HANDLE) { return cachedShaderModule; } @@ -180,7 +190,7 @@ namespace dawn_native { namespace vulkan { "CreateShaderModule")); if (newHandle != VK_NULL_HANDLE) { newHandle = - mTransformedShaderModuleCache.AddOrGetCachedShaderModule(cacheKey, newHandle); + mTransformedShaderModuleCache->AddOrGetCachedShaderModule(cacheKey, newHandle); } SetDebugName(ToBackend(GetDevice()), VK_OBJECT_TYPE_SHADER_MODULE, diff --git a/src/dawn_native/vulkan/ShaderModuleVk.h b/src/dawn_native/vulkan/ShaderModuleVk.h index b0a7bc33fc..9d28f91d4d 100644 --- a/src/dawn_native/vulkan/ShaderModuleVk.h +++ b/src/dawn_native/vulkan/ShaderModuleVk.h @@ -40,6 +40,7 @@ namespace dawn_native { namespace vulkan { ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ~ShaderModule() override; MaybeError Initialize(ShaderModuleParseResult* parseResult); + void DestroyImpl() override; // New handles created by GetTransformedModuleHandle at pipeline creation time class ConcurrentTransformedShaderModuleCache { @@ -58,7 +59,7 @@ namespace dawn_native { namespace vulkan { PipelineLayoutEntryPointPairHashFunc> mTransformedShaderModuleCache; }; - ConcurrentTransformedShaderModuleCache mTransformedShaderModuleCache; + std::unique_ptr mTransformedShaderModuleCache; }; }} // namespace dawn_native::vulkan