D3D12: Support caching DX shaders.

This change is a prerequisite to D3D pipeline caching.

This change introduces:
- Caching interface which enables the cache.
- Helper for backends to load/store blobs to be cached.
- Ability to cache HLSL shaders.

Bug:dawn:549
Change-Id: I2af759882d18b3f45dc63e49dcb6a3caa1be3485
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/32305
Commit-Queue: Bryan Bernhart <bryan.bernhart@intel.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
Bryan Bernhart 2020-11-20 20:38:37 +00:00 committed by Commit Bot service account
parent cf89a68f46
commit 41b3f9c1e4
18 changed files with 683 additions and 89 deletions

View File

@ -223,6 +223,8 @@ source_set("dawn_native_sources") {
"PassResourceUsageTracker.h",
"PerStage.cpp",
"PerStage.h",
"PersistentCache.cpp",
"PersistentCache.h",
"Pipeline.cpp",
"Pipeline.h",
"PipelineLayout.cpp",

View File

@ -108,6 +108,8 @@ target_sources(dawn_native PRIVATE
"PassResourceUsage.h"
"PassResourceUsageTracker.cpp"
"PassResourceUsageTracker.h"
"PersistentCache.cpp"
"PersistentCache.h"
"PerStage.cpp"
"PerStage.h"
"Pipeline.cpp"

View File

@ -31,6 +31,7 @@
#include "dawn_native/Fence.h"
#include "dawn_native/Instance.h"
#include "dawn_native/InternalPipelineStore.h"
#include "dawn_native/PersistentCache.h"
#include "dawn_native/PipelineLayout.h"
#include "dawn_native/QuerySet.h"
#include "dawn_native/Queue.h"
@ -132,6 +133,7 @@ namespace dawn_native {
mCreateReadyPipelineTracker = std::make_unique<CreateReadyPipelineTracker>(this);
mDeprecationWarnings = std::make_unique<DeprecationWarnings>();
mInternalPipelineStore = std::make_unique<InternalPipelineStore>();
mPersistentCache = std::make_unique<PersistentCache>(this);
// Starting from now the backend can start doing reentrant calls so the device is marked as
// alive.
@ -196,6 +198,7 @@ namespace dawn_native {
mErrorScopeTracker = nullptr;
mDynamicUploader = nullptr;
mCreateReadyPipelineTracker = nullptr;
mPersistentCache = nullptr;
mEmptyBindGroupLayout = nullptr;
@ -299,6 +302,11 @@ namespace dawn_native {
return mCurrentErrorScope.Get();
}
PersistentCache* DeviceBase::GetPersistentCache() {
ASSERT(mPersistentCache.get() != nullptr);
return mPersistentCache.get();
}
MaybeError DeviceBase::ValidateObject(const ObjectBase* object) const {
ASSERT(object != nullptr);
if (DAWN_UNLIKELY(object->GetDevice() != this)) {

View File

@ -37,6 +37,7 @@ namespace dawn_native {
class DynamicUploader;
class ErrorScope;
class ErrorScopeTracker;
class PersistentCache;
class StagingBufferBase;
struct InternalPipelineStore;
@ -180,6 +181,8 @@ namespace dawn_native {
ErrorScope* GetCurrentErrorScope();
PersistentCache* GetPersistentCache();
void Reference();
void Release();
@ -388,6 +391,8 @@ namespace dawn_native {
ExtensionsSet mEnabledExtensions;
std::unique_ptr<InternalPipelineStore> mInternalPipelineStore;
std::unique_ptr<PersistentCache> mPersistentCache;
};
} // namespace dawn_native

View File

@ -0,0 +1,65 @@
// Copyright 2020 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn_native/PersistentCache.h"
#include "common/Assert.h"
#include "dawn_native/Device.h"
#include "dawn_platform/DawnPlatform.h"
namespace dawn_native {
PersistentCache::PersistentCache(DeviceBase* device)
: mDevice(device), mCache(GetPlatformCache()) {
}
ScopedCachedBlob PersistentCache::LoadData(const PersistentCacheKey& key) {
ScopedCachedBlob blob = {};
if (mCache == nullptr) {
return blob;
}
blob.bufferSize = mCache->LoadData(reinterpret_cast<WGPUDevice>(mDevice), key.data(),
key.size(), nullptr, 0);
if (blob.bufferSize > 0) {
blob.buffer.reset(new uint8_t[blob.bufferSize]);
const size_t bufferSize =
mCache->LoadData(reinterpret_cast<WGPUDevice>(mDevice), key.data(), key.size(),
blob.buffer.get(), blob.bufferSize);
ASSERT(bufferSize == blob.bufferSize);
return blob;
}
return blob;
}
void PersistentCache::StoreData(const PersistentCacheKey& key, const void* value, size_t size) {
if (mCache == nullptr) {
return;
}
ASSERT(value != nullptr);
ASSERT(size > 0);
mCache->StoreData(reinterpret_cast<WGPUDevice>(mDevice), key.data(), key.size(), value,
size);
}
dawn_platform::CachingInterface* PersistentCache::GetPlatformCache() {
// TODO(dawn:549): Create a fingerprint of concatenated version strings (ex. Tint commit
// hash, Dawn commit hash). This will be used by the client so it may know when to discard
// previously cached Dawn objects should this fingerprint change.
dawn_platform::Platform* platform = mDevice->GetPlatform();
if (platform != nullptr) {
return platform->GetCachingInterface(/*fingerprint*/ nullptr, /*fingerprintSize*/ 0);
}
return nullptr;
}
} // namespace dawn_native

View File

@ -0,0 +1,86 @@
// Copyright 2020 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef DAWNNATIVE_PERSISTENTCACHE_H_
#define DAWNNATIVE_PERSISTENTCACHE_H_
#include "dawn_native/Error.h"
#include <vector>
namespace dawn_platform {
class CachingInterface;
}
namespace dawn_native {
using PersistentCacheKey = std::vector<uint8_t>;
struct ScopedCachedBlob {
std::unique_ptr<uint8_t[]> buffer;
size_t bufferSize = 0;
};
class DeviceBase;
enum class PersistentKeyType { Shader };
class PersistentCache {
public:
PersistentCache(DeviceBase* device);
// Combines load/store operations into a single call.
// If the load was successful, a non-empty blob is returned to the caller.
// Else, the creation callback |createFn| gets invoked with a callback
// |doCache| to store the newly created blob back in the cache.
//
// Example usage:
//
// ScopedCachedBlob cachedBlob = {};
// DAWN_TRY_ASSIGN(cachedBlob, GetOrCreate(key, [&](auto doCache)) {
// // Create a new blob to be stored
// doCache(newBlobPtr, newBlobSize); // store
// }));
//
template <typename CreateFn>
ResultOrError<ScopedCachedBlob> GetOrCreate(const PersistentCacheKey& key,
CreateFn&& createFn) {
// Attempt to load an existing blob from the cache.
ScopedCachedBlob blob = LoadData(key);
if (blob.bufferSize > 0) {
return std::move(blob);
}
// Allow the caller to create a new blob to be stored for the given key.
DAWN_TRY(createFn([this, key](const void* value, size_t size) {
this->StoreData(key, value, size);
}));
return std::move(blob);
}
private:
// PersistentCache impl
ScopedCachedBlob LoadData(const PersistentCacheKey& key);
void StoreData(const PersistentCacheKey& key, const void* value, size_t size);
dawn_platform::CachingInterface* GetPlatformCache();
DeviceBase* mDevice = nullptr;
dawn_platform::CachingInterface* mCache = nullptr;
};
} // namespace dawn_native
#endif // DAWNNATIVE_PERSISTENTCACHE_H_

View File

@ -43,44 +43,14 @@ namespace dawn_native { namespace d3d12 {
ShaderModule* module = ToBackend(descriptor->computeStage.module);
const char* entryPoint = descriptor->computeStage.entryPoint;
std::string remappedEntryPoint;
std::string hlslSource;
if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
DAWN_TRY_ASSIGN(hlslSource, module->TranslateToHLSLWithTint(
entryPoint, SingleShaderStage::Compute,
ToBackend(GetLayout()), &remappedEntryPoint));
entryPoint = remappedEntryPoint.c_str();
} else {
DAWN_TRY_ASSIGN(hlslSource,
module->TranslateToHLSLWithSPIRVCross(
entryPoint, SingleShaderStage::Compute, ToBackend(GetLayout())));
// Note that the HLSL will always use entryPoint "main" under SPIRV-cross.
entryPoint = "main";
}
D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
ComPtr<IDxcBlob> compiledDXCShader;
ComPtr<ID3DBlob> compiledFXCShader;
if (device->IsToggleEnabled(Toggle::UseDXC)) {
DAWN_TRY_ASSIGN(compiledDXCShader,
CompileShaderDXC(device, SingleShaderStage::Compute, hlslSource,
entryPoint, compileFlags));
d3dDesc.CS.pShaderBytecode = compiledDXCShader->GetBufferPointer();
d3dDesc.CS.BytecodeLength = compiledDXCShader->GetBufferSize();
} else {
DAWN_TRY_ASSIGN(compiledFXCShader,
CompileShaderFXC(device, SingleShaderStage::Compute, hlslSource,
entryPoint, compileFlags));
d3dDesc.CS.pShaderBytecode = compiledFXCShader->GetBufferPointer();
d3dDesc.CS.BytecodeLength = compiledFXCShader->GetBufferSize();
}
CompiledShader compiledShader;
DAWN_TRY_ASSIGN(compiledShader, module->Compile(descriptor->computeStage.entryPoint,
SingleShaderStage::Compute,
ToBackend(GetLayout()), compileFlags));
d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
device->GetD3D12Device()->CreateComputePipelineState(&d3dDesc,
IID_PPV_ARGS(&mPipelineState));
return {};

View File

@ -306,44 +306,13 @@ namespace dawn_native { namespace d3d12 {
shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
PerStage<ComPtr<ID3DBlob>> compiledFXCShader;
PerStage<ComPtr<IDxcBlob>> compiledDXCShader;
PerStage<CompiledShader> compiledShader;
wgpu::ShaderStage renderStages = wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment;
for (auto stage : IterateStages(renderStages)) {
std::string hlslSource;
const char* entryPoint = GetStage(stage).entryPoint.c_str();
std::string remappedEntryPoint;
if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
DAWN_TRY_ASSIGN(hlslSource, modules[stage]->TranslateToHLSLWithTint(
entryPoint, stage, ToBackend(GetLayout()),
&remappedEntryPoint));
entryPoint = remappedEntryPoint.c_str();
} else {
DAWN_TRY_ASSIGN(hlslSource, modules[stage]->TranslateToHLSLWithSPIRVCross(
entryPoint, stage, ToBackend(GetLayout())));
// Note that the HLSL will always use entryPoint "main" under SPIRV-cross.
entryPoint = "main";
}
if (device->IsToggleEnabled(Toggle::UseDXC)) {
DAWN_TRY_ASSIGN(
compiledDXCShader[stage],
CompileShaderDXC(device, stage, hlslSource, entryPoint, compileFlags));
shaders[stage]->pShaderBytecode = compiledDXCShader[stage]->GetBufferPointer();
shaders[stage]->BytecodeLength = compiledDXCShader[stage]->GetBufferSize();
} else {
DAWN_TRY_ASSIGN(
compiledFXCShader[stage],
CompileShaderFXC(device, stage, hlslSource, entryPoint, compileFlags));
shaders[stage]->pShaderBytecode = compiledFXCShader[stage]->GetBufferPointer();
shaders[stage]->BytecodeLength = compiledFXCShader[stage]->GetBufferSize();
}
DAWN_TRY_ASSIGN(compiledShader[stage],
modules[stage]->Compile(entryPoints[stage], stage,
ToBackend(GetLayout()), compileFlags));
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
}
PipelineLayout* layout = ToBackend(GetLayout());

View File

@ -309,4 +309,102 @@ namespace dawn_native { namespace d3d12 {
return compiler.compile();
}
ResultOrError<CompiledShader> ShaderModule::Compile(const char* entryPointName,
SingleShaderStage stage,
PipelineLayout* layout,
uint32_t compileFlags) {
Device* device = ToBackend(GetDevice());
// Compile the source shader to HLSL.
std::string hlslSource;
std::string remappedEntryPoint;
if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSLWithTint(entryPointName, stage, layout,
&remappedEntryPoint));
entryPointName = remappedEntryPoint.c_str();
} else {
DAWN_TRY_ASSIGN(hlslSource,
TranslateToHLSLWithSPIRVCross(entryPointName, stage, layout));
// Note that the HLSL will always use entryPoint "main" under
// SPIRV-cross.
entryPointName = "main";
}
// Use HLSL source as the input for the key since it does need to know about the pipeline
// layout. The pipeline layout is only required if we key from WGSL: two different pipeline
// layouts could be used to produce different shader blobs and the wrong shader blob could
// be loaded since the pipeline layout was missing from the key.
// TODO(dawn:549): Consider keying from WGSL and serialize the pipeline layout it used.
const PersistentCacheKey& shaderCacheKey =
CreateHLSLKey(entryPointName, stage, hlslSource, compileFlags);
CompiledShader compiledShader = {};
DAWN_TRY_ASSIGN(compiledShader.cachedShader,
device->GetPersistentCache()->GetOrCreate(
shaderCacheKey, [&](auto doCache) -> MaybeError {
if (device->IsToggleEnabled(Toggle::UseDXC)) {
DAWN_TRY_ASSIGN(compiledShader.compiledDXCShader,
CompileShaderDXC(device, stage, hlslSource,
entryPointName, compileFlags));
} else {
DAWN_TRY_ASSIGN(compiledShader.compiledFXCShader,
CompileShaderFXC(device, stage, hlslSource,
entryPointName, compileFlags));
}
const D3D12_SHADER_BYTECODE shader =
compiledShader.GetD3D12ShaderBytecode();
doCache(shader.pShaderBytecode, shader.BytecodeLength);
return {};
}));
return std::move(compiledShader);
}
D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const {
if (cachedShader.buffer != nullptr) {
return {cachedShader.buffer.get(), cachedShader.bufferSize};
} else if (compiledFXCShader != nullptr) {
return {compiledFXCShader->GetBufferPointer(), compiledFXCShader->GetBufferSize()};
} else if (compiledDXCShader != nullptr) {
return {compiledDXCShader->GetBufferPointer(), compiledDXCShader->GetBufferSize()};
}
UNREACHABLE();
return {};
}
PersistentCacheKey ShaderModule::CreateHLSLKey(const char* entryPointName,
SingleShaderStage stage,
const std::string& hlslSource,
uint32_t compileFlags) const {
std::stringstream stream;
// Prefix the key with the type to avoid collisions from another type that could have the
// same key.
stream << static_cast<uint32_t>(PersistentKeyType::Shader);
// Provide "guard" strings that the user cannot provide to help ensure the generated HLSL
// used to create this key is not being manufactured by the user to load the wrong shader
// blob.
// These strings can be HLSL comments because Tint does not emit HLSL comments.
// TODO(dawn:549): Replace guards strings with something more secure.
ASSERT(hlslSource.find("//") == std::string::npos);
stream << "// Start shader autogenerated by Dawn.";
stream << hlslSource;
stream << "// End of shader autogenerated by Dawn.";
stream << compileFlags;
// TODO(dawn:549): add the HLSL compiler version for good measure.
// If the source contains multiple entry points, ensure they are cached seperately
// per stage since DX shader code can only be compiled per stage using the same
// entry point.
stream << static_cast<uint32_t>(stage);
stream << entryPointName;
return PersistentCacheKey(std::istreambuf_iterator<char>{stream},
std::istreambuf_iterator<char>{});
}
}} // namespace dawn_native::d3d12

View File

@ -15,6 +15,7 @@
#ifndef DAWNNATIVE_D3D12_SHADERMODULED3D12_H_
#define DAWNNATIVE_D3D12_SHADERMODULED3D12_H_
#include "dawn_native/PersistentCache.h"
#include "dawn_native/ShaderModule.h"
#include "dawn_native/d3d12/d3d12_platform.h"
@ -24,22 +25,28 @@ namespace dawn_native { namespace d3d12 {
class Device;
class PipelineLayout;
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(Device* device,
SingleShaderStage stage,
const std::string& hlslSource,
const char* entryPoint,
uint32_t compileFlags);
ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(Device* device,
SingleShaderStage stage,
const std::string& hlslSource,
const char* entryPoint,
uint32_t compileFlags);
// Manages a ref to one of the various representations of shader blobs.
struct CompiledShader {
ScopedCachedBlob cachedShader;
ComPtr<ID3DBlob> compiledFXCShader;
ComPtr<IDxcBlob> compiledDXCShader;
D3D12_SHADER_BYTECODE GetD3D12ShaderBytecode() const;
};
class ShaderModule final : public ShaderModuleBase {
public:
static ResultOrError<ShaderModule*> Create(Device* device,
const ShaderModuleDescriptor* descriptor);
ResultOrError<CompiledShader> Compile(const char* entryPointName,
SingleShaderStage stage,
PipelineLayout* layout,
uint32_t compileFlags);
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default;
ResultOrError<std::string> TranslateToHLSLWithTint(
const char* entryPointName,
SingleShaderStage stage,
@ -50,9 +57,10 @@ namespace dawn_native { namespace d3d12 {
SingleShaderStage stage,
PipelineLayout* layout) const;
private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default;
PersistentCacheKey CreateHLSLKey(const char* entryPointName,
SingleShaderStage stage,
const std::string& hlslSource,
uint32_t compileFlags) const;
};
}} // namespace dawn_native::d3d12

View File

@ -31,4 +31,9 @@ dawn_component("dawn_platform") {
]
deps = [ "${dawn_root}/src/common" ]
public_deps = [
# DawnPlatform.h has #include <dawn/webgpu.h>
"${dawn_root}/src/dawn:dawn_headers",
]
}

View File

@ -27,4 +27,4 @@ target_sources(dawn_platform PRIVATE
"tracing/EventTracer.h"
"tracing/TraceEvent.h"
)
target_link_libraries(dawn_platform PRIVATE dawn_internal_config dawn_common)
target_link_libraries(dawn_platform PUBLIC dawn_headers PRIVATE dawn_internal_config dawn_common)

View File

@ -14,10 +14,45 @@
#include "dawn_platform/DawnPlatform.h"
#include "common/Assert.h"
namespace dawn_platform {
CachingInterface::CachingInterface() = default;
CachingInterface::~CachingInterface() = default;
Platform::Platform() = default;
Platform::~Platform() = default;
const unsigned char* Platform::GetTraceCategoryEnabledFlag(TraceCategory category) {
static unsigned char disabled = 0;
return &disabled;
}
double Platform::MonotonicallyIncreasingTime() {
return 0;
}
uint64_t Platform::AddTraceEvent(char phase,
const unsigned char* categoryGroupEnabled,
const char* name,
uint64_t id,
double timestamp,
int numArgs,
const char** argNames,
const unsigned char* argTypes,
const uint64_t* argValues,
unsigned char flags) {
// AddTraceEvent cannot be called if events are disabled.
ASSERT(false);
return 0;
}
dawn_platform::CachingInterface* Platform::GetCachingInterface(const void* fingerprint,
size_t fingerprintSize) {
return nullptr;
}
} // namespace dawn_platform

View File

@ -17,8 +17,11 @@
#include "dawn_platform/dawn_platform_export.h"
#include <cstddef>
#include <cstdint>
#include <dawn/webgpu.h>
namespace dawn_platform {
enum class TraceCategory {
@ -28,14 +31,43 @@ namespace dawn_platform {
GPUWork, // Actual GPU work
};
class DAWN_PLATFORM_EXPORT CachingInterface {
public:
CachingInterface();
virtual ~CachingInterface();
// LoadData has two modes. The first mode is used to get a value which
// corresponds to the |key|. The |valueOut| is a caller provided buffer
// allocated to the size |valueSize| which is loaded with data of the
// size returned. The second mode is used to query for the existence of
// the |key| where |valueOut| is nullptr and |valueSize| must be 0.
// The return size is non-zero if the |key| exists.
virtual size_t LoadData(const WGPUDevice device,
const void* key,
size_t keySize,
void* valueOut,
size_t valueSize) = 0;
// StoreData puts a |value| in the cache which corresponds to the |key|.
virtual void StoreData(const WGPUDevice device,
const void* key,
size_t keySize,
const void* value,
size_t valueSize) = 0;
private:
CachingInterface(const CachingInterface&) = delete;
CachingInterface& operator=(const CachingInterface&) = delete;
};
class DAWN_PLATFORM_EXPORT Platform {
public:
Platform();
virtual ~Platform();
virtual const unsigned char* GetTraceCategoryEnabledFlag(TraceCategory category) = 0;
virtual const unsigned char* GetTraceCategoryEnabledFlag(TraceCategory category);
virtual double MonotonicallyIncreasingTime() = 0;
virtual double MonotonicallyIncreasingTime();
virtual uint64_t AddTraceEvent(char phase,
const unsigned char* categoryGroupEnabled,
@ -46,7 +78,13 @@ namespace dawn_platform {
const char** argNames,
const unsigned char* argTypes,
const uint64_t* argValues,
unsigned char flags) = 0;
unsigned char flags);
// The |fingerprint| is provided by Dawn to inform the client to discard the Dawn caches
// when the fingerprint changes. The returned CachingInterface is expected to outlive the
// device which uses it to persistently cache objects.
virtual CachingInterface* GetCachingInterface(const void* fingerprint,
size_t fingerprintSize);
private:
Platform(const Platform&) = delete;

View File

@ -333,7 +333,10 @@ source_set("dawn_end2end_tests_sources") {
libs = []
if (dawn_enable_d3d12) {
sources += [ "end2end/D3D12ResourceWrappingTests.cpp" ]
sources += [
"end2end/D3D12CachingTests.cpp",
"end2end/D3D12ResourceWrappingTests.cpp",
]
libs += [
"d3d11.lib",
"dxgi.lib",

View File

@ -744,6 +744,10 @@ void DawnTestBase::SetUp() {
mBackendAdapter = *it;
}
// Setup the per-test platform. Tests can provide one by overloading CreateTestPlatform.
mTestPlatform = CreateTestPlatform();
gTestEnv->GetInstance()->SetPlatform(mTestPlatform.get());
// Create the device from the adapter
for (const char* forceEnabledWorkaround : mParam.forceEnabledWorkarounds) {
ASSERT(gTestEnv->GetInstance()->GetToggleInfo(forceEnabledWorkaround) != nullptr);
@ -1080,6 +1084,10 @@ void DawnTestBase::ResolveExpectations() {
}
}
std::unique_ptr<dawn_platform::Platform> DawnTestBase::CreateTestPlatform() {
return nullptr;
}
bool RGBA8::operator==(const RGBA8& other) const {
return r == other.r && g == other.g && b == other.b && a == other.a;
}

View File

@ -20,6 +20,7 @@
#include "dawn/webgpu_cpp.h"
#include "dawn_native/DawnNative.h"
#include <dawn_platform/DawnPlatform.h>
#include <gtest/gtest.h>
#include <memory>
@ -268,6 +269,8 @@ class DawnTestBase {
wgpu::Instance GetInstance() const;
dawn_native::Adapter GetAdapter() const;
virtual std::unique_ptr<dawn_platform::Platform> CreateTestPlatform();
protected:
wgpu::Device device;
wgpu::Queue queue;
@ -403,6 +406,8 @@ class DawnTestBase {
void ResolveExpectations();
dawn_native::Adapter mBackendAdapter;
std::unique_ptr<dawn_platform::Platform> mTestPlatform;
};
// Skip a test when the given condition is satisfied.

View File

@ -0,0 +1,287 @@
// Copyright 2020 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "tests/DawnTest.h"
#include "utils/ComboRenderPipelineDescriptor.h"
#include "utils/WGPUHelpers.h"
#define EXPECT_CACHE_HIT(N, statement) \
do { \
size_t before = mPersistentCache.mHitCount; \
statement; \
FlushWire(); \
size_t after = mPersistentCache.mHitCount; \
EXPECT_EQ(N, after - before); \
} while (0)
// FakePersistentCache implements a in-memory persistent cache.
class FakePersistentCache : public dawn_platform::CachingInterface {
public:
// PersistentCache API
void StoreData(const WGPUDevice device,
const void* key,
size_t keySize,
const void* value,
size_t valueSize) override {
if (mIsDisabled)
return;
const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
const uint8_t* value_start = reinterpret_cast<const uint8_t*>(value);
std::vector<uint8_t> entry_value(value_start, value_start + valueSize);
EXPECT_TRUE(mCache.insert({keyStr, std::move(entry_value)}).second);
}
size_t LoadData(const WGPUDevice device,
const void* key,
size_t keySize,
void* value,
size_t valueSize) override {
const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
auto entry = mCache.find(keyStr);
if (entry == mCache.end()) {
return 0;
}
if (valueSize >= entry->second.size()) {
memcpy(value, entry->second.data(), entry->second.size());
}
mHitCount++;
return entry->second.size();
}
using Blob = std::vector<uint8_t>;
using FakeCache = std::unordered_map<std::string, Blob>;
FakeCache mCache;
size_t mHitCount = 0;
bool mIsDisabled = false;
};
// Test platform that only supports caching.
class DawnTestPlatform : public dawn_platform::Platform {
public:
DawnTestPlatform(dawn_platform::CachingInterface* cachingInterface)
: mCachingInterface(cachingInterface) {
}
~DawnTestPlatform() override = default;
dawn_platform::CachingInterface* GetCachingInterface(const void* fingerprint,
size_t fingerprintSize) override {
return mCachingInterface;
}
dawn_platform::CachingInterface* mCachingInterface = nullptr;
};
class D3D12CachingTests : public DawnTest {
protected:
std::unique_ptr<dawn_platform::Platform> CreateTestPlatform() override {
return std::make_unique<DawnTestPlatform>(&mPersistentCache);
}
FakePersistentCache mPersistentCache;
};
// Test that duplicate WGSL still re-compiles HLSL even when the cache is not enabled.
TEST_P(D3D12CachingTests, SameShaderNoCache) {
mPersistentCache.mIsDisabled = true;
wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
[[builtin(position)]] var<out> Position : vec4<f32>;
[[stage(vertex)]]
fn vertex_main() -> void {
Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
return;
}
[[location(0)]] var<out> outColor : vec4<f32>;
[[stage(fragment)]]
fn fragment_main() -> void {
outColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
return;
}
)");
// Store the WGSL shader into the cache.
{
utils::ComboRenderPipelineDescriptor desc(device);
desc.vertexStage.module = module;
desc.vertexStage.entryPoint = "vertex_main";
desc.cFragmentStage.module = module;
desc.cFragmentStage.entryPoint = "fragment_main";
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
}
EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
// Load the same WGSL shader from the cache.
{
utils::ComboRenderPipelineDescriptor desc(device);
desc.vertexStage.module = module;
desc.vertexStage.entryPoint = "vertex_main";
desc.cFragmentStage.module = module;
desc.cFragmentStage.entryPoint = "fragment_main";
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
}
EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
}
// Test creating a pipeline from two entrypoints in multiple stages will cache the correct number
// of HLSL shaders. WGSL shader should result into caching 2 HLSL shaders (stage x
// entrypoints)
TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) {
wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
[[builtin(position)]] var<out> Position : vec4<f32>;
[[stage(vertex)]]
fn vertex_main() -> void {
Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
return;
}
[[location(0)]] var<out> outColor : vec4<f32>;
[[stage(fragment)]]
fn fragment_main() -> void {
outColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
return;
}
)");
// Store the WGSL shader into the cache.
{
utils::ComboRenderPipelineDescriptor desc(device);
desc.vertexStage.module = module;
desc.vertexStage.entryPoint = "vertex_main";
desc.cFragmentStage.module = module;
desc.cFragmentStage.entryPoint = "fragment_main";
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
}
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
// Load the same WGSL shader from the cache.
{
utils::ComboRenderPipelineDescriptor desc(device);
desc.vertexStage.module = module;
desc.vertexStage.entryPoint = "vertex_main";
desc.cFragmentStage.module = module;
desc.cFragmentStage.entryPoint = "fragment_main";
// Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
// kNumOfShaders hits.
EXPECT_CACHE_HIT(4u, device.CreateRenderPipeline(&desc));
}
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
// Modify the WGSL shader functions and make sure it doesn't hit.
wgpu::ShaderModule newModule = utils::CreateShaderModuleFromWGSL(device, R"(
[[builtin(position)]] var<out> Position : vec4<f32>;
[[stage(vertex)]]
fn vertex_main() -> void {
Position = vec4<f32>(1.0, 1.0, 1.0, 1.0);
return;
}
[[location(0)]] var<out> outColor : vec4<f32>;
[[stage(fragment)]]
fn fragment_main() -> void {
outColor = vec4<f32>(1.0, 1.0, 1.0, 1.0);
return;
}
)");
{
utils::ComboRenderPipelineDescriptor desc(device);
desc.vertexStage.module = newModule;
desc.vertexStage.entryPoint = "vertex_main";
desc.cFragmentStage.module = newModule;
desc.cFragmentStage.entryPoint = "fragment_main";
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
}
// Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
// kNumOfShaders hits.
EXPECT_EQ(mPersistentCache.mCache.size(), 4u);
}
// Test creating a WGSL shader with two entrypoints in the same stage will cache the correct number
// of HLSL shaders. WGSL shader should result into caching 1 HLSL shader (stage x entrypoints)
TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPoints) {
wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
[[block]] struct Data {
[[offset(0)]] data : u32;
};
[[binding(0), set(0)]] var<storage_buffer> data : Data;
[[stage(compute)]]
fn write1() -> void {
data.data = 1u;
return;
}
[[stage(compute)]]
fn write42() -> void {
data.data = 42u;
return;
}
)");
// Store the WGSL shader into the cache.
{
wgpu::ComputePipelineDescriptor desc;
desc.computeStage.module = module;
desc.computeStage.entryPoint = "write1";
EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
desc.computeStage.module = module;
desc.computeStage.entryPoint = "write42";
EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
}
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
// Load the same WGSL shader from the cache.
{
wgpu::ComputePipelineDescriptor desc;
desc.computeStage.module = module;
desc.computeStage.entryPoint = "write1";
// Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
// kNumOfShaders hits.
EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
desc.computeStage.module = module;
desc.computeStage.entryPoint = "write42";
// Cached HLSL shader calls LoadData twice, so check 2 x kNumOfShaders hits.
EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
}
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
}
DAWN_INSTANTIATE_TEST(D3D12CachingTests, D3D12Backend());