Pipeline overridable constants: Metal backend
Also gate Dawn metal behind MacOS 10.12+ Bug: dawn:1136, dawn:1181 Change-Id: Id7bfaa2953b1acf08f37e6a08ddeadd9cde44657 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/67421 Reviewed-by: Austin Eng <enga@chromium.org> Commit-Queue: Shrek Shao <shrekshao@google.com>
This commit is contained in:
parent
f4a6ad8edd
commit
b0a5ed49b1
|
@ -54,7 +54,7 @@ namespace dawn_native {
|
||||||
// Validate if overridable constants exist in shader module
|
// Validate if overridable constants exist in shader module
|
||||||
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
|
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
|
||||||
size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size();
|
size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size();
|
||||||
// Keep a initialized constants sets to handle duplicate initialization cases
|
// Keep an initialized constants sets to handle duplicate initialization cases
|
||||||
// Only storing that of uninialized constants is needed
|
// Only storing that of uninialized constants is needed
|
||||||
std::unordered_set<std::string> stageInitializedConstantIdentifiers;
|
std::unordered_set<std::string> stageInitializedConstantIdentifiers;
|
||||||
for (uint32_t i = 0; i < constantCount; i++) {
|
for (uint32_t i = 0; i < constantCount; i++) {
|
||||||
|
|
|
@ -636,11 +636,29 @@ namespace dawn_native {
|
||||||
"are partially implemented.");
|
"are partially implemented.");
|
||||||
|
|
||||||
const auto& name2Id = inspector.GetConstantNameToIdMap();
|
const auto& name2Id = inspector.GetConstantNameToIdMap();
|
||||||
|
const auto& id2Scalar = inspector.GetConstantIDs();
|
||||||
|
|
||||||
for (auto& c : entryPoint.overridable_constants) {
|
for (auto& c : entryPoint.overridable_constants) {
|
||||||
|
uint32_t id = name2Id.at(c.name);
|
||||||
|
OverridableConstantScalar defaultValue;
|
||||||
|
if (c.is_initialized) {
|
||||||
|
// if it is initialized, the scalar must exist
|
||||||
|
const auto& scalar = id2Scalar.at(id);
|
||||||
|
if (scalar.IsBool()) {
|
||||||
|
defaultValue.b = scalar.AsBool();
|
||||||
|
} else if (scalar.IsU32()) {
|
||||||
|
defaultValue.u32 = scalar.AsU32();
|
||||||
|
} else if (scalar.IsI32()) {
|
||||||
|
defaultValue.i32 = scalar.AsI32();
|
||||||
|
} else if (scalar.IsFloat()) {
|
||||||
|
defaultValue.f32 = scalar.AsFloat();
|
||||||
|
} else {
|
||||||
|
UNREACHABLE();
|
||||||
|
}
|
||||||
|
}
|
||||||
EntryPointMetadata::OverridableConstant constant = {
|
EntryPointMetadata::OverridableConstant constant = {
|
||||||
name2Id.at(c.name), FromTintOverridableConstantType(c.type),
|
id, FromTintOverridableConstantType(c.type), c.is_initialized,
|
||||||
c.is_initialized};
|
defaultValue};
|
||||||
|
|
||||||
std::string identifier =
|
std::string identifier =
|
||||||
c.is_numeric_id_specified ? std::to_string(constant.id) : c.name;
|
c.is_numeric_id_specified ? std::to_string(constant.id) : c.name;
|
||||||
|
@ -651,6 +669,11 @@ namespace dawn_native {
|
||||||
std::move(identifier));
|
std::move(identifier));
|
||||||
// The insertion should have taken place
|
// The insertion should have taken place
|
||||||
ASSERT(it.second);
|
ASSERT(it.second);
|
||||||
|
} else {
|
||||||
|
auto it = metadata->initializedOverridableConstants.emplace(
|
||||||
|
std::move(identifier));
|
||||||
|
// The insertion should have taken place
|
||||||
|
ASSERT(it.second);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -150,6 +150,15 @@ namespace dawn_native {
|
||||||
using BindingGroupInfoMap = std::map<BindingNumber, ShaderBindingInfo>;
|
using BindingGroupInfoMap = std::map<BindingNumber, ShaderBindingInfo>;
|
||||||
using BindingInfoArray = ityp::array<BindGroupIndex, BindingGroupInfoMap, kMaxBindGroups>;
|
using BindingInfoArray = ityp::array<BindGroupIndex, BindingGroupInfoMap, kMaxBindGroups>;
|
||||||
|
|
||||||
|
// The WebGPU overridable constants only support these scalar types
|
||||||
|
union OverridableConstantScalar {
|
||||||
|
// Use int32_t for boolean to initialize the full 32bit
|
||||||
|
int32_t b;
|
||||||
|
float f32;
|
||||||
|
int32_t i32;
|
||||||
|
uint32_t u32;
|
||||||
|
};
|
||||||
|
|
||||||
// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
|
// Contains all the reflection data for a valid (ShaderModule, entryPoint, stage). They are
|
||||||
// stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so
|
// stored in the ShaderModuleBase and destroyed only when the shader program is destroyed so
|
||||||
// pointers to EntryPointMetadata are safe to store as long as you also keep a Ref to the
|
// pointers to EntryPointMetadata are safe to store as long as you also keep a Ref to the
|
||||||
|
@ -206,6 +215,11 @@ namespace dawn_native {
|
||||||
// Then it is required for the pipeline stage to have a constant record to initialize a
|
// Then it is required for the pipeline stage to have a constant record to initialize a
|
||||||
// value
|
// value
|
||||||
bool isInitialized;
|
bool isInitialized;
|
||||||
|
|
||||||
|
// Store the default initialized value in shader
|
||||||
|
// This is used by metal backend as the function_constant does not have dafault values
|
||||||
|
// Initialized when isInitialized == true
|
||||||
|
OverridableConstantScalar defaultValue;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Map identifier to overridable constant
|
// Map identifier to overridable constant
|
||||||
|
@ -216,6 +230,11 @@ namespace dawn_native {
|
||||||
// They need value initialization from pipeline stage or it is a validation error
|
// They need value initialization from pipeline stage or it is a validation error
|
||||||
std::unordered_set<std::string> uninitializedOverridableConstants;
|
std::unordered_set<std::string> uninitializedOverridableConstants;
|
||||||
|
|
||||||
|
// Store constants with shader initialized values as well
|
||||||
|
// This is used by metal backend to set values with default initializers that are not
|
||||||
|
// overridden
|
||||||
|
std::unordered_set<std::string> initializedOverridableConstants;
|
||||||
|
|
||||||
bool usesNumWorkgroups = false;
|
bool usesNumWorkgroups = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -156,7 +156,9 @@ namespace dawn_native { namespace metal {
|
||||||
|
|
||||||
bool IsMetalSupported() {
|
bool IsMetalSupported() {
|
||||||
// Metal was first introduced in macOS 10.11
|
// Metal was first introduced in macOS 10.11
|
||||||
return IsMacOSVersionAtLeast(10, 11);
|
// WebGPU is targeted at macOS 10.12+
|
||||||
|
// TODO(dawn:1181): Dawn native should allow non-conformant WebGPU on macOS 10.11
|
||||||
|
return IsMacOSVersionAtLeast(10, 12);
|
||||||
}
|
}
|
||||||
#elif defined(DAWN_PLATFORM_IOS)
|
#elif defined(DAWN_PLATFORM_IOS)
|
||||||
MaybeError GetDevicePCIInfo(id<MTLDevice> device, PCIIDs* ids) {
|
MaybeError GetDevicePCIInfo(id<MTLDevice> device, PCIIDs* ids) {
|
||||||
|
|
|
@ -17,6 +17,7 @@
|
||||||
#include "dawn_native/CreatePipelineAsyncTask.h"
|
#include "dawn_native/CreatePipelineAsyncTask.h"
|
||||||
#include "dawn_native/metal/DeviceMTL.h"
|
#include "dawn_native/metal/DeviceMTL.h"
|
||||||
#include "dawn_native/metal/ShaderModuleMTL.h"
|
#include "dawn_native/metal/ShaderModuleMTL.h"
|
||||||
|
#include "dawn_native/metal/UtilsMetal.h"
|
||||||
|
|
||||||
namespace dawn_native { namespace metal {
|
namespace dawn_native { namespace metal {
|
||||||
|
|
||||||
|
@ -31,11 +32,10 @@ namespace dawn_native { namespace metal {
|
||||||
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
|
auto mtlDevice = ToBackend(GetDevice())->GetMTLDevice();
|
||||||
|
|
||||||
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
|
const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute);
|
||||||
ShaderModule* computeModule = ToBackend(computeStage.module.Get());
|
|
||||||
const char* computeEntryPoint = computeStage.entryPoint.c_str();
|
|
||||||
ShaderModule::MetalFunctionData computeData;
|
ShaderModule::MetalFunctionData computeData;
|
||||||
DAWN_TRY(computeModule->CreateFunction(computeEntryPoint, SingleShaderStage::Compute,
|
|
||||||
ToBackend(GetLayout()), &computeData));
|
DAWN_TRY(CreateMTLFunction(computeStage, SingleShaderStage::Compute, ToBackend(GetLayout()),
|
||||||
|
&computeData));
|
||||||
|
|
||||||
NSError* error = nullptr;
|
NSError* error = nullptr;
|
||||||
mMtlComputePipelineState.Acquire([mtlDevice
|
mMtlComputePipelineState.Acquire([mtlDevice
|
||||||
|
|
|
@ -339,12 +339,9 @@ namespace dawn_native { namespace metal {
|
||||||
|
|
||||||
const PerStage<ProgrammableStage>& allStages = GetAllStages();
|
const PerStage<ProgrammableStage>& allStages = GetAllStages();
|
||||||
const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex];
|
const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex];
|
||||||
ShaderModule* vertexModule = ToBackend(vertexStage.module).Get();
|
|
||||||
const char* vertexEntryPoint = vertexStage.entryPoint.c_str();
|
|
||||||
ShaderModule::MetalFunctionData vertexData;
|
ShaderModule::MetalFunctionData vertexData;
|
||||||
DAWN_TRY(vertexModule->CreateFunction(vertexEntryPoint, SingleShaderStage::Vertex,
|
DAWN_TRY(CreateMTLFunction(vertexStage, SingleShaderStage::Vertex, ToBackend(GetLayout()),
|
||||||
ToBackend(GetLayout()), &vertexData, 0xFFFFFFFF,
|
&vertexData, 0xFFFFFFFF, this));
|
||||||
this));
|
|
||||||
|
|
||||||
descriptorMTL.vertexFunction = vertexData.function.Get();
|
descriptorMTL.vertexFunction = vertexData.function.Get();
|
||||||
if (vertexData.needsStorageBufferLength) {
|
if (vertexData.needsStorageBufferLength) {
|
||||||
|
@ -353,12 +350,9 @@ namespace dawn_native { namespace metal {
|
||||||
|
|
||||||
if (GetStageMask() & wgpu::ShaderStage::Fragment) {
|
if (GetStageMask() & wgpu::ShaderStage::Fragment) {
|
||||||
const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment];
|
const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment];
|
||||||
ShaderModule* fragmentModule = ToBackend(fragmentStage.module).Get();
|
|
||||||
const char* fragmentEntryPoint = fragmentStage.entryPoint.c_str();
|
|
||||||
ShaderModule::MetalFunctionData fragmentData;
|
ShaderModule::MetalFunctionData fragmentData;
|
||||||
DAWN_TRY(fragmentModule->CreateFunction(fragmentEntryPoint, SingleShaderStage::Fragment,
|
DAWN_TRY(CreateMTLFunction(fragmentStage, SingleShaderStage::Fragment,
|
||||||
ToBackend(GetLayout()), &fragmentData,
|
ToBackend(GetLayout()), &fragmentData, GetSampleMask()));
|
||||||
GetSampleMask()));
|
|
||||||
|
|
||||||
descriptorMTL.fragmentFunction = fragmentData.function.Get();
|
descriptorMTL.fragmentFunction = fragmentData.function.Get();
|
||||||
if (fragmentData.needsStorageBufferLength) {
|
if (fragmentData.needsStorageBufferLength) {
|
||||||
|
|
|
@ -39,10 +39,14 @@ namespace dawn_native { namespace metal {
|
||||||
bool needsStorageBufferLength;
|
bool needsStorageBufferLength;
|
||||||
std::vector<uint32_t> workgroupAllocations;
|
std::vector<uint32_t> workgroupAllocations;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// MTLFunctionConstantValues needs @available tag to compile
|
||||||
|
// Use id (like void*) in function signature as workaround and do static cast inside
|
||||||
MaybeError CreateFunction(const char* entryPointName,
|
MaybeError CreateFunction(const char* entryPointName,
|
||||||
SingleShaderStage stage,
|
SingleShaderStage stage,
|
||||||
const PipelineLayout* layout,
|
const PipelineLayout* layout,
|
||||||
MetalFunctionData* out,
|
MetalFunctionData* out,
|
||||||
|
id constantValues = nil,
|
||||||
uint32_t sampleMask = 0xFFFFFFFF,
|
uint32_t sampleMask = 0xFFFFFFFF,
|
||||||
const RenderPipeline* renderPipeline = nullptr);
|
const RenderPipeline* renderPipeline = nullptr);
|
||||||
|
|
||||||
|
|
|
@ -174,6 +174,7 @@ namespace dawn_native { namespace metal {
|
||||||
SingleShaderStage stage,
|
SingleShaderStage stage,
|
||||||
const PipelineLayout* layout,
|
const PipelineLayout* layout,
|
||||||
ShaderModule::MetalFunctionData* out,
|
ShaderModule::MetalFunctionData* out,
|
||||||
|
id constantValuesPointer,
|
||||||
uint32_t sampleMask,
|
uint32_t sampleMask,
|
||||||
const RenderPipeline* renderPipeline) {
|
const RenderPipeline* renderPipeline) {
|
||||||
ASSERT(!IsError());
|
ASSERT(!IsError());
|
||||||
|
@ -231,7 +232,26 @@ namespace dawn_native { namespace metal {
|
||||||
|
|
||||||
NSRef<NSString> name =
|
NSRef<NSString> name =
|
||||||
AcquireNSRef([[NSString alloc] initWithUTF8String:remappedEntryPointName.c_str()]);
|
AcquireNSRef([[NSString alloc] initWithUTF8String:remappedEntryPointName.c_str()]);
|
||||||
|
|
||||||
|
if (constantValuesPointer != nil) {
|
||||||
|
if (@available(macOS 10.12, *)) {
|
||||||
|
MTLFunctionConstantValues* constantValues = constantValuesPointer;
|
||||||
|
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()
|
||||||
|
constantValues:constantValues
|
||||||
|
error:&error]);
|
||||||
|
if (error != nullptr) {
|
||||||
|
if (error.code != MTLLibraryErrorCompileWarning) {
|
||||||
|
return DAWN_VALIDATION_ERROR(std::string("Function compile error: ") +
|
||||||
|
[error.localizedDescription UTF8String]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ASSERT(out->function != nil);
|
||||||
|
} else {
|
||||||
|
UNREACHABLE();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
|
out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]);
|
||||||
|
}
|
||||||
|
|
||||||
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
|
if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) &&
|
||||||
GetEntryPoint(entryPointName).usedVertexInputs.any()) {
|
GetEntryPoint(entryPointName).usedVertexInputs.any()) {
|
||||||
|
|
|
@ -17,10 +17,17 @@
|
||||||
|
|
||||||
#include "dawn_native/dawn_platform.h"
|
#include "dawn_native/dawn_platform.h"
|
||||||
#include "dawn_native/metal/DeviceMTL.h"
|
#include "dawn_native/metal/DeviceMTL.h"
|
||||||
|
#include "dawn_native/metal/ShaderModuleMTL.h"
|
||||||
#include "dawn_native/metal/TextureMTL.h"
|
#include "dawn_native/metal/TextureMTL.h"
|
||||||
|
|
||||||
#import <Metal/Metal.h>
|
#import <Metal/Metal.h>
|
||||||
|
|
||||||
|
namespace dawn_native {
|
||||||
|
struct ProgrammableStage;
|
||||||
|
struct EntryPointMetadata;
|
||||||
|
enum class SingleShaderStage;
|
||||||
|
}
|
||||||
|
|
||||||
namespace dawn_native { namespace metal {
|
namespace dawn_native { namespace metal {
|
||||||
|
|
||||||
MTLCompareFunction ToMetalCompareFunction(wgpu::CompareFunction compareFunction);
|
MTLCompareFunction ToMetalCompareFunction(wgpu::CompareFunction compareFunction);
|
||||||
|
@ -65,6 +72,15 @@ namespace dawn_native { namespace metal {
|
||||||
|
|
||||||
MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect);
|
MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect);
|
||||||
|
|
||||||
|
// Helper function to create function with constant values wrapped in
|
||||||
|
// if available branch
|
||||||
|
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
|
||||||
|
SingleShaderStage singleShaderStage,
|
||||||
|
PipelineLayout* pipelineLayout,
|
||||||
|
ShaderModule::MetalFunctionData* functionData,
|
||||||
|
uint32_t sampleMask = 0xFFFFFFFF,
|
||||||
|
const RenderPipeline* renderPipeline = nullptr);
|
||||||
|
|
||||||
}} // namespace dawn_native::metal
|
}} // namespace dawn_native::metal
|
||||||
|
|
||||||
#endif // DAWNNATIVE_METAL_UTILSMETAL_H_
|
#endif // DAWNNATIVE_METAL_UTILSMETAL_H_
|
||||||
|
|
|
@ -14,6 +14,8 @@
|
||||||
|
|
||||||
#include "dawn_native/metal/UtilsMetal.h"
|
#include "dawn_native/metal/UtilsMetal.h"
|
||||||
#include "dawn_native/CommandBuffer.h"
|
#include "dawn_native/CommandBuffer.h"
|
||||||
|
#include "dawn_native/Pipeline.h"
|
||||||
|
#include "dawn_native/ShaderModule.h"
|
||||||
|
|
||||||
#include "common/Assert.h"
|
#include "common/Assert.h"
|
||||||
|
|
||||||
|
@ -186,4 +188,106 @@ namespace dawn_native { namespace metal {
|
||||||
return MTLBlitOptionNone;
|
return MTLBlitOptionNone;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
|
||||||
|
SingleShaderStage singleShaderStage,
|
||||||
|
PipelineLayout* pipelineLayout,
|
||||||
|
ShaderModule::MetalFunctionData* functionData,
|
||||||
|
uint32_t sampleMask,
|
||||||
|
const RenderPipeline* renderPipeline) {
|
||||||
|
ShaderModule* shaderModule = ToBackend(programmableStage.module.Get());
|
||||||
|
const char* shaderEntryPoint = programmableStage.entryPoint.c_str();
|
||||||
|
const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint);
|
||||||
|
if (entryPointMetadata.overridableConstants.size() == 0) {
|
||||||
|
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage,
|
||||||
|
pipelineLayout, functionData, nil, sampleMask,
|
||||||
|
renderPipeline));
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (@available(macOS 10.12, *)) {
|
||||||
|
// MTLFunctionConstantValues can only be created within the if available branch
|
||||||
|
NSRef<MTLFunctionConstantValues> constantValues =
|
||||||
|
AcquireNSRef([MTLFunctionConstantValues new]);
|
||||||
|
|
||||||
|
std::unordered_set<std::string> overriddenConstants;
|
||||||
|
|
||||||
|
auto switchType = [&](EntryPointMetadata::OverridableConstant::Type dawnType,
|
||||||
|
MTLDataType* type, OverridableConstantScalar* entry,
|
||||||
|
double value = 0) {
|
||||||
|
switch (dawnType) {
|
||||||
|
case EntryPointMetadata::OverridableConstant::Type::Boolean:
|
||||||
|
*type = MTLDataTypeBool;
|
||||||
|
if (entry) {
|
||||||
|
entry->b = static_cast<int32_t>(value);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case EntryPointMetadata::OverridableConstant::Type::Float32:
|
||||||
|
*type = MTLDataTypeFloat;
|
||||||
|
if (entry) {
|
||||||
|
entry->f32 = static_cast<float>(value);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case EntryPointMetadata::OverridableConstant::Type::Int32:
|
||||||
|
*type = MTLDataTypeInt;
|
||||||
|
if (entry) {
|
||||||
|
entry->i32 = static_cast<int32_t>(value);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case EntryPointMetadata::OverridableConstant::Type::Uint32:
|
||||||
|
*type = MTLDataTypeUInt;
|
||||||
|
if (entry) {
|
||||||
|
entry->u32 = static_cast<uint32_t>(value);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
UNREACHABLE();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (const auto& pipelineConstant : programmableStage.constants) {
|
||||||
|
const std::string& name = pipelineConstant.first;
|
||||||
|
double value = pipelineConstant.second;
|
||||||
|
|
||||||
|
overriddenConstants.insert(name);
|
||||||
|
|
||||||
|
// This is already validated so `name` must exist
|
||||||
|
const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name);
|
||||||
|
|
||||||
|
MTLDataType type;
|
||||||
|
OverridableConstantScalar entry{};
|
||||||
|
|
||||||
|
switchType(moduleConstant.type, &type, &entry, value);
|
||||||
|
|
||||||
|
[constantValues.Get() setConstantValue:&entry type:type atIndex:moduleConstant.id];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set shader initialized default values because MSL function_constant
|
||||||
|
// has no default value
|
||||||
|
for (const std::string& name : entryPointMetadata.initializedOverridableConstants) {
|
||||||
|
if (overriddenConstants.count(name) != 0) {
|
||||||
|
// This constant already has overridden value
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must exist because it is validated
|
||||||
|
const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name);
|
||||||
|
ASSERT(moduleConstant.isInitialized);
|
||||||
|
MTLDataType type;
|
||||||
|
|
||||||
|
switchType(moduleConstant.type, &type, nullptr);
|
||||||
|
|
||||||
|
[constantValues.Get() setConstantValue:&moduleConstant.defaultValue
|
||||||
|
type:type
|
||||||
|
atIndex:moduleConstant.id];
|
||||||
|
}
|
||||||
|
|
||||||
|
DAWN_TRY(shaderModule->CreateFunction(
|
||||||
|
shaderEntryPoint, singleShaderStage, pipelineLayout, functionData,
|
||||||
|
constantValues.Get(), sampleMask, renderPipeline));
|
||||||
|
} else {
|
||||||
|
UNREACHABLE();
|
||||||
|
}
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
}} // namespace dawn_native::metal
|
}} // namespace dawn_native::metal
|
||||||
|
|
|
@ -53,7 +53,7 @@ namespace dawn_native { namespace vulkan {
|
||||||
|
|
||||||
createInfo.stage.pName = computeStage.entryPoint.c_str();
|
createInfo.stage.pName = computeStage.entryPoint.c_str();
|
||||||
|
|
||||||
std::vector<SpecializationDataEntry> specializationDataEntries;
|
std::vector<OverridableConstantScalar> specializationDataEntries;
|
||||||
std::vector<VkSpecializationMapEntry> specializationMapEntries;
|
std::vector<VkSpecializationMapEntry> specializationMapEntries;
|
||||||
VkSpecializationInfo specializationInfo{};
|
VkSpecializationInfo specializationInfo{};
|
||||||
createInfo.stage.pSpecializationInfo =
|
createInfo.stage.pSpecializationInfo =
|
||||||
|
|
|
@ -339,7 +339,7 @@ namespace dawn_native { namespace vulkan {
|
||||||
|
|
||||||
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
|
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
|
||||||
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
|
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
|
||||||
std::array<std::vector<SpecializationDataEntry>, 2> specializationDataEntriesPerStages;
|
std::array<std::vector<OverridableConstantScalar>, 2> specializationDataEntriesPerStages;
|
||||||
std::array<std::vector<VkSpecializationMapEntry>, 2> specializationMapEntriesPerStages;
|
std::array<std::vector<VkSpecializationMapEntry>, 2> specializationMapEntriesPerStages;
|
||||||
std::array<VkSpecializationInfo, 2> specializationInfoPerStages;
|
std::array<VkSpecializationInfo, 2> specializationInfoPerStages;
|
||||||
uint32_t stageCount = 0;
|
uint32_t stageCount = 0;
|
||||||
|
|
|
@ -201,7 +201,7 @@ namespace dawn_native { namespace vulkan {
|
||||||
VkSpecializationInfo* GetVkSpecializationInfo(
|
VkSpecializationInfo* GetVkSpecializationInfo(
|
||||||
const ProgrammableStage& programmableStage,
|
const ProgrammableStage& programmableStage,
|
||||||
VkSpecializationInfo* specializationInfo,
|
VkSpecializationInfo* specializationInfo,
|
||||||
std::vector<SpecializationDataEntry>* specializationDataEntries,
|
std::vector<OverridableConstantScalar>* specializationDataEntries,
|
||||||
std::vector<VkSpecializationMapEntry>* specializationMapEntries) {
|
std::vector<VkSpecializationMapEntry>* specializationMapEntries) {
|
||||||
ASSERT(specializationInfo);
|
ASSERT(specializationInfo);
|
||||||
ASSERT(specializationDataEntries);
|
ASSERT(specializationDataEntries);
|
||||||
|
@ -224,10 +224,10 @@ namespace dawn_native { namespace vulkan {
|
||||||
specializationMapEntries->push_back(
|
specializationMapEntries->push_back(
|
||||||
VkSpecializationMapEntry{moduleConstant.id,
|
VkSpecializationMapEntry{moduleConstant.id,
|
||||||
static_cast<uint32_t>(specializationDataEntries->size() *
|
static_cast<uint32_t>(specializationDataEntries->size() *
|
||||||
sizeof(SpecializationDataEntry)),
|
sizeof(OverridableConstantScalar)),
|
||||||
sizeof(SpecializationDataEntry)});
|
sizeof(OverridableConstantScalar)});
|
||||||
|
|
||||||
SpecializationDataEntry entry{};
|
OverridableConstantScalar entry{};
|
||||||
switch (moduleConstant.type) {
|
switch (moduleConstant.type) {
|
||||||
case EntryPointMetadata::OverridableConstant::Type::Boolean:
|
case EntryPointMetadata::OverridableConstant::Type::Boolean:
|
||||||
entry.b = static_cast<int32_t>(value);
|
entry.b = static_cast<int32_t>(value);
|
||||||
|
@ -250,7 +250,7 @@ namespace dawn_native { namespace vulkan {
|
||||||
specializationInfo->mapEntryCount = static_cast<uint32_t>(specializationMapEntries->size());
|
specializationInfo->mapEntryCount = static_cast<uint32_t>(specializationMapEntries->size());
|
||||||
specializationInfo->pMapEntries = specializationMapEntries->data();
|
specializationInfo->pMapEntries = specializationMapEntries->data();
|
||||||
specializationInfo->dataSize =
|
specializationInfo->dataSize =
|
||||||
specializationDataEntries->size() * sizeof(SpecializationDataEntry);
|
specializationDataEntries->size() * sizeof(OverridableConstantScalar);
|
||||||
specializationInfo->pData = specializationDataEntries->data();
|
specializationInfo->pData = specializationDataEntries->data();
|
||||||
|
|
||||||
return specializationInfo;
|
return specializationInfo;
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
|
|
||||||
namespace dawn_native {
|
namespace dawn_native {
|
||||||
struct ProgrammableStage;
|
struct ProgrammableStage;
|
||||||
|
union OverridableConstantScalar;
|
||||||
} // namespace dawn_native
|
} // namespace dawn_native
|
||||||
|
|
||||||
namespace dawn_native { namespace vulkan {
|
namespace dawn_native { namespace vulkan {
|
||||||
|
@ -111,23 +112,13 @@ namespace dawn_native { namespace vulkan {
|
||||||
const char* prefix,
|
const char* prefix,
|
||||||
std::string label = "");
|
std::string label = "");
|
||||||
|
|
||||||
// Helpers for creating VkSpecializationInfo
|
|
||||||
// The WebGPU overridable constants only support these scalar types
|
|
||||||
union SpecializationDataEntry {
|
|
||||||
// Use int32_t for boolean to initialize the full 32bit
|
|
||||||
int32_t b;
|
|
||||||
float f32;
|
|
||||||
int32_t i32;
|
|
||||||
uint32_t u32;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Returns nullptr or &specializationInfo
|
// Returns nullptr or &specializationInfo
|
||||||
// specializationInfo, specializationDataEntries, specializationMapEntries needs to
|
// specializationInfo, specializationDataEntries, specializationMapEntries needs to
|
||||||
// be alive at least until VkSpecializationInfo is passed into Vulkan Create*Pipelines
|
// be alive at least until VkSpecializationInfo is passed into Vulkan Create*Pipelines
|
||||||
VkSpecializationInfo* GetVkSpecializationInfo(
|
VkSpecializationInfo* GetVkSpecializationInfo(
|
||||||
const ProgrammableStage& programmableStage,
|
const ProgrammableStage& programmableStage,
|
||||||
VkSpecializationInfo* specializationInfo,
|
VkSpecializationInfo* specializationInfo,
|
||||||
std::vector<SpecializationDataEntry>* specializationDataEntries,
|
std::vector<OverridableConstantScalar>* specializationDataEntries,
|
||||||
std::vector<VkSpecializationMapEntry>* specializationMapEntries);
|
std::vector<VkSpecializationMapEntry>* specializationMapEntries);
|
||||||
|
|
||||||
}} // namespace dawn_native::vulkan
|
}} // namespace dawn_native::vulkan
|
||||||
|
|
|
@ -391,8 +391,8 @@ fn main([[location(0)]] pos : vec4<f32>) -> [[builtin(position)]] vec4<f32> {
|
||||||
|
|
||||||
// Test overridable constants without numeric identifiers
|
// Test overridable constants without numeric identifiers
|
||||||
TEST_P(ShaderTests, OverridableConstants) {
|
TEST_P(ShaderTests, OverridableConstants) {
|
||||||
// TODO(dawn:1041): Only Vulkan backend is implemented
|
// TODO(dawn:1137): D3D12 backend is unimplemented
|
||||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan());
|
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
||||||
|
|
||||||
uint32_t const kCount = 11;
|
uint32_t const kCount = 11;
|
||||||
std::vector<uint32_t> expected(kCount);
|
std::vector<uint32_t> expected(kCount);
|
||||||
|
@ -469,8 +469,8 @@ TEST_P(ShaderTests, OverridableConstants) {
|
||||||
|
|
||||||
// Test overridable constants with numeric identifiers
|
// Test overridable constants with numeric identifiers
|
||||||
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
||||||
// TODO(dawn:1041): Only Vulkan backend is implemented
|
// TODO(dawn:1137): D3D12 backend is unimplemented
|
||||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan());
|
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
||||||
|
|
||||||
uint32_t const kCount = 4;
|
uint32_t const kCount = 4;
|
||||||
std::vector<uint32_t> expected{1u, 2u, 3u, 0u};
|
std::vector<uint32_t> expected{1u, 2u, 3u, 0u};
|
||||||
|
@ -525,8 +525,8 @@ TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) {
|
||||||
|
|
||||||
// Test overridable constants for different entry points
|
// Test overridable constants for different entry points
|
||||||
TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
||||||
// TODO(dawn:1041): Only Vulkan backend is implemented
|
// TODO(dawn:1137): D3D12 backend is unimplemented
|
||||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan());
|
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
||||||
|
|
||||||
uint32_t const kCount = 1;
|
uint32_t const kCount = 1;
|
||||||
std::vector<uint32_t> expected1{1u};
|
std::vector<uint32_t> expected1{1u};
|
||||||
|
@ -607,8 +607,8 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) {
|
||||||
// Draw a triangle covering the render target, with vertex position and color values from
|
// Draw a triangle covering the render target, with vertex position and color values from
|
||||||
// overridable constants
|
// overridable constants
|
||||||
TEST_P(ShaderTests, OverridableConstantsRenderPipeline) {
|
TEST_P(ShaderTests, OverridableConstantsRenderPipeline) {
|
||||||
// TODO(dawn:1041): Only Vulkan backend is implemented
|
// TODO(dawn:1137): D3D12 backend is unimplemented
|
||||||
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan());
|
DAWN_TEST_UNSUPPORTED_IF(!IsVulkan() && !IsMetal());
|
||||||
|
|
||||||
wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
|
wgpu::ShaderModule vsModule = utils::CreateShaderModule(device, R"(
|
||||||
[[override(1111)]] let xright: f32;
|
[[override(1111)]] let xright: f32;
|
||||||
|
|
Loading…
Reference in New Issue