Add CacheRequest utilities and tests
This CL adds a DAWN_MAKE_CACHE_REQUEST X macro which helps in building a CacheRequest struct. A CacheRequest struct may be passed to LoadOrRun which will generate a CacheKey from the struct and load a result if there is a cache hit, or it will call the provided cache miss function to compute a value. The request struct helps enforce that precisely the inputs that go into a computation are all also included inside the CacheKey for that computation. Bug: dawn:549 Change-Id: Id85eb95f1b944d5431f142162ffa9a384351be89 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/91063 Reviewed-by: Loko Kung <lokokung@google.com> Commit-Queue: Austin Eng <enga@chromium.org>
This commit is contained in:
parent
76d0454c66
commit
d3fa3f0e23
|
@ -200,6 +200,9 @@ source_set("sources") {
|
|||
"Buffer.h",
|
||||
"CacheKey.cpp",
|
||||
"CacheKey.h",
|
||||
"CacheRequest.cpp",
|
||||
"CacheRequest.h",
|
||||
"CacheResult.h",
|
||||
"CachedObject.cpp",
|
||||
"CachedObject.h",
|
||||
"CallbackTaskManager.cpp",
|
||||
|
|
|
@ -59,6 +59,9 @@ target_sources(dawn_native PRIVATE
|
|||
"CachedObject.h"
|
||||
"CacheKey.cpp"
|
||||
"CacheKey.h"
|
||||
"CacheRequest.cpp"
|
||||
"CacheRequest.h"
|
||||
"CacheResult.h"
|
||||
"CallbackTaskManager.cpp"
|
||||
"CallbackTaskManager.h"
|
||||
"CommandAllocator.cpp"
|
||||
|
|
|
@ -18,8 +18,10 @@
|
|||
#include <bitset>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "dawn/common/TypedInteger.h"
|
||||
|
@ -48,6 +50,19 @@ class CacheKey : public std::vector<uint8_t> {
|
|||
|
||||
enum class Type { ComputePipeline, RenderPipeline, Shader };
|
||||
|
||||
template <typename T>
|
||||
class UnsafeUnkeyedValue {
|
||||
public:
|
||||
UnsafeUnkeyedValue() = default;
|
||||
// NOLINTNEXTLINE(runtime/explicit) allow implicit construction to decrease verbosity
|
||||
UnsafeUnkeyedValue(T&& value) : mValue(std::forward<T>(value)) {}
|
||||
|
||||
const T& UnsafeGetValue() const { return mValue; }
|
||||
|
||||
private:
|
||||
T mValue;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CacheKey& Record(const T& t) {
|
||||
CacheKeySerializer<T>::Serialize(this, t);
|
||||
|
@ -89,6 +104,18 @@ class CacheKey : public std::vector<uint8_t> {
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CacheKey::UnsafeUnkeyedValue<T> UnsafeUnkeyedValue(T&& value) {
|
||||
return CacheKey::UnsafeUnkeyedValue<T>(std::forward<T>(value));
|
||||
}
|
||||
|
||||
// Specialized overload for CacheKey::UnsafeIgnoredValue which does nothing.
|
||||
template <typename T>
|
||||
class CacheKeySerializer<CacheKey::UnsafeUnkeyedValue<T>> {
|
||||
public:
|
||||
constexpr static void Serialize(CacheKey* key, const CacheKey::UnsafeUnkeyedValue<T>&) {}
|
||||
};
|
||||
|
||||
// Specialized overload for fundamental types.
|
||||
template <typename T>
|
||||
class CacheKeySerializer<T, std::enable_if_t<std::is_fundamental_v<T>>> {
|
||||
|
@ -197,6 +224,13 @@ class CacheKeySerializer<T, std::enable_if_t<std::is_base_of_v<CachedObject, T>>
|
|||
static void Serialize(CacheKey* key, const T& t) { key->Record(t.GetCacheKey()); }
|
||||
};
|
||||
|
||||
// Specialized overload for std::vector.
|
||||
template <typename T>
|
||||
class CacheKeySerializer<std::vector<T>> {
|
||||
public:
|
||||
static void Serialize(CacheKey* key, const std::vector<T>& t) { key->RecordIterable(t); }
|
||||
};
|
||||
|
||||
} // namespace dawn::native
|
||||
|
||||
#endif // SRC_DAWN_NATIVE_CACHEKEY_H_
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
// Copyright 2022 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "dawn/native/CacheRequest.h"
|
||||
|
||||
#include "dawn/common/Log.h"
|
||||
|
||||
namespace dawn::native::detail {
|
||||
|
||||
void LogCacheHitError(std::unique_ptr<ErrorData> error) {
|
||||
dawn::ErrorLog() << error->GetFormattedMessage();
|
||||
}
|
||||
|
||||
} // namespace dawn::native::detail
|
|
@ -0,0 +1,186 @@
|
|||
// Copyright 2022 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_DAWN_NATIVE_CACHEREQUEST_H_
|
||||
#define SRC_DAWN_NATIVE_CACHEREQUEST_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "dawn/common/Assert.h"
|
||||
#include "dawn/common/Compiler.h"
|
||||
#include "dawn/native/Blob.h"
|
||||
#include "dawn/native/BlobCache.h"
|
||||
#include "dawn/native/CacheKey.h"
|
||||
#include "dawn/native/CacheResult.h"
|
||||
#include "dawn/native/Device.h"
|
||||
#include "dawn/native/Error.h"
|
||||
|
||||
namespace dawn::native {
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
struct UnwrapResultOrError {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct UnwrapResultOrError<ResultOrError<T>> {
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct IsResultOrError {
|
||||
static constexpr bool value = false;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct IsResultOrError<ResultOrError<T>> {
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
void LogCacheHitError(std::unique_ptr<ErrorData> error);
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Implementation of a CacheRequest which provides a LoadOrRun friend function which can be found
|
||||
// via argument-dependent lookup. So, it doesn't need to be called with a fully qualified function
|
||||
// name.
|
||||
//
|
||||
// Example usage:
|
||||
// Request r = { ... };
|
||||
// ResultOrError<CacheResult<T>> cacheResult =
|
||||
// LoadOrRun(device, std::move(r),
|
||||
// [](Blob blob) -> T { /* handle cache hit */ },
|
||||
// [](Request r) -> ResultOrError<T> { /* handle cache miss */ }
|
||||
// );
|
||||
// Or with free functions:
|
||||
/// T OnCacheHit(Blob blob) { ... }
|
||||
// ResultOrError<T> OnCacheMiss(Request r) { ... }
|
||||
// // ...
|
||||
// Request r = { ... };
|
||||
// auto result = LoadOrRun(device, std::move(r), OnCacheHit, OnCacheMiss);
|
||||
//
|
||||
// LoadOrRun generates a CacheKey from the request and loads from the device's BlobCache. On cache
|
||||
// hit, calls CacheHitFn and returns a CacheResult<T>. On cache miss or if CacheHitFn returned an
|
||||
// Error, calls CacheMissFn -> ResultOrError<T> with the request data and returns a
|
||||
// ResultOrError<CacheResult<T>>. CacheHitFn must return the same unwrapped type as CacheMissFn.
|
||||
// i.e. it doesn't need to be wrapped in ResultOrError.
|
||||
//
|
||||
// CacheMissFn may not have any additional data bound to it. It may not be a lambda or std::function
|
||||
// which captures additional information, so it can only operate on the request data. This is
|
||||
// enforced with a compile-time static_assert, and ensures that the result created from the
|
||||
// computation is exactly the data included in the CacheKey.
|
||||
template <typename Request>
|
||||
class CacheRequestImpl {
|
||||
public:
|
||||
CacheRequestImpl() = default;
|
||||
|
||||
// Require CacheRequests to be move-only to avoid unnecessary copies.
|
||||
CacheRequestImpl(CacheRequestImpl&&) = default;
|
||||
CacheRequestImpl& operator=(CacheRequestImpl&&) = default;
|
||||
CacheRequestImpl(const CacheRequestImpl&) = delete;
|
||||
CacheRequestImpl& operator=(const CacheRequestImpl&) = delete;
|
||||
|
||||
template <typename CacheHitFn, typename CacheMissFn>
|
||||
friend auto LoadOrRun(DeviceBase* device,
|
||||
Request&& r,
|
||||
CacheHitFn cacheHitFn,
|
||||
CacheMissFn cacheMissFn) {
|
||||
// Get return types and check that CacheMissReturnType can be cast to a raw function
|
||||
// pointer. This means it's not a std::function or lambda that captures additional data.
|
||||
using CacheHitReturnType = decltype(cacheHitFn(std::declval<Blob>()));
|
||||
using CacheMissReturnType = decltype(cacheMissFn(std::declval<Request>()));
|
||||
static_assert(
|
||||
std::is_convertible_v<CacheMissFn, CacheMissReturnType (*)(Request)>,
|
||||
"CacheMissFn function signature does not match, or it is not a free function.");
|
||||
|
||||
static_assert(detail::IsResultOrError<CacheMissReturnType>::value,
|
||||
"CacheMissFn should return a ResultOrError.");
|
||||
using UnwrappedReturnType = typename detail::UnwrapResultOrError<CacheMissReturnType>::type;
|
||||
|
||||
static_assert(std::is_same_v<typename detail::UnwrapResultOrError<CacheHitReturnType>::type,
|
||||
UnwrappedReturnType>,
|
||||
"If CacheMissFn returns T, CacheHitFn must return T or ResultOrError<T>.");
|
||||
|
||||
using CacheResultType = CacheResult<UnwrappedReturnType>;
|
||||
using ReturnType = ResultOrError<CacheResultType>;
|
||||
|
||||
CacheKey key = r.CreateCacheKey(device);
|
||||
BlobCache* cache = device->GetBlobCache();
|
||||
Blob blob;
|
||||
if (cache != nullptr) {
|
||||
blob = cache->Load(key);
|
||||
}
|
||||
|
||||
if (!blob.Empty()) {
|
||||
// Cache hit. Handle the cached blob.
|
||||
auto result = cacheHitFn(std::move(blob));
|
||||
|
||||
if constexpr (!detail::IsResultOrError<CacheHitReturnType>::value) {
|
||||
// If the result type is not a ResultOrError, return it.
|
||||
return ReturnType(CacheResultType::CacheHit(std::move(key), std::move(result)));
|
||||
} else {
|
||||
// Otherwise, if the value is a success, also return it.
|
||||
if (DAWN_LIKELY(result.IsSuccess())) {
|
||||
return ReturnType(
|
||||
CacheResultType::CacheHit(std::move(key), result.AcquireSuccess()));
|
||||
}
|
||||
// On error, continue to the cache miss path and log the error.
|
||||
detail::LogCacheHitError(result.AcquireError());
|
||||
}
|
||||
}
|
||||
// Cache miss, or the CacheHitFn failed.
|
||||
auto result = cacheMissFn(std::move(r));
|
||||
if (DAWN_LIKELY(result.IsSuccess())) {
|
||||
return ReturnType(CacheResultType::CacheMiss(std::move(key), result.AcquireSuccess()));
|
||||
}
|
||||
return ReturnType(result.AcquireError());
|
||||
}
|
||||
};
|
||||
|
||||
// Helper for X macro to declare a struct member.
|
||||
#define DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER(type, name) type name{};
|
||||
|
||||
// Helper for X macro for recording cache request fields into a CacheKey.
|
||||
#define DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY(type, name) key.Record(name);
|
||||
|
||||
// Helper X macro to define a CacheRequest struct.
|
||||
// Example usage:
|
||||
// #define REQUEST_MEMBERS(X) \
|
||||
// X(int, a) \
|
||||
// X(float, b) \
|
||||
// X(Foo, foo) \
|
||||
// X(Bar, bar)
|
||||
// DAWN_MAKE_CACHE_REQUEST(MyCacheRequest, REQUEST_MEMBERS)
|
||||
// #undef REQUEST_MEMBERS
|
||||
#define DAWN_MAKE_CACHE_REQUEST(Request, MEMBERS) \
|
||||
class Request : public CacheRequestImpl<Request> { \
|
||||
public: \
|
||||
Request() = default; \
|
||||
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_DECL_STRUCT_MEMBER) \
|
||||
\
|
||||
/* Create a CacheKey from the request type and all members */ \
|
||||
CacheKey CreateCacheKey(const DeviceBase* device) const { \
|
||||
CacheKey key = device->GetCacheKey(); \
|
||||
key.Record(#Request); \
|
||||
MEMBERS(DAWN_INTERNAL_CACHE_REQUEST_RECORD_KEY) \
|
||||
return key; \
|
||||
} \
|
||||
};
|
||||
|
||||
} // namespace dawn::native
|
||||
|
||||
#endif // SRC_DAWN_NATIVE_CACHEREQUEST_H_
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright 2022 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_DAWN_NATIVE_CACHERESULT_H_
|
||||
#define SRC_DAWN_NATIVE_CACHERESULT_H_
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
||||
#include "dawn/common/Assert.h"
|
||||
|
||||
namespace dawn::native {
|
||||
|
||||
template <typename T>
|
||||
class CacheResult {
|
||||
public:
|
||||
static CacheResult CacheHit(CacheKey key, T value) {
|
||||
return CacheResult(std::move(key), std::move(value), true);
|
||||
}
|
||||
|
||||
static CacheResult CacheMiss(CacheKey key, T value) {
|
||||
return CacheResult(std::move(key), std::move(value), false);
|
||||
}
|
||||
|
||||
CacheResult() : mKey(), mValue(), mIsCached(false), mIsValid(false) {}
|
||||
|
||||
bool IsCached() const {
|
||||
ASSERT(mIsValid);
|
||||
return mIsCached;
|
||||
}
|
||||
const CacheKey& GetCacheKey() {
|
||||
ASSERT(mIsValid);
|
||||
return mKey;
|
||||
}
|
||||
|
||||
// Note: Getting mValue is always const, since mutating it would invalidate consistency with
|
||||
// mKey.
|
||||
const T* operator->() const {
|
||||
ASSERT(mIsValid);
|
||||
return &mValue;
|
||||
}
|
||||
const T& operator*() const {
|
||||
ASSERT(mIsValid);
|
||||
return mValue;
|
||||
}
|
||||
|
||||
T Acquire() {
|
||||
ASSERT(mIsValid);
|
||||
mIsValid = false;
|
||||
return std::move(mValue);
|
||||
}
|
||||
|
||||
private:
|
||||
CacheResult(CacheKey key, T value, bool isCached)
|
||||
: mKey(std::move(key)), mValue(std::move(value)), mIsCached(isCached), mIsValid(true) {}
|
||||
|
||||
CacheKey mKey;
|
||||
T mValue;
|
||||
bool mIsCached;
|
||||
bool mIsValid;
|
||||
};
|
||||
|
||||
} // namespace dawn::native
|
||||
|
||||
#endif // SRC_DAWN_NATIVE_CACHERESULT_H_
|
|
@ -192,6 +192,7 @@ dawn_test("dawn_unittests") {
|
|||
":gmock_and_gtest",
|
||||
":mock_webgpu_gen",
|
||||
":native_mocks_sources",
|
||||
":platform_mocks_sources",
|
||||
"${dawn_root}/src/dawn:cpp",
|
||||
"${dawn_root}/src/dawn:proc",
|
||||
"${dawn_root}/src/dawn/common",
|
||||
|
@ -254,6 +255,7 @@ dawn_test("dawn_unittests") {
|
|||
"unittests/VersionTests.cpp",
|
||||
"unittests/native/BlobTests.cpp",
|
||||
"unittests/native/CacheKeyTests.cpp",
|
||||
"unittests/native/CacheRequestTests.cpp",
|
||||
"unittests/native/CommandBufferEncodingTests.cpp",
|
||||
"unittests/native/CreatePipelineAsyncTaskTests.cpp",
|
||||
"unittests/native/DestroyObjectTests.cpp",
|
||||
|
@ -380,9 +382,9 @@ source_set("test_infra_sources") {
|
|||
# Dawn end2end tests targets
|
||||
###############################################################################
|
||||
|
||||
# Source code for mocks used for end2end testing are separated from the rest of
|
||||
# Source code for mocks used for platform testing are separated from the rest of
|
||||
# sources so that they aren't included in non-test builds.
|
||||
source_set("end2end_mocks_sources") {
|
||||
source_set("platform_mocks_sources") {
|
||||
configs += [ "${dawn_root}/src/dawn/native:internal" ]
|
||||
testonly = true
|
||||
|
||||
|
@ -392,8 +394,8 @@ source_set("end2end_mocks_sources") {
|
|||
]
|
||||
|
||||
sources = [
|
||||
"end2end/mocks/CachingInterfaceMock.cpp",
|
||||
"end2end/mocks/CachingInterfaceMock.h",
|
||||
"mocks/platform/CachingInterfaceMock.cpp",
|
||||
"mocks/platform/CachingInterfaceMock.h",
|
||||
]
|
||||
}
|
||||
|
||||
|
@ -401,7 +403,7 @@ source_set("end2end_tests_sources") {
|
|||
testonly = true
|
||||
|
||||
deps = [
|
||||
":end2end_mocks_sources",
|
||||
":platform_mocks_sources",
|
||||
":test_infra_sources",
|
||||
"${dawn_root}/src/dawn:cpp",
|
||||
"${dawn_root}/src/dawn:proc",
|
||||
|
|
|
@ -20,6 +20,9 @@
|
|||
#include "dawn/common/Assert.h"
|
||||
#include "dawn/dawn_proc.h"
|
||||
#include "dawn/native/ErrorData.h"
|
||||
#include "dawn/native/Instance.h"
|
||||
#include "dawn/native/dawn_platform.h"
|
||||
#include "dawn/platform/DawnPlatform.h"
|
||||
|
||||
namespace dawn::native {
|
||||
|
||||
|
@ -43,6 +46,9 @@ DawnNativeTest::~DawnNativeTest() {
|
|||
|
||||
void DawnNativeTest::SetUp() {
|
||||
instance = std::make_unique<dawn::native::Instance>();
|
||||
platform = CreateTestPlatform();
|
||||
dawn::native::FromAPI(instance->Get())->SetPlatformForTesting(platform.get());
|
||||
|
||||
instance->DiscoverDefaultAdapters();
|
||||
|
||||
std::vector<dawn::native::Adapter> adapters = instance->GetAdapters();
|
||||
|
@ -66,7 +72,9 @@ void DawnNativeTest::SetUp() {
|
|||
device.SetUncapturedErrorCallback(DawnNativeTest::OnDeviceError, nullptr);
|
||||
}
|
||||
|
||||
void DawnNativeTest::TearDown() {}
|
||||
std::unique_ptr<dawn::platform::Platform> DawnNativeTest::CreateTestPlatform() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
WGPUDevice DawnNativeTest::CreateTestDevice() {
|
||||
// Disabled disallowing unsafe APIs so we can test them.
|
||||
|
|
|
@ -38,12 +38,13 @@ class DawnNativeTest : public ::testing::Test {
|
|||
~DawnNativeTest() override;
|
||||
|
||||
void SetUp() override;
|
||||
void TearDown() override;
|
||||
|
||||
virtual std::unique_ptr<dawn::platform::Platform> CreateTestPlatform();
|
||||
virtual WGPUDevice CreateTestDevice();
|
||||
|
||||
protected:
|
||||
std::unique_ptr<dawn::native::Instance> instance;
|
||||
std::unique_ptr<dawn::platform::Platform> platform;
|
||||
dawn::native::Adapter adapter;
|
||||
wgpu::Device device;
|
||||
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
#include <utility>
|
||||
|
||||
#include "dawn/tests/DawnTest.h"
|
||||
#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
|
||||
#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
|
||||
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
|
||||
#include "dawn/utils/WGPUHelpers.h"
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
#include <string_view>
|
||||
|
||||
#include "dawn/tests/DawnTest.h"
|
||||
#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
|
||||
#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
|
||||
#include "dawn/utils/ComboRenderPipelineDescriptor.h"
|
||||
#include "dawn/utils/WGPUHelpers.h"
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
|
||||
#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
|
||||
|
||||
using ::testing::Invoke;
|
||||
|
|
@ -12,8 +12,8 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
|
||||
#define SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
|
||||
#ifndef SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
|
||||
#define SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
|
||||
|
||||
#include <dawn/platform/DawnPlatform.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
@ -70,4 +70,4 @@ class DawnCachingMockPlatform : public dawn::platform::Platform {
|
|||
dawn::platform::CachingInterface* mCachingInterface = nullptr;
|
||||
};
|
||||
|
||||
#endif // SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
|
||||
#endif // SRC_DAWN_TESTS_MOCKS_PLATFORM_CACHINGINTERFACEMOCK_H_
|
|
@ -0,0 +1,320 @@
|
|||
// Copyright 2022 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "dawn/native/Blob.h"
|
||||
#include "dawn/native/CacheRequest.h"
|
||||
#include "dawn/tests/DawnNativeTest.h"
|
||||
#include "dawn/tests/mocks/platform/CachingInterfaceMock.h"
|
||||
|
||||
namespace dawn::native {
|
||||
|
||||
namespace {
|
||||
|
||||
using ::testing::_;
|
||||
using ::testing::ByMove;
|
||||
using ::testing::Invoke;
|
||||
using ::testing::MockFunction;
|
||||
using ::testing::Return;
|
||||
using ::testing::StrictMock;
|
||||
using ::testing::WithArg;
|
||||
|
||||
class CacheRequestTests : public DawnNativeTest {
|
||||
protected:
|
||||
std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override {
|
||||
return std::make_unique<DawnCachingMockPlatform>(&mMockCache);
|
||||
}
|
||||
|
||||
WGPUDevice CreateTestDevice() override {
|
||||
wgpu::DeviceDescriptor deviceDescriptor = {};
|
||||
wgpu::DawnTogglesDeviceDescriptor togglesDesc = {};
|
||||
deviceDescriptor.nextInChain = &togglesDesc;
|
||||
|
||||
const char* toggle = "enable_blob_cache";
|
||||
togglesDesc.forceEnabledToggles = &toggle;
|
||||
togglesDesc.forceEnabledTogglesCount = 1;
|
||||
|
||||
return adapter.CreateDevice(&deviceDescriptor);
|
||||
}
|
||||
|
||||
DeviceBase* GetDevice() { return dawn::native::FromAPI(device.Get()); }
|
||||
|
||||
StrictMock<CachingInterfaceMock> mMockCache;
|
||||
};
|
||||
|
||||
struct Foo {
|
||||
int value;
|
||||
};
|
||||
|
||||
#define REQUEST_MEMBERS(X) \
|
||||
X(int, a) \
|
||||
X(float, b) \
|
||||
X(std::vector<unsigned int>, c) \
|
||||
X(CacheKey::UnsafeUnkeyedValue<int*>, d) \
|
||||
X(CacheKey::UnsafeUnkeyedValue<Foo>, e)
|
||||
|
||||
DAWN_MAKE_CACHE_REQUEST(CacheRequestForTesting, REQUEST_MEMBERS)
|
||||
|
||||
#undef REQUEST_MEMBERS
|
||||
|
||||
// static_assert the expected types for various return types from the cache hit handler and cache
|
||||
// miss handler.
|
||||
TEST_F(CacheRequestTests, CacheResultTypes) {
|
||||
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillRepeatedly(Return(0));
|
||||
|
||||
// (int, ResultOrError<int>), should be ResultOrError<CacheResult<int>>.
|
||||
auto v1 = LoadOrRun(
|
||||
GetDevice(), CacheRequestForTesting{}, [](Blob) -> int { return 0; },
|
||||
[](CacheRequestForTesting) -> ResultOrError<int> { return 1; });
|
||||
v1.AcquireSuccess();
|
||||
static_assert(std::is_same_v<ResultOrError<CacheResult<int>>, decltype(v1)>);
|
||||
|
||||
// (ResultOrError<float>, ResultOrError<float>), should be ResultOrError<CacheResult<float>>.
|
||||
auto v2 = LoadOrRun(
|
||||
GetDevice(), CacheRequestForTesting{}, [](Blob) -> ResultOrError<float> { return 0.0; },
|
||||
[](CacheRequestForTesting) -> ResultOrError<float> { return 1.0; });
|
||||
v2.AcquireSuccess();
|
||||
static_assert(std::is_same_v<ResultOrError<CacheResult<float>>, decltype(v2)>);
|
||||
}
|
||||
|
||||
// Test that using a CacheRequest builds a key from the device key, the request type enum, and all
|
||||
// of the request members.
|
||||
TEST_F(CacheRequestTests, MakesCacheKey) {
|
||||
// Make a request.
|
||||
CacheRequestForTesting req;
|
||||
req.a = 1;
|
||||
req.b = 0.2;
|
||||
req.c = {3, 4, 5};
|
||||
|
||||
// Make the expected key.
|
||||
CacheKey expectedKey;
|
||||
expectedKey.Record(GetDevice()->GetCacheKey(), "CacheRequestForTesting", req.a, req.b, req.c);
|
||||
|
||||
// Expect a call to LoadData with the expected key.
|
||||
EXPECT_CALL(mMockCache, LoadData(_, expectedKey.size(), nullptr, 0))
|
||||
.WillOnce(WithArg<0>(Invoke([&](const void* actualKeyData) {
|
||||
EXPECT_EQ(memcmp(actualKeyData, expectedKey.data(), expectedKey.size()), 0);
|
||||
return 0;
|
||||
})));
|
||||
|
||||
// Load the request.
|
||||
auto result = LoadOrRun(
|
||||
GetDevice(), std::move(req), [](Blob) -> int { return 0; },
|
||||
[](CacheRequestForTesting) -> ResultOrError<int> { return 0; })
|
||||
.AcquireSuccess();
|
||||
|
||||
// The created cache key should be saved on the result.
|
||||
EXPECT_EQ(result.GetCacheKey().size(), expectedKey.size());
|
||||
EXPECT_EQ(memcmp(result.GetCacheKey().data(), expectedKey.data(), expectedKey.size()), 0);
|
||||
}
|
||||
|
||||
// Test that members that are wrapped in UnsafeUnkeyedValue do not impact the key.
|
||||
TEST_F(CacheRequestTests, CacheKeyIgnoresUnsafeIgnoredValue) {
|
||||
// Make two requests with different UnsafeUnkeyedValues (UnsafeUnkeyed is declared on the struct
|
||||
// definition).
|
||||
int v1, v2;
|
||||
CacheRequestForTesting req1;
|
||||
req1.d = &v1;
|
||||
req1.e = Foo{42};
|
||||
|
||||
CacheRequestForTesting req2;
|
||||
req2.d = &v2;
|
||||
req2.e = Foo{24};
|
||||
|
||||
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0)).WillOnce(Return(0));
|
||||
|
||||
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
|
||||
|
||||
// Load the first request, and check that the unsafe unkeyed values were passed though
|
||||
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
|
||||
EXPECT_EQ(req.d.UnsafeGetValue(), &v1);
|
||||
EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 42);
|
||||
return 0;
|
||||
})));
|
||||
auto r1 = LoadOrRun(
|
||||
GetDevice(), std::move(req1), [](Blob) { return 0; },
|
||||
[](CacheRequestForTesting req) -> ResultOrError<int> {
|
||||
return cacheMissFn.Call(std::move(req));
|
||||
})
|
||||
.AcquireSuccess();
|
||||
|
||||
// Load the second request, and check that the unsafe unkeyed values were passed though
|
||||
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([&](CacheRequestForTesting req) {
|
||||
EXPECT_EQ(req.d.UnsafeGetValue(), &v2);
|
||||
EXPECT_FLOAT_EQ(req.e.UnsafeGetValue().value, 24);
|
||||
return 0;
|
||||
})));
|
||||
auto r2 = LoadOrRun(
|
||||
GetDevice(), std::move(req2), [](Blob) { return 0; },
|
||||
[](CacheRequestForTesting req) -> ResultOrError<int> {
|
||||
return cacheMissFn.Call(std::move(req));
|
||||
})
|
||||
.AcquireSuccess();
|
||||
|
||||
// Expect their keys to be the same.
|
||||
EXPECT_EQ(r1.GetCacheKey().size(), r2.GetCacheKey().size());
|
||||
EXPECT_EQ(memcmp(r1.GetCacheKey().data(), r2.GetCacheKey().data(), r1.GetCacheKey().size()), 0);
|
||||
}
|
||||
|
||||
// Test the expected code path when there is a cache miss.
|
||||
TEST_F(CacheRequestTests, CacheMiss) {
|
||||
// Make a request.
|
||||
CacheRequestForTesting req;
|
||||
req.a = 1;
|
||||
req.b = 0.2;
|
||||
req.c = {3, 4, 5};
|
||||
|
||||
unsigned int* cPtr = req.c.data();
|
||||
|
||||
static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
|
||||
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
|
||||
|
||||
// Mock a cache miss.
|
||||
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(0));
|
||||
|
||||
// Expect the cache miss, and return some value.
|
||||
int rv = 42;
|
||||
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
|
||||
// Expect the request contents to be the same. The data pointer for |c| is also the same
|
||||
// since it was moved.
|
||||
EXPECT_EQ(req.a, 1);
|
||||
EXPECT_FLOAT_EQ(req.b, 0.2);
|
||||
EXPECT_EQ(req.c.data(), cPtr);
|
||||
return rv;
|
||||
})));
|
||||
|
||||
// Load the request.
|
||||
auto result = LoadOrRun(
|
||||
GetDevice(), std::move(req),
|
||||
[](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
|
||||
[](CacheRequestForTesting req) -> ResultOrError<int> {
|
||||
return cacheMissFn.Call(std::move(req));
|
||||
})
|
||||
.AcquireSuccess();
|
||||
|
||||
// Expect the result to store the value.
|
||||
EXPECT_EQ(*result, rv);
|
||||
EXPECT_FALSE(result.IsCached());
|
||||
}
|
||||
|
||||
// Test the expected code path when there is a cache hit.
|
||||
TEST_F(CacheRequestTests, CacheHit) {
|
||||
// Make a request.
|
||||
CacheRequestForTesting req;
|
||||
req.a = 1;
|
||||
req.b = 0.2;
|
||||
req.c = {3, 4, 5};
|
||||
|
||||
static StrictMock<MockFunction<int(Blob)>> cacheHitFn;
|
||||
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
|
||||
|
||||
static constexpr char kCachedData[] = "hello world!";
|
||||
|
||||
// Mock a cache hit, and load the cached data.
|
||||
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
|
||||
EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
|
||||
.WillOnce(WithArg<2>(Invoke([](void* dataOut) {
|
||||
memcpy(dataOut, kCachedData, sizeof(kCachedData));
|
||||
return sizeof(kCachedData);
|
||||
})));
|
||||
|
||||
// Expect the cache hit, and return some value.
|
||||
int rv = 1337;
|
||||
EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
|
||||
// Expect the cached blob contents to match the cached data.
|
||||
EXPECT_EQ(blob.Size(), sizeof(kCachedData));
|
||||
EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
|
||||
|
||||
return rv;
|
||||
})));
|
||||
|
||||
// Load the request.
|
||||
auto result = LoadOrRun(
|
||||
GetDevice(), std::move(req),
|
||||
[](Blob blob) -> int { return cacheHitFn.Call(std::move(blob)); },
|
||||
[](CacheRequestForTesting req) -> ResultOrError<int> {
|
||||
return cacheMissFn.Call(std::move(req));
|
||||
})
|
||||
.AcquireSuccess();
|
||||
|
||||
// Expect the result to store the value.
|
||||
EXPECT_EQ(*result, rv);
|
||||
EXPECT_TRUE(result.IsCached());
|
||||
}
|
||||
|
||||
// Test the expected code path when there is a cache hit but the handler errors.
|
||||
TEST_F(CacheRequestTests, CacheHitError) {
|
||||
// Make a request.
|
||||
CacheRequestForTesting req;
|
||||
req.a = 1;
|
||||
req.b = 0.2;
|
||||
req.c = {3, 4, 5};
|
||||
|
||||
unsigned int* cPtr = req.c.data();
|
||||
|
||||
static StrictMock<MockFunction<ResultOrError<int>(Blob)>> cacheHitFn;
|
||||
static StrictMock<MockFunction<int(CacheRequestForTesting)>> cacheMissFn;
|
||||
|
||||
static constexpr char kCachedData[] = "hello world!";
|
||||
|
||||
// Mock a cache hit, and load the cached data.
|
||||
EXPECT_CALL(mMockCache, LoadData(_, _, nullptr, 0)).WillOnce(Return(sizeof(kCachedData)));
|
||||
EXPECT_CALL(mMockCache, LoadData(_, _, _, sizeof(kCachedData)))
|
||||
.WillOnce(WithArg<2>(Invoke([](void* dataOut) {
|
||||
memcpy(dataOut, kCachedData, sizeof(kCachedData));
|
||||
return sizeof(kCachedData);
|
||||
})));
|
||||
|
||||
// Expect the cache hit.
|
||||
EXPECT_CALL(cacheHitFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](Blob blob) {
|
||||
// Expect the cached blob contents to match the cached data.
|
||||
EXPECT_EQ(blob.Size(), sizeof(kCachedData));
|
||||
EXPECT_EQ(memcmp(blob.Data(), kCachedData, sizeof(kCachedData)), 0);
|
||||
|
||||
// Return an error.
|
||||
return DAWN_VALIDATION_ERROR("fake test error");
|
||||
})));
|
||||
|
||||
// Expect the cache miss handler since the cache hit errored.
|
||||
int rv = 79;
|
||||
EXPECT_CALL(cacheMissFn, Call(_)).WillOnce(WithArg<0>(Invoke([=](CacheRequestForTesting req) {
|
||||
// Expect the request contents to be the same. The data pointer for |c| is also the same
|
||||
// since it was moved.
|
||||
EXPECT_EQ(req.a, 1);
|
||||
EXPECT_FLOAT_EQ(req.b, 0.2);
|
||||
EXPECT_EQ(req.c.data(), cPtr);
|
||||
return rv;
|
||||
})));
|
||||
|
||||
// Load the request.
|
||||
auto result =
|
||||
LoadOrRun(
|
||||
GetDevice(), std::move(req),
|
||||
[](Blob blob) -> ResultOrError<int> { return cacheHitFn.Call(std::move(blob)); },
|
||||
[](CacheRequestForTesting req) -> ResultOrError<int> {
|
||||
return cacheMissFn.Call(std::move(req));
|
||||
})
|
||||
.AcquireSuccess();
|
||||
|
||||
// Expect the result to store the value.
|
||||
EXPECT_EQ(*result, rv);
|
||||
EXPECT_FALSE(result.IsCached());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace dawn::native
|
Loading…
Reference in New Issue