diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp index bb846fdb15..2a5e31696e 100644 --- a/src/dawn_native/Pipeline.cpp +++ b/src/dawn_native/Pipeline.cpp @@ -54,7 +54,7 @@ namespace dawn_native { // Validate if overridable constants exist in shader module // pipelineBase is not yet constructed at this moment so iterate constants from descriptor size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size(); - // Keep a initialized constants sets to handle duplicate initialization cases + // Keep an initialized constants sets to handle duplicate initialization cases // Only storing that of uninialized constants is needed std::unordered_set stageInitializedConstantIdentifiers; for (uint32_t i = 0; i < constantCount; i++) { diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 3a8c26ca76..4f4cc3c418 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -636,11 +636,29 @@ namespace dawn_native { "are partially implemented."); const auto& name2Id = inspector.GetConstantNameToIdMap(); + const auto& id2Scalar = inspector.GetConstantIDs(); for (auto& c : entryPoint.overridable_constants) { + uint32_t id = name2Id.at(c.name); + OverridableConstantScalar defaultValue; + if (c.is_initialized) { + // if it is initialized, the scalar must exist + const auto& scalar = id2Scalar.at(id); + if (scalar.IsBool()) { + defaultValue.b = scalar.AsBool(); + } else if (scalar.IsU32()) { + defaultValue.u32 = scalar.AsU32(); + } else if (scalar.IsI32()) { + defaultValue.i32 = scalar.AsI32(); + } else if (scalar.IsFloat()) { + defaultValue.f32 = scalar.AsFloat(); + } else { + UNREACHABLE(); + } + } EntryPointMetadata::OverridableConstant constant = { - name2Id.at(c.name), FromTintOverridableConstantType(c.type), - c.is_initialized}; + id, FromTintOverridableConstantType(c.type), c.is_initialized, + defaultValue}; std::string identifier = c.is_numeric_id_specified ? std::to_string(constant.id) : c.name; @@ -651,6 +669,11 @@ namespace dawn_native { std::move(identifier)); // The insertion should have taken place ASSERT(it.second); + } else { + auto it = metadata->initializedOverridableConstants.emplace( + std::move(identifier)); + // The insertion should have taken place + ASSERT(it.second); } } } diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index 30f32946bb..2b13a2162a 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -150,6 +150,15 @@ namespace dawn_native { using BindingGroupInfoMap = std::map; using BindingInfoArray = ityp::array; + // The WebGPU overridable constants only support these scalar types + union OverridableConstantScalar { + // Use int32_t for boolean to initialize the full 32bit + int32_t b; + float f32; + int32_t i32; + uint32_t u32; + }; + // Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are // stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so // pointers to EntryPointMetadata are safe to store as long as you also keep a Ref to the @@ -206,6 +215,11 @@ namespace dawn_native { // Then it is required for the pipeline stage to have a constant record to initialize a // value bool isInitialized; + + // Store the default initialized value in shader + // This is used by metal backend as the function_constant does not have dafault values + // Initialized when isInitialized == true + OverridableConstantScalar defaultValue; }; // Map identifier to overridable constant @@ -216,6 +230,11 @@ namespace dawn_native { // They need value initialization from pipeline stage or it is a validation error std::unordered_set uninitializedOverridableConstants; + // Store constants with shader initialized values as well + // This is used by metal backend to set values with default initializers that are not + // overridden + std::unordered_set initializedOverridableConstants; + bool usesNumWorkgroups = false; }; diff --git a/src/dawn_native/metal/BackendMTL.mm b/src/dawn_native/metal/BackendMTL.mm index 9ebccdb517..ccbc46785e 100644 --- a/src/dawn_native/metal/BackendMTL.mm +++ b/src/dawn_native/metal/BackendMTL.mm @@ -156,7 +156,9 @@ namespace dawn_native { namespace metal { bool IsMetalSupported() { // Metal was first introduced in macOS 10.11 - return IsMacOSVersionAtLeast(10, 11); + // WebGPU is targeted at macOS 10.12+ + // TODO(dawn:1181): Dawn native should allow non-conformant WebGPU on macOS 10.11 + return IsMacOSVersionAtLeast(10, 12); } #elif defined(DAWN_PLATFORM_IOS) MaybeError GetDevicePCIInfo(id device, PCIIDs* ids) { diff --git a/src/dawn_native/metal/ComputePipelineMTL.mm b/src/dawn_native/metal/ComputePipelineMTL.mm index 18edc566fb..48d36ae14a 100644 --- a/src/dawn_native/metal/ComputePipelineMTL.mm +++ b/src/dawn_native/metal/ComputePipelineMTL.mm @@ -17,6 +17,7 @@ #include "dawn_native/CreatePipelineAsyncTask.h" #include "dawn_native/metal/DeviceMTL.h" #include "dawn_native/metal/ShaderModuleMTL.h" +#include "dawn_native/metal/UtilsMetal.h" namespace dawn_native { namespace metal { @@ -31,11 +32,10 @@ namespace dawn_native { namespace metal { auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice(); const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute); - ShaderModule* computeModule = ToBackend(computeStage.module.Get()); - const char* computeEntryPoint = computeStage.entryPoint.c_str(); ShaderModule::MetalFunctionData computeData; - DAWN_TRY(computeModule->CreateFunction(computeEntryPoint, SingleShaderStage::Compute, - ToBackend(GetLayout()), &computeData)); + + DAWN_TRY(CreateMTLFunction(computeStage, SingleShaderStage::Compute, ToBackend(GetLayout()), + &computeData)); NSError* error = nullptr; mMtlComputePipelineState.Acquire([mtlDevice diff --git a/src/dawn_native/metal/RenderPipelineMTL.mm b/src/dawn_native/metal/RenderPipelineMTL.mm index 1537abfa44..a4ca812e98 100644 --- a/src/dawn_native/metal/RenderPipelineMTL.mm +++ b/src/dawn_native/metal/RenderPipelineMTL.mm @@ -339,12 +339,9 @@ namespace dawn_native { namespace metal { const PerStage& allStages = GetAllStages(); const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex]; - ShaderModule* vertexModule = ToBackend(vertexStage.module).Get(); - const char* vertexEntryPoint = vertexStage.entryPoint.c_str(); ShaderModule::MetalFunctionData vertexData; - DAWN_TRY(vertexModule->CreateFunction(vertexEntryPoint, SingleShaderStage::Vertex, - ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF, - this)); + DAWN_TRY(CreateMTLFunction(vertexStage, SingleShaderStage::Vertex, ToBackend(GetLayout()), + &vertexData, 0xFFFFFFFF, this)); descriptorMTL.vertexFunction = vertexData.function.Get(); if (vertexData.needsStorageBufferLength) { @@ -353,12 +350,9 @@ namespace dawn_native { namespace metal { if (GetStageMask() & wgpu::ShaderStage::Fragment) { const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment]; - ShaderModule* fragmentModule = ToBackend(fragmentStage.module).Get(); - const char* fragmentEntryPoint = fragmentStage.entryPoint.c_str(); ShaderModule::MetalFunctionData fragmentData; - DAWN_TRY(fragmentModule->CreateFunction(fragmentEntryPoint, SingleShaderStage::Fragment, - ToBackend(GetLayout()), &fragmentData, - GetSampleMask())); + DAWN_TRY(CreateMTLFunction(fragmentStage, SingleShaderStage::Fragment, + ToBackend(GetLayout()), &fragmentData, GetSampleMask())); descriptorMTL.fragmentFunction = fragmentData.function.Get(); if (fragmentData.needsStorageBufferLength) { diff --git a/src/dawn_native/metal/ShaderModuleMTL.h b/src/dawn_native/metal/ShaderModuleMTL.h index 4cb91a4db6..e82ffad3f8 100644 --- a/src/dawn_native/metal/ShaderModuleMTL.h +++ b/src/dawn_native/metal/ShaderModuleMTL.h @@ -39,10 +39,14 @@ namespace dawn_native { namespace metal { bool needsStorageBufferLength; std::vector workgroupAllocations; }; + + // MTLFunctionConstantValues needs @available tag to compile + // Use id (like void*) in function signature as workaround and do static cast inside MaybeError CreateFunction(const char* entryPointName, SingleShaderStage stage, const PipelineLayout* layout, MetalFunctionData* out, + id constantValues = nil, uint32_t sampleMask = 0xFFFFFFFF, const RenderPipeline* renderPipeline = nullptr); diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm index 4c0aaba951..9189cf84c6 100644 --- a/src/dawn_native/metal/ShaderModuleMTL.mm +++ b/src/dawn_native/metal/ShaderModuleMTL.mm @@ -174,6 +174,7 @@ namespace dawn_native { namespace metal { SingleShaderStage stage, const PipelineLayout* layout, ShaderModule::MetalFunctionData* out, + id constantValuesPointer, uint32_t sampleMask, const RenderPipeline* renderPipeline) { ASSERT(!IsError()); @@ -231,7 +232,26 @@ namespace dawn_native { namespace metal { NSRef name = AcquireNSRef([[NSString alloc] initWithUTF8String:remappedEntryPointName.c_str()]); - out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]); + + if (constantValuesPointer != nil) { + if (@available(macOS 10.12, *)) { + MTLFunctionConstantValues* constantValues = constantValuesPointer; + out->function = AcquireNSPRef([*library newFunctionWithName:name.Get() + constantValues:constantValues + error:&error]); + if (error != nullptr) { + if (error.code != MTLLibraryErrorCompileWarning) { + return DAWN_VALIDATION_ERROR(std::string("Function compile error: ") + + [error.localizedDescription UTF8String]); + } + } + ASSERT(out->function != nil); + } else { + UNREACHABLE(); + } + } else { + out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]); + } if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && GetEntryPoint(entryPointName).usedVertexInputs.any()) { diff --git a/src/dawn_native/metal/UtilsMetal.h b/src/dawn_native/metal/UtilsMetal.h index 6855734f8e..3a17c99b27 100644 --- a/src/dawn_native/metal/UtilsMetal.h +++ b/src/dawn_native/metal/UtilsMetal.h @@ -17,10 +17,17 @@ #include "dawn_native/dawn_platform.h" #include "dawn_native/metal/DeviceMTL.h" +#include "dawn_native/metal/ShaderModuleMTL.h" #include "dawn_native/metal/TextureMTL.h" #import +namespace dawn_native { + struct ProgrammableStage; + struct EntryPointMetadata; + enum class SingleShaderStage; +} + namespace dawn_native { namespace metal { MTLCompareFunction ToMetalCompareFunction(wgpu::CompareFunction compareFunction); @@ -65,6 +72,15 @@ namespace dawn_native { namespace metal { MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect); + // Helper function to create function with constant values wrapped in + // if available branch + MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage, + SingleShaderStage singleShaderStage, + PipelineLayout* pipelineLayout, + ShaderModule::MetalFunctionData* functionData, + uint32_t sampleMask = 0xFFFFFFFF, + const RenderPipeline* renderPipeline = nullptr); + }} // namespace dawn_native::metal #endif // DAWNNATIVE_METAL_UTILSMETAL_H_ diff --git a/src/dawn_native/metal/UtilsMetal.mm b/src/dawn_native/metal/UtilsMetal.mm index 51fa99325c..1a7962afde 100644 --- a/src/dawn_native/metal/UtilsMetal.mm +++ b/src/dawn_native/metal/UtilsMetal.mm @@ -14,6 +14,8 @@ #include "dawn_native/metal/UtilsMetal.h" #include "dawn_native/CommandBuffer.h" +#include "dawn_native/Pipeline.h" +#include "dawn_native/ShaderModule.h" #include "common/Assert.h" @@ -186,4 +188,106 @@ namespace dawn_native { namespace metal { return MTLBlitOptionNone; } + MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage, + SingleShaderStage singleShaderStage, + PipelineLayout* pipelineLayout, + ShaderModule::MetalFunctionData* functionData, + uint32_t sampleMask, + const RenderPipeline* renderPipeline) { + ShaderModule* shaderModule = ToBackend(programmableStage.module.Get()); + const char* shaderEntryPoint = programmableStage.entryPoint.c_str(); + const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint); + if (entryPointMetadata.overridableConstants.size() == 0) { + DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, + pipelineLayout, functionData, nil, sampleMask, + renderPipeline)); + return {}; + } + + if (@available(macOS 10.12, *)) { + // MTLFunctionConstantValues can only be created within the if available branch + NSRef constantValues = + AcquireNSRef([MTLFunctionConstantValues new]); + + std::unordered_set overriddenConstants; + + auto switchType = [&](EntryPointMetadata::OverridableConstant::Type dawnType, + MTLDataType* type, OverridableConstantScalar* entry, + double value = 0) { + switch (dawnType) { + case EntryPointMetadata::OverridableConstant::Type::Boolean: + *type = MTLDataTypeBool; + if (entry) { + entry->b = static_cast(value); + } + break; + case EntryPointMetadata::OverridableConstant::Type::Float32: + *type = MTLDataTypeFloat; + if (entry) { + entry->f32 = static_cast(value); + } + break; + case EntryPointMetadata::OverridableConstant::Type::Int32: + *type = MTLDataTypeInt; + if (entry) { + entry->i32 = static_cast(value); + } + break; + case EntryPointMetadata::OverridableConstant::Type::Uint32: + *type = MTLDataTypeUInt; + if (entry) { + entry->u32 = static_cast(value); + } + break; + default: + UNREACHABLE(); + } + }; + + for (const auto& pipelineConstant : programmableStage.constants) { + const std::string& name = pipelineConstant.first; + double value = pipelineConstant.second; + + overriddenConstants.insert(name); + + // This is already validated so `name` must exist + const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name); + + MTLDataType type; + OverridableConstantScalar entry{}; + + switchType(moduleConstant.type, &type, &entry, value); + + [constantValues.Get() setConstantValue:&entry type:type atIndex:moduleConstant.id]; + } + + // Set shader initialized default values because MSL function_constant + // has no default value + for (const std::string& name : entryPointMetadata.initializedOverridableConstants) { + if (overriddenConstants.count(name) != 0) { + // This constant already has overridden value + continue; + } + + // Must exist because it is validated + const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name); + ASSERT(moduleConstant.isInitialized); + MTLDataType type; + + switchType(moduleConstant.type, &type, nullptr); + + [constantValues.Get() setConstantValue:&moduleConstant.defaultValue + type:type + atIndex:moduleConstant.id]; + } + + DAWN_TRY(shaderModule->CreateFunction( + shaderEntryPoint, singleShaderStage, pipelineLayout, functionData, + constantValues.Get(), sampleMask, renderPipeline)); + } else { + UNREACHABLE(); + } + return {}; + } + }} // namespace dawn_native::metal diff --git a/src/dawn_native/vulkan/ComputePipelineVk.cpp b/src/dawn_native/vulkan/ComputePipelineVk.cpp index c30f8c2caf..c6c3278b3d 100644 --- a/src/dawn_native/vulkan/ComputePipelineVk.cpp +++ b/src/dawn_native/vulkan/ComputePipelineVk.cpp @@ -53,7 +53,7 @@ namespace dawn_native { namespace vulkan { createInfo.stage.pName = computeStage.entryPoint.c_str(); - std::vector specializationDataEntries; + std::vector specializationDataEntries; std::vector specializationMapEntries; VkSpecializationInfo specializationInfo{}; createInfo.stage.pSpecializationInfo = diff --git a/src/dawn_native/vulkan/RenderPipelineVk.cpp b/src/dawn_native/vulkan/RenderPipelineVk.cpp index 93e6eb3ae4..ff96ec1df0 100644 --- a/src/dawn_native/vulkan/RenderPipelineVk.cpp +++ b/src/dawn_native/vulkan/RenderPipelineVk.cpp @@ -339,7 +339,7 @@ namespace dawn_native { namespace vulkan { // There are at most 2 shader stages in render pipeline, i.e. vertex and fragment std::array shaderStages; - std::array, 2> specializationDataEntriesPerStages; + std::array, 2> specializationDataEntriesPerStages; std::array, 2> specializationMapEntriesPerStages; std::array specializationInfoPerStages; uint32_t stageCount = 0; diff --git a/src/dawn_native/vulkan/UtilsVulkan.cpp b/src/dawn_native/vulkan/UtilsVulkan.cpp index 87d896d4aa..6e316c567d 100644 --- a/src/dawn_native/vulkan/UtilsVulkan.cpp +++ b/src/dawn_native/vulkan/UtilsVulkan.cpp @@ -201,7 +201,7 @@ namespace dawn_native { namespace vulkan { VkSpecializationInfo* GetVkSpecializationInfo( const ProgrammableStage& programmableStage, VkSpecializationInfo* specializationInfo, - std::vector* specializationDataEntries, + std::vector* specializationDataEntries, std::vector* specializationMapEntries) { ASSERT(specializationInfo); ASSERT(specializationDataEntries); @@ -224,10 +224,10 @@ namespace dawn_native { namespace vulkan { specializationMapEntries->push_back( VkSpecializationMapEntry{moduleConstant.id, static_cast(specializationDataEntries->size() * - sizeof(SpecializationDataEntry)), - sizeof(SpecializationDataEntry)}); + sizeof(OverridableConstantScalar)), + sizeof(OverridableConstantScalar)}); - SpecializationDataEntry entry{}; + OverridableConstantScalar entry{}; switch (moduleConstant.type) { case EntryPointMetadata::OverridableConstant::Type::Boolean: entry.b = static_cast(value); @@ -250,7 +250,7 @@ namespace dawn_native { namespace vulkan { specializationInfo->mapEntryCount = static_cast(specializationMapEntries->size()); specializationInfo->pMapEntries = specializationMapEntries->data(); specializationInfo->dataSize = - specializationDataEntries->size() * sizeof(SpecializationDataEntry); + specializationDataEntries->size() * sizeof(OverridableConstantScalar); specializationInfo->pData = specializationDataEntries->data(); return specializationInfo; diff --git a/src/dawn_native/vulkan/UtilsVulkan.h b/src/dawn_native/vulkan/UtilsVulkan.h index c3859172d9..53b6d41e26 100644 --- a/src/dawn_native/vulkan/UtilsVulkan.h +++ b/src/dawn_native/vulkan/UtilsVulkan.h @@ -21,6 +21,7 @@ namespace dawn_native { struct ProgrammableStage; + union OverridableConstantScalar; } // namespace dawn_native namespace dawn_native { namespace vulkan { @@ -111,23 +112,13 @@ namespace dawn_native { namespace vulkan { const char* prefix, std::string label = ""); - // Helpers for creating VkSpecializationInfo - // The WebGPU overridable constants only support these scalar types - union SpecializationDataEntry { - // Use int32_t for boolean to initialize the full 32bit - int32_t b; - float f32; - int32_t i32; - uint32_t u32; - }; - // Returns nullptr or &specializationInfo // specializationInfo, specializationDataEntries, specializationMapEntries needs to // be alive at least until VkSpecializationInfo is passed into Vulkan Create*Pipelines VkSpecializationInfo* GetVkSpecializationInfo( const ProgrammableStage& programmableStage, VkSpecializationInfo* specializationInfo, - std::vector* specializationDataEntries, + std::vector* specializationDataEntries, std::vector* specializationMapEntries); }} // namespace dawn_native::vulkan diff --git a/src/tests/end2end/ShaderTests.cpp b/src/tests/end2end/ShaderTests.cpp index 4650d7b684..19865e01dd 100644 --- a/src/tests/end2end/ShaderTests.cpp +++ b/src/tests/end2end/ShaderTests.cpp @@ -391,8 +391,8 @@ fn main([[location(0)]] pos : vec4) -> [[builtin(position)]] vec4 { // Test overridable constants without numeric identifiers TEST_P(ShaderTests, OverridableConstants) { - // TODO(dawn:1041): Only Vulkan backend is implemented - DAWN_TEST_UNSUPPORTED_IF(!IsVulkan()); + // TODO(dawn:1137): D3D12 backend is unimplemented + DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal()); uint32_t const kCount = 11; std::vector expected(kCount); @@ -469,8 +469,8 @@ TEST_P(ShaderTests, OverridableConstants) { // Test overridable constants with numeric identifiers TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) { - // TODO(dawn:1041): Only Vulkan backend is implemented - DAWN_TEST_UNSUPPORTED_IF(!IsVulkan()); + // TODO(dawn:1137): D3D12 backend is unimplemented + DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal()); uint32_t const kCount = 4; std::vector expected{1u, 2u, 3u, 0u}; @@ -525,8 +525,8 @@ TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) { // Test overridable constants for different entry points TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) { - // TODO(dawn:1041): Only Vulkan backend is implemented - DAWN_TEST_UNSUPPORTED_IF(!IsVulkan()); + // TODO(dawn:1137): D3D12 backend is unimplemented + DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal()); uint32_t const kCount = 1; std::vector expected1{1u}; @@ -607,8 +607,8 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) { // Draw a triangle covering the render target, with vertex position and color values from // overridable constants TEST_P(ShaderTests, OverridableConstantsRenderPipeline) { - // TODO(dawn:1041): Only Vulkan backend is implemented - DAWN_TEST_UNSUPPORTED_IF(!IsVulkan()); + // TODO(dawn:1137): D3D12 backend is unimplemented + DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal()); wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"( [[override(1111)]] let xright: f32;