From 4091e0fa9c1e2814ffefdafd01bd8c196c3bfdff Mon Sep 17 00:00:00 2001 From: Loko Kung Date: Thu, 21 Apr 2022 03:14:37 +0000 Subject: [PATCH] Factor out common cache testing code. - To be used for pipeline cache testing. - Plumbs overriding the platform to the adapters for testing. - Restructures build a little bit so that the test infrastructure can have full access to dawn native internals. Also differentiates end2end and white_box a bit more to make it clear that end2end should not have access to dawn native internals. Bug: dawn:549, dawn:1374 Change-Id: Ibcc6c44a116c7967ee2317c74409f613e896eb0a Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/86841 Reviewed-by: Austin Eng Commit-Queue: Loko Kung --- include/dawn/native/DawnNative.h | 1 + src/dawn/native/DawnNative.cpp | 1 + src/dawn/native/Instance.cpp | 13 +- src/dawn/native/Instance.h | 5 +- src/dawn/tests/BUILD.gn | 110 +++++++------- src/dawn/tests/DawnTest.cpp | 8 +- src/dawn/tests/end2end/D3D12CachingTests.cpp | 135 ++++-------------- .../end2end/mocks/CachingInterfaceMock.cpp | 89 ++++++++++++ .../end2end/mocks/CachingInterfaceMock.h | 87 +++++++++++ 9 files changed, 274 insertions(+), 175 deletions(-) create mode 100644 src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp create mode 100644 src/dawn/tests/end2end/mocks/CachingInterfaceMock.h diff --git a/include/dawn/native/DawnNative.h b/include/dawn/native/DawnNative.h index 81c3b6b08b..a208fcfc64 100644 --- a/include/dawn/native/DawnNative.h +++ b/include/dawn/native/DawnNative.h @@ -164,6 +164,7 @@ namespace dawn::native { // Enable debug capture on Dawn startup void EnableBeginCaptureOnStartup(bool beginCaptureOnStartup); + // TODO(dawn:1374) Deprecate this once it is passed via the descriptor. void SetPlatform(dawn::platform::Platform* platform); // Returns the underlying WGPUInstance object. diff --git a/src/dawn/native/DawnNative.cpp b/src/dawn/native/DawnNative.cpp index 282e88e3a2..4787bbe890 100644 --- a/src/dawn/native/DawnNative.cpp +++ b/src/dawn/native/DawnNative.cpp @@ -240,6 +240,7 @@ namespace dawn::native { mImpl->EnableBeginCaptureOnStartup(beginCaptureOnStartup); } + // TODO(dawn:1374) Deprecate this once it is passed via the descriptor. void Instance::SetPlatform(dawn::platform::Platform* platform) { mImpl->SetPlatform(platform); } diff --git a/src/dawn/native/Instance.cpp b/src/dawn/native/Instance.cpp index 3a27b1bbf0..6c666213cc 100644 --- a/src/dawn/native/Instance.cpp +++ b/src/dawn/native/Instance.cpp @@ -422,15 +422,12 @@ namespace dawn::native { mBlobCache = std::make_unique(GetCachingInterface(platform)); } - dawn::platform::Platform* InstanceBase::GetPlatform() { - if (mPlatform != nullptr) { - return mPlatform; - } + void InstanceBase::SetPlatformForTesting(dawn::platform::Platform* platform) { + SetPlatform(platform); + } - if (mDefaultPlatform == nullptr) { - mDefaultPlatform = std::make_unique(); - } - return mDefaultPlatform.get(); + dawn::platform::Platform* InstanceBase::GetPlatform() { + return mPlatform; } BlobCache* InstanceBase::GetBlobCache() { diff --git a/src/dawn/native/Instance.h b/src/dawn/native/Instance.h index 13aa74e8fc..73f1084c09 100644 --- a/src/dawn/native/Instance.h +++ b/src/dawn/native/Instance.h @@ -77,9 +77,10 @@ namespace dawn::native { void EnableBeginCaptureOnStartup(bool beginCaptureOnStartup); bool IsBeginCaptureOnStartupEnabled() const; - // TODO(dawn:1374): SetPlatform should become a private helper, and a NOT thread-safe - // testing version exposed for special testing cases. + // TODO(dawn:1374): SetPlatform should become a private helper, and SetPlatformForTesting + // will become the NOT thread-safe testing version exposed for special testing cases. void SetPlatform(dawn::platform::Platform* platform); + void SetPlatformForTesting(dawn::platform::Platform* platform); dawn::platform::Platform* GetPlatform(); BlobCache* GetBlobCache(); diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn index 4b3a942fa2..090b3cf0c8 100644 --- a/src/dawn/tests/BUILD.gn +++ b/src/dawn/tests/BUILD.gn @@ -330,32 +330,76 @@ dawn_test("dawn_unittests") { } ############################################################################### -# Dawn end2end tests targets +# Dawn test infrastructure targets ############################################################################### -source_set("end2end_tests_sources") { - configs += [ "${dawn_root}/src/dawn/common:internal_config" ] +source_set("test_infra_sources") { + configs += [ "${dawn_root}/src/dawn/native:internal" ] testonly = true deps = [ - ":gmock_and_gtest", "${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:proc", "${dawn_root}/src/dawn/common", - - # Statically linked because the end2end white_box tests use Dawn internals. + "${dawn_root}/src/dawn/native:sources", "${dawn_root}/src/dawn/native:static", - "${dawn_root}/src/dawn/platform:platform", "${dawn_root}/src/dawn/utils", "${dawn_root}/src/dawn/wire", ] + public_deps = [ ":gmock_and_gtest" ] + + if (dawn_supports_glfw_for_windowing || dawn_enable_opengl) { + assert(dawn_supports_glfw_for_windowing) + public_deps += [ "${dawn_root}/src/dawn/utils:glfw" ] + } + sources = [ + "DawnTest.cpp", "DawnTest.h", "MockCallback.h", "ParamGenerator.h", "ToggleParser.cpp", "ToggleParser.h", + ] +} + +############################################################################### +# Dawn end2end tests targets +############################################################################### + +# Source code for mocks used for end2end testing are separated from the rest of +# sources so that they aren't included in non-test builds. +source_set("end2end_mocks_sources") { + configs += [ "${dawn_root}/src/dawn/native:internal" ] + testonly = true + + deps = [ + ":gmock_and_gtest", + "${dawn_root}/src/dawn/platform", + ] + + sources = [ + "end2end/mocks/CachingInterfaceMock.cpp", + "end2end/mocks/CachingInterfaceMock.h", + ] +} + +source_set("end2end_tests_sources") { + testonly = true + + deps = [ + ":end2end_mocks_sources", + ":test_infra_sources", + "${dawn_root}/src/dawn:cpp", + "${dawn_root}/src/dawn:proc", + "${dawn_root}/src/dawn/common", + "${dawn_root}/src/dawn/native:headers", + "${dawn_root}/src/dawn/utils", + "${dawn_root}/src/dawn/wire", + ] + + sources = [ "end2end/AdapterDiscoveryTests.cpp", "end2end/BasicTests.cpp", "end2end/BindGroupTests.cpp", @@ -461,17 +505,12 @@ source_set("end2end_tests_sources") { frameworks = [ "IOSurface.framework" ] } - if (dawn_enable_opengl) { - assert(dawn_supports_glfw_for_windowing) - } - if (dawn_supports_glfw_for_windowing) { sources += [ "end2end/SwapChainTests.cpp", "end2end/SwapChainValidationTests.cpp", "end2end/WindowSurfaceTests.cpp", ] - deps += [ "${dawn_root}/src/dawn/utils:glfw" ] } if (dawn_enable_d3d12 || (dawn_enable_vulkan && is_chromeos) || @@ -492,23 +531,19 @@ source_set("white_box_tests_sources") { testonly = true deps = [ - ":gmock_and_gtest", + ":test_infra_sources", "${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:proc", "${dawn_root}/src/dawn/common", - "${dawn_root}/src/dawn/native:sources", - # Statically linked because the end2end white_box tests use Dawn internals. + # Statically linked and with sources because the end2end white_box tests use Dawn internals. + "${dawn_root}/src/dawn/native:sources", "${dawn_root}/src/dawn/native:static", "${dawn_root}/src/dawn/utils", "${dawn_root}/src/dawn/wire", ] - sources = [ - "DawnTest.h", - "ParamGenerator.h", - "ToggleParser.h", - ] + sources = [] if (dawn_enable_vulkan) { deps += [ "${dawn_vulkan_headers_dir}:vulkan_headers" ] @@ -552,10 +587,6 @@ source_set("white_box_tests_sources") { sources += [ "white_box/MetalAutoreleasePoolTests.mm" ] } - if (dawn_enable_opengl) { - deps += [ "${dawn_root}/src/dawn/utils:glfw" ] - } - if (dawn_enable_opengles && defined(dawn_angle_dir)) { sources += [ "white_box/EGLImageWrappingTests.cpp" ] deps += [ "${dawn_angle_dir}:libEGL" ] @@ -567,22 +598,11 @@ source_set("white_box_tests_sources") { dawn_test("dawn_end2end_tests") { deps = [ ":end2end_tests_sources", - ":gmock_and_gtest", + ":test_infra_sources", ":white_box_tests_sources", - "${dawn_root}/src/dawn:cpp", - "${dawn_root}/src/dawn:proc", - "${dawn_root}/src/dawn/common", - "${dawn_root}/src/dawn/native:static", - "${dawn_root}/src/dawn/platform:platform", - "${dawn_root}/src/dawn/utils", - "${dawn_root}/src/dawn/wire", - ] - - sources = [ - "DawnTest.cpp", - "DawnTest.h", ] + sources = [] libs = [] # When building inside Chromium, use their gtest main function because it is @@ -593,10 +613,6 @@ dawn_test("dawn_end2end_tests") { sources += [ "End2EndTestsMain.cpp" ] } - if (dawn_enable_opengl) { - deps += [ "${dawn_root}/src/dawn/utils:glfw" ] - } - if (is_chromeos) { libs += [ "gbm" ] } @@ -608,22 +624,16 @@ dawn_test("dawn_end2end_tests") { dawn_test("dawn_perf_tests") { deps = [ - ":gmock_and_gtest", + ":test_infra_sources", "${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:proc", "${dawn_root}/src/dawn/common", - "${dawn_root}/src/dawn/native", "${dawn_root}/src/dawn/platform", "${dawn_root}/src/dawn/utils", "${dawn_root}/src/dawn/wire", ] sources = [ - "DawnTest.cpp", - "DawnTest.h", - "ParamGenerator.h", - "ToggleParser.cpp", - "ToggleParser.h", "perf_tests/BufferUploadPerf.cpp", "perf_tests/DawnPerfTest.cpp", "perf_tests/DawnPerfTest.h", @@ -648,8 +658,4 @@ dawn_test("dawn_perf_tests") { if (dawn_enable_metal) { frameworks = [ "IOSurface.framework" ] } - - if (dawn_enable_opengl) { - deps += [ "${dawn_root}/src/dawn/utils:glfw" ] - } } diff --git a/src/dawn/tests/DawnTest.cpp b/src/dawn/tests/DawnTest.cpp index 9dda7f2a9a..1bbb61562e 100644 --- a/src/dawn/tests/DawnTest.cpp +++ b/src/dawn/tests/DawnTest.cpp @@ -30,6 +30,8 @@ #include "dawn/common/Platform.h" #include "dawn/common/SystemUtils.h" #include "dawn/dawn_proc.h" +#include "dawn/native/Instance.h" +#include "dawn/native/dawn_platform.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/PlatformDebugLogger.h" #include "dawn/utils/SystemUtils.h" @@ -921,9 +923,11 @@ void DawnTestBase::SetUp() { mBackendAdapter = *it; } - // Setup the per-test platform. Tests can provide one by overloading CreateTestPlatform. + // Setup the per-test platform. Tests can provide one by overloading CreateTestPlatform. This is + // NOT a thread-safe operation and is allowed here for testing only. mTestPlatform = CreateTestPlatform(); - gTestEnv->GetInstance()->SetPlatform(mTestPlatform.get()); + dawn::native::FromAPI(gTestEnv->GetInstance()->Get()) + ->SetPlatformForTesting(mTestPlatform.get()); // Create the device from the adapter for (const char* forceEnabledWorkaround : mParam.forceEnabledWorkarounds) { diff --git a/src/dawn/tests/end2end/D3D12CachingTests.cpp b/src/dawn/tests/end2end/D3D12CachingTests.cpp index 745ae2122c..5a4da5bedf 100644 --- a/src/dawn/tests/end2end/D3D12CachingTests.cpp +++ b/src/dawn/tests/end2end/D3D12CachingTests.cpp @@ -13,97 +13,30 @@ // limitations under the License. #include -#include #include #include -#include #include "dawn/tests/DawnTest.h" +#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/WGPUHelpers.h" -#define EXPECT_CACHE_HIT(N, statement) \ - do { \ - size_t before = mPersistentCache.mHitCount; \ - statement; \ - FlushWire(); \ - size_t after = mPersistentCache.mHitCount; \ - EXPECT_EQ(N, after - before); \ - } while (0) - -// FakePersistentCache implements a in-memory persistent cache. -class FakePersistentCache : public dawn::platform::CachingInterface { - public: - // PersistentCache API - void StoreData(const WGPUDevice device, - const void* key, - size_t keySize, - const void* value, - size_t valueSize) override { - if (mIsDisabled) - return; - const std::string keyStr(reinterpret_cast(key), keySize); - - const uint8_t* value_start = reinterpret_cast(value); - std::vector entry_value(value_start, value_start + valueSize); - - EXPECT_TRUE(mCache.insert({keyStr, std::move(entry_value)}).second); - } - - size_t LoadData(const WGPUDevice device, - const void* key, - size_t keySize, - void* value, - size_t valueSize) override { - const std::string keyStr(reinterpret_cast(key), keySize); - auto entry = mCache.find(keyStr); - if (entry == mCache.end()) { - return 0; - } - if (valueSize >= entry->second.size()) { - memcpy(value, entry->second.data(), entry->second.size()); - } - mHitCount++; - return entry->second.size(); - } - - using Blob = std::vector; - using FakeCache = std::unordered_map; - - FakeCache mCache; - - size_t mHitCount = 0; - bool mIsDisabled = false; -}; - -// Test platform that only supports caching. -class DawnTestPlatform : public dawn::platform::Platform { - public: - explicit DawnTestPlatform(dawn::platform::CachingInterface* cachingInterface) - : mCachingInterface(cachingInterface) { - } - ~DawnTestPlatform() override = default; - - dawn::platform::CachingInterface* GetCachingInterface(const void* fingerprint, - size_t fingerprintSize) override { - return mCachingInterface; - } - - dawn::platform::CachingInterface* mCachingInterface = nullptr; -}; +namespace { + using ::testing::NiceMock; +} // namespace class D3D12CachingTests : public DawnTest { protected: std::unique_ptr CreateTestPlatform() override { - return std::make_unique(&mPersistentCache); + return std::make_unique(&mMockCache); } - FakePersistentCache mPersistentCache; + NiceMock mMockCache; }; -// Test that duplicate WGSL still re-compiles HLSL even when the cache is not enabled. +// Test that duplicate WGSL still works (and re-compiles HLSL) when the cache is not enabled. TEST_P(D3D12CachingTests, SameShaderNoCache) { - mPersistentCache.mIsDisabled = true; + mMockCache.Disable(); wgpu::ShaderModule module = utils::CreateShaderModule(device, R"( @stage(vertex) fn vertex_main() -> @builtin(position) vec4 { @@ -122,11 +55,9 @@ TEST_P(D3D12CachingTests, SameShaderNoCache) { desc.vertex.entryPoint = "vertex_main"; desc.cFragment.module = module; desc.cFragment.entryPoint = "fragment_main"; - - EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc)); + EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); } - - EXPECT_EQ(mPersistentCache.mCache.size(), 0u); + EXPECT_EQ(mMockCache.GetNumEntries(), 0u); // Load the same WGSL shader from the cache. { @@ -135,11 +66,9 @@ TEST_P(D3D12CachingTests, SameShaderNoCache) { desc.vertex.entryPoint = "vertex_main"; desc.cFragment.module = module; desc.cFragment.entryPoint = "fragment_main"; - - EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc)); + EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); } - - EXPECT_EQ(mPersistentCache.mCache.size(), 0u); + EXPECT_EQ(mMockCache.GetNumEntries(), 0u); } // Test creating a pipeline from two entrypoints in multiple stages will cache the correct number @@ -163,11 +92,9 @@ TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) { desc.vertex.entryPoint = "vertex_main"; desc.cFragment.module = module; desc.cFragment.entryPoint = "fragment_main"; - - EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc)); + EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); } - - EXPECT_EQ(mPersistentCache.mCache.size(), 2u); + EXPECT_EQ(mMockCache.GetNumEntries(), 2u); // Load the same WGSL shader from the cache. { @@ -176,13 +103,9 @@ TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) { desc.vertex.entryPoint = "vertex_main"; desc.cFragment.module = module; desc.cFragment.entryPoint = "fragment_main"; - - // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x - // kNumOfShaders hits. - EXPECT_CACHE_HIT(4u, device.CreateRenderPipeline(&desc)); + EXPECT_CACHE_HIT(mMockCache, 2u, device.CreateRenderPipeline(&desc)); } - - EXPECT_EQ(mPersistentCache.mCache.size(), 2u); + EXPECT_EQ(mMockCache.GetNumEntries(), 2u); // Modify the WGSL shader functions and make sure it doesn't hit. wgpu::ShaderModule newModule = utils::CreateShaderModule(device, R"( @@ -201,12 +124,9 @@ TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) { desc.vertex.entryPoint = "vertex_main"; desc.cFragment.module = newModule; desc.cFragment.entryPoint = "fragment_main"; - EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc)); + EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc)); } - - // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x - // kNumOfShaders hits. - EXPECT_EQ(mPersistentCache.mCache.size(), 4u); + EXPECT_EQ(mMockCache.GetNumEntries(), 4u); } // Test creating a WGSL shader with two entrypoints in the same stage will cache the correct number @@ -232,33 +152,26 @@ TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPoints) { wgpu::ComputePipelineDescriptor desc; desc.compute.module = module; desc.compute.entryPoint = "write1"; - EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc)); + EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); desc.compute.module = module; desc.compute.entryPoint = "write42"; - EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc)); + EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc)); } - - EXPECT_EQ(mPersistentCache.mCache.size(), 2u); + EXPECT_EQ(mMockCache.GetNumEntries(), 2u); // Load the same WGSL shader from the cache. { wgpu::ComputePipelineDescriptor desc; desc.compute.module = module; desc.compute.entryPoint = "write1"; - - // Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x - // kNumOfShaders hits. - EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc)); + EXPECT_CACHE_HIT(mMockCache, 1u, device.CreateComputePipeline(&desc)); desc.compute.module = module; desc.compute.entryPoint = "write42"; - - // Cached HLSL shader calls LoadData twice, so check 2 x kNumOfShaders hits. - EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc)); + EXPECT_CACHE_HIT(mMockCache, 1u, device.CreateComputePipeline(&desc)); } - - EXPECT_EQ(mPersistentCache.mCache.size(), 2u); + EXPECT_EQ(mMockCache.GetNumEntries(), 2u); } DAWN_INSTANTIATE_TEST(D3D12CachingTests, D3D12Backend()); diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp b/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp new file mode 100644 index 0000000000..d07b18d9d8 --- /dev/null +++ b/src/dawn/tests/end2end/mocks/CachingInterfaceMock.cpp @@ -0,0 +1,89 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h" + +using ::testing::Invoke; + +CachingInterfaceMock::CachingInterfaceMock() { + ON_CALL(*this, LoadData).WillByDefault(Invoke([=](auto&&... args) { + return LoadDataDefault(args...); + })); + ON_CALL(*this, StoreData).WillByDefault(Invoke([=](auto&&... args) { + return StoreDataDefault(args...); + })); +} + +void CachingInterfaceMock::Enable() { + mEnabled = true; +} + +void CachingInterfaceMock::Disable() { + mEnabled = false; +} + +size_t CachingInterfaceMock::GetHitCount() const { + return mHitCount; +} + +size_t CachingInterfaceMock::GetNumEntries() const { + return mCache.size(); +} + +size_t CachingInterfaceMock::LoadDataDefault(const WGPUDevice device, + const void* key, + size_t keySize, + void* value, + size_t valueSize) { + if (!mEnabled) { + return 0; + } + + const std::string keyStr(reinterpret_cast(key), keySize); + auto entry = mCache.find(keyStr); + if (entry == mCache.end()) { + return 0; + } + if (valueSize >= entry->second.size()) { + // Only consider a cache-hit on the memcpy, since peeks are implementation detail. + memcpy(value, entry->second.data(), entry->second.size()); + mHitCount++; + } + return entry->second.size(); +} + +void CachingInterfaceMock::StoreDataDefault(const WGPUDevice device, + const void* key, + size_t keySize, + const void* value, + size_t valueSize) { + if (!mEnabled) { + return; + } + + const std::string keyStr(reinterpret_cast(key), keySize); + const uint8_t* it = reinterpret_cast(value); + std::vector entry(it, it + valueSize); + mCache.insert_or_assign(keyStr, entry); +} + +DawnCachingMockPlatform::DawnCachingMockPlatform(dawn::platform::CachingInterface* cachingInterface) + : mCachingInterface(cachingInterface) { +} + +dawn::platform::CachingInterface* DawnCachingMockPlatform::GetCachingInterface( + const void* fingerprint, + size_t fingerprintSize) { + return mCachingInterface; +} diff --git a/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h b/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h new file mode 100644 index 0000000000..8cb1e1cfbe --- /dev/null +++ b/src/dawn/tests/end2end/mocks/CachingInterfaceMock.h @@ -0,0 +1,87 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ +#define SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_ + +#include +#include + +#include +#include +#include + +#define EXPECT_CACHE_HIT(cache, N, statement) \ + do { \ + FlushWire(); \ + size_t before = cache.GetHitCount(); \ + statement; \ + FlushWire(); \ + size_t after = cache.GetHitCount(); \ + EXPECT_EQ(N, after - before); \ + } while (0) + +// A mock caching interface class that also supplies an in memory cache for testing. +class CachingInterfaceMock : public dawn::platform::CachingInterface { + public: + CachingInterfaceMock(); + + // Toggles to disable/enable caching. + void Disable(); + void Enable(); + + // Returns the number of cache hits up to this point. + size_t GetHitCount() const; + + // Returns the number of entries in the cache. + size_t GetNumEntries() const; + + MOCK_METHOD(size_t, + LoadData, + (const WGPUDevice, const void*, size_t, void*, size_t), + (override)); + MOCK_METHOD(void, + StoreData, + (const WGPUDevice, const void*, size_t, const void*, size_t), + (override)); + + private: + size_t LoadDataDefault(const WGPUDevice device, + const void* key, + size_t keySize, + void* value, + size_t valueSize); + void StoreDataDefault(const WGPUDevice device, + const void* key, + size_t keySize, + const void* value, + size_t valueSize); + + bool mEnabled = true; + size_t mHitCount = 0; + std::unordered_map> mCache; +}; + +// Dawn platform used for testing with a mock caching interface. +class DawnCachingMockPlatform : public dawn::platform::Platform { + public: + explicit DawnCachingMockPlatform(dawn::platform::CachingInterface* cachingInterface); + dawn::platform::CachingInterface* GetCachingInterface(const void* fingerprint, + size_t fingerprintSize) override; + + private: + dawn::platform::CachingInterface* mCachingInterface = nullptr; +}; + +#endif // SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_