Fixes cache key generation to handle binary data.
Bug: dawn:549 Change-Id: Ie6b3ceb610b362adfed96a0982d7541002660809 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/84920 Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Loko Kung <lokokung@google.com>
This commit is contained in:
parent
eab5300e87
commit
c53bc6f698
|
@ -18,15 +18,14 @@ namespace dawn::native {
|
|||
|
||||
template <>
|
||||
void CacheKeySerializer<std::string>::Serialize(CacheKey* key, const std::string& t) {
|
||||
std::string len = std::to_string(t.length());
|
||||
key->insert(key->end(), len.begin(), len.end());
|
||||
key->push_back('"');
|
||||
key->Record(static_cast<size_t>(t.length()));
|
||||
key->insert(key->end(), t.begin(), t.end());
|
||||
key->push_back('"');
|
||||
}
|
||||
|
||||
template <>
|
||||
void CacheKeySerializer<CacheKey>::Serialize(CacheKey* key, const CacheKey& t) {
|
||||
// For nested cache keys, we do not record the length, and just copy the key so that it
|
||||
// appears we just flatten the keys into a single key.
|
||||
key->insert(key->end(), t.begin(), t.end());
|
||||
}
|
||||
|
||||
|
|
|
@ -15,79 +15,84 @@
|
|||
#ifndef DAWNNATIVE_CACHE_KEY_H_
|
||||
#define DAWNNATIVE_CACHE_KEY_H_
|
||||
|
||||
#include <limits>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "dawn/common/Compiler.h"
|
||||
#include "dawn/common/Assert.h"
|
||||
|
||||
namespace dawn::native {
|
||||
|
||||
using CacheKey = std::vector<uint8_t>;
|
||||
// Forward declare CacheKey class because of co-dependency.
|
||||
class CacheKey;
|
||||
|
||||
// Overridable serializer struct that should be implemented for cache key serializable
|
||||
// types/classes.
|
||||
template <typename T, typename SFINAE = void>
|
||||
struct CacheKeySerializer {
|
||||
class CacheKeySerializer {
|
||||
public:
|
||||
static void Serialize(CacheKey* key, const T& t);
|
||||
};
|
||||
|
||||
// Specialized overload for integral types. Note that we are currently serializing as a string
|
||||
// to avoid handling null termiantors.
|
||||
template <typename Integer>
|
||||
struct CacheKeySerializer<Integer, std::enable_if_t<std::is_integral_v<Integer>>> {
|
||||
static void Serialize(CacheKey* key, const Integer i) {
|
||||
std::string str = std::to_string(i);
|
||||
key->insert(key->end(), str.begin(), str.end());
|
||||
class CacheKey : public std::vector<uint8_t> {
|
||||
public:
|
||||
using std::vector<uint8_t>::vector;
|
||||
|
||||
template <typename T>
|
||||
CacheKey& Record(const T& t) {
|
||||
CacheKeySerializer<T>::Serialize(this, t);
|
||||
return *this;
|
||||
}
|
||||
template <typename T, typename... Args>
|
||||
CacheKey& Record(const T& t, const Args&... args) {
|
||||
CacheKeySerializer<T>::Serialize(this, t);
|
||||
return Record(args...);
|
||||
}
|
||||
|
||||
// Records iterables by prepending the number of elements. Some common iterables are have a
|
||||
// CacheKeySerializer implemented to avoid needing to split them out when recording, i.e.
|
||||
// strings and CacheKeys, but they fundamentally do the same as this function.
|
||||
template <typename IterableT>
|
||||
CacheKey& RecordIterable(const IterableT& iterable) {
|
||||
// Always record the size of generic iterables as a size_t for now.
|
||||
Record(static_cast<size_t>(iterable.size()));
|
||||
for (auto it = iterable.begin(); it != iterable.end(); ++it) {
|
||||
Record(*it);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
template <typename Ptr>
|
||||
CacheKey& RecordIterable(const Ptr* ptr, size_t n) {
|
||||
Record(n);
|
||||
for (size_t i = 0; i < n; ++i) {
|
||||
Record(ptr[i]);
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
// Specialized overload for floating point types. Note that we are currently serializing as a
|
||||
// string to avoid handling null termiantors.
|
||||
template <typename Float>
|
||||
struct CacheKeySerializer<Float, std::enable_if_t<std::is_floating_point_v<Float>>> {
|
||||
static void Serialize(CacheKey* key, const Float f) {
|
||||
std::string str = std::to_string(f);
|
||||
key->insert(key->end(), str.begin(), str.end());
|
||||
// Specialized overload for fundamental types.
|
||||
template <typename T>
|
||||
class CacheKeySerializer<T, std::enable_if_t<std::is_fundamental_v<T>>> {
|
||||
public:
|
||||
static void Serialize(CacheKey* key, const T t) {
|
||||
const char* it = reinterpret_cast<const char*>(&t);
|
||||
key->insert(key->end(), it, (it + sizeof(T)));
|
||||
}
|
||||
};
|
||||
|
||||
// Specialized overload for string literals. Note we drop the null-terminator.
|
||||
template <size_t N>
|
||||
struct CacheKeySerializer<char[N]> {
|
||||
class CacheKeySerializer<char[N]> {
|
||||
public:
|
||||
static void Serialize(CacheKey* key, const char (&t)[N]) {
|
||||
std::string len = std::to_string(N - 1);
|
||||
key->insert(key->end(), len.begin(), len.end());
|
||||
key->push_back('"');
|
||||
key->insert(key->end(), t, t + N - 1);
|
||||
key->push_back('"');
|
||||
static_assert(N > 0);
|
||||
key->Record(static_cast<size_t>(N));
|
||||
key->insert(key->end(), t, t + N);
|
||||
}
|
||||
};
|
||||
|
||||
// Helper template function that defers to underlying static functions.
|
||||
template <typename T>
|
||||
void SerializeInto(CacheKey* key, const T& t) {
|
||||
CacheKeySerializer<T>::Serialize(key, t);
|
||||
}
|
||||
|
||||
// Given list of arguments of types with a free implementation of SerializeIntoImpl in the
|
||||
// dawn::native namespace, serializes each argument and appends them to the CacheKey while
|
||||
// prepending member ids before each argument.
|
||||
template <typename... Ts>
|
||||
CacheKey GetCacheKey(const Ts&... inputs) {
|
||||
CacheKey key;
|
||||
key.push_back('{');
|
||||
int memberId = 0;
|
||||
auto Serialize = [&](const auto& input) {
|
||||
std::string memberIdStr = (memberId == 0 ? "" : ",") + std::to_string(memberId) + ":";
|
||||
key.insert(key.end(), memberIdStr.begin(), memberIdStr.end());
|
||||
SerializeInto(&key, input);
|
||||
memberId++;
|
||||
};
|
||||
(Serialize(inputs), ...);
|
||||
key.push_back('}');
|
||||
return key;
|
||||
}
|
||||
|
||||
} // namespace dawn::native
|
||||
|
||||
#endif // DAWNNATIVE_CACHE_KEY_H_
|
||||
|
|
|
@ -42,26 +42,12 @@ namespace dawn::native {
|
|||
mIsContentHashInitialized = true;
|
||||
}
|
||||
|
||||
const std::string& CachedObject::GetCacheKey() const {
|
||||
ASSERT(mIsCacheKeyBaseInitialized);
|
||||
return mCacheKeyBase;
|
||||
const CacheKey& CachedObject::GetCacheKey() const {
|
||||
return mCacheKey;
|
||||
}
|
||||
|
||||
std::string CachedObject::GetCacheKey(DeviceBase* device) const {
|
||||
ASSERT(mIsCacheKeyBaseInitialized);
|
||||
// TODO(dawn:549) Prepend/append with device/adapter information.
|
||||
return mCacheKeyBase;
|
||||
}
|
||||
|
||||
void CachedObject::SetCacheKey(const std::string& cacheKey) {
|
||||
ASSERT(!mIsContentHashInitialized);
|
||||
mCacheKeyBase = cacheKey;
|
||||
mIsCacheKeyBaseInitialized = true;
|
||||
}
|
||||
|
||||
std::string CachedObject::ComputeCacheKeyBase() const {
|
||||
// This implementation should never be called. Only overrides should be called.
|
||||
UNREACHABLE();
|
||||
CacheKey* CachedObject::GetCacheKey() {
|
||||
return &mCacheKey;
|
||||
}
|
||||
|
||||
} // namespace dawn::native
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#ifndef DAWNNATIVE_CACHED_OBJECT_H_
|
||||
#define DAWNNATIVE_CACHED_OBJECT_H_
|
||||
|
||||
#include "dawn/native/CacheKey.h"
|
||||
#include "dawn/native/Forward.h"
|
||||
|
||||
#include <cstddef>
|
||||
|
@ -38,13 +39,12 @@ namespace dawn::native {
|
|||
size_t GetContentHash() const;
|
||||
void SetContentHash(size_t contentHash);
|
||||
|
||||
// Two versions of GetCacheKey, when passed a device, prepends the stored cache
|
||||
// key base with device and adapter information. When called without passing a
|
||||
// device, returns the stored cache key base. This is useful when the instance
|
||||
// is a member to a parent class.
|
||||
const std::string& GetCacheKey() const;
|
||||
std::string GetCacheKey(DeviceBase* device) const;
|
||||
void SetCacheKey(const std::string& cacheKey);
|
||||
// Returns the cache key for the object only, i.e. without device/adapter information.
|
||||
const CacheKey& GetCacheKey() const;
|
||||
|
||||
protected:
|
||||
// Protected accessor for derived classes to access and modify the key.
|
||||
CacheKey* GetCacheKey();
|
||||
|
||||
private:
|
||||
friend class DeviceBase;
|
||||
|
@ -55,13 +55,9 @@ namespace dawn::native {
|
|||
// Called by ObjectContentHasher upon creation to record the object.
|
||||
virtual size_t ComputeContentHash() = 0;
|
||||
|
||||
// Not all classes implement cache key computation, so by default we assert.
|
||||
virtual std::string ComputeCacheKeyBase() const;
|
||||
|
||||
size_t mContentHash = 0;
|
||||
bool mIsContentHashInitialized = false;
|
||||
std::string mCacheKeyBase = "";
|
||||
bool mIsCacheKeyBaseInitialized = false;
|
||||
CacheKey mCacheKey;
|
||||
};
|
||||
|
||||
} // namespace dawn::native
|
||||
|
|
|
@ -15,73 +15,168 @@
|
|||
#include <gmock/gmock.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
|
||||
#include "dawn/native/CacheKey.h"
|
||||
|
||||
namespace dawn::native {
|
||||
|
||||
// Testing classes/structs with serializing implemented for testing.
|
||||
struct A {};
|
||||
// Testing classes with mock serializing implemented for testing.
|
||||
class A {
|
||||
public:
|
||||
MOCK_METHOD(void, SerializeMock, (CacheKey*, const A&), (const));
|
||||
};
|
||||
template <>
|
||||
void CacheKeySerializer<A>::Serialize(CacheKey* key, const A& t) {
|
||||
std::string str = "structA";
|
||||
key->insert(key->end(), str.begin(), str.end());
|
||||
t.SerializeMock(key, t);
|
||||
}
|
||||
|
||||
class B {};
|
||||
template <>
|
||||
void CacheKeySerializer<B>::Serialize(CacheKey* key, const B& t) {
|
||||
std::string str = "classB";
|
||||
key->insert(key->end(), str.begin(), str.end());
|
||||
// Custom printer for CacheKey for clearer debug testing messages.
|
||||
void PrintTo(const CacheKey& key, std::ostream* stream) {
|
||||
*stream << std::hex;
|
||||
for (const int b : key) {
|
||||
*stream << std::setfill('0') << std::setw(2) << b << " ";
|
||||
}
|
||||
*stream << std::dec;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Matcher to compare CacheKey to a string for easier testing.
|
||||
MATCHER_P(CacheKeyEq,
|
||||
key,
|
||||
"cache key " + std::string(negation ? "not" : "") + "equal to " + key) {
|
||||
return std::string(arg.begin(), arg.end()) == key;
|
||||
using ::testing::InSequence;
|
||||
using ::testing::NotNull;
|
||||
using ::testing::PrintToString;
|
||||
using ::testing::Ref;
|
||||
|
||||
// Matcher to compare CacheKeys for easier testing.
|
||||
MATCHER_P(CacheKeyEq, key, PrintToString(key)) {
|
||||
return memcmp(arg.data(), key.data(), arg.size()) == 0;
|
||||
}
|
||||
|
||||
TEST(CacheKeyTest, IntegralTypes) {
|
||||
EXPECT_THAT(GetCacheKey((int)-1), CacheKeyEq("{0:-1}"));
|
||||
EXPECT_THAT(GetCacheKey((uint8_t)2), CacheKeyEq("{0:2}"));
|
||||
EXPECT_THAT(GetCacheKey((uint16_t)4), CacheKeyEq("{0:4}"));
|
||||
EXPECT_THAT(GetCacheKey((uint32_t)8), CacheKeyEq("{0:8}"));
|
||||
EXPECT_THAT(GetCacheKey((uint64_t)16), CacheKeyEq("{0:16}"));
|
||||
TEST(CacheKeyTests, RecordSingleMember) {
|
||||
CacheKey key;
|
||||
|
||||
EXPECT_THAT(GetCacheKey((int)-1, (uint8_t)2, (uint16_t)4, (uint32_t)8, (uint64_t)16),
|
||||
CacheKeyEq("{0:-1,1:2,2:4,3:8,4:16}"));
|
||||
A a;
|
||||
EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
|
||||
EXPECT_THAT(key.Record(a), CacheKeyEq(CacheKey()));
|
||||
}
|
||||
|
||||
TEST(CacheKeyTest, FloatingTypes) {
|
||||
EXPECT_THAT(GetCacheKey((float)0.5), CacheKeyEq("{0:0.500000}"));
|
||||
EXPECT_THAT(GetCacheKey((double)32.0), CacheKeyEq("{0:32.000000}"));
|
||||
TEST(CacheKeyTests, RecordManyMembers) {
|
||||
constexpr size_t kNumMembers = 100;
|
||||
|
||||
EXPECT_THAT(GetCacheKey((float)0.5, (double)32.0),
|
||||
CacheKeyEq("{0:0.500000,1:32.000000}"));
|
||||
CacheKey key;
|
||||
for (size_t i = 0; i < kNumMembers; ++i) {
|
||||
A a;
|
||||
EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
|
||||
key.Record(a);
|
||||
}
|
||||
EXPECT_THAT(key, CacheKeyEq(CacheKey()));
|
||||
}
|
||||
|
||||
TEST(CacheKeyTest, Strings) {
|
||||
std::string str0 = "string0";
|
||||
std::string str1 = "string1";
|
||||
TEST(CacheKeyTests, RecordIterable) {
|
||||
constexpr size_t kIterableSize = 100;
|
||||
|
||||
EXPECT_THAT(GetCacheKey("string0"), CacheKeyEq(R"({0:7"string0"})"));
|
||||
EXPECT_THAT(GetCacheKey(str0), CacheKeyEq(R"({0:7"string0"})"));
|
||||
EXPECT_THAT(GetCacheKey("string0", str1), CacheKeyEq(R"({0:7"string0",1:7"string1"})"));
|
||||
// Expecting the size of the container.
|
||||
CacheKey expected;
|
||||
expected.Record(kIterableSize);
|
||||
|
||||
std::vector<A> iterable(kIterableSize);
|
||||
{
|
||||
InSequence seq;
|
||||
for (const auto& a : iterable) {
|
||||
EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
|
||||
}
|
||||
for (const auto& a : iterable) {
|
||||
EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
|
||||
}
|
||||
}
|
||||
|
||||
EXPECT_THAT(CacheKey().RecordIterable(iterable), CacheKeyEq(expected));
|
||||
EXPECT_THAT(CacheKey().RecordIterable(iterable.data(), kIterableSize),
|
||||
CacheKeyEq(expected));
|
||||
}
|
||||
|
||||
TEST(CacheKeyTest, NestedCacheKey) {
|
||||
EXPECT_THAT(GetCacheKey(GetCacheKey((int)-1)), CacheKeyEq("{0:{0:-1}}"));
|
||||
EXPECT_THAT(GetCacheKey(GetCacheKey("string")), CacheKeyEq(R"({0:{0:6"string"}})"));
|
||||
EXPECT_THAT(GetCacheKey(GetCacheKey(A{})), CacheKeyEq("{0:{0:structA}}"));
|
||||
EXPECT_THAT(GetCacheKey(GetCacheKey(B())), CacheKeyEq("{0:{0:classB}}"));
|
||||
TEST(CacheKeyTests, RecordNested) {
|
||||
CacheKey expected;
|
||||
CacheKey actual;
|
||||
{
|
||||
// Recording a single member.
|
||||
A a;
|
||||
EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
|
||||
actual.Record(CacheKey().Record(a));
|
||||
}
|
||||
{
|
||||
// Recording multiple members.
|
||||
constexpr size_t kNumMembers = 2;
|
||||
CacheKey sub;
|
||||
for (size_t i = 0; i < kNumMembers; ++i) {
|
||||
A a;
|
||||
EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
|
||||
sub.Record(a);
|
||||
}
|
||||
actual.Record(sub);
|
||||
}
|
||||
{
|
||||
// Record an iterable.
|
||||
constexpr size_t kIterableSize = 2;
|
||||
expected.Record(kIterableSize);
|
||||
std::vector<A> iterable(kIterableSize);
|
||||
{
|
||||
InSequence seq;
|
||||
for (const auto& a : iterable) {
|
||||
EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1);
|
||||
}
|
||||
}
|
||||
actual.Record(CacheKey().RecordIterable(iterable));
|
||||
}
|
||||
EXPECT_THAT(actual, CacheKeyEq(expected));
|
||||
}
|
||||
|
||||
EXPECT_THAT(GetCacheKey(GetCacheKey((int)-1), GetCacheKey("string"), GetCacheKey(A{}),
|
||||
GetCacheKey(B())),
|
||||
CacheKeyEq(R"({0:{0:-1},1:{0:6"string"},2:{0:structA},3:{0:classB}})"));
|
||||
TEST(CacheKeySerializerTests, IntegralTypes) {
|
||||
// Only testing explicitly sized types for simplicity, and using 0s for larger types to
|
||||
// avoid dealing with endianess.
|
||||
EXPECT_THAT(CacheKey().Record('c'), CacheKeyEq(CacheKey({'c'})));
|
||||
EXPECT_THAT(CacheKey().Record(uint8_t(255)), CacheKeyEq(CacheKey({255})));
|
||||
EXPECT_THAT(CacheKey().Record(uint16_t(0)), CacheKeyEq(CacheKey({0, 0})));
|
||||
EXPECT_THAT(CacheKey().Record(uint32_t(0)), CacheKeyEq(CacheKey({0, 0, 0, 0})));
|
||||
}
|
||||
|
||||
TEST(CacheKeySerializerTests, FloatingTypes) {
|
||||
// Using 0s to avoid dealing with implementation specific float details.
|
||||
EXPECT_THAT(CacheKey().Record(float(0)), CacheKeyEq(CacheKey(sizeof(float), 0)));
|
||||
EXPECT_THAT(CacheKey().Record(double(0)), CacheKeyEq(CacheKey(sizeof(double), 0)));
|
||||
}
|
||||
|
||||
TEST(CacheKeySerializerTests, LiteralStrings) {
|
||||
// Using a std::string here to help with creating the expected result.
|
||||
std::string str = "string";
|
||||
|
||||
CacheKey expected;
|
||||
expected.Record(size_t(7));
|
||||
expected.insert(expected.end(), str.begin(), str.end());
|
||||
expected.push_back('\0');
|
||||
|
||||
EXPECT_THAT(CacheKey().Record("string"), CacheKeyEq(expected));
|
||||
}
|
||||
|
||||
TEST(CacheKeySerializerTests, StdStrings) {
|
||||
std::string str = "string";
|
||||
|
||||
CacheKey expected;
|
||||
expected.Record((size_t)6);
|
||||
expected.insert(expected.end(), str.begin(), str.end());
|
||||
|
||||
EXPECT_THAT(CacheKey().Record(str), CacheKeyEq(expected));
|
||||
}
|
||||
|
||||
TEST(CacheKeySerializerTests, CacheKeys) {
|
||||
CacheKey data = {'d', 'a', 't', 'a'};
|
||||
|
||||
CacheKey expected;
|
||||
expected.insert(expected.end(), data.begin(), data.end());
|
||||
|
||||
EXPECT_THAT(CacheKey().Record(data), CacheKeyEq(expected));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
Loading…
Reference in New Issue