Factor out common cache testing code.

- To be used for pipeline cache testing.
- Plumbs overriding the platform to the adapters for testing.
- Restructures build a little bit so that the test infrastructure can have full access to dawn native internals. Also differentiates end2end and white_box a bit more to make it clear that end2end should not have access to dawn native internals.

Bug: dawn:549, dawn:1374
Change-Id: Ibcc6c44a116c7967ee2317c74409f613e896eb0a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/86841
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Loko Kung <lokokung@google.com>
This commit is contained in:
Loko Kung 2022-04-21 03:14:37 +00:00 committed by Dawn LUCI CQ
parent 12d45e2068
commit 4091e0fa9c
9 changed files with 274 additions and 175 deletions

View File

@ -164,6 +164,7 @@ namespace dawn::native {
// Enable debug capture on Dawn startup // Enable debug capture on Dawn startup
void EnableBeginCaptureOnStartup(bool beginCaptureOnStartup); void EnableBeginCaptureOnStartup(bool beginCaptureOnStartup);
// TODO(dawn:1374) Deprecate this once it is passed via the descriptor.
void SetPlatform(dawn::platform::Platform* platform); void SetPlatform(dawn::platform::Platform* platform);
// Returns the underlying WGPUInstance object. // Returns the underlying WGPUInstance object.

View File

@ -240,6 +240,7 @@ namespace dawn::native {
mImpl->EnableBeginCaptureOnStartup(beginCaptureOnStartup); mImpl->EnableBeginCaptureOnStartup(beginCaptureOnStartup);
} }
// TODO(dawn:1374) Deprecate this once it is passed via the descriptor.
void Instance::SetPlatform(dawn::platform::Platform* platform) { void Instance::SetPlatform(dawn::platform::Platform* platform) {
mImpl->SetPlatform(platform); mImpl->SetPlatform(platform);
} }

View File

@ -422,15 +422,12 @@ namespace dawn::native {
mBlobCache = std::make_unique<BlobCache>(GetCachingInterface(platform)); mBlobCache = std::make_unique<BlobCache>(GetCachingInterface(platform));
} }
dawn::platform::Platform* InstanceBase::GetPlatform() { void InstanceBase::SetPlatformForTesting(dawn::platform::Platform* platform) {
if (mPlatform != nullptr) { SetPlatform(platform);
return mPlatform;
} }
if (mDefaultPlatform == nullptr) { dawn::platform::Platform* InstanceBase::GetPlatform() {
mDefaultPlatform = std::make_unique<dawn::platform::Platform>(); return mPlatform;
}
return mDefaultPlatform.get();
} }
BlobCache* InstanceBase::GetBlobCache() { BlobCache* InstanceBase::GetBlobCache() {

View File

@ -77,9 +77,10 @@ namespace dawn::native {
void EnableBeginCaptureOnStartup(bool beginCaptureOnStartup); void EnableBeginCaptureOnStartup(bool beginCaptureOnStartup);
bool IsBeginCaptureOnStartupEnabled() const; bool IsBeginCaptureOnStartupEnabled() const;
// TODO(dawn:1374): SetPlatform should become a private helper, and a NOT thread-safe // TODO(dawn:1374): SetPlatform should become a private helper, and SetPlatformForTesting
// testing version exposed for special testing cases. // will become the NOT thread-safe testing version exposed for special testing cases.
void SetPlatform(dawn::platform::Platform* platform); void SetPlatform(dawn::platform::Platform* platform);
void SetPlatformForTesting(dawn::platform::Platform* platform);
dawn::platform::Platform* GetPlatform(); dawn::platform::Platform* GetPlatform();
BlobCache* GetBlobCache(); BlobCache* GetBlobCache();

View File

@ -330,32 +330,76 @@ dawn_test("dawn_unittests") {
} }
############################################################################### ###############################################################################
# Dawn end2end tests targets # Dawn test infrastructure targets
############################################################################### ###############################################################################
source_set("end2end_tests_sources") { source_set("test_infra_sources") {
configs += [ "${dawn_root}/src/dawn/common:internal_config" ] configs += [ "${dawn_root}/src/dawn/native:internal" ]
testonly = true testonly = true
deps = [ deps = [
":gmock_and_gtest",
"${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:cpp",
"${dawn_root}/src/dawn:proc", "${dawn_root}/src/dawn:proc",
"${dawn_root}/src/dawn/common", "${dawn_root}/src/dawn/common",
"${dawn_root}/src/dawn/native:sources",
# Statically linked because the end2end white_box tests use Dawn internals.
"${dawn_root}/src/dawn/native:static", "${dawn_root}/src/dawn/native:static",
"${dawn_root}/src/dawn/platform:platform",
"${dawn_root}/src/dawn/utils", "${dawn_root}/src/dawn/utils",
"${dawn_root}/src/dawn/wire", "${dawn_root}/src/dawn/wire",
] ]
public_deps = [ ":gmock_and_gtest" ]
if (dawn_supports_glfw_for_windowing || dawn_enable_opengl) {
assert(dawn_supports_glfw_for_windowing)
public_deps += [ "${dawn_root}/src/dawn/utils:glfw" ]
}
sources = [ sources = [
"DawnTest.cpp",
"DawnTest.h", "DawnTest.h",
"MockCallback.h", "MockCallback.h",
"ParamGenerator.h", "ParamGenerator.h",
"ToggleParser.cpp", "ToggleParser.cpp",
"ToggleParser.h", "ToggleParser.h",
]
}
###############################################################################
# Dawn end2end tests targets
###############################################################################
# Source code for mocks used for end2end testing are separated from the rest of
# sources so that they aren't included in non-test builds.
source_set("end2end_mocks_sources") {
configs += [ "${dawn_root}/src/dawn/native:internal" ]
testonly = true
deps = [
":gmock_and_gtest",
"${dawn_root}/src/dawn/platform",
]
sources = [
"end2end/mocks/CachingInterfaceMock.cpp",
"end2end/mocks/CachingInterfaceMock.h",
]
}
source_set("end2end_tests_sources") {
testonly = true
deps = [
":end2end_mocks_sources",
":test_infra_sources",
"${dawn_root}/src/dawn:cpp",
"${dawn_root}/src/dawn:proc",
"${dawn_root}/src/dawn/common",
"${dawn_root}/src/dawn/native:headers",
"${dawn_root}/src/dawn/utils",
"${dawn_root}/src/dawn/wire",
]
sources = [
"end2end/AdapterDiscoveryTests.cpp", "end2end/AdapterDiscoveryTests.cpp",
"end2end/BasicTests.cpp", "end2end/BasicTests.cpp",
"end2end/BindGroupTests.cpp", "end2end/BindGroupTests.cpp",
@ -461,17 +505,12 @@ source_set("end2end_tests_sources") {
frameworks = [ "IOSurface.framework" ] frameworks = [ "IOSurface.framework" ]
} }
if (dawn_enable_opengl) {
assert(dawn_supports_glfw_for_windowing)
}
if (dawn_supports_glfw_for_windowing) { if (dawn_supports_glfw_for_windowing) {
sources += [ sources += [
"end2end/SwapChainTests.cpp", "end2end/SwapChainTests.cpp",
"end2end/SwapChainValidationTests.cpp", "end2end/SwapChainValidationTests.cpp",
"end2end/WindowSurfaceTests.cpp", "end2end/WindowSurfaceTests.cpp",
] ]
deps += [ "${dawn_root}/src/dawn/utils:glfw" ]
} }
if (dawn_enable_d3d12 || (dawn_enable_vulkan && is_chromeos) || if (dawn_enable_d3d12 || (dawn_enable_vulkan && is_chromeos) ||
@ -492,23 +531,19 @@ source_set("white_box_tests_sources") {
testonly = true testonly = true
deps = [ deps = [
":gmock_and_gtest", ":test_infra_sources",
"${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:cpp",
"${dawn_root}/src/dawn:proc", "${dawn_root}/src/dawn:proc",
"${dawn_root}/src/dawn/common", "${dawn_root}/src/dawn/common",
"${dawn_root}/src/dawn/native:sources",
# Statically linked because the end2end white_box tests use Dawn internals. # Statically linked and with sources because the end2end white_box tests use Dawn internals.
"${dawn_root}/src/dawn/native:sources",
"${dawn_root}/src/dawn/native:static", "${dawn_root}/src/dawn/native:static",
"${dawn_root}/src/dawn/utils", "${dawn_root}/src/dawn/utils",
"${dawn_root}/src/dawn/wire", "${dawn_root}/src/dawn/wire",
] ]
sources = [ sources = []
"DawnTest.h",
"ParamGenerator.h",
"ToggleParser.h",
]
if (dawn_enable_vulkan) { if (dawn_enable_vulkan) {
deps += [ "${dawn_vulkan_headers_dir}:vulkan_headers" ] deps += [ "${dawn_vulkan_headers_dir}:vulkan_headers" ]
@ -552,10 +587,6 @@ source_set("white_box_tests_sources") {
sources += [ "white_box/MetalAutoreleasePoolTests.mm" ] sources += [ "white_box/MetalAutoreleasePoolTests.mm" ]
} }
if (dawn_enable_opengl) {
deps += [ "${dawn_root}/src/dawn/utils:glfw" ]
}
if (dawn_enable_opengles && defined(dawn_angle_dir)) { if (dawn_enable_opengles && defined(dawn_angle_dir)) {
sources += [ "white_box/EGLImageWrappingTests.cpp" ] sources += [ "white_box/EGLImageWrappingTests.cpp" ]
deps += [ "${dawn_angle_dir}:libEGL" ] deps += [ "${dawn_angle_dir}:libEGL" ]
@ -567,22 +598,11 @@ source_set("white_box_tests_sources") {
dawn_test("dawn_end2end_tests") { dawn_test("dawn_end2end_tests") {
deps = [ deps = [
":end2end_tests_sources", ":end2end_tests_sources",
":gmock_and_gtest", ":test_infra_sources",
":white_box_tests_sources", ":white_box_tests_sources",
"${dawn_root}/src/dawn:cpp",
"${dawn_root}/src/dawn:proc",
"${dawn_root}/src/dawn/common",
"${dawn_root}/src/dawn/native:static",
"${dawn_root}/src/dawn/platform:platform",
"${dawn_root}/src/dawn/utils",
"${dawn_root}/src/dawn/wire",
]
sources = [
"DawnTest.cpp",
"DawnTest.h",
] ]
sources = []
libs = [] libs = []
# When building inside Chromium, use their gtest main function because it is # When building inside Chromium, use their gtest main function because it is
@ -593,10 +613,6 @@ dawn_test("dawn_end2end_tests") {
sources += [ "End2EndTestsMain.cpp" ] sources += [ "End2EndTestsMain.cpp" ]
} }
if (dawn_enable_opengl) {
deps += [ "${dawn_root}/src/dawn/utils:glfw" ]
}
if (is_chromeos) { if (is_chromeos) {
libs += [ "gbm" ] libs += [ "gbm" ]
} }
@ -608,22 +624,16 @@ dawn_test("dawn_end2end_tests") {
dawn_test("dawn_perf_tests") { dawn_test("dawn_perf_tests") {
deps = [ deps = [
":gmock_and_gtest", ":test_infra_sources",
"${dawn_root}/src/dawn:cpp", "${dawn_root}/src/dawn:cpp",
"${dawn_root}/src/dawn:proc", "${dawn_root}/src/dawn:proc",
"${dawn_root}/src/dawn/common", "${dawn_root}/src/dawn/common",
"${dawn_root}/src/dawn/native",
"${dawn_root}/src/dawn/platform", "${dawn_root}/src/dawn/platform",
"${dawn_root}/src/dawn/utils", "${dawn_root}/src/dawn/utils",
"${dawn_root}/src/dawn/wire", "${dawn_root}/src/dawn/wire",
] ]
sources = [ sources = [
"DawnTest.cpp",
"DawnTest.h",
"ParamGenerator.h",
"ToggleParser.cpp",
"ToggleParser.h",
"perf_tests/BufferUploadPerf.cpp", "perf_tests/BufferUploadPerf.cpp",
"perf_tests/DawnPerfTest.cpp", "perf_tests/DawnPerfTest.cpp",
"perf_tests/DawnPerfTest.h", "perf_tests/DawnPerfTest.h",
@ -648,8 +658,4 @@ dawn_test("dawn_perf_tests") {
if (dawn_enable_metal) { if (dawn_enable_metal) {
frameworks = [ "IOSurface.framework" ] frameworks = [ "IOSurface.framework" ]
} }
if (dawn_enable_opengl) {
deps += [ "${dawn_root}/src/dawn/utils:glfw" ]
}
} }

View File

@ -30,6 +30,8 @@
#include "dawn/common/Platform.h" #include "dawn/common/Platform.h"
#include "dawn/common/SystemUtils.h" #include "dawn/common/SystemUtils.h"
#include "dawn/dawn_proc.h" #include "dawn/dawn_proc.h"
#include "dawn/native/Instance.h"
#include "dawn/native/dawn_platform.h"
#include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h"
#include "dawn/utils/PlatformDebugLogger.h" #include "dawn/utils/PlatformDebugLogger.h"
#include "dawn/utils/SystemUtils.h" #include "dawn/utils/SystemUtils.h"
@ -921,9 +923,11 @@ void DawnTestBase::SetUp() {
mBackendAdapter = *it; mBackendAdapter = *it;
} }
// Setup the per-test platform. Tests can provide one by overloading CreateTestPlatform. // Setup the per-test platform. Tests can provide one by overloading CreateTestPlatform. This is
// NOT a thread-safe operation and is allowed here for testing only.
mTestPlatform = CreateTestPlatform(); mTestPlatform = CreateTestPlatform();
gTestEnv->GetInstance()->SetPlatform(mTestPlatform.get()); dawn::native::FromAPI(gTestEnv->GetInstance()->Get())
->SetPlatformForTesting(mTestPlatform.get());
// Create the device from the adapter // Create the device from the adapter
for (const char* forceEnabledWorkaround : mParam.forceEnabledWorkarounds) { for (const char* forceEnabledWorkaround : mParam.forceEnabledWorkarounds) {

View File

@ -13,97 +13,30 @@
// limitations under the License. // limitations under the License.
#include <memory> #include <memory>
#include <string>
#include <unordered_map> #include <unordered_map>
#include <utility> #include <utility>
#include <vector>
#include "dawn/tests/DawnTest.h" #include "dawn/tests/DawnTest.h"
#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
#include "dawn/utils/ComboRenderPipelineDescriptor.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h"
#include "dawn/utils/WGPUHelpers.h" #include "dawn/utils/WGPUHelpers.h"
#define EXPECT_CACHE_HIT(N, statement) \ namespace {
do { \ using ::testing::NiceMock;
size_t before = mPersistentCache.mHitCount; \ } // namespace
statement; \
FlushWire(); \
size_t after = mPersistentCache.mHitCount; \
EXPECT_EQ(N, after - before); \
} while (0)
// FakePersistentCache implements a in-memory persistent cache.
class FakePersistentCache : public dawn::platform::CachingInterface {
public:
// PersistentCache API
void StoreData(const WGPUDevice device,
const void* key,
size_t keySize,
const void* value,
size_t valueSize) override {
if (mIsDisabled)
return;
const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
const uint8_t* value_start = reinterpret_cast<const uint8_t*>(value);
std::vector<uint8_t> entry_value(value_start, value_start + valueSize);
EXPECT_TRUE(mCache.insert({keyStr, std::move(entry_value)}).second);
}
size_t LoadData(const WGPUDevice device,
const void* key,
size_t keySize,
void* value,
size_t valueSize) override {
const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
auto entry = mCache.find(keyStr);
if (entry == mCache.end()) {
return 0;
}
if (valueSize >= entry->second.size()) {
memcpy(value, entry->second.data(), entry->second.size());
}
mHitCount++;
return entry->second.size();
}
using Blob = std::vector<uint8_t>;
using FakeCache = std::unordered_map<std::string, Blob>;
FakeCache mCache;
size_t mHitCount = 0;
bool mIsDisabled = false;
};
// Test platform that only supports caching.
class DawnTestPlatform : public dawn::platform::Platform {
public:
explicit DawnTestPlatform(dawn::platform::CachingInterface* cachingInterface)
: mCachingInterface(cachingInterface) {
}
~DawnTestPlatform() override = default;
dawn::platform::CachingInterface* GetCachingInterface(const void* fingerprint,
size_t fingerprintSize) override {
return mCachingInterface;
}
dawn::platform::CachingInterface* mCachingInterface = nullptr;
};
class D3D12CachingTests : public DawnTest { class D3D12CachingTests : public DawnTest {
protected: protected:
std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override { std::unique_ptr<dawn::platform::Platform> CreateTestPlatform() override {
return std::make_unique<DawnTestPlatform>(&mPersistentCache); return std::make_unique<DawnCachingMockPlatform>(&mMockCache);
} }
FakePersistentCache mPersistentCache; NiceMock<CachingInterfaceMock> mMockCache;
}; };
// Test that duplicate WGSL still re-compiles HLSL even when the cache is not enabled. // Test that duplicate WGSL still works (and re-compiles HLSL) when the cache is not enabled.
TEST_P(D3D12CachingTests, SameShaderNoCache) { TEST_P(D3D12CachingTests, SameShaderNoCache) {
mPersistentCache.mIsDisabled = true; mMockCache.Disable();
wgpu::ShaderModule module = utils::CreateShaderModule(device, R"( wgpu::ShaderModule module = utils::CreateShaderModule(device, R"(
@stage(vertex) fn vertex_main() -> @builtin(position) vec4<f32> { @stage(vertex) fn vertex_main() -> @builtin(position) vec4<f32> {
@ -122,11 +55,9 @@ TEST_P(D3D12CachingTests, SameShaderNoCache) {
desc.vertex.entryPoint = "vertex_main"; desc.vertex.entryPoint = "vertex_main";
desc.cFragment.module = module; desc.cFragment.module = module;
desc.cFragment.entryPoint = "fragment_main"; desc.cFragment.entryPoint = "fragment_main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc));
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 0u);
EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
// Load the same WGSL shader from the cache. // Load the same WGSL shader from the cache.
{ {
@ -135,11 +66,9 @@ TEST_P(D3D12CachingTests, SameShaderNoCache) {
desc.vertex.entryPoint = "vertex_main"; desc.vertex.entryPoint = "vertex_main";
desc.cFragment.module = module; desc.cFragment.module = module;
desc.cFragment.entryPoint = "fragment_main"; desc.cFragment.entryPoint = "fragment_main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc));
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 0u);
EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
} }
// Test creating a pipeline from two entrypoints in multiple stages will cache the correct number // Test creating a pipeline from two entrypoints in multiple stages will cache the correct number
@ -163,11 +92,9 @@ TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) {
desc.vertex.entryPoint = "vertex_main"; desc.vertex.entryPoint = "vertex_main";
desc.cFragment.module = module; desc.cFragment.module = module;
desc.cFragment.entryPoint = "fragment_main"; desc.cFragment.entryPoint = "fragment_main";
EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc));
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 2u);
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
// Load the same WGSL shader from the cache. // Load the same WGSL shader from the cache.
{ {
@ -176,13 +103,9 @@ TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) {
desc.vertex.entryPoint = "vertex_main"; desc.vertex.entryPoint = "vertex_main";
desc.cFragment.module = module; desc.cFragment.module = module;
desc.cFragment.entryPoint = "fragment_main"; desc.cFragment.entryPoint = "fragment_main";
EXPECT_CACHE_HIT(mMockCache, 2u, device.CreateRenderPipeline(&desc));
// Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
// kNumOfShaders hits.
EXPECT_CACHE_HIT(4u, device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 2u);
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
// Modify the WGSL shader functions and make sure it doesn't hit. // Modify the WGSL shader functions and make sure it doesn't hit.
wgpu::ShaderModule newModule = utils::CreateShaderModule(device, R"( wgpu::ShaderModule newModule = utils::CreateShaderModule(device, R"(
@ -201,12 +124,9 @@ TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) {
desc.vertex.entryPoint = "vertex_main"; desc.vertex.entryPoint = "vertex_main";
desc.cFragment.module = newModule; desc.cFragment.module = newModule;
desc.cFragment.entryPoint = "fragment_main"; desc.cFragment.entryPoint = "fragment_main";
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc)); EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateRenderPipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 4u);
// Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
// kNumOfShaders hits.
EXPECT_EQ(mPersistentCache.mCache.size(), 4u);
} }
// Test creating a WGSL shader with two entrypoints in the same stage will cache the correct number // Test creating a WGSL shader with two entrypoints in the same stage will cache the correct number
@ -232,33 +152,26 @@ TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPoints) {
wgpu::ComputePipelineDescriptor desc; wgpu::ComputePipelineDescriptor desc;
desc.compute.module = module; desc.compute.module = module;
desc.compute.entryPoint = "write1"; desc.compute.entryPoint = "write1";
EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc));
desc.compute.module = module; desc.compute.module = module;
desc.compute.entryPoint = "write42"; desc.compute.entryPoint = "write42";
EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc)); EXPECT_CACHE_HIT(mMockCache, 0u, device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 2u);
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
// Load the same WGSL shader from the cache. // Load the same WGSL shader from the cache.
{ {
wgpu::ComputePipelineDescriptor desc; wgpu::ComputePipelineDescriptor desc;
desc.compute.module = module; desc.compute.module = module;
desc.compute.entryPoint = "write1"; desc.compute.entryPoint = "write1";
EXPECT_CACHE_HIT(mMockCache, 1u, device.CreateComputePipeline(&desc));
// Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
// kNumOfShaders hits.
EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
desc.compute.module = module; desc.compute.module = module;
desc.compute.entryPoint = "write42"; desc.compute.entryPoint = "write42";
EXPECT_CACHE_HIT(mMockCache, 1u, device.CreateComputePipeline(&desc));
// Cached HLSL shader calls LoadData twice, so check 2 x kNumOfShaders hits.
EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
} }
EXPECT_EQ(mMockCache.GetNumEntries(), 2u);
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
} }
DAWN_INSTANTIATE_TEST(D3D12CachingTests, D3D12Backend()); DAWN_INSTANTIATE_TEST(D3D12CachingTests, D3D12Backend());

View File

@ -0,0 +1,89 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn/tests/end2end/mocks/CachingInterfaceMock.h"
using ::testing::Invoke;
CachingInterfaceMock::CachingInterfaceMock() {
ON_CALL(*this, LoadData).WillByDefault(Invoke([=](auto&&... args) {
return LoadDataDefault(args...);
}));
ON_CALL(*this, StoreData).WillByDefault(Invoke([=](auto&&... args) {
return StoreDataDefault(args...);
}));
}
void CachingInterfaceMock::Enable() {
mEnabled = true;
}
void CachingInterfaceMock::Disable() {
mEnabled = false;
}
size_t CachingInterfaceMock::GetHitCount() const {
return mHitCount;
}
size_t CachingInterfaceMock::GetNumEntries() const {
return mCache.size();
}
size_t CachingInterfaceMock::LoadDataDefault(const WGPUDevice device,
const void* key,
size_t keySize,
void* value,
size_t valueSize) {
if (!mEnabled) {
return 0;
}
const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
auto entry = mCache.find(keyStr);
if (entry == mCache.end()) {
return 0;
}
if (valueSize >= entry->second.size()) {
// Only consider a cache-hit on the memcpy, since peeks are implementation detail.
memcpy(value, entry->second.data(), entry->second.size());
mHitCount++;
}
return entry->second.size();
}
void CachingInterfaceMock::StoreDataDefault(const WGPUDevice device,
const void* key,
size_t keySize,
const void* value,
size_t valueSize) {
if (!mEnabled) {
return;
}
const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
const uint8_t* it = reinterpret_cast<const uint8_t*>(value);
std::vector<uint8_t> entry(it, it + valueSize);
mCache.insert_or_assign(keyStr, entry);
}
DawnCachingMockPlatform::DawnCachingMockPlatform(dawn::platform::CachingInterface* cachingInterface)
: mCachingInterface(cachingInterface) {
}
dawn::platform::CachingInterface* DawnCachingMockPlatform::GetCachingInterface(
const void* fingerprint,
size_t fingerprintSize) {
return mCachingInterface;
}

View File

@ -0,0 +1,87 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
#define SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_
#include <dawn/platform/DawnPlatform.h>
#include <gmock/gmock.h>
#include <string>
#include <unordered_map>
#include <vector>
#define EXPECT_CACHE_HIT(cache, N, statement) \
do { \
FlushWire(); \
size_t before = cache.GetHitCount(); \
statement; \
FlushWire(); \
size_t after = cache.GetHitCount(); \
EXPECT_EQ(N, after - before); \
} while (0)
// A mock caching interface class that also supplies an in memory cache for testing.
class CachingInterfaceMock : public dawn::platform::CachingInterface {
public:
CachingInterfaceMock();
// Toggles to disable/enable caching.
void Disable();
void Enable();
// Returns the number of cache hits up to this point.
size_t GetHitCount() const;
// Returns the number of entries in the cache.
size_t GetNumEntries() const;
MOCK_METHOD(size_t,
LoadData,
(const WGPUDevice, const void*, size_t, void*, size_t),
(override));
MOCK_METHOD(void,
StoreData,
(const WGPUDevice, const void*, size_t, const void*, size_t),
(override));
private:
size_t LoadDataDefault(const WGPUDevice device,
const void* key,
size_t keySize,
void* value,
size_t valueSize);
void StoreDataDefault(const WGPUDevice device,
const void* key,
size_t keySize,
const void* value,
size_t valueSize);
bool mEnabled = true;
size_t mHitCount = 0;
std::unordered_map<std::string, std::vector<uint8_t>> mCache;
};
// Dawn platform used for testing with a mock caching interface.
class DawnCachingMockPlatform : public dawn::platform::Platform {
public:
explicit DawnCachingMockPlatform(dawn::platform::CachingInterface* cachingInterface);
dawn::platform::CachingInterface* GetCachingInterface(const void* fingerprint,
size_t fingerprintSize) override;
private:
dawn::platform::CachingInterface* mCachingInterface = nullptr;
};
#endif // SRC_DAWN_TESTS_END2END_MOCKS_CACHINGINTERFACEMOCK_H_