Add compute pipeline cache key generation for Vulkan.

- Adds dependency to vulkan-tools for pNext chain helpers.
- Adds extra caching to vulkan shaders to keep the spirv in the in-memory cache as well.
- Adds pNext chain serializer infra for Vulkan.

Change-Id: Ibe73183fbff15f7310eaaeae92fbd622be1ac096
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/85022
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
This commit is contained in:
Loko Kung 2022-04-09 00:10:08 +00:00 committed by Dawn LUCI CQ
parent 30353998c1
commit 23d09c6114
16 changed files with 340 additions and 52 deletions

View File

@ -142,7 +142,8 @@ set_if_not_defined(DAWN_JINJA2_DIR "${DAWN_THIRD_PARTY_DIR}/jinja2" "Directory i
set_if_not_defined(DAWN_SPIRV_HEADERS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/spirv-headers/src" "Directory in which to find SPIRV-Headers") set_if_not_defined(DAWN_SPIRV_HEADERS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/spirv-headers/src" "Directory in which to find SPIRV-Headers")
set_if_not_defined(DAWN_SPIRV_TOOLS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/spirv-tools/src" "Directory in which to find SPIRV-Tools") set_if_not_defined(DAWN_SPIRV_TOOLS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/spirv-tools/src" "Directory in which to find SPIRV-Tools")
set_if_not_defined(DAWN_TINT_DIR "${Dawn_SOURCE_DIR}" "Directory in which to find Tint") set_if_not_defined(DAWN_TINT_DIR "${Dawn_SOURCE_DIR}" "Directory in which to find Tint")
set_if_not_defined(DAWN_VULKAN_HEADERS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps/vulkan-headers/src" "Directory in which to find Vulkan-Headers") set_if_not_defined(DAWN_VULKAN_DEPS_DIR "${DAWN_THIRD_PARTY_DIR}/vulkan-deps" "Directory in which to find vulkan-deps")
set_if_not_defined(DAWN_VULKAN_HEADERS_DIR "${DAWN_VULKAN_DEPS_DIR}/vulkan-headers/src" "Directory in which to find Vulkan-Headers")
# Dependencies for DAWN_BUILD_NODE_BINDINGS # Dependencies for DAWN_BUILD_NODE_BINDINGS
set_if_not_defined(NODE_ADDON_API_DIR "${DAWN_THIRD_PARTY_DIR}/node-addon-api" "Directory in which to find node-addon-api") set_if_not_defined(NODE_ADDON_API_DIR "${DAWN_THIRD_PARTY_DIR}/node-addon-api" "Directory in which to find node-addon-api")

View File

@ -54,14 +54,17 @@ if (!defined(dawn_swiftshader_dir)) {
dawn_swiftshader_dir = "" dawn_swiftshader_dir = ""
} }
if (!defined(dawn_vulkan_headers_dir)) { if (!defined(dawn_vulkan_deps_dir)) {
dawn_vulkan_headers_dir = "//third_party/vulkan-deps/vulkan-headers/src" dawn_vulkan_deps_dir = "//third_party/vulkan-deps"
if (dawn_standalone) { if (dawn_standalone) {
dawn_vulkan_headers_dir = dawn_vulkan_deps_dir = "${dawn_root}/third_party/vulkan-deps"
"${dawn_root}/third_party/vulkan-deps/vulkan-headers/src"
} }
} }
if (!defined(dawn_vulkan_headers_dir)) {
dawn_vulkan_headers_dir = "${dawn_vulkan_deps_dir}/vulkan-headers/src"
}
if (!defined(dawn_vulkan_loader_dir)) { if (!defined(dawn_vulkan_loader_dir)) {
# Default to the Vulkan loader not being available except in standalone. # Default to the Vulkan loader not being available except in standalone.
dawn_vulkan_loader_dir = "" dawn_vulkan_loader_dir = ""
@ -70,6 +73,10 @@ if (!defined(dawn_vulkan_loader_dir)) {
} }
} }
if (!defined(dawn_vulkan_tools_dir)) {
dawn_vulkan_tools_dir = "${dawn_vulkan_deps_dir}/vulkan-tools/src"
}
if (!defined(dawn_vulkan_validation_layers_dir)) { if (!defined(dawn_vulkan_validation_layers_dir)) {
# Default to VVLs not being available. # Default to VVLs not being available.
dawn_vulkan_validation_layers_dir = "" dawn_vulkan_validation_layers_dir = ""

View File

@ -99,6 +99,11 @@ config("vulkan_rpath") {
} }
} }
# Config that adds include directory for vulkan-deps, specifically for Vulkan-Tools.
config("vulkan_deps_include") {
include_dirs = [ "${dawn_vulkan_deps_dir}" ]
}
dawn_json_generator("utils_gen") { dawn_json_generator("utils_gen") {
target = "native_utils" target = "native_utils"
outputs = [ outputs = [
@ -571,6 +576,8 @@ source_set("sources") {
} }
if (dawn_enable_vulkan) { if (dawn_enable_vulkan) {
configs += [ ":vulkan_deps_include" ]
deps += [ "${dawn_vulkan_tools_dir}:vulkan_tools_headers" ]
public_deps += [ "${dawn_vulkan_headers_dir}:vulkan_headers" ] public_deps += [ "${dawn_vulkan_headers_dir}:vulkan_headers" ]
sources += [ sources += [
"vulkan/AdapterVk.cpp", "vulkan/AdapterVk.cpp",
@ -583,6 +590,8 @@ source_set("sources") {
"vulkan/BindGroupVk.h", "vulkan/BindGroupVk.h",
"vulkan/BufferVk.cpp", "vulkan/BufferVk.cpp",
"vulkan/BufferVk.h", "vulkan/BufferVk.h",
"vulkan/CacheKeyVk.cpp",
"vulkan/CacheKeyVk.h",
"vulkan/CommandBufferVk.cpp", "vulkan/CommandBufferVk.cpp",
"vulkan/CommandBufferVk.h", "vulkan/CommandBufferVk.h",
"vulkan/CommandRecordingContext.h", "vulkan/CommandRecordingContext.h",

View File

@ -455,6 +455,8 @@ if (DAWN_ENABLE_VULKAN)
"vulkan/BindGroupVk.h" "vulkan/BindGroupVk.h"
"vulkan/BufferVk.cpp" "vulkan/BufferVk.cpp"
"vulkan/BufferVk.h" "vulkan/BufferVk.h"
"vulkan/CacheKeyVk.cpp"
"vulkan/CacheKeyVk.h"
"vulkan/CommandBufferVk.cpp" "vulkan/CommandBufferVk.cpp"
"vulkan/CommandBufferVk.h" "vulkan/CommandBufferVk.h"
"vulkan/CommandRecordingContext.h" "vulkan/CommandRecordingContext.h"
@ -510,6 +512,7 @@ if (DAWN_ENABLE_VULKAN)
) )
target_link_libraries(dawn_native PUBLIC Vulkan-Headers) target_link_libraries(dawn_native PUBLIC Vulkan-Headers)
target_include_directories(dawn_native PRIVATE ${DAWN_VULKAN_DEPS_DIR})
if (UNIX AND NOT APPLE) if (UNIX AND NOT APPLE)
target_sources(dawn_native PRIVATE target_sources(dawn_native PRIVATE

View File

@ -14,8 +14,19 @@
#include "dawn/native/CacheKey.h" #include "dawn/native/CacheKey.h"
#include <iomanip>
namespace dawn::native { namespace dawn::native {
std::ostream& operator<<(std::ostream& os, const CacheKey& key) {
os << std::hex;
for (const int b : key) {
os << std::setfill('0') << std::setw(2) << b << " ";
}
os << std::dec;
return os;
}
template <> template <>
void CacheKeySerializer<std::string>::Serialize(CacheKey* key, const std::string& t) { void CacheKeySerializer<std::string>::Serialize(CacheKey* key, const std::string& t) {
key->Record(static_cast<size_t>(t.length())); key->Record(static_cast<size_t>(t.length()));

View File

@ -15,17 +15,20 @@
#ifndef DAWNNATIVE_CACHE_KEY_H_ #ifndef DAWNNATIVE_CACHE_KEY_H_
#define DAWNNATIVE_CACHE_KEY_H_ #define DAWNNATIVE_CACHE_KEY_H_
#include <iostream>
#include <limits> #include <limits>
#include <string> #include <string>
#include <type_traits> #include <type_traits>
#include <vector> #include <vector>
#include "dawn/common/Assert.h"
namespace dawn::native { namespace dawn::native {
// Forward declare CacheKey class because of co-dependency. // Forward declare classes because of co-dependency.
class CacheKey; class CacheKey;
class CachedObject;
// Stream operator for CacheKey for debugging.
std::ostream& operator<<(std::ostream& os, const CacheKey& key);
// Overridable serializer struct that should be implemented for cache key serializable // Overridable serializer struct that should be implemented for cache key serializable
// types/classes. // types/classes.
@ -82,7 +85,31 @@ namespace dawn::native {
} }
}; };
// Specialized overload for string literals. Note we drop the null-terminator. // Specialized overload for enums.
template <typename T>
class CacheKeySerializer<T, std::enable_if_t<std::is_enum_v<T>>> {
public:
static void Serialize(CacheKey* key, const T t) {
CacheKeySerializer<std::underlying_type_t<T>>::Serialize(
key, static_cast<std::underlying_type_t<T>>(t));
}
};
// Specialized overload for pointers. Since we are serializing for a cache key, we always
// serialize via value, not by pointer. To handle nullptr scenarios, we always serialize whether
// the pointer was nullptr followed by the contents if applicable.
template <typename T>
class CacheKeySerializer<T, std::enable_if_t<std::is_pointer_v<T>>> {
public:
static void Serialize(CacheKey* key, const T t) {
key->Record(t == nullptr);
if (t != nullptr) {
CacheKeySerializer<std::remove_cv_t<std::remove_pointer_t<T>>>::Serialize(key, *t);
}
}
};
// Specialized overload for string literals.
template <size_t N> template <size_t N>
class CacheKeySerializer<char[N]> { class CacheKeySerializer<char[N]> {
public: public:
@ -93,6 +120,15 @@ namespace dawn::native {
} }
}; };
// Specialized overload for CachedObjects.
template <typename T>
class CacheKeySerializer<T, std::enable_if_t<std::is_base_of_v<CachedObject, T>>> {
public:
static void Serialize(CacheKey* key, const T& t) {
key->Record(t.GetCacheKey());
}
};
} // namespace dawn::native } // namespace dawn::native
#endif // DAWNNATIVE_CACHE_KEY_H_ #endif // DAWNNATIVE_CACHE_KEY_H_

View File

@ -16,6 +16,7 @@
#include "dawn/common/BitSetIterator.h" #include "dawn/common/BitSetIterator.h"
#include "dawn/common/ityp_vector.h" #include "dawn/common/ityp_vector.h"
#include "dawn/native/CacheKey.h"
#include "dawn/native/vulkan/BindGroupVk.h" #include "dawn/native/vulkan/BindGroupVk.h"
#include "dawn/native/vulkan/DescriptorSetAllocator.h" #include "dawn/native/vulkan/DescriptorSetAllocator.h"
#include "dawn/native/vulkan/DeviceVk.h" #include "dawn/native/vulkan/DeviceVk.h"
@ -115,6 +116,9 @@ namespace dawn::native::vulkan {
createInfo.bindingCount = static_cast<uint32_t>(bindings.size()); createInfo.bindingCount = static_cast<uint32_t>(bindings.size());
createInfo.pBindings = bindings.data(); createInfo.pBindings = bindings.data();
// Record cache key information now since the createInfo is not stored.
GetCacheKey()->Record(createInfo);
Device* device = ToBackend(GetDevice()); Device* device = ToBackend(GetDevice());
DAWN_TRY(CheckVkSuccess(device->fn.CreateDescriptorSetLayout( DAWN_TRY(CheckVkSuccess(device->fn.CreateDescriptorSetLayout(
device->GetVkDevice(), &createInfo, nullptr, &*mHandle), device->GetVkDevice(), &createInfo, nullptr, &*mHandle),

View File

@ -22,6 +22,10 @@
#include <vector> #include <vector>
namespace dawn::native {
class CacheKey;
} // namespace dawn::native
namespace dawn::native::vulkan { namespace dawn::native::vulkan {
class BindGroup; class BindGroup;

View File

@ -0,0 +1,97 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn/native/vulkan/CacheKeyVk.h"
#include <cstring>
namespace dawn::native {
template <>
void CacheKeySerializer<VkDescriptorSetLayoutBinding>::Serialize(
CacheKey* key,
const VkDescriptorSetLayoutBinding& t) {
key->Record(t.binding, t.descriptorType, t.descriptorCount, t.stageFlags);
}
template <>
void CacheKeySerializer<VkDescriptorSetLayoutCreateInfo>::Serialize(
CacheKey* key,
const VkDescriptorSetLayoutCreateInfo& t) {
key->Record(t.flags).RecordIterable(t.pBindings, t.bindingCount);
vulkan::SerializePnext<>(key, reinterpret_cast<const VkBaseOutStructure*>(&t));
}
template <>
void CacheKeySerializer<VkPushConstantRange>::Serialize(CacheKey* key,
const VkPushConstantRange& t) {
key->Record(t.stageFlags, t.offset, t.size);
}
template <>
void CacheKeySerializer<VkPipelineLayoutCreateInfo>::Serialize(
CacheKey* key,
const VkPipelineLayoutCreateInfo& t) {
// The set layouts are not serialized here because they are pointers to backend objects.
// They need to be cross-referenced with the frontend objects and serialized from there.
key->Record(t.flags).RecordIterable(t.pPushConstantRanges, t.pushConstantRangeCount);
vulkan::SerializePnext<>(key, reinterpret_cast<const VkBaseOutStructure*>(&t));
}
template <>
void CacheKeySerializer<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>::Serialize(
CacheKey* key,
const VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT& t) {
key->Record(t.requiredSubgroupSize);
}
template <>
void CacheKeySerializer<VkSpecializationMapEntry>::Serialize(
CacheKey* key,
const VkSpecializationMapEntry& t) {
key->Record(t.constantID, t.offset, t.size);
}
template <>
void CacheKeySerializer<VkSpecializationInfo>::Serialize(CacheKey* key,
const VkSpecializationInfo& t) {
key->RecordIterable(t.pMapEntries, t.mapEntryCount)
.RecordIterable(static_cast<const uint8_t*>(t.pData), t.dataSize);
}
template <>
void CacheKeySerializer<VkPipelineShaderStageCreateInfo>::Serialize(
CacheKey* key,
const VkPipelineShaderStageCreateInfo& t) {
// The shader module is not serialized here because it is a pointer to a backend object.
key->Record(t.flags, t.stage)
.RecordIterable(t.pName, strlen(t.pName))
.Record(t.pSpecializationInfo);
vulkan::SerializePnext<VkPipelineShaderStageRequiredSubgroupSizeCreateInfoEXT>(
key, reinterpret_cast<const VkBaseOutStructure*>(&t));
}
template <>
void CacheKeySerializer<VkComputePipelineCreateInfo>::Serialize(
CacheKey* key,
const VkComputePipelineCreateInfo& t) {
// The pipeline layout is not serialized here because it is a pointer to a backend object.
// It needs to be cross-referenced with the frontend objects and serialized from there. The
// base pipeline information is also currently not recorded since we do not use them in our
// backend implementation. If we decide to use them later on, they also need to be
// cross-referenced from the frontend.
key->Record(t.flags, t.stage);
}
} // namespace dawn::native

View File

@ -0,0 +1,85 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn/common/Assert.h"
#include "dawn/common/vulkan_platform.h"
#include "dawn/native/CacheKey.h"
#include "vulkan-tools/src/icd/generated/vk_typemap_helper.h"
#include <map>
namespace dawn::native::vulkan {
namespace detail {
template <typename... VK_STRUCT_TYPES>
void ValidatePnextImpl(const VkBaseOutStructure* root) {
const VkBaseOutStructure* next =
reinterpret_cast<const VkBaseOutStructure*>(root->pNext);
while (next != nullptr) {
// Assert that the type of each pNext struct is exactly one of the specified
// templates.
ASSERT(((LvlTypeMap<VK_STRUCT_TYPES>::kSType == next->sType ? 1 : 0) + ... + 0) ==
1);
next = reinterpret_cast<const VkBaseOutStructure*>(next->pNext);
}
}
template <typename VK_STRUCT_TYPE>
void SerializePnextImpl(CacheKey* key, const VkBaseOutStructure* root) {
const VkBaseOutStructure* next =
reinterpret_cast<const VkBaseOutStructure*>(root->pNext);
const VK_STRUCT_TYPE* found = nullptr;
while (next != nullptr) {
if (LvlTypeMap<VK_STRUCT_TYPE>::kSType == next->sType) {
if (found == nullptr) {
found = reinterpret_cast<const VK_STRUCT_TYPE*>(next);
} else {
// Fail an assert here since that means that the chain had more than one of
// the same typed chained object.
ASSERT(false);
}
}
next = reinterpret_cast<const VkBaseOutStructure*>(next->pNext);
}
if (found != nullptr) {
key->Record(found);
}
}
template <typename VK_STRUCT_TYPE,
typename... VK_STRUCT_TYPES,
typename = std::enable_if_t<(sizeof...(VK_STRUCT_TYPES) > 0)>>
void SerializePnextImpl(CacheKey* key, const VkBaseOutStructure* root) {
SerializePnextImpl<VK_STRUCT_TYPE>(key, root);
SerializePnextImpl<VK_STRUCT_TYPES...>(key, root);
}
} // namespace detail
template <typename... VK_STRUCT_TYPES>
void SerializePnext(CacheKey* key, const VkBaseOutStructure* root) {
detail::ValidatePnextImpl<VK_STRUCT_TYPES...>(root);
detail::SerializePnextImpl<VK_STRUCT_TYPES...>(key, root);
}
// Empty template specialization so that we can put this in to ensure failures occur if new
// extensions are added without updating serialization.
template <>
void SerializePnext(CacheKey* key, const VkBaseOutStructure* root) {
detail::ValidatePnextImpl<>(root);
}
} // namespace dawn::native::vulkan

View File

@ -22,6 +22,8 @@
#include "dawn/native/vulkan/UtilsVulkan.h" #include "dawn/native/vulkan/UtilsVulkan.h"
#include "dawn/native/vulkan/VulkanError.h" #include "dawn/native/vulkan/VulkanError.h"
#include <utility>
namespace dawn::native::vulkan { namespace dawn::native::vulkan {
// static // static
@ -46,10 +48,11 @@ namespace dawn::native::vulkan {
createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT; createInfo.stage.stage = VK_SHADER_STAGE_COMPUTE_BIT;
// Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline // Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute); const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
DAWN_TRY_ASSIGN(createInfo.stage.module, ShaderModule* module = ToBackend(computeStage.module.Get());
ToBackend(computeStage.module.Get()) PipelineLayout* layout = ToBackend(GetLayout());
->GetTransformedModuleHandle(computeStage.entryPoint.c_str(), const ShaderModule::Spirv* spirv;
ToBackend(GetLayout()))); DAWN_TRY_ASSIGN((std::tie(createInfo.stage.module, spirv)),
module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout));
createInfo.stage.pName = computeStage.entryPoint.c_str(); createInfo.stage.pName = computeStage.entryPoint.c_str();
@ -74,6 +77,11 @@ namespace dawn::native::vulkan {
VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT); VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO_EXT);
} }
// Record cache key information now since the createInfo is not stored.
GetCacheKey()
->Record(createInfo, static_cast<const ComputePipeline*>(this)->GetLayout())
.RecordIterable(*spirv);
DAWN_TRY(CheckVkSuccess( DAWN_TRY(CheckVkSuccess(
device->fn.CreateComputePipelines(device->GetVkDevice(), ::VK_NULL_HANDLE, 1, device->fn.CreateComputePipelines(device->GetVkDevice(), ::VK_NULL_HANDLE, 1,
&createInfo, nullptr, &*mHandle), &createInfo, nullptr, &*mHandle),

View File

@ -38,8 +38,11 @@ namespace dawn::native::vulkan {
// this constraints at the Dawn level? // this constraints at the Dawn level?
uint32_t numSetLayouts = 0; uint32_t numSetLayouts = 0;
std::array<VkDescriptorSetLayout, kMaxBindGroups> setLayouts; std::array<VkDescriptorSetLayout, kMaxBindGroups> setLayouts;
std::array<const CachedObject*, kMaxBindGroups> cachedObjects;
for (BindGroupIndex setIndex : IterateBitSet(GetBindGroupLayoutsMask())) { for (BindGroupIndex setIndex : IterateBitSet(GetBindGroupLayoutsMask())) {
setLayouts[numSetLayouts] = ToBackend(GetBindGroupLayout(setIndex))->GetHandle(); const BindGroupLayoutBase* bindGroupLayout = GetBindGroupLayout(setIndex);
setLayouts[numSetLayouts] = ToBackend(bindGroupLayout)->GetHandle();
cachedObjects[numSetLayouts] = bindGroupLayout;
numSetLayouts++; numSetLayouts++;
} }
@ -52,6 +55,9 @@ namespace dawn::native::vulkan {
createInfo.pushConstantRangeCount = 0; createInfo.pushConstantRangeCount = 0;
createInfo.pPushConstantRanges = nullptr; createInfo.pPushConstantRanges = nullptr;
// Record cache key information now since the createInfo is not stored.
GetCacheKey()->RecordIterable(cachedObjects.data(), numSetLayouts).Record(createInfo);
Device* device = ToBackend(GetDevice()); Device* device = ToBackend(GetDevice());
DAWN_TRY(CheckVkSuccess( DAWN_TRY(CheckVkSuccess(
device->fn.CreatePipelineLayout(device->GetVkDevice(), &createInfo, nullptr, &*mHandle), device->fn.CreatePipelineLayout(device->GetVkDevice(), &createInfo, nullptr, &*mHandle),

View File

@ -332,6 +332,7 @@ namespace dawn::native::vulkan {
MaybeError RenderPipeline::Initialize() { MaybeError RenderPipeline::Initialize() {
Device* device = ToBackend(GetDevice()); Device* device = ToBackend(GetDevice());
PipelineLayout* layout = ToBackend(GetLayout());
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment // There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages; std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
@ -344,10 +345,11 @@ namespace dawn::native::vulkan {
VkPipelineShaderStageCreateInfo shaderStage; VkPipelineShaderStageCreateInfo shaderStage;
const ProgrammableStage& programmableStage = GetStage(stage); const ProgrammableStage& programmableStage = GetStage(stage);
DAWN_TRY_ASSIGN(shaderStage.module, ShaderModule* module = ToBackend(programmableStage.module.Get());
ToBackend(programmableStage.module) const ShaderModule::Spirv* spirv;
->GetTransformedModuleHandle(programmableStage.entryPoint.c_str(), DAWN_TRY_ASSIGN(
ToBackend(GetLayout()))); std::tie(shaderStage.module, spirv),
module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout));
shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
shaderStage.pNext = nullptr; shaderStage.pNext = nullptr;

View File

@ -38,35 +38,38 @@ namespace dawn::native::vulkan {
ShaderModule::ConcurrentTransformedShaderModuleCache:: ShaderModule::ConcurrentTransformedShaderModuleCache::
~ConcurrentTransformedShaderModuleCache() { ~ConcurrentTransformedShaderModuleCache() {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard<std::mutex> lock(mMutex);
for (const auto& [_, module] : mTransformedShaderModuleCache) { for (const auto& [_, moduleAndSpirv] : mTransformedShaderModuleCache) {
mDevice->GetFencedDeleter()->DeleteWhenUnused(module); mDevice->GetFencedDeleter()->DeleteWhenUnused(moduleAndSpirv.first);
} }
} }
VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::FindShaderModule( std::optional<ShaderModule::ModuleAndSpirv>
ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
const PipelineLayoutEntryPointPair& key) { const PipelineLayoutEntryPointPair& key) {
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key); auto iter = mTransformedShaderModuleCache.find(key);
if (iter != mTransformedShaderModuleCache.end()) { if (iter != mTransformedShaderModuleCache.end()) {
auto cached = iter->second; return std::make_pair(iter->second.first, iter->second.second.get());
return cached;
} }
return VK_NULL_HANDLE; return {};
} }
VkShaderModule ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGetCachedShaderModule( ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet(
const PipelineLayoutEntryPointPair& key, const PipelineLayoutEntryPointPair& key,
VkShaderModule value) { VkShaderModule module,
ASSERT(value != VK_NULL_HANDLE); std::vector<uint32_t>&& spirv) {
ASSERT(module != VK_NULL_HANDLE);
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key); auto iter = mTransformedShaderModuleCache.find(key);
if (iter == mTransformedShaderModuleCache.end()) { if (iter == mTransformedShaderModuleCache.end()) {
mTransformedShaderModuleCache.emplace(key, value); mTransformedShaderModuleCache.emplace(
return value; key, std::make_pair(module, std::unique_ptr<Spirv>(new Spirv(spirv))));
} else { } else {
mDevice->GetFencedDeleter()->DeleteWhenUnused(value); mDevice->GetFencedDeleter()->DeleteWhenUnused(module);
return iter->second;
} }
// Now the key should exist in the map, so find it again and return it.
iter = mTransformedShaderModuleCache.find(key);
return std::make_pair(iter->second.first, iter->second.second.get());
} }
// static // static
@ -109,25 +112,24 @@ namespace dawn::native::vulkan {
ShaderModule::~ShaderModule() = default; ShaderModule::~ShaderModule() = default;
ResultOrError<VkShaderModule> ShaderModule::GetTransformedModuleHandle( ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
const char* entryPointName, const char* entryPointName,
PipelineLayout* layout) { PipelineLayout* layout) {
TRACE_EVENT0(GetDevice()->GetPlatform(), General, TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv");
"ShaderModuleVk::GetTransformedModuleHandle");
// If the shader was destroyed, we should never call this function. // If the shader was destroyed, we should never call this function.
ASSERT(IsAlive()); ASSERT(IsAlive());
ScopedTintICEHandler scopedICEHandler(GetDevice()); ScopedTintICEHandler scopedICEHandler(GetDevice());
// Check to see if we have the handle and spirv cached already.
auto cacheKey = std::make_pair(layout, entryPointName); auto cacheKey = std::make_pair(layout, entryPointName);
VkShaderModule cachedShaderModule = auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey);
mTransformedShaderModuleCache->FindShaderModule(cacheKey); if (handleAndSpirv.has_value()) {
if (cachedShaderModule != VK_NULL_HANDLE) { return std::move(*handleAndSpirv);
return cachedShaderModule;
} }
// Creation of VkShaderModule is deferred to this point when using tint generator // Creation of module and spirv is deferred to this point when using tint generator
// Remap BindingNumber to BindingIndex in WGSL shader // Remap BindingNumber to BindingIndex in WGSL shader
using BindingRemapper = tint::transform::BindingRemapper; using BindingRemapper = tint::transform::BindingRemapper;
@ -207,7 +209,7 @@ namespace dawn::native::vulkan {
options.use_zero_initialize_workgroup_memory_extension = options.use_zero_initialize_workgroup_memory_extension =
GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension); GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
std::vector<uint32_t> spirv; Spirv spirv;
{ {
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "tint::writer::spirv::Generate()"); TRACE_EVENT0(GetDevice()->GetPlatform(), General, "tint::writer::spirv::Generate()");
auto result = tint::writer::spirv::Generate(&program, options); auto result = tint::writer::spirv::Generate(&program, options);
@ -236,15 +238,17 @@ namespace dawn::native::vulkan {
device->GetVkDevice(), &createInfo, nullptr, &*newHandle), device->GetVkDevice(), &createInfo, nullptr, &*newHandle),
"CreateShaderModule")); "CreateShaderModule"));
} }
ModuleAndSpirv moduleAndSpirv;
if (newHandle != VK_NULL_HANDLE) { if (newHandle != VK_NULL_HANDLE) {
newHandle = moduleAndSpirv =
mTransformedShaderModuleCache->AddOrGetCachedShaderModule(cacheKey, newHandle); mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, std::move(spirv));
} }
SetDebugName(ToBackend(GetDevice()), VK_OBJECT_TYPE_SHADER_MODULE, SetDebugName(ToBackend(GetDevice()), VK_OBJECT_TYPE_SHADER_MODULE,
reinterpret_cast<uint64_t&>(newHandle), "Dawn_ShaderModule", GetLabel()); reinterpret_cast<uint64_t&>(moduleAndSpirv.first), "Dawn_ShaderModule",
GetLabel());
return newHandle; return std::move(moduleAndSpirv);
} }
} // namespace dawn::native::vulkan } // namespace dawn::native::vulkan

View File

@ -20,7 +20,11 @@
#include "dawn/common/vulkan_platform.h" #include "dawn/common/vulkan_platform.h"
#include "dawn/native/Error.h" #include "dawn/native/Error.h"
#include <memory>
#include <mutex> #include <mutex>
#include <optional>
#include <utility>
#include <vector>
namespace dawn::native::vulkan { namespace dawn::native::vulkan {
@ -29,11 +33,14 @@ namespace dawn::native::vulkan {
class ShaderModule final : public ShaderModuleBase { class ShaderModule final : public ShaderModuleBase {
public: public:
using Spirv = std::vector<uint32_t>;
using ModuleAndSpirv = std::pair<VkShaderModule, const Spirv*>;
static ResultOrError<Ref<ShaderModule>> Create(Device* device, static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
ShaderModuleParseResult* parseResult); ShaderModuleParseResult* parseResult);
ResultOrError<VkShaderModule> GetTransformedModuleHandle(const char* entryPointName, ResultOrError<ModuleAndSpirv> GetHandleAndSpirv(const char* entryPointName,
PipelineLayout* layout); PipelineLayout* layout);
private: private:
@ -42,20 +49,24 @@ namespace dawn::native::vulkan {
MaybeError Initialize(ShaderModuleParseResult* parseResult); MaybeError Initialize(ShaderModuleParseResult* parseResult);
void DestroyImpl() override; void DestroyImpl() override;
// New handles created by GetTransformedModuleHandle at pipeline creation time // New handles created by GetHandleAndSpirv at pipeline creation time.
class ConcurrentTransformedShaderModuleCache { class ConcurrentTransformedShaderModuleCache {
public: public:
explicit ConcurrentTransformedShaderModuleCache(Device* device); explicit ConcurrentTransformedShaderModuleCache(Device* device);
~ConcurrentTransformedShaderModuleCache(); ~ConcurrentTransformedShaderModuleCache();
VkShaderModule FindShaderModule(const PipelineLayoutEntryPointPair& key);
VkShaderModule AddOrGetCachedShaderModule(const PipelineLayoutEntryPointPair& key, std::optional<ModuleAndSpirv> Find(const PipelineLayoutEntryPointPair& key);
VkShaderModule value); ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key,
VkShaderModule module,
std::vector<uint32_t>&& spirv);
private: private:
using Entry = std::pair<VkShaderModule, std::unique_ptr<Spirv>>;
Device* mDevice; Device* mDevice;
std::mutex mMutex; std::mutex mMutex;
std::unordered_map<PipelineLayoutEntryPointPair, std::unordered_map<PipelineLayoutEntryPointPair,
VkShaderModule, Entry,
PipelineLayoutEntryPointPairHashFunc> PipelineLayoutEntryPointPairHashFunc>
mTransformedShaderModuleCache; mTransformedShaderModuleCache;
}; };

View File

@ -51,7 +51,7 @@ namespace dawn::native {
// Matcher to compare CacheKeys for easier testing. // Matcher to compare CacheKeys for easier testing.
MATCHER_P(CacheKeyEq, key, PrintToString(key)) { MATCHER_P(CacheKeyEq, key, PrintToString(key)) {
return memcmp(arg.data(), key.data(), arg.size()) == 0; return arg.size() == key.size() && memcmp(arg.data(), key.data(), key.size()) == 0;
} }
TEST(CacheKeyTests, RecordSingleMember) { TEST(CacheKeyTests, RecordSingleMember) {