Improve D3D12 pipeline cache implementation code

Bug: dawn:549
Change-Id: I84eaabdb2b72e73e37cd840632a4180acf2253e9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/92680
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Shrek Shao <shrekshao@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
shrekshao 2022-06-07 17:21:34 +00:00 committed by Dawn LUCI CQ
parent 37b1f0fb5a
commit fc95c27933
13 changed files with 60 additions and 37 deletions

View File

@ -377,6 +377,7 @@ source_set("sources") {
"d3d12/BindGroupLayoutD3D12.cpp", "d3d12/BindGroupLayoutD3D12.cpp",
"d3d12/BindGroupLayoutD3D12.h", "d3d12/BindGroupLayoutD3D12.h",
"d3d12/BlobD3D12.cpp", "d3d12/BlobD3D12.cpp",
"d3d12/BlobD3D12.h",
"d3d12/BufferD3D12.cpp", "d3d12/BufferD3D12.cpp",
"d3d12/BufferD3D12.h", "d3d12/BufferD3D12.h",
"d3d12/CPUDescriptorHeapAllocationD3D12.cpp", "d3d12/CPUDescriptorHeapAllocationD3D12.cpp",

View File

@ -18,16 +18,20 @@
namespace dawn::native { namespace dawn::native {
// static Blob CreateBlob(size_t size) {
Blob Blob::Create(size_t size) {
if (size > 0) { if (size > 0) {
uint8_t* data = new uint8_t[size]; uint8_t* data = new uint8_t[size];
return Blob(data, size, [=]() { delete[] data; }); return Blob::UnsafeCreateWithDeleter(data, size, [=]() { delete[] data; });
} else { } else {
return Blob(); return Blob();
} }
} }
// static
Blob Blob::UnsafeCreateWithDeleter(uint8_t* data, size_t size, std::function<void()> deleter) {
return Blob(data, size, deleter);
}
Blob::Blob() : mData(nullptr), mSize(0), mDeleter({}) {} Blob::Blob() : mData(nullptr), mSize(0), mDeleter({}) {}
Blob::Blob(uint8_t* data, size_t size, std::function<void()> deleter) Blob::Blob(uint8_t* data, size_t size, std::function<void()> deleter)

View File

@ -18,12 +18,6 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include "dawn/common/Platform.h"
#if defined(DAWN_PLATFORM_WINDOWS)
#include "dawn/native/d3d12/d3d12_platform.h"
#endif // DAWN_PLATFORM_WINDOWS
namespace dawn::native { namespace dawn::native {
// Blob represents a block of bytes. It may be constructed from // Blob represents a block of bytes. It may be constructed from
@ -31,11 +25,9 @@ namespace dawn::native {
// ownership of the container and release its memory on destruction. // ownership of the container and release its memory on destruction.
class Blob { class Blob {
public: public:
static Blob Create(size_t size); // This function is used to create Blob with actual data.
// Make sure the creation and deleter handles the data ownership and lifetime correctly.
#if defined(DAWN_PLATFORM_WINDOWS) static Blob UnsafeCreateWithDeleter(uint8_t* data, size_t size, std::function<void()> deleter);
static Blob Create(Microsoft::WRL::ComPtr<ID3DBlob> blob);
#endif // DAWN_PLATFORM_WINDOWS
Blob(); Blob();
~Blob(); ~Blob();
@ -52,6 +44,8 @@ class Blob {
size_t Size() const; size_t Size() const;
private: private:
// The constructor should be responsible to take ownership of |data| and releases ownership by
// calling |deleter|. The deleter function is called at ~Blob() and during std::move.
explicit Blob(uint8_t* data, size_t size, std::function<void()> deleter); explicit Blob(uint8_t* data, size_t size, std::function<void()> deleter);
uint8_t* mData; uint8_t* mData;
@ -59,6 +53,8 @@ class Blob {
std::function<void()> mDeleter; std::function<void()> mDeleter;
}; };
Blob CreateBlob(size_t size);
} // namespace dawn::native } // namespace dawn::native
#endif // SRC_DAWN_NATIVE_BLOB_H_ #endif // SRC_DAWN_NATIVE_BLOB_H_

View File

@ -45,7 +45,7 @@ Blob BlobCache::LoadInternal(const CacheKey& key) {
const size_t expectedSize = mCache->LoadData(key.data(), key.size(), nullptr, 0); const size_t expectedSize = mCache->LoadData(key.data(), key.size(), nullptr, 0);
if (expectedSize > 0) { if (expectedSize > 0) {
// Need to put this inside to trigger copy elision. // Need to put this inside to trigger copy elision.
Blob result = Blob::Create(expectedSize); Blob result = CreateBlob(expectedSize);
const size_t actualSize = const size_t actualSize =
mCache->LoadData(key.data(), key.size(), result.Data(), expectedSize); mCache->LoadData(key.data(), key.size(), result.Data(), expectedSize);
ASSERT(expectedSize == actualSize); ASSERT(expectedSize == actualSize);

View File

@ -244,6 +244,7 @@ if (DAWN_ENABLE_D3D12)
"d3d12/BindGroupLayoutD3D12.cpp" "d3d12/BindGroupLayoutD3D12.cpp"
"d3d12/BindGroupLayoutD3D12.h" "d3d12/BindGroupLayoutD3D12.h"
"d3d12/BlobD3D12.cpp" "d3d12/BlobD3D12.cpp"
"d3d12/BlobD3D12.h"
"d3d12/BufferD3D12.cpp" "d3d12/BufferD3D12.cpp"
"d3d12/BufferD3D12.h" "d3d12/BufferD3D12.h"
"d3d12/CPUDescriptorHeapAllocationD3D12.cpp" "d3d12/CPUDescriptorHeapAllocationD3D12.cpp"

View File

@ -22,7 +22,6 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "dawn/native/BlobCache.h"
#include "dawn/native/CacheKey.h" #include "dawn/native/CacheKey.h"
#include "dawn/native/Commands.h" #include "dawn/native/Commands.h"
#include "dawn/native/ComputePipeline.h" #include "dawn/native/ComputePipeline.h"
@ -48,6 +47,8 @@ namespace dawn::native {
class AsyncTaskManager; class AsyncTaskManager;
class AttachmentState; class AttachmentState;
class AttachmentStateBlueprint; class AttachmentStateBlueprint;
class Blob;
class BlobCache;
class CallbackTaskManager; class CallbackTaskManager;
class DynamicUploader; class DynamicUploader;
class ErrorScopeStack; class ErrorScopeStack;

View File

@ -12,21 +12,20 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "dawn/native/Blob.h" #include "dawn/native/d3d12/BlobD3D12.h"
#include "dawn/native/d3d12/d3d12_platform.h"
namespace dawn::native { namespace dawn::native {
// static Blob CreateBlob(ComPtr<ID3DBlob> blob) {
Blob Blob::Create(ComPtr<ID3DBlob> blob) {
// Detach so the deleter callback can "own" the reference // Detach so the deleter callback can "own" the reference
ID3DBlob* ptr = blob.Detach(); ID3DBlob* ptr = blob.Detach();
return Blob(reinterpret_cast<uint8_t*>(ptr->GetBufferPointer()), ptr->GetBufferSize(), [=]() { return Blob::UnsafeCreateWithDeleter(reinterpret_cast<uint8_t*>(ptr->GetBufferPointer()),
// Reattach and drop to delete it. ptr->GetBufferSize(), [=]() {
ComPtr<ID3DBlob> b; // Reattach and drop to delete it.
b.Attach(ptr); ComPtr<ID3DBlob> b;
b = nullptr; b.Attach(ptr);
}); b = nullptr;
});
} }
} // namespace dawn::native } // namespace dawn::native

View File

@ -0,0 +1,22 @@
// Copyright 2022 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn/native/Blob.h"
#include "dawn/native/d3d12/d3d12_platform.h"
namespace dawn::native {
Blob CreateBlob(ComPtr<ID3DBlob> blob);
} // namespace dawn::native

View File

@ -18,6 +18,7 @@
#include <utility> #include <utility>
#include "dawn/native/CreatePipelineAsyncTask.h" #include "dawn/native/CreatePipelineAsyncTask.h"
#include "dawn/native/d3d12/BlobD3D12.h"
#include "dawn/native/d3d12/D3D12Error.h" #include "dawn/native/d3d12/D3D12Error.h"
#include "dawn/native/d3d12/DeviceD3D12.h" #include "dawn/native/d3d12/DeviceD3D12.h"
#include "dawn/native/d3d12/PipelineLayoutD3D12.h" #include "dawn/native/d3d12/PipelineLayoutD3D12.h"
@ -82,7 +83,7 @@ MaybeError ComputePipeline::Initialize() {
ComPtr<ID3DBlob> d3dBlob; ComPtr<ID3DBlob> d3dBlob;
DAWN_TRY(CheckHRESULT(GetPipelineState()->GetCachedBlob(&d3dBlob), DAWN_TRY(CheckHRESULT(GetPipelineState()->GetCachedBlob(&d3dBlob),
"D3D12 compute pipeline state get cached blob")); "D3D12 compute pipeline state get cached blob"));
device->StoreCachedBlob(GetCacheKey(), Blob::Create(std::move(d3dBlob))); device->StoreCachedBlob(GetCacheKey(), CreateBlob(std::move(d3dBlob)));
} }
SetLabelImpl(); SetLabelImpl();

View File

@ -441,9 +441,6 @@ ResultOrError<Ref<TextureViewBase>> Device::CreateTextureViewImpl(
const TextureViewDescriptor* descriptor) { const TextureViewDescriptor* descriptor) {
return TextureView::Create(texture, descriptor); return TextureView::Create(texture, descriptor);
} }
Ref<PipelineCacheBase> Device::GetOrCreatePipelineCacheImpl(const CacheKey& key) {
UNREACHABLE();
}
void Device::InitializeComputePipelineAsyncImpl(Ref<ComputePipelineBase> computePipeline, void Device::InitializeComputePipelineAsyncImpl(Ref<ComputePipelineBase> computePipeline,
WGPUCreateComputePipelineAsyncCallback callback, WGPUCreateComputePipelineAsyncCallback callback,
void* userdata) { void* userdata) {

View File

@ -188,7 +188,6 @@ class Device final : public DeviceBase {
const ComputePipelineDescriptor* descriptor) override; const ComputePipelineDescriptor* descriptor) override;
Ref<RenderPipelineBase> CreateUninitializedRenderPipelineImpl( Ref<RenderPipelineBase> CreateUninitializedRenderPipelineImpl(
const RenderPipelineDescriptor* descriptor) override; const RenderPipelineDescriptor* descriptor) override;
Ref<PipelineCacheBase> GetOrCreatePipelineCacheImpl(const CacheKey& key) override;
void InitializeComputePipelineAsyncImpl(Ref<ComputePipelineBase> computePipeline, void InitializeComputePipelineAsyncImpl(Ref<ComputePipelineBase> computePipeline,
WGPUCreateComputePipelineAsyncCallback callback, WGPUCreateComputePipelineAsyncCallback callback,
void* userdata) override; void* userdata) override;

View File

@ -22,6 +22,7 @@
#include "dawn/common/Assert.h" #include "dawn/common/Assert.h"
#include "dawn/common/Log.h" #include "dawn/common/Log.h"
#include "dawn/native/CreatePipelineAsyncTask.h" #include "dawn/native/CreatePipelineAsyncTask.h"
#include "dawn/native/d3d12/BlobD3D12.h"
#include "dawn/native/d3d12/D3D12Error.h" #include "dawn/native/d3d12/D3D12Error.h"
#include "dawn/native/d3d12/DeviceD3D12.h" #include "dawn/native/d3d12/DeviceD3D12.h"
#include "dawn/native/d3d12/PipelineLayoutD3D12.h" #include "dawn/native/d3d12/PipelineLayoutD3D12.h"
@ -449,7 +450,7 @@ MaybeError RenderPipeline::Initialize() {
ComPtr<ID3DBlob> d3dBlob; ComPtr<ID3DBlob> d3dBlob;
DAWN_TRY(CheckHRESULT(GetPipelineState()->GetCachedBlob(&d3dBlob), DAWN_TRY(CheckHRESULT(GetPipelineState()->GetCachedBlob(&d3dBlob),
"D3D12 render pipeline state get cached blob")); "D3D12 render pipeline state get cached blob"));
device->StoreCachedBlob(GetCacheKey(), Blob::Create(std::move(d3dBlob))); device->StoreCachedBlob(GetCacheKey(), CreateBlob(std::move(d3dBlob)));
} }
SetLabelImpl(); SetLabelImpl();

View File

@ -61,12 +61,13 @@ MaybeError PipelineCache::SerializeToBlobImpl(Blob* blob) {
DAWN_TRY(CheckVkSuccess( DAWN_TRY(CheckVkSuccess(
device->fn.GetPipelineCacheData(device->GetVkDevice(), mHandle, &bufferSize, nullptr), device->fn.GetPipelineCacheData(device->GetVkDevice(), mHandle, &bufferSize, nullptr),
"GetPipelineCacheData")); "GetPipelineCacheData"));
if (bufferSize > 0) { if (bufferSize == 0) {
*blob = Blob::Create(bufferSize); return {};
DAWN_TRY(CheckVkSuccess(device->fn.GetPipelineCacheData(device->GetVkDevice(), mHandle,
&bufferSize, blob->Data()),
"GetPipelineCacheData"));
} }
*blob = CreateBlob(bufferSize);
DAWN_TRY(CheckVkSuccess(
device->fn.GetPipelineCacheData(device->GetVkDevice(), mHandle, &bufferSize, blob->Data()),
"GetPipelineCacheData"));
return {}; return {};
} }