D3D12: Support caching DX shaders.
This change is a prerequisite to D3D pipeline caching. This change introduces: - Caching interface which enables the cache. - Helper for backends to load/store blobs to be cached. - Ability to cache HLSL shaders. Bug:dawn:549 Change-Id: I2af759882d18b3f45dc63e49dcb6a3caa1be3485 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/32305 Commit-Queue: Bryan Bernhart <bryan.bernhart@intel.com> Reviewed-by: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
parent
cf89a68f46
commit
41b3f9c1e4
|
@ -223,6 +223,8 @@ source_set("dawn_native_sources") {
|
|||
"PassResourceUsageTracker.h",
|
||||
"PerStage.cpp",
|
||||
"PerStage.h",
|
||||
"PersistentCache.cpp",
|
||||
"PersistentCache.h",
|
||||
"Pipeline.cpp",
|
||||
"Pipeline.h",
|
||||
"PipelineLayout.cpp",
|
||||
|
|
|
@ -108,6 +108,8 @@ target_sources(dawn_native PRIVATE
|
|||
"PassResourceUsage.h"
|
||||
"PassResourceUsageTracker.cpp"
|
||||
"PassResourceUsageTracker.h"
|
||||
"PersistentCache.cpp"
|
||||
"PersistentCache.h"
|
||||
"PerStage.cpp"
|
||||
"PerStage.h"
|
||||
"Pipeline.cpp"
|
||||
|
|
|
@ -31,6 +31,7 @@
|
|||
#include "dawn_native/Fence.h"
|
||||
#include "dawn_native/Instance.h"
|
||||
#include "dawn_native/InternalPipelineStore.h"
|
||||
#include "dawn_native/PersistentCache.h"
|
||||
#include "dawn_native/PipelineLayout.h"
|
||||
#include "dawn_native/QuerySet.h"
|
||||
#include "dawn_native/Queue.h"
|
||||
|
@ -132,6 +133,7 @@ namespace dawn_native {
|
|||
mCreateReadyPipelineTracker = std::make_unique<CreateReadyPipelineTracker>(this);
|
||||
mDeprecationWarnings = std::make_unique<DeprecationWarnings>();
|
||||
mInternalPipelineStore = std::make_unique<InternalPipelineStore>();
|
||||
mPersistentCache = std::make_unique<PersistentCache>(this);
|
||||
|
||||
// Starting from now the backend can start doing reentrant calls so the device is marked as
|
||||
// alive.
|
||||
|
@ -196,6 +198,7 @@ namespace dawn_native {
|
|||
mErrorScopeTracker = nullptr;
|
||||
mDynamicUploader = nullptr;
|
||||
mCreateReadyPipelineTracker = nullptr;
|
||||
mPersistentCache = nullptr;
|
||||
|
||||
mEmptyBindGroupLayout = nullptr;
|
||||
|
||||
|
@ -299,6 +302,11 @@ namespace dawn_native {
|
|||
return mCurrentErrorScope.Get();
|
||||
}
|
||||
|
||||
PersistentCache* DeviceBase::GetPersistentCache() {
|
||||
ASSERT(mPersistentCache.get() != nullptr);
|
||||
return mPersistentCache.get();
|
||||
}
|
||||
|
||||
MaybeError DeviceBase::ValidateObject(const ObjectBase* object) const {
|
||||
ASSERT(object != nullptr);
|
||||
if (DAWN_UNLIKELY(object->GetDevice() != this)) {
|
||||
|
|
|
@ -37,6 +37,7 @@ namespace dawn_native {
|
|||
class DynamicUploader;
|
||||
class ErrorScope;
|
||||
class ErrorScopeTracker;
|
||||
class PersistentCache;
|
||||
class StagingBufferBase;
|
||||
struct InternalPipelineStore;
|
||||
|
||||
|
@ -180,6 +181,8 @@ namespace dawn_native {
|
|||
|
||||
ErrorScope* GetCurrentErrorScope();
|
||||
|
||||
PersistentCache* GetPersistentCache();
|
||||
|
||||
void Reference();
|
||||
void Release();
|
||||
|
||||
|
@ -388,6 +391,8 @@ namespace dawn_native {
|
|||
ExtensionsSet mEnabledExtensions;
|
||||
|
||||
std::unique_ptr<InternalPipelineStore> mInternalPipelineStore;
|
||||
|
||||
std::unique_ptr<PersistentCache> mPersistentCache;
|
||||
};
|
||||
|
||||
} // namespace dawn_native
|
||||
|
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright 2020 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "dawn_native/PersistentCache.h"
|
||||
|
||||
#include "common/Assert.h"
|
||||
#include "dawn_native/Device.h"
|
||||
#include "dawn_platform/DawnPlatform.h"
|
||||
|
||||
namespace dawn_native {
|
||||
|
||||
PersistentCache::PersistentCache(DeviceBase* device)
|
||||
: mDevice(device), mCache(GetPlatformCache()) {
|
||||
}
|
||||
|
||||
ScopedCachedBlob PersistentCache::LoadData(const PersistentCacheKey& key) {
|
||||
ScopedCachedBlob blob = {};
|
||||
if (mCache == nullptr) {
|
||||
return blob;
|
||||
}
|
||||
blob.bufferSize = mCache->LoadData(reinterpret_cast<WGPUDevice>(mDevice), key.data(),
|
||||
key.size(), nullptr, 0);
|
||||
if (blob.bufferSize > 0) {
|
||||
blob.buffer.reset(new uint8_t[blob.bufferSize]);
|
||||
const size_t bufferSize =
|
||||
mCache->LoadData(reinterpret_cast<WGPUDevice>(mDevice), key.data(), key.size(),
|
||||
blob.buffer.get(), blob.bufferSize);
|
||||
ASSERT(bufferSize == blob.bufferSize);
|
||||
return blob;
|
||||
}
|
||||
return blob;
|
||||
}
|
||||
|
||||
void PersistentCache::StoreData(const PersistentCacheKey& key, const void* value, size_t size) {
|
||||
if (mCache == nullptr) {
|
||||
return;
|
||||
}
|
||||
ASSERT(value != nullptr);
|
||||
ASSERT(size > 0);
|
||||
mCache->StoreData(reinterpret_cast<WGPUDevice>(mDevice), key.data(), key.size(), value,
|
||||
size);
|
||||
}
|
||||
|
||||
dawn_platform::CachingInterface* PersistentCache::GetPlatformCache() {
|
||||
// TODO(dawn:549): Create a fingerprint of concatenated version strings (ex. Tint commit
|
||||
// hash, Dawn commit hash). This will be used by the client so it may know when to discard
|
||||
// previously cached Dawn objects should this fingerprint change.
|
||||
dawn_platform::Platform* platform = mDevice->GetPlatform();
|
||||
if (platform != nullptr) {
|
||||
return platform->GetCachingInterface(/*fingerprint*/ nullptr, /*fingerprintSize*/ 0);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace dawn_native
|
|
@ -0,0 +1,86 @@
|
|||
// Copyright 2020 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef DAWNNATIVE_PERSISTENTCACHE_H_
|
||||
#define DAWNNATIVE_PERSISTENTCACHE_H_
|
||||
|
||||
#include "dawn_native/Error.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
namespace dawn_platform {
|
||||
class CachingInterface;
|
||||
}
|
||||
|
||||
namespace dawn_native {
|
||||
|
||||
using PersistentCacheKey = std::vector<uint8_t>;
|
||||
|
||||
struct ScopedCachedBlob {
|
||||
std::unique_ptr<uint8_t[]> buffer;
|
||||
size_t bufferSize = 0;
|
||||
};
|
||||
|
||||
class DeviceBase;
|
||||
|
||||
enum class PersistentKeyType { Shader };
|
||||
|
||||
class PersistentCache {
|
||||
public:
|
||||
PersistentCache(DeviceBase* device);
|
||||
|
||||
// Combines load/store operations into a single call.
|
||||
// If the load was successful, a non-empty blob is returned to the caller.
|
||||
// Else, the creation callback |createFn| gets invoked with a callback
|
||||
// |doCache| to store the newly created blob back in the cache.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// ScopedCachedBlob cachedBlob = {};
|
||||
// DAWN_TRY_ASSIGN(cachedBlob, GetOrCreate(key, [&](auto doCache)) {
|
||||
// // Create a new blob to be stored
|
||||
// doCache(newBlobPtr, newBlobSize); // store
|
||||
// }));
|
||||
//
|
||||
template <typename CreateFn>
|
||||
ResultOrError<ScopedCachedBlob> GetOrCreate(const PersistentCacheKey& key,
|
||||
CreateFn&& createFn) {
|
||||
// Attempt to load an existing blob from the cache.
|
||||
ScopedCachedBlob blob = LoadData(key);
|
||||
if (blob.bufferSize > 0) {
|
||||
return std::move(blob);
|
||||
}
|
||||
|
||||
// Allow the caller to create a new blob to be stored for the given key.
|
||||
DAWN_TRY(createFn([this, key](const void* value, size_t size) {
|
||||
this->StoreData(key, value, size);
|
||||
}));
|
||||
|
||||
return std::move(blob);
|
||||
}
|
||||
|
||||
private:
|
||||
// PersistentCache impl
|
||||
ScopedCachedBlob LoadData(const PersistentCacheKey& key);
|
||||
void StoreData(const PersistentCacheKey& key, const void* value, size_t size);
|
||||
|
||||
dawn_platform::CachingInterface* GetPlatformCache();
|
||||
|
||||
DeviceBase* mDevice = nullptr;
|
||||
|
||||
dawn_platform::CachingInterface* mCache = nullptr;
|
||||
};
|
||||
} // namespace dawn_native
|
||||
|
||||
#endif // DAWNNATIVE_PERSISTENTCACHE_H_
|
|
@ -43,44 +43,14 @@ namespace dawn_native { namespace d3d12 {
|
|||
|
||||
ShaderModule* module = ToBackend(descriptor->computeStage.module);
|
||||
|
||||
const char* entryPoint = descriptor->computeStage.entryPoint;
|
||||
std::string remappedEntryPoint;
|
||||
std::string hlslSource;
|
||||
if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
|
||||
DAWN_TRY_ASSIGN(hlslSource, module->TranslateToHLSLWithTint(
|
||||
entryPoint, SingleShaderStage::Compute,
|
||||
ToBackend(GetLayout()), &remappedEntryPoint));
|
||||
entryPoint = remappedEntryPoint.c_str();
|
||||
|
||||
} else {
|
||||
DAWN_TRY_ASSIGN(hlslSource,
|
||||
module->TranslateToHLSLWithSPIRVCross(
|
||||
entryPoint, SingleShaderStage::Compute, ToBackend(GetLayout())));
|
||||
|
||||
// Note that the HLSL will always use entryPoint "main" under SPIRV-cross.
|
||||
entryPoint = "main";
|
||||
}
|
||||
|
||||
D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
|
||||
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
|
||||
|
||||
ComPtr<IDxcBlob> compiledDXCShader;
|
||||
ComPtr<ID3DBlob> compiledFXCShader;
|
||||
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
||||
DAWN_TRY_ASSIGN(compiledDXCShader,
|
||||
CompileShaderDXC(device, SingleShaderStage::Compute, hlslSource,
|
||||
entryPoint, compileFlags));
|
||||
|
||||
d3dDesc.CS.pShaderBytecode = compiledDXCShader->GetBufferPointer();
|
||||
d3dDesc.CS.BytecodeLength = compiledDXCShader->GetBufferSize();
|
||||
} else {
|
||||
DAWN_TRY_ASSIGN(compiledFXCShader,
|
||||
CompileShaderFXC(device, SingleShaderStage::Compute, hlslSource,
|
||||
entryPoint, compileFlags));
|
||||
d3dDesc.CS.pShaderBytecode = compiledFXCShader->GetBufferPointer();
|
||||
d3dDesc.CS.BytecodeLength = compiledFXCShader->GetBufferSize();
|
||||
}
|
||||
|
||||
CompiledShader compiledShader;
|
||||
DAWN_TRY_ASSIGN(compiledShader, module->Compile(descriptor->computeStage.entryPoint,
|
||||
SingleShaderStage::Compute,
|
||||
ToBackend(GetLayout()), compileFlags));
|
||||
d3dDesc.CS = compiledShader.GetD3D12ShaderBytecode();
|
||||
device->GetD3D12Device()->CreateComputePipelineState(&d3dDesc,
|
||||
IID_PPV_ARGS(&mPipelineState));
|
||||
return {};
|
||||
|
|
|
@ -306,44 +306,13 @@ namespace dawn_native { namespace d3d12 {
|
|||
shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS;
|
||||
shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS;
|
||||
|
||||
PerStage<ComPtr<ID3DBlob>> compiledFXCShader;
|
||||
PerStage<ComPtr<IDxcBlob>> compiledDXCShader;
|
||||
|
||||
PerStage<CompiledShader> compiledShader;
|
||||
wgpu::ShaderStage renderStages = wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment;
|
||||
for (auto stage : IterateStages(renderStages)) {
|
||||
std::string hlslSource;
|
||||
const char* entryPoint = GetStage(stage).entryPoint.c_str();
|
||||
std::string remappedEntryPoint;
|
||||
|
||||
if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
|
||||
DAWN_TRY_ASSIGN(hlslSource, modules[stage]->TranslateToHLSLWithTint(
|
||||
entryPoint, stage, ToBackend(GetLayout()),
|
||||
&remappedEntryPoint));
|
||||
entryPoint = remappedEntryPoint.c_str();
|
||||
|
||||
} else {
|
||||
DAWN_TRY_ASSIGN(hlslSource, modules[stage]->TranslateToHLSLWithSPIRVCross(
|
||||
entryPoint, stage, ToBackend(GetLayout())));
|
||||
|
||||
// Note that the HLSL will always use entryPoint "main" under SPIRV-cross.
|
||||
entryPoint = "main";
|
||||
}
|
||||
|
||||
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
||||
DAWN_TRY_ASSIGN(
|
||||
compiledDXCShader[stage],
|
||||
CompileShaderDXC(device, stage, hlslSource, entryPoint, compileFlags));
|
||||
|
||||
shaders[stage]->pShaderBytecode = compiledDXCShader[stage]->GetBufferPointer();
|
||||
shaders[stage]->BytecodeLength = compiledDXCShader[stage]->GetBufferSize();
|
||||
} else {
|
||||
DAWN_TRY_ASSIGN(
|
||||
compiledFXCShader[stage],
|
||||
CompileShaderFXC(device, stage, hlslSource, entryPoint, compileFlags));
|
||||
|
||||
shaders[stage]->pShaderBytecode = compiledFXCShader[stage]->GetBufferPointer();
|
||||
shaders[stage]->BytecodeLength = compiledFXCShader[stage]->GetBufferSize();
|
||||
}
|
||||
DAWN_TRY_ASSIGN(compiledShader[stage],
|
||||
modules[stage]->Compile(entryPoints[stage], stage,
|
||||
ToBackend(GetLayout()), compileFlags));
|
||||
*shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode();
|
||||
}
|
||||
|
||||
PipelineLayout* layout = ToBackend(GetLayout());
|
||||
|
|
|
@ -309,4 +309,102 @@ namespace dawn_native { namespace d3d12 {
|
|||
return compiler.compile();
|
||||
}
|
||||
|
||||
ResultOrError<CompiledShader> ShaderModule::Compile(const char* entryPointName,
|
||||
SingleShaderStage stage,
|
||||
PipelineLayout* layout,
|
||||
uint32_t compileFlags) {
|
||||
Device* device = ToBackend(GetDevice());
|
||||
|
||||
// Compile the source shader to HLSL.
|
||||
std::string hlslSource;
|
||||
std::string remappedEntryPoint;
|
||||
if (device->IsToggleEnabled(Toggle::UseTintGenerator)) {
|
||||
DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSLWithTint(entryPointName, stage, layout,
|
||||
&remappedEntryPoint));
|
||||
entryPointName = remappedEntryPoint.c_str();
|
||||
} else {
|
||||
DAWN_TRY_ASSIGN(hlslSource,
|
||||
TranslateToHLSLWithSPIRVCross(entryPointName, stage, layout));
|
||||
|
||||
// Note that the HLSL will always use entryPoint "main" under
|
||||
// SPIRV-cross.
|
||||
entryPointName = "main";
|
||||
}
|
||||
|
||||
// Use HLSL source as the input for the key since it does need to know about the pipeline
|
||||
// layout. The pipeline layout is only required if we key from WGSL: two different pipeline
|
||||
// layouts could be used to produce different shader blobs and the wrong shader blob could
|
||||
// be loaded since the pipeline layout was missing from the key.
|
||||
// TODO(dawn:549): Consider keying from WGSL and serialize the pipeline layout it used.
|
||||
const PersistentCacheKey& shaderCacheKey =
|
||||
CreateHLSLKey(entryPointName, stage, hlslSource, compileFlags);
|
||||
|
||||
CompiledShader compiledShader = {};
|
||||
DAWN_TRY_ASSIGN(compiledShader.cachedShader,
|
||||
device->GetPersistentCache()->GetOrCreate(
|
||||
shaderCacheKey, [&](auto doCache) -> MaybeError {
|
||||
if (device->IsToggleEnabled(Toggle::UseDXC)) {
|
||||
DAWN_TRY_ASSIGN(compiledShader.compiledDXCShader,
|
||||
CompileShaderDXC(device, stage, hlslSource,
|
||||
entryPointName, compileFlags));
|
||||
} else {
|
||||
DAWN_TRY_ASSIGN(compiledShader.compiledFXCShader,
|
||||
CompileShaderFXC(device, stage, hlslSource,
|
||||
entryPointName, compileFlags));
|
||||
}
|
||||
const D3D12_SHADER_BYTECODE shader =
|
||||
compiledShader.GetD3D12ShaderBytecode();
|
||||
doCache(shader.pShaderBytecode, shader.BytecodeLength);
|
||||
return {};
|
||||
}));
|
||||
|
||||
return std::move(compiledShader);
|
||||
}
|
||||
|
||||
D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const {
|
||||
if (cachedShader.buffer != nullptr) {
|
||||
return {cachedShader.buffer.get(), cachedShader.bufferSize};
|
||||
} else if (compiledFXCShader != nullptr) {
|
||||
return {compiledFXCShader->GetBufferPointer(), compiledFXCShader->GetBufferSize()};
|
||||
} else if (compiledDXCShader != nullptr) {
|
||||
return {compiledDXCShader->GetBufferPointer(), compiledDXCShader->GetBufferSize()};
|
||||
}
|
||||
UNREACHABLE();
|
||||
return {};
|
||||
}
|
||||
|
||||
PersistentCacheKey ShaderModule::CreateHLSLKey(const char* entryPointName,
|
||||
SingleShaderStage stage,
|
||||
const std::string& hlslSource,
|
||||
uint32_t compileFlags) const {
|
||||
std::stringstream stream;
|
||||
|
||||
// Prefix the key with the type to avoid collisions from another type that could have the
|
||||
// same key.
|
||||
stream << static_cast<uint32_t>(PersistentKeyType::Shader);
|
||||
|
||||
// Provide "guard" strings that the user cannot provide to help ensure the generated HLSL
|
||||
// used to create this key is not being manufactured by the user to load the wrong shader
|
||||
// blob.
|
||||
// These strings can be HLSL comments because Tint does not emit HLSL comments.
|
||||
// TODO(dawn:549): Replace guards strings with something more secure.
|
||||
ASSERT(hlslSource.find("//") == std::string::npos);
|
||||
|
||||
stream << "// Start shader autogenerated by Dawn.";
|
||||
stream << hlslSource;
|
||||
stream << "// End of shader autogenerated by Dawn.";
|
||||
|
||||
stream << compileFlags;
|
||||
|
||||
// TODO(dawn:549): add the HLSL compiler version for good measure.
|
||||
|
||||
// If the source contains multiple entry points, ensure they are cached seperately
|
||||
// per stage since DX shader code can only be compiled per stage using the same
|
||||
// entry point.
|
||||
stream << static_cast<uint32_t>(stage);
|
||||
stream << entryPointName;
|
||||
|
||||
return PersistentCacheKey(std::istreambuf_iterator<char>{stream},
|
||||
std::istreambuf_iterator<char>{});
|
||||
}
|
||||
}} // namespace dawn_native::d3d12
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#ifndef DAWNNATIVE_D3D12_SHADERMODULED3D12_H_
|
||||
#define DAWNNATIVE_D3D12_SHADERMODULED3D12_H_
|
||||
|
||||
#include "dawn_native/PersistentCache.h"
|
||||
#include "dawn_native/ShaderModule.h"
|
||||
|
||||
#include "dawn_native/d3d12/d3d12_platform.h"
|
||||
|
@ -24,22 +25,28 @@ namespace dawn_native { namespace d3d12 {
|
|||
class Device;
|
||||
class PipelineLayout;
|
||||
|
||||
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(Device* device,
|
||||
SingleShaderStage stage,
|
||||
const std::string& hlslSource,
|
||||
const char* entryPoint,
|
||||
uint32_t compileFlags);
|
||||
ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(Device* device,
|
||||
SingleShaderStage stage,
|
||||
const std::string& hlslSource,
|
||||
const char* entryPoint,
|
||||
uint32_t compileFlags);
|
||||
// Manages a ref to one of the various representations of shader blobs.
|
||||
struct CompiledShader {
|
||||
ScopedCachedBlob cachedShader;
|
||||
ComPtr<ID3DBlob> compiledFXCShader;
|
||||
ComPtr<IDxcBlob> compiledDXCShader;
|
||||
D3D12_SHADER_BYTECODE GetD3D12ShaderBytecode() const;
|
||||
};
|
||||
|
||||
class ShaderModule final : public ShaderModuleBase {
|
||||
public:
|
||||
static ResultOrError<ShaderModule*> Create(Device* device,
|
||||
const ShaderModuleDescriptor* descriptor);
|
||||
|
||||
ResultOrError<CompiledShader> Compile(const char* entryPointName,
|
||||
SingleShaderStage stage,
|
||||
PipelineLayout* layout,
|
||||
uint32_t compileFlags);
|
||||
|
||||
private:
|
||||
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
|
||||
~ShaderModule() override = default;
|
||||
|
||||
ResultOrError<std::string> TranslateToHLSLWithTint(
|
||||
const char* entryPointName,
|
||||
SingleShaderStage stage,
|
||||
|
@ -50,9 +57,10 @@ namespace dawn_native { namespace d3d12 {
|
|||
SingleShaderStage stage,
|
||||
PipelineLayout* layout) const;
|
||||
|
||||
private:
|
||||
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
|
||||
~ShaderModule() override = default;
|
||||
PersistentCacheKey CreateHLSLKey(const char* entryPointName,
|
||||
SingleShaderStage stage,
|
||||
const std::string& hlslSource,
|
||||
uint32_t compileFlags) const;
|
||||
};
|
||||
|
||||
}} // namespace dawn_native::d3d12
|
||||
|
|
|
@ -31,4 +31,9 @@ dawn_component("dawn_platform") {
|
|||
]
|
||||
|
||||
deps = [ "${dawn_root}/src/common" ]
|
||||
|
||||
public_deps = [
|
||||
# DawnPlatform.h has #include <dawn/webgpu.h>
|
||||
"${dawn_root}/src/dawn:dawn_headers",
|
||||
]
|
||||
}
|
||||
|
|
|
@ -27,4 +27,4 @@ target_sources(dawn_platform PRIVATE
|
|||
"tracing/EventTracer.h"
|
||||
"tracing/TraceEvent.h"
|
||||
)
|
||||
target_link_libraries(dawn_platform PRIVATE dawn_internal_config dawn_common)
|
||||
target_link_libraries(dawn_platform PUBLIC dawn_headers PRIVATE dawn_internal_config dawn_common)
|
||||
|
|
|
@ -14,10 +14,45 @@
|
|||
|
||||
#include "dawn_platform/DawnPlatform.h"
|
||||
|
||||
#include "common/Assert.h"
|
||||
|
||||
namespace dawn_platform {
|
||||
|
||||
CachingInterface::CachingInterface() = default;
|
||||
|
||||
CachingInterface::~CachingInterface() = default;
|
||||
|
||||
Platform::Platform() = default;
|
||||
|
||||
Platform::~Platform() = default;
|
||||
|
||||
const unsigned char* Platform::GetTraceCategoryEnabledFlag(TraceCategory category) {
|
||||
static unsigned char disabled = 0;
|
||||
return &disabled;
|
||||
}
|
||||
|
||||
double Platform::MonotonicallyIncreasingTime() {
|
||||
return 0;
|
||||
}
|
||||
|
||||
uint64_t Platform::AddTraceEvent(char phase,
|
||||
const unsigned char* categoryGroupEnabled,
|
||||
const char* name,
|
||||
uint64_t id,
|
||||
double timestamp,
|
||||
int numArgs,
|
||||
const char** argNames,
|
||||
const unsigned char* argTypes,
|
||||
const uint64_t* argValues,
|
||||
unsigned char flags) {
|
||||
// AddTraceEvent cannot be called if events are disabled.
|
||||
ASSERT(false);
|
||||
return 0;
|
||||
}
|
||||
|
||||
dawn_platform::CachingInterface* Platform::GetCachingInterface(const void* fingerprint,
|
||||
size_t fingerprintSize) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace dawn_platform
|
|
@ -17,8 +17,11 @@
|
|||
|
||||
#include "dawn_platform/dawn_platform_export.h"
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
#include <dawn/webgpu.h>
|
||||
|
||||
namespace dawn_platform {
|
||||
|
||||
enum class TraceCategory {
|
||||
|
@ -28,14 +31,43 @@ namespace dawn_platform {
|
|||
GPUWork, // Actual GPU work
|
||||
};
|
||||
|
||||
class DAWN_PLATFORM_EXPORT CachingInterface {
|
||||
public:
|
||||
CachingInterface();
|
||||
virtual ~CachingInterface();
|
||||
|
||||
// LoadData has two modes. The first mode is used to get a value which
|
||||
// corresponds to the |key|. The |valueOut| is a caller provided buffer
|
||||
// allocated to the size |valueSize| which is loaded with data of the
|
||||
// size returned. The second mode is used to query for the existence of
|
||||
// the |key| where |valueOut| is nullptr and |valueSize| must be 0.
|
||||
// The return size is non-zero if the |key| exists.
|
||||
virtual size_t LoadData(const WGPUDevice device,
|
||||
const void* key,
|
||||
size_t keySize,
|
||||
void* valueOut,
|
||||
size_t valueSize) = 0;
|
||||
|
||||
// StoreData puts a |value| in the cache which corresponds to the |key|.
|
||||
virtual void StoreData(const WGPUDevice device,
|
||||
const void* key,
|
||||
size_t keySize,
|
||||
const void* value,
|
||||
size_t valueSize) = 0;
|
||||
|
||||
private:
|
||||
CachingInterface(const CachingInterface&) = delete;
|
||||
CachingInterface& operator=(const CachingInterface&) = delete;
|
||||
};
|
||||
|
||||
class DAWN_PLATFORM_EXPORT Platform {
|
||||
public:
|
||||
Platform();
|
||||
virtual ~Platform();
|
||||
|
||||
virtual const unsigned char* GetTraceCategoryEnabledFlag(TraceCategory category) = 0;
|
||||
virtual const unsigned char* GetTraceCategoryEnabledFlag(TraceCategory category);
|
||||
|
||||
virtual double MonotonicallyIncreasingTime() = 0;
|
||||
virtual double MonotonicallyIncreasingTime();
|
||||
|
||||
virtual uint64_t AddTraceEvent(char phase,
|
||||
const unsigned char* categoryGroupEnabled,
|
||||
|
@ -46,7 +78,13 @@ namespace dawn_platform {
|
|||
const char** argNames,
|
||||
const unsigned char* argTypes,
|
||||
const uint64_t* argValues,
|
||||
unsigned char flags) = 0;
|
||||
unsigned char flags);
|
||||
|
||||
// The |fingerprint| is provided by Dawn to inform the client to discard the Dawn caches
|
||||
// when the fingerprint changes. The returned CachingInterface is expected to outlive the
|
||||
// device which uses it to persistently cache objects.
|
||||
virtual CachingInterface* GetCachingInterface(const void* fingerprint,
|
||||
size_t fingerprintSize);
|
||||
|
||||
private:
|
||||
Platform(const Platform&) = delete;
|
||||
|
|
|
@ -333,7 +333,10 @@ source_set("dawn_end2end_tests_sources") {
|
|||
libs = []
|
||||
|
||||
if (dawn_enable_d3d12) {
|
||||
sources += [ "end2end/D3D12ResourceWrappingTests.cpp" ]
|
||||
sources += [
|
||||
"end2end/D3D12CachingTests.cpp",
|
||||
"end2end/D3D12ResourceWrappingTests.cpp",
|
||||
]
|
||||
libs += [
|
||||
"d3d11.lib",
|
||||
"dxgi.lib",
|
||||
|
|
|
@ -744,6 +744,10 @@ void DawnTestBase::SetUp() {
|
|||
mBackendAdapter = *it;
|
||||
}
|
||||
|
||||
// Setup the per-test platform. Tests can provide one by overloading CreateTestPlatform.
|
||||
mTestPlatform = CreateTestPlatform();
|
||||
gTestEnv->GetInstance()->SetPlatform(mTestPlatform.get());
|
||||
|
||||
// Create the device from the adapter
|
||||
for (const char* forceEnabledWorkaround : mParam.forceEnabledWorkarounds) {
|
||||
ASSERT(gTestEnv->GetInstance()->GetToggleInfo(forceEnabledWorkaround) != nullptr);
|
||||
|
@ -1080,6 +1084,10 @@ void DawnTestBase::ResolveExpectations() {
|
|||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<dawn_platform::Platform> DawnTestBase::CreateTestPlatform() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool RGBA8::operator==(const RGBA8& other) const {
|
||||
return r == other.r && g == other.g && b == other.b && a == other.a;
|
||||
}
|
||||
|
|
|
@ -20,6 +20,7 @@
|
|||
#include "dawn/webgpu_cpp.h"
|
||||
#include "dawn_native/DawnNative.h"
|
||||
|
||||
#include <dawn_platform/DawnPlatform.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <memory>
|
||||
|
@ -268,6 +269,8 @@ class DawnTestBase {
|
|||
wgpu::Instance GetInstance() const;
|
||||
dawn_native::Adapter GetAdapter() const;
|
||||
|
||||
virtual std::unique_ptr<dawn_platform::Platform> CreateTestPlatform();
|
||||
|
||||
protected:
|
||||
wgpu::Device device;
|
||||
wgpu::Queue queue;
|
||||
|
@ -403,6 +406,8 @@ class DawnTestBase {
|
|||
void ResolveExpectations();
|
||||
|
||||
dawn_native::Adapter mBackendAdapter;
|
||||
|
||||
std::unique_ptr<dawn_platform::Platform> mTestPlatform;
|
||||
};
|
||||
|
||||
// Skip a test when the given condition is satisfied.
|
||||
|
|
|
@ -0,0 +1,287 @@
|
|||
// Copyright 2020 The Dawn Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "tests/DawnTest.h"
|
||||
|
||||
#include "utils/ComboRenderPipelineDescriptor.h"
|
||||
#include "utils/WGPUHelpers.h"
|
||||
|
||||
#define EXPECT_CACHE_HIT(N, statement) \
|
||||
do { \
|
||||
size_t before = mPersistentCache.mHitCount; \
|
||||
statement; \
|
||||
FlushWire(); \
|
||||
size_t after = mPersistentCache.mHitCount; \
|
||||
EXPECT_EQ(N, after - before); \
|
||||
} while (0)
|
||||
|
||||
// FakePersistentCache implements a in-memory persistent cache.
|
||||
class FakePersistentCache : public dawn_platform::CachingInterface {
|
||||
public:
|
||||
// PersistentCache API
|
||||
void StoreData(const WGPUDevice device,
|
||||
const void* key,
|
||||
size_t keySize,
|
||||
const void* value,
|
||||
size_t valueSize) override {
|
||||
if (mIsDisabled)
|
||||
return;
|
||||
const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
|
||||
|
||||
const uint8_t* value_start = reinterpret_cast<const uint8_t*>(value);
|
||||
std::vector<uint8_t> entry_value(value_start, value_start + valueSize);
|
||||
|
||||
EXPECT_TRUE(mCache.insert({keyStr, std::move(entry_value)}).second);
|
||||
}
|
||||
|
||||
size_t LoadData(const WGPUDevice device,
|
||||
const void* key,
|
||||
size_t keySize,
|
||||
void* value,
|
||||
size_t valueSize) override {
|
||||
const std::string keyStr(reinterpret_cast<const char*>(key), keySize);
|
||||
auto entry = mCache.find(keyStr);
|
||||
if (entry == mCache.end()) {
|
||||
return 0;
|
||||
}
|
||||
if (valueSize >= entry->second.size()) {
|
||||
memcpy(value, entry->second.data(), entry->second.size());
|
||||
}
|
||||
mHitCount++;
|
||||
return entry->second.size();
|
||||
}
|
||||
|
||||
using Blob = std::vector<uint8_t>;
|
||||
using FakeCache = std::unordered_map<std::string, Blob>;
|
||||
|
||||
FakeCache mCache;
|
||||
|
||||
size_t mHitCount = 0;
|
||||
bool mIsDisabled = false;
|
||||
};
|
||||
|
||||
// Test platform that only supports caching.
|
||||
class DawnTestPlatform : public dawn_platform::Platform {
|
||||
public:
|
||||
DawnTestPlatform(dawn_platform::CachingInterface* cachingInterface)
|
||||
: mCachingInterface(cachingInterface) {
|
||||
}
|
||||
~DawnTestPlatform() override = default;
|
||||
|
||||
dawn_platform::CachingInterface* GetCachingInterface(const void* fingerprint,
|
||||
size_t fingerprintSize) override {
|
||||
return mCachingInterface;
|
||||
}
|
||||
|
||||
dawn_platform::CachingInterface* mCachingInterface = nullptr;
|
||||
};
|
||||
|
||||
class D3D12CachingTests : public DawnTest {
|
||||
protected:
|
||||
std::unique_ptr<dawn_platform::Platform> CreateTestPlatform() override {
|
||||
return std::make_unique<DawnTestPlatform>(&mPersistentCache);
|
||||
}
|
||||
|
||||
FakePersistentCache mPersistentCache;
|
||||
};
|
||||
|
||||
// Test that duplicate WGSL still re-compiles HLSL even when the cache is not enabled.
|
||||
TEST_P(D3D12CachingTests, SameShaderNoCache) {
|
||||
mPersistentCache.mIsDisabled = true;
|
||||
|
||||
wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
|
||||
[[builtin(position)]] var<out> Position : vec4<f32>;
|
||||
|
||||
[[stage(vertex)]]
|
||||
fn vertex_main() -> void {
|
||||
Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
|
||||
return;
|
||||
}
|
||||
|
||||
[[location(0)]] var<out> outColor : vec4<f32>;
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn fragment_main() -> void {
|
||||
outColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
|
||||
return;
|
||||
}
|
||||
)");
|
||||
|
||||
// Store the WGSL shader into the cache.
|
||||
{
|
||||
utils::ComboRenderPipelineDescriptor desc(device);
|
||||
desc.vertexStage.module = module;
|
||||
desc.vertexStage.entryPoint = "vertex_main";
|
||||
desc.cFragmentStage.module = module;
|
||||
desc.cFragmentStage.entryPoint = "fragment_main";
|
||||
|
||||
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
|
||||
}
|
||||
|
||||
EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
|
||||
|
||||
// Load the same WGSL shader from the cache.
|
||||
{
|
||||
utils::ComboRenderPipelineDescriptor desc(device);
|
||||
desc.vertexStage.module = module;
|
||||
desc.vertexStage.entryPoint = "vertex_main";
|
||||
desc.cFragmentStage.module = module;
|
||||
desc.cFragmentStage.entryPoint = "fragment_main";
|
||||
|
||||
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
|
||||
}
|
||||
|
||||
EXPECT_EQ(mPersistentCache.mCache.size(), 0u);
|
||||
}
|
||||
|
||||
// Test creating a pipeline from two entrypoints in multiple stages will cache the correct number
|
||||
// of HLSL shaders. WGSL shader should result into caching 2 HLSL shaders (stage x
|
||||
// entrypoints)
|
||||
TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPointsPerStage) {
|
||||
wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
|
||||
[[builtin(position)]] var<out> Position : vec4<f32>;
|
||||
|
||||
[[stage(vertex)]]
|
||||
fn vertex_main() -> void {
|
||||
Position = vec4<f32>(0.0, 0.0, 0.0, 1.0);
|
||||
return;
|
||||
}
|
||||
|
||||
[[location(0)]] var<out> outColor : vec4<f32>;
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn fragment_main() -> void {
|
||||
outColor = vec4<f32>(1.0, 0.0, 0.0, 1.0);
|
||||
return;
|
||||
}
|
||||
)");
|
||||
|
||||
// Store the WGSL shader into the cache.
|
||||
{
|
||||
utils::ComboRenderPipelineDescriptor desc(device);
|
||||
desc.vertexStage.module = module;
|
||||
desc.vertexStage.entryPoint = "vertex_main";
|
||||
desc.cFragmentStage.module = module;
|
||||
desc.cFragmentStage.entryPoint = "fragment_main";
|
||||
|
||||
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
|
||||
}
|
||||
|
||||
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
|
||||
|
||||
// Load the same WGSL shader from the cache.
|
||||
{
|
||||
utils::ComboRenderPipelineDescriptor desc(device);
|
||||
desc.vertexStage.module = module;
|
||||
desc.vertexStage.entryPoint = "vertex_main";
|
||||
desc.cFragmentStage.module = module;
|
||||
desc.cFragmentStage.entryPoint = "fragment_main";
|
||||
|
||||
// Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
|
||||
// kNumOfShaders hits.
|
||||
EXPECT_CACHE_HIT(4u, device.CreateRenderPipeline(&desc));
|
||||
}
|
||||
|
||||
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
|
||||
|
||||
// Modify the WGSL shader functions and make sure it doesn't hit.
|
||||
wgpu::ShaderModule newModule = utils::CreateShaderModuleFromWGSL(device, R"(
|
||||
[[builtin(position)]] var<out> Position : vec4<f32>;
|
||||
|
||||
[[stage(vertex)]]
|
||||
fn vertex_main() -> void {
|
||||
Position = vec4<f32>(1.0, 1.0, 1.0, 1.0);
|
||||
return;
|
||||
}
|
||||
|
||||
[[location(0)]] var<out> outColor : vec4<f32>;
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn fragment_main() -> void {
|
||||
outColor = vec4<f32>(1.0, 1.0, 1.0, 1.0);
|
||||
return;
|
||||
}
|
||||
)");
|
||||
|
||||
{
|
||||
utils::ComboRenderPipelineDescriptor desc(device);
|
||||
desc.vertexStage.module = newModule;
|
||||
desc.vertexStage.entryPoint = "vertex_main";
|
||||
desc.cFragmentStage.module = newModule;
|
||||
desc.cFragmentStage.entryPoint = "fragment_main";
|
||||
EXPECT_CACHE_HIT(0u, device.CreateRenderPipeline(&desc));
|
||||
}
|
||||
|
||||
// Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
|
||||
// kNumOfShaders hits.
|
||||
EXPECT_EQ(mPersistentCache.mCache.size(), 4u);
|
||||
}
|
||||
|
||||
// Test creating a WGSL shader with two entrypoints in the same stage will cache the correct number
|
||||
// of HLSL shaders. WGSL shader should result into caching 1 HLSL shader (stage x entrypoints)
|
||||
TEST_P(D3D12CachingTests, ReuseShaderWithMultipleEntryPoints) {
|
||||
wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"(
|
||||
[[block]] struct Data {
|
||||
[[offset(0)]] data : u32;
|
||||
};
|
||||
[[binding(0), set(0)]] var<storage_buffer> data : Data;
|
||||
|
||||
[[stage(compute)]]
|
||||
fn write1() -> void {
|
||||
data.data = 1u;
|
||||
return;
|
||||
}
|
||||
|
||||
[[stage(compute)]]
|
||||
fn write42() -> void {
|
||||
data.data = 42u;
|
||||
return;
|
||||
}
|
||||
)");
|
||||
|
||||
// Store the WGSL shader into the cache.
|
||||
{
|
||||
wgpu::ComputePipelineDescriptor desc;
|
||||
desc.computeStage.module = module;
|
||||
desc.computeStage.entryPoint = "write1";
|
||||
EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
|
||||
|
||||
desc.computeStage.module = module;
|
||||
desc.computeStage.entryPoint = "write42";
|
||||
EXPECT_CACHE_HIT(0u, device.CreateComputePipeline(&desc));
|
||||
}
|
||||
|
||||
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
|
||||
|
||||
// Load the same WGSL shader from the cache.
|
||||
{
|
||||
wgpu::ComputePipelineDescriptor desc;
|
||||
desc.computeStage.module = module;
|
||||
desc.computeStage.entryPoint = "write1";
|
||||
|
||||
// Cached HLSL shader calls LoadData twice (once to peek, again to get), so check 2 x
|
||||
// kNumOfShaders hits.
|
||||
EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
|
||||
|
||||
desc.computeStage.module = module;
|
||||
desc.computeStage.entryPoint = "write42";
|
||||
|
||||
// Cached HLSL shader calls LoadData twice, so check 2 x kNumOfShaders hits.
|
||||
EXPECT_CACHE_HIT(2u, device.CreateComputePipeline(&desc));
|
||||
}
|
||||
|
||||
EXPECT_EQ(mPersistentCache.mCache.size(), 2u);
|
||||
}
|
||||
|
||||
DAWN_INSTANTIATE_TEST(D3D12CachingTests, D3D12Backend());
|
Loading…
Reference in New Issue