From d3fa3f0e23cf8da1db77a464260d0fa8dffc3c4e Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Sat, 11 Jun 2022 03:50:33 +0000 Subject: [PATCH] Add CacheRequest utilities and tests This CL adds a DAWN_MAKE_CACHE_REQUEST X macro which helps in building a CacheRequest struct. A CacheRequest struct may be passed to LoadOrRun which will generate a CacheKey from the struct and load a result if there is a cache hit, or it will call the provided cache miss function to compute a value. The request struct helps enforce that precisely the inputs that go into a computation are all also included inside the CacheKey for that computation. Bug: dawn:549 Change-Id: Id85eb95f1b944d5431f142162ffa9a384351be89 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91063 Reviewed-by: Loko Kung Commit-Queue: Austin Eng --- src/dawn/native/BUILD.gn | 3 + src/dawn/native/CMakeLists.txt | 3 + src/dawn/native/CacheKey.h | 34 ++ src/dawn/native/CacheRequest.cpp | 25 ++ src/dawn/native/CacheRequest.h | 186 ++++++++++ src/dawn/native/CacheResult.h | 76 +++++ src/dawn/tests/BUILD.gn | 12 +- src/dawn/tests/DawnNativeTest.cpp | 10 +- src/dawn/tests/DawnNativeTest.h | 3 +- src/dawn/tests/end2end/D3D12CachingTests.cpp | 2 +- .../tests/end2end/PipelineCachingTests.cpp | 2 +- .../platform}/CachingInterfaceMock.cpp | 2 +- .../platform}/CachingInterfaceMock.h | 6 +- .../unittests/native/CacheRequestTests.cpp | 320 ++++++++++++++++++ 14 files changed, 671 insertions(+), 13 deletions(-) create mode 100644 src/dawn/native/CacheRequest.cpp create mode 100644 src/dawn/native/CacheRequest.h create mode 100644 src/dawn/native/CacheResult.h rename src/dawn/tests/{end2end/mocks => mocks/platform}/CachingInterfaceMock.cpp (97%) rename src/dawn/tests/{end2end/mocks => mocks/platform}/CachingInterfaceMock.h (93%) create mode 100644 src/dawn/tests/unittests/native/CacheRequestTests.cpp diff --git a/src/dawn/native/BUILD.gn b/src/dawn/native/BUILD.gn index 4294b26cf1..137ce1b2dd 100644 --- a/src/dawn/native/BUILD.gn +++ b/src/dawn/native/BUILD.gn @@ -200,6 +200,9 @@ source_set("sources") { "Buffer.h", "CacheKey.cpp", "CacheKey.h", + "CacheRequest.cpp", + "CacheRequest.h", + "CacheResult.h", "CachedObject.cpp", "CachedObject.h", "CallbackTaskManager.cpp", diff --git a/src/dawn/native/CMakeLists.txt b/src/dawn/native/CMakeLists.txt index bbb50a04d3..8542b337e9 100644 --- a/src/dawn/native/CMakeLists.txt +++ b/src/dawn/native/CMakeLists.txt @@ -59,6 +59,9 @@ target_sources(dawn_native PRIVATE "CachedObject.h" "CacheKey.cpp" "CacheKey.h" + "CacheRequest.cpp" + "CacheRequest.h" + "CacheResult.h" "CallbackTaskManager.cpp" "CallbackTaskManager.h" "CommandAllocator.cpp" diff --git a/src/dawn/native/CacheKey.h b/src/dawn/native/CacheKey.h index 357ce4b325..c2901db40f 100644 --- a/src/dawn/native/CacheKey.h +++ b/src/dawn/native/CacheKey.h @@ -18,8 +18,10 @@ #include #include #include +#include #include #include +#include #include #include "dawn/common/TypedInteger.h" @@ -48,6 +50,19 @@ class CacheKey : public std::vector { enum class Type { ComputePipeline, RenderPipeline, Shader }; + template + class UnsafeUnkeyedValue { + public: + UnsafeUnkeyedValue() = default; + // NOLINTNEXTLINE(runtime/explicit) allow implicit construction to decrease verbosity + UnsafeUnkeyedValue(T&& value) : mValue(std::forward(value)) {} + + const T& UnsafeGetValue() const { return mValue; } + + private: + T mValue; + }; + template CacheKey& Record(const T& t) { CacheKeySerializer::Serialize(this, t); @@ -89,6 +104,18 @@ class CacheKey : public std::vector { } }; +template +CacheKey::UnsafeUnkeyedValue UnsafeUnkeyedValue(T&& value) { + return CacheKey::UnsafeUnkeyedValue(std::forward(value)); +} + +// Specialized overload for CacheKey::UnsafeIgnoredValue which does nothing. +template +class CacheKeySerializer> { + public: + constexpr static void Serialize(CacheKey* key, const CacheKey::UnsafeUnkeyedValue&) {} +}; + // Specialized overload for fundamental types. template class CacheKeySerializer>> { @@ -197,6 +224,13 @@ class CacheKeySerializer> static void Serialize(CacheKey* key, const T& t) { key->Record(t.GetCacheKey()); } }; +// Specialized overload for std::vector. +template +class CacheKeySerializer> { + public: + static void Serialize(CacheKey* key, const std::vector& t) { key->RecordIterable(t); } +}; + } // namespace dawn::native #endif // SRC_DAWN_NATIVE_CACHEKEY_H_ diff --git a/src/dawn/native/CacheRequest.cpp b/src/dawn/native/CacheRequest.cpp new file mode 100644 index 0000000000..2b35b1224d --- /dev/null +++ b/src/dawn/native/CacheRequest.cpp @@ -0,0 +1,25 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn/native/CacheRequest.h" + +#include "dawn/common/Log.h" + +namespace dawn::native::detail { + +void LogCacheHitError(std::unique_ptr error) { + dawn::ErrorLog() << error->GetFormattedMessage(); +} + +} // namespace dawn::native::detail diff --git a/src/dawn/native/CacheRequest.h b/src/dawn/native/CacheRequest.h new file mode 100644 index 0000000000..0fecb0da57 --- /dev/null +++ b/src/dawn/native/CacheRequest.h @@ -0,0 +1,186 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_DAWN_NATIVE_CACHEREQUEST_H_ +#define SRC_DAWN_NATIVE_CACHEREQUEST_H_ + +#include +#include + +#include "dawn/common/Assert.h" +#include "dawn/common/Compiler.h" +#include "dawn/native/Blob.h" +#include "dawn/native/BlobCache.h" +#include "dawn/native/CacheKey.h" +#include "dawn/native/CacheResult.h" +#include "dawn/native/Device.h" +#include "dawn/native/Error.h" + +namespace dawn::native { + +namespace detail { + +template +struct UnwrapResultOrError { + using type = T; +}; + +template +struct UnwrapResultOrError> { + using type = T; +}; + +template +struct IsResultOrError { + static constexpr bool value = false; +}; + +template +struct IsResultOrError> { + static constexpr bool value = true; +}; + +void LogCacheHitError(std::unique_ptr error); + +} // namespace detail + +// Implementation of a CacheRequest which provides a LoadOrRun friend function which can be found +// via argument-dependent lookup. So, it doesn't need to be called with a fully qualified function +// name. +// +// Example usage: +// Request r = { ... }; +// ResultOrError> cacheResult = +// LoadOrRun(device, std::move(r), +// [](Blob blob) -> T { /* handle cache hit */ }, +// [](Request r) -> ResultOrError { /* handle cache miss */ } +// ); +// Or with free functions: +/// T OnCacheHit(Blob blob) { ... } +// ResultOrError OnCacheMiss(Request r) { ... } +// // ... +// Request r = { ... }; +// auto result = LoadOrRun(device, std::move(r), OnCacheHit, OnCacheMiss); +// +// LoadOrRun generates a CacheKey from the request and loads from the device's BlobCache. On cache +// hit, calls CacheHitFn and returns a CacheResult. On cache miss or if CacheHitFn returned an +// Error, calls CacheMissFn -> ResultOrError with the request data and returns a +// ResultOrError>. CacheHitFn must return the same unwrapped type as CacheMissFn. +// i.e. it doesn't need to be wrapped in ResultOrError. +// +// CacheMissFn may not have any additional data bound to it. It may not be a lambda or std::function +// which captures additional information, so it can only operate on the request data. This is +// enforced with a compile-time static_assert, and ensures that the result created from the +// computation is exactly the data included in the CacheKey. +template +class CacheRequestImpl { + public: + CacheRequestImpl() = default; + + // Require CacheRequests to be move-only to avoid unnecessary copies. + CacheRequestImpl(CacheRequestImpl&&) = default; + CacheRequestImpl& operator=(CacheRequestImpl&&) = default; + CacheRequestImpl(const CacheRequestImpl&) = delete; + CacheRequestImpl& operator=(const CacheRequestImpl&) = delete; + + template + friend auto LoadOrRun(DeviceBase* device, + Request&& r, + CacheHitFn cacheHitFn, + CacheMissFn cacheMissFn) { + // Get return types and check that CacheMissReturnType can be cast to a raw function + // pointer. This means it's not a std::function or lambda that captures additional data. + using CacheHitReturnType = decltype(cacheHitFn(std::declval())); + using CacheMissReturnType = decltype(cacheMissFn(std::declval())); + static_assert( + std::is_convertible_v, + "CacheMissFn function signature does not match, or it is not a free function."); + + static_assert(detail::IsResultOrError::value, + "CacheMissFn should return a ResultOrError."); + using UnwrappedReturnType = typename detail::UnwrapResultOrError::type; + + static_assert(std::is_same_v::type, + UnwrappedReturnType>, + "If CacheMissFn returns T, CacheHitFn must return T or ResultOrError."); + + using CacheResultType = CacheResult; + using ReturnType = ResultOrError; + + CacheKey key = r.CreateCacheKey(device); + BlobCache* cache = device->GetBlobCache(); + Blob blob; + if (cache != nullptr) { + blob = cache->Load(key); + } + + if (!blob.Empty()) { + // Cache hit. Handle the cached blob. + auto result = cacheHitFn(std::move(blob)); + + if constexpr (!detail::IsResultOrError::value) { + // If the result type is not a ResultOrError, return it. + return ReturnType(CacheResultType::CacheHit(std::move(key), std::move(result))); + } else { + // Otherwise, if the value is a success, also return it. + if (DAWN_LIKELY(result.IsSuccess())) { + return ReturnType( + CacheResultType::CacheHit(std::move(key), result.AcquireSuccess())); + } + // On error, continue to the cache miss path and log the error. + detail::LogCacheHitError(result.AcquireError()); + } + } + // Cache miss, or the CacheHitFn failed. + auto result = cacheMissFn(std::move(r)); + if (DAWN_LIKELY(result.IsSuccess())) { + return ReturnType(CacheResultType::CacheMiss(std::move(key), result.AcquireSuccess())); + } + return ReturnType(result.AcquireError()); + } +}; + +// Helper for X macro to declare a struct member. +#define DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER(type, name) type name{}; + +// Helper for X macro for recording cache request fields into a CacheKey. +#define DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY(type, name) key.Record(name); + +// Helper X macro to define a CacheRequest struct. +// Example usage: +// #define REQUEST_MEMBERS(X) \ +// X(int, a) \ +// X(float, b) \ +// X(Foo, foo) \ +// X(Bar, bar) +// DAWN_MAKE_CACHE_REQUEST(MyCacheRequest, REQUEST_MEMBERS) +// #undef REQUEST_MEMBERS +#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \ + class Request : public CacheRequestImpl { \ + public: \ + Request() = default; \ + MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER) \ + \ + /* Create a CacheKey from the request type and all members */ \ + CacheKey CreateCacheKey(const DeviceBase* device) const { \ + CacheKey key = device->GetCacheKey(); \ + key.Record(#Request); \ + MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \ + return key; \ + } \ + }; + +} // namespace dawn::native + +#endif // SRC_DAWN_NATIVE_CACHEREQUEST_H_ diff --git a/src/dawn/native/CacheResult.h b/src/dawn/native/CacheResult.h new file mode 100644 index 0000000000..a2750fecba --- /dev/null +++ b/src/dawn/native/CacheResult.h @@ -0,0 +1,76 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_DAWN_NATIVE_CACHERESULT_H_ +#define SRC_DAWN_NATIVE_CACHERESULT_H_ + +#include +#include + +#include "dawn/common/Assert.h" + +namespace dawn::native { + +template +class CacheResult { + public: + static CacheResult CacheHit(CacheKey key, T value) { + return CacheResult(std::move(key), std::move(value), true); + } + + static CacheResult CacheMiss(CacheKey key, T value) { + return CacheResult(std::move(key), std::move(value), false); + } + + CacheResult() : mKey(), mValue(), mIsCached(false), mIsValid(false) {} + + bool IsCached() const { + ASSERT(mIsValid); + return mIsCached; + } + const CacheKey& GetCacheKey() { + ASSERT(mIsValid); + return mKey; + } + + // Note: Getting mValue is always const, since mutating it would invalidate consistency with + // mKey. + const T* operator->() const { + ASSERT(mIsValid); + return &mValue; + } + const T& operator*() const { + ASSERT(mIsValid); + return mValue; + } + + T Acquire() { + ASSERT(mIsValid); + mIsValid = false; + return std::move(mValue); + } + + private: + CacheResult(CacheKey key, T value, bool isCached) + : mKey(std::move(key)), mValue(std::move(value)), mIsCached(isCached), mIsValid(true) {} + + CacheKey mKey; + T mValue; + bool mIsCached; + bool mIsValid; +}; + +} // namespace dawn::native + +#endif // SRC_DAWN_NATIVE_CACHERESULT_H_ diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn index 69c41a2e1a..57bb666be1 100644 --- a/src/dawn/tests/BUILD.gn +++ b/src/dawn/tests/BUILD.gn @@ -192,6 +192,7 @@ dawn_test("dawn_unittests") { ":gmock_and_gtest", ":mock_webgpu_gen", ":native_mocks_sources", + ":platform_mocks_sources", "${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:proc", "${dawn_root}/src/dawn/common", @@ -254,6 +255,7 @@ dawn_test("dawn_unittests") { "unittests/VersionTests.cpp", "unittests/native/BlobTests.cpp", "unittests/native/CacheKeyTests.cpp", + "unittests/native/CacheRequestTests.cpp", "unittests/native/CommandBufferEncodingTests.cpp", "unittests/native/CreatePipelineAsyncTaskTests.cpp", "unittests/native/DestroyObjectTests.cpp", @@ -380,9 +382,9 @@ source_set("test_infra_sources") { # Dawn end2end tests targets ############################################################################### -# Source code for mocks used for end2end testing are separated from the rest of +# Source code for mocks used for platform testing are separated from the rest of # sources so that they aren't included in non-test builds. -source_set("end2end_mocks_sources") { +source_set("platform_mocks_sources") { configs += [ "${dawn_root}/src/dawn/native:internal" ] testonly = true @@ -392,8 +394,8 @@ source_set("end2end_mocks_sources") { ] sources = [ - "end2end/mocks/CachingInterfaceMock.cpp", - "end2end/mocks/CachingInterfaceMock.h", + "mocks/platform/CachingInterfaceMock.cpp", + "mocks/platform/CachingInterfaceMock.h", ] } @@ -401,7 +403,7 @@ source_set("end2end_tests_sources") { testonly = true deps = [ - ":end2end_mocks_sources", + ":platform_mocks_sources", ":test_infra_sources", "${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:proc", diff --git a/src/dawn/tests/DawnNativeTest.cpp b/src/dawn/tests/DawnNativeTest.cpp index 163413dc1e..fbf8030ab3 100644 --- a/src/dawn/tests/DawnNativeTest.cpp +++ b/src/dawn/tests/DawnNativeTest.cpp @@ -20,6 +20,9 @@ #include "dawn/common/Assert.h" #include "dawn/dawn_proc.h" #include "dawn/native/ErrorData.h" +#include "dawn/native/Instance.h" +#include "dawn/native/dawn_platform.h" +#include "dawn/platform/DawnPlatform.h" namespace dawn::native { @@ -43,6 +46,9 @@ DawnNativeTest::~DawnNativeTest() { void DawnNativeTest::SetUp() { instance = std::make_unique(); + platform = CreateTestPlatform(); + dawn::native::FromAPI(instance->Get())->SetPlatformForTesting(platform.get()); + instance->DiscoverDefaultAdapters(); std::vector adapters = instance->GetAdapters(); @@ -66,7 +72,9 @@ void DawnNativeTest::SetUp() { device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr); } -void DawnNativeTest::TearDown() {} +std::unique_ptr DawnNativeTest::CreateTestPlatform() { + return nullptr; +} WGPUDevice DawnNativeTest::CreateTestDevice() { // Disabled disallowing unsafe APIs so we can test them. diff --git a/src/dawn/tests/DawnNativeTest.h b/src/dawn/tests/DawnNativeTest.h index e92bf67f4a..dd3532f723 100644 --- a/src/dawn/tests/DawnNativeTest.h +++ b/src/dawn/tests/DawnNativeTest.h @@ -38,12 +38,13 @@ class DawnNativeTest : public ::testing::Test { ~DawnNativeTest() override; void SetUp() override; - void TearDown() override; + virtual std::unique_ptr CreateTestPlatform(); virtual WGPUDevice CreateTestDevice(); protected: std::unique_ptr instance; + std::unique_ptr platform; dawn::native::Adapter adapter; wgpu::Device device; diff --git a/src/dawn/tests/end2end/D3D12CachingTests.cpp b/src/dawn/tests/end2end/D3D12CachingTests.cpp index 9cd404226e..0d3bcdf024 100644 --- a/src/dawn/tests/end2end/D3D12CachingTests.cpp +++ b/src/dawn/tests/end2end/D3D12CachingTests.cpp @@ -17,7 +17,7 @@ #include #include "dawn/tests/DawnTest.h" -#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/WGPUHelpers.h" diff --git a/src/dawn/tests/end2end/PipelineCachingTests.cpp b/src/dawn/tests/end2end/PipelineCachingTests.cpp index 23ee708002..8318c48776 100644 --- a/src/dawn/tests/end2end/PipelineCachingTests.cpp +++ b/src/dawn/tests/end2end/PipelineCachingTests.cpp @@ -16,7 +16,7 @@ #include #include "dawn/tests/DawnTest.h" -#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/WGPUHelpers.h" diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp similarity index 97% rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp index 8507e9cffb..a52d4c2cec 100644 --- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp +++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.cpp @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" using ::testing::Invoke; diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h similarity index 93% rename from src/dawn/tests/end2end/mocks/CachingInterfaceMock.h rename to src/dawn/tests/mocks/platform/CachingInterfaceMock.h index cc61d80fc2..0e9e6aff0e 100644 --- a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h +++ b/src/dawn/tests/mocks/platform/CachingInterfaceMock.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ -#define SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ +#ifndef SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_ +#define SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_ #include #include @@ -70,4 +70,4 @@ class DawnCachingMockPlatform : public dawn::platform::Platform { dawn::platform::CachingInterface* mCachingInterface = nullptr; }; -#endif // SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ +#endif // SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_ diff --git a/src/dawn/tests/unittests/native/CacheRequestTests.cpp b/src/dawn/tests/unittests/native/CacheRequestTests.cpp new file mode 100644 index 0000000000..995de7f520 --- /dev/null +++ b/src/dawn/tests/unittests/native/CacheRequestTests.cpp @@ -0,0 +1,320 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "dawn/native/Blob.h" +#include "dawn/native/CacheRequest.h" +#include "dawn/tests/DawnNativeTest.h" +#include "dawn/tests/mocks/platform/CachingInterfaceMock.h" + +namespace dawn::native { + +namespace { + +using ::testing::_; +using ::testing::ByMove; +using ::testing::Invoke; +using ::testing::MockFunction; +using ::testing::Return; +using ::testing::StrictMock; +using ::testing::WithArg; + +class CacheRequestTests : public DawnNativeTest { + protected: + std::unique_ptr CreateTestPlatform() override { + return std::make_unique(&mMockCache); + } + + WGPUDevice CreateTestDevice() override { + wgpu::DeviceDescriptor deviceDescriptor = {}; + wgpu::DawnTogglesDeviceDescriptor togglesDesc = {}; + deviceDescriptor.nextInChain = &togglesDesc; + + const char* toggle = "enable_blob_cache"; + togglesDesc.forceEnabledToggles = &toggle; + togglesDesc.forceEnabledTogglesCount = 1; + + return adapter.CreateDevice(&deviceDescriptor); + } + + DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); } + + StrictMock mMockCache; +}; + +struct Foo { + int value; +}; + +#define REQUEST_MEMBERS(X) \ + X(int, a) \ + X(float, b) \ + X(std::vector, c) \ + X(CacheKey::UnsafeUnkeyedValue, d) \ + X(CacheKey::UnsafeUnkeyedValue, e) + +DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS) + +#undef REQUEST_MEMBERS + +// static_assert the expected types for various return types from the cache hit handler and cache +// miss handler. +TEST_F(CacheRequestTests, CacheResultTypes) { + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0)); + + // (int, ResultOrError), should be ResultOrError>. + auto v1 = LoadOrRun( + GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; }, + [](CacheRequestForTesting) -> ResultOrError { return 1; }); + v1.AcquireSuccess(); + static_assert(std::is_same_v>, decltype(v1)>); + + // (ResultOrError, ResultOrError), should be ResultOrError>. + auto v2 = LoadOrRun( + GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError { return 0.0; }, + [](CacheRequestForTesting) -> ResultOrError { return 1.0; }); + v2.AcquireSuccess(); + static_assert(std::is_same_v>, decltype(v2)>); +} + +// Test that using a CacheRequest builds a key from the device key, the request type enum, and all +// of the request members. +TEST_F(CacheRequestTests, MakesCacheKey) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + // Make the expected key. + CacheKey expectedKey; + expectedKey.Record(GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, req.c); + + // Expect a call to LoadData with the expected key. + EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0)) + .WillOnce(WithArg<0>(Invoke([&](const void* actualKeyData) { + EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0); + return 0; + }))); + + // Load the request. + auto result = LoadOrRun( + GetDevice(), std::move(req), [](Blob) -> int { return 0; }, + [](CacheRequestForTesting) -> ResultOrError { return 0; }) + .AcquireSuccess(); + + // The created cache key should be saved on the result. + EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size()); + EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0); +} + +// Test that members that are wrapped in UnsafeUnkeyedValue do not impact the key. +TEST_F(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) { + // Make two requests with different UnsafeUnkeyedValues (UnsafeUnkeyed is declared on the struct + // definition). + int v1, v2; + CacheRequestForTesting req1; + req1.d = &v1; + req1.e = Foo{42}; + + CacheRequestForTesting req2; + req2.d = &v2; + req2.e = Foo{24}; + + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0)); + + static StrictMock> cacheMissFn; + + // Load the first request, and check that the unsafe unkeyed values were passed though + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) { + EXPECT_EQ(req.d.UnsafeGetValue(), &v1); + EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42); + return 0; + }))); + auto r1 = LoadOrRun( + GetDevice(), std::move(req1), [](Blob) { return 0; }, + [](CacheRequestForTesting req) -> ResultOrError { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Load the second request, and check that the unsafe unkeyed values were passed though + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) { + EXPECT_EQ(req.d.UnsafeGetValue(), &v2); + EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24); + return 0; + }))); + auto r2 = LoadOrRun( + GetDevice(), std::move(req2), [](Blob) { return 0; }, + [](CacheRequestForTesting req) -> ResultOrError { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect their keys to be the same. + EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size()); + EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0); +} + +// Test the expected code path when there is a cache miss. +TEST_F(CacheRequestTests, CacheMiss) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + unsigned int* cPtr = req.c.data(); + + static StrictMock> cacheHitFn; + static StrictMock> cacheMissFn; + + // Mock a cache miss. + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)); + + // Expect the cache miss, and return some value. + int rv = 42; + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) { + // Expect the request contents to be the same. The data pointer for |c| is also the same + // since it was moved. + EXPECT_EQ(req.a, 1); + EXPECT_FLOAT_EQ(req.b, 0.2); + EXPECT_EQ(req.c.data(), cPtr); + return rv; + }))); + + // Load the request. + auto result = LoadOrRun( + GetDevice(), std::move(req), + [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); }, + [](CacheRequestForTesting req) -> ResultOrError { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect the result to store the value. + EXPECT_EQ(*result, rv); + EXPECT_FALSE(result.IsCached()); +} + +// Test the expected code path when there is a cache hit. +TEST_F(CacheRequestTests, CacheHit) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + static StrictMock> cacheHitFn; + static StrictMock> cacheMissFn; + + static constexpr char kCachedData[] = "hello world!"; + + // Mock a cache hit, and load the cached data. + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData))); + EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData))) + .WillOnce(WithArg<2>(Invoke([](void* dataOut) { + memcpy(dataOut, kCachedData, sizeof(kCachedData)); + return sizeof(kCachedData); + }))); + + // Expect the cache hit, and return some value. + int rv = 1337; + EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) { + // Expect the cached blob contents to match the cached data. + EXPECT_EQ(blob.Size(), sizeof(kCachedData)); + EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0); + + return rv; + }))); + + // Load the request. + auto result = LoadOrRun( + GetDevice(), std::move(req), + [](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); }, + [](CacheRequestForTesting req) -> ResultOrError { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect the result to store the value. + EXPECT_EQ(*result, rv); + EXPECT_TRUE(result.IsCached()); +} + +// Test the expected code path when there is a cache hit but the handler errors. +TEST_F(CacheRequestTests, CacheHitError) { + // Make a request. + CacheRequestForTesting req; + req.a = 1; + req.b = 0.2; + req.c = {3, 4, 5}; + + unsigned int* cPtr = req.c.data(); + + static StrictMock(Blob)>> cacheHitFn; + static StrictMock> cacheMissFn; + + static constexpr char kCachedData[] = "hello world!"; + + // Mock a cache hit, and load the cached data. + EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData))); + EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData))) + .WillOnce(WithArg<2>(Invoke([](void* dataOut) { + memcpy(dataOut, kCachedData, sizeof(kCachedData)); + return sizeof(kCachedData); + }))); + + // Expect the cache hit. + EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) { + // Expect the cached blob contents to match the cached data. + EXPECT_EQ(blob.Size(), sizeof(kCachedData)); + EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0); + + // Return an error. + return DAWN_VALIDATION_ERROR("fake test error"); + }))); + + // Expect the cache miss handler since the cache hit errored. + int rv = 79; + EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) { + // Expect the request contents to be the same. The data pointer for |c| is also the same + // since it was moved. + EXPECT_EQ(req.a, 1); + EXPECT_FLOAT_EQ(req.b, 0.2); + EXPECT_EQ(req.c.data(), cPtr); + return rv; + }))); + + // Load the request. + auto result = + LoadOrRun( + GetDevice(), std::move(req), + [](Blob blob) -> ResultOrError { return cacheHitFn.Call(std::move(blob)); }, + [](CacheRequestForTesting req) -> ResultOrError { + return cacheMissFn.Call(std::move(req)); + }) + .AcquireSuccess(); + + // Expect the result to store the value. + EXPECT_EQ(*result, rv); + EXPECT_FALSE(result.IsCached()); +} + +} // namespace + +} // namespace dawn::native