Pipeline overridable constants: Metal backend

Also gate Dawn metal behind MacOS 10.12+

Bug: dawn:1136, dawn:1181
Change-Id: Id7bfaa2953b1acf08f37e6a08ddeadd9cde44657
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/67421
Reviewed-by: Austin Eng <enga@chromium.org>
Commit-Queue: Shrek Shao <shrekshao@google.com>
This commit is contained in:
Shrek Shao 2021-11-04 18:43:10 +00:00 committed by Dawn LUCI CQ
parent f4a6ad8edd
commit b0a5ed49b1
15 changed files with 218 additions and 45 deletions

View File

@ -54,7 +54,7 @@ namespace dawn_native {
// Validate if overridable constants exist in shader module
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size();
// Keep a initialized constants sets to handle duplicate initialization cases
// Keep an initialized constants sets to handle duplicate initialization cases
// Only storing that of uninialized constants is needed
std::unordered_set<std::string> stageInitializedConstantIdentifiers;
for (uint32_t i = 0; i < constantCount; i++) {

View File

@ -636,11 +636,29 @@ namespace dawn_native {
"are partially implemented.");
const auto& name2Id = inspector.GetConstantNameToIdMap();
const auto& id2Scalar = inspector.GetConstantIDs();
for (auto& c : entryPoint.overridable_constants) {
uint32_t id = name2Id.at(c.name);
OverridableConstantScalar defaultValue;
if (c.is_initialized) {
// if it is initialized, the scalar must exist
const auto& scalar = id2Scalar.at(id);
if (scalar.IsBool()) {
defaultValue.b = scalar.AsBool();
} else if (scalar.IsU32()) {
defaultValue.u32 = scalar.AsU32();
} else if (scalar.IsI32()) {
defaultValue.i32 = scalar.AsI32();
} else if (scalar.IsFloat()) {
defaultValue.f32 = scalar.AsFloat();
} else {
UNREACHABLE();
}
}
EntryPointMetadata::OverridableConstant constant = {
name2Id.at(c.name), FromTintOverridableConstantType(c.type),
c.is_initialized};
id, FromTintOverridableConstantType(c.type), c.is_initialized,
defaultValue};
std::string identifier =
c.is_numeric_id_specified ? std::to_string(constant.id) : c.name;
@ -651,6 +669,11 @@ namespace dawn_native {
std::move(identifier));
// The insertion should have taken place
ASSERT(it.second);
} else {
auto it = metadata->initializedOverridableConstants.emplace(
std::move(identifier));
// The insertion should have taken place
ASSERT(it.second);
}
}
}

View File

@ -150,6 +150,15 @@ namespace dawn_native {
using BindingGroupInfoMap = std::map<BindingNumber, ShaderBindingInfo>;
using BindingInfoArray = ityp::array<BindGroupIndex, BindingGroupInfoMap, kMaxBindGroups>;
// The WebGPU overridable constants only support these scalar types
union OverridableConstantScalar {
// Use int32_t for boolean to initialize the full 32bit
int32_t b;
float f32;
int32_t i32;
uint32_t u32;
};
// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
// stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so
// pointers to EntryPointMetadata are safe to store as long as you also keep a Ref to the
@ -206,6 +215,11 @@ namespace dawn_native {
// Then it is required for the pipeline stage to have a constant record to initialize a
// value
bool isInitialized;
// Store the default initialized value in shader
// This is used by metal backend as the function_constant does not have dafault values
// Initialized when isInitialized == true
OverridableConstantScalar defaultValue;
};
// Map identifier to overridable constant
@ -216,6 +230,11 @@ namespace dawn_native {
// They need value initialization from pipeline stage or it is a validation error
std::unordered_set<std::string> uninitializedOverridableConstants;
// Store constants with shader initialized values as well
// This is used by metal backend to set values with default initializers that are not
// overridden
std::unordered_set<std::string> initializedOverridableConstants;
bool usesNumWorkgroups = false;
};

View File

@ -156,7 +156,9 @@ namespace dawn_native { namespace metal {
bool IsMetalSupported() {
// Metal was first introduced in macOS 10.11
return IsMacOSVersionAtLeast(10, 11);
// WebGPU is targeted at macOS 10.12+
// TODO(dawn:1181): Dawn native should allow non-conformant WebGPU on macOS 10.11
return IsMacOSVersionAtLeast(10, 12);
}
#elif defined(DAWN_PLATFORM_IOS)
MaybeError GetDevicePCIInfo(id<MTLDevice> device, PCIIDs* ids) {

View File

@ -17,6 +17,7 @@
#include "dawn_native/CreatePipelineAsyncTask.h"
#include "dawn_native/metal/DeviceMTL.h"
#include "dawn_native/metal/ShaderModuleMTL.h"
#include "dawn_native/metal/UtilsMetal.h"
namespace dawn_native { namespace metal {
@ -31,11 +32,10 @@ namespace dawn_native { namespace metal {
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
ShaderModule* computeModule = ToBackend(computeStage.module.Get());
const char* computeEntryPoint = computeStage.entryPoint.c_str();
ShaderModule::MetalFunctionData computeData;
DAWN_TRY(computeModule->CreateFunction(computeEntryPoint, SingleShaderStage::Compute,
ToBackend(GetLayout()), &computeData));
DAWN_TRY(CreateMTLFunction(computeStage, SingleShaderStage::Compute, ToBackend(GetLayout()),
&computeData));
NSError* error = nullptr;
mMtlComputePipelineState.Acquire([mtlDevice

View File

@ -339,12 +339,9 @@ namespace dawn_native { namespace metal {
const PerStage<ProgrammableStage>& allStages = GetAllStages();
const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex];
ShaderModule* vertexModule = ToBackend(vertexStage.module).Get();
const char* vertexEntryPoint = vertexStage.entryPoint.c_str();
ShaderModule::MetalFunctionData vertexData;
DAWN_TRY(vertexModule->CreateFunction(vertexEntryPoint, SingleShaderStage::Vertex,
ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF,
this));
DAWN_TRY(CreateMTLFunction(vertexStage, SingleShaderStage::Vertex, ToBackend(GetLayout()),
&vertexData, 0xFFFFFFFF, this));
descriptorMTL.vertexFunction = vertexData.function.Get();
if (vertexData.needsStorageBufferLength) {
@ -353,12 +350,9 @@ namespace dawn_native { namespace metal {
if (GetStageMask() & wgpu::ShaderStage::Fragment) {
const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment];
ShaderModule* fragmentModule = ToBackend(fragmentStage.module).Get();
const char* fragmentEntryPoint = fragmentStage.entryPoint.c_str();
ShaderModule::MetalFunctionData fragmentData;
DAWN_TRY(fragmentModule->CreateFunction(fragmentEntryPoint, SingleShaderStage::Fragment,
ToBackend(GetLayout()), &fragmentData,
GetSampleMask()));
DAWN_TRY(CreateMTLFunction(fragmentStage, SingleShaderStage::Fragment,
ToBackend(GetLayout()), &fragmentData, GetSampleMask()));
descriptorMTL.fragmentFunction = fragmentData.function.Get();
if (fragmentData.needsStorageBufferLength) {

View File

@ -39,10 +39,14 @@ namespace dawn_native { namespace metal {
bool needsStorageBufferLength;
std::vector<uint32_t> workgroupAllocations;
};
// MTLFunctionConstantValues needs @available tag to compile
// Use id (like void*) in function signature as workaround and do static cast inside
MaybeError CreateFunction(const char* entryPointName,
SingleShaderStage stage,
const PipelineLayout* layout,
MetalFunctionData* out,
id constantValues = nil,
uint32_t sampleMask = 0xFFFFFFFF,
const RenderPipeline* renderPipeline = nullptr);

View File

@ -174,6 +174,7 @@ namespace dawn_native { namespace metal {
SingleShaderStage stage,
const PipelineLayout* layout,
ShaderModule::MetalFunctionData* out,
id constantValuesPointer,
uint32_t sampleMask,
const RenderPipeline* renderPipeline) {
ASSERT(!IsError());
@ -231,7 +232,26 @@ namespace dawn_native { namespace metal {
NSRef<NSString> name =
AcquireNSRef([[NSString alloc] initWithUTF8String:remappedEntryPointName.c_str()]);
if (constantValuesPointer != nil) {
if (@available(macOS 10.12, *)) {
MTLFunctionConstantValues* constantValues = constantValuesPointer;
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()
constantValues:constantValues
error:&error]);
if (error != nullptr) {
if (error.code != MTLLibraryErrorCompileWarning) {
return DAWN_VALIDATION_ERROR(std::string("Function compile error: ") +
[error.localizedDescription UTF8String]);
}
}
ASSERT(out->function != nil);
} else {
UNREACHABLE();
}
} else {
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
}
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
GetEntryPoint(entryPointName).usedVertexInputs.any()) {

View File

@ -17,10 +17,17 @@
#include "dawn_native/dawn_platform.h"
#include "dawn_native/metal/DeviceMTL.h"
#include "dawn_native/metal/ShaderModuleMTL.h"
#include "dawn_native/metal/TextureMTL.h"
#import <Metal/Metal.h>
namespace dawn_native {
struct ProgrammableStage;
struct EntryPointMetadata;
enum class SingleShaderStage;
}
namespace dawn_native { namespace metal {
MTLCompareFunction ToMetalCompareFunction(wgpu::CompareFunction compareFunction);
@ -65,6 +72,15 @@ namespace dawn_native { namespace metal {
MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect);
// Helper function to create function with constant values wrapped in
// if available branch
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
SingleShaderStage singleShaderStage,
PipelineLayout* pipelineLayout,
ShaderModule::MetalFunctionData* functionData,
uint32_t sampleMask = 0xFFFFFFFF,
const RenderPipeline* renderPipeline = nullptr);
}} // namespace dawn_native::metal
#endif // DAWNNATIVE_METAL_UTILSMETAL_H_

View File

@ -14,6 +14,8 @@
#include "dawn_native/metal/UtilsMetal.h"
#include "dawn_native/CommandBuffer.h"
#include "dawn_native/Pipeline.h"
#include "dawn_native/ShaderModule.h"
#include "common/Assert.h"
@ -186,4 +188,106 @@ namespace dawn_native { namespace metal {
return MTLBlitOptionNone;
}
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
SingleShaderStage singleShaderStage,
PipelineLayout* pipelineLayout,
ShaderModule::MetalFunctionData* functionData,
uint32_t sampleMask,
const RenderPipeline* renderPipeline) {
ShaderModule* shaderModule = ToBackend(programmableStage.module.Get());
const char* shaderEntryPoint = programmableStage.entryPoint.c_str();
const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint);
if (entryPointMetadata.overridableConstants.size() == 0) {
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage,
pipelineLayout, functionData, nil, sampleMask,
renderPipeline));
return {};
}
if (@available(macOS 10.12, *)) {
// MTLFunctionConstantValues can only be created within the if available branch
NSRef<MTLFunctionConstantValues> constantValues =
AcquireNSRef([MTLFunctionConstantValues new]);
std::unordered_set<std::string> overriddenConstants;
auto switchType = [&](EntryPointMetadata::OverridableConstant::Type dawnType,
MTLDataType* type, OverridableConstantScalar* entry,
double value = 0) {
switch (dawnType) {
case EntryPointMetadata::OverridableConstant::Type::Boolean:
*type = MTLDataTypeBool;
if (entry) {
entry->b = static_cast<int32_t>(value);
}
break;
case EntryPointMetadata::OverridableConstant::Type::Float32:
*type = MTLDataTypeFloat;
if (entry) {
entry->f32 = static_cast<float>(value);
}
break;
case EntryPointMetadata::OverridableConstant::Type::Int32:
*type = MTLDataTypeInt;
if (entry) {
entry->i32 = static_cast<int32_t>(value);
}
break;
case EntryPointMetadata::OverridableConstant::Type::Uint32:
*type = MTLDataTypeUInt;
if (entry) {
entry->u32 = static_cast<uint32_t>(value);
}
break;
default:
UNREACHABLE();
}
};
for (const auto& pipelineConstant : programmableStage.constants) {
const std::string& name = pipelineConstant.first;
double value = pipelineConstant.second;
overriddenConstants.insert(name);
// This is already validated so `name` must exist
const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name);
MTLDataType type;
OverridableConstantScalar entry{};
switchType(moduleConstant.type, &type, &entry, value);
[constantValues.Get() setConstantValue:&entry type:type atIndex:moduleConstant.id];
}
// Set shader initialized default values because MSL function_constant
// has no default value
for (const std::string& name : entryPointMetadata.initializedOverridableConstants) {
if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value
continue;
}
// Must exist because it is validated
const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name);
ASSERT(moduleConstant.isInitialized);
MTLDataType type;
switchType(moduleConstant.type, &type, nullptr);
[constantValues.Get() setConstantValue:&moduleConstant.defaultValue
type:type
atIndex:moduleConstant.id];
}
DAWN_TRY(shaderModule->CreateFunction(
shaderEntryPoint, singleShaderStage, pipelineLayout, functionData,
constantValues.Get(), sampleMask, renderPipeline));
} else {
UNREACHABLE();
}
return {};
}
}} // namespace dawn_native::metal

View File

@ -53,7 +53,7 @@ namespace dawn_native { namespace vulkan {
createInfo.stage.pName = computeStage.entryPoint.c_str();
std::vector<SpecializationDataEntry> specializationDataEntries;
std::vector<OverridableConstantScalar> specializationDataEntries;
std::vector<VkSpecializationMapEntry> specializationMapEntries;
VkSpecializationInfo specializationInfo{};
createInfo.stage.pSpecializationInfo =

View File

@ -339,7 +339,7 @@ namespace dawn_native { namespace vulkan {
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
std::array<std::vector<SpecializationDataEntry>, 2> specializationDataEntriesPerStages;
std::array<std::vector<OverridableConstantScalar>, 2> specializationDataEntriesPerStages;
std::array<std::vector<VkSpecializationMapEntry>, 2> specializationMapEntriesPerStages;
std::array<VkSpecializationInfo, 2> specializationInfoPerStages;
uint32_t stageCount = 0;

View File

@ -201,7 +201,7 @@ namespace dawn_native { namespace vulkan {
VkSpecializationInfo* GetVkSpecializationInfo(
const ProgrammableStage& programmableStage,
VkSpecializationInfo* specializationInfo,
std::vector<SpecializationDataEntry>* specializationDataEntries,
std::vector<OverridableConstantScalar>* specializationDataEntries,
std::vector<VkSpecializationMapEntry>* specializationMapEntries) {
ASSERT(specializationInfo);
ASSERT(specializationDataEntries);
@ -224,10 +224,10 @@ namespace dawn_native { namespace vulkan {
specializationMapEntries->push_back(
VkSpecializationMapEntry{moduleConstant.id,
static_cast<uint32_t>(specializationDataEntries->size() *
sizeof(SpecializationDataEntry)),
sizeof(SpecializationDataEntry)});
sizeof(OverridableConstantScalar)),
sizeof(OverridableConstantScalar)});
SpecializationDataEntry entry{};
OverridableConstantScalar entry{};
switch (moduleConstant.type) {
case EntryPointMetadata::OverridableConstant::Type::Boolean:
entry.b = static_cast<int32_t>(value);
@ -250,7 +250,7 @@ namespace dawn_native { namespace vulkan {
specializationInfo->mapEntryCount = static_cast<uint32_t>(specializationMapEntries->size());
specializationInfo->pMapEntries = specializationMapEntries->data();
specializationInfo->dataSize =
specializationDataEntries->size() * sizeof(SpecializationDataEntry);
specializationDataEntries->size() * sizeof(OverridableConstantScalar);
specializationInfo->pData = specializationDataEntries->data();
return specializationInfo;

View File

@ -21,6 +21,7 @@
namespace dawn_native {
struct ProgrammableStage;
union OverridableConstantScalar;
} // namespace dawn_native
namespace dawn_native { namespace vulkan {
@ -111,23 +112,13 @@ namespace dawn_native { namespace vulkan {
const char* prefix,
std::string label = "");
// Helpers for creating VkSpecializationInfo
// The WebGPU overridable constants only support these scalar types
union SpecializationDataEntry {
// Use int32_t for boolean to initialize the full 32bit
int32_t b;
float f32;
int32_t i32;
uint32_t u32;
};
// Returns nullptr or &specializationInfo
// specializationInfo, specializationDataEntries, specializationMapEntries needs to
// be alive at least until VkSpecializationInfo is passed into Vulkan Create*Pipelines
VkSpecializationInfo* GetVkSpecializationInfo(
const ProgrammableStage& programmableStage,
VkSpecializationInfo* specializationInfo,
std::vector<SpecializationDataEntry>* specializationDataEntries,
std::vector<OverridableConstantScalar>* specializationDataEntries,
std::vector<VkSpecializationMapEntry>* specializationMapEntries);
}} // namespace dawn_native::vulkan

View File

@ -391,8 +391,8 @@ fn main([[location(0)]] pos : vec4<f32>) -> [[builtin(position)]] vec4<f32> {
// Test overridable constants without numeric identifiers
TEST_P(ShaderTests, OverridableConstants) {
// TODO(dawn:1041): Only Vulkan backend is implemented
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan());
// TODO(dawn:1137): D3D12 backend is unimplemented
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
uint32_t const kCount = 11;
std::vector<uint32_t> expected(kCount);
@ -469,8 +469,8 @@ TEST_P(ShaderTests, OverridableConstants) {
// Test overridable constants with numeric identifiers
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
// TODO(dawn:1041): Only Vulkan backend is implemented
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan());
// TODO(dawn:1137): D3D12 backend is unimplemented
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
uint32_t const kCount = 4;
std::vector<uint32_t> expected{1u, 2u, 3u, 0u};
@ -525,8 +525,8 @@ TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
// Test overridable constants for different entry points
TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
// TODO(dawn:1041): Only Vulkan backend is implemented
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan());
// TODO(dawn:1137): D3D12 backend is unimplemented
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
uint32_t const kCount = 1;
std::vector<uint32_t> expected1{1u};
@ -607,8 +607,8 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
// Draw a triangle covering the render target, with vertex position and color values from
// overridable constants
TEST_P(ShaderTests, OverridableConstantsRenderPipeline) {
// TODO(dawn:1041): Only Vulkan backend is implemented
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan());
// TODO(dawn:1137): D3D12 backend is unimplemented
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
[[override(1111)]] let xright: f32;