Add CacheRequest utilities and tests

This CL adds a DAWN_MAKE_CACHE_REQUEST X macro
which helps in building a CacheRequest struct.

A CacheRequest struct may be passed to LoadOrRun
which will generate a CacheKey from the struct and
load a result if there is a cache hit, or it will
call the provided cache miss function to compute a value.

The request struct helps enforce that precisely the
inputs that go into a computation are all also included
inside the CacheKey for that computation.

Bug: dawn:549
Change-Id: Id85eb95f1b944d5431f142162ffa9a384351be89
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91063
Reviewed-by: Loko Kung <lokokung@google.com>
Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
Austin Eng 2022-06-11 03:50:33 +00:00 committed by Dawn LUCI CQ
parent 76d0454c66
commit d3fa3f0e23
14 changed files with 671 additions and 13 deletions

View File

@ -200,6 +200,9 @@ source_set("sources") {
"Buffer.h",
"CacheKey.cpp",
"CacheKey.h",
"CacheRequest.cpp",
"CacheRequest.h",
"CacheResult.h",
"CachedObject.cpp",
"CachedObject.h",
"CallbackTaskManager.cpp",

View File

@ -59,6 +59,9 @@ target_sources(dawn_native PRIVATE
"CachedObject.h"
"CacheKey.cpp"
"CacheKey.h"
"CacheRequest.cpp"
"CacheRequest.h"
"CacheResult.h"
"CallbackTaskManager.cpp"
"CallbackTaskManager.h"
"CommandAllocator.cpp"

View File

@ -18,8 +18,10 @@
#include <bitset>
#include <iostream>
#include <limits>
#include <memory>
#include <string>
#include <type_traits>
#include <utility>
#include <vector>
#include "dawn/common/TypedInteger.h"
@ -48,6 +50,19 @@ class CacheKey : public std::vector<uint8_t> {
enum class Type { ComputePipeline, RenderPipeline, Shader };
template <typename T>
class UnsafeUnkeyedValue {
public:
UnsafeUnkeyedValue() = default;
// NOLINTNEXTLINE(runtime/explicit) allow implicit construction to decrease verbosity
UnsafeUnkeyedValue(T&& value) : mValue(std::forward<T>(value)) {}
const T& UnsafeGetValue() const { return mValue; }
private:
T mValue;
};
template <typename T>
CacheKey& Record(const T& t) {
CacheKeySerializer<T>::Serialize(this, t);
@ -89,6 +104,18 @@ class CacheKey : public std::vector<uint8_t> {
}
};
template <typename T>
CacheKey::UnsafeUnkeyedValue<T> UnsafeUnkeyedValue(T&& value) {
return CacheKey::UnsafeUnkeyedValue<T>(std::forward<T>(value));
}
// Specialized overload for CacheKey::UnsafeIgnoredValue which does nothing.
template <typename T>
class CacheKeySerializer<CacheKey::UnsafeUnkeyedValue<T>> {
public:
constexpr static void Serialize(CacheKey* key, const CacheKey::UnsafeUnkeyedValue<T>&) {}
};
// Specialized overload for fundamental types.
template <typename T>
class CacheKeySerializer<T, std::enable_if_t<std::is_fundamental_v<T>>> {
@ -197,6 +224,13 @@ class CacheKeySerializer<T, std::enable_if_t<std::is_base_of_v<CachedObject, T>>
static void Serialize(CacheKey* key, const T& t) { key->Record(t.GetCacheKey()); }
};
// Specialized overload for std::vector.
template <typename T>
class CacheKeySerializer<std::vector<T>> {
public:
static void Serialize(CacheKey* key, const std::vector<T>& t) { key->RecordIterable(t); }
};
} // namespace dawn::native
#endif // SRC_DAWN_NATIVE_CACHEKEY_H_

View File

@ -0,0 +1,25 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn/native/CacheRequest.h"
#include "dawn/common/Log.h"
namespace dawn::native::detail {
void LogCacheHitError(std::unique_ptr<ErrorData> error) {
dawn::ErrorLog() << error->GetFormattedMessage();
}
} // namespace dawn::native::detail

View File

@ -0,0 +1,186 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_DAWN_NATIVE_CACHEREQUEST_H_
#define SRC_DAWN_NATIVE_CACHEREQUEST_H_
#include <memory>
#include <utility>
#include "dawn/common/Assert.h"
#include "dawn/common/Compiler.h"
#include "dawn/native/Blob.h"
#include "dawn/native/BlobCache.h"
#include "dawn/native/CacheKey.h"
#include "dawn/native/CacheResult.h"
#include "dawn/native/Device.h"
#include "dawn/native/Error.h"
namespace dawn::native {
namespace detail {
template <typename T>
struct UnwrapResultOrError {
using type = T;
};
template <typename T>
struct UnwrapResultOrError<ResultOrError<T>> {
using type = T;
};
template <typename T>
struct IsResultOrError {
static constexpr bool value = false;
};
template <typename T>
struct IsResultOrError<ResultOrError<T>> {
static constexpr bool value = true;
};
void LogCacheHitError(std::unique_ptr<ErrorData> error);
} // namespace detail
// Implementation of a CacheRequest which provides a LoadOrRun friend function which can be found
// via argument-dependent lookup. So, it doesn't need to be called with a fully qualified function
// name.
//
// Example usage:
// Request r = { ... };
// ResultOrError<CacheResult<T>> cacheResult =
// LoadOrRun(device, std::move(r),
// [](Blob blob) -> T { /* handle cache hit */ },
// [](Request r) -> ResultOrError<T> { /* handle cache miss */ }
// );
// Or with free functions:
/// T OnCacheHit(Blob blob) { ... }
// ResultOrError<T> OnCacheMiss(Request r) { ... }
// // ...
// Request r = { ... };
// auto result = LoadOrRun(device, std::move(r), OnCacheHit, OnCacheMiss);
//
// LoadOrRun generates a CacheKey from the request and loads from the device's BlobCache. On cache
// hit, calls CacheHitFn and returns a CacheResult<T>. On cache miss or if CacheHitFn returned an
// Error, calls CacheMissFn -> ResultOrError<T> with the request data and returns a
// ResultOrError<CacheResult<T>>. CacheHitFn must return the same unwrapped type as CacheMissFn.
// i.e. it doesn't need to be wrapped in ResultOrError.
//
// CacheMissFn may not have any additional data bound to it. It may not be a lambda or std::function
// which captures additional information, so it can only operate on the request data. This is
// enforced with a compile-time static_assert, and ensures that the result created from the
// computation is exactly the data included in the CacheKey.
template <typename Request>
class CacheRequestImpl {
public:
CacheRequestImpl() = default;
// Require CacheRequests to be move-only to avoid unnecessary copies.
CacheRequestImpl(CacheRequestImpl&&) = default;
CacheRequestImpl& operator=(CacheRequestImpl&&) = default;
CacheRequestImpl(const CacheRequestImpl&) = delete;
CacheRequestImpl& operator=(const CacheRequestImpl&) = delete;
template <typename CacheHitFn, typename CacheMissFn>
friend auto LoadOrRun(DeviceBase* device,
Request&& r,
CacheHitFn cacheHitFn,
CacheMissFn cacheMissFn) {
// Get return types and check that CacheMissReturnType can be cast to a raw function
// pointer. This means it's not a std::function or lambda that captures additional data.
using CacheHitReturnType = decltype(cacheHitFn(std::declval<Blob>()));
using CacheMissReturnType = decltype(cacheMissFn(std::declval<Request>()));
static_assert(
std::is_convertible_v<CacheMissFn, CacheMissReturnType (*)(Request)>,
"CacheMissFn function signature does not match, or it is not a free function.");
static_assert(detail::IsResultOrError<CacheMissReturnType>::value,
"CacheMissFn should return a ResultOrError.");
using UnwrappedReturnType = typename detail::UnwrapResultOrError<CacheMissReturnType>::type;
static_assert(std::is_same_v<typename detail::UnwrapResultOrError<CacheHitReturnType>::type,
UnwrappedReturnType>,
"If CacheMissFn returns T, CacheHitFn must return T or ResultOrError<T>.");
using CacheResultType = CacheResult<UnwrappedReturnType>;
using ReturnType = ResultOrError<CacheResultType>;
CacheKey key = r.CreateCacheKey(device);
BlobCache* cache = device->GetBlobCache();
Blob blob;
if (cache != nullptr) {
blob = cache->Load(key);
}
if (!blob.Empty()) {
// Cache hit. Handle the cached blob.
auto result = cacheHitFn(std::move(blob));
if constexpr (!detail::IsResultOrError<CacheHitReturnType>::value) {
// If the result type is not a ResultOrError, return it.
return ReturnType(CacheResultType::CacheHit(std::move(key), std::move(result)));
} else {
// Otherwise, if the value is a success, also return it.
if (DAWN_LIKELY(result.IsSuccess())) {
return ReturnType(
CacheResultType::CacheHit(std::move(key), result.AcquireSuccess()));
}
// On error, continue to the cache miss path and log the error.
detail::LogCacheHitError(result.AcquireError());
}
}
// Cache miss, or the CacheHitFn failed.
auto result = cacheMissFn(std::move(r));
if (DAWN_LIKELY(result.IsSuccess())) {
return ReturnType(CacheResultType::CacheMiss(std::move(key), result.AcquireSuccess()));
}
return ReturnType(result.AcquireError());
}
};
// Helper for X macro to declare a struct member.
#define DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER(type, name) type name{};
// Helper for X macro for recording cache request fields into a CacheKey.
#define DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY(type, name) key.Record(name);
// Helper X macro to define a CacheRequest struct.
// Example usage:
// #define REQUEST_MEMBERS(X) \
// X(int, a) \
// X(float, b) \
// X(Foo, foo) \
// X(Bar, bar)
// DAWN_MAKE_CACHE_REQUEST(MyCacheRequest, REQUEST_MEMBERS)
// #undef REQUEST_MEMBERS
#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \
class Request : public CacheRequestImpl<Request> { \
public: \
Request() = default; \
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER) \
\
/* Create a CacheKey from the request type and all members */ \
CacheKey CreateCacheKey(const DeviceBase* device) const { \
CacheKey key = device->GetCacheKey(); \
key.Record(#Request); \
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \
return key; \
} \
};
} // namespace dawn::native
#endif // SRC_DAWN_NATIVE_CACHEREQUEST_H_

View File

@ -0,0 +1,76 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_DAWN_NATIVE_CACHERESULT_H_
#define SRC_DAWN_NATIVE_CACHERESULT_H_
#include <memory>
#include <utility>
#include "dawn/common/Assert.h"
namespace dawn::native {
template <typename T>
class CacheResult {
public:
static CacheResult CacheHit(CacheKey key, T value) {
return CacheResult(std::move(key), std::move(value), true);
}
static CacheResult CacheMiss(CacheKey key, T value) {
return CacheResult(std::move(key), std::move(value), false);
}
CacheResult() : mKey(), mValue(), mIsCached(false), mIsValid(false) {}
bool IsCached() const {
ASSERT(mIsValid);
return mIsCached;
}
const CacheKey& GetCacheKey() {
ASSERT(mIsValid);
return mKey;
}
// Note: Getting mValue is always const, since mutating it would invalidate consistency with
// mKey.
const T* operator->() const {
ASSERT(mIsValid);
return &mValue;
}
const T& operator*() const {
ASSERT(mIsValid);
return mValue;
}
T Acquire() {
ASSERT(mIsValid);
mIsValid = false;
return std::move(mValue);
}
private:
CacheResult(CacheKey key, T value, bool isCached)
: mKey(std::move(key)), mValue(std::move(value)), mIsCached(isCached), mIsValid(true) {}
CacheKey mKey;
T mValue;
bool mIsCached;
bool mIsValid;
};
} // namespace dawn::native
#endif // SRC_DAWN_NATIVE_CACHERESULT_H_

View File

@ -192,6 +192,7 @@ dawn_test("dawn_unittests") {
":gmock_and_gtest",
":mock_webgpu_gen",
":native_mocks_sources",
":platform_mocks_sources",
"${dawn_root}/src/dawn:cpp",
"${dawn_root}/src/dawn:proc",
"${dawn_root}/src/dawn/common",
@ -254,6 +255,7 @@ dawn_test("dawn_unittests") {
"unittests/VersionTests.cpp",
"unittests/native/BlobTests.cpp",
"unittests/native/CacheKeyTests.cpp",
"unittests/native/CacheRequestTests.cpp",
"unittests/native/CommandBufferEncodingTests.cpp",
"unittests/native/CreatePipelineAsyncTaskTests.cpp",
"unittests/native/DestroyObjectTests.cpp",
@ -380,9 +382,9 @@ source_set("test_infra_sources") {
# Dawn end2end tests targets
###############################################################################
# Source code for mocks used for end2end testing are separated from the rest of
# Source code for mocks used for platform testing are separated from the rest of
# sources so that they aren't included in non-test builds.
source_set("end2end_mocks_sources") {
source_set("platform_mocks_sources") {
configs += [ "${dawn_root}/src/dawn/native:internal" ]
testonly = true
@ -392,8 +394,8 @@ source_set("end2end_mocks_sources") {
]
sources = [
"end2end/mocks/CachingInterfaceMock.cpp",
"end2end/mocks/CachingInterfaceMock.h",
"mocks/platform/CachingInterfaceMock.cpp",
"mocks/platform/CachingInterfaceMock.h",
]
}
@ -401,7 +403,7 @@ source_set("end2end_tests_sources") {
testonly = true
deps = [
":end2end_mocks_sources",
":platform_mocks_sources",
":test_infra_sources",
"${dawn_root}/src/dawn:cpp",
"${dawn_root}/src/dawn:proc",

View File

@ -20,6 +20,9 @@
#include "dawn/common/Assert.h"
#include "dawn/dawn_proc.h"
#include "dawn/native/ErrorData.h"
#include "dawn/native/Instance.h"
#include "dawn/native/dawn_platform.h"
#include "dawn/platform/DawnPlatform.h"
namespace dawn::native {
@ -43,6 +46,9 @@ DawnNativeTest::~DawnNativeTest() {
void DawnNativeTest::SetUp() {
instance = std::make_unique<dawn::native::Instance>();
platform = CreateTestPlatform();
dawn::native::FromAPI(instance->Get())->SetPlatformForTesting(platform.get());
instance->DiscoverDefaultAdapters();
std::vector<dawn::native::Adapter> adapters = instance->GetAdapters();
@ -66,7 +72,9 @@ void DawnNativeTest::SetUp() {
device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr);
}
void DawnNativeTest::TearDown() {}
std::unique_ptr<dawn::platform::Platform> DawnNativeTest::CreateTestPlatform() {
return nullptr;
}
WGPUDevice DawnNativeTest::CreateTestDevice() {
// Disabled disallowing unsafe APIs so we can test them.

View File

@ -38,12 +38,13 @@ class DawnNativeTest : public ::testing::Test {
~DawnNativeTest() override;
void SetUp() override;
void TearDown() override;
virtual std::unique_ptr<dawn::platform::Platform> CreateTestPlatform();
virtual WGPUDevice CreateTestDevice();
protected:
std::unique_ptr<dawn::native::Instance> instance;
std::unique_ptr<dawn::platform::Platform> platform;
dawn::native::Adapter adapter;
wgpu::Device device;

View File

@ -17,7 +17,7 @@
#include <utility>
#include "dawn/tests/DawnTest.h"
#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
#include "dawn/utils/WGPUHelpers.h"

View File

@ -16,7 +16,7 @@
#include <string_view>
#include "dawn/tests/DawnTest.h"
#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
#include "dawn/utils/WGPUHelpers.h"

View File

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
using ::testing::Invoke;

View File

@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
#define SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
#ifndef SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
#define SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
#include <dawn/platform/DawnPlatform.h>
#include <gmock/gmock.h>
@ -70,4 +70,4 @@ class DawnCachingMockPlatform : public dawn::platform::Platform {
dawn::platform::CachingInterface* mCachingInterface = nullptr;
};
#endif // SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
#endif // SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_

View File

@ -0,0 +1,320 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <utility>
#include <vector>
#include "dawn/native/Blob.h"
#include "dawn/native/CacheRequest.h"
#include "dawn/tests/DawnNativeTest.h"
#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
namespace dawn::native {
namespace {
using ::testing::_;
using ::testing::ByMove;
using ::testing::Invoke;
using ::testing::MockFunction;
using ::testing::Return;
using ::testing::StrictMock;
using ::testing::WithArg;
class CacheRequestTests : public DawnNativeTest {
protected:
std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override {
return std::make_unique<DawnCachingMockPlatform>(&mMockCache);
}
WGPUDevice CreateTestDevice() override {
wgpu::DeviceDescriptor deviceDescriptor = {};
wgpu::DawnTogglesDeviceDescriptor togglesDesc = {};
deviceDescriptor.nextInChain = &togglesDesc;
const char* toggle = "enable_blob_cache";
togglesDesc.forceEnabledToggles = &toggle;
togglesDesc.forceEnabledTogglesCount = 1;
return adapter.CreateDevice(&deviceDescriptor);
}
DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); }
StrictMock<CachingInterfaceMock> mMockCache;
};
struct Foo {
int value;
};
#define REQUEST_MEMBERS(X) \
X(int, a) \
X(float, b) \
X(std::vector<unsigned int>, c) \
X(CacheKey::UnsafeUnkeyedValue<int*>, d) \
X(CacheKey::UnsafeUnkeyedValue<Foo>, e)
DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS)
#undef REQUEST_MEMBERS
// static_assert the expected types for various return types from the cache hit handler and cache
// miss handler.
TEST_F(CacheRequestTests, CacheResultTypes) {
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0));
// (int, ResultOrError<int>), should be ResultOrError<CacheResult<int>>.
auto v1 = LoadOrRun(
GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; },
[](CacheRequestForTesting) -> ResultOrError<int> { return 1; });
v1.AcquireSuccess();
static_assert(std::is_same_v<ResultOrError<CacheResult<int>>, decltype(v1)>);
// (ResultOrError<float>, ResultOrError<float>), should be ResultOrError<CacheResult<float>>.
auto v2 = LoadOrRun(
GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError<float> { return 0.0; },
[](CacheRequestForTesting) -> ResultOrError<float> { return 1.0; });
v2.AcquireSuccess();
static_assert(std::is_same_v<ResultOrError<CacheResult<float>>, decltype(v2)>);
}
// Test that using a CacheRequest builds a key from the device key, the request type enum, and all
// of the request members.
TEST_F(CacheRequestTests, MakesCacheKey) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
// Make the expected key.
CacheKey expectedKey;
expectedKey.Record(GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, req.c);
// Expect a call to LoadData with the expected key.
EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0))
.WillOnce(WithArg<0>(Invoke([&](const void* actualKeyData) {
EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0);
return 0;
})));
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req), [](Blob) -> int { return 0; },
[](CacheRequestForTesting) -> ResultOrError<int> { return 0; })
.AcquireSuccess();
// The created cache key should be saved on the result.
EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size());
EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0);
}
// Test that members that are wrapped in UnsafeUnkeyedValue do not impact the key.
TEST_F(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) {
// Make two requests with different UnsafeUnkeyedValues (UnsafeUnkeyed is declared on the struct
// definition).
int v1, v2;
CacheRequestForTesting req1;
req1.d = &v1;
req1.e = Foo{42};
CacheRequestForTesting req2;
req2.d = &v2;
req2.e = Foo{24};
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0));
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
// Load the first request, and check that the unsafe unkeyed values were passed though
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
EXPECT_EQ(req.d.UnsafeGetValue(), &v1);
EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42);
return 0;
})));
auto r1 = LoadOrRun(
GetDevice(), std::move(req1), [](Blob) { return 0; },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Load the second request, and check that the unsafe unkeyed values were passed though
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
EXPECT_EQ(req.d.UnsafeGetValue(), &v2);
EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24);
return 0;
})));
auto r2 = LoadOrRun(
GetDevice(), std::move(req2), [](Blob) { return 0; },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect their keys to be the same.
EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size());
EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0);
}
// Test the expected code path when there is a cache miss.
TEST_F(CacheRequestTests, CacheMiss) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
unsigned int* cPtr = req.c.data();
static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
// Mock a cache miss.
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0));
// Expect the cache miss, and return some value.
int rv = 42;
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
// Expect the request contents to be the same. The data pointer for |c| is also the same
// since it was moved.
EXPECT_EQ(req.a, 1);
EXPECT_FLOAT_EQ(req.b, 0.2);
EXPECT_EQ(req.c.data(), cPtr);
return rv;
})));
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect the result to store the value.
EXPECT_EQ(*result, rv);
EXPECT_FALSE(result.IsCached());
}
// Test the expected code path when there is a cache hit.
TEST_F(CacheRequestTests, CacheHit) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
static constexpr char kCachedData[] = "hello world!";
// Mock a cache hit, and load the cached data.
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
.WillOnce(WithArg<2>(Invoke([](void* dataOut) {
memcpy(dataOut, kCachedData, sizeof(kCachedData));
return sizeof(kCachedData);
})));
// Expect the cache hit, and return some value.
int rv = 1337;
EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
// Expect the cached blob contents to match the cached data.
EXPECT_EQ(blob.Size(), sizeof(kCachedData));
EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
return rv;
})));
// Load the request.
auto result = LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect the result to store the value.
EXPECT_EQ(*result, rv);
EXPECT_TRUE(result.IsCached());
}
// Test the expected code path when there is a cache hit but the handler errors.
TEST_F(CacheRequestTests, CacheHitError) {
// Make a request.
CacheRequestForTesting req;
req.a = 1;
req.b = 0.2;
req.c = {3, 4, 5};
unsigned int* cPtr = req.c.data();
static StrictMock<MockFunction<ResultOrError<int>(Blob)>> cacheHitFn;
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
static constexpr char kCachedData[] = "hello world!";
// Mock a cache hit, and load the cached data.
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
.WillOnce(WithArg<2>(Invoke([](void* dataOut) {
memcpy(dataOut, kCachedData, sizeof(kCachedData));
return sizeof(kCachedData);
})));
// Expect the cache hit.
EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
// Expect the cached blob contents to match the cached data.
EXPECT_EQ(blob.Size(), sizeof(kCachedData));
EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
// Return an error.
return DAWN_VALIDATION_ERROR("fake test error");
})));
// Expect the cache miss handler since the cache hit errored.
int rv = 79;
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
// Expect the request contents to be the same. The data pointer for |c| is also the same
// since it was moved.
EXPECT_EQ(req.a, 1);
EXPECT_FLOAT_EQ(req.b, 0.2);
EXPECT_EQ(req.c.data(), cPtr);
return rv;
})));
// Load the request.
auto result =
LoadOrRun(
GetDevice(), std::move(req),
[](Blob blob) -> ResultOrError<int> { return cacheHitFn.Call(std::move(blob)); },
[](CacheRequestForTesting req) -> ResultOrError<int> {
return cacheMissFn.Call(std::move(req));
})
.AcquireSuccess();
// Expect the result to store the value.
EXPECT_EQ(*result, rv);
EXPECT_FALSE(result.IsCached());
}
} // namespace
} // namespace dawn::native