mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-05-15 11:51:22 +00:00
dawn_native: deduplicate compute pipelines
BUG=dawn:143 Change-Id: I64e4660de2241bb72bb7c615a0bd1e675e043295 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6863 Commit-Queue: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
parent
c535198d96
commit
1152bbaf8e
@ -14,6 +14,7 @@
|
||||
|
||||
#include "dawn_native/ComputePipeline.h"
|
||||
|
||||
#include "common/HashUtils.h"
|
||||
#include "dawn_native/Device.h"
|
||||
|
||||
namespace dawn_native {
|
||||
@ -33,8 +34,12 @@ namespace dawn_native {
|
||||
// ComputePipelineBase
|
||||
|
||||
ComputePipelineBase::ComputePipelineBase(DeviceBase* device,
|
||||
const ComputePipelineDescriptor* descriptor)
|
||||
: PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) {
|
||||
const ComputePipelineDescriptor* descriptor,
|
||||
bool blueprint)
|
||||
: PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute),
|
||||
mModule(descriptor->computeStage->module),
|
||||
mEntryPoint(descriptor->computeStage->entryPoint),
|
||||
mIsBlueprint(blueprint) {
|
||||
ExtractModuleData(dawn::ShaderStage::Compute, descriptor->computeStage->module);
|
||||
}
|
||||
|
||||
@ -42,9 +47,29 @@ namespace dawn_native {
|
||||
: PipelineBase(device, tag) {
|
||||
}
|
||||
|
||||
ComputePipelineBase::~ComputePipelineBase() {
|
||||
// Do not uncache the actual cached object if we are a blueprint
|
||||
if (!mIsBlueprint) {
|
||||
ASSERT(!IsError());
|
||||
GetDevice()->UncacheComputePipeline(this);
|
||||
}
|
||||
}
|
||||
|
||||
// static
|
||||
ComputePipelineBase* ComputePipelineBase::MakeError(DeviceBase* device) {
|
||||
return new ComputePipelineBase(device, ObjectBase::kError);
|
||||
}
|
||||
|
||||
size_t ComputePipelineBase::HashFunc::operator()(const ComputePipelineBase* pipeline) const {
|
||||
size_t hash = 0;
|
||||
HashCombine(&hash, pipeline->mModule.Get(), pipeline->mEntryPoint, pipeline->GetLayout());
|
||||
return hash;
|
||||
}
|
||||
|
||||
bool ComputePipelineBase::EqualityFunc::operator()(const ComputePipelineBase* a,
|
||||
const ComputePipelineBase* b) const {
|
||||
return a->mModule.Get() == b->mModule.Get() && a->mEntryPoint == b->mEntryPoint &&
|
||||
a->GetLayout() == b->GetLayout();
|
||||
}
|
||||
|
||||
} // namespace dawn_native
|
||||
|
@ -26,12 +26,28 @@ namespace dawn_native {
|
||||
|
||||
class ComputePipelineBase : public PipelineBase {
|
||||
public:
|
||||
ComputePipelineBase(DeviceBase* device, const ComputePipelineDescriptor* descriptor);
|
||||
ComputePipelineBase(DeviceBase* device,
|
||||
const ComputePipelineDescriptor* descriptor,
|
||||
bool blueprint = false);
|
||||
~ComputePipelineBase() override;
|
||||
|
||||
static ComputePipelineBase* MakeError(DeviceBase* device);
|
||||
|
||||
// Functors necessary for the unordered_set<ComputePipelineBase*>-based cache.
|
||||
struct HashFunc {
|
||||
size_t operator()(const ComputePipelineBase* pipeline) const;
|
||||
};
|
||||
struct EqualityFunc {
|
||||
bool operator()(const ComputePipelineBase* a, const ComputePipelineBase* b) const;
|
||||
};
|
||||
|
||||
private:
|
||||
ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
|
||||
|
||||
// TODO(cwallez@chromium.org): Store a crypto hash of the module instead.
|
||||
Ref<ShaderModuleBase> mModule;
|
||||
std::string mEntryPoint;
|
||||
bool mIsBlueprint = false;
|
||||
};
|
||||
|
||||
} // namespace dawn_native
|
||||
|
@ -48,6 +48,7 @@ namespace dawn_native {
|
||||
|
||||
struct DeviceBase::Caches {
|
||||
ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
|
||||
ContentLessObjectCache<ComputePipelineBase> computePipelines;
|
||||
ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
|
||||
ContentLessObjectCache<ShaderModuleBase> shaderModules;
|
||||
};
|
||||
@ -121,6 +122,27 @@ namespace dawn_native {
|
||||
ASSERT(removedCount == 1);
|
||||
}
|
||||
|
||||
ResultOrError<ComputePipelineBase*> DeviceBase::GetOrCreateComputePipeline(
|
||||
const ComputePipelineDescriptor* descriptor) {
|
||||
ComputePipelineBase blueprint(this, descriptor, true);
|
||||
|
||||
auto iter = mCaches->computePipelines.find(&blueprint);
|
||||
if (iter != mCaches->computePipelines.end()) {
|
||||
(*iter)->Reference();
|
||||
return *iter;
|
||||
}
|
||||
|
||||
ComputePipelineBase* backendObj;
|
||||
DAWN_TRY_ASSIGN(backendObj, CreateComputePipelineImpl(descriptor));
|
||||
mCaches->computePipelines.insert(backendObj);
|
||||
return backendObj;
|
||||
}
|
||||
|
||||
void DeviceBase::UncacheComputePipeline(ComputePipelineBase* obj) {
|
||||
size_t removedCount = mCaches->computePipelines.erase(obj);
|
||||
ASSERT(removedCount == 1);
|
||||
}
|
||||
|
||||
ResultOrError<PipelineLayoutBase*> DeviceBase::GetOrCreatePipelineLayout(
|
||||
const PipelineLayoutDescriptor* descriptor) {
|
||||
PipelineLayoutBase blueprint(this, descriptor, true);
|
||||
@ -369,7 +391,7 @@ namespace dawn_native {
|
||||
ComputePipelineBase** result,
|
||||
const ComputePipelineDescriptor* descriptor) {
|
||||
DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
|
||||
DAWN_TRY_ASSIGN(*result, CreateComputePipelineImpl(descriptor));
|
||||
DAWN_TRY_ASSIGN(*result, GetOrCreateComputePipeline(descriptor));
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -84,6 +84,10 @@ namespace dawn_native {
|
||||
const BindGroupLayoutDescriptor* descriptor);
|
||||
void UncacheBindGroupLayout(BindGroupLayoutBase* obj);
|
||||
|
||||
ResultOrError<ComputePipelineBase*> GetOrCreateComputePipeline(
|
||||
const ComputePipelineDescriptor* descriptor);
|
||||
void UncacheComputePipeline(ComputePipelineBase* obj);
|
||||
|
||||
ResultOrError<PipelineLayoutBase*> GetOrCreatePipelineLayout(
|
||||
const PipelineLayoutDescriptor* descriptor);
|
||||
void UncachePipelineLayout(PipelineLayoutBase* obj);
|
||||
|
@ -87,4 +87,9 @@ namespace dawn_native {
|
||||
return mLayout.Get();
|
||||
}
|
||||
|
||||
const PipelineLayoutBase* PipelineBase::GetLayout() const {
|
||||
ASSERT(!IsError());
|
||||
return mLayout.Get();
|
||||
}
|
||||
|
||||
} // namespace dawn_native
|
||||
|
@ -48,6 +48,7 @@ namespace dawn_native {
|
||||
const PushConstantInfo& GetPushConstants(dawn::ShaderStage stage) const;
|
||||
dawn::ShaderStageBit GetStageMask() const;
|
||||
PipelineLayoutBase* GetLayout();
|
||||
const PipelineLayoutBase* GetLayout() const;
|
||||
|
||||
protected:
|
||||
PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, dawn::ShaderStageBit stages);
|
||||
|
@ -75,4 +75,86 @@ TEST_P(ObjectCachingTest, ShaderModuleDeduplication) {
|
||||
EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
|
||||
}
|
||||
|
||||
// Test that ComputePipeline are correctly deduplicated wrt. their ShaderModule
|
||||
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) {
|
||||
dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
|
||||
#version 450
|
||||
void main() {
|
||||
int i = 0;
|
||||
})");
|
||||
dawn::ShaderModule sameModule =
|
||||
utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
|
||||
#version 450
|
||||
void main() {
|
||||
int i = 0;
|
||||
})");
|
||||
dawn::ShaderModule otherModule =
|
||||
utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
|
||||
#version 450
|
||||
void main() {
|
||||
})");
|
||||
|
||||
EXPECT_NE(module.Get(), otherModule.Get());
|
||||
EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
|
||||
|
||||
dawn::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr);
|
||||
|
||||
dawn::PipelineStageDescriptor stageDesc;
|
||||
stageDesc.entryPoint = "main";
|
||||
stageDesc.module = module;
|
||||
|
||||
dawn::ComputePipelineDescriptor desc;
|
||||
desc.computeStage = &stageDesc;
|
||||
desc.layout = layout;
|
||||
|
||||
dawn::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
|
||||
|
||||
stageDesc.module = sameModule;
|
||||
dawn::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
|
||||
|
||||
stageDesc.module = otherModule;
|
||||
dawn::ComputePipeline otherPipeline = device.CreateComputePipeline(&desc);
|
||||
|
||||
EXPECT_NE(pipeline.Get(), otherPipeline.Get());
|
||||
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||
}
|
||||
|
||||
// Test that ComputePipeline are correctly deduplicated wrt. their layout
|
||||
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) {
|
||||
dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
|
||||
device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
|
||||
dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout(
|
||||
device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}});
|
||||
|
||||
dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
|
||||
dawn::PipelineLayout samePl = utils::MakeBasicPipelineLayout(device, &bgl);
|
||||
dawn::PipelineLayout otherPl = utils::MakeBasicPipelineLayout(device, nullptr);
|
||||
|
||||
EXPECT_NE(pl.Get(), otherPl.Get());
|
||||
EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
|
||||
|
||||
dawn::PipelineStageDescriptor stageDesc;
|
||||
stageDesc.entryPoint = "main";
|
||||
stageDesc.module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
|
||||
#version 450
|
||||
void main() {
|
||||
int i = 0;
|
||||
})");
|
||||
|
||||
dawn::ComputePipelineDescriptor desc;
|
||||
desc.computeStage = &stageDesc;
|
||||
|
||||
desc.layout = pl;
|
||||
dawn::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
|
||||
|
||||
desc.layout = samePl;
|
||||
dawn::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
|
||||
|
||||
desc.layout = otherPl;
|
||||
dawn::ComputePipeline otherPipeline = device.CreateComputePipeline(&desc);
|
||||
|
||||
EXPECT_NE(pipeline.Get(), otherPipeline.Get());
|
||||
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
|
||||
}
|
||||
|
||||
DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);
|
||||
|
Loading…
x
Reference in New Issue
Block a user