dawn_native: deduplicate shader modules
BUG=dawn:143 Change-Id: I2c0fa63e3a6d77c137418f12b9807d16a0636d57 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/6862 Commit-Queue: Corentin Wallez <cwallez@chromium.org> Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
parent
0ee9859c91
commit
c535198d96
|
@ -49,6 +49,7 @@ namespace dawn_native {
|
||||||
struct DeviceBase::Caches {
|
struct DeviceBase::Caches {
|
||||||
ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
|
ContentLessObjectCache<BindGroupLayoutBase> bindGroupLayouts;
|
||||||
ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
|
ContentLessObjectCache<PipelineLayoutBase> pipelineLayouts;
|
||||||
|
ContentLessObjectCache<ShaderModuleBase> shaderModules;
|
||||||
};
|
};
|
||||||
|
|
||||||
// DeviceBase
|
// DeviceBase
|
||||||
|
@ -141,6 +142,27 @@ namespace dawn_native {
|
||||||
ASSERT(removedCount == 1);
|
ASSERT(removedCount == 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ResultOrError<ShaderModuleBase*> DeviceBase::GetOrCreateShaderModule(
|
||||||
|
const ShaderModuleDescriptor* descriptor) {
|
||||||
|
ShaderModuleBase blueprint(this, descriptor, true);
|
||||||
|
|
||||||
|
auto iter = mCaches->shaderModules.find(&blueprint);
|
||||||
|
if (iter != mCaches->shaderModules.end()) {
|
||||||
|
(*iter)->Reference();
|
||||||
|
return *iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
ShaderModuleBase* backendObj;
|
||||||
|
DAWN_TRY_ASSIGN(backendObj, CreateShaderModuleImpl(descriptor));
|
||||||
|
mCaches->shaderModules.insert(backendObj);
|
||||||
|
return backendObj;
|
||||||
|
}
|
||||||
|
|
||||||
|
void DeviceBase::UncacheShaderModule(ShaderModuleBase* obj) {
|
||||||
|
size_t removedCount = mCaches->shaderModules.erase(obj);
|
||||||
|
ASSERT(removedCount == 1);
|
||||||
|
}
|
||||||
|
|
||||||
// Object creation API methods
|
// Object creation API methods
|
||||||
|
|
||||||
BindGroupBase* DeviceBase::CreateBindGroup(const BindGroupDescriptor* descriptor) {
|
BindGroupBase* DeviceBase::CreateBindGroup(const BindGroupDescriptor* descriptor) {
|
||||||
|
@ -382,7 +404,7 @@ namespace dawn_native {
|
||||||
MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result,
|
MaybeError DeviceBase::CreateShaderModuleInternal(ShaderModuleBase** result,
|
||||||
const ShaderModuleDescriptor* descriptor) {
|
const ShaderModuleDescriptor* descriptor) {
|
||||||
DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor));
|
DAWN_TRY(ValidateShaderModuleDescriptor(this, descriptor));
|
||||||
DAWN_TRY_ASSIGN(*result, CreateShaderModuleImpl(descriptor));
|
DAWN_TRY_ASSIGN(*result, GetOrCreateShaderModule(descriptor));
|
||||||
return {};
|
return {};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -88,6 +88,10 @@ namespace dawn_native {
|
||||||
const PipelineLayoutDescriptor* descriptor);
|
const PipelineLayoutDescriptor* descriptor);
|
||||||
void UncachePipelineLayout(PipelineLayoutBase* obj);
|
void UncachePipelineLayout(PipelineLayoutBase* obj);
|
||||||
|
|
||||||
|
ResultOrError<ShaderModuleBase*> GetOrCreateShaderModule(
|
||||||
|
const ShaderModuleDescriptor* descriptor);
|
||||||
|
void UncacheShaderModule(ShaderModuleBase* obj);
|
||||||
|
|
||||||
// Dawn API
|
// Dawn API
|
||||||
BindGroupBase* CreateBindGroup(const BindGroupDescriptor* descriptor);
|
BindGroupBase* CreateBindGroup(const BindGroupDescriptor* descriptor);
|
||||||
BindGroupLayoutBase* CreateBindGroupLayout(const BindGroupLayoutDescriptor* descriptor);
|
BindGroupLayoutBase* CreateBindGroupLayout(const BindGroupLayoutDescriptor* descriptor);
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
|
|
||||||
#include "dawn_native/ShaderModule.h"
|
#include "dawn_native/ShaderModule.h"
|
||||||
|
|
||||||
|
#include "common/HashUtils.h"
|
||||||
#include "dawn_native/BindGroupLayout.h"
|
#include "dawn_native/BindGroupLayout.h"
|
||||||
#include "dawn_native/Device.h"
|
#include "dawn_native/Device.h"
|
||||||
#include "dawn_native/Pipeline.h"
|
#include "dawn_native/Pipeline.h"
|
||||||
|
@ -67,14 +68,26 @@ namespace dawn_native {
|
||||||
|
|
||||||
// ShaderModuleBase
|
// ShaderModuleBase
|
||||||
|
|
||||||
ShaderModuleBase::ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor*)
|
ShaderModuleBase::ShaderModuleBase(DeviceBase* device,
|
||||||
: ObjectBase(device) {
|
const ShaderModuleDescriptor* descriptor,
|
||||||
|
bool blueprint)
|
||||||
|
: ObjectBase(device),
|
||||||
|
mCode(descriptor->code, descriptor->code + descriptor->codeSize),
|
||||||
|
mIsBlueprint(blueprint) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag)
|
ShaderModuleBase::ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag)
|
||||||
: ObjectBase(device, tag) {
|
: ObjectBase(device, tag) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ShaderModuleBase::~ShaderModuleBase() {
|
||||||
|
// Do not uncache the actual cached object if we are a blueprint
|
||||||
|
if (!mIsBlueprint) {
|
||||||
|
ASSERT(!IsError());
|
||||||
|
GetDevice()->UncacheShaderModule(this);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// static
|
// static
|
||||||
ShaderModuleBase* ShaderModuleBase::MakeError(DeviceBase* device) {
|
ShaderModuleBase* ShaderModuleBase::MakeError(DeviceBase* device) {
|
||||||
return new ShaderModuleBase(device, ObjectBase::kError);
|
return new ShaderModuleBase(device, ObjectBase::kError);
|
||||||
|
@ -287,4 +300,19 @@ namespace dawn_native {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const {
|
||||||
|
size_t hash = 0;
|
||||||
|
|
||||||
|
for (uint32_t word : module->mCode) {
|
||||||
|
HashCombine(&hash, word);
|
||||||
|
}
|
||||||
|
|
||||||
|
return hash;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ShaderModuleBase::EqualityFunc::operator()(const ShaderModuleBase* a,
|
||||||
|
const ShaderModuleBase* b) const {
|
||||||
|
return a->mCode == b->mCode;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace dawn_native
|
} // namespace dawn_native
|
||||||
|
|
|
@ -37,7 +37,10 @@ namespace dawn_native {
|
||||||
|
|
||||||
class ShaderModuleBase : public ObjectBase {
|
class ShaderModuleBase : public ObjectBase {
|
||||||
public:
|
public:
|
||||||
ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
|
ShaderModuleBase(DeviceBase* device,
|
||||||
|
const ShaderModuleDescriptor* descriptor,
|
||||||
|
bool blueprint = false);
|
||||||
|
~ShaderModuleBase() override;
|
||||||
|
|
||||||
static ShaderModuleBase* MakeError(DeviceBase* device);
|
static ShaderModuleBase* MakeError(DeviceBase* device);
|
||||||
|
|
||||||
|
@ -68,11 +71,24 @@ namespace dawn_native {
|
||||||
|
|
||||||
bool IsCompatibleWithPipelineLayout(const PipelineLayoutBase* layout);
|
bool IsCompatibleWithPipelineLayout(const PipelineLayoutBase* layout);
|
||||||
|
|
||||||
|
// Functors necessary for the unordered_set<ShaderModuleBase*>-based cache.
|
||||||
|
struct HashFunc {
|
||||||
|
size_t operator()(const ShaderModuleBase* module) const;
|
||||||
|
};
|
||||||
|
struct EqualityFunc {
|
||||||
|
bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const;
|
||||||
|
};
|
||||||
|
|
||||||
private:
|
private:
|
||||||
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
|
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);
|
||||||
|
|
||||||
bool IsCompatibleWithBindGroupLayout(size_t group, const BindGroupLayoutBase* layout);
|
bool IsCompatibleWithBindGroupLayout(size_t group, const BindGroupLayoutBase* layout);
|
||||||
|
|
||||||
|
// TODO(cwallez@chromium.org): The code is only stored for deduplication. We could maybe
|
||||||
|
// store a cryptographic hash of the code instead?
|
||||||
|
std::vector<uint32_t> mCode;
|
||||||
|
bool mIsBlueprint = false;
|
||||||
|
|
||||||
PushConstantInfo mPushConstants = {};
|
PushConstantInfo mPushConstants = {};
|
||||||
ModuleBindingInfo mBindingInfo;
|
ModuleBindingInfo mBindingInfo;
|
||||||
std::bitset<kMaxVertexAttributes> mUsedVertexAttributes;
|
std::bitset<kMaxVertexAttributes> mUsedVertexAttributes;
|
||||||
|
|
|
@ -48,4 +48,31 @@ TEST_P(ObjectCachingTest, PipelineLayoutDeduplication) {
|
||||||
EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
|
EXPECT_EQ(pl.Get() == samePl.Get(), !UsesWire());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Test that ShaderModules are correctly deduplicated.
|
||||||
|
TEST_P(ObjectCachingTest, ShaderModuleDeduplication) {
|
||||||
|
dawn::ShaderModule module = utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
|
||||||
|
#version 450
|
||||||
|
layout(location = 0) out vec4 fragColor;
|
||||||
|
void main() {
|
||||||
|
fragColor = vec4(0.0, 1.0, 0.0, 1.0);
|
||||||
|
})");
|
||||||
|
dawn::ShaderModule sameModule =
|
||||||
|
utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
|
||||||
|
#version 450
|
||||||
|
layout(location = 0) out vec4 fragColor;
|
||||||
|
void main() {
|
||||||
|
fragColor = vec4(0.0, 1.0, 0.0, 1.0);
|
||||||
|
})");
|
||||||
|
dawn::ShaderModule otherModule =
|
||||||
|
utils::CreateShaderModule(device, dawn::ShaderStage::Fragment, R"(
|
||||||
|
#version 450
|
||||||
|
layout(location = 0) out vec4 fragColor;
|
||||||
|
void main() {
|
||||||
|
fragColor = vec4(0.0);
|
||||||
|
})");
|
||||||
|
|
||||||
|
EXPECT_NE(module.Get(), otherModule.Get());
|
||||||
|
EXPECT_EQ(module.Get() == sameModule.Get(), !UsesWire());
|
||||||
|
}
|
||||||
|
|
||||||
DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);
|
DAWN_INSTANTIATE_TEST(ObjectCachingTest, D3D12Backend, MetalBackend, OpenGLBackend, VulkanBackend);
|
||||||
|
|
Loading…
Reference in New Issue