Unify ProgrammableStageDescriptor handling in PipelineBase

Previously both Render and Compute pipelines handled extracting data
from the ProgrammableStageDescriptors. Unify them in PipelineBase in
preparation for gathering EntryPointMetadata in the PipelineBase.

Bug: dawn:216
Change-Id: I633dd2d8c9fdd0c08bb34cbf18955445951e312f
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/27263
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
Reviewed-by: Austin Eng <enga@chromium.org>
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
This commit is contained in:
Corentin Wallez 2020-08-28 14:26:00 +00:00 committed by Commit Bot service account
parent 900bd341a3
commit 9ed8d518ca
8 changed files with 104 additions and 93 deletions

View File

@ -26,11 +26,11 @@ namespace dawn_native {
namespace { namespace {
bool BufferSizesAtLeastAsBig(const ityp::span<uint32_t, uint64_t> unverifiedBufferSizes, bool BufferSizesAtLeastAsBig(const ityp::span<uint32_t, uint64_t> unverifiedBufferSizes,
const std::vector<uint64_t>& pipelineMinimumBufferSizes) { const std::vector<uint64_t>& pipelineMinBufferSizes) {
ASSERT(unverifiedBufferSizes.size() == pipelineMinimumBufferSizes.size()); ASSERT(unverifiedBufferSizes.size() == pipelineMinBufferSizes.size());
for (uint32_t i = 0; i < unverifiedBufferSizes.size(); ++i) { for (uint32_t i = 0; i < unverifiedBufferSizes.size(); ++i) {
if (unverifiedBufferSizes[i] < pipelineMinimumBufferSizes[i]) { if (unverifiedBufferSizes[i] < pipelineMinBufferSizes[i]) {
return false; return false;
} }
} }
@ -105,7 +105,7 @@ namespace dawn_native {
if (mBindgroups[i] == nullptr || if (mBindgroups[i] == nullptr ||
mLastPipelineLayout->GetBindGroupLayout(i) != mBindgroups[i]->GetLayout() || mLastPipelineLayout->GetBindGroupLayout(i) != mBindgroups[i]->GetLayout() ||
!BufferSizesAtLeastAsBig(mBindgroups[i]->GetUnverifiedBufferSizes(), !BufferSizesAtLeastAsBig(mBindgroups[i]->GetUnverifiedBufferSizes(),
(*mMinimumBufferSizes)[i])) { (*mMinBufferSizes)[i])) {
matches = false; matches = false;
break; break;
} }
@ -190,7 +190,7 @@ namespace dawn_native {
"Pipeline and bind group layout doesn't match for bind group " + "Pipeline and bind group layout doesn't match for bind group " +
std::to_string(static_cast<uint32_t>(i))); std::to_string(static_cast<uint32_t>(i)));
} else if (!BufferSizesAtLeastAsBig(mBindgroups[i]->GetUnverifiedBufferSizes(), } else if (!BufferSizesAtLeastAsBig(mBindgroups[i]->GetUnverifiedBufferSizes(),
(*mMinimumBufferSizes)[i])) { (*mMinBufferSizes)[i])) {
return DAWN_VALIDATION_ERROR("Binding sizes too small for bind group " + return DAWN_VALIDATION_ERROR("Binding sizes too small for bind group " +
std::to_string(static_cast<uint32_t>(i))); std::to_string(static_cast<uint32_t>(i)));
} }
@ -236,7 +236,7 @@ namespace dawn_native {
void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) { void CommandBufferStateTracker::SetPipelineCommon(PipelineBase* pipeline) {
mLastPipelineLayout = pipeline->GetLayout(); mLastPipelineLayout = pipeline->GetLayout();
mMinimumBufferSizes = &pipeline->GetMinimumBufferSizes(); mMinBufferSizes = &pipeline->GetMinBufferSizes();
mAspects.set(VALIDATION_ASPECT_PIPELINE); mAspects.set(VALIDATION_ASPECT_PIPELINE);

View File

@ -61,7 +61,7 @@ namespace dawn_native {
PipelineLayoutBase* mLastPipelineLayout = nullptr; PipelineLayoutBase* mLastPipelineLayout = nullptr;
RenderPipelineBase* mLastRenderPipeline = nullptr; RenderPipelineBase* mLastRenderPipeline = nullptr;
const RequiredBufferSizes* mMinimumBufferSizes = nullptr; const RequiredBufferSizes* mMinBufferSizes = nullptr;
}; };
} // namespace dawn_native } // namespace dawn_native

View File

@ -19,13 +19,6 @@
namespace dawn_native { namespace dawn_native {
namespace {
RequiredBufferSizes ComputeMinBufferSizes(const ComputePipelineDescriptor* descriptor) {
return descriptor->computeStage.module->ComputeRequiredBufferSizesForLayout(
descriptor->layout);
}
} // anonymous namespace
MaybeError ValidateComputePipelineDescriptor(DeviceBase* device, MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
const ComputePipelineDescriptor* descriptor) { const ComputePipelineDescriptor* descriptor) {
if (descriptor->nextInChain != nullptr) { if (descriptor->nextInChain != nullptr) {
@ -47,10 +40,7 @@ namespace dawn_native {
const ComputePipelineDescriptor* descriptor) const ComputePipelineDescriptor* descriptor)
: PipelineBase(device, : PipelineBase(device,
descriptor->layout, descriptor->layout,
wgpu::ShaderStage::Compute, {{SingleShaderStage::Compute, &descriptor->computeStage}}) {
ComputeMinBufferSizes(descriptor)),
mModule(descriptor->computeStage.module),
mEntryPoint(descriptor->computeStage.entryPoint) {
} }
ComputePipelineBase::ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag) ComputePipelineBase::ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag)
@ -70,15 +60,12 @@ namespace dawn_native {
} }
size_t ComputePipelineBase::HashFunc::operator()(const ComputePipelineBase* pipeline) const { size_t ComputePipelineBase::HashFunc::operator()(const ComputePipelineBase* pipeline) const {
size_t hash = 0; return PipelineBase::HashForCache(pipeline);
HashCombine(&hash, pipeline->mModule.Get(), pipeline->mEntryPoint, pipeline->GetLayout());
return hash;
} }
bool ComputePipelineBase::EqualityFunc::operator()(const ComputePipelineBase* a, bool ComputePipelineBase::EqualityFunc::operator()(const ComputePipelineBase* a,
const ComputePipelineBase* b) const { const ComputePipelineBase* b) const {
return a->mModule.Get() == b->mModule.Get() && a->mEntryPoint == b->mEntryPoint && return PipelineBase::EqualForCache(a, b);
a->GetLayout() == b->GetLayout();
} }
} // namespace dawn_native } // namespace dawn_native

View File

@ -41,10 +41,6 @@ namespace dawn_native {
private: private:
ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag); ComputePipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
// TODO(cwallez@chromium.org): Store a crypto hash of the module instead.
Ref<ShaderModuleBase> mModule;
std::string mEntryPoint;
}; };
} // namespace dawn_native } // namespace dawn_native

View File

@ -14,6 +14,7 @@
#include "dawn_native/Pipeline.h" #include "dawn_native/Pipeline.h"
#include "common/HashUtils.h"
#include "dawn_native/BindGroupLayout.h" #include "dawn_native/BindGroupLayout.h"
#include "dawn_native/Device.h" #include "dawn_native/Device.h"
#include "dawn_native/PipelineLayout.h" #include "dawn_native/PipelineLayout.h"
@ -43,23 +44,38 @@ namespace dawn_native {
PipelineBase::PipelineBase(DeviceBase* device, PipelineBase::PipelineBase(DeviceBase* device,
PipelineLayoutBase* layout, PipelineLayoutBase* layout,
wgpu::ShaderStage stages, std::vector<StageAndDescriptor> stages)
RequiredBufferSizes minimumBufferSizes) : CachedObject(device), mLayout(layout) {
: CachedObject(device), ASSERT(!stages.empty());
mStageMask(stages),
mLayout(layout), for (const StageAndDescriptor& stage : stages) {
mMinimumBufferSizes(std::move(minimumBufferSizes)) { bool isFirstStage = mStageMask == wgpu::ShaderStage::None;
mStageMask |= StageBit(stage.first);
mStages[stage.first] = {stage.second->module, stage.second->entryPoint};
// Compute the max() of all minBufferSizes across all stages.
RequiredBufferSizes stageMinBufferSizes =
stage.second->module->ComputeRequiredBufferSizesForLayout(layout);
if (isFirstStage) {
mMinBufferSizes = std::move(stageMinBufferSizes);
} else {
for (BindGroupIndex group(0); group < mMinBufferSizes.size(); ++group) {
ASSERT(stageMinBufferSizes[group].size() == mMinBufferSizes[group].size());
for (size_t i = 0; i < stageMinBufferSizes[group].size(); ++i) {
mMinBufferSizes[group][i] =
std::max(mMinBufferSizes[group][i], stageMinBufferSizes[group][i]);
}
}
}
}
} }
PipelineBase::PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag) PipelineBase::PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag)
: CachedObject(device, tag) { : CachedObject(device, tag) {
} }
wgpu::ShaderStage PipelineBase::GetStageMask() const {
ASSERT(!IsError());
return mStageMask;
}
PipelineLayoutBase* PipelineBase::GetLayout() { PipelineLayoutBase* PipelineBase::GetLayout() {
ASSERT(!IsError()); ASSERT(!IsError());
return mLayout.Get(); return mLayout.Get();
@ -70,9 +86,14 @@ namespace dawn_native {
return mLayout.Get(); return mLayout.Get();
} }
const RequiredBufferSizes& PipelineBase::GetMinimumBufferSizes() const { const RequiredBufferSizes& PipelineBase::GetMinBufferSizes() const {
ASSERT(!IsError()); ASSERT(!IsError());
return mMinimumBufferSizes; return mMinBufferSizes;
}
const ProgrammableStage& PipelineBase::GetStage(SingleShaderStage stage) const {
ASSERT(!IsError());
return mStages[stage];
} }
MaybeError PipelineBase::ValidateGetBindGroupLayout(uint32_t groupIndex) { MaybeError PipelineBase::ValidateGetBindGroupLayout(uint32_t groupIndex) {
@ -102,4 +123,39 @@ namespace dawn_native {
return bgl; return bgl;
} }
// static
size_t PipelineBase::HashForCache(const PipelineBase* pipeline) {
size_t hash = 0;
// The layout is deduplicated so it can be hashed by pointer.
HashCombine(&hash, pipeline->mLayout.Get());
HashCombine(&hash, pipeline->mStageMask);
for (SingleShaderStage stage : IterateStages(pipeline->mStageMask)) {
// The module is deduplicated so it can be hashed by pointer.
HashCombine(&hash, pipeline->mStages[stage].module.Get());
HashCombine(&hash, pipeline->mStages[stage].entryPoint);
}
return hash;
}
// static
bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) {
// The layout is deduplicated so it can be compared by pointer.
if (a->mLayout.Get() != b->mLayout.Get() || a->mStageMask != b->mStageMask) {
return false;
}
for (SingleShaderStage stage : IterateStages(a->mStageMask)) {
// The module is deduplicated so it can be compared by pointer.
if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() ||
a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) {
return false;
}
}
return true;
}
} // namespace dawn_native } // namespace dawn_native

View File

@ -33,27 +33,40 @@ namespace dawn_native {
const PipelineLayoutBase* layout, const PipelineLayoutBase* layout,
SingleShaderStage stage); SingleShaderStage stage);
struct ProgrammableStage {
Ref<ShaderModuleBase> module;
std::string entryPoint;
};
class PipelineBase : public CachedObject { class PipelineBase : public CachedObject {
public: public:
wgpu::ShaderStage GetStageMask() const;
PipelineLayoutBase* GetLayout(); PipelineLayoutBase* GetLayout();
const PipelineLayoutBase* GetLayout() const; const PipelineLayoutBase* GetLayout() const;
const RequiredBufferSizes& GetMinBufferSizes() const;
const ProgrammableStage& GetStage(SingleShaderStage stage) const;
BindGroupLayoutBase* GetBindGroupLayout(uint32_t groupIndex); BindGroupLayoutBase* GetBindGroupLayout(uint32_t groupIndex);
const RequiredBufferSizes& GetMinimumBufferSizes() const;
// Helper function for the functors for std::unordered_map-based pipeline caches.
static size_t HashForCache(const PipelineBase* pipeline);
static bool EqualForCache(const PipelineBase* a, const PipelineBase* b);
protected: protected:
using StageAndDescriptor = std::pair<SingleShaderStage, const ProgrammableStageDescriptor*>;
PipelineBase(DeviceBase* device, PipelineBase(DeviceBase* device,
PipelineLayoutBase* layout, PipelineLayoutBase* layout,
wgpu::ShaderStage stages, std::vector<StageAndDescriptor> stages);
RequiredBufferSizes bufferSizes);
PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag); PipelineBase(DeviceBase* device, ObjectBase::ErrorTag tag);
private: private:
MaybeError ValidateGetBindGroupLayout(uint32_t group); MaybeError ValidateGetBindGroupLayout(uint32_t group);
wgpu::ShaderStage mStageMask; wgpu::ShaderStage mStageMask = wgpu::ShaderStage::None;
PerStage<ProgrammableStage> mStages;
Ref<PipelineLayoutBase> mLayout; Ref<PipelineLayoutBase> mLayout;
RequiredBufferSizes mMinimumBufferSizes; RequiredBufferSizes mMinBufferSizes;
}; };
} // namespace dawn_native } // namespace dawn_native

View File

@ -193,29 +193,6 @@ namespace dawn_native {
return {}; return {};
} }
RequiredBufferSizes ComputeMinBufferSizes(const RenderPipelineDescriptor* descriptor) {
RequiredBufferSizes bufferSizes =
descriptor->vertexStage.module->ComputeRequiredBufferSizesForLayout(
descriptor->layout);
// Merge the two buffer size requirements by taking the larger element from each
if (descriptor->fragmentStage != nullptr) {
RequiredBufferSizes fragmentSizes =
descriptor->fragmentStage->module->ComputeRequiredBufferSizesForLayout(
descriptor->layout);
for (BindGroupIndex group(0); group < bufferSizes.size(); ++group) {
ASSERT(bufferSizes[group].size() == fragmentSizes[group].size());
for (size_t i = 0; i < bufferSizes[group].size(); ++i) {
bufferSizes[group][i] =
std::max(bufferSizes[group][i], fragmentSizes[group][i]);
}
}
}
return bufferSizes;
}
} // anonymous namespace } // anonymous namespace
// Helper functions // Helper functions
@ -411,16 +388,12 @@ namespace dawn_native {
const RenderPipelineDescriptor* descriptor) const RenderPipelineDescriptor* descriptor)
: PipelineBase(device, : PipelineBase(device,
descriptor->layout, descriptor->layout,
wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment, {{SingleShaderStage::Vertex, &descriptor->vertexStage},
ComputeMinBufferSizes(descriptor)), {SingleShaderStage::Fragment, descriptor->fragmentStage}}),
mAttachmentState(device->GetOrCreateAttachmentState(descriptor)), mAttachmentState(device->GetOrCreateAttachmentState(descriptor)),
mPrimitiveTopology(descriptor->primitiveTopology), mPrimitiveTopology(descriptor->primitiveTopology),
mSampleMask(descriptor->sampleMask), mSampleMask(descriptor->sampleMask),
mAlphaToCoverageEnabled(descriptor->alphaToCoverageEnabled), mAlphaToCoverageEnabled(descriptor->alphaToCoverageEnabled) {
mVertexModule(descriptor->vertexStage.module),
mVertexEntryPoint(descriptor->vertexStage.entryPoint),
mFragmentModule(descriptor->fragmentStage->module),
mFragmentEntryPoint(descriptor->fragmentStage->entryPoint) {
if (descriptor->vertexState != nullptr) { if (descriptor->vertexState != nullptr) {
mVertexState = *descriptor->vertexState; mVertexState = *descriptor->vertexState;
} else { } else {
@ -608,12 +581,8 @@ namespace dawn_native {
} }
size_t RenderPipelineBase::HashFunc::operator()(const RenderPipelineBase* pipeline) const { size_t RenderPipelineBase::HashFunc::operator()(const RenderPipelineBase* pipeline) const {
size_t hash = 0;
// Hash modules and layout // Hash modules and layout
HashCombine(&hash, pipeline->GetLayout()); size_t hash = PipelineBase::HashForCache(pipeline);
HashCombine(&hash, pipeline->mVertexModule.Get(), pipeline->mFragmentEntryPoint);
HashCombine(&hash, pipeline->mFragmentModule.Get(), pipeline->mFragmentEntryPoint);
// Hierarchically hash the attachment state. // Hierarchically hash the attachment state.
// It contains the attachments set, texture formats, and sample count. // It contains the attachments set, texture formats, and sample count.
@ -671,11 +640,8 @@ namespace dawn_native {
bool RenderPipelineBase::EqualityFunc::operator()(const RenderPipelineBase* a, bool RenderPipelineBase::EqualityFunc::operator()(const RenderPipelineBase* a,
const RenderPipelineBase* b) const { const RenderPipelineBase* b) const {
// Check modules and layout // Check the layout and shader stages.
if (a->GetLayout() != b->GetLayout() || a->mVertexModule.Get() != b->mVertexModule.Get() || if (!PipelineBase::EqualForCache(a, b)) {
a->mVertexEntryPoint != b->mVertexEntryPoint ||
a->mFragmentModule.Get() != b->mFragmentModule.Get() ||
a->mFragmentEntryPoint != b->mFragmentEntryPoint) {
return false; return false;
} }

View File

@ -114,13 +114,6 @@ namespace dawn_native {
RasterizationStateDescriptor mRasterizationState; RasterizationStateDescriptor mRasterizationState;
uint32_t mSampleMask; uint32_t mSampleMask;
bool mAlphaToCoverageEnabled; bool mAlphaToCoverageEnabled;
// Stage information
// TODO(cwallez@chromium.org): Store a crypto hash of the modules instead.
Ref<ShaderModuleBase> mVertexModule;
std::string mVertexEntryPoint;
Ref<ShaderModuleBase> mFragmentModule;
std::string mFragmentEntryPoint;
}; };
} // namespace dawn_native } // namespace dawn_native