mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-05-15 20:01:22 +00:00
dawn_native: deduplicate compute pipelines
BUG=dawn:143 Change-Id: I64e4660de2241bb72bb7c615a0bd1e675e043295 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6863 Commit-Queue: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
parent
c535198d96
commit
1152bbaf8e
@ -14,6 +14,7 @@
|
|||||||
|
|
||||||
#include "dawn_native/ComputePipeline.h"
|
#include "dawn_native/ComputePipeline.h"
|
||||||
|
|
||||||
|
#include "common/HashUtils.h"
|
||||||
#include "dawn_native/Device.h"
|
#include "dawn_native/Device.h"
|
||||||
|
|
||||||
namespace dawn_native {
|
namespace dawn_native {
|
||||||
@ -33,8 +34,12 @@ namespace dawn_native {
|
|||||||
// ComputePipelineBase
|
// ComputePipelineBase
|
||||||
|
|
||||||
ComputePipelineBase::ComputePipelineBase(DeviceBase* device,
|
ComputePipelineBase::ComputePipelineBase(DeviceBase* device,
|
||||||
const ComputePipelineDescriptor* descriptor)
|
const ComputePipelineDescriptor* descriptor,
|
||||||
: PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) {
|
bool blueprint)
|
||||||
|
: PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute),
|
||||||
|
mModule(descriptor->computeStage->module),
|
||||||
|
mEntryPoint(descriptor->computeStage->entryPoint),
|
||||||
|
mIsBlueprint(blueprint) {
|
||||||
ExtractModuleData(dawn::ShaderStage::Compute, descriptor->computeStage->module);
|
ExtractModuleData(dawn::ShaderStage::Compute, descriptor->computeStage->module);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -42,9 +47,29 @@ namespace dawn_native {
|
|||||||
: PipelineBase(device, tag) {
|
: PipelineBase(device, tag) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ComputePipelineBase::~ComputePipelineBase() {
|
||||||
|
// Do not uncache the actual cached object if we are a blueprint
|
||||||
|
if (!mIsBlueprint) {
|
||||||
|
ASSERT(!IsError());
|
||||||
|
GetDevice()->UncacheComputePipeline(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
ComputePipelineBase* ComputePipelineBase::MakeError(DeviceBase* device) {
|
ComputePipelineBase* ComputePipelineBase::MakeError(DeviceBase* device) {
|
||||||
return new ComputePipelineBase(device, ObjectBase::kError);
|
return new ComputePipelineBase(device, ObjectBase::kError);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t ComputePipelineBase::HashFunc::operator()(const ComputePipelineBase* pipeline) const {
|
||||||
|
size_t hash = 0;
|
||||||
|
HashCombine(&hash, pipeline->mModule.Get(), pipeline->mEntryPoint, pipeline->GetLayout());
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ComputePipelineBase::EqualityFunc::operator()(const ComputePipelineBase* a,
|
||||||
|
const ComputePipelineBase* b) const {
|
||||||
|
return a->mModule.Get() == b->mModule.Get() && a->mEntryPoint == b->mEntryPoint &&
|
||||||
|
a->GetLayout() == b->GetLayout();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dawn_native
|
} // namespace dawn_native
|
||||||
|
@ -26,12 +26,28 @@ namespace dawn_native {
|
|||||||
|
|
||||||
class ComputePipelineBase : public PipelineBase {
|
class ComputePipelineBase : public PipelineBase {
|
||||||
public:
|
public:
|
||||||
ComputePipelineBase(DeviceBase* device, const ComputePipelineDescriptor* descriptor);
|
ComputePipelineBase(DeviceBase* device,
|
||||||
|
const ComputePipelineDescriptor* descriptor,
|
||||||
|
bool blueprint = false);
|
||||||
|
~ComputePipelineBase() override;
|
||||||
|
|
||||||
static ComputePipelineBase* MakeError(DeviceBase* device);
|
static ComputePipelineBase* MakeError(DeviceBase* device);
|
||||||
|
|
||||||
|
// Functors necessary for the unordered_set<ComputePipelineBase*>-based cache.
|
||||||
|
struct HashFunc {
|
||||||
|
size_t operator()(const ComputePipelineBase* pipeline) const;
|
||||||
|
};
|
||||||
|
struct EqualityFunc {
|
||||||
|
bool operator()(const ComputePipelineBase* a, const ComputePipelineBase* b) const;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
|
ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
|
||||||
|
|
||||||
|
// TODO(cwallez@chromium.org): Store a crypto hash of the module instead.
|
||||||
|
Ref<ShaderModuleBase> mModule;
|
||||||
|
std::string mEntryPoint;
|
||||||
|
bool mIsBlueprint = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace dawn_native
|
} // namespace dawn_native
|
||||||
|
@ -48,6 +48,7 @@ namespace dawn_native {
|
|||||||
|
|
||||||
struct DeviceBase::Caches {
|
struct DeviceBase::Caches {
|
||||||
ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
|
ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
|
||||||
|
ContentLessObjectCache<ComputePipelineBase> computePipelines;
|
||||||
ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
|
ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
|
||||||
ContentLessObjectCache<ShaderModuleBase> shaderModules;
|
ContentLessObjectCache<ShaderModuleBase> shaderModules;
|
||||||
};
|
};
|
||||||
@ -121,6 +122,27 @@ namespace dawn_native {
|
|||||||
ASSERT(removedCount == 1);
|
ASSERT(removedCount == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ResultOrError<ComputePipelineBase*> DeviceBase::GetOrCreateComputePipeline(
|
||||||
|
const ComputePipelineDescriptor* descriptor) {
|
||||||
|
ComputePipelineBase blueprint(this, descriptor, true);
|
||||||
|
|
||||||
|
auto iter = mCaches->computePipelines.find(&blueprint);
|
||||||
|
if (iter != mCaches->computePipelines.end()) {
|
||||||
|
(*iter)->Reference();
|
||||||
|
return *iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
ComputePipelineBase* backendObj;
|
||||||
|
DAWN_TRY_ASSIGN(backendObj, CreateComputePipelineImpl(descriptor));
|
||||||
|
mCaches->computePipelines.insert(backendObj);
|
||||||
|
return backendObj;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceBase::UncacheComputePipeline(ComputePipelineBase* obj) {
|
||||||
|
size_t removedCount = mCaches->computePipelines.erase(obj);
|
||||||
|
ASSERT(removedCount == 1);
|
||||||
|
}
|
||||||
|
|
||||||
ResultOrError<PipelineLayoutBase*> DeviceBase::GetOrCreatePipelineLayout(
|
ResultOrError<PipelineLayoutBase*> DeviceBase::GetOrCreatePipelineLayout(
|
||||||
const PipelineLayoutDescriptor* descriptor) {
|
const PipelineLayoutDescriptor* descriptor) {
|
||||||
PipelineLayoutBase blueprint(this, descriptor, true);
|
PipelineLayoutBase blueprint(this, descriptor, true);
|
||||||
@ -369,7 +391,7 @@ namespace dawn_native {
|
|||||||
ComputePipelineBase** result,
|
ComputePipelineBase** result,
|
||||||
const ComputePipelineDescriptor* descriptor) {
|
const ComputePipelineDescriptor* descriptor) {
|
||||||
DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
|
DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
|
||||||
DAWN_TRY_ASSIGN(*result, CreateComputePipelineImpl(descriptor));
|
DAWN_TRY_ASSIGN(*result, GetOrCreateComputePipeline(descriptor));
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,6 +84,10 @@ namespace dawn_native {
|
|||||||
const BindGroupLayoutDescriptor* descriptor);
|
const BindGroupLayoutDescriptor* descriptor);
|
||||||
void UncacheBindGroupLayout(BindGroupLayoutBase* obj);
|
void UncacheBindGroupLayout(BindGroupLayoutBase* obj);
|
||||||
|
|
||||||
|
ResultOrError<ComputePipelineBase*> GetOrCreateComputePipeline(
|
||||||
|
const ComputePipelineDescriptor* descriptor);
|
||||||
|
void UncacheComputePipeline(ComputePipelineBase* obj);
|
||||||
|
|
||||||
ResultOrError<PipelineLayoutBase*> GetOrCreatePipelineLayout(
|
ResultOrError<PipelineLayoutBase*> GetOrCreatePipelineLayout(
|
||||||
const PipelineLayoutDescriptor* descriptor);
|
const PipelineLayoutDescriptor* descriptor);
|
||||||
void UncachePipelineLayout(PipelineLayoutBase* obj);
|
void UncachePipelineLayout(PipelineLayoutBase* obj);
|
||||||
|
@ -87,4 +87,9 @@ namespace dawn_native {
|
|||||||
return mLayout.Get();
|
return mLayout.Get();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const PipelineLayoutBase* PipelineBase::GetLayout() const {
|
||||||
|
ASSERT(!IsError());
|
||||||
|
return mLayout.Get();
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dawn_native
|
} // namespace dawn_native
|
||||||
|
@ -48,6 +48,7 @@ namespace dawn_native {
|
|||||||
const PushConstantInfo& GetPushConstants(dawn::ShaderStage stage) const;
|
const PushConstantInfo& GetPushConstants(dawn::ShaderStage stage) const;
|
||||||
dawn::ShaderStageBit GetStageMask() const;
|
dawn::ShaderStageBit GetStageMask() const;
|
||||||
PipelineLayoutBase* GetLayout();
|
PipelineLayoutBase* GetLayout();
|
||||||
|
const PipelineLayoutBase* GetLayout() const;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, dawn::ShaderStageBit stages);
|
PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, dawn::ShaderStageBit stages);
|
||||||
|
@ -75,4 +75,86 @@ TEST_P(ObjectCachingTest, ShaderModuleDeduplication) {
|
|||||||
EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
|
EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that ComputePipeline are correctly deduplicated wrt. their ShaderModule
|
||||||
|
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) {
|
||||||
|
dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
|
||||||
|
#version 450
|
||||||
|
void main() {
|
||||||
|
int i = 0;
|
||||||
|
})");
|
||||||
|
dawn::ShaderModule sameModule =
|
||||||
|
utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
|
||||||
|
#version 450
|
||||||
|
void main() {
|
||||||
|
int i = 0;
|
||||||
|
})");
|
||||||
|
dawn::ShaderModule otherModule =
|
||||||
|
utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
|
||||||
|
#version 450
|
||||||
|
void main() {
|
||||||
|
})");
|
||||||
|
|
||||||
|
EXPECT_NE(module.Get(), otherModule.Get());
|
||||||
|
EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
|
||||||
|
|
||||||
|
dawn::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr);
|
||||||
|
|
||||||
|
dawn::PipelineStageDescriptor stageDesc;
|
||||||
|
stageDesc.entryPoint = "main";
|
||||||
|
stageDesc.module = module;
|
||||||
|
|
||||||
|
dawn::ComputePipelineDescriptor desc;
|
||||||
|
desc.computeStage = &stageDesc;
|
||||||
|
desc.layout = layout;
|
||||||
|
|
||||||
|
dawn::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
|
||||||
|
|
||||||
|
stageDesc.module = sameModule;
|
||||||
|
dawn::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
|
||||||
|
|
||||||
|
stageDesc.module = otherModule;
|
||||||
|
dawn::ComputePipeline otherPipeline = device.CreateComputePipeline(&desc);
|
||||||
|
|
||||||
|
EXPECT_NE(pipeline.Get(), otherPipeline.Get());
|
||||||
|
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test that ComputePipeline are correctly deduplicated wrt. their layout
|
||||||
|
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) {
|
||||||
|
dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
|
||||||
|
device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
|
||||||
|
dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout(
|
||||||
|
device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}});
|
||||||
|
|
||||||
|
dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
|
||||||
|
dawn::PipelineLayout samePl = utils::MakeBasicPipelineLayout(device, &bgl);
|
||||||
|
dawn::PipelineLayout otherPl = utils::MakeBasicPipelineLayout(device, nullptr);
|
||||||
|
|
||||||
|
EXPECT_NE(pl.Get(), otherPl.Get());
|
||||||
|
EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
|
||||||
|
|
||||||
|
dawn::PipelineStageDescriptor stageDesc;
|
||||||
|
stageDesc.entryPoint = "main";
|
||||||
|
stageDesc.module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
|
||||||
|
#version 450
|
||||||
|
void main() {
|
||||||
|
int i = 0;
|
||||||
|
})");
|
||||||
|
|
||||||
|
dawn::ComputePipelineDescriptor desc;
|
||||||
|
desc.computeStage = &stageDesc;
|
||||||
|
|
||||||
|
desc.layout = pl;
|
||||||
|
dawn::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
|
||||||
|
|
||||||
|
desc.layout = samePl;
|
||||||
|
dawn::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
|
||||||
|
|
||||||
|
desc.layout = otherPl;
|
||||||
|
dawn::ComputePipeline otherPipeline = device.CreateComputePipeline(&desc);
|
||||||
|
|
||||||
|
EXPECT_NE(pipeline.Get(), otherPipeline.Get());
|
||||||
|
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||||
|
}
|
||||||
|
|
||||||
DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);
|
DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);
|
||||||
|
Loading…
x
Reference in New Issue
Block a user