dawn_native: deduplicate compute pipelines

BUG=dawn:143

Change-Id: I64e4660de2241bb72bb7c615a0bd1e675e043295
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6863
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
Corentin Wallez 2019-05-01 13:48:47 +00:00 committed by Commit Bot service account
parent c535198d96
commit 1152bbaf8e
7 changed files with 159 additions and 4 deletions

View File

@ -14,6 +14,7 @@
#include "dawn_native/ComputePipeline.h"
#include "common/HashUtils.h"
#include "dawn_native/Device.h"
namespace dawn_native {
@ -33,8 +34,12 @@ namespace dawn_native {
// ComputePipelineBase
ComputePipelineBase::ComputePipelineBase(DeviceBase* device,
const ComputePipelineDescriptor* descriptor)
: PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute) {
const ComputePipelineDescriptor* descriptor,
bool blueprint)
: PipelineBase(device, descriptor->layout, dawn::ShaderStageBit::Compute),
mModule(descriptor->computeStage->module),
mEntryPoint(descriptor->computeStage->entryPoint),
mIsBlueprint(blueprint) {
ExtractModuleData(dawn::ShaderStage::Compute, descriptor->computeStage->module);
}
@ -42,9 +47,29 @@ namespace dawn_native {
: PipelineBase(device, tag) {
}
ComputePipelineBase::~ComputePipelineBase() {
// Do not uncache the actual cached object if we are a blueprint
if (!mIsBlueprint) {
ASSERT(!IsError());
GetDevice()->UncacheComputePipeline(this);
}
}
// static
ComputePipelineBase* ComputePipelineBase::MakeError(DeviceBase* device) {
return new ComputePipelineBase(device, ObjectBase::kError);
}
size_t ComputePipelineBase::HashFunc::operator()(const ComputePipelineBase* pipeline) const {
size_t hash = 0;
HashCombine(&hash, pipeline->mModule.Get(), pipeline->mEntryPoint, pipeline->GetLayout());
return hash;
}
bool ComputePipelineBase::EqualityFunc::operator()(const ComputePipelineBase* a,
const ComputePipelineBase* b) const {
return a->mModule.Get() == b->mModule.Get() && a->mEntryPoint == b->mEntryPoint &&
a->GetLayout() == b->GetLayout();
}
} // namespace dawn_native

View File

@ -26,12 +26,28 @@ namespace dawn_native {
class ComputePipelineBase : public PipelineBase {
public:
ComputePipelineBase(DeviceBase* device, const ComputePipelineDescriptor* descriptor);
ComputePipelineBase(DeviceBase* device,
const ComputePipelineDescriptor* descriptor,
bool blueprint = false);
~ComputePipelineBase() override;
static ComputePipelineBase* MakeError(DeviceBase* device);
// Functors necessary for the unordered_set<ComputePipelineBase*>-based cache.
struct HashFunc {
size_t operator()(const ComputePipelineBase* pipeline) const;
};
struct EqualityFunc {
bool operator()(const ComputePipelineBase* a, const ComputePipelineBase* b) const;
};
private:
ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
// TODO(cwallez@chromium.org): Store a crypto hash of the module instead.
Ref<ShaderModuleBase> mModule;
std::string mEntryPoint;
bool mIsBlueprint = false;
};
} // namespace dawn_native

View File

@ -48,6 +48,7 @@ namespace dawn_native {
struct DeviceBase::Caches {
ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
ContentLessObjectCache<ComputePipelineBase> computePipelines;
ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
ContentLessObjectCache<ShaderModuleBase> shaderModules;
};
@ -121,6 +122,27 @@ namespace dawn_native {
ASSERT(removedCount == 1);
}
ResultOrError<ComputePipelineBase*> DeviceBase::GetOrCreateComputePipeline(
const ComputePipelineDescriptor* descriptor) {
ComputePipelineBase blueprint(this, descriptor, true);
auto iter = mCaches->computePipelines.find(&blueprint);
if (iter != mCaches->computePipelines.end()) {
(*iter)->Reference();
return *iter;
}
ComputePipelineBase* backendObj;
DAWN_TRY_ASSIGN(backendObj, CreateComputePipelineImpl(descriptor));
mCaches->computePipelines.insert(backendObj);
return backendObj;
}
void DeviceBase::UncacheComputePipeline(ComputePipelineBase* obj) {
size_t removedCount = mCaches->computePipelines.erase(obj);
ASSERT(removedCount == 1);
}
ResultOrError<PipelineLayoutBase*> DeviceBase::GetOrCreatePipelineLayout(
const PipelineLayoutDescriptor* descriptor) {
PipelineLayoutBase blueprint(this, descriptor, true);
@ -369,7 +391,7 @@ namespace dawn_native {
ComputePipelineBase** result,
const ComputePipelineDescriptor* descriptor) {
DAWN_TRY(ValidateComputePipelineDescriptor(this, descriptor));
DAWN_TRY_ASSIGN(*result, CreateComputePipelineImpl(descriptor));
DAWN_TRY_ASSIGN(*result, GetOrCreateComputePipeline(descriptor));
return {};
}

View File

@ -84,6 +84,10 @@ namespace dawn_native {
const BindGroupLayoutDescriptor* descriptor);
void UncacheBindGroupLayout(BindGroupLayoutBase* obj);
ResultOrError<ComputePipelineBase*> GetOrCreateComputePipeline(
const ComputePipelineDescriptor* descriptor);
void UncacheComputePipeline(ComputePipelineBase* obj);
ResultOrError<PipelineLayoutBase*> GetOrCreatePipelineLayout(
const PipelineLayoutDescriptor* descriptor);
void UncachePipelineLayout(PipelineLayoutBase* obj);

View File

@ -87,4 +87,9 @@ namespace dawn_native {
return mLayout.Get();
}
const PipelineLayoutBase* PipelineBase::GetLayout() const {
ASSERT(!IsError());
return mLayout.Get();
}
} // namespace dawn_native

View File

@ -48,6 +48,7 @@ namespace dawn_native {
const PushConstantInfo& GetPushConstants(dawn::ShaderStage stage) const;
dawn::ShaderStageBit GetStageMask() const;
PipelineLayoutBase* GetLayout();
const PipelineLayoutBase* GetLayout() const;
protected:
PipelineBase(DeviceBase* device, PipelineLayoutBase* layout, dawn::ShaderStageBit stages);

View File

@ -75,4 +75,86 @@ TEST_P(ObjectCachingTest, ShaderModuleDeduplication) {
EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
}
// Test that ComputePipeline are correctly deduplicated wrt. their ShaderModule
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) {
dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
#version 450
void main() {
int i = 0;
})");
dawn::ShaderModule sameModule =
utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
#version 450
void main() {
int i = 0;
})");
dawn::ShaderModule otherModule =
utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
#version 450
void main() {
})");
EXPECT_NE(module.Get(), otherModule.Get());
EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
dawn::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr);
dawn::PipelineStageDescriptor stageDesc;
stageDesc.entryPoint = "main";
stageDesc.module = module;
dawn::ComputePipelineDescriptor desc;
desc.computeStage = &stageDesc;
desc.layout = layout;
dawn::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
stageDesc.module = sameModule;
dawn::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
stageDesc.module = otherModule;
dawn::ComputePipeline otherPipeline = device.CreateComputePipeline(&desc);
EXPECT_NE(pipeline.Get(), otherPipeline.Get());
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
}
// Test that ComputePipeline are correctly deduplicated wrt. their layout
TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) {
dawn::BindGroupLayout bgl = utils::MakeBindGroupLayout(
device, {{1, dawn::ShaderStageBit::Fragment, dawn::BindingType::UniformBuffer}});
dawn::BindGroupLayout otherBgl = utils::MakeBindGroupLayout(
device, {{1, dawn::ShaderStageBit::Vertex, dawn::BindingType::UniformBuffer}});
dawn::PipelineLayout pl = utils::MakeBasicPipelineLayout(device, &bgl);
dawn::PipelineLayout samePl = utils::MakeBasicPipelineLayout(device, &bgl);
dawn::PipelineLayout otherPl = utils::MakeBasicPipelineLayout(device, nullptr);
EXPECT_NE(pl.Get(), otherPl.Get());
EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
dawn::PipelineStageDescriptor stageDesc;
stageDesc.entryPoint = "main";
stageDesc.module = utils::CreateShaderModule(device, dawn::ShaderStage::Compute, R"(
#version 450
void main() {
int i = 0;
})");
dawn::ComputePipelineDescriptor desc;
desc.computeStage = &stageDesc;
desc.layout = pl;
dawn::ComputePipeline pipeline = device.CreateComputePipeline(&desc);
desc.layout = samePl;
dawn::ComputePipeline samePipeline = device.CreateComputePipeline(&desc);
desc.layout = otherPl;
dawn::ComputePipeline otherPipeline = device.CreateComputePipeline(&desc);
EXPECT_NE(pipeline.Get(), otherPipeline.Get());
EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire());
}
DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);