Make ShaderModule reflection go through EntryPointMetadata

PipelineBase now collects the EntryPointMetadata for all its
stages which makes the rest of the code agnostic to the entrypoint
name (except D3D12 and OpenGL that required transition hacks and
will be fixed in follow-up CLs).

Bug: dawn:216

Change-Id: I643da198cb2a20a9d94d805a2dc783d6d4346ae9
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/27260
Reviewed-by: Ryan Harrison <rharrison@chromium.org>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
Corentin Wallez 2020-09-02 15:57:39 +00:00 committed by Commit Bot service account
parent e9bc506e0a
commit 8ec8f31e3b
15 changed files with 167 additions and 137 deletions

View File

@ -20,6 +20,7 @@
namespace dawn_native {
class DeviceBase;
struct EntryPointMetadata;
MaybeError ValidateComputePipelineDescriptor(DeviceBase* device,
const ComputePipelineDescriptor* descriptor);
@ -31,6 +32,8 @@ namespace dawn_native {
static ComputePipelineBase* MakeError(DeviceBase* device);
const EntryPointMetadata& GetMetadata() const;
// Functors necessary for the unordered_set<ComputePipelineBase*>-based cache.
struct HashFunc {
size_t operator()(const ComputePipelineBase* pipeline) const;

View File

@ -877,9 +877,9 @@ namespace dawn_native {
if (descriptor->layout == nullptr) {
ComputePipelineDescriptor descriptorWithDefaultLayout = *descriptor;
DAWN_TRY_ASSIGN(
descriptorWithDefaultLayout.layout,
PipelineLayoutBase::CreateDefault(this, &descriptor->computeStage.module, 1));
DAWN_TRY_ASSIGN(descriptorWithDefaultLayout.layout,
PipelineLayoutBase::CreateDefault(
this, {{SingleShaderStage::Compute, &descriptor->computeStage}}));
// Ref will keep the pipeline layout alive until the end of the function where
// the pipeline will take another reference.
Ref<PipelineLayoutBase> layoutRef = AcquireRef(descriptorWithDefaultLayout.layout);
@ -934,18 +934,14 @@ namespace dawn_native {
if (descriptor->layout == nullptr) {
RenderPipelineDescriptor descriptorWithDefaultLayout = *descriptor;
const ShaderModuleBase* modules[2];
modules[0] = descriptor->vertexStage.module;
uint32_t count;
if (descriptor->fragmentStage == nullptr) {
count = 1;
} else {
modules[1] = descriptor->fragmentStage->module;
count = 2;
std::vector<StageAndDescriptor> stages;
stages.emplace_back(SingleShaderStage::Vertex, &descriptor->vertexStage);
if (descriptor->fragmentStage != nullptr) {
stages.emplace_back(SingleShaderStage::Fragment, descriptor->fragmentStage);
}
DAWN_TRY_ASSIGN(descriptorWithDefaultLayout.layout,
PipelineLayoutBase::CreateDefault(this, modules, count));
PipelineLayoutBase::CreateDefault(this, std::move(stages)));
// Ref will keep the pipeline layout alive until the end of the function where
// the pipeline will take another reference.
Ref<PipelineLayoutBase> layoutRef = AcquireRef(descriptorWithDefaultLayout.layout);

View File

@ -26,16 +26,17 @@ namespace dawn_native {
const ProgrammableStageDescriptor* descriptor,
const PipelineLayoutBase* layout,
SingleShaderStage stage) {
DAWN_TRY(device->ValidateObject(descriptor->module));
const ShaderModuleBase* module = descriptor->module;
DAWN_TRY(device->ValidateObject(module));
if (descriptor->entryPoint != std::string("main")) {
return DAWN_VALIDATION_ERROR("Entry point must be \"main\"");
}
if (descriptor->module->GetExecutionModel() != stage) {
return DAWN_VALIDATION_ERROR("Setting module with wrong stages");
if (!module->HasEntryPoint(descriptor->entryPoint, stage)) {
return DAWN_VALIDATION_ERROR("Entry point doesn't exist in the module");
}
if (layout != nullptr) {
DAWN_TRY(descriptor->module->ValidateCompatibilityWithPipelineLayout(layout));
const EntryPointMetadata& metadata =
module->GetEntryPoint(descriptor->entryPoint, stage);
DAWN_TRY(ValidateCompatibilityWithPipelineLayout(metadata, layout));
}
return {};
}
@ -49,13 +50,20 @@ namespace dawn_native {
ASSERT(!stages.empty());
for (const StageAndDescriptor& stage : stages) {
// Extract argument for this stage.
SingleShaderStage shaderStage = stage.first;
ShaderModuleBase* module = stage.second->module;
const char* entryPointName = stage.second->entryPoint;
const EntryPointMetadata& metadata = module->GetEntryPoint(entryPointName, shaderStage);
// Record them internally.
bool isFirstStage = mStageMask == wgpu::ShaderStage::None;
mStageMask |= StageBit(stage.first);
mStages[stage.first] = {stage.second->module, stage.second->entryPoint};
mStageMask |= StageBit(shaderStage);
mStages[shaderStage] = {module, entryPointName, &metadata};
// Compute the max() of all minBufferSizes across all stages.
RequiredBufferSizes stageMinBufferSizes =
stage.second->module->ComputeRequiredBufferSizesForLayout(layout);
ComputeRequiredBufferSizesForLayout(metadata, layout);
if (isFirstStage) {
mMinBufferSizes = std::move(stageMinBufferSizes);

View File

@ -36,6 +36,9 @@ namespace dawn_native {
struct ProgrammableStage {
Ref<ShaderModuleBase> module;
std::string entryPoint;
// The metadata lives as long as module, that's ref-ed in the same structure.
const EntryPointMetadata* metadata = nullptr;
};
class PipelineBase : public CachedObject {
@ -52,8 +55,6 @@ namespace dawn_native {
static bool EqualForCache(const PipelineBase* a, const PipelineBase* b);
protected:
using StageAndDescriptor = std::pair<SingleShaderStage, const ProgrammableStageDescriptor*>;
PipelineBase(DeviceBase* device,
PipelineLayoutBase* layout,
std::vector<StageAndDescriptor> stages);

View File

@ -114,9 +114,8 @@ namespace dawn_native {
// static
ResultOrError<PipelineLayoutBase*> PipelineLayoutBase::CreateDefault(
DeviceBase* device,
const ShaderModuleBase* const* modules,
uint32_t count) {
ASSERT(count > 0);
std::vector<StageAndDescriptor> stages) {
ASSERT(!stages.empty());
// Data which BindGroupLayoutDescriptor will point to for creation
ityp::array<
@ -134,20 +133,22 @@ namespace dawn_native {
BindingCounts bindingCounts = {};
BindGroupIndex bindGroupLayoutCount(0);
for (uint32_t moduleIndex = 0; moduleIndex < count; ++moduleIndex) {
const ShaderModuleBase* module = modules[moduleIndex];
const ShaderModuleBase::ModuleBindingInfo& info = module->GetBindingInfo();
for (const StageAndDescriptor& stage : stages) {
// Extract argument for this stage.
SingleShaderStage shaderStage = stage.first;
const EntryPointMetadata::BindingInfo& info =
stage.second->module->GetEntryPoint(stage.second->entryPoint, shaderStage).bindings;
for (BindGroupIndex group(0); group < info.size(); ++group) {
for (const auto& it : info[group]) {
BindingNumber bindingNumber = it.first;
const ShaderModuleBase::ShaderBindingInfo& bindingInfo = it.second;
const EntryPointMetadata::ShaderBindingInfo& bindingInfo = it.second;
BindGroupLayoutEntry bindingSlot;
bindingSlot.binding = static_cast<uint32_t>(bindingNumber);
DAWN_TRY(ValidateBindingTypeWithShaderStageVisibility(
bindingInfo.type, StageBit(module->GetExecutionModel())));
DAWN_TRY(ValidateBindingTypeWithShaderStageVisibility(bindingInfo.type,
StageBit(shaderStage)));
DAWN_TRY(ValidateStorageTextureFormat(device, bindingInfo.type,
bindingInfo.storageTextureFormat));
DAWN_TRY(ValidateStorageTextureViewDimension(bindingInfo.type,
@ -239,10 +240,10 @@ namespace dawn_native {
}
}
for (uint32_t moduleIndex = 0; moduleIndex < count; ++moduleIndex) {
ASSERT(modules[moduleIndex]
->ValidateCompatibilityWithPipelineLayout(pipelineLayout)
.IsSuccess());
for (const StageAndDescriptor& stage : stages) {
const EntryPointMetadata& metadata =
stage.second->module->GetEntryPoint(stage.second->entryPoint, stage.first);
ASSERT(ValidateCompatibilityWithPipelineLayout(metadata, pipelineLayout).IsSuccess());
}
return pipelineLayout;

View File

@ -37,14 +37,17 @@ namespace dawn_native {
ityp::array<BindGroupIndex, Ref<BindGroupLayoutBase>, kMaxBindGroups>;
using BindGroupLayoutMask = ityp::bitset<BindGroupIndex, kMaxBindGroups>;
using StageAndDescriptor = std::pair<SingleShaderStage, const ProgrammableStageDescriptor*>;
class PipelineLayoutBase : public CachedObject {
public:
PipelineLayoutBase(DeviceBase* device, const PipelineLayoutDescriptor* descriptor);
~PipelineLayoutBase() override;
static PipelineLayoutBase* MakeError(DeviceBase* device);
static ResultOrError<PipelineLayoutBase*>
CreateDefault(DeviceBase* device, const ShaderModuleBase* const* modules, uint32_t count);
static ResultOrError<PipelineLayoutBase*> CreateDefault(
DeviceBase* device,
std::vector<StageAndDescriptor> stages);
const BindGroupLayoutBase* GetBindGroupLayout(BindGroupIndex group) const;
BindGroupLayoutBase* GetBindGroupLayout(BindGroupIndex group);

View File

@ -333,8 +333,9 @@ namespace dawn_native {
DAWN_TRY(ValidateRasterizationStateDescriptor(descriptor->rasterizationState));
}
if ((descriptor->vertexStage.module->GetUsedVertexAttributes() & ~attributesSetMask)
.any()) {
const EntryPointMetadata& vertexMetadata = descriptor->vertexStage.module->GetEntryPoint(
descriptor->vertexStage.entryPoint, SingleShaderStage::Vertex);
if ((vertexMetadata.usedVertexAttributes & ~attributesSetMask).any()) {
return DAWN_VALIDATION_ERROR(
"Pipeline vertex stage uses vertex buffers not in the vertex state");
}
@ -352,11 +353,13 @@ namespace dawn_native {
}
ASSERT(descriptor->fragmentStage != nullptr);
const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes =
descriptor->fragmentStage->module->GetFragmentOutputBaseTypes();
const EntryPointMetadata& fragmentMetadata =
descriptor->fragmentStage->module->GetEntryPoint(descriptor->fragmentStage->entryPoint,
SingleShaderStage::Fragment);
for (uint32_t i = 0; i < descriptor->colorStateCount; ++i) {
DAWN_TRY(ValidateColorStateDescriptor(device, descriptor->colorStates[i],
fragmentOutputBaseTypes[i]));
DAWN_TRY(
ValidateColorStateDescriptor(device, descriptor->colorStates[i],
fragmentMetadata.fragmentOutputFormatBaseTypes[i]));
}
if (descriptor->depthStencilState) {

View File

@ -28,6 +28,7 @@ namespace dawn_native {
struct BeginRenderPassCmd;
class DeviceBase;
struct EntryPointMetadata;
class RenderBundleEncoder;
MaybeError ValidateRenderPipelineDescriptor(DeviceBase* device,

View File

@ -549,7 +549,7 @@ namespace dawn_native {
#endif // DAWN_ENABLE_WGSL
std::vector<uint64_t> GetBindGroupMinBufferSizes(
const ShaderModuleBase::BindingInfoMap& shaderBindings,
const EntryPointMetadata::BindingGroupInfoMap& shaderBindings,
const BindGroupLayoutBase* layout) {
std::vector<uint64_t> requiredBufferSizes(layout->GetUnverifiedBufferCount());
uint32_t packedIdx = 0;
@ -578,17 +578,16 @@ namespace dawn_native {
return requiredBufferSizes;
}
MaybeError ValidateCompatibilityWithBindGroupLayout(
BindGroupIndex group,
const ShaderModuleBase::EntryPointMetadata& entryPoint,
const BindGroupLayoutBase* layout) {
MaybeError ValidateCompatibilityWithBindGroupLayout(BindGroupIndex group,
const EntryPointMetadata& entryPoint,
const BindGroupLayoutBase* layout) {
const BindGroupLayoutBase::BindingMap& layoutBindings = layout->GetBindingMap();
// Iterate over all bindings used by this group in the shader, and find the
// corresponding binding in the BindGroupLayout, if it exists.
for (const auto& it : entryPoint.bindings[group]) {
BindingNumber bindingNumber = it.first;
const ShaderModuleBase::ShaderBindingInfo& shaderInfo = it.second;
const EntryPointMetadata::ShaderBindingInfo& shaderInfo = it.second;
const auto& bindingIt = layoutBindings.find(bindingNumber);
if (bindingIt == layoutBindings.end()) {
@ -732,9 +731,8 @@ namespace dawn_native {
return {};
}
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(
const ShaderModuleBase::EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout) {
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout) {
RequiredBufferSizes bufferSizes;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
bufferSizes[group] = GetBindGroupMinBufferSizes(entryPoint.bindings[group],
@ -744,9 +742,8 @@ namespace dawn_native {
return bufferSizes;
}
MaybeError ValidateCompatibilityWithPipelineLayout(
const ShaderModuleBase::EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout) {
MaybeError ValidateCompatibilityWithPipelineLayout(const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout) {
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
DAWN_TRY(ValidateCompatibilityWithBindGroupLayout(group, entryPoint,
layout->GetBindGroupLayout(group)));
@ -766,7 +763,7 @@ namespace dawn_native {
// EntryPointMetadata
ShaderModuleBase::EntryPointMetadata::EntryPointMetadata() {
EntryPointMetadata::EntryPointMetadata() {
fragmentOutputFormatBaseTypes.fill(Format::Type::Other);
}
@ -814,6 +811,20 @@ namespace dawn_native {
return new ShaderModuleBase(device, ObjectBase::kError);
}
bool ShaderModuleBase::HasEntryPoint(const std::string& entryPoint,
SingleShaderStage stage) const {
// TODO(dawn:216): Properly extract all entryPoints from the shader module.
return entryPoint == "main" && stage == mMainEntryPoint->stage;
}
const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint,
SingleShaderStage stage) const {
// TODO(dawn:216): Properly extract all entryPoints from the shader module.
ASSERT(entryPoint == "main");
ASSERT(stage == mMainEntryPoint->stage);
return *mMainEntryPoint;
}
MaybeError ShaderModuleBase::ExtractSpirvInfo(const spirv_cross::Compiler& compiler) {
ASSERT(!IsError());
if (GetDevice()->IsToggleEnabled(Toggle::UseSpvc)) {
@ -824,7 +835,7 @@ namespace dawn_native {
return {};
}
ResultOrError<std::unique_ptr<ShaderModuleBase::EntryPointMetadata>>
ResultOrError<std::unique_ptr<EntryPointMetadata>>
ShaderModuleBase::ExtractSpirvInfoWithSpvc() {
DeviceBase* device = GetDevice();
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
@ -848,7 +859,7 @@ namespace dawn_native {
// Fill in bindingInfo with the SPIRV bindings
auto ExtractResourcesBinding =
[](const DeviceBase* device, const std::vector<shaderc_spvc_binding_info>& spvcBindings,
ModuleBindingInfo* metadataBindings) -> MaybeError {
EntryPointMetadata::BindingInfo* metadataBindings) -> MaybeError {
for (const shaderc_spvc_binding_info& binding : spvcBindings) {
BindGroupIndex bindGroupIndex(binding.set);
@ -857,12 +868,12 @@ namespace dawn_native {
}
const auto& it = (*metadataBindings)[bindGroupIndex].emplace(
BindingNumber(binding.binding), ShaderBindingInfo{});
BindingNumber(binding.binding), EntryPointMetadata::ShaderBindingInfo{});
if (!it.second) {
return DAWN_VALIDATION_ERROR("Shader has duplicate bindings");
}
ShaderBindingInfo* info = &it.first->second;
EntryPointMetadata::ShaderBindingInfo* info = &it.first->second;
info->id = binding.id;
info->base_type_id = binding.base_type_id;
info->type = ToWGPUBindingType(binding.binding_type);
@ -994,7 +1005,7 @@ namespace dawn_native {
return {std::move(metadata)};
}
ResultOrError<std::unique_ptr<ShaderModuleBase::EntryPointMetadata>>
ResultOrError<std::unique_ptr<EntryPointMetadata>>
ShaderModuleBase::ExtractSpirvInfoWithSpirvCross(const spirv_cross::Compiler& compiler) {
DeviceBase* device = GetDevice();
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
@ -1031,7 +1042,7 @@ namespace dawn_native {
[](const DeviceBase* device,
const spirv_cross::SmallVector<spirv_cross::Resource>& resources,
const spirv_cross::Compiler& compiler, wgpu::BindingType bindingType,
ModuleBindingInfo* metadataBindings) -> MaybeError {
EntryPointMetadata::BindingInfo* metadataBindings) -> MaybeError {
for (const auto& resource : resources) {
if (!compiler.get_decoration_bitset(resource.id).get(spv::DecorationBinding)) {
return DAWN_VALIDATION_ERROR("No Binding decoration set for resource");
@ -1051,13 +1062,13 @@ namespace dawn_native {
return DAWN_VALIDATION_ERROR("Bind group index over limits in the SPIRV");
}
const auto& it =
(*metadataBindings)[bindGroupIndex].emplace(bindingNumber, ShaderBindingInfo{});
const auto& it = (*metadataBindings)[bindGroupIndex].emplace(
bindingNumber, EntryPointMetadata::ShaderBindingInfo{});
if (!it.second) {
return DAWN_VALIDATION_ERROR("Shader has duplicate bindings");
}
ShaderBindingInfo* info = &it.first->second;
EntryPointMetadata::ShaderBindingInfo* info = &it.first->second;
info->id = resource.id;
info->base_type_id = resource.base_type_id;
@ -1204,39 +1215,6 @@ namespace dawn_native {
return {std::move(metadata)};
}
const ShaderModuleBase::ModuleBindingInfo& ShaderModuleBase::GetBindingInfo() const {
ASSERT(!IsError());
return mMainEntryPoint->bindings;
}
const std::bitset<kMaxVertexAttributes>& ShaderModuleBase::GetUsedVertexAttributes() const {
ASSERT(!IsError());
return mMainEntryPoint->usedVertexAttributes;
}
const ShaderModuleBase::FragmentOutputBaseTypes& ShaderModuleBase::GetFragmentOutputBaseTypes()
const {
ASSERT(!IsError());
return mMainEntryPoint->fragmentOutputFormatBaseTypes;
}
SingleShaderStage ShaderModuleBase::GetExecutionModel() const {
ASSERT(!IsError());
return mMainEntryPoint->stage;
}
RequiredBufferSizes ShaderModuleBase::ComputeRequiredBufferSizesForLayout(
const PipelineLayoutBase* layout) const {
ASSERT(!IsError());
return ::dawn_native::ComputeRequiredBufferSizesForLayout(*mMainEntryPoint, layout);
}
MaybeError ShaderModuleBase::ValidateCompatibilityWithPipelineLayout(
const PipelineLayoutBase* layout) const {
ASSERT(!IsError());
return ::dawn_native::ValidateCompatibilityWithPipelineLayout(*mMainEntryPoint, layout);
}
size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const {
size_t hash = 0;
@ -1298,4 +1276,10 @@ namespace dawn_native {
return {};
}
SingleShaderStage ShaderModuleBase::GetMainEntryPointStageForTransition() const {
ASSERT(!IsError());
return mMainEntryPoint->stage;
}
} // namespace dawn_native

View File

@ -38,18 +38,25 @@ namespace spirv_cross {
namespace dawn_native {
struct EntryPointMetadata;
MaybeError ValidateShaderModuleDescriptor(DeviceBase* device,
const ShaderModuleDescriptor* descriptor);
MaybeError ValidateCompatibilityWithPipelineLayout(const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout);
class ShaderModuleBase : public CachedObject {
public:
ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
~ShaderModuleBase() override;
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint,
const PipelineLayoutBase* layout);
static ShaderModuleBase* MakeError(DeviceBase* device);
MaybeError ExtractSpirvInfo(const spirv_cross::Compiler& compiler);
// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
// stored in the ShaderModuleBase and destroyed only when the shader module is destroyed so
// pointers to EntryPointMetadata are safe to store as long as you also keep a Ref to the
// ShaderModuleBase.
struct EntryPointMetadata {
EntryPointMetadata();
// Per-binding shader metadata contains some SPIRV specific information in addition to
// most of the frontend per-binding information.
struct ShaderBindingInfo : BindingInfo {
// The SPIRV ID of the resource.
uint32_t id;
@ -61,22 +68,42 @@ namespace dawn_native {
using BindingInfo::visibility;
};
using BindingInfoMap = std::map<BindingNumber, ShaderBindingInfo>;
using ModuleBindingInfo = ityp::array<BindGroupIndex, BindingInfoMap, kMaxBindGroups>;
// bindings[G][B] is the reflection data for the binding defined with
// [[group=G, binding=B]] in WGSL / SPIRV.
using BindingGroupInfoMap = std::map<BindingNumber, ShaderBindingInfo>;
using BindingInfo = ityp::array<BindGroupIndex, BindingGroupInfoMap, kMaxBindGroups>;
BindingInfo bindings;
const ModuleBindingInfo& GetBindingInfo() const;
const std::bitset<kMaxVertexAttributes>& GetUsedVertexAttributes() const;
SingleShaderStage GetExecutionModel() const;
// The set of vertex attributes this entryPoint uses.
std::bitset<kMaxVertexAttributes> usedVertexAttributes;
// An array to record the basic types (float, int and uint) of the fragment shader outputs
// or Format::Type::Other means the fragment shader output is unused.
using FragmentOutputBaseTypes = std::array<Format::Type, kMaxColorAttachments>;
const FragmentOutputBaseTypes& GetFragmentOutputBaseTypes() const;
FragmentOutputBaseTypes fragmentOutputFormatBaseTypes;
MaybeError ValidateCompatibilityWithPipelineLayout(const PipelineLayoutBase* layout) const;
// The shader stage for this binding, TODO(dawn:216): can likely be removed once we
// properly support multiple entrypoints per ShaderModule.
SingleShaderStage stage;
};
RequiredBufferSizes ComputeRequiredBufferSizesForLayout(
const PipelineLayoutBase* layout) const;
class ShaderModuleBase : public CachedObject {
public:
ShaderModuleBase(DeviceBase* device, const ShaderModuleDescriptor* descriptor);
~ShaderModuleBase() override;
static ShaderModuleBase* MakeError(DeviceBase* device);
// Return true iff the module has an entrypoint called `entryPoint` for stage `stage`.
bool HasEntryPoint(const std::string& entryPoint, SingleShaderStage stage) const;
// Returns the metadata for the given `entryPoint` and `stage`. HasEntryPoint with the same
// arguments must be true.
const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint,
SingleShaderStage stage) const;
// TODO make this member protected, it is only used outside of child classes in DeviceNull.
MaybeError ExtractSpirvInfo(const spirv_cross::Compiler& compiler);
// Functors necessary for the unordered_set<ShaderModuleBase*>-based cache.
struct HashFunc {
@ -96,15 +123,6 @@ namespace dawn_native {
uint32_t pullingBufferBindingSet) const;
#endif
struct EntryPointMetadata {
EntryPointMetadata();
ModuleBindingInfo bindings;
std::bitset<kMaxVertexAttributes> usedVertexAttributes;
SingleShaderStage stage;
FragmentOutputBaseTypes fragmentOutputFormatBaseTypes;
};
protected:
static MaybeError CheckSpvcSuccess(shaderc_spvc_status status, const char* error_msg);
shaderc_spvc::CompileOptions GetCompileOptions() const;
@ -112,6 +130,11 @@ namespace dawn_native {
shaderc_spvc::Context mSpvcContext;
// Allows backends to get the stage for the "main" entrypoint while they are transitioned to
// support multiple entrypoints.
// TODO(dawn:216): Remove this once the transition is complete.
SingleShaderStage GetMainEntryPointStageForTransition() const;
private:
ShaderModuleBase(DeviceBase* device, ObjectBase::ErrorTag tag);

View File

@ -170,13 +170,15 @@ namespace dawn_native { namespace d3d12 {
compiler->set_hlsl_options(options_hlsl);
}
const ModuleBindingInfo& moduleBindingInfo = GetBindingInfo();
const EntryPointMetadata::BindingInfo& moduleBindingInfo =
GetEntryPoint("main", GetMainEntryPointStageForTransition()).bindings;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
const auto& bindingOffsets = bgl->GetBindingOffsets();
const auto& groupBindingInfo = moduleBindingInfo[group];
for (const auto& it : groupBindingInfo) {
const ShaderBindingInfo& bindingInfo = it.second;
const EntryPointMetadata::ShaderBindingInfo& bindingInfo = it.second;
BindingNumber bindingNumber = it.first;
BindingIndex bindingIndex = bgl->GetBindingIndex(bindingNumber);

View File

@ -368,8 +368,8 @@ namespace dawn_native { namespace metal {
}
}
const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes =
descriptor->fragmentStage->module->GetFragmentOutputBaseTypes();
const EntryPointMetadata::FragmentOutputBaseTypes& fragmentOutputBaseTypes =
GetStage(SingleShaderStage::Fragment).metadata->fragmentOutputFormatBaseTypes;
for (uint32_t i : IterateBitSet(GetColorAttachmentsMask())) {
descriptorMTL.colorAttachments[i].pixelFormat =
MetalPixelFormat(GetColorAttachmentFormat(i));

View File

@ -259,11 +259,15 @@ namespace dawn_native { namespace metal {
// TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like
// clean_func_name:
// https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213
if (strcmp(functionName, "main") == 0) {
functionName = "main0";
const char* metalFunctionName = functionName;
if (strcmp(metalFunctionName, "main") == 0) {
metalFunctionName = "main0";
}
if (strcmp(metalFunctionName, "saturate") == 0) {
metalFunctionName = "saturate0";
}
NSString* name = [[NSString alloc] initWithUTF8String:functionName];
NSString* name = [[NSString alloc] initWithUTF8String:metalFunctionName];
out->function = [library newFunctionWithName:name];
[library release];
}
@ -277,7 +281,7 @@ namespace dawn_native { namespace metal {
}
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
functionStage == SingleShaderStage::Vertex && GetUsedVertexAttributes().any()) {
GetEntryPoint(functionName, functionStage).usedVertexAttributes.any()) {
out->needsStorageBufferLength = true;
}

View File

@ -125,8 +125,6 @@ namespace dawn_native { namespace opengl {
DAWN_TRY(ExtractSpirvInfo(*compiler));
const ShaderModuleBase::ModuleBindingInfo& bindingInfo = GetBindingInfo();
// Extract bindings names so that it can be used to get its location in program.
// Now translate the separate sampler / textures into combined ones and store their info.
// We need to do this before removing the set and binding decorations.
@ -182,6 +180,9 @@ namespace dawn_native { namespace opengl {
}
}
const EntryPointMetadata::BindingInfo& bindingInfo =
GetEntryPoint("main", GetMainEntryPointStageForTransition()).bindings;
// Change binding names to be "dawn_binding_<group>_<binding>".
// Also unsets the SPIRV "Binding" decoration as it outputs "layout(binding=)" which
// isn't supported on OSX's OpenGL.

View File

@ -425,8 +425,8 @@ namespace dawn_native { namespace vulkan {
// Initialize the "blend state info" that will be chained in the "create info" from the data
// pre-computed in the ColorState
std::array<VkPipelineColorBlendAttachmentState, kMaxColorAttachments> colorBlendAttachments;
const ShaderModuleBase::FragmentOutputBaseTypes& fragmentOutputBaseTypes =
descriptor->fragmentStage->module->GetFragmentOutputBaseTypes();
const EntryPointMetadata::FragmentOutputBaseTypes& fragmentOutputBaseTypes =
GetStage(SingleShaderStage::Fragment).metadata->fragmentOutputFormatBaseTypes;
for (uint32_t i : IterateBitSet(GetColorAttachmentsMask())) {
const ColorStateDescriptor* colorStateDescriptor = GetColorStateDescriptor(i);
bool isDeclaredInFragmentShader = fragmentOutputBaseTypes[i] != Format::Type::Other;