Vulkan: Implement WGSL->SPIRV caching

Bug: dawn:1480
Change-Id: I77facc854ce9d5fe41c2332236113f266178470a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94660
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Austin Eng 2022-07-08 21:30:25 +00:00 committed by Dawn LUCI CQ
parent 6d41e60a77
commit 1b4da5d28f
21 changed files with 609 additions and 181 deletions

View File

@ -200,6 +200,7 @@ source_set("sources") {
"Buffer.h", "Buffer.h",
"CacheKey.cpp", "CacheKey.cpp",
"CacheKey.h", "CacheKey.h",
"CacheKeyImplTint.cpp",
"CacheRequest.cpp", "CacheRequest.cpp",
"CacheRequest.h", "CacheRequest.h",
"CacheResult.h", "CacheResult.h",

View File

@ -16,7 +16,9 @@
#define SRC_DAWN_NATIVE_BLOB_H_ #define SRC_DAWN_NATIVE_BLOB_H_
#include <functional> #include <functional>
#include <memory> #include <type_traits>
#include <utility>
#include <vector>
namespace dawn::native { namespace dawn::native {
@ -59,6 +61,15 @@ class Blob {
Blob CreateBlob(size_t size, size_t alignment = 1); Blob CreateBlob(size_t size, size_t alignment = 1);
template <typename T, typename = std::enable_if_t<std::is_fundamental_v<T>>>
Blob CreateBlob(std::vector<T> vec) {
uint8_t* data = reinterpret_cast<uint8_t*>(vec.data());
size_t size = vec.size() * sizeof(T);
// Move the vector into a new allocation so we can destruct it in the deleter.
auto* wrapped_vec = new std::vector<T>(std::move(vec));
return Blob::UnsafeCreateWithDeleter(data, size, [wrapped_vec]() { delete wrapped_vec; });
}
} // namespace dawn::native } // namespace dawn::native
#endif // SRC_DAWN_NATIVE_BLOB_H_ #endif // SRC_DAWN_NATIVE_BLOB_H_

View File

@ -19,6 +19,7 @@
#include "dawn/common/Platform.h" #include "dawn/common/Platform.h"
#include "dawn/native/Blob.h" #include "dawn/native/Blob.h"
#include "dawn/native/CacheResult.h"
namespace dawn::platform { namespace dawn::platform {
class CachingInterface; class CachingInterface;
@ -42,6 +43,19 @@ class BlobCache {
void Store(const CacheKey& key, size_t valueSize, const void* value); void Store(const CacheKey& key, size_t valueSize, const void* value);
void Store(const CacheKey& key, const Blob& value); void Store(const CacheKey& key, const Blob& value);
// Other types may specialize BlobCache::Store<T> to define how T is serialized into the cache.
template <typename T>
void Store(const CacheKey& key, const T& value);
// Store a CacheResult into the cache if it isn't cached yet.
// Calls Store<T> which should be defined elsewhere.
template <typename T>
void EnsureStored(const CacheResult<T>& cacheResult) {
if (!cacheResult.IsCached()) {
Store(cacheResult.GetCacheKey(), *cacheResult);
}
}
private: private:
// Non-thread safe internal implementations of load and store. Exposed callers that use // Non-thread safe internal implementations of load and store. Exposed callers that use
// these helpers need to make sure that these are entered with `mMutex` held. // these helpers need to make sure that these are entered with `mMutex` held.

View File

@ -59,6 +59,7 @@ target_sources(dawn_native PRIVATE
"CachedObject.h" "CachedObject.h"
"CacheKey.cpp" "CacheKey.cpp"
"CacheKey.h" "CacheKey.h"
"CacheKeyImplTint.cpp"
"CacheRequest.cpp" "CacheRequest.cpp"
"CacheRequest.h" "CacheRequest.h"
"CacheResult.h" "CacheResult.h"

View File

@ -15,11 +15,14 @@
#ifndef SRC_DAWN_NATIVE_CACHEKEY_H_ #ifndef SRC_DAWN_NATIVE_CACHEKEY_H_
#define SRC_DAWN_NATIVE_CACHEKEY_H_ #define SRC_DAWN_NATIVE_CACHEKEY_H_
#include <algorithm>
#include <bitset> #include <bitset>
#include <functional>
#include <iostream> #include <iostream>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <type_traits> #include <type_traits>
#include <unordered_map>
#include <utility> #include <utility>
#include <vector> #include <vector>
@ -230,6 +233,29 @@ class CacheKeySerializer<std::vector<T>> {
static void Serialize(CacheKey* key, const std::vector<T>& t) { key->RecordIterable(t); } static void Serialize(CacheKey* key, const std::vector<T>& t) { key->RecordIterable(t); }
}; };
// Specialized overload for std::pair<A, B>
template <typename A, typename B>
class CacheKeySerializer<std::pair<A, B>> {
public:
static void Serialize(CacheKey* key, const std::pair<A, B>& p) {
key->Record(p.first, p.second);
}
};
// Specialized overload for std::unordered_map<K, V>
template <typename K, typename V>
class CacheKeySerializer<std::unordered_map<K, V>> {
public:
static void Serialize(CacheKey* key, const std::unordered_map<K, V>& m) {
std::vector<std::pair<K, V>> ordered(m.begin(), m.end());
std::sort(ordered.begin(), ordered.end(),
[](const std::pair<K, V>& a, const std::pair<K, V>& b) {
return std::less<K>{}(a.first, b.first);
});
key->RecordIterable(ordered);
}
};
} // namespace dawn::native } // namespace dawn::native
#endif // SRC_DAWN_NATIVE_CACHEKEY_H_ #endif // SRC_DAWN_NATIVE_CACHEKEY_H_

View File

@ -0,0 +1,62 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn/native/CacheKey.h"
#include "tint/tint.h"
namespace dawn::native {
// static
template <>
void CacheKeySerializer<tint::Program>::Serialize(CacheKey* key, const tint::Program& p) {
#if TINT_BUILD_WGSL_WRITER
tint::writer::wgsl::Options options{};
key->Record(tint::writer::wgsl::Generate(&p, options).wgsl);
#else
// TODO(crbug.com/dawn/1481): We shouldn't need to write back to WGSL if we have a CacheKey
// built from the initial shader module input. Then, we would never need to parse the program
// and write back out to WGSL.
UNREACHABLE();
#endif
}
// static
template <>
void CacheKeySerializer<tint::transform::BindingPoints>::Serialize(
CacheKey* key,
const tint::transform::BindingPoints& points) {
static_assert(offsetof(tint::transform::BindingPoints, plane_1) == 0,
"Please update serialization for tint::transform::BindingPoints");
static_assert(offsetof(tint::transform::BindingPoints, params) == 8,
"Please update serialization for tint::transform::BindingPoints");
static_assert(sizeof(tint::transform::BindingPoints) == 16,
"Please update serialization for tint::transform::BindingPoints");
key->Record(points.plane_1, points.params);
}
// static
template <>
void CacheKeySerializer<tint::sem::BindingPoint>::Serialize(CacheKey* key,
const tint::sem::BindingPoint& p) {
static_assert(offsetof(tint::sem::BindingPoint, group) == 0,
"Please update serialization for tint::sem::BindingPoint");
static_assert(offsetof(tint::sem::BindingPoint, binding) == 4,
"Please update serialization for tint::sem::BindingPoint");
static_assert(sizeof(tint::sem::BindingPoint) == 8,
"Please update serialization for tint::sem::BindingPoint");
key->Record(p.group, p.binding);
}
} // namespace dawn::native

View File

@ -151,6 +151,8 @@ class CacheRequestImpl {
} }
}; };
} // namespace dawn::native
// Helper for X macro to declare a struct member. // Helper for X macro to declare a struct member.
#define DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER(type, name) type name{}; #define DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER(type, name) type name{};
@ -166,21 +168,23 @@ class CacheRequestImpl {
// X(Bar, bar) // X(Bar, bar)
// DAWN_MAKE_CACHE_REQUEST(MyCacheRequest, REQUEST_MEMBERS) // DAWN_MAKE_CACHE_REQUEST(MyCacheRequest, REQUEST_MEMBERS)
// #undef REQUEST_MEMBERS // #undef REQUEST_MEMBERS
#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \ #define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \
class Request : public CacheRequestImpl<Request> { \ class Request : public ::dawn::native::CacheRequestImpl<Request> { \
public: \ public: \
Request() = default; \ Request() = default; \
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER) \ MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER) \
\ \
/* Create a CacheKey from the request type and all members */ \ /* Create a CacheKey from the request type and all members */ \
CacheKey CreateCacheKey(const DeviceBase* device) const { \ ::dawn::native::CacheKey CreateCacheKey(const ::dawn::native::DeviceBase* device) const { \
CacheKey key = device->GetCacheKey(); \ ::dawn::native::CacheKey key = device->GetCacheKey(); \
key.Record(#Request); \ key.Record(#Request); \
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \ MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \
return key; \ return key; \
} \ } \
}; };
} // namespace dawn::native // Helper macro for the common pattern of DAWN_TRY_ASSIGN around LoadOrRun.
// Requires an #include of dawn/native/Error.h
#define DAWN_TRY_LOAD_OR_RUN(var, ...) DAWN_TRY_ASSIGN(var, LoadOrRun(__VA_ARGS__))
#endif // SRC_DAWN_NATIVE_CACHEREQUEST_H_ #endif // SRC_DAWN_NATIVE_CACHEREQUEST_H_

View File

@ -19,6 +19,7 @@
#include <utility> #include <utility>
#include "dawn/common/Assert.h" #include "dawn/common/Assert.h"
#include "dawn/native/CacheKey.h"
namespace dawn::native { namespace dawn::native {
@ -39,7 +40,7 @@ class CacheResult {
ASSERT(mIsValid); ASSERT(mIsValid);
return mIsCached; return mIsCached;
} }
const CacheKey& GetCacheKey() { const CacheKey& GetCacheKey() const {
ASSERT(mIsValid); ASSERT(mIsValid);
return mKey; return mKey;
} }

View File

@ -623,9 +623,14 @@ bool DeviceBase::APIPopErrorScope(wgpu::ErrorCallback callback, void* userdata)
} }
BlobCache* DeviceBase::GetBlobCache() { BlobCache* DeviceBase::GetBlobCache() {
#if TINT_BUILD_WGSL_WRITER
// TODO(crbug.com/dawn/1481): Shader caching currently has a dependency on the WGSL writer to
// generate cache keys. We can lift the dependency once we also cache frontend parsing,
// transformations, and reflection.
if (IsToggleEnabled(Toggle::EnableBlobCache)) { if (IsToggleEnabled(Toggle::EnableBlobCache)) {
return mInstance->GetBlobCache(); return mInstance->GetBlobCache();
} }
#endif
return nullptr; return nullptr;
} }

View File

@ -23,7 +23,10 @@
namespace dawn::native { namespace dawn::native {
MaybeError ValidateSpirv(DeviceBase* device, const std::vector<uint32_t>& spirv, bool dumpSpirv) { MaybeError ValidateSpirv(DeviceBase* device,
const uint32_t* spirv,
size_t wordCount,
bool dumpSpirv) {
spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1); spvtools::SpirvTools spirvTools(SPV_ENV_VULKAN_1_1);
spirvTools.SetMessageConsumer([device](spv_message_level_t level, const char*, spirvTools.SetMessageConsumer([device](spv_message_level_t level, const char*,
const spv_position_t& position, const char* message) { const spv_position_t& position, const char* message) {
@ -50,12 +53,12 @@ MaybeError ValidateSpirv(DeviceBase* device, const std::vector<uint32_t>& spirv,
device->EmitLog(wgpuLogLevel, ss.str().c_str()); device->EmitLog(wgpuLogLevel, ss.str().c_str());
}); });
const bool valid = spirvTools.Validate(spirv); const bool valid = spirvTools.Validate(spirv, wordCount);
if (dumpSpirv || !valid) { if (dumpSpirv || !valid) {
std::ostringstream dumpedMsg; std::ostringstream dumpedMsg;
std::string disassembly; std::string disassembly;
if (spirvTools.Disassemble( if (spirvTools.Disassemble(
spirv, &disassembly, spirv, wordCount, &disassembly,
SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT)) { SPV_BINARY_TO_TEXT_OPTION_FRIENDLY_NAMES | SPV_BINARY_TO_TEXT_OPTION_INDENT)) {
dumpedMsg << "/* Dumped generated SPIRV disassembly */" << std::endl << disassembly; dumpedMsg << "/* Dumped generated SPIRV disassembly */" << std::endl << disassembly;
} else { } else {

View File

@ -15,15 +15,16 @@
#ifndef SRC_DAWN_NATIVE_SPIRVVALIDATION_H_ #ifndef SRC_DAWN_NATIVE_SPIRVVALIDATION_H_
#define SRC_DAWN_NATIVE_SPIRVVALIDATION_H_ #define SRC_DAWN_NATIVE_SPIRVVALIDATION_H_
#include <vector>
#include "dawn/native/Error.h" #include "dawn/native/Error.h"
namespace dawn::native { namespace dawn::native {
class DeviceBase; class DeviceBase;
MaybeError ValidateSpirv(DeviceBase* device, const std::vector<uint32_t>& spirv, bool dumpSpirv); MaybeError ValidateSpirv(DeviceBase* device,
const uint32_t* spirv,
size_t wordCount,
bool dumpSpirv);
} // namespace dawn::native } // namespace dawn::native

View File

@ -54,3 +54,8 @@ ScopedTintICEHandler::~ScopedTintICEHandler() {
} }
} // namespace dawn::native } // namespace dawn::native
bool std::less<tint::sem::BindingPoint>::operator()(const tint::sem::BindingPoint& a,
const tint::sem::BindingPoint& b) const {
return std::tie(a.group, a.binding) < std::tie(b.group, b.binding);
}

View File

@ -15,8 +15,14 @@
#ifndef SRC_DAWN_NATIVE_TINTUTILS_H_ #ifndef SRC_DAWN_NATIVE_TINTUTILS_H_
#define SRC_DAWN_NATIVE_TINTUTILS_H_ #define SRC_DAWN_NATIVE_TINTUTILS_H_
#include <functional>
#include "dawn/common/NonCopyable.h" #include "dawn/common/NonCopyable.h"
namespace tint::sem {
struct BindingPoint;
}
namespace dawn::native { namespace dawn::native {
class DeviceBase; class DeviceBase;
@ -34,4 +40,10 @@ class ScopedTintICEHandler : public NonCopyable {
} // namespace dawn::native } // namespace dawn::native
// std::less operator for std::map containing BindingPoint
template <>
struct std::less<tint::sem::BindingPoint> {
bool operator()(const tint::sem::BindingPoint& a, const tint::sem::BindingPoint& b) const;
};
#endif // SRC_DAWN_NATIVE_TINTUTILS_H_ #endif // SRC_DAWN_NATIVE_TINTUTILS_H_

View File

@ -58,10 +58,12 @@ MaybeError ComputePipeline::Initialize() {
// Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline // Generate a new VkShaderModule with BindingRemapper tint transform for each pipeline
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute); const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
ShaderModule* module = ToBackend(computeStage.module.Get()); ShaderModule* module = ToBackend(computeStage.module.Get());
const ShaderModule::Spirv* spirv;
DAWN_TRY_ASSIGN((std::tie(createInfo.stage.module, spirv)), ShaderModule::ModuleAndSpirv moduleAndSpirv;
DAWN_TRY_ASSIGN(moduleAndSpirv,
module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout)); module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout));
createInfo.stage.module = moduleAndSpirv.module;
createInfo.stage.pName = computeStage.entryPoint.c_str(); createInfo.stage.pName = computeStage.entryPoint.c_str();
std::vector<OverridableConstantScalar> specializationDataEntries; std::vector<OverridableConstantScalar> specializationDataEntries;
@ -83,7 +85,8 @@ MaybeError ComputePipeline::Initialize() {
} }
// Record cache key information now since the createInfo is not stored. // Record cache key information now since the createInfo is not stored.
mCacheKey.Record(createInfo, layout).RecordIterable(*spirv); mCacheKey.Record(createInfo, layout)
.RecordIterable(moduleAndSpirv.spirv, moduleAndSpirv.wordCount);
// Try to see if we have anything in the blob cache. // Try to see if we have anything in the blob cache.
Ref<PipelineCache> cache = ToBackend(GetDevice()->GetOrCreatePipelineCache(GetCacheKey())); Ref<PipelineCache> cache = ToBackend(GetDevice()->GetOrCreatePipelineCache(GetCacheKey()));

View File

@ -351,10 +351,12 @@ MaybeError RenderPipeline::Initialize() {
const ProgrammableStage& programmableStage = GetStage(stage); const ProgrammableStage& programmableStage = GetStage(stage);
ShaderModule* module = ToBackend(programmableStage.module.Get()); ShaderModule* module = ToBackend(programmableStage.module.Get());
const ShaderModule::Spirv* spirv;
DAWN_TRY_ASSIGN(std::tie(shaderStage.module, spirv), ShaderModule::ModuleAndSpirv moduleAndSpirv;
DAWN_TRY_ASSIGN(moduleAndSpirv,
module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout)); module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout));
shaderStage.module = moduleAndSpirv.module;
shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
shaderStage.pNext = nullptr; shaderStage.pNext = nullptr;
shaderStage.flags = 0; shaderStage.flags = 0;
@ -387,7 +389,7 @@ MaybeError RenderPipeline::Initialize() {
stageCount++; stageCount++;
// Record cache key for each shader since it will become inaccessible later on. // Record cache key for each shader since it will become inaccessible later on.
mCacheKey.Record(stage).RecordIterable(*spirv); mCacheKey.Record(stage).RecordIterable(moduleAndSpirv.spirv, moduleAndSpirv.wordCount);
} }
PipelineVertexInputStateCreateInfoTemporaryAllocations tempAllocations; PipelineVertexInputStateCreateInfoTemporaryAllocations tempAllocations;

View File

@ -17,7 +17,10 @@
#include <spirv-tools/libspirv.hpp> #include <spirv-tools/libspirv.hpp>
#include <map> #include <map>
#include <string>
#include <vector>
#include "dawn/native/CacheRequest.h"
#include "dawn/native/SpirvValidation.h" #include "dawn/native/SpirvValidation.h"
#include "dawn/native/TintUtils.h" #include "dawn/native/TintUtils.h"
#include "dawn/native/vulkan/BindGroupLayoutVk.h" #include "dawn/native/vulkan/BindGroupLayoutVk.h"
@ -32,6 +35,59 @@
namespace dawn::native::vulkan { namespace dawn::native::vulkan {
// Spirv is a wrapper around Blob that exposes the data as uint32_t words.
class ShaderModule::Spirv : private Blob {
public:
static Spirv FromBlob(Blob&& blob) {
// Vulkan drivers expect the SPIRV to be aligned like an array of uint32_t values.
blob.AlignTo(alignof(uint32_t));
return static_cast<Spirv&&>(blob);
}
static Spirv Create(std::vector<uint32_t> code) {
Blob blob = CreateBlob(std::move(code));
ASSERT(IsPtrAligned(blob.Data(), alignof(uint32_t)));
return static_cast<Spirv&&>(std::move(blob));
}
const uint32_t* Code() const { return reinterpret_cast<const uint32_t*>(Data()); }
size_t WordCount() const { return Size() / sizeof(uint32_t); }
};
} // namespace dawn::native::vulkan
namespace dawn::native {
// Define the implementation to store vulkan::ShaderModule::Spirv into the BlobCache.
template <>
void BlobCache::Store<vulkan::ShaderModule::Spirv>(const CacheKey& key,
const vulkan::ShaderModule::Spirv& spirv) {
Store(key, spirv.WordCount() * sizeof(uint32_t), spirv.Code());
}
} // namespace dawn::native
namespace dawn::native::vulkan {
class ShaderModule::ConcurrentTransformedShaderModuleCache {
public:
explicit ConcurrentTransformedShaderModuleCache(Device* device);
~ConcurrentTransformedShaderModuleCache();
std::optional<ModuleAndSpirv> Find(const PipelineLayoutEntryPointPair& key);
ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key,
VkShaderModule module,
Spirv&& spirv);
private:
using Entry = std::pair<VkShaderModule, Spirv>;
Device* mDevice;
std::mutex mMutex;
std::unordered_map<PipelineLayoutEntryPointPair, Entry, PipelineLayoutEntryPointPairHashFunc>
mTransformedShaderModuleCache;
};
ShaderModule::ConcurrentTransformedShaderModuleCache::ConcurrentTransformedShaderModuleCache( ShaderModule::ConcurrentTransformedShaderModuleCache::ConcurrentTransformedShaderModuleCache(
Device* device) Device* device)
: mDevice(device) {} : mDevice(device) {}
@ -49,7 +105,11 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key); auto iter = mTransformedShaderModuleCache.find(key);
if (iter != mTransformedShaderModuleCache.end()) { if (iter != mTransformedShaderModuleCache.end()) {
return std::make_pair(iter->second.first, iter->second.second.get()); return ModuleAndSpirv{
iter->second.first,
iter->second.second.Code(),
iter->second.second.WordCount(),
};
} }
return {}; return {};
} }
@ -57,19 +117,22 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet( ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet(
const PipelineLayoutEntryPointPair& key, const PipelineLayoutEntryPointPair& key,
VkShaderModule module, VkShaderModule module,
std::vector<uint32_t>&& spirv) { Spirv&& spirv) {
ASSERT(module != VK_NULL_HANDLE); ASSERT(module != VK_NULL_HANDLE);
std::lock_guard<std::mutex> lock(mMutex); std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key); auto iter = mTransformedShaderModuleCache.find(key);
if (iter == mTransformedShaderModuleCache.end()) { if (iter == mTransformedShaderModuleCache.end()) {
mTransformedShaderModuleCache.emplace( mTransformedShaderModuleCache.emplace(key, std::make_pair(module, std::move(spirv)));
key, std::make_pair(module, std::unique_ptr<Spirv>(new Spirv(spirv))));
} else { } else {
mDevice->GetFencedDeleter()->DeleteWhenUnused(module); mDevice->GetFencedDeleter()->DeleteWhenUnused(module);
} }
// Now the key should exist in the map, so find it again and return it. // Now the key should exist in the map, so find it again and return it.
iter = mTransformedShaderModuleCache.find(key); iter = mTransformedShaderModuleCache.find(key);
return std::make_pair(iter->second.first, iter->second.second.get()); return ModuleAndSpirv{
iter->second.first,
iter->second.second.Code(),
iter->second.second.WordCount(),
};
} }
// static // static
@ -114,6 +177,18 @@ void ShaderModule::DestroyImpl() {
ShaderModule::~ShaderModule() = default; ShaderModule::~ShaderModule() = default;
#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \
X(const tint::Program*, inputProgram) \
X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \
X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \
X(std::string_view, entryPointName) \
X(bool, disableWorkgroupInit) \
X(bool, useZeroInitializeWorkgroupMemoryExtension) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS);
#undef SPIRV_COMPILATION_REQUEST_MEMBERS
ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv( ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
const char* entryPointName, const char* entryPointName,
const PipelineLayout* layout) { const PipelineLayout* layout) {
@ -137,15 +212,13 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
using BindingRemapper = tint::transform::BindingRemapper; using BindingRemapper = tint::transform::BindingRemapper;
using BindingPoint = tint::transform::BindingPoint; using BindingPoint = tint::transform::BindingPoint;
BindingRemapper::BindingPoints bindingPoints; BindingRemapper::BindingPoints bindingPoints;
BindingRemapper::AccessControls accessControls;
const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings; const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
const auto& groupBindingInfo = moduleBindingInfo[group]; const auto& groupBindingInfo = moduleBindingInfo[group];
for (const auto& it : groupBindingInfo) { for (const auto& [binding, _] : groupBindingInfo) {
BindingNumber binding = it.first;
BindingIndex bindingIndex = bgl->GetBindingIndex(binding); BindingIndex bindingIndex = bgl->GetBindingIndex(binding);
BindingPoint srcBindingPoint{static_cast<uint32_t>(group), BindingPoint srcBindingPoint{static_cast<uint32_t>(group),
static_cast<uint32_t>(binding)}; static_cast<uint32_t>(binding)};
@ -158,79 +231,84 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
} }
} }
tint::transform::Manager transformManager;
// Many Vulkan drivers can't handle multi-entrypoint shader modules.
transformManager.append(std::make_unique<tint::transform::SingleEntryPoint>());
// Run the binding remapper after SingleEntryPoint to avoid collisions with unused entryPoints.
transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
tint::transform::DataMap transformInputs;
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(entryPointName);
transformInputs.Add<BindingRemapper::Remappings>(std::move(bindingPoints),
std::move(accessControls),
/* mayCollide */ false);
// Transform external textures into the binding locations specified in the bgl // Transform external textures into the binding locations specified in the bgl
// TODO(dawn:1082): Replace this block with ShaderModuleBase::AddExternalTextureTransform. // TODO(dawn:1082): Replace this block with ShaderModuleBase::AddExternalTextureTransform.
tint::transform::MultiplanarExternalTexture::BindingsMap newBindingsMap; tint::transform::MultiplanarExternalTexture::BindingsMap newBindingsMap;
for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) { for (BindGroupIndex i : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayoutBase* bgl = layout->GetBindGroupLayout(i); const BindGroupLayoutBase* bgl = layout->GetBindGroupLayout(i);
ExternalTextureBindingExpansionMap expansions = for (const auto& [_, expansion] : bgl->GetExternalTextureBindingExpansionMap()) {
bgl->GetExternalTextureBindingExpansionMap();
std::map<BindingNumber, dawn_native::ExternalTextureBindingExpansion>::iterator it =
expansions.begin();
while (it != expansions.end()) {
newBindingsMap[{static_cast<uint32_t>(i), newBindingsMap[{static_cast<uint32_t>(i),
static_cast<uint32_t>(bgl->GetBindingIndex(it->second.plane0))}] = { static_cast<uint32_t>(bgl->GetBindingIndex(expansion.plane0))}] = {
{static_cast<uint32_t>(i), {static_cast<uint32_t>(i),
static_cast<uint32_t>(bgl->GetBindingIndex(it->second.plane1))}, static_cast<uint32_t>(bgl->GetBindingIndex(expansion.plane1))},
{static_cast<uint32_t>(i), {static_cast<uint32_t>(i),
static_cast<uint32_t>(bgl->GetBindingIndex(it->second.params))}}; static_cast<uint32_t>(bgl->GetBindingIndex(expansion.params))}};
it++;
} }
} }
if (!newBindingsMap.empty()) {
transformManager.Add<tint::transform::MultiplanarExternalTexture>();
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
newBindingsMap);
}
tint::Program program;
{
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "RunTransforms");
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, GetTintProgram(), transformInputs,
nullptr, nullptr));
}
#if TINT_BUILD_SPV_WRITER #if TINT_BUILD_SPV_WRITER
tint::writer::spirv::Options options; SpirvCompilationRequest req = {};
options.emit_vertex_point_size = true; req.inputProgram = GetTintProgram();
options.disable_workgroup_init = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit); req.bindingPoints = std::move(bindingPoints);
options.use_zero_initialize_workgroup_memory_extension = req.newBindingsMap = std::move(newBindingsMap);
req.entryPointName = entryPointName;
req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
req.useZeroInitializeWorkgroupMemoryExtension =
GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension); GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
Spirv spirv; CacheResult<Spirv> spirv;
{ DAWN_TRY_LOAD_OR_RUN(
TRACE_EVENT0(GetDevice()->GetPlatform(), General, "tint::writer::spirv::Generate()"); spirv, GetDevice(), std::move(req), Spirv::FromBlob,
auto result = tint::writer::spirv::Generate(&program, options); [](SpirvCompilationRequest r) -> ResultOrError<Spirv> {
DAWN_INVALID_IF(!result.success, "An error occured while generating SPIR-V: %s.", tint::transform::Manager transformManager;
result.error); // Many Vulkan drivers can't handle multi-entrypoint shader modules.
transformManager.append(std::make_unique<tint::transform::SingleEntryPoint>());
// Run the binding remapper after SingleEntryPoint to avoid collisions with
// unused entryPoints.
transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
tint::transform::DataMap transformInputs;
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(
std::string(r.entryPointName));
transformInputs.Add<BindingRemapper::Remappings>(std::move(r.bindingPoints),
BindingRemapper::AccessControls{},
/* mayCollide */ false);
if (!r.newBindingsMap.empty()) {
transformManager.Add<tint::transform::MultiplanarExternalTexture>();
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
r.newBindingsMap);
}
tint::Program program;
{
TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms");
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram,
transformInputs, nullptr, nullptr));
}
tint::writer::spirv::Options options;
options.emit_vertex_point_size = true;
options.disable_workgroup_init = r.disableWorkgroupInit;
options.use_zero_initialize_workgroup_memory_extension =
r.useZeroInitializeWorkgroupMemoryExtension;
spirv = std::move(result.spirv); TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General,
} "tint::writer::spirv::Generate()");
auto result = tint::writer::spirv::Generate(&program, options);
DAWN_INVALID_IF(!result.success, "An error occured while generating SPIR-V: %s.",
result.error);
DAWN_TRY(ValidateSpirv(GetDevice(), spirv, GetDevice()->IsToggleEnabled(Toggle::DumpShaders))); return Spirv::Create(std::move(result.spirv));
});
DAWN_TRY(ValidateSpirv(GetDevice(), spirv->Code(), spirv->WordCount(),
GetDevice()->IsToggleEnabled(Toggle::DumpShaders)));
VkShaderModuleCreateInfo createInfo; VkShaderModuleCreateInfo createInfo;
createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
createInfo.pNext = nullptr; createInfo.pNext = nullptr;
createInfo.flags = 0; createInfo.flags = 0;
createInfo.codeSize = spirv.size() * sizeof(uint32_t); createInfo.codeSize = spirv->WordCount() * sizeof(uint32_t);
createInfo.pCode = spirv.data(); createInfo.pCode = spirv->Code();
Device* device = ToBackend(GetDevice()); Device* device = ToBackend(GetDevice());
@ -241,13 +319,17 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &*newHandle), device->fn.CreateShaderModule(device->GetVkDevice(), &createInfo, nullptr, &*newHandle),
"CreateShaderModule")); "CreateShaderModule"));
} }
ModuleAndSpirv moduleAndSpirv; ModuleAndSpirv moduleAndSpirv;
if (newHandle != VK_NULL_HANDLE) { if (newHandle != VK_NULL_HANDLE) {
if (BlobCache* cache = device->GetBlobCache()) {
cache->EnsureStored(spirv);
}
moduleAndSpirv = moduleAndSpirv =
mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, std::move(spirv)); mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, spirv.Acquire());
} }
SetDebugName(ToBackend(GetDevice()), moduleAndSpirv.first, "Dawn_ShaderModule", GetLabel()); SetDebugName(ToBackend(GetDevice()), moduleAndSpirv.module, "Dawn_ShaderModule", GetLabel());
return std::move(moduleAndSpirv); return std::move(moduleAndSpirv);
#else #else

View File

@ -20,12 +20,10 @@
#include <optional> #include <optional>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector>
#include "dawn/native/ShaderModule.h"
#include "dawn/common/vulkan_platform.h" #include "dawn/common/vulkan_platform.h"
#include "dawn/native/Error.h" #include "dawn/native/Error.h"
#include "dawn/native/ShaderModule.h"
namespace dawn::native::vulkan { namespace dawn::native::vulkan {
@ -34,8 +32,12 @@ class PipelineLayout;
class ShaderModule final : public ShaderModuleBase { class ShaderModule final : public ShaderModuleBase {
public: public:
using Spirv = std::vector<uint32_t>; class Spirv;
using ModuleAndSpirv = std::pair<VkShaderModule, const Spirv*>; struct ModuleAndSpirv {
VkShaderModule module;
const uint32_t* spirv;
size_t wordCount;
};
static ResultOrError<Ref<ShaderModule>> Create(Device* device, static ResultOrError<Ref<ShaderModule>> Create(Device* device,
const ShaderModuleDescriptor* descriptor, const ShaderModuleDescriptor* descriptor,
@ -53,25 +55,7 @@ class ShaderModule final : public ShaderModuleBase {
void DestroyImpl() override; void DestroyImpl() override;
// New handles created by GetHandleAndSpirv at pipeline creation time. // New handles created by GetHandleAndSpirv at pipeline creation time.
class ConcurrentTransformedShaderModuleCache { class ConcurrentTransformedShaderModuleCache;
public:
explicit ConcurrentTransformedShaderModuleCache(Device* device);
~ConcurrentTransformedShaderModuleCache();
std::optional<ModuleAndSpirv> Find(const PipelineLayoutEntryPointPair& key);
ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key,
VkShaderModule module,
std::vector<uint32_t>&& spirv);
private:
using Entry = std::pair<VkShaderModule, std::unique_ptr<Spirv>>;
Device* mDevice;
std::mutex mMutex;
std::
unordered_map<PipelineLayoutEntryPointPair, Entry, PipelineLayoutEntryPointPairHashFunc>
mTransformedShaderModuleCache;
};
std::unique_ptr<ConcurrentTransformedShaderModuleCache> mTransformedShaderModuleCache; std::unique_ptr<ConcurrentTransformedShaderModuleCache> mTransformedShaderModuleCache;
}; };

View File

@ -389,6 +389,7 @@ source_set("platform_mocks_sources") {
deps = [ deps = [
":gmock_and_gtest", ":gmock_and_gtest",
"${dawn_root}/src/dawn/common",
"${dawn_root}/src/dawn/platform", "${dawn_root}/src/dawn/platform",
] ]

View File

@ -71,12 +71,46 @@ static constexpr std::string_view kFragmentShaderMultipleOutput = R"(
} }
)"; )";
static constexpr std::string_view kFragmentShaderBindGroup00Uniform = R"(
struct S {
value : f32
};
@group(0) @binding(0) var<uniform> uBuffer : S;
@fragment fn main() -> @location(0) vec4<f32> {
return vec4<f32>(uBuffer.value, 0.2, 0.3, 0.4);
}
)";
static constexpr std::string_view kFragmentShaderBindGroup01Uniform = R"(
struct S {
value : f32
};
@group(0) @binding(1) var<uniform> uBuffer : S;
@fragment fn main() -> @location(0) vec4<f32> {
return vec4<f32>(uBuffer.value, 0.2, 0.3, 0.4);
}
)";
class PipelineCachingTests : public DawnTest { class PipelineCachingTests : public DawnTest {
protected: protected:
std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override { std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override {
return std::make_unique<DawnCachingMockPlatform>(&mMockCache); return std::make_unique<DawnCachingMockPlatform>(&mMockCache);
} }
struct EntryCounts {
unsigned pipeline;
unsigned shaderModule;
};
const EntryCounts counts = {
// pipeline caching is only implemented on D3D12/Vulkan
IsD3D12() || IsVulkan() ? 1u : 0u,
// shader module caching is only implemented on Vulkan
IsVulkan() ? 1u : 0u,
};
NiceMock<CachingInterfaceMock> mMockCache; NiceMock<CachingInterfaceMock> mMockCache;
}; };
@ -95,9 +129,8 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineNoCache) {
wgpu::ComputePipelineDescriptor desc; wgpu::ComputePipelineDescriptor desc;
desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data()); desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data());
desc.compute.entryPoint = "main"; desc.compute.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(0), device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 0u);
// Second time should create fine with no cache hits since cache is disabled. // Second time should create fine with no cache hits since cache is disabled.
{ {
@ -105,9 +138,8 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineNoCache) {
wgpu::ComputePipelineDescriptor desc; wgpu::ComputePipelineDescriptor desc;
desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data()); desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data());
desc.compute.entryPoint = "main"; desc.compute.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(0), device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 0u);
} }
// Tests that pipeline creation on the same device uses frontend cache when possible. // Tests that pipeline creation on the same device uses frontend cache when possible.
@ -118,14 +150,15 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineFrontedCache) {
// First creation should create a cache entry. // First creation should create a cache entry.
wgpu::ComputePipeline pipeline; wgpu::ComputePipeline pipeline;
EXPECT_CACHE_HIT(mMockCache, 0u, pipeline = device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(counts.shaderModule + counts.pipeline),
EXPECT_EQ(mMockCache.GetNumEntries(), 1u); pipeline = device.CreateComputePipeline(&desc));
// Second creation on the same device should just return from frontend cache and should not // Second creation on the same device should just return from frontend cache and should not
// call out to the blob cache. // call out to the blob cache.
EXPECT_CALL(mMockCache, LoadData).Times(0); EXPECT_CALL(mMockCache, LoadData).Times(0);
wgpu::ComputePipeline samePipeline; wgpu::ComputePipeline samePipeline;
EXPECT_CACHE_HIT(mMockCache, 0u, samePipeline = device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(0),
samePipeline = device.CreateComputePipeline(&desc));
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
} }
@ -139,9 +172,9 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineBlobCache) {
wgpu::ComputePipelineDescriptor desc; wgpu::ComputePipelineDescriptor desc;
desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data()); desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data());
desc.compute.entryPoint = "main"; desc.compute.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(counts.shaderModule + counts.pipeline),
device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
// Second time should create using the cache. // Second time should create using the cache.
{ {
@ -149,9 +182,9 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineBlobCache) {
wgpu::ComputePipelineDescriptor desc; wgpu::ComputePipelineDescriptor desc;
desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data()); desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data());
desc.compute.entryPoint = "main"; desc.compute.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 1u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(counts.shaderModule + counts.pipeline), Add(0),
device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
} }
// Tests that pipeline creation hits the cache when using the same pipeline but with explicit // Tests that pipeline creation hits the cache when using the same pipeline but with explicit
@ -163,9 +196,9 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineBlobCacheExplictLayout) {
wgpu::ComputePipelineDescriptor desc; wgpu::ComputePipelineDescriptor desc;
desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data()); desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data());
desc.compute.entryPoint = "main"; desc.compute.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(counts.shaderModule + counts.pipeline),
device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
// Cache should hit: use the same pipeline but with explicit pipeline layout. // Cache should hit: use the same pipeline but with explicit pipeline layout.
{ {
@ -174,23 +207,22 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineBlobCacheExplictLayout) {
desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data()); desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data());
desc.compute.entryPoint = "main"; desc.compute.entryPoint = "main";
desc.layout = utils::MakeBasicPipelineLayout(device, {}); desc.layout = utils::MakeBasicPipelineLayout(device, {});
EXPECT_CACHE_HIT(mMockCache, 1u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(counts.shaderModule + counts.pipeline), Add(0),
device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
} }
// Tests that pipeline creation wouldn't hit the cache if the pipelines are not exactly the same. // Tests that pipeline creation wouldn't hit the cache if the pipelines are not exactly the same.
TEST_P(SinglePipelineCachingTests, ComputePipelineBlobCacheShaderNegativeCases) { TEST_P(SinglePipelineCachingTests, ComputePipelineBlobCacheShaderNegativeCases) {
size_t numCacheEntries = 0u;
// First time should create and write out to the cache. // First time should create and write out to the cache.
{ {
wgpu::Device device = CreateDevice(); wgpu::Device device = CreateDevice();
wgpu::ComputePipelineDescriptor desc; wgpu::ComputePipelineDescriptor desc;
desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data()); desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data());
desc.compute.entryPoint = "main"; desc.compute.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(counts.shaderModule + counts.pipeline),
device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), ++numCacheEntries);
// Cache should not hit: different shader module. // Cache should not hit: different shader module.
{ {
@ -199,9 +231,9 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineBlobCacheShaderNegativeCases)
desc.compute.module = desc.compute.module =
utils::CreateShaderModule(device, kComputeShaderMultipleEntryPoints.data()); utils::CreateShaderModule(device, kComputeShaderMultipleEntryPoints.data());
desc.compute.entryPoint = "main"; desc.compute.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(counts.shaderModule + counts.pipeline),
device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), ++numCacheEntries);
// Cache should not hit: same shader module but different shader entry point. // Cache should not hit: same shader module but different shader entry point.
{ {
@ -210,9 +242,9 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineBlobCacheShaderNegativeCases)
desc.compute.module = desc.compute.module =
utils::CreateShaderModule(device, kComputeShaderMultipleEntryPoints.data()); utils::CreateShaderModule(device, kComputeShaderMultipleEntryPoints.data());
desc.compute.entryPoint = "main2"; desc.compute.entryPoint = "main2";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(counts.shaderModule + counts.pipeline),
device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), ++numCacheEntries);
} }
// Tests that pipeline creation does not hits the cache when it is enabled but we use different // Tests that pipeline creation does not hits the cache when it is enabled but we use different
@ -224,9 +256,9 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineBlobCacheIsolationKey) {
wgpu::ComputePipelineDescriptor desc; wgpu::ComputePipelineDescriptor desc;
desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data()); desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data());
desc.compute.entryPoint = "main"; desc.compute.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(counts.shaderModule + counts.pipeline),
device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
// Second time should also create and write out to the cache. // Second time should also create and write out to the cache.
{ {
@ -234,9 +266,9 @@ TEST_P(SinglePipelineCachingTests, ComputePipelineBlobCacheIsolationKey) {
wgpu::ComputePipelineDescriptor desc; wgpu::ComputePipelineDescriptor desc;
desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data()); desc.compute.module = utils::CreateShaderModule(device, kComputeShaderDefault.data());
desc.compute.entryPoint = "main"; desc.compute.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(counts.shaderModule + counts.pipeline),
device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 2u);
} }
// Tests that pipeline creation works fine even if the cache is disabled. // Tests that pipeline creation works fine even if the cache is disabled.
@ -254,9 +286,8 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineNoCache) {
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(0), device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 0u);
// Second time should create fine with no cache hits since cache is disabled. // Second time should create fine with no cache hits since cache is disabled.
{ {
@ -266,9 +297,8 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineNoCache) {
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(0), device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 0u);
} }
// Tests that pipeline creation on the same device uses frontend cache when possible. // Tests that pipeline creation on the same device uses frontend cache when possible.
@ -281,14 +311,15 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineFrontedCache) {
// First creation should create a cache entry. // First creation should create a cache entry.
wgpu::RenderPipeline pipeline; wgpu::RenderPipeline pipeline;
EXPECT_CACHE_HIT(mMockCache, 0u, pipeline = device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(2 * counts.shaderModule + counts.pipeline),
EXPECT_EQ(mMockCache.GetNumEntries(), 1u); pipeline = device.CreateRenderPipeline(&desc));
// Second creation on the same device should just return from frontend cache and should not // Second creation on the same device should just return from frontend cache and should not
// call out to the blob cache. // call out to the blob cache.
EXPECT_CALL(mMockCache, LoadData).Times(0); EXPECT_CALL(mMockCache, LoadData).Times(0);
wgpu::RenderPipeline samePipeline; wgpu::RenderPipeline samePipeline;
EXPECT_CACHE_HIT(mMockCache, 0u, samePipeline = device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(0),
samePipeline = device.CreateRenderPipeline(&desc));
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
} }
@ -307,9 +338,9 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCache) {
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(2 * counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
// Second time should create using the cache. // Second time should create using the cache.
{ {
@ -319,9 +350,9 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCache) {
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 1u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(2 * counts.shaderModule + counts.pipeline), Add(0),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
} }
// Tests that pipeline creation hits the cache when using the same pipeline but with explicit // Tests that pipeline creation hits the cache when using the same pipeline but with explicit
@ -338,9 +369,9 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheExplictLayout) {
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(2 * counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
// Cache should hit: use the same pipeline but with explicit pipeline layout. // Cache should hit: use the same pipeline but with explicit pipeline layout.
{ {
@ -351,9 +382,9 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheExplictLayout) {
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
desc.layout = utils::MakeBasicPipelineLayout(device, {}); desc.layout = utils::MakeBasicPipelineLayout(device, {});
EXPECT_CACHE_HIT(mMockCache, 1u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(2 * counts.shaderModule + counts.pipeline), Add(0),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
} }
// Tests that pipeline creation wouldn't hit the cache if the pipelines have different state set in // Tests that pipeline creation wouldn't hit the cache if the pipelines have different state set in
@ -367,11 +398,11 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheDescriptorNegativeCase
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(2 * counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
// Cache should not hit: different pipeline descriptor state. // Cache should hit for shaders, but not pipeline: different pipeline descriptor state.
{ {
wgpu::Device device = CreateDevice(); wgpu::Device device = CreateDevice();
utils::ComboRenderPipelineDescriptor desc; utils::ComboRenderPipelineDescriptor desc;
@ -380,15 +411,14 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheDescriptorNegativeCase
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(2 * counts.shaderModule), Add(counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 2u);
} }
// Tests that pipeline creation wouldn't hit the cache if the pipelines are not exactly the same in // Tests that pipeline creation wouldn't hit the cache if the pipelines are not exactly the same in
// terms of shader. // terms of shader.
TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheShaderNegativeCases) { TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheShaderNegativeCases) {
size_t numCacheEntries = 0u;
// First time should create and write out to the cache. // First time should create and write out to the cache.
{ {
wgpu::Device device = CreateDevice(); wgpu::Device device = CreateDevice();
@ -397,11 +427,12 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheShaderNegativeCases) {
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(2 * counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), ++numCacheEntries);
// Cache should not hit: different shader module. // Cache should not hit for different vertex shader module,
// Cache should still hit for the same fragment shader module.
{ {
wgpu::Device device = CreateDevice(); wgpu::Device device = CreateDevice();
utils::ComboRenderPipelineDescriptor desc; utils::ComboRenderPipelineDescriptor desc;
@ -410,11 +441,13 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheShaderNegativeCases) {
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(counts.shaderModule),
Add(counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), ++numCacheEntries);
// Cache should not hit: same shader module but different shader entry point. // Cache should not hit: same shader module but different shader entry point.
// Cache should still hit for the same shader module.
{ {
wgpu::Device device = CreateDevice(); wgpu::Device device = CreateDevice();
utils::ComboRenderPipelineDescriptor desc; utils::ComboRenderPipelineDescriptor desc;
@ -423,15 +456,15 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheShaderNegativeCases) {
desc.vertex.entryPoint = "main2"; desc.vertex.entryPoint = "main2";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(counts.shaderModule),
Add(counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), ++numCacheEntries);
} }
// Tests that pipeline creation wouldn't hit the cache if the pipelines are not exactly the same // Tests that pipeline creation wouldn't hit the cache if the pipelines are not exactly the same
// (fragment color targets differences). // (fragment color targets differences).
TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheNegativeCasesFragmentColorTargets) { TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheNegativeCasesFragmentColorTargets) {
size_t numCacheEntries = 0u;
// First time should create and write out to the cache. // First time should create and write out to the cache.
{ {
wgpu::Device device = CreateDevice(); wgpu::Device device = CreateDevice();
@ -445,11 +478,11 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheNegativeCasesFragmentC
desc.cFragment.module = desc.cFragment.module =
utils::CreateShaderModule(device, kFragmentShaderMultipleOutput.data()); utils::CreateShaderModule(device, kFragmentShaderMultipleOutput.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(2 * counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), ++numCacheEntries);
// Cache should not hit: different fragment color target state (sparse). // Cache should not hit for the pipeline: different fragment color target state (sparse).
{ {
wgpu::Device device = CreateDevice(); wgpu::Device device = CreateDevice();
utils::ComboRenderPipelineDescriptor desc; utils::ComboRenderPipelineDescriptor desc;
@ -462,9 +495,9 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheNegativeCasesFragmentC
desc.cFragment.module = desc.cFragment.module =
utils::CreateShaderModule(device, kFragmentShaderMultipleOutput.data()); utils::CreateShaderModule(device, kFragmentShaderMultipleOutput.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(2 * counts.shaderModule), Add(counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), ++numCacheEntries);
// Cache should not hit: different fragment color target state (trailing empty). // Cache should not hit: different fragment color target state (trailing empty).
{ {
@ -479,9 +512,101 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheNegativeCasesFragmentC
desc.cFragment.module = desc.cFragment.module =
utils::CreateShaderModule(device, kFragmentShaderMultipleOutput.data()); utils::CreateShaderModule(device, kFragmentShaderMultipleOutput.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(2 * counts.shaderModule), Add(counts.pipeline),
device.CreateRenderPipeline(&desc));
}
}
// Tests that pipeline creation hits the cache for shaders, but not the pipeline if the
// shaders aren't impacted by the layout. This test is a bit change detecting - but all
// cached backends currently remap shader bindings based on the layout. It can be split
// per-backend as needed.
TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheLayout) {
// First time should create and write out to the cache.
{
wgpu::Device device = CreateDevice();
utils::ComboRenderPipelineDescriptor desc;
desc.vertex.module = utils::CreateShaderModule(device, kVertexShaderDefault.data());
desc.vertex.entryPoint = "main";
desc.cFragment.module =
utils::CreateShaderModule(device, kFragmentShaderBindGroup00Uniform.data());
desc.cFragment.entryPoint = "main";
desc.layout = utils::MakePipelineLayout(
device, {
utils::MakeBindGroupLayout(
device,
{
{0, wgpu::ShaderStage::Fragment, wgpu::BufferBindingType::Uniform},
}),
});
EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(2 * counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
}
// Cache should hit for the shaders, but not for the pipeline: different layout.
{
wgpu::Device device = CreateDevice();
utils::ComboRenderPipelineDescriptor desc;
desc.vertex.module = utils::CreateShaderModule(device, kVertexShaderDefault.data());
desc.vertex.entryPoint = "main";
desc.cFragment.module =
utils::CreateShaderModule(device, kFragmentShaderBindGroup00Uniform.data());
desc.cFragment.entryPoint = "main";
desc.layout = utils::MakePipelineLayout(
device, {
utils::MakeBindGroupLayout(
device,
{
{0, wgpu::ShaderStage::Fragment, wgpu::BufferBindingType::Uniform},
{1, wgpu::ShaderStage::Fragment, wgpu::BufferBindingType::Uniform},
}),
});
EXPECT_CACHE_STATS(mMockCache, Hit(2 * counts.shaderModule), Add(counts.pipeline),
device.CreateRenderPipeline(&desc));
}
// Cache should hit for the shaders, but not for the pipeline: different layout (dynamic).
{
wgpu::Device device = CreateDevice();
utils::ComboRenderPipelineDescriptor desc;
desc.vertex.module = utils::CreateShaderModule(device, kVertexShaderDefault.data());
desc.vertex.entryPoint = "main";
desc.cFragment.module =
utils::CreateShaderModule(device, kFragmentShaderBindGroup00Uniform.data());
desc.cFragment.entryPoint = "main";
desc.layout = utils::MakePipelineLayout(
device, {
utils::MakeBindGroupLayout(device,
{
{0, wgpu::ShaderStage::Fragment,
wgpu::BufferBindingType::Uniform, true},
}),
});
EXPECT_CACHE_STATS(mMockCache, Hit(2 * counts.shaderModule), Add(counts.pipeline),
device.CreateRenderPipeline(&desc));
}
// Cache should hit for the shaders, but not for the pipeline.
// The shader is different but compiles to the same due to binding number remapping.
{
wgpu::Device device = CreateDevice();
utils::ComboRenderPipelineDescriptor desc;
desc.vertex.module = utils::CreateShaderModule(device, kVertexShaderDefault.data());
desc.vertex.entryPoint = "main";
desc.cFragment.module =
utils::CreateShaderModule(device, kFragmentShaderBindGroup01Uniform.data());
desc.cFragment.entryPoint = "main";
desc.layout = utils::MakePipelineLayout(
device, {
utils::MakeBindGroupLayout(
device,
{
{1, wgpu::ShaderStage::Fragment, wgpu::BufferBindingType::Uniform},
}),
});
EXPECT_CACHE_STATS(mMockCache, Hit(2 * counts.shaderModule), Add(counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), ++numCacheEntries);
} }
// Tests that pipeline creation does not hits the cache when it is enabled but we use different // Tests that pipeline creation does not hits the cache when it is enabled but we use different
@ -495,9 +620,9 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheIsolationKey) {
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(2 * counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 1u);
// Second time should also create and write out to the cache. // Second time should also create and write out to the cache.
{ {
@ -507,13 +632,16 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheIsolationKey) {
desc.vertex.entryPoint = "main"; desc.vertex.entryPoint = "main";
desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data()); desc.cFragment.module = utils::CreateShaderModule(device, kFragmentShaderDefault.data());
desc.cFragment.entryPoint = "main"; desc.cFragment.entryPoint = "main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_STATS(mMockCache, Hit(0), Add(2 * counts.shaderModule + counts.pipeline),
device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 2u);
} }
DAWN_INSTANTIATE_TEST(SinglePipelineCachingTests, DAWN_INSTANTIATE_TEST(SinglePipelineCachingTests,
VulkanBackend({"enable_blob_cache"}), D3D12Backend({"enable_blob_cache"}),
D3D12Backend({"enable_blob_cache"})); MetalBackend({"enable_blob_cache"}),
OpenGLBackend({"enable_blob_cache"}),
OpenGLESBackend({"enable_blob_cache"}),
VulkanBackend({"enable_blob_cache"}));
} // namespace } // namespace

View File

@ -22,6 +22,8 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "dawn/common/TypedInteger.h"
#define EXPECT_CACHE_HIT(cache, N, statement) \ #define EXPECT_CACHE_HIT(cache, N, statement) \
do { \ do { \
FlushWire(); \ FlushWire(); \
@ -32,6 +34,26 @@
EXPECT_EQ(N, after - before); \ EXPECT_EQ(N, after - before); \
} while (0) } while (0)
// Check that |HitN| cache hits occured, and |AddN| entries were added.
// Usage: EXPECT_CACHE_STATS(myMockCache, Hit(42), Add(3), ...)
// Hit / Add help readability, and enforce the args are passed correctly in the expected order.
#define EXPECT_CACHE_STATS(cache, HitN, AddN, statement) \
do { \
using Hit = TypedInteger<struct HitT, size_t>; \
using Add = TypedInteger<struct AddT, size_t>; \
static_assert(std::is_same_v<decltype(HitN), Hit>); \
static_assert(std::is_same_v<decltype(AddN), Add>); \
FlushWire(); \
size_t hitBefore = cache.GetHitCount(); \
size_t entriesBefore = cache.GetNumEntries(); \
statement; \
FlushWire(); \
size_t hitAfter = cache.GetHitCount(); \
size_t entriesAfter = cache.GetNumEntries(); \
EXPECT_EQ(static_cast<size_t>(HitN), hitAfter - hitBefore); \
EXPECT_EQ(static_cast<size_t>(AddN), entriesAfter - entriesBefore); \
} while (0)
// A mock caching interface class that also supplies an in memory cache for testing. // A mock caching interface class that also supplies an in memory cache for testing.
class CachingInterfaceMock : public dawn::platform::CachingInterface { class CachingInterfaceMock : public dawn::platform::CachingInterface {
public: public:

View File

@ -15,11 +15,14 @@
#include <cstring> #include <cstring>
#include <iomanip> #include <iomanip>
#include <string> #include <string>
#include <unordered_map>
#include <utility>
#include <vector> #include <vector>
#include "dawn/native/CacheKey.h" #include "dawn/native/CacheKey.h"
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "tint/tint.h"
namespace dawn::native { namespace dawn::native {
@ -54,6 +57,7 @@ MATCHER_P(CacheKeyEq, key, PrintToString(key)) {
return arg.size() == key.size() && memcmp(arg.data(), key.data(), key.size()) == 0; return arg.size() == key.size() && memcmp(arg.data(), key.data(), key.size()) == 0;
} }
// Test that CacheKey::Record calls serialize on the single member of a struct.
TEST(CacheKeyTests, RecordSingleMember) { TEST(CacheKeyTests, RecordSingleMember) {
CacheKey key; CacheKey key;
@ -62,6 +66,7 @@ TEST(CacheKeyTests, RecordSingleMember) {
EXPECT_THAT(key.Record(a), CacheKeyEq(CacheKey())); EXPECT_THAT(key.Record(a), CacheKeyEq(CacheKey()));
} }
// Test that CacheKey::Record calls serialize on all members of a struct.
TEST(CacheKeyTests, RecordManyMembers) { TEST(CacheKeyTests, RecordManyMembers) {
constexpr size_t kNumMembers = 100; constexpr size_t kNumMembers = 100;
@ -74,6 +79,7 @@ TEST(CacheKeyTests, RecordManyMembers) {
EXPECT_THAT(key, CacheKeyEq(CacheKey())); EXPECT_THAT(key, CacheKeyEq(CacheKey()));
} }
// Test that CacheKey::Record calls serialize on all elements of an iterable.
TEST(CacheKeyTests, RecordIterable) { TEST(CacheKeyTests, RecordIterable) {
constexpr size_t kIterableSize = 100; constexpr size_t kIterableSize = 100;
@ -96,6 +102,7 @@ TEST(CacheKeyTests, RecordIterable) {
EXPECT_THAT(CacheKey().RecordIterable(iterable.data(), kIterableSize), CacheKeyEq(expected)); EXPECT_THAT(CacheKey().RecordIterable(iterable.data(), kIterableSize), CacheKeyEq(expected));
} }
// Test that CacheKey::Record calls serialize on all members and nested struct members.
TEST(CacheKeyTests, RecordNested) { TEST(CacheKeyTests, RecordNested) {
CacheKey expected; CacheKey expected;
CacheKey actual; CacheKey actual;
@ -132,6 +139,7 @@ TEST(CacheKeyTests, RecordNested) {
EXPECT_THAT(actual, CacheKeyEq(expected)); EXPECT_THAT(actual, CacheKeyEq(expected));
} }
// Test that CacheKey::Record serializes integral data as expected.
TEST(CacheKeySerializerTests, IntegralTypes) { TEST(CacheKeySerializerTests, IntegralTypes) {
// Only testing explicitly sized types for simplicity, and using 0s for larger types to // Only testing explicitly sized types for simplicity, and using 0s for larger types to
// avoid dealing with endianess. // avoid dealing with endianess.
@ -141,12 +149,14 @@ TEST(CacheKeySerializerTests, IntegralTypes) {
EXPECT_THAT(CacheKey().Record(uint32_t(0)), CacheKeyEq(CacheKey({0, 0, 0, 0}))); EXPECT_THAT(CacheKey().Record(uint32_t(0)), CacheKeyEq(CacheKey({0, 0, 0, 0})));
} }
// Test that CacheKey::Record serializes floating-point data as expected.
TEST(CacheKeySerializerTests, FloatingTypes) { TEST(CacheKeySerializerTests, FloatingTypes) {
// Using 0s to avoid dealing with implementation specific float details. // Using 0s to avoid dealing with implementation specific float details.
EXPECT_THAT(CacheKey().Record(float{0}), CacheKeyEq(CacheKey(sizeof(float), 0))); EXPECT_THAT(CacheKey().Record(float{0}), CacheKeyEq(CacheKey(sizeof(float), 0)));
EXPECT_THAT(CacheKey().Record(double{0}), CacheKeyEq(CacheKey(sizeof(double), 0))); EXPECT_THAT(CacheKey().Record(double{0}), CacheKeyEq(CacheKey(sizeof(double), 0)));
} }
// Test that CacheKey::Record serializes literal strings as expected.
TEST(CacheKeySerializerTests, LiteralStrings) { TEST(CacheKeySerializerTests, LiteralStrings) {
// Using a std::string here to help with creating the expected result. // Using a std::string here to help with creating the expected result.
std::string str = "string"; std::string str = "string";
@ -159,6 +169,7 @@ TEST(CacheKeySerializerTests, LiteralStrings) {
EXPECT_THAT(CacheKey().Record("string"), CacheKeyEq(expected)); EXPECT_THAT(CacheKey().Record("string"), CacheKeyEq(expected));
} }
// Test that CacheKey::Record serializes std::strings as expected.
TEST(CacheKeySerializerTests, StdStrings) { TEST(CacheKeySerializerTests, StdStrings) {
std::string str = "string"; std::string str = "string";
@ -169,6 +180,7 @@ TEST(CacheKeySerializerTests, StdStrings) {
EXPECT_THAT(CacheKey().Record(str), CacheKeyEq(expected)); EXPECT_THAT(CacheKey().Record(str), CacheKeyEq(expected));
} }
// Test that CacheKey::Record serializes std::string_views as expected.
TEST(CacheKeySerializerTests, StdStringViews) { TEST(CacheKeySerializerTests, StdStringViews) {
static constexpr std::string_view str("string"); static constexpr std::string_view str("string");
@ -179,6 +191,7 @@ TEST(CacheKeySerializerTests, StdStringViews) {
EXPECT_THAT(CacheKey().Record(str), CacheKeyEq(expected)); EXPECT_THAT(CacheKey().Record(str), CacheKeyEq(expected));
} }
// Test that CacheKey::Record serializes other CacheKeys as expected.
TEST(CacheKeySerializerTests, CacheKeys) { TEST(CacheKeySerializerTests, CacheKeys) {
CacheKey data = {'d', 'a', 't', 'a'}; CacheKey data = {'d', 'a', 't', 'a'};
@ -188,6 +201,53 @@ TEST(CacheKeySerializerTests, CacheKeys) {
EXPECT_THAT(CacheKey().Record(data), CacheKeyEq(expected)); EXPECT_THAT(CacheKey().Record(data), CacheKeyEq(expected));
} }
// Test that CacheKey::Record serializes std::pair as expected.
TEST(CacheKeySerializerTests, StdPair) {
std::string_view s = "hi!";
CacheKey expected;
expected.Record(s);
expected.Record(uint32_t(42));
EXPECT_THAT(CacheKey().Record(std::make_pair(s, uint32_t(42))), CacheKeyEq(expected));
}
// Test that CacheKey::Record serializes std::unordered_map as expected.
TEST(CacheKeySerializerTests, StdUnorderedMap) {
std::unordered_map<uint32_t, std::string_view> m;
m[4] = "hello";
m[1] = "world";
m[7] = "test";
m[3] = "data";
// Expect the number of entries, followed by (K, V) pairs sorted in order of key.
CacheKey expected;
expected.Record(size_t(4));
expected.Record(std::make_pair(uint32_t(1), m[1]));
expected.Record(std::make_pair(uint32_t(3), m[3]));
expected.Record(std::make_pair(uint32_t(4), m[4]));
expected.Record(std::make_pair(uint32_t(7), m[7]));
EXPECT_THAT(CacheKey().Record(m), CacheKeyEq(expected));
}
// Test that CacheKey::Record serializes tint::sem::BindingPoint as expected.
TEST(CacheKeySerializerTests, TintSemBindingPoint) {
tint::sem::BindingPoint bp{3, 6};
EXPECT_THAT(CacheKey().Record(bp), CacheKeyEq(CacheKey().Record(uint32_t(3), uint32_t(6))));
}
// Test that CacheKey::Record serializes tint::transform::BindingPoints as expected.
TEST(CacheKeySerializerTests, TintTransformBindingPoints) {
tint::transform::BindingPoints points{
tint::sem::BindingPoint{1, 4},
tint::sem::BindingPoint{3, 7},
};
EXPECT_THAT(CacheKey().Record(points),
CacheKeyEq(CacheKey().Record(uint32_t(1), uint32_t(4), uint32_t(3), uint32_t(7))));
}
} // namespace } // namespace
} // namespace dawn::native } // namespace dawn::native