diff --git a/src/dawn/native/CacheKey.cpp b/src/dawn/native/CacheKey.cpp index ff2cec0495..3495577c06 100644 --- a/src/dawn/native/CacheKey.cpp +++ b/src/dawn/native/CacheKey.cpp @@ -18,15 +18,14 @@ namespace dawn::native { template <> void CacheKeySerializer::Serialize(CacheKey* key, const std::string& t) { - std::string len = std::to_string(t.length()); - key->insert(key->end(), len.begin(), len.end()); - key->push_back('"'); + key->Record(static_cast(t.length())); key->insert(key->end(), t.begin(), t.end()); - key->push_back('"'); } template <> void CacheKeySerializer::Serialize(CacheKey* key, const CacheKey& t) { + // For nested cache keys, we do not record the length, and just copy the key so that it + // appears we just flatten the keys into a single key. key->insert(key->end(), t.begin(), t.end()); } diff --git a/src/dawn/native/CacheKey.h b/src/dawn/native/CacheKey.h index cb58711e28..ce21f6d7bd 100644 --- a/src/dawn/native/CacheKey.h +++ b/src/dawn/native/CacheKey.h @@ -15,79 +15,84 @@ #ifndef DAWNNATIVE_CACHE_KEY_H_ #define DAWNNATIVE_CACHE_KEY_H_ +#include #include +#include #include -#include "dawn/common/Compiler.h" +#include "dawn/common/Assert.h" namespace dawn::native { - using CacheKey = std::vector; + // Forward declare CacheKey class because of co-dependency. + class CacheKey; // Overridable serializer struct that should be implemented for cache key serializable // types/classes. template - struct CacheKeySerializer { + class CacheKeySerializer { + public: static void Serialize(CacheKey* key, const T& t); }; - // Specialized overload for integral types. Note that we are currently serializing as a string - // to avoid handling null termiantors. - template - struct CacheKeySerializer>> { - static void Serialize(CacheKey* key, const Integer i) { - std::string str = std::to_string(i); - key->insert(key->end(), str.begin(), str.end()); + class CacheKey : public std::vector { + public: + using std::vector::vector; + + template + CacheKey& Record(const T& t) { + CacheKeySerializer::Serialize(this, t); + return *this; + } + template + CacheKey& Record(const T& t, const Args&... args) { + CacheKeySerializer::Serialize(this, t); + return Record(args...); + } + + // Records iterables by prepending the number of elements. Some common iterables are have a + // CacheKeySerializer implemented to avoid needing to split them out when recording, i.e. + // strings and CacheKeys, but they fundamentally do the same as this function. + template + CacheKey& RecordIterable(const IterableT& iterable) { + // Always record the size of generic iterables as a size_t for now. + Record(static_cast(iterable.size())); + for (auto it = iterable.begin(); it != iterable.end(); ++it) { + Record(*it); + } + return *this; + } + template + CacheKey& RecordIterable(const Ptr* ptr, size_t n) { + Record(n); + for (size_t i = 0; i < n; ++i) { + Record(ptr[i]); + } + return *this; } }; - // Specialized overload for floating point types. Note that we are currently serializing as a - // string to avoid handling null termiantors. - template - struct CacheKeySerializer>> { - static void Serialize(CacheKey* key, const Float f) { - std::string str = std::to_string(f); - key->insert(key->end(), str.begin(), str.end()); + // Specialized overload for fundamental types. + template + class CacheKeySerializer>> { + public: + static void Serialize(CacheKey* key, const T t) { + const char* it = reinterpret_cast(&t); + key->insert(key->end(), it, (it + sizeof(T))); } }; // Specialized overload for string literals. Note we drop the null-terminator. template - struct CacheKeySerializer { + class CacheKeySerializer { + public: static void Serialize(CacheKey* key, const char (&t)[N]) { - std::string len = std::to_string(N - 1); - key->insert(key->end(), len.begin(), len.end()); - key->push_back('"'); - key->insert(key->end(), t, t + N - 1); - key->push_back('"'); + static_assert(N > 0); + key->Record(static_cast(N)); + key->insert(key->end(), t, t + N); } }; - // Helper template function that defers to underlying static functions. - template - void SerializeInto(CacheKey* key, const T& t) { - CacheKeySerializer::Serialize(key, t); - } - - // Given list of arguments of types with a free implementation of SerializeIntoImpl in the - // dawn::native namespace, serializes each argument and appends them to the CacheKey while - // prepending member ids before each argument. - template - CacheKey GetCacheKey(const Ts&... inputs) { - CacheKey key; - key.push_back('{'); - int memberId = 0; - auto Serialize = [&](const auto& input) { - std::string memberIdStr = (memberId == 0 ? "" : ",") + std::to_string(memberId) + ":"; - key.insert(key.end(), memberIdStr.begin(), memberIdStr.end()); - SerializeInto(&key, input); - memberId++; - }; - (Serialize(inputs), ...); - key.push_back('}'); - return key; - } - } // namespace dawn::native #endif // DAWNNATIVE_CACHE_KEY_H_ diff --git a/src/dawn/native/CachedObject.cpp b/src/dawn/native/CachedObject.cpp index 538b7b5ccf..e7e7cd84d5 100644 --- a/src/dawn/native/CachedObject.cpp +++ b/src/dawn/native/CachedObject.cpp @@ -42,26 +42,12 @@ namespace dawn::native { mIsContentHashInitialized = true; } - const std::string& CachedObject::GetCacheKey() const { - ASSERT(mIsCacheKeyBaseInitialized); - return mCacheKeyBase; + const CacheKey& CachedObject::GetCacheKey() const { + return mCacheKey; } - std::string CachedObject::GetCacheKey(DeviceBase* device) const { - ASSERT(mIsCacheKeyBaseInitialized); - // TODO(dawn:549) Prepend/append with device/adapter information. - return mCacheKeyBase; - } - - void CachedObject::SetCacheKey(const std::string& cacheKey) { - ASSERT(!mIsContentHashInitialized); - mCacheKeyBase = cacheKey; - mIsCacheKeyBaseInitialized = true; - } - - std::string CachedObject::ComputeCacheKeyBase() const { - // This implementation should never be called. Only overrides should be called. - UNREACHABLE(); + CacheKey* CachedObject::GetCacheKey() { + return &mCacheKey; } } // namespace dawn::native diff --git a/src/dawn/native/CachedObject.h b/src/dawn/native/CachedObject.h index 3cf79a2190..7d28ae8f12 100644 --- a/src/dawn/native/CachedObject.h +++ b/src/dawn/native/CachedObject.h @@ -15,6 +15,7 @@ #ifndef DAWNNATIVE_CACHED_OBJECT_H_ #define DAWNNATIVE_CACHED_OBJECT_H_ +#include "dawn/native/CacheKey.h" #include "dawn/native/Forward.h" #include @@ -38,13 +39,12 @@ namespace dawn::native { size_t GetContentHash() const; void SetContentHash(size_t contentHash); - // Two versions of GetCacheKey, when passed a device, prepends the stored cache - // key base with device and adapter information. When called without passing a - // device, returns the stored cache key base. This is useful when the instance - // is a member to a parent class. - const std::string& GetCacheKey() const; - std::string GetCacheKey(DeviceBase* device) const; - void SetCacheKey(const std::string& cacheKey); + // Returns the cache key for the object only, i.e. without device/adapter information. + const CacheKey& GetCacheKey() const; + + protected: + // Protected accessor for derived classes to access and modify the key. + CacheKey* GetCacheKey(); private: friend class DeviceBase; @@ -55,13 +55,9 @@ namespace dawn::native { // Called by ObjectContentHasher upon creation to record the object. virtual size_t ComputeContentHash() = 0; - // Not all classes implement cache key computation, so by default we assert. - virtual std::string ComputeCacheKeyBase() const; - size_t mContentHash = 0; bool mIsContentHashInitialized = false; - std::string mCacheKeyBase = ""; - bool mIsCacheKeyBaseInitialized = false; + CacheKey mCacheKey; }; } // namespace dawn::native diff --git a/src/dawn/tests/unittests/native/CacheKeyTests.cpp b/src/dawn/tests/unittests/native/CacheKeyTests.cpp index 3228e1e658..45fd360279 100644 --- a/src/dawn/tests/unittests/native/CacheKeyTests.cpp +++ b/src/dawn/tests/unittests/native/CacheKeyTests.cpp @@ -15,73 +15,168 @@ #include #include +#include +#include #include #include "dawn/native/CacheKey.h" namespace dawn::native { - // Testing classes/structs with serializing implemented for testing. - struct A {}; + // Testing classes with mock serializing implemented for testing. + class A { + public: + MOCK_METHOD(void, SerializeMock, (CacheKey*, const A&), (const)); + }; template <> void CacheKeySerializer::Serialize(CacheKey* key, const A& t) { - std::string str = "structA"; - key->insert(key->end(), str.begin(), str.end()); + t.SerializeMock(key, t); } - class B {}; - template <> - void CacheKeySerializer::Serialize(CacheKey* key, const B& t) { - std::string str = "classB"; - key->insert(key->end(), str.begin(), str.end()); + // Custom printer for CacheKey for clearer debug testing messages. + void PrintTo(const CacheKey& key, std::ostream* stream) { + *stream << std::hex; + for (const int b : key) { + *stream << std::setfill('0') << std::setw(2) << b << " "; + } + *stream << std::dec; } namespace { - // Matcher to compare CacheKey to a string for easier testing. - MATCHER_P(CacheKeyEq, - key, - "cache key " + std::string(negation ? "not" : "") + "equal to " + key) { - return std::string(arg.begin(), arg.end()) == key; + using ::testing::InSequence; + using ::testing::NotNull; + using ::testing::PrintToString; + using ::testing::Ref; + + // Matcher to compare CacheKeys for easier testing. + MATCHER_P(CacheKeyEq, key, PrintToString(key)) { + return memcmp(arg.data(), key.data(), arg.size()) == 0; } - TEST(CacheKeyTest, IntegralTypes) { - EXPECT_THAT(GetCacheKey((int)-1), CacheKeyEq("{0:-1}")); - EXPECT_THAT(GetCacheKey((uint8_t)2), CacheKeyEq("{0:2}")); - EXPECT_THAT(GetCacheKey((uint16_t)4), CacheKeyEq("{0:4}")); - EXPECT_THAT(GetCacheKey((uint32_t)8), CacheKeyEq("{0:8}")); - EXPECT_THAT(GetCacheKey((uint64_t)16), CacheKeyEq("{0:16}")); + TEST(CacheKeyTests, RecordSingleMember) { + CacheKey key; - EXPECT_THAT(GetCacheKey((int)-1, (uint8_t)2, (uint16_t)4, (uint32_t)8, (uint64_t)16), - CacheKeyEq("{0:-1,1:2,2:4,3:8,4:16}")); + A a; + EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1); + EXPECT_THAT(key.Record(a), CacheKeyEq(CacheKey())); } - TEST(CacheKeyTest, FloatingTypes) { - EXPECT_THAT(GetCacheKey((float)0.5), CacheKeyEq("{0:0.500000}")); - EXPECT_THAT(GetCacheKey((double)32.0), CacheKeyEq("{0:32.000000}")); + TEST(CacheKeyTests, RecordManyMembers) { + constexpr size_t kNumMembers = 100; - EXPECT_THAT(GetCacheKey((float)0.5, (double)32.0), - CacheKeyEq("{0:0.500000,1:32.000000}")); + CacheKey key; + for (size_t i = 0; i < kNumMembers; ++i) { + A a; + EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1); + key.Record(a); + } + EXPECT_THAT(key, CacheKeyEq(CacheKey())); } - TEST(CacheKeyTest, Strings) { - std::string str0 = "string0"; - std::string str1 = "string1"; + TEST(CacheKeyTests, RecordIterable) { + constexpr size_t kIterableSize = 100; - EXPECT_THAT(GetCacheKey("string0"), CacheKeyEq(R"({0:7"string0"})")); - EXPECT_THAT(GetCacheKey(str0), CacheKeyEq(R"({0:7"string0"})")); - EXPECT_THAT(GetCacheKey("string0", str1), CacheKeyEq(R"({0:7"string0",1:7"string1"})")); + // Expecting the size of the container. + CacheKey expected; + expected.Record(kIterableSize); + + std::vector iterable(kIterableSize); + { + InSequence seq; + for (const auto& a : iterable) { + EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1); + } + for (const auto& a : iterable) { + EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1); + } + } + + EXPECT_THAT(CacheKey().RecordIterable(iterable), CacheKeyEq(expected)); + EXPECT_THAT(CacheKey().RecordIterable(iterable.data(), kIterableSize), + CacheKeyEq(expected)); } - TEST(CacheKeyTest, NestedCacheKey) { - EXPECT_THAT(GetCacheKey(GetCacheKey((int)-1)), CacheKeyEq("{0:{0:-1}}")); - EXPECT_THAT(GetCacheKey(GetCacheKey("string")), CacheKeyEq(R"({0:{0:6"string"}})")); - EXPECT_THAT(GetCacheKey(GetCacheKey(A{})), CacheKeyEq("{0:{0:structA}}")); - EXPECT_THAT(GetCacheKey(GetCacheKey(B())), CacheKeyEq("{0:{0:classB}}")); + TEST(CacheKeyTests, RecordNested) { + CacheKey expected; + CacheKey actual; + { + // Recording a single member. + A a; + EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1); + actual.Record(CacheKey().Record(a)); + } + { + // Recording multiple members. + constexpr size_t kNumMembers = 2; + CacheKey sub; + for (size_t i = 0; i < kNumMembers; ++i) { + A a; + EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1); + sub.Record(a); + } + actual.Record(sub); + } + { + // Record an iterable. + constexpr size_t kIterableSize = 2; + expected.Record(kIterableSize); + std::vector iterable(kIterableSize); + { + InSequence seq; + for (const auto& a : iterable) { + EXPECT_CALL(a, SerializeMock(NotNull(), Ref(a))).Times(1); + } + } + actual.Record(CacheKey().RecordIterable(iterable)); + } + EXPECT_THAT(actual, CacheKeyEq(expected)); + } - EXPECT_THAT(GetCacheKey(GetCacheKey((int)-1), GetCacheKey("string"), GetCacheKey(A{}), - GetCacheKey(B())), - CacheKeyEq(R"({0:{0:-1},1:{0:6"string"},2:{0:structA},3:{0:classB}})")); + TEST(CacheKeySerializerTests, IntegralTypes) { + // Only testing explicitly sized types for simplicity, and using 0s for larger types to + // avoid dealing with endianess. + EXPECT_THAT(CacheKey().Record('c'), CacheKeyEq(CacheKey({'c'}))); + EXPECT_THAT(CacheKey().Record(uint8_t(255)), CacheKeyEq(CacheKey({255}))); + EXPECT_THAT(CacheKey().Record(uint16_t(0)), CacheKeyEq(CacheKey({0, 0}))); + EXPECT_THAT(CacheKey().Record(uint32_t(0)), CacheKeyEq(CacheKey({0, 0, 0, 0}))); + } + + TEST(CacheKeySerializerTests, FloatingTypes) { + // Using 0s to avoid dealing with implementation specific float details. + EXPECT_THAT(CacheKey().Record(float(0)), CacheKeyEq(CacheKey(sizeof(float), 0))); + EXPECT_THAT(CacheKey().Record(double(0)), CacheKeyEq(CacheKey(sizeof(double), 0))); + } + + TEST(CacheKeySerializerTests, LiteralStrings) { + // Using a std::string here to help with creating the expected result. + std::string str = "string"; + + CacheKey expected; + expected.Record(size_t(7)); + expected.insert(expected.end(), str.begin(), str.end()); + expected.push_back('\0'); + + EXPECT_THAT(CacheKey().Record("string"), CacheKeyEq(expected)); + } + + TEST(CacheKeySerializerTests, StdStrings) { + std::string str = "string"; + + CacheKey expected; + expected.Record((size_t)6); + expected.insert(expected.end(), str.begin(), str.end()); + + EXPECT_THAT(CacheKey().Record(str), CacheKeyEq(expected)); + } + + TEST(CacheKeySerializerTests, CacheKeys) { + CacheKey data = {'d', 'a', 't', 'a'}; + + CacheKey expected; + expected.insert(expected.end(), data.begin(), data.end()); + + EXPECT_THAT(CacheKey().Record(data), CacheKeyEq(expected)); } } // namespace