Metal: Use ShaderModule reflection when possible.

This change the Metal backend in preparation for supporting multiple
entrypoints:

 - Explicitly set the spirv_cross entry point before compiling.
 - Moves gathering of the local size to the frontend as it will be
   useful for validation in the future.
 - Query spirv-cross for the modified entrypoint name instead of
   duplicating the code in Dawn.
 - Move some conversion helpers from ShaderModule.cpp to their own
   SpirvUtils file.

Bug: dawn:216
Change-Id: I87d4953428e0bfeb97e39ed22f94d86ae7987782
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/28241
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
Corentin Wallez 2020-09-09 22:47:07 +00:00 committed by Commit Bot service account
parent 947201da19
commit 676e8f39df
10 changed files with 273 additions and 200 deletions

View File

@ -253,6 +253,8 @@ source_set("dawn_native_sources") {
"Sampler.h", "Sampler.h",
"ShaderModule.cpp", "ShaderModule.cpp",
"ShaderModule.h", "ShaderModule.h",
"SpirvUtils.cpp",
"SpirvUtils.h",
"StagingBuffer.cpp", "StagingBuffer.cpp",
"StagingBuffer.h", "StagingBuffer.h",
"Surface.cpp", "Surface.cpp",

View File

@ -131,6 +131,8 @@ target_sources(dawn_native PRIVATE
"Sampler.h" "Sampler.h"
"ShaderModule.cpp" "ShaderModule.cpp"
"ShaderModule.h" "ShaderModule.h"
"SpirvUtils.cpp"
"SpirvUtils.h"
"StagingBuffer.cpp" "StagingBuffer.cpp"
"StagingBuffer.h" "StagingBuffer.h"
"Surface.cpp" "Surface.cpp"

View File

@ -19,6 +19,7 @@
#include "dawn_native/Device.h" #include "dawn_native/Device.h"
#include "dawn_native/Pipeline.h" #include "dawn_native/Pipeline.h"
#include "dawn_native/PipelineLayout.h" #include "dawn_native/PipelineLayout.h"
#include "dawn_native/SpirvUtils.h"
#include <spirv-tools/libspirv.hpp> #include <spirv-tools/libspirv.hpp>
#include <spirv_cross.hpp> #include <spirv_cross.hpp>
@ -36,114 +37,6 @@
namespace dawn_native { namespace dawn_native {
namespace { namespace {
Format::Type SpirvCrossBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType) {
switch (spirvBaseType) {
case spirv_cross::SPIRType::Float:
return Format::Type::Float;
case spirv_cross::SPIRType::Int:
return Format::Type::Sint;
case spirv_cross::SPIRType::UInt:
return Format::Type::Uint;
default:
UNREACHABLE();
return Format::Type::Other;
}
}
wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed) {
switch (dim) {
case spv::Dim::Dim1D:
return wgpu::TextureViewDimension::e1D;
case spv::Dim::Dim2D:
if (arrayed) {
return wgpu::TextureViewDimension::e2DArray;
} else {
return wgpu::TextureViewDimension::e2D;
}
case spv::Dim::Dim3D:
return wgpu::TextureViewDimension::e3D;
case spv::Dim::DimCube:
if (arrayed) {
return wgpu::TextureViewDimension::CubeArray;
} else {
return wgpu::TextureViewDimension::Cube;
}
default:
UNREACHABLE();
return wgpu::TextureViewDimension::Undefined;
}
}
wgpu::TextureFormat ToWGPUTextureFormat(spv::ImageFormat format) {
switch (format) {
case spv::ImageFormatR8:
return wgpu::TextureFormat::R8Unorm;
case spv::ImageFormatR8Snorm:
return wgpu::TextureFormat::R8Snorm;
case spv::ImageFormatR8ui:
return wgpu::TextureFormat::R8Uint;
case spv::ImageFormatR8i:
return wgpu::TextureFormat::R8Sint;
case spv::ImageFormatR16ui:
return wgpu::TextureFormat::R16Uint;
case spv::ImageFormatR16i:
return wgpu::TextureFormat::R16Sint;
case spv::ImageFormatR16f:
return wgpu::TextureFormat::R16Float;
case spv::ImageFormatRg8:
return wgpu::TextureFormat::RG8Unorm;
case spv::ImageFormatRg8Snorm:
return wgpu::TextureFormat::RG8Snorm;
case spv::ImageFormatRg8ui:
return wgpu::TextureFormat::RG8Uint;
case spv::ImageFormatRg8i:
return wgpu::TextureFormat::RG8Sint;
case spv::ImageFormatR32f:
return wgpu::TextureFormat::R32Float;
case spv::ImageFormatR32ui:
return wgpu::TextureFormat::R32Uint;
case spv::ImageFormatR32i:
return wgpu::TextureFormat::R32Sint;
case spv::ImageFormatRg16ui:
return wgpu::TextureFormat::RG16Uint;
case spv::ImageFormatRg16i:
return wgpu::TextureFormat::RG16Sint;
case spv::ImageFormatRg16f:
return wgpu::TextureFormat::RG16Float;
case spv::ImageFormatRgba8:
return wgpu::TextureFormat::RGBA8Unorm;
case spv::ImageFormatRgba8Snorm:
return wgpu::TextureFormat::RGBA8Snorm;
case spv::ImageFormatRgba8ui:
return wgpu::TextureFormat::RGBA8Uint;
case spv::ImageFormatRgba8i:
return wgpu::TextureFormat::RGBA8Sint;
case spv::ImageFormatRgb10A2:
return wgpu::TextureFormat::RGB10A2Unorm;
case spv::ImageFormatR11fG11fB10f:
return wgpu::TextureFormat::RG11B10Ufloat;
case spv::ImageFormatRg32f:
return wgpu::TextureFormat::RG32Float;
case spv::ImageFormatRg32ui:
return wgpu::TextureFormat::RG32Uint;
case spv::ImageFormatRg32i:
return wgpu::TextureFormat::RG32Sint;
case spv::ImageFormatRgba16ui:
return wgpu::TextureFormat::RGBA16Uint;
case spv::ImageFormatRgba16i:
return wgpu::TextureFormat::RGBA16Sint;
case spv::ImageFormatRgba16f:
return wgpu::TextureFormat::RGBA16Float;
case spv::ImageFormatRgba32f:
return wgpu::TextureFormat::RGBA32Float;
case spv::ImageFormatRgba32ui:
return wgpu::TextureFormat::RGBA32Uint;
case spv::ImageFormatRgba32i:
return wgpu::TextureFormat::RGBA32Sint;
default:
return wgpu::TextureFormat::Undefined;
}
}
std::string GetShaderDeclarationString(BindGroupIndex group, BindingNumber binding) { std::string GetShaderDeclarationString(BindGroupIndex group, BindingNumber binding) {
std::ostringstream ostream; std::ostringstream ostream;
@ -550,27 +443,15 @@ namespace dawn_native {
ResultOrError<std::unique_ptr<EntryPointMetadata>> ExtractSpirvInfo( ResultOrError<std::unique_ptr<EntryPointMetadata>> ExtractSpirvInfo(
const DeviceBase* device, const DeviceBase* device,
const spirv_cross::Compiler& compiler) { const spirv_cross::Compiler& compiler,
const char* entryPointName) {
std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>(); std::unique_ptr<EntryPointMetadata> metadata = std::make_unique<EntryPointMetadata>();
// TODO(cwallez@chromium.org): make errors here creation errors // TODO(cwallez@chromium.org): make errors here creation errors
// currently errors here do not prevent the shadermodule from being used // currently errors here do not prevent the shadermodule from being used
const auto& resources = compiler.get_shader_resources(); const auto& resources = compiler.get_shader_resources();
switch (compiler.get_execution_model()) { metadata->stage = ExecutionModelToShaderStage(compiler.get_execution_model());
case spv::ExecutionModelVertex:
metadata->stage = SingleShaderStage::Vertex;
break;
case spv::ExecutionModelFragment:
metadata->stage = SingleShaderStage::Fragment;
break;
case spv::ExecutionModelGLCompute:
metadata->stage = SingleShaderStage::Compute;
break;
default:
UNREACHABLE();
return DAWN_VALIDATION_ERROR("Unexpected shader execution model");
}
if (resources.push_constant_buffers.size() > 0) { if (resources.push_constant_buffers.size() > 0) {
return DAWN_VALIDATION_ERROR("Push constants aren't supported."); return DAWN_VALIDATION_ERROR("Push constants aren't supported.");
@ -635,7 +516,7 @@ namespace dawn_native {
info->viewDimension = info->viewDimension =
SpirvDimToTextureViewDimension(imageType.dim, imageType.arrayed); SpirvDimToTextureViewDimension(imageType.dim, imageType.arrayed);
info->textureComponentType = info->textureComponentType =
SpirvCrossBaseTypeToFormatType(textureComponentType); SpirvBaseTypeToFormatType(textureComponentType);
info->type = bindingType; info->type = bindingType;
break; break;
} }
@ -664,7 +545,7 @@ namespace dawn_native {
spirv_cross::SPIRType::ImageType imageType = spirv_cross::SPIRType::ImageType imageType =
compiler.get_type(info->base_type_id).image; compiler.get_type(info->base_type_id).image;
wgpu::TextureFormat storageTextureFormat = wgpu::TextureFormat storageTextureFormat =
ToWGPUTextureFormat(imageType.format); SpirvImageFormatToTextureFormat(imageType.format);
if (storageTextureFormat == wgpu::TextureFormat::Undefined) { if (storageTextureFormat == wgpu::TextureFormat::Undefined) {
return DAWN_VALIDATION_ERROR( return DAWN_VALIDATION_ERROR(
"Invalid image format declaration on storage image"); "Invalid image format declaration on storage image");
@ -756,7 +637,7 @@ namespace dawn_native {
spirv_cross::SPIRType::BaseType shaderFragmentOutputBaseType = spirv_cross::SPIRType::BaseType shaderFragmentOutputBaseType =
compiler.get_type(fragmentOutput.base_type_id).basetype; compiler.get_type(fragmentOutput.base_type_id).basetype;
Format::Type formatType = Format::Type formatType =
SpirvCrossBaseTypeToFormatType(shaderFragmentOutputBaseType); SpirvBaseTypeToFormatType(shaderFragmentOutputBaseType);
if (formatType == Format::Type::Other) { if (formatType == Format::Type::Other) {
return DAWN_VALIDATION_ERROR("Unexpected Fragment output type"); return DAWN_VALIDATION_ERROR("Unexpected Fragment output type");
} }
@ -764,6 +645,14 @@ namespace dawn_native {
} }
} }
if (metadata->stage == SingleShaderStage::Compute) {
const spirv_cross::SPIREntryPoint& spirEntryPoint =
compiler.get_entry_point(entryPointName, spv::ExecutionModelGLCompute);
metadata->localWorkgroupSize.x = spirEntryPoint.workgroup_size.x;
metadata->localWorkgroupSize.y = spirEntryPoint.workgroup_size.y;
metadata->localWorkgroupSize.z = spirEntryPoint.workgroup_size.z;
}
return {std::move(metadata)}; return {std::move(metadata)};
} }
@ -935,7 +824,7 @@ namespace dawn_native {
} }
spirv_cross::Compiler compiler(mSpirv); spirv_cross::Compiler compiler(mSpirv);
DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfo(GetDevice(), compiler)); DAWN_TRY_ASSIGN(mMainEntryPoint, ExtractSpirvInfo(GetDevice(), compiler, "main"));
return {}; return {};
} }

View File

@ -24,7 +24,6 @@
#include "dawn_native/Forward.h" #include "dawn_native/Forward.h"
#include "dawn_native/IntegerTypes.h" #include "dawn_native/IntegerTypes.h"
#include "dawn_native/PerStage.h" #include "dawn_native/PerStage.h"
#include "dawn_native/dawn_platform.h" #include "dawn_native/dawn_platform.h"
#include <bitset> #include <bitset>
@ -82,8 +81,10 @@ namespace dawn_native {
ityp::array<ColorAttachmentIndex, Format::Type, kMaxColorAttachments>; ityp::array<ColorAttachmentIndex, Format::Type, kMaxColorAttachments>;
FragmentOutputBaseTypes fragmentOutputFormatBaseTypes; FragmentOutputBaseTypes fragmentOutputFormatBaseTypes;
// The shader stage for this binding, TODO(dawn:216): can likely be removed once we // The local workgroup size declared for a compute entry point (or 0s otehrwise).
// properly support multiple entrypoints per ShaderModule. Origin3D localWorkgroupSize;
// The shader stage for this binding.
SingleShaderStage stage; SingleShaderStage stage;
}; };

View File

@ -0,0 +1,154 @@
// Copyright 2020 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "dawn_native/SpirvUtils.h"
namespace dawn_native {
spv::ExecutionModel ShaderStageToExecutionModel(SingleShaderStage stage) {
switch (stage) {
case SingleShaderStage::Vertex:
return spv::ExecutionModelVertex;
case SingleShaderStage::Fragment:
return spv::ExecutionModelFragment;
case SingleShaderStage::Compute:
return spv::ExecutionModelGLCompute;
default:
UNREACHABLE();
}
}
SingleShaderStage ExecutionModelToShaderStage(spv::ExecutionModel model) {
switch (model) {
case spv::ExecutionModelVertex:
return SingleShaderStage::Vertex;
case spv::ExecutionModelFragment:
return SingleShaderStage::Fragment;
case spv::ExecutionModelGLCompute:
return SingleShaderStage::Compute;
default:
UNREACHABLE();
}
}
wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed) {
switch (dim) {
case spv::Dim::Dim1D:
return wgpu::TextureViewDimension::e1D;
case spv::Dim::Dim2D:
if (arrayed) {
return wgpu::TextureViewDimension::e2DArray;
} else {
return wgpu::TextureViewDimension::e2D;
}
case spv::Dim::Dim3D:
return wgpu::TextureViewDimension::e3D;
case spv::Dim::DimCube:
if (arrayed) {
return wgpu::TextureViewDimension::CubeArray;
} else {
return wgpu::TextureViewDimension::Cube;
}
default:
UNREACHABLE();
return wgpu::TextureViewDimension::Undefined;
}
}
wgpu::TextureFormat SpirvImageFormatToTextureFormat(spv::ImageFormat format) {
switch (format) {
case spv::ImageFormatR8:
return wgpu::TextureFormat::R8Unorm;
case spv::ImageFormatR8Snorm:
return wgpu::TextureFormat::R8Snorm;
case spv::ImageFormatR8ui:
return wgpu::TextureFormat::R8Uint;
case spv::ImageFormatR8i:
return wgpu::TextureFormat::R8Sint;
case spv::ImageFormatR16ui:
return wgpu::TextureFormat::R16Uint;
case spv::ImageFormatR16i:
return wgpu::TextureFormat::R16Sint;
case spv::ImageFormatR16f:
return wgpu::TextureFormat::R16Float;
case spv::ImageFormatRg8:
return wgpu::TextureFormat::RG8Unorm;
case spv::ImageFormatRg8Snorm:
return wgpu::TextureFormat::RG8Snorm;
case spv::ImageFormatRg8ui:
return wgpu::TextureFormat::RG8Uint;
case spv::ImageFormatRg8i:
return wgpu::TextureFormat::RG8Sint;
case spv::ImageFormatR32f:
return wgpu::TextureFormat::R32Float;
case spv::ImageFormatR32ui:
return wgpu::TextureFormat::R32Uint;
case spv::ImageFormatR32i:
return wgpu::TextureFormat::R32Sint;
case spv::ImageFormatRg16ui:
return wgpu::TextureFormat::RG16Uint;
case spv::ImageFormatRg16i:
return wgpu::TextureFormat::RG16Sint;
case spv::ImageFormatRg16f:
return wgpu::TextureFormat::RG16Float;
case spv::ImageFormatRgba8:
return wgpu::TextureFormat::RGBA8Unorm;
case spv::ImageFormatRgba8Snorm:
return wgpu::TextureFormat::RGBA8Snorm;
case spv::ImageFormatRgba8ui:
return wgpu::TextureFormat::RGBA8Uint;
case spv::ImageFormatRgba8i:
return wgpu::TextureFormat::RGBA8Sint;
case spv::ImageFormatRgb10A2:
return wgpu::TextureFormat::RGB10A2Unorm;
case spv::ImageFormatR11fG11fB10f:
return wgpu::TextureFormat::RG11B10Ufloat;
case spv::ImageFormatRg32f:
return wgpu::TextureFormat::RG32Float;
case spv::ImageFormatRg32ui:
return wgpu::TextureFormat::RG32Uint;
case spv::ImageFormatRg32i:
return wgpu::TextureFormat::RG32Sint;
case spv::ImageFormatRgba16ui:
return wgpu::TextureFormat::RGBA16Uint;
case spv::ImageFormatRgba16i:
return wgpu::TextureFormat::RGBA16Sint;
case spv::ImageFormatRgba16f:
return wgpu::TextureFormat::RGBA16Float;
case spv::ImageFormatRgba32f:
return wgpu::TextureFormat::RGBA32Float;
case spv::ImageFormatRgba32ui:
return wgpu::TextureFormat::RGBA32Uint;
case spv::ImageFormatRgba32i:
return wgpu::TextureFormat::RGBA32Sint;
default:
return wgpu::TextureFormat::Undefined;
}
}
Format::Type SpirvBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType) {
switch (spirvBaseType) {
case spirv_cross::SPIRType::Float:
return Format::Type::Float;
case spirv_cross::SPIRType::Int:
return Format::Type::Sint;
case spirv_cross::SPIRType::UInt:
return Format::Type::Uint;
default:
UNREACHABLE();
return Format::Type::Other;
}
}
} // namespace dawn_native

View File

@ -0,0 +1,44 @@
// Copyright 2020 The Dawn Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This file contains utilities to convert from-to spirv.hpp datatypes without polluting other
// headers with spirv.hpp
#ifndef DAWNNATIVE_SPIRV_UTILS_H_
#define DAWNNATIVE_SPIRV_UTILS_H_
#include "dawn_native/Format.h"
#include "dawn_native/PerStage.h"
#include "dawn_native/dawn_platform.h"
#include <spirv_cross.hpp>
namespace dawn_native {
// Returns the spirv_cross equivalent for this shader stage and vice-versa.
spv::ExecutionModel ShaderStageToExecutionModel(SingleShaderStage stage);
SingleShaderStage ExecutionModelToShaderStage(spv::ExecutionModel model);
// Returns the texture view dimension for corresponding to (dim, arrayed).
wgpu::TextureViewDimension SpirvDimToTextureViewDimension(spv::Dim dim, bool arrayed);
// Returns the texture format corresponding to format.
wgpu::TextureFormat SpirvImageFormatToTextureFormat(spv::ImageFormat format);
// Returns the format "component type" corresponding to the SPIRV base type.
Format::Type SpirvBaseTypeToFormatType(spirv_cross::SPIRType::BaseType spirvBaseType);
} // namespace dawn_native
#endif // DAWNNATIVE_SPIRV_UTILS_H_

View File

@ -34,8 +34,8 @@ namespace dawn_native { namespace metal {
ShaderModule* computeModule = ToBackend(descriptor->computeStage.module); ShaderModule* computeModule = ToBackend(descriptor->computeStage.module);
const char* computeEntryPoint = descriptor->computeStage.entryPoint; const char* computeEntryPoint = descriptor->computeStage.entryPoint;
ShaderModule::MetalFunctionData computeData; ShaderModule::MetalFunctionData computeData;
DAWN_TRY(computeModule->GetFunction(computeEntryPoint, SingleShaderStage::Compute, DAWN_TRY(computeModule->CreateFunction(computeEntryPoint, SingleShaderStage::Compute,
ToBackend(GetLayout()), &computeData)); ToBackend(GetLayout()), &computeData));
NSError* error = nil; NSError* error = nil;
mMtlComputePipelineState = mMtlComputePipelineState =
@ -46,7 +46,9 @@ namespace dawn_native { namespace metal {
} }
// Copy over the local workgroup size as it is passed to dispatch explicitly in Metal // Copy over the local workgroup size as it is passed to dispatch explicitly in Metal
mLocalWorkgroupSize = computeData.localWorkgroupSize; Origin3D localSize = GetStage(SingleShaderStage::Compute).metadata->localWorkgroupSize;
mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z);
mRequiresStorageBufferLength = computeData.needsStorageBufferLength; mRequiresStorageBufferLength = computeData.needsStorageBufferLength;
return {}; return {};
} }

View File

@ -335,8 +335,9 @@ namespace dawn_native { namespace metal {
ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module); ShaderModule* vertexModule = ToBackend(descriptor->vertexStage.module);
const char* vertexEntryPoint = descriptor->vertexStage.entryPoint; const char* vertexEntryPoint = descriptor->vertexStage.entryPoint;
ShaderModule::MetalFunctionData vertexData; ShaderModule::MetalFunctionData vertexData;
DAWN_TRY(vertexModule->GetFunction(vertexEntryPoint, SingleShaderStage::Vertex, DAWN_TRY(vertexModule->CreateFunction(vertexEntryPoint, SingleShaderStage::Vertex,
ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, this)); ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF,
this));
descriptorMTL.vertexFunction = vertexData.function; descriptorMTL.vertexFunction = vertexData.function;
if (vertexData.needsStorageBufferLength) { if (vertexData.needsStorageBufferLength) {
@ -346,9 +347,9 @@ namespace dawn_native { namespace metal {
ShaderModule* fragmentModule = ToBackend(descriptor->fragmentStage->module); ShaderModule* fragmentModule = ToBackend(descriptor->fragmentStage->module);
const char* fragmentEntryPoint = descriptor->fragmentStage->entryPoint; const char* fragmentEntryPoint = descriptor->fragmentStage->entryPoint;
ShaderModule::MetalFunctionData fragmentData; ShaderModule::MetalFunctionData fragmentData;
DAWN_TRY(fragmentModule->GetFunction(fragmentEntryPoint, SingleShaderStage::Fragment, DAWN_TRY(fragmentModule->CreateFunction(fragmentEntryPoint, SingleShaderStage::Fragment,
ToBackend(GetLayout()), &fragmentData, ToBackend(GetLayout()), &fragmentData,
descriptor->sampleMask)); descriptor->sampleMask));
descriptorMTL.fragmentFunction = fragmentData.function; descriptorMTL.fragmentFunction = fragmentData.function;
if (fragmentData.needsStorageBufferLength) { if (fragmentData.needsStorageBufferLength) {

View File

@ -38,18 +38,17 @@ namespace dawn_native { namespace metal {
struct MetalFunctionData { struct MetalFunctionData {
id<MTLFunction> function = nil; id<MTLFunction> function = nil;
MTLSize localWorkgroupSize;
bool needsStorageBufferLength; bool needsStorageBufferLength;
~MetalFunctionData() { ~MetalFunctionData() {
[function release]; [function release];
} }
}; };
MaybeError GetFunction(const char* functionName, MaybeError CreateFunction(const char* entryPointName,
SingleShaderStage functionStage, SingleShaderStage stage,
const PipelineLayout* layout, const PipelineLayout* layout,
MetalFunctionData* out, MetalFunctionData* out,
uint32_t sampleMask = 0xFFFFFFFF, uint32_t sampleMask = 0xFFFFFFFF,
const RenderPipeline* renderPipeline = nullptr); const RenderPipeline* renderPipeline = nullptr);
private: private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);

View File

@ -15,6 +15,7 @@
#include "dawn_native/metal/ShaderModuleMTL.h" #include "dawn_native/metal/ShaderModuleMTL.h"
#include "dawn_native/BindGroupLayout.h" #include "dawn_native/BindGroupLayout.h"
#include "dawn_native/SpirvUtils.h"
#include "dawn_native/metal/DeviceMTL.h" #include "dawn_native/metal/DeviceMTL.h"
#include "dawn_native/metal/PipelineLayoutMTL.h" #include "dawn_native/metal/PipelineLayoutMTL.h"
#include "dawn_native/metal/RenderPipelineMTL.h" #include "dawn_native/metal/RenderPipelineMTL.h"
@ -25,22 +26,6 @@
namespace dawn_native { namespace metal { namespace dawn_native { namespace metal {
namespace {
spv::ExecutionModel SpirvExecutionModelForStage(SingleShaderStage stage) {
switch (stage) {
case SingleShaderStage::Vertex:
return spv::ExecutionModelVertex;
case SingleShaderStage::Fragment:
return spv::ExecutionModelFragment;
case SingleShaderStage::Compute:
return spv::ExecutionModelGLCompute;
default:
UNREACHABLE();
}
}
} // namespace
// static // static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device, ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor) {
@ -57,25 +42,26 @@ namespace dawn_native { namespace metal {
return InitializeBase(); return InitializeBase();
} }
MaybeError ShaderModule::GetFunction(const char* functionName, MaybeError ShaderModule::CreateFunction(const char* entryPointName,
SingleShaderStage functionStage, SingleShaderStage stage,
const PipelineLayout* layout, const PipelineLayout* layout,
ShaderModule::MetalFunctionData* out, ShaderModule::MetalFunctionData* out,
uint32_t sampleMask, uint32_t sampleMask,
const RenderPipeline* renderPipeline) { const RenderPipeline* renderPipeline) {
ASSERT(!IsError()); ASSERT(!IsError());
ASSERT(out); ASSERT(out);
const std::vector<uint32_t>* spirv = &GetSpirv(); const std::vector<uint32_t>* spirv = &GetSpirv();
spv::ExecutionModel executionModel = ShaderStageToExecutionModel(stage);
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
// Use set 4 since it is bigger than what users can access currently // Use set 4 since it is bigger than what users can access currently
static const uint32_t kPullingBufferBindingSet = 4; static const uint32_t kPullingBufferBindingSet = 4;
std::vector<uint32_t> pullingSpirv; std::vector<uint32_t> pullingSpirv;
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
functionStage == SingleShaderStage::Vertex) { stage == SingleShaderStage::Vertex) {
DAWN_TRY_ASSIGN(pullingSpirv, DAWN_TRY_ASSIGN(pullingSpirv,
GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(), GeneratePullingSpirv(*renderPipeline->GetVertexStateDescriptor(),
functionName, kPullingBufferBindingSet)); entryPointName, kPullingBufferBindingSet));
spirv = &pullingSpirv; spirv = &pullingSpirv;
} }
#endif #endif
@ -99,6 +85,7 @@ namespace dawn_native { namespace metal {
spirv_cross::CompilerMSL compiler(*spirv); spirv_cross::CompilerMSL compiler(*spirv);
compiler.set_msl_options(options_msl); compiler.set_msl_options(options_msl);
compiler.set_entry_point(entryPointName, executionModel);
// By default SPIRV-Cross will give MSL resources indices in increasing order. // By default SPIRV-Cross will give MSL resources indices in increasing order.
// To make the MSL indices match the indices chosen in the PipelineLayout, we build // To make the MSL indices match the indices chosen in the PipelineLayout, we build
@ -116,30 +103,33 @@ namespace dawn_native { namespace metal {
const BindingInfo& bindingInfo = const BindingInfo& bindingInfo =
layout->GetBindGroupLayout(group)->GetBindingInfo(bindingIndex); layout->GetBindGroupLayout(group)->GetBindingInfo(bindingIndex);
for (auto stage : IterateStages(bindingInfo.visibility)) { if (!(bindingInfo.visibility & StageBit(stage))) {
uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex]; continue;
spirv_cross::MSLResourceBinding mslBinding;
mslBinding.stage = SpirvExecutionModelForStage(stage);
mslBinding.desc_set = static_cast<uint32_t>(group);
mslBinding.binding = static_cast<uint32_t>(bindingNumber);
mslBinding.msl_buffer = mslBinding.msl_texture = mslBinding.msl_sampler =
shaderIndex;
compiler.add_msl_resource_binding(mslBinding);
} }
uint32_t shaderIndex = layout->GetBindingIndexInfo(stage)[group][bindingIndex];
spirv_cross::MSLResourceBinding mslBinding;
mslBinding.stage = executionModel;
mslBinding.desc_set = static_cast<uint32_t>(group);
mslBinding.binding = static_cast<uint32_t>(bindingNumber);
mslBinding.msl_buffer = mslBinding.msl_texture = mslBinding.msl_sampler =
shaderIndex;
compiler.add_msl_resource_binding(mslBinding);
} }
} }
#ifdef DAWN_ENABLE_WGSL #ifdef DAWN_ENABLE_WGSL
// Add vertex buffers bound as storage buffers // Add vertex buffers bound as storage buffers
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
functionStage == SingleShaderStage::Vertex) { stage == SingleShaderStage::Vertex) {
for (uint32_t dawnIndex : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { for (uint32_t dawnIndex : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) {
uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(dawnIndex); uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(dawnIndex);
spirv_cross::MSLResourceBinding mslBinding; spirv_cross::MSLResourceBinding mslBinding;
mslBinding.stage = SpirvExecutionModelForStage(SingleShaderStage::Vertex); mslBinding.stage = spv::ExecutionModelVertex;
mslBinding.desc_set = kPullingBufferBindingSet; mslBinding.desc_set = kPullingBufferBindingSet;
mslBinding.binding = dawnIndex; mslBinding.binding = dawnIndex;
mslBinding.msl_buffer = metalIndex; mslBinding.msl_buffer = metalIndex;
@ -148,17 +138,17 @@ namespace dawn_native { namespace metal {
} }
#endif #endif
{
spv::ExecutionModel executionModel = SpirvExecutionModelForStage(functionStage);
auto size = compiler.get_entry_point(functionName, executionModel).workgroup_size;
out->localWorkgroupSize = MTLSizeMake(size.x, size.y, size.z);
}
{ {
// SPIRV-Cross also supports re-ordering attributes but it seems to do the correct thing // SPIRV-Cross also supports re-ordering attributes but it seems to do the correct thing
// by default. // by default.
NSString* mslSource; NSString* mslSource;
std::string msl = compiler.compile(); std::string msl = compiler.compile();
// Some entry point names are forbidden in MSL so SPIRV-Cross modifies them. Query the
// modified entryPointName from it.
const std::string& modifiedEntryPointName =
compiler.get_entry_point(entryPointName, executionModel).name;
// Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall // Metal uses Clang to compile the shader as C++14. Disable everything in the -Wall
// category. -Wunused-variable in particular comes up a lot in generated code, and some // category. -Wunused-variable in particular comes up a lot in generated code, and some
// (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead // (old?) Metal drivers accidentally treat it as a MTLLibraryErrorCompileError instead
@ -183,18 +173,7 @@ namespace dawn_native { namespace metal {
} }
} }
// TODO(kainino@chromium.org): make this somehow more robust; it needs to behave like NSString* name = [[NSString alloc] initWithUTF8String:modifiedEntryPointName.c_str()];
// clean_func_name:
// https://github.com/KhronosGroup/SPIRV-Cross/blob/4e915e8c483e319d0dd7a1fa22318bef28f8cca3/spirv_msl.cpp#L1213
const char* metalFunctionName = functionName;
if (strcmp(metalFunctionName, "main") == 0) {
metalFunctionName = "main0";
}
if (strcmp(metalFunctionName, "saturate") == 0) {
metalFunctionName = "saturate0";
}
NSString* name = [[NSString alloc] initWithUTF8String:metalFunctionName];
out->function = [library newFunctionWithName:name]; out->function = [library newFunctionWithName:name];
[library release]; [library release];
} }
@ -202,7 +181,7 @@ namespace dawn_native { namespace metal {
out->needsStorageBufferLength = compiler.needs_buffer_size_buffer(); out->needsStorageBufferLength = compiler.needs_buffer_size_buffer();
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
GetEntryPoint(functionName, functionStage).usedVertexAttributes.any()) { GetEntryPoint(entryPointName, stage).usedVertexAttributes.any()) {
out->needsStorageBufferLength = true; out->needsStorageBufferLength = true;
} }