From 145337f309abad41995dbd6740976c72b9d2968f Mon Sep 17 00:00:00 2001 From: shrekshao Date: Wed, 7 Sep 2022 20:09:54 +0000 Subject: [PATCH] Use SubstituteOverride transform to implement overrides Remove the old backend specific implementation for overrides. Use tint SubstituteOverride transform to replace overrides with const expressions and use the updated program at pipeline creation time. This CL also adds support for overrides used as workgroup size and related tests. Workgroup size validation now happens in backend code and at compute pipeline creation time. Bug: dawn:1504 Change-Id: I7df1fe9c3e358caa23235eacd6d13ba0b2998aec Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/99821 Commit-Queue: Shrek Shao Reviewed-by: Austin Eng Kokoro: Kokoro --- .../templates/dawn/native/api_StreamImpl.cpp | 57 ++- src/dawn/native/Limits.cpp | 18 + src/dawn/native/Limits.h | 15 + src/dawn/native/ObjectContentHasher.h | 17 + src/dawn/native/Pipeline.cpp | 16 +- src/dawn/native/Pipeline.h | 3 - src/dawn/native/ShaderModule.cpp | 90 ++-- src/dawn/native/ShaderModule.h | 22 +- src/dawn/native/StreamImplTint.cpp | 17 + src/dawn/native/TintUtils.cpp | 16 + src/dawn/native/TintUtils.h | 4 + src/dawn/native/d3d12/RenderPipelineD3D12.cpp | 7 +- src/dawn/native/d3d12/ShaderModuleD3D12.cpp | 186 +++------ src/dawn/native/metal/ComputePipelineMTL.mm | 8 +- src/dawn/native/metal/RenderPipelineMTL.mm | 10 +- src/dawn/native/metal/ShaderModuleMTL.h | 12 +- src/dawn/native/metal/ShaderModuleMTL.mm | 119 +++--- src/dawn/native/metal/UtilsMetal.h | 9 - src/dawn/native/metal/UtilsMetal.mm | 98 ----- src/dawn/native/stream/Stream.cpp | 2 + src/dawn/native/vulkan/ComputePipelineVk.cpp | 8 +- src/dawn/native/vulkan/RenderPipelineVk.cpp | 10 +- src/dawn/native/vulkan/ShaderModuleVk.cpp | 87 +++- src/dawn/native/vulkan/ShaderModuleVk.h | 27 +- src/dawn/native/vulkan/UtilsVulkan.cpp | 56 --- src/dawn/native/vulkan/UtilsVulkan.h | 9 - src/dawn/tests/BUILD.gn | 2 + src/dawn/tests/end2end/ObjectCachingTests.cpp | 84 ++++ src/dawn/tests/end2end/ShaderTests.cpp | 130 +++++- .../tests/end2end/ShaderValidationTests.cpp | 395 ++++++++++++++++++ .../native/ObjectContentHasherTests.cpp | 81 ++++ .../ShaderModuleValidationTests.cpp | 81 ---- .../validation/UnsafeAPIValidationTests.cpp | 40 -- src/tint/inspector/inspector.cc | 196 +++++---- src/tint/inspector/inspector.h | 8 + src/tint/transform/substitute_override.h | 4 + 36 files changed, 1246 insertions(+), 698 deletions(-) create mode 100644 src/dawn/tests/end2end/ShaderValidationTests.cpp create mode 100644 src/dawn/tests/unittests/native/ObjectContentHasherTests.cpp diff --git a/generator/templates/dawn/native/api_StreamImpl.cpp b/generator/templates/dawn/native/api_StreamImpl.cpp index 82b64043c2..e6617aacb9 100644 --- a/generator/templates/dawn/native/api_StreamImpl.cpp +++ b/generator/templates/dawn/native/api_StreamImpl.cpp @@ -25,7 +25,15 @@ namespace {{native_namespace}} { // -// Cache key writers for wgpu structures used in caching. +// Streaming readers for wgpu structures. +// +{% macro render_reader(member) %} + {%- set name = member.name.camelCase() -%} + DAWN_TRY(StreamOut(source, &t->{{name}})); +{% endmacro %} + +// +// Streaming writers for wgpu structures. // {% macro render_writer(member) %} {%- set name = member.name.camelCase() -%} @@ -38,31 +46,50 @@ namespace {{native_namespace}} { {% endif %} {% endmacro %} -{# Helper macro to render writers. Should be used in a call block to provide additional custom +{# Helper macro to render readers and writers. Should be used in a call block to provide additional custom handling when necessary. The optional `omit` field can be used to omit fields that are either handled in the custom code, or unnecessary in the serialized output. Example: - {% call render_cache_key_writer("struct name", omits=["omit field"]) %} + {% call render_streaming_impl("struct name", writer=true, reader=false, omits=["omit field"]) %} // Custom C++ code to handle special types/members that are hard to generate code for {% endcall %} + One day we should probably make the generator smart enough to generate everything it can + instead of manually adding streaming implementations here. #} -{% macro render_cache_key_writer(json_type, omits=[]) %} +{% macro render_streaming_impl(json_type, writer, reader, omits=[]) %} {%- set cpp_type = types[json_type].name.CamelCase() -%} - template <> - void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) { - {{ caller() }} - {% for member in types[json_type].members %} - {%- if not member.name.get() in omits %} - {{render_writer(member)}} - {%- endif %} - {% endfor %} - } + {% if reader %} + template <> + MaybeError stream::Stream<{{cpp_type}}>::Read(stream::Source* source, {{cpp_type}}* t) { + {{ caller() }} + {% for member in types[json_type].members %} + {% if not member.name.get() in omits %} + {{render_reader(member)}} + {% endif %} + {% endfor %} + return {}; + } + {% endif %} + {% if writer %} + template <> + void stream::Stream<{{cpp_type}}>::Write(stream::Sink* sink, const {{cpp_type}}& t) { + {{ caller() }} + {% for member in types[json_type].members %} + {% if not member.name.get() in omits %} + {{render_writer(member)}} + {% endif %} + {% endfor %} + } + {% endif %} {% endmacro %} -{% call render_cache_key_writer("adapter properties") %} +{% call render_streaming_impl("adapter properties", true, false) %} {% endcall %} -{% call render_cache_key_writer("dawn cache device descriptor") %} +{% call render_streaming_impl("dawn cache device descriptor", true, false) %} +{% endcall %} + +{% call render_streaming_impl("extent 3D", true, true) %} {% endcall %} } // namespace {{native_namespace}} diff --git a/src/dawn/native/Limits.cpp b/src/dawn/native/Limits.cpp index ef285b2a4d..215609c126 100644 --- a/src/dawn/native/Limits.cpp +++ b/src/dawn/native/Limits.cpp @@ -215,4 +215,22 @@ Limits ApplyLimitTiers(Limits limits) { return limits; } +#define DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT(type, name) \ + { result.name = limits.name; } +#define DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(MEMBERS) \ + MEMBERS(DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT) +LimitsForCompilationRequest LimitsForCompilationRequest::Create(const Limits& limits) { + LimitsForCompilationRequest result; + DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS) + return result; +} +#undef DAWN_INTERNAL_LIMITS_FOREACH_MEMBER_ASSIGNMENT +#undef DAWN_INTERNAL_LIMITS_MEMBER_ASSIGNMENT + +template <> +void stream::Stream::Write(Sink* s, + const LimitsForCompilationRequest& t) { + t.VisitAll([&](const auto&... members) { StreamIn(s, members...); }); +} + } // namespace dawn::native diff --git a/src/dawn/native/Limits.h b/src/dawn/native/Limits.h index cc81742053..8091a31cd2 100644 --- a/src/dawn/native/Limits.h +++ b/src/dawn/native/Limits.h @@ -16,6 +16,7 @@ #define SRC_DAWN_NATIVE_LIMITS_H_ #include "dawn/native/Error.h" +#include "dawn/native/VisitableMembers.h" #include "dawn/native/dawn_platform.h" namespace dawn::native { @@ -38,6 +39,20 @@ MaybeError ValidateLimits(const Limits& supportedLimits, const Limits& requiredL // Returns a copy of |limits| where limit tiers are applied. Limits ApplyLimitTiers(Limits limits); +// If there are new limit member needed at shader compilation time +// Simply append a new X(type, name) here. +#define LIMITS_FOR_COMPILATION_REQUEST_MEMBERS(X) \ + X(uint32_t, maxComputeWorkgroupSizeX) \ + X(uint32_t, maxComputeWorkgroupSizeY) \ + X(uint32_t, maxComputeWorkgroupSizeZ) \ + X(uint32_t, maxComputeInvocationsPerWorkgroup) \ + X(uint32_t, maxComputeWorkgroupStorageSize) + +struct LimitsForCompilationRequest { + static LimitsForCompilationRequest Create(const Limits& limits); + DAWN_VISITABLE_MEMBERS(LIMITS_FOR_COMPILATION_REQUEST_MEMBERS) +}; + } // namespace dawn::native #endif // SRC_DAWN_NATIVE_LIMITS_H_ diff --git a/src/dawn/native/ObjectContentHasher.h b/src/dawn/native/ObjectContentHasher.h index 4211fb3514..40ba602649 100644 --- a/src/dawn/native/ObjectContentHasher.h +++ b/src/dawn/native/ObjectContentHasher.h @@ -15,7 +15,9 @@ #ifndef SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_ #define SRC_DAWN_NATIVE_OBJECTCONTENTHASHER_H_ +#include #include +#include #include #include "dawn/common/HashUtils.h" @@ -60,6 +62,13 @@ class ObjectContentHasher { } }; + template + struct RecordImpl> { + static constexpr void Call(ObjectContentHasher* recorder, const std::map& map) { + recorder->RecordIterable>(map); + } + }; + template constexpr void RecordIterable(const IteratorT& iterable) { for (auto it = iterable.begin(); it != iterable.end(); ++it) { @@ -67,6 +76,14 @@ class ObjectContentHasher { } } + template + struct RecordImpl> { + static constexpr void Call(ObjectContentHasher* recorder, const std::pair& pair) { + recorder->Record(pair.first); + recorder->Record(pair.second); + } + }; + size_t mContentHash = 0; }; diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp index 50a4b9c105..dc55f637d1 100644 --- a/src/dawn/native/Pipeline.cpp +++ b/src/dawn/native/Pipeline.cpp @@ -58,12 +58,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device, DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout)); } - if (constantCount > 0u && device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs)) { - return DAWN_VALIDATION_ERROR( - "Pipeline overridable constants are disallowed because they are partially " - "implemented."); - } - // Validate if overridable constants exist in shader module // pipelineBase is not yet constructed at this moment so iterate constants from descriptor size_t numUninitializedConstants = metadata.uninitializedOverrides.size(); @@ -233,6 +227,7 @@ size_t PipelineBase::ComputeContentHash() { for (SingleShaderStage stage : IterateStages(mStageMask)) { recorder.Record(mStages[stage].module->GetContentHash()); recorder.Record(mStages[stage].entryPoint); + recorder.Record(mStages[stage].constants); } return recorder.GetContentHash(); @@ -248,7 +243,14 @@ bool PipelineBase::EqualForCache(const PipelineBase* a, const PipelineBase* b) { for (SingleShaderStage stage : IterateStages(a->mStageMask)) { // The module is deduplicated so it can be compared by pointer. if (a->mStages[stage].module.Get() != b->mStages[stage].module.Get() || - a->mStages[stage].entryPoint != b->mStages[stage].entryPoint) { + a->mStages[stage].entryPoint != b->mStages[stage].entryPoint || + a->mStages[stage].constants.size() != b->mStages[stage].constants.size()) { + return false; + } + + // If the constants.size are the same, we still need to compare the key and value. + if (!std::equal(a->mStages[stage].constants.begin(), a->mStages[stage].constants.end(), + b->mStages[stage].constants.begin())) { return false; } } diff --git a/src/dawn/native/Pipeline.h b/src/dawn/native/Pipeline.h index 2d5b6dfc65..adf8f8d107 100644 --- a/src/dawn/native/Pipeline.h +++ b/src/dawn/native/Pipeline.h @@ -40,9 +40,6 @@ MaybeError ValidateProgrammableStage(DeviceBase* device, const PipelineLayoutBase* layout, SingleShaderStage stage); -// Use map to make sure constant keys are sorted for creating shader cache keys -using PipelineConstantEntries = std::map; - struct ProgrammableStage { Ref module; std::string entryPoint; diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index 6769f7e4ff..e527634b2e 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp @@ -20,7 +20,6 @@ #include "absl/strings/str_format.h" #include "dawn/common/BitSetIterator.h" #include "dawn/common/Constants.h" -#include "dawn/common/HashUtils.h" #include "dawn/native/BindGroupLayout.h" #include "dawn/native/ChainUtils_autogen.h" #include "dawn/native/CompilationMessages.h" @@ -511,7 +510,6 @@ ResultOrError> ReflectEntryPointUsingTint( const DeviceBase* device, tint::inspector::Inspector* inspector, const tint::inspector::EntryPoint& entryPoint) { - const CombinedLimits& limits = device->GetLimits(); constexpr uint32_t kMaxInterStageShaderLocation = kMaxInterStageShaderVariables - 1; std::unique_ptr metadata = std::make_unique(); @@ -528,10 +526,6 @@ ResultOrError> ReflectEntryPointUsingTint( })() if (!entryPoint.overrides.empty()) { - DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs), - "Pipeline overridable constants are disallowed because they " - "are partially implemented."); - const auto& name2Id = inspector->GetNamedOverrideIds(); const auto& id2Scalar = inspector->GetOverrideDefaultValues(); @@ -553,10 +547,10 @@ ResultOrError> ReflectEntryPointUsingTint( UNREACHABLE(); } } - EntryPointMetadata::Override override = {id.value, FromTintOverrideType(c.type), + EntryPointMetadata::Override override = {id, FromTintOverrideType(c.type), c.is_initialized, defaultValue}; - std::string identifier = c.is_id_specified ? std::to_string(override.id) : c.name; + std::string identifier = c.is_id_specified ? std::to_string(override.id.value) : c.name; metadata->overrides[identifier] = override; if (!c.is_initialized) { @@ -575,39 +569,6 @@ ResultOrError> ReflectEntryPointUsingTint( DAWN_TRY_ASSIGN(metadata->stage, TintPipelineStageToShaderStage(entryPoint.stage)); if (metadata->stage == SingleShaderStage::Compute) { - auto workgroup_size = entryPoint.workgroup_size; - DAWN_INVALID_IF( - !workgroup_size.has_value(), - "TODO(crbug.com/dawn/1504): Dawn does not currently support @workgroup_size " - "attributes using override-expressions"); - DelayedInvalidIf(workgroup_size->x > limits.v1.maxComputeWorkgroupSizeX || - workgroup_size->y > limits.v1.maxComputeWorkgroupSizeY || - workgroup_size->z > limits.v1.maxComputeWorkgroupSizeZ, - "Entry-point uses workgroup_size(%u, %u, %u) that exceeds the " - "maximum allowed (%u, %u, %u).", - workgroup_size->x, workgroup_size->y, workgroup_size->z, - limits.v1.maxComputeWorkgroupSizeX, limits.v1.maxComputeWorkgroupSizeY, - limits.v1.maxComputeWorkgroupSizeZ); - - // Dimensions have already been validated against their individual limits above. - // Cast to uint64_t to avoid overflow in this multiplication. - uint64_t numInvocations = - static_cast(workgroup_size->x) * workgroup_size->y * workgroup_size->z; - DelayedInvalidIf(numInvocations > limits.v1.maxComputeInvocationsPerWorkgroup, - "The total number of workgroup invocations (%u) exceeds the " - "maximum allowed (%u).", - numInvocations, limits.v1.maxComputeInvocationsPerWorkgroup); - - const size_t workgroupStorageSize = inspector->GetWorkgroupStorageSize(entryPoint.name); - DelayedInvalidIf(workgroupStorageSize > limits.v1.maxComputeWorkgroupStorageSize, - "The total use of workgroup storage (%u bytes) is larger than " - "the maximum allowed (%u bytes).", - workgroupStorageSize, limits.v1.maxComputeWorkgroupStorageSize); - - metadata->localWorkgroupSize.x = workgroup_size->x; - metadata->localWorkgroupSize.y = workgroup_size->y; - metadata->localWorkgroupSize.z = workgroup_size->z; - metadata->usesNumWorkgroups = entryPoint.num_workgroups_used; } @@ -883,6 +844,46 @@ MaybeError ReflectShaderUsingTint(const DeviceBase* device, } } // anonymous namespace +ResultOrError ValidateComputeStageWorkgroupSize( + const tint::Program& program, + const char* entryPointName, + const LimitsForCompilationRequest& limits) { + tint::inspector::Inspector inspector(&program); + // At this point the entry point must exist and must have workgroup size values. + tint::inspector::EntryPoint entryPoint = inspector.GetEntryPoint(entryPointName); + ASSERT(entryPoint.workgroup_size.has_value()); + const tint::inspector::WorkgroupSize& workgroup_size = entryPoint.workgroup_size.value(); + + DAWN_INVALID_IF(workgroup_size.x < 1 || workgroup_size.y < 1 || workgroup_size.z < 1, + "Entry-point uses workgroup_size(%u, %u, %u) that are below the " + "minimum allowed (1, 1, 1).", + workgroup_size.x, workgroup_size.y, workgroup_size.z); + + DAWN_INVALID_IF(workgroup_size.x > limits.maxComputeWorkgroupSizeX || + workgroup_size.y > limits.maxComputeWorkgroupSizeY || + workgroup_size.z > limits.maxComputeWorkgroupSizeZ, + "Entry-point uses workgroup_size(%u, %u, %u) that exceeds the " + "maximum allowed (%u, %u, %u).", + workgroup_size.x, workgroup_size.y, workgroup_size.z, + limits.maxComputeWorkgroupSizeX, limits.maxComputeWorkgroupSizeY, + limits.maxComputeWorkgroupSizeZ); + + uint64_t numInvocations = + static_cast(workgroup_size.x) * workgroup_size.y * workgroup_size.z; + DAWN_INVALID_IF(numInvocations > limits.maxComputeInvocationsPerWorkgroup, + "The total number of workgroup invocations (%u) exceeds the " + "maximum allowed (%u).", + numInvocations, limits.maxComputeInvocationsPerWorkgroup); + + const size_t workgroupStorageSize = inspector.GetWorkgroupStorageSize(entryPointName); + DAWN_INVALID_IF(workgroupStorageSize > limits.maxComputeWorkgroupStorageSize, + "The total use of workgroup storage (%u bytes) is larger than " + "the maximum allowed (%u bytes).", + workgroupStorageSize, limits.maxComputeWorkgroupStorageSize); + + return Extent3D{workgroup_size.x, workgroup_size.y, workgroup_size.z}; +} + ShaderModuleParseResult::ShaderModuleParseResult() = default; ShaderModuleParseResult::~ShaderModuleParseResult() = default; @@ -1200,11 +1201,4 @@ MaybeError ShaderModuleBase::InitializeBase(ShaderModuleParseResult* parseResult return {}; } -size_t PipelineLayoutEntryPointPairHashFunc::operator()( - const PipelineLayoutEntryPointPair& pair) const { - size_t hash = 0; - HashCombine(&hash, pair.first, pair.second); - return hash; -} - } // namespace dawn::native diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h index 170b388000..b04e9ed7e1 100644 --- a/src/dawn/native/ShaderModule.h +++ b/src/dawn/native/ShaderModule.h @@ -33,10 +33,12 @@ #include "dawn/native/Format.h" #include "dawn/native/Forward.h" #include "dawn/native/IntegerTypes.h" +#include "dawn/native/Limits.h" #include "dawn/native/ObjectBase.h" #include "dawn/native/PerStage.h" #include "dawn/native/VertexFormat.h" #include "dawn/native/dawn_platform.h" +#include "tint/override_id.h" namespace tint { @@ -76,10 +78,8 @@ enum class InterpolationSampling { Sample, }; -using PipelineLayoutEntryPointPair = std::pair; -struct PipelineLayoutEntryPointPairHashFunc { - size_t operator()(const PipelineLayoutEntryPointPair& pair) const; -}; +// Use map to make sure constant keys are sorted for creating shader cache keys +using PipelineConstantEntries = std::map; // A map from name to EntryPointMetadata. using EntryPointMetadataTable = @@ -108,6 +108,13 @@ MaybeError ValidateCompatibilityWithPipelineLayout(DeviceBase* device, const EntryPointMetadata& entryPoint, const PipelineLayoutBase* layout); +// Return extent3D with workgroup size dimension info if it is valid +// width = x, height = y, depthOrArrayLength = z +ResultOrError ValidateComputeStageWorkgroupSize( + const tint::Program& program, + const char* entryPointName, + const LimitsForCompilationRequest& limits); + RequiredBufferSizes ComputeRequiredBufferSizesForLayout(const EntryPointMetadata& entryPoint, const PipelineLayoutBase* layout); ResultOrError RunTransforms(tint::transform::Transform* transform, @@ -204,14 +211,12 @@ struct EntryPointMetadata { std::bitset usedInterStageVariables; std::array interStageVariables; - // The local workgroup size declared for a compute entry point (or 0s otehrwise). - Origin3D localWorkgroupSize; - // The shader stage for this binding. SingleShaderStage stage; struct Override { - uint32_t id; + tint::OverrideId id; + // Match tint::inspector::Override::Type // Bool is defined as a macro on linux X11 and cannot compile enum class Type { Boolean, Float32, Uint32, Int32 } type; @@ -273,6 +278,7 @@ class ShaderModuleBase : public ApiObjectBase, public CachedObject { bool operator()(const ShaderModuleBase* a, const ShaderModuleBase* b) const; }; + // This returns tint program before running transforms. const tint::Program* GetTintProgram() const; void APIGetCompilationInfo(wgpu::CompilationInfoCallback callback, void* userdata); diff --git a/src/dawn/native/StreamImplTint.cpp b/src/dawn/native/StreamImplTint.cpp index deeaa0f11a..5f077ec129 100644 --- a/src/dawn/native/StreamImplTint.cpp +++ b/src/dawn/native/StreamImplTint.cpp @@ -65,6 +65,23 @@ void stream::Stream::Write( StreamInTintObject(cfg, sink); } +// static +template <> +void stream::Stream::Write( + stream::Sink* sink, + const tint::transform::SubstituteOverride::Config& cfg) { + StreamInTintObject(cfg, sink); +} + +// static +template <> +void stream::Stream::Write(stream::Sink* sink, const tint::OverrideId& id) { + // TODO(tint:1640): fix the include build issues and use StreamInTintObject instead. + static_assert(offsetof(tint::OverrideId, value) == 0, + "Please update serialization for tint::OverrideId"); + StreamIn(sink, id.value); +} + // static template <> void stream::Stream::Write( diff --git a/src/dawn/native/TintUtils.cpp b/src/dawn/native/TintUtils.cpp index 24d4cc6afb..a2cf5dc4e7 100644 --- a/src/dawn/native/TintUtils.cpp +++ b/src/dawn/native/TintUtils.cpp @@ -16,6 +16,7 @@ #include "dawn/native/BindGroupLayout.h" #include "dawn/native/Device.h" +#include "dawn/native/Pipeline.h" #include "dawn/native/PipelineLayout.h" #include "dawn/native/RenderPipeline.h" @@ -183,6 +184,21 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( return cfg; } +tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig( + const ProgrammableStage& stage) { + const EntryPointMetadata& metadata = *stage.metadata; + const auto& constants = stage.constants; + + tint::transform::SubstituteOverride::Config cfg; + + for (const auto& [key, value] : constants) { + const auto& o = metadata.overrides.at(key); + cfg.map.insert({o.id, value}); + } + + return cfg; +} + } // namespace dawn::native namespace tint::sem { diff --git a/src/dawn/native/TintUtils.h b/src/dawn/native/TintUtils.h index 7c03881502..d6bea1d88f 100644 --- a/src/dawn/native/TintUtils.h +++ b/src/dawn/native/TintUtils.h @@ -26,6 +26,7 @@ namespace dawn::native { class DeviceBase; class PipelineLayoutBase; +struct ProgrammableStage; class RenderPipelineBase; // Indicates that for the lifetime of this object tint internal compiler errors should be @@ -47,6 +48,9 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( const std::string_view& entryPoint, BindGroupIndex pullingBufferBindingSet); +tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig( + const ProgrammableStage& stage); + } // namespace dawn::native namespace tint::sem { diff --git a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp index b1da726691..80bc6fce9c 100644 --- a/src/dawn/native/d3d12/RenderPipelineD3D12.cpp +++ b/src/dawn/native/d3d12/RenderPipelineD3D12.cpp @@ -351,8 +351,6 @@ MaybeError RenderPipeline::Initialize() { D3D12_GRAPHICS_PIPELINE_STATE_DESC descriptorD3D12 = {}; - PerStage pipelineStages = GetAllStages(); - PerStage shaders; shaders[SingleShaderStage::Vertex] = &descriptorD3D12.VS; shaders[SingleShaderStage::Fragment] = &descriptorD3D12.PS; @@ -360,8 +358,9 @@ MaybeError RenderPipeline::Initialize() { PerStage compiledShader; for (auto stage : IterateStages(GetStageMask())) { - DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(pipelineStages[stage].module) - ->Compile(pipelineStages[stage], stage, + const ProgrammableStage& programmableStage = GetStage(stage); + DAWN_TRY_ASSIGN(compiledShader[stage], ToBackend(programmableStage.module) + ->Compile(programmableStage, stage, ToBackend(GetLayout()), compileFlags)); *shaders[stage] = compiledShader[stage].GetD3D12ShaderBytecode(); } diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp index 9e1b9e9a11..bb0077aa29 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp @@ -65,100 +65,35 @@ namespace dawn::native::d3d12 { namespace { -// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough -std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) { - std::ostringstream out; - out.precision(n); - out << std::fixed << v; - return out.str(); -} - -std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType, - const OverrideScalar* entry, - double value = 0) { - switch (dawnType) { - case EntryPointMetadata::Override::Type::Boolean: - return std::to_string(entry ? entry->b : static_cast(value)); - case EntryPointMetadata::Override::Type::Float32: - return FloatToStringWithPrecision(entry ? entry->f32 : static_cast(value)); - case EntryPointMetadata::Override::Type::Int32: - return std::to_string(entry ? entry->i32 : static_cast(value)); - case EntryPointMetadata::Override::Type::Uint32: - return std::to_string(entry ? entry->u32 : static_cast(value)); - default: - UNREACHABLE(); - } -} - -constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; - -using DefineStrings = std::vector>; - -DefineStrings GetOverridableConstantsDefines( - const PipelineConstantEntries& pipelineConstantEntries, - const EntryPointMetadata::OverridesMap& shaderEntryPointConstants) { - DefineStrings defineStrings; - std::unordered_set overriddenConstants; - - // Set pipeline overridden values - for (const auto& [name, value] : pipelineConstantEntries) { - overriddenConstants.insert(name); - - // This is already validated so `name` must exist - const auto& moduleConstant = shaderEntryPointConstants.at(name); - - defineStrings.emplace_back( - kSpecConstantPrefix + std::to_string(static_cast(moduleConstant.id)), - GetHLSLValueString(moduleConstant.type, nullptr, value)); - } - - // Set shader initialized default values - for (const auto& iter : shaderEntryPointConstants) { - const std::string& name = iter.first; - if (overriddenConstants.count(name) != 0) { - // This constant already has overridden value - continue; - } - - const auto& moduleConstant = shaderEntryPointConstants.at(name); - - // Uninitialized default values are okay since they ar only defined to pass - // compilation but not used - defineStrings.emplace_back( - kSpecConstantPrefix + std::to_string(static_cast(moduleConstant.id)), - GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue)); - } - return defineStrings; -} - enum class Compiler { FXC, DXC }; -#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \ - X(const tint::Program*, inputProgram) \ - X(std::string_view, entryPointName) \ - X(SingleShaderStage, stage) \ - X(uint32_t, shaderModel) \ - X(uint32_t, compileFlags) \ - X(Compiler, compiler) \ - X(uint64_t, compilerVersion) \ - X(std::wstring_view, dxcShaderProfile) \ - X(std::string_view, fxcShaderProfile) \ - X(pD3DCompile, d3dCompile) \ - X(IDxcLibrary*, dxcLibrary) \ - X(IDxcCompiler*, dxcCompiler) \ - X(uint32_t, firstIndexOffsetShaderRegister) \ - X(uint32_t, firstIndexOffsetRegisterSpace) \ - X(bool, usesNumWorkgroups) \ - X(uint32_t, numWorkgroupsShaderRegister) \ - X(uint32_t, numWorkgroupsRegisterSpace) \ - X(DefineStrings, defineStrings) \ - X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ - X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \ - X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \ - X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \ - X(bool, disableSymbolRenaming) \ - X(bool, isRobustnessEnabled) \ - X(bool, disableWorkgroupInit) \ +#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \ + X(const tint::Program*, inputProgram) \ + X(std::string_view, entryPointName) \ + X(SingleShaderStage, stage) \ + X(uint32_t, shaderModel) \ + X(uint32_t, compileFlags) \ + X(Compiler, compiler) \ + X(uint64_t, compilerVersion) \ + X(std::wstring_view, dxcShaderProfile) \ + X(std::string_view, fxcShaderProfile) \ + X(pD3DCompile, d3dCompile) \ + X(IDxcLibrary*, dxcLibrary) \ + X(IDxcCompiler*, dxcCompiler) \ + X(uint32_t, firstIndexOffsetShaderRegister) \ + X(uint32_t, firstIndexOffsetRegisterSpace) \ + X(bool, usesNumWorkgroups) \ + X(uint32_t, numWorkgroupsShaderRegister) \ + X(uint32_t, numWorkgroupsRegisterSpace) \ + X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ + X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \ + X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \ + X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \ + X(std::optional, substituteOverrideConfig) \ + X(LimitsForCompilationRequest, limits) \ + X(bool, disableSymbolRenaming) \ + X(bool, isRobustnessEnabled) \ + X(bool, disableWorkgroupInit) \ X(bool, dumpShaders) #define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \ @@ -170,8 +105,7 @@ enum class Compiler { FXC, DXC }; X(std::string_view, fxcShaderProfile) \ X(pD3DCompile, d3dCompile) \ X(IDxcLibrary*, dxcLibrary) \ - X(IDxcCompiler*, dxcCompiler) \ - X(DefineStrings, defineStrings) + X(IDxcCompiler*, dxcCompiler) DAWN_SERIALIZABLE(struct, HlslCompilationRequest, HLSL_COMPILATION_REQUEST_MEMBERS){}; #undef HLSL_COMPILATION_REQUEST_MEMBERS @@ -255,25 +189,11 @@ ResultOrError> CompileShaderDXC(const D3DBytecodeCompilationReq std::vector arguments = GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature); - // Build defines for overridable constants - std::vector> defineStrings; - defineStrings.reserve(r.defineStrings.size()); - for (const auto& [name, value] : r.defineStrings) { - defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str())); - } - - std::vector dxcDefines; - dxcDefines.reserve(defineStrings.size()); - for (const auto& [name, value] : defineStrings) { - dxcDefines.push_back({name.c_str(), value.c_str()}); - } - ComPtr result; - DAWN_TRY(CheckHRESULT( - r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), - r.dxcShaderProfile.data(), arguments.data(), arguments.size(), - dxcDefines.data(), dxcDefines.size(), nullptr, &result), - "DXC compile")); + DAWN_TRY(CheckHRESULT(r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), + r.dxcShaderProfile.data(), arguments.data(), + arguments.size(), nullptr, 0, nullptr, &result), + "DXC compile")); HRESULT hr; DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status")); @@ -359,20 +279,7 @@ ResultOrError> CompileShaderFXC(const D3DBytecodeCompilationReq ComPtr compiledShader; ComPtr errors; - // Build defines for overridable constants - const D3D_SHADER_MACRO* pDefines = nullptr; - std::vector fxcDefines; - if (r.defineStrings.size() > 0) { - fxcDefines.reserve(r.defineStrings.size() + 1); - for (const auto& [name, value] : r.defineStrings) { - fxcDefines.push_back({name.c_str(), value.c_str()}); - } - // d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array - fxcDefines.push_back({nullptr, nullptr}); - pDefines = fxcDefines.data(); - } - - DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, + DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr, nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(), r.compileFlags, 0, &compiledShader, &errors)), "D3D compile failed with: %s", static_cast(errors->GetBufferPointer())); @@ -420,6 +327,14 @@ ResultOrError TranslateToHLSL( tint::transform::Renamer::Target::kHlslKeywords); } + if (r.substituteOverrideConfig) { + // This needs to run after SingleEntryPoint transform to get rid of overrides not used for + // the current entry point. + transformManager.Add(); + transformInputs.Add( + std::move(r.substituteOverrideConfig).value()); + } + // D3D12 registers like `t3` and `c3` have the same bindingOffset number in // the remapping but should not be considered a collision because they have // different types. @@ -450,6 +365,13 @@ ResultOrError TranslateToHLSL( return DAWN_VALIDATION_ERROR("Transform output missing renamer data."); } + if (r.stage == SingleShaderStage::Compute) { + // Validate workgroup size after program runs transforms. + Extent3D _; + DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize( + transformedProgram, remappedEntryPointName->data(), r.limits)); + } + if (r.stage == SingleShaderStage::Vertex) { if (auto* data = transformOutputs.Get()) { *usesVertexOrInstanceIndex = data->has_vertex_or_instance_index; @@ -555,8 +477,7 @@ ResultOrError ShaderModule::Compile(const ProgrammableStage& pro req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16); req.bytecode.compileFlags = compileFlags; - req.bytecode.defineStrings = - GetOverridableConstantsDefines(programmableStage.constants, entryPoint.overrides); + if (device->IsToggleEnabled(Toggle::UseDXC)) { req.bytecode.compiler = Compiler::DXC; req.bytecode.dxcLibrary = device->GetDxcLibrary().Get(); @@ -645,6 +566,11 @@ ResultOrError ShaderModule::Compile(const ProgrammableStage& pro } } + std::optional substituteOverrideConfig; + if (!programmableStage.metadata->overrides.empty()) { + substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage); + } + req.hlsl.inputProgram = GetTintProgram(); req.hlsl.entryPointName = programmableStage.entryPoint.c_str(); req.hlsl.stage = stage; @@ -657,6 +583,10 @@ ResultOrError ShaderModule::Compile(const ProgrammableStage& pro req.hlsl.remappedAccessControls = std::move(remappedAccessControls); req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout); req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform); + req.hlsl.substituteOverrideConfig = std::move(substituteOverrideConfig); + + const CombinedLimits& limits = device->GetLimits(); + req.hlsl.limits = LimitsForCompilationRequest::Create(limits.v1); CacheResult compiledShader; DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob, diff --git a/src/dawn/native/metal/ComputePipelineMTL.mm b/src/dawn/native/metal/ComputePipelineMTL.mm index 855cd7b4bf..3ed84dcc13 100644 --- a/src/dawn/native/metal/ComputePipelineMTL.mm +++ b/src/dawn/native/metal/ComputePipelineMTL.mm @@ -40,8 +40,9 @@ MaybeError ComputePipeline::Initialize() { const ProgrammableStage& computeStage = GetStage(SingleShaderStage::Compute); ShaderModule::MetalFunctionData computeData; - DAWN_TRY(CreateMTLFunction(computeStage, SingleShaderStage::Compute, ToBackend(GetLayout()), - &computeData)); + DAWN_TRY(ToBackend(computeStage.module.Get()) + ->CreateFunction(SingleShaderStage::Compute, computeStage, ToBackend(GetLayout()), + &computeData)); NSError* error = nullptr; mMtlComputePipelineState.Acquire( @@ -53,8 +54,7 @@ MaybeError ComputePipeline::Initialize() { ASSERT(mMtlComputePipelineState != nil); // Copy over the local workgroup size as it is passed to dispatch explicitly in Metal - Origin3D localSize = GetStage(SingleShaderStage::Compute).metadata->localWorkgroupSize; - mLocalWorkgroupSize = MTLSizeMake(localSize.x, localSize.y, localSize.z); + mLocalWorkgroupSize = computeData.localWorkgroupSize; mRequiresStorageBufferLength = computeData.needsStorageBufferLength; mWorkgroupAllocations = std::move(computeData.workgroupAllocations); diff --git a/src/dawn/native/metal/RenderPipelineMTL.mm b/src/dawn/native/metal/RenderPipelineMTL.mm index 6e5afd54fe..9b9ccb4295 100644 --- a/src/dawn/native/metal/RenderPipelineMTL.mm +++ b/src/dawn/native/metal/RenderPipelineMTL.mm @@ -340,8 +340,9 @@ MaybeError RenderPipeline::Initialize() { const PerStage& allStages = GetAllStages(); const ProgrammableStage& vertexStage = allStages[wgpu::ShaderStage::Vertex]; ShaderModule::MetalFunctionData vertexData; - DAWN_TRY(CreateMTLFunction(vertexStage, SingleShaderStage::Vertex, ToBackend(GetLayout()), - &vertexData, 0xFFFFFFFF, this)); + DAWN_TRY(ToBackend(vertexStage.module.Get()) + ->CreateFunction(SingleShaderStage::Vertex, vertexStage, ToBackend(GetLayout()), + &vertexData, 0xFFFFFFFF, this)); descriptorMTL.vertexFunction = vertexData.function.Get(); if (vertexData.needsStorageBufferLength) { @@ -351,8 +352,9 @@ MaybeError RenderPipeline::Initialize() { if (GetStageMask() & wgpu::ShaderStage::Fragment) { const ProgrammableStage& fragmentStage = allStages[wgpu::ShaderStage::Fragment]; ShaderModule::MetalFunctionData fragmentData; - DAWN_TRY(CreateMTLFunction(fragmentStage, SingleShaderStage::Fragment, - ToBackend(GetLayout()), &fragmentData, GetSampleMask())); + DAWN_TRY(ToBackend(fragmentStage.module.Get()) + ->CreateFunction(SingleShaderStage::Fragment, fragmentStage, + ToBackend(GetLayout()), &fragmentData, GetSampleMask())); descriptorMTL.fragmentFunction = fragmentData.function.Get(); if (fragmentData.needsStorageBufferLength) { diff --git a/src/dawn/native/metal/ShaderModuleMTL.h b/src/dawn/native/metal/ShaderModuleMTL.h index 27f1def213..d79dbb3a72 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.h +++ b/src/dawn/native/metal/ShaderModuleMTL.h @@ -25,6 +25,10 @@ #import +namespace dawn::native { +struct ProgrammableStage; +} + namespace dawn::native::metal { class Device; @@ -42,15 +46,13 @@ class ShaderModule final : public ShaderModuleBase { NSPRef> function; bool needsStorageBufferLength; std::vector workgroupAllocations; + MTLSize localWorkgroupSize; }; - // MTLFunctionConstantValues needs @available tag to compile - // Use id (like void*) in function signature as workaround and do static cast inside - MaybeError CreateFunction(const char* entryPointName, - SingleShaderStage stage, + MaybeError CreateFunction(SingleShaderStage stage, + const ProgrammableStage& programmableStage, const PipelineLayout* layout, MetalFunctionData* out, - id constantValues = nil, uint32_t sampleMask = 0xFFFFFFFF, const RenderPipeline* renderPipeline = nullptr); diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm index 5d4196b8af..b0f59a3b7e 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.mm +++ b/src/dawn/native/metal/ShaderModuleMTL.mm @@ -35,17 +35,20 @@ namespace { using OptionalVertexPullingTransformConfig = std::optional; -#define MSL_COMPILATION_REQUEST_MEMBERS(X) \ - X(const tint::Program*, inputProgram) \ - X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ - X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \ - X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \ - X(std::string, entryPointName) \ - X(uint32_t, sampleMask) \ - X(bool, emitVertexPointSize) \ - X(bool, isRobustnessEnabled) \ - X(bool, disableSymbolRenaming) \ - X(bool, disableWorkgroupInit) \ +#define MSL_COMPILATION_REQUEST_MEMBERS(X) \ + X(SingleShaderStage, stage) \ + X(const tint::Program*, inputProgram) \ + X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ + X(tint::transform::MultiplanarExternalTexture::BindingsMap, externalTextureBindings) \ + X(OptionalVertexPullingTransformConfig, vertexPullingTransformConfig) \ + X(std::optional, substituteOverrideConfig) \ + X(LimitsForCompilationRequest, limits) \ + X(std::string, entryPointName) \ + X(uint32_t, sampleMask) \ + X(bool, emitVertexPointSize) \ + X(bool, isRobustnessEnabled) \ + X(bool, disableSymbolRenaming) \ + X(bool, disableWorkgroupInit) \ X(CacheKey::UnsafeUnkeyedValue, tracePlatform) DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS); @@ -53,12 +56,13 @@ DAWN_MAKE_CACHE_REQUEST(MslCompilationRequest, MSL_COMPILATION_REQUEST_MEMBERS); using WorkgroupAllocations = std::vector; -#define MSL_COMPILATION_MEMBERS(X) \ - X(std::string, msl) \ - X(std::string, remappedEntryPointName) \ - X(bool, needsStorageBufferLength) \ - X(bool, hasInvariantAttribute) \ - X(WorkgroupAllocations, workgroupAllocations) +#define MSL_COMPILATION_MEMBERS(X) \ + X(std::string, msl) \ + X(std::string, remappedEntryPointName) \ + X(bool, needsStorageBufferLength) \ + X(bool, hasInvariantAttribute) \ + X(WorkgroupAllocations, workgroupAllocations) \ + X(Extent3D, localWorkgroupSize) DAWN_SERIALIZABLE(struct, MslCompilation, MSL_COMPILATION_MEMBERS){}; #undef MSL_COMPILATION_MEMBERS @@ -92,13 +96,14 @@ MaybeError ShaderModule::Initialize(ShaderModuleParseResult* parseResult, namespace { -ResultOrError> TranslateToMSL(DeviceBase* device, - const tint::Program* inputProgram, - const char* entryPointName, - SingleShaderStage stage, - const PipelineLayout* layout, - uint32_t sampleMask, - const RenderPipeline* renderPipeline) { +ResultOrError> TranslateToMSL( + DeviceBase* device, + const ProgrammableStage& programmableStage, + SingleShaderStage stage, + const PipelineLayout* layout, + ShaderModule::MetalFunctionData* out, + uint32_t sampleMask, + const RenderPipeline* renderPipeline) { ScopedTintICEHandler scopedICEHandler(device); std::ostringstream errorStream; @@ -137,7 +142,7 @@ ResultOrError> TranslateToMSL(DeviceBase* device, if (stage == SingleShaderStage::Vertex && device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) { vertexPullingTransformConfig = BuildVertexPullingTransformConfig( - *renderPipeline, entryPointName, kPullingBufferBindingSet); + *renderPipeline, programmableStage.entryPoint.c_str(), kPullingBufferBindingSet); for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot); @@ -152,12 +157,19 @@ ResultOrError> TranslateToMSL(DeviceBase* device, } } + std::optional substituteOverrideConfig; + if (!programmableStage.metadata->overrides.empty()) { + substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage); + } + MslCompilationRequest req = {}; - req.inputProgram = inputProgram; + req.stage = stage; + req.inputProgram = programmableStage.module->GetTintProgram(); req.bindingPoints = std::move(bindingPoints); req.externalTextureBindings = std::move(externalTextureBindings); req.vertexPullingTransformConfig = std::move(vertexPullingTransformConfig); - req.entryPointName = entryPointName; + req.substituteOverrideConfig = std::move(substituteOverrideConfig); + req.entryPointName = programmableStage.entryPoint.c_str(); req.sampleMask = sampleMask; req.emitVertexPointSize = stage == SingleShaderStage::Vertex && @@ -166,6 +178,9 @@ ResultOrError> TranslateToMSL(DeviceBase* device, req.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming); req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform()); + const CombinedLimits& limits = device->GetLimits(); + req.limits = LimitsForCompilationRequest::Create(limits.v1); + CacheResult mslCompilation; DAWN_TRY_LOAD_OR_RUN( mslCompilation, device, std::move(req), MslCompilation::FromBlob, @@ -190,6 +205,14 @@ ResultOrError> TranslateToMSL(DeviceBase* device, std::move(r.vertexPullingTransformConfig).value()); } + if (r.substituteOverrideConfig) { + // This needs to run after SingleEntryPoint transform to get rid of overrides not + // used for the current entry point. + transformManager.Add(); + transformInputs.Add( + std::move(r.substituteOverrideConfig).value()); + } + if (r.isRobustnessEnabled) { transformManager.Add(); } @@ -230,6 +253,13 @@ ResultOrError> TranslateToMSL(DeviceBase* device, return DAWN_VALIDATION_ERROR("Transform output missing renamer data."); } + Extent3D localSize{0, 0, 0}; + if (r.stage == SingleShaderStage::Compute) { + // Validate workgroup size after program runs transforms. + DAWN_TRY_ASSIGN(localSize, ValidateComputeStageWorkgroupSize( + program, remappedEntryPointName.data(), r.limits)); + } + tint::writer::msl::Options options; options.buffer_size_ubo_index = kBufferLengthBufferSlot; options.fixed_sample_mask = r.sampleMask; @@ -258,6 +288,7 @@ ResultOrError> TranslateToMSL(DeviceBase* device, result.needs_storage_buffer_sizes, result.has_invariant_attribute, std::move(workgroupAllocations), + localSize, }}; }); @@ -272,11 +303,10 @@ ResultOrError> TranslateToMSL(DeviceBase* device, } // namespace -MaybeError ShaderModule::CreateFunction(const char* entryPointName, - SingleShaderStage stage, +MaybeError ShaderModule::CreateFunction(SingleShaderStage stage, + const ProgrammableStage& programmableStage, const PipelineLayout* layout, ShaderModule::MetalFunctionData* out, - id constantValuesPointer, uint32_t sampleMask, const RenderPipeline* renderPipeline) { TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleMTL::CreateFunction"); @@ -284,16 +314,21 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName, ASSERT(!IsError()); ASSERT(out); + const char* entryPointName = programmableStage.entryPoint.c_str(); + // Vertex stages must specify a renderPipeline if (stage == SingleShaderStage::Vertex) { ASSERT(renderPipeline != nullptr); } CacheResult mslCompilation; - DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), GetTintProgram(), entryPointName, - stage, layout, sampleMask, renderPipeline)); + DAWN_TRY_ASSIGN(mslCompilation, TranslateToMSL(GetDevice(), programmableStage, stage, layout, + out, sampleMask, renderPipeline)); out->needsStorageBufferLength = mslCompilation->needsStorageBufferLength; out->workgroupAllocations = std::move(mslCompilation->workgroupAllocations); + out->localWorkgroupSize = MTLSizeMake(mslCompilation->localWorkgroupSize.width, + mslCompilation->localWorkgroupSize.height, + mslCompilation->localWorkgroupSize.depthOrArrayLayers); NSRef mslSource = AcquireNSRef([[NSString alloc] initWithUTF8String:mslCompilation->msl.c_str()]); @@ -327,25 +362,7 @@ MaybeError ShaderModule::CreateFunction(const char* entryPointName, { TRACE_EVENT0(GetDevice()->GetPlatform(), General, "MTLLibrary::newFunctionWithName"); - if (constantValuesPointer != nil) { - if (@available(macOS 10.12, *)) { - MTLFunctionConstantValues* constantValues = constantValuesPointer; - out->function = AcquireNSPRef([*library newFunctionWithName:name.Get() - constantValues:constantValues - error:&error]); - if (error != nullptr) { - if (error.code != MTLLibraryErrorCompileWarning) { - return DAWN_VALIDATION_ERROR("Function compile error: %s", - [error.localizedDescription UTF8String]); - } - } - ASSERT(out->function != nil); - } else { - UNREACHABLE(); - } - } else { - out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]); - } + out->function = AcquireNSPRef([*library newFunctionWithName:name.Get()]); } if (BlobCache* cache = GetDevice()->GetBlobCache()) { diff --git a/src/dawn/native/metal/UtilsMetal.h b/src/dawn/native/metal/UtilsMetal.h index 9ee31bd618..418c4a6864 100644 --- a/src/dawn/native/metal/UtilsMetal.h +++ b/src/dawn/native/metal/UtilsMetal.h @@ -68,15 +68,6 @@ void EnsureDestinationTextureInitialized(CommandRecordingContext* commandContext MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect); -// Helper function to create function with constant values wrapped in -// if available branch -MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage, - SingleShaderStage singleShaderStage, - PipelineLayout* pipelineLayout, - ShaderModule::MetalFunctionData* functionData, - uint32_t sampleMask = 0xFFFFFFFF, - const RenderPipeline* renderPipeline = nullptr); - // Allow use MTLStoreActionStoreAndMultismapleResolve because the logic in the backend is // first to compute what the "best" Metal render pass descriptor is, then fix it up if we // are not on macOS 10.12 (i.e. the EmulateStoreAndMSAAResolve toggle is on). diff --git a/src/dawn/native/metal/UtilsMetal.mm b/src/dawn/native/metal/UtilsMetal.mm index df122f3e3e..4eb5fb6e81 100644 --- a/src/dawn/native/metal/UtilsMetal.mm +++ b/src/dawn/native/metal/UtilsMetal.mm @@ -323,104 +323,6 @@ MTLBlitOption ComputeMTLBlitOption(const Format& format, Aspect aspect) { return MTLBlitOptionNone; } -MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage, - SingleShaderStage singleShaderStage, - PipelineLayout* pipelineLayout, - ShaderModule::MetalFunctionData* functionData, - uint32_t sampleMask, - const RenderPipeline* renderPipeline) { - ShaderModule* shaderModule = ToBackend(programmableStage.module.Get()); - const char* shaderEntryPoint = programmableStage.entryPoint.c_str(); - const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint); - if (entryPointMetadata.overrides.size() == 0) { - DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout, - functionData, nil, sampleMask, renderPipeline)); - return {}; - } - - if (@available(macOS 10.12, *)) { - // MTLFunctionConstantValues can only be created within the if available branch - NSRef constantValues = - AcquireNSRef([MTLFunctionConstantValues new]); - - std::unordered_set overriddenConstants; - - auto switchType = [&](EntryPointMetadata::Override::Type dawnType, - MTLDataType* type, OverrideScalar* entry, - double value = 0) { - switch (dawnType) { - case EntryPointMetadata::Override::Type::Boolean: - *type = MTLDataTypeBool; - if (entry) { - entry->b = static_cast(value); - } - break; - case EntryPointMetadata::Override::Type::Float32: - *type = MTLDataTypeFloat; - if (entry) { - entry->f32 = static_cast(value); - } - break; - case EntryPointMetadata::Override::Type::Int32: - *type = MTLDataTypeInt; - if (entry) { - entry->i32 = static_cast(value); - } - break; - case EntryPointMetadata::Override::Type::Uint32: - *type = MTLDataTypeUInt; - if (entry) { - entry->u32 = static_cast(value); - } - break; - default: - UNREACHABLE(); - } - }; - - for (const auto& [name, value] : programmableStage.constants) { - overriddenConstants.insert(name); - - // This is already validated so `name` must exist - const auto& moduleConstant = entryPointMetadata.overrides.at(name); - - MTLDataType type; - OverrideScalar entry{}; - - switchType(moduleConstant.type, &type, &entry, value); - - [constantValues.Get() setConstantValue:&entry type:type atIndex:moduleConstant.id]; - } - - // Set shader initialized default values because MSL function_constant - // has no default value - for (const std::string& name : entryPointMetadata.initializedOverrides) { - if (overriddenConstants.count(name) != 0) { - // This constant already has overridden value - continue; - } - - // Must exist because it is validated - const auto& moduleConstant = entryPointMetadata.overrides.at(name); - ASSERT(moduleConstant.isInitialized); - MTLDataType type; - - switchType(moduleConstant.type, &type, nullptr); - - [constantValues.Get() setConstantValue:&moduleConstant.defaultValue - type:type - atIndex:moduleConstant.id]; - } - - DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout, - functionData, constantValues.Get(), sampleMask, - renderPipeline)); - } else { - UNREACHABLE(); - } - return {}; -} - MaybeError EncodeMetalRenderPass(Device* device, CommandRecordingContext* commandContext, MTLRenderPassDescriptor* mtlRenderPass, diff --git a/src/dawn/native/stream/Stream.cpp b/src/dawn/native/stream/Stream.cpp index beb3823da0..eba4f5386c 100644 --- a/src/dawn/native/stream/Stream.cpp +++ b/src/dawn/native/stream/Stream.cpp @@ -16,6 +16,8 @@ #include +#include "dawn/native/Limits.h" + namespace dawn::native::stream { template <> diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp index ad6fbf8410..ea8c59a6bf 100644 --- a/src/dawn/native/vulkan/ComputePipelineVk.cpp +++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp @@ -61,16 +61,12 @@ MaybeError ComputePipeline::Initialize() { ShaderModule::ModuleAndSpirv moduleAndSpirv; DAWN_TRY_ASSIGN(moduleAndSpirv, - module->GetHandleAndSpirv(computeStage.entryPoint.c_str(), layout)); + module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout)); createInfo.stage.module = moduleAndSpirv.module; createInfo.stage.pName = computeStage.entryPoint.c_str(); - std::vector specializationDataEntries; - std::vector specializationMapEntries; - VkSpecializationInfo specializationInfo{}; - createInfo.stage.pSpecializationInfo = GetVkSpecializationInfo( - computeStage, &specializationInfo, &specializationDataEntries, &specializationMapEntries); + createInfo.stage.pSpecializationInfo = nullptr; PNextChainBuilder stageExtChain(&createInfo.stage); diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp index 372abb515b..830e012437 100644 --- a/src/dawn/native/vulkan/RenderPipelineVk.cpp +++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp @@ -341,9 +341,6 @@ MaybeError RenderPipeline::Initialize() { // There are at most 2 shader stages in render pipeline, i.e. vertex and fragment std::array shaderStages; - std::array, 2> specializationDataEntriesPerStages; - std::array, 2> specializationMapEntriesPerStages; - std::array specializationInfoPerStages; uint32_t stageCount = 0; for (auto stage : IterateStages(this->GetStageMask())) { @@ -354,7 +351,7 @@ MaybeError RenderPipeline::Initialize() { ShaderModule::ModuleAndSpirv moduleAndSpirv; DAWN_TRY_ASSIGN(moduleAndSpirv, - module->GetHandleAndSpirv(programmableStage.entryPoint.c_str(), layout)); + module->GetHandleAndSpirv(stage, programmableStage, layout)); shaderStage.module = moduleAndSpirv.module; shaderStage.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO; @@ -379,11 +376,6 @@ MaybeError RenderPipeline::Initialize() { } } - shaderStage.pSpecializationInfo = - GetVkSpecializationInfo(programmableStage, &specializationInfoPerStages[stageCount], - &specializationDataEntriesPerStages[stageCount], - &specializationMapEntriesPerStages[stageCount]); - DAWN_ASSERT(stageCount < 2); shaderStages[stageCount] = shaderStage; stageCount++; diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp index 5477d84784..6588627d96 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp @@ -60,13 +60,35 @@ class ShaderModule::Spirv : private Blob { namespace dawn::native::vulkan { +bool TransformedShaderModuleCacheKey::operator==( + const TransformedShaderModuleCacheKey& other) const { + if (layout != other.layout || entryPoint != other.entryPoint || + constants.size() != other.constants.size()) { + return false; + } + if (!std::equal(constants.begin(), constants.end(), other.constants.begin())) { + return false; + } + return true; +} + +size_t TransformedShaderModuleCacheKeyHashFunc::operator()( + const TransformedShaderModuleCacheKey& key) const { + size_t hash = 0; + HashCombine(&hash, key.layout, key.entryPoint); + for (const auto& entry : key.constants) { + HashCombine(&hash, entry.first, entry.second); + } + return hash; +} + class ShaderModule::ConcurrentTransformedShaderModuleCache { public: explicit ConcurrentTransformedShaderModuleCache(Device* device); ~ConcurrentTransformedShaderModuleCache(); - std::optional Find(const PipelineLayoutEntryPointPair& key); - ModuleAndSpirv AddOrGet(const PipelineLayoutEntryPointPair& key, + std::optional Find(const TransformedShaderModuleCacheKey& key); + ModuleAndSpirv AddOrGet(const TransformedShaderModuleCacheKey& key, VkShaderModule module, Spirv&& spirv); @@ -75,7 +97,9 @@ class ShaderModule::ConcurrentTransformedShaderModuleCache { Device* mDevice; std::mutex mMutex; - std::unordered_map + std::unordered_map mTransformedShaderModuleCache; }; @@ -92,7 +116,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::~ConcurrentTransformedShad std::optional ShaderModule::ConcurrentTransformedShaderModuleCache::Find( - const PipelineLayoutEntryPointPair& key) { + const TransformedShaderModuleCacheKey& key) { std::lock_guard lock(mMutex); auto iter = mTransformedShaderModuleCache.find(key); if (iter != mTransformedShaderModuleCache.end()) { @@ -106,7 +130,7 @@ ShaderModule::ConcurrentTransformedShaderModuleCache::Find( } ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet( - const PipelineLayoutEntryPointPair& key, + const TransformedShaderModuleCacheKey& key, VkShaderModule module, Spirv&& spirv) { ASSERT(module != VK_NULL_HANDLE); @@ -168,20 +192,24 @@ void ShaderModule::DestroyImpl() { ShaderModule::~ShaderModule() = default; -#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \ - X(const tint::Program*, inputProgram) \ - X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ - X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ - X(std::string_view, entryPointName) \ - X(bool, disableWorkgroupInit) \ - X(bool, useZeroInitializeWorkgroupMemoryExtension) \ +#define SPIRV_COMPILATION_REQUEST_MEMBERS(X) \ + X(SingleShaderStage, stage) \ + X(const tint::Program*, inputProgram) \ + X(tint::transform::BindingRemapper::BindingPoints, bindingPoints) \ + X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ + X(std::optional, substituteOverrideConfig) \ + X(LimitsForCompilationRequest, limits) \ + X(std::string_view, entryPointName) \ + X(bool, disableWorkgroupInit) \ + X(bool, useZeroInitializeWorkgroupMemoryExtension) \ X(CacheKey::UnsafeUnkeyedValue, tracePlatform) DAWN_MAKE_CACHE_REQUEST(SpirvCompilationRequest, SPIRV_COMPILATION_REQUEST_MEMBERS); #undef SPIRV_COMPILATION_REQUEST_MEMBERS ResultOrError ShaderModule::GetHandleAndSpirv( - const char* entryPointName, + SingleShaderStage stage, + const ProgrammableStage& programmableStage, const PipelineLayout* layout) { TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleVk::GetHandleAndSpirv"); @@ -191,7 +219,8 @@ ResultOrError ShaderModule::GetHandleAndSpirv( ScopedTintICEHandler scopedICEHandler(GetDevice()); // Check to see if we have the handle and spirv cached already. - auto cacheKey = std::make_pair(layout, entryPointName); + auto cacheKey = TransformedShaderModuleCacheKey{layout, programmableStage.entryPoint.c_str(), + programmableStage.constants}; auto handleAndSpirv = mTransformedShaderModuleCache->Find(cacheKey); if (handleAndSpirv.has_value()) { return std::move(*handleAndSpirv); @@ -204,7 +233,8 @@ ResultOrError ShaderModule::GetHandleAndSpirv( using BindingPoint = tint::transform::BindingPoint; BindingRemapper::BindingPoints bindingPoints; - const BindingInfoArray& moduleBindingInfo = GetEntryPoint(entryPointName).bindings; + const BindingInfoArray& moduleBindingInfo = + GetEntryPoint(programmableStage.entryPoint.c_str()).bindings; for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); @@ -238,16 +268,26 @@ ResultOrError ShaderModule::GetHandleAndSpirv( } } + std::optional substituteOverrideConfig; + if (!programmableStage.metadata->overrides.empty()) { + substituteOverrideConfig = BuildSubstituteOverridesTransformConfig(programmableStage); + } + #if TINT_BUILD_SPV_WRITER SpirvCompilationRequest req = {}; + req.stage = stage; req.inputProgram = GetTintProgram(); req.bindingPoints = std::move(bindingPoints); req.newBindingsMap = std::move(newBindingsMap); - req.entryPointName = entryPointName; + req.entryPointName = programmableStage.entryPoint; req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit); req.useZeroInitializeWorkgroupMemoryExtension = GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension); req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform()); + req.substituteOverrideConfig = std::move(substituteOverrideConfig); + + const CombinedLimits& limits = GetDevice()->GetLimits(); + req.limits = LimitsForCompilationRequest::Create(limits.v1); CacheResult spirv; DAWN_TRY_LOAD_OR_RUN( @@ -270,12 +310,27 @@ ResultOrError ShaderModule::GetHandleAndSpirv( transformInputs.Add( r.newBindingsMap); } + if (r.substituteOverrideConfig) { + // This needs to run after SingleEntryPoint transform to get rid of overrides not + // used for the current entry point. + transformManager.Add(); + transformInputs.Add( + std::move(r.substituteOverrideConfig).value()); + } tint::Program program; { TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms"); DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram, transformInputs, nullptr, nullptr)); } + + if (r.stage == SingleShaderStage::Compute) { + // Validate workgroup size after program runs transforms. + Extent3D _; + DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize( + program, r.entryPointName.data(), r.limits)); + } + tint::writer::spirv::Options options; options.emit_vertex_point_size = true; options.disable_workgroup_init = r.disableWorkgroupInit; diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h index c0741eff6e..7c6a5dff46 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.h +++ b/src/dawn/native/vulkan/ShaderModuleVk.h @@ -18,14 +18,32 @@ #include #include #include +#include #include #include +#include "dawn/common/HashUtils.h" #include "dawn/common/vulkan_platform.h" #include "dawn/native/Error.h" #include "dawn/native/ShaderModule.h" -namespace dawn::native::vulkan { +namespace dawn::native { + +struct ProgrammableStage; + +namespace vulkan { + +struct TransformedShaderModuleCacheKey { + const PipelineLayoutBase* layout; + std::string entryPoint; + PipelineConstantEntries constants; + + bool operator==(const TransformedShaderModuleCacheKey& other) const; +}; + +struct TransformedShaderModuleCacheKeyHashFunc { + size_t operator()(const TransformedShaderModuleCacheKey& key) const; +}; class Device; class PipelineLayout; @@ -44,7 +62,8 @@ class ShaderModule final : public ShaderModuleBase { ShaderModuleParseResult* parseResult, OwnedCompilationMessages* compilationMessages); - ResultOrError GetHandleAndSpirv(const char* entryPointName, + ResultOrError GetHandleAndSpirv(SingleShaderStage stage, + const ProgrammableStage& programmableStage, const PipelineLayout* layout); private: @@ -59,6 +78,8 @@ class ShaderModule final : public ShaderModuleBase { std::unique_ptr mTransformedShaderModuleCache; }; -} // namespace dawn::native::vulkan +} // namespace vulkan + +} // namespace dawn::native #endif // SRC_DAWN_NATIVE_VULKAN_SHADERMODULEVK_H_ diff --git a/src/dawn/native/vulkan/UtilsVulkan.cpp b/src/dawn/native/vulkan/UtilsVulkan.cpp index ca8b4052ca..aa02ec35e5 100644 --- a/src/dawn/native/vulkan/UtilsVulkan.cpp +++ b/src/dawn/native/vulkan/UtilsVulkan.cpp @@ -258,60 +258,4 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName) { return std::string(debugName, length); } -VkSpecializationInfo* GetVkSpecializationInfo( - const ProgrammableStage& programmableStage, - VkSpecializationInfo* specializationInfo, - std::vector* specializationDataEntries, - std::vector* specializationMapEntries) { - ASSERT(specializationInfo); - ASSERT(specializationDataEntries); - ASSERT(specializationMapEntries); - - if (programmableStage.constants.size() == 0) { - return nullptr; - } - - const EntryPointMetadata& entryPointMetaData = - programmableStage.module->GetEntryPoint(programmableStage.entryPoint); - - for (const auto& pipelineConstant : programmableStage.constants) { - const std::string& identifier = pipelineConstant.first; - double value = pipelineConstant.second; - - // This is already validated so `identifier` must exist - const auto& moduleConstant = entryPointMetaData.overrides.at(identifier); - - specializationMapEntries->push_back(VkSpecializationMapEntry{ - moduleConstant.id, - static_cast(specializationDataEntries->size() * sizeof(OverrideScalar)), - sizeof(OverrideScalar)}); - - OverrideScalar entry{}; - switch (moduleConstant.type) { - case EntryPointMetadata::Override::Type::Boolean: - entry.b = static_cast(value); - break; - case EntryPointMetadata::Override::Type::Float32: - entry.f32 = static_cast(value); - break; - case EntryPointMetadata::Override::Type::Int32: - entry.i32 = static_cast(value); - break; - case EntryPointMetadata::Override::Type::Uint32: - entry.u32 = static_cast(value); - break; - default: - UNREACHABLE(); - } - specializationDataEntries->push_back(entry); - } - - specializationInfo->mapEntryCount = static_cast(specializationMapEntries->size()); - specializationInfo->pMapEntries = specializationMapEntries->data(); - specializationInfo->dataSize = specializationDataEntries->size() * sizeof(OverrideScalar); - specializationInfo->pData = specializationDataEntries->data(); - - return specializationInfo; -} - } // namespace dawn::native::vulkan diff --git a/src/dawn/native/vulkan/UtilsVulkan.h b/src/dawn/native/vulkan/UtilsVulkan.h index 3e2748e173..5ccdf0adf6 100644 --- a/src/dawn/native/vulkan/UtilsVulkan.h +++ b/src/dawn/native/vulkan/UtilsVulkan.h @@ -144,15 +144,6 @@ void SetDebugName(Device* device, std::string GetNextDeviceDebugPrefix(); std::string GetDeviceDebugPrefixFromDebugName(const char* debugName); -// Returns nullptr or &specializationInfo -// specializationInfo, specializationDataEntries, specializationMapEntries needs to -// be alive at least until VkSpecializationInfo is passed into Vulkan Create*Pipelines -VkSpecializationInfo* GetVkSpecializationInfo( - const ProgrammableStage& programmableStage, - VkSpecializationInfo* specializationInfo, - std::vector* specializationDataEntries, - std::vector* specializationMapEntries); - } // namespace dawn::native::vulkan #endif // SRC_DAWN_NATIVE_VULKAN_UTILSVULKAN_H_ diff --git a/src/dawn/tests/BUILD.gn b/src/dawn/tests/BUILD.gn index eb1ed6c186..2955a8b94a 100644 --- a/src/dawn/tests/BUILD.gn +++ b/src/dawn/tests/BUILD.gn @@ -265,6 +265,7 @@ dawn_test("dawn_unittests") { "unittests/native/CreatePipelineAsyncTaskTests.cpp", "unittests/native/DestroyObjectTests.cpp", "unittests/native/DeviceCreationTests.cpp", + "unittests/native/ObjectContentHasherTests.cpp", "unittests/native/StreamTests.cpp", "unittests/validation/BindGroupValidationTests.cpp", "unittests/validation/BufferValidationTests.cpp", @@ -489,6 +490,7 @@ source_set("end2end_tests_sources") { "end2end/ScissorTests.cpp", "end2end/ShaderFloat16Tests.cpp", "end2end/ShaderTests.cpp", + "end2end/ShaderValidationTests.cpp", "end2end/StorageTextureTests.cpp", "end2end/SubresourceRenderAttachmentTests.cpp", "end2end/Texture3DTests.cpp", diff --git a/src/dawn/tests/end2end/ObjectCachingTests.cpp b/src/dawn/tests/end2end/ObjectCachingTests.cpp index cda1fca1ea..aeebd30c3e 100644 --- a/src/dawn/tests/end2end/ObjectCachingTests.cpp +++ b/src/dawn/tests/end2end/ObjectCachingTests.cpp @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "dawn/tests/DawnTest.h" #include "dawn/utils/ComboRenderPipelineDescriptor.h" @@ -158,6 +160,46 @@ TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnShaderModule) { EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); } +// Test that ComputePipeline are correctly deduplicated wrt. their constants override values +TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnOverrides) { + wgpu::ShaderModule module = utils::CreateShaderModule(device, R"( + override x: u32 = 1u; + var i : u32; + @compute @workgroup_size(x) fn main() { + i = 0u; + })"); + + wgpu::PipelineLayout layout = utils::MakeBasicPipelineLayout(device, nullptr); + + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + desc.layout = layout; + desc.compute.module = module; + + std::vector constants{{nullptr, "x", 16}}; + desc.compute.constantCount = constants.size(); + desc.compute.constants = constants.data(); + wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&desc); + + std::vector sameConstants{{nullptr, "x", 16}}; + desc.compute.constantCount = sameConstants.size(); + desc.compute.constants = sameConstants.data(); + wgpu::ComputePipeline samePipeline = device.CreateComputePipeline(&desc); + + desc.compute.constantCount = 0; + desc.compute.constants = nullptr; + wgpu::ComputePipeline otherPipeline1 = device.CreateComputePipeline(&desc); + + std::vector otherConstants{{nullptr, "x", 4}}; + desc.compute.constantCount = otherConstants.size(); + desc.compute.constants = otherConstants.data(); + wgpu::ComputePipeline otherPipeline2 = device.CreateComputePipeline(&desc); + + EXPECT_NE(pipeline.Get(), otherPipeline1.Get()); + EXPECT_NE(pipeline.Get(), otherPipeline2.Get()); + EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); +} + // Test that ComputePipeline are correctly deduplicated wrt. their layout TEST_P(ObjectCachingTest, ComputePipelineDeduplicationOnLayout) { wgpu::BindGroupLayout bgl = utils::MakeBindGroupLayout( @@ -303,6 +345,48 @@ TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnFragmentModule) { EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); } +// Test that Renderpipelines are correctly deduplicated wrt. their constants override values +TEST_P(ObjectCachingTest, RenderPipelineDeduplicationOnOverrides) { + wgpu::ShaderModule module = utils::CreateShaderModule(device, R"( + override a: f32 = 1.0; + @vertex fn vertexMain() -> @builtin(position) vec4 { + return vec4(0.0, 0.0, 0.0, 0.0); + } + @fragment fn fragmentMain() -> @location(0) vec4 { + return vec4(0.0, 0.0, 0.0, a); + })"); + + utils::ComboRenderPipelineDescriptor desc; + desc.vertex.module = module; + desc.vertex.entryPoint = "vertexMain"; + desc.cFragment.module = module; + desc.cFragment.entryPoint = "fragmentMain"; + desc.cTargets[0].writeMask = wgpu::ColorWriteMask::None; + + std::vector constants{{nullptr, "a", 0.5}}; + desc.cFragment.constantCount = constants.size(); + desc.cFragment.constants = constants.data(); + wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&desc); + + std::vector sameConstants{{nullptr, "a", 0.5}}; + desc.cFragment.constantCount = sameConstants.size(); + desc.cFragment.constants = sameConstants.data(); + wgpu::RenderPipeline samePipeline = device.CreateRenderPipeline(&desc); + + std::vector otherConstants{{nullptr, "a", 1.0}}; + desc.cFragment.constantCount = otherConstants.size(); + desc.cFragment.constants = otherConstants.data(); + wgpu::RenderPipeline otherPipeline1 = device.CreateRenderPipeline(&desc); + + desc.cFragment.constantCount = 0; + desc.cFragment.constants = nullptr; + wgpu::RenderPipeline otherPipeline2 = device.CreateRenderPipeline(&desc); + + EXPECT_NE(pipeline.Get(), otherPipeline1.Get()); + EXPECT_NE(pipeline.Get(), otherPipeline2.Get()); + EXPECT_EQ(pipeline.Get() == samePipeline.Get(), !UsesWire()); +} + // Test that Samplers are correctly deduplicated. TEST_P(ObjectCachingTest, SamplerDeduplication) { wgpu::SamplerDescriptor samplerDesc; diff --git a/src/dawn/tests/end2end/ShaderTests.cpp b/src/dawn/tests/end2end/ShaderTests.cpp index 82159953ae..45ba077516 100644 --- a/src/dawn/tests/end2end/ShaderTests.cpp +++ b/src/dawn/tests/end2end/ShaderTests.cpp @@ -471,6 +471,127 @@ struct Buf { EXPECT_BUFFER_U32_RANGE_EQ(expected.data(), buffer, 0, kCount); } +// Test one shader shared by two pipelines with different constants overridden +TEST_P(ShaderTests, OverridableConstantsSharedShader) { + DAWN_TEST_UNSUPPORTED_IF(IsOpenGL()); + DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES()); + + std::vector expected1{1}; + wgpu::Buffer buffer1 = CreateBuffer(expected1.size()); + std::vector expected2{2}; + wgpu::Buffer buffer2 = CreateBuffer(expected2.size()); + + std::string shader = R"( +override a: u32; + +struct Buf { + data : array +} + +@group(0) @binding(0) var buf : Buf; + +@compute @workgroup_size(1) fn main() { + buf.data[0] = a; +})"; + + std::vector constants1; + constants1.push_back({nullptr, "a", 1}); + std::vector constants2; + constants2.push_back({nullptr, "a", 2}); + + wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1); + wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2); + + wgpu::BindGroup bindGroup1 = + utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}}); + wgpu::BindGroup bindGroup2 = + utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}}); + + wgpu::CommandBuffer commands; + { + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(pipeline1); + pass.SetBindGroup(0, bindGroup1); + pass.DispatchWorkgroups(1); + pass.SetPipeline(pipeline2); + pass.SetBindGroup(0, bindGroup2); + pass.DispatchWorkgroups(1); + pass.End(); + + commands = encoder.Finish(); + } + + queue.Submit(1, &commands); + + EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size()); + EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size()); +} + +// Test overridable constants work with workgroup size +TEST_P(ShaderTests, OverridableConstantsWorkgroupSize) { + DAWN_TEST_UNSUPPORTED_IF(IsOpenGL()); + DAWN_TEST_UNSUPPORTED_IF(IsOpenGLES()); + + std::string shader = R"( +override x: u32; + +struct Buf { + data : array +} + +@group(0) @binding(0) var buf : Buf; + +@compute @workgroup_size(x) fn main( + @builtin(local_invocation_id) local_invocation_id : vec3 +) { + if (local_invocation_id.x >= x - 1) { + buf.data[0] = local_invocation_id.x + 1; + } +})"; + + const uint32_t workgroup_size_x_1 = 16u; + const uint32_t workgroup_size_x_2 = 64u; + + std::vector expected1{workgroup_size_x_1}; + wgpu::Buffer buffer1 = CreateBuffer(expected1.size()); + std::vector expected2{workgroup_size_x_2}; + wgpu::Buffer buffer2 = CreateBuffer(expected2.size()); + + std::vector constants1; + constants1.push_back({nullptr, "x", static_cast(workgroup_size_x_1)}); + std::vector constants2; + constants2.push_back({nullptr, "x", static_cast(workgroup_size_x_2)}); + + wgpu::ComputePipeline pipeline1 = CreateComputePipeline(shader, "main", &constants1); + wgpu::ComputePipeline pipeline2 = CreateComputePipeline(shader, "main", &constants2); + + wgpu::BindGroup bindGroup1 = + utils::MakeBindGroup(device, pipeline1.GetBindGroupLayout(0), {{0, buffer1}}); + wgpu::BindGroup bindGroup2 = + utils::MakeBindGroup(device, pipeline2.GetBindGroupLayout(0), {{0, buffer2}}); + + wgpu::CommandBuffer commands; + { + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(pipeline1); + pass.SetBindGroup(0, bindGroup1); + pass.DispatchWorkgroups(1); + pass.SetPipeline(pipeline2); + pass.SetBindGroup(0, bindGroup2); + pass.DispatchWorkgroups(1); + pass.End(); + + commands = encoder.Finish(); + } + + queue.Submit(1, &commands); + + EXPECT_BUFFER_U32_RANGE_EQ(expected1.data(), buffer1, 0, expected1.size()); + EXPECT_BUFFER_U32_RANGE_EQ(expected2.data(), buffer2, 0, expected2.size()); +} + // Test overridable constants with numeric identifiers TEST_P(ShaderTests, OverridableConstantsNumericIdentifiers) { DAWN_TEST_UNSUPPORTED_IF(IsOpenGL()); @@ -596,6 +717,7 @@ TEST_P(ShaderTests, OverridableConstantsMultipleEntryPoints) { std::string shader = R"( @id(1001) override c1: u32; @id(1002) override c2: u32; +@id(1003) override c3: u32; struct Buf { data : array @@ -611,7 +733,7 @@ struct Buf { buf.data[0] = c2; } -@compute @workgroup_size(1) fn main3() { +@compute @workgroup_size(c3) fn main3() { buf.data[0] = 3u; } )"; @@ -620,6 +742,8 @@ struct Buf { constants1.push_back({nullptr, "1001", 1}); std::vector constants2; constants2.push_back({nullptr, "1002", 2}); + std::vector constants3; + constants3.push_back({nullptr, "1003", 1}); wgpu::ShaderModule shaderModule = utils::CreateShaderModule(device, shader.c_str()); @@ -640,6 +764,8 @@ struct Buf { wgpu::ComputePipelineDescriptor csDesc3; csDesc3.compute.module = shaderModule; csDesc3.compute.entryPoint = "main3"; + csDesc3.compute.constants = constants3.data(); + csDesc3.compute.constantCount = constants3.size(); wgpu::ComputePipeline pipeline3 = device.CreateComputePipeline(&csDesc3); wgpu::BindGroup bindGroup1 = @@ -765,8 +891,6 @@ TEST_P(ShaderTests, ConflictingBindingsDueToTransformOrder) { device.CreateRenderPipeline(&desc); } -// TODO(tint:1155): Test overridable constants used for workgroup size - DAWN_INSTANTIATE_TEST(ShaderTests, D3D12Backend(), MetalBackend(), diff --git a/src/dawn/tests/end2end/ShaderValidationTests.cpp b/src/dawn/tests/end2end/ShaderValidationTests.cpp new file mode 100644 index 0000000000..c8b03f31db --- /dev/null +++ b/src/dawn/tests/end2end/ShaderValidationTests.cpp @@ -0,0 +1,395 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "dawn/tests/DawnTest.h" +#include "dawn/utils/ComboRenderPipelineDescriptor.h" +#include "dawn/utils/WGPUHelpers.h" + +// The compute shader workgroup size is settled at compute pipeline creation time. +// The validation code in dawn is in each backend (not including Null backend) thus this test needs +// to be as part of a dawn_end2end_tests instead of the dawn_unittests +// TODO(dawn:1504): Add support for GL backend. +class WorkgroupSizeValidationTest : public DawnTest { + public: + wgpu::ShaderModule SetUpShaderWithValidDefaultValueConstants() { + return utils::CreateShaderModule(device, R"( +override x: u32 = 1u; +override y: u32 = 1u; +override z: u32 = 1u; + +@compute @workgroup_size(x, y, z) fn main() { + _ = 0u; +})"); + } + + wgpu::ShaderModule SetUpShaderWithZeroDefaultValueConstants() { + return utils::CreateShaderModule(device, R"( +override x: u32 = 0u; +override y: u32 = 0u; +override z: u32 = 0u; + +@compute @workgroup_size(x, y, z) fn main() { + _ = 0u; +})"); + } + + wgpu::ShaderModule SetUpShaderWithOutOfLimitsDefaultValueConstants() { + return utils::CreateShaderModule(device, R"( +override x: u32 = 1u; +override y: u32 = 1u; +override z: u32 = 9999u; + +@compute @workgroup_size(x, y, z) fn main() { + _ = 0u; +})"); + } + + wgpu::ShaderModule SetUpShaderWithUninitializedConstants() { + return utils::CreateShaderModule(device, R"( +override x: u32; +override y: u32; +override z: u32; + +@compute @workgroup_size(x, y, z) fn main() { + _ = 0u; +})"); + } + + wgpu::ShaderModule SetUpShaderWithPartialConstants() { + return utils::CreateShaderModule(device, R"( +override x: u32; + +@compute @workgroup_size(x, 1, 1) fn main() { + _ = 0u; +})"); + } + + void TestCreatePipeline(const wgpu::ShaderModule& module) { + wgpu::ComputePipelineDescriptor csDesc; + csDesc.compute.module = module; + csDesc.compute.entryPoint = "main"; + csDesc.compute.constants = nullptr; + csDesc.compute.constantCount = 0; + wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc); + } + + void TestCreatePipeline(const wgpu::ShaderModule& module, + const std::vector& constants) { + wgpu::ComputePipelineDescriptor csDesc; + csDesc.compute.module = module; + csDesc.compute.entryPoint = "main"; + csDesc.compute.constants = constants.data(); + csDesc.compute.constantCount = constants.size(); + wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&csDesc); + } + + void TestInitializedWithZero(const wgpu::ShaderModule& module) { + std::vector constants{ + {nullptr, "x", 0}, {nullptr, "y", 0}, {nullptr, "z", 0}}; + TestCreatePipeline(module, constants); + } + + void TestInitializedWithOutOfLimitValue(const wgpu::ShaderModule& module) { + std::vector constants{ + {nullptr, "x", + static_cast(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)}, + {nullptr, "y", 1}, + {nullptr, "z", 1}}; + TestCreatePipeline(module, constants); + } + + void TestInitializedWithValidValue(const wgpu::ShaderModule& module) { + std::vector constants{ + {nullptr, "x", 4}, {nullptr, "y", 4}, {nullptr, "z", 4}}; + TestCreatePipeline(module, constants); + } + + void TestInitializedPartially(const wgpu::ShaderModule& module) { + std::vector constants{{nullptr, "y", 4}}; + TestCreatePipeline(module, constants); + } + + wgpu::Buffer buffer; +}; + +// Test workgroup size validation with fixed values. +TEST_P(WorkgroupSizeValidationTest, WithFixedValues) { + auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) { + std::ostringstream ss; + ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}"; + + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); + + if (success) { + device.CreateComputePipeline(&desc); + } else { + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); + } + }; + + wgpu::Limits supportedLimits = GetSupportedLimits().limits; + + CheckShaderWithWorkgroupSize(true, 1, 1, 1); + CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1); + CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1); + CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ); + + CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1); + CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1); + CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1); + + // No individual dimension exceeds its limit, but the combined size should definitely exceed the + // total invocation limit. + DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX * + supportedLimits.maxComputeWorkgroupSizeY * + supportedLimits.maxComputeWorkgroupSizeZ > + supportedLimits.maxComputeInvocationsPerWorkgroup); + CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX, + supportedLimits.maxComputeWorkgroupSizeY, + supportedLimits.maxComputeWorkgroupSizeZ); +} + +// Test workgroup size validation with fixed values (storage size limits validation). +TEST_P(WorkgroupSizeValidationTest, WithFixedValuesStorageSizeLimits) { + wgpu::Limits supportedLimits = GetSupportedLimits().limits; + + constexpr uint32_t kVec4Size = 16; + const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size; + constexpr uint32_t kMat4Size = 64; + const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size; + + auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count, + uint32_t mat4_count) { + std::ostringstream ss; + std::ostringstream body; + if (vec4_count > 0) { + ss << "var vec4_data: array, " << vec4_count << ">;"; + body << "_ = vec4_data;"; + } + if (mat4_count > 0) { + ss << "var mat4_data: array, " << mat4_count << ">;"; + body << "_ = mat4_data;"; + } + ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }"; + + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); + + if (success) { + device.CreateComputePipeline(&desc); + } else { + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); + } + }; + + CheckPipelineWithWorkgroupStorage(true, 1, 1); + CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0); + CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count); + CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1); + CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1); + + CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0); + CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1); + CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1); + CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count); +} + +// Test workgroup size validation with valid overrides default values. +TEST_P(WorkgroupSizeValidationTest, OverridesWithValidDefault) { + wgpu::ShaderModule module = SetUpShaderWithValidDefaultValueConstants(); + { + // Valid default + TestCreatePipeline(module); + } + { + // Error: invalid value (zero) + ASSERT_DEVICE_ERROR(TestInitializedWithZero(module)); + } + { + // Error: invalid value (out of device limits) + ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module)); + } + { + // Valid: initialized partially + TestInitializedPartially(module); + } + { + // Valid + TestInitializedWithValidValue(module); + } +} + +// Test workgroup size validation with zero as the overrides default values. +TEST_P(WorkgroupSizeValidationTest, OverridesWithZeroDefault) { + // Error: zero is detected as invalid at shader creation time + ASSERT_DEVICE_ERROR(SetUpShaderWithZeroDefaultValueConstants()); +} + +// Test workgroup size validation with out-of-limits overrides default values. +TEST_P(WorkgroupSizeValidationTest, OverridesWithOutOfLimitsDefault) { + wgpu::ShaderModule module = SetUpShaderWithOutOfLimitsDefaultValueConstants(); + { + // Error: invalid default + ASSERT_DEVICE_ERROR(TestCreatePipeline(module)); + } + { + // Error: invalid value (zero) + ASSERT_DEVICE_ERROR(TestInitializedWithZero(module)); + } + { + // Error: invalid value (out of device limits) + ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module)); + } + { + // Error: initialized partially + ASSERT_DEVICE_ERROR(TestInitializedPartially(module)); + } + { + // Valid + TestInitializedWithValidValue(module); + } +} + +// Test workgroup size validation without overrides default values specified. +TEST_P(WorkgroupSizeValidationTest, OverridesWithUninitialized) { + wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants(); + { + // Error: uninitialized + ASSERT_DEVICE_ERROR(TestCreatePipeline(module)); + } + { + // Error: invalid value (zero) + ASSERT_DEVICE_ERROR(TestInitializedWithZero(module)); + } + { + // Error: invalid value (out of device limits) + ASSERT_DEVICE_ERROR(TestInitializedWithOutOfLimitValue(module)); + } + { + // Error: initialized partially + ASSERT_DEVICE_ERROR(TestInitializedPartially(module)); + } + { + // Valid + TestInitializedWithValidValue(module); + } +} + +// Test workgroup size validation with only partial dimensions are overrides. +TEST_P(WorkgroupSizeValidationTest, PartialOverrides) { + wgpu::ShaderModule module = SetUpShaderWithPartialConstants(); + { + // Error: uninitialized + ASSERT_DEVICE_ERROR(TestCreatePipeline(module)); + } + { + // Error: invalid value (zero) + std::vector constants{{nullptr, "x", 0}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants)); + } + { + // Error: invalid value (out of device limits) + std::vector constants{ + {nullptr, "x", + static_cast(GetSupportedLimits().limits.maxComputeWorkgroupSizeX + 1)}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants)); + } + { + // Valid + std::vector constants{{nullptr, "x", 16}}; + TestCreatePipeline(module, constants); + } +} + +// Test workgroup size validation after being overrided with invalid values. +TEST_P(WorkgroupSizeValidationTest, ValidationAfterOverride) { + wgpu::ShaderModule module = SetUpShaderWithUninitializedConstants(); + wgpu::Limits supportedLimits = GetSupportedLimits().limits; + { + // Error: exceed maxComputeWorkgroupSizeZ + std::vector constants{ + {nullptr, "x", 1}, + {nullptr, "y", 1}, + {nullptr, "z", static_cast(supportedLimits.maxComputeWorkgroupSizeZ + 1)}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants)); + } + { + // Error: exceed maxComputeInvocationsPerWorkgroup + DAWN_ASSERT(supportedLimits.maxComputeWorkgroupSizeX * + supportedLimits.maxComputeWorkgroupSizeY * + supportedLimits.maxComputeWorkgroupSizeZ > + supportedLimits.maxComputeInvocationsPerWorkgroup); + std::vector constants{ + {nullptr, "x", static_cast(supportedLimits.maxComputeWorkgroupSizeX)}, + {nullptr, "y", static_cast(supportedLimits.maxComputeWorkgroupSizeY)}, + {nullptr, "z", static_cast(supportedLimits.maxComputeWorkgroupSizeZ)}}; + ASSERT_DEVICE_ERROR(TestCreatePipeline(module, constants)); + } +} + +// Test workgroup size validation after being overrided with invalid values (storage size limits +// validation). +// TODO(tint:1660): re-enable after override can be used as array size. +TEST_P(WorkgroupSizeValidationTest, DISABLED_ValidationAfterOverrideStorageSize) { + wgpu::Limits supportedLimits = GetSupportedLimits().limits; + + constexpr uint32_t kVec4Size = 16; + const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size; + constexpr uint32_t kMat4Size = 64; + const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size; + + auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count, + uint32_t mat4_count) { + std::ostringstream ss; + std::ostringstream body; + ss << "override a: u32;"; + ss << "override b: u32;"; + if (vec4_count > 0) { + ss << "var vec4_data: array, a>;"; + body << "_ = vec4_data;"; + } + if (mat4_count > 0) { + ss << "var mat4_data: array, b>;"; + body << "_ = mat4_data;"; + } + ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }"; + + wgpu::ComputePipelineDescriptor desc; + desc.compute.entryPoint = "main"; + desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); + + std::vector constants{{nullptr, "a", static_cast(vec4_count)}, + {nullptr, "b", static_cast(mat4_count)}}; + desc.compute.constants = constants.data(); + desc.compute.constantCount = constants.size(); + + if (success) { + device.CreateComputePipeline(&desc); + } else { + ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); + } + }; + + CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0); + CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1); +} + +DAWN_INSTANTIATE_TEST(WorkgroupSizeValidationTest, D3D12Backend(), MetalBackend(), VulkanBackend()); diff --git a/src/dawn/tests/unittests/native/ObjectContentHasherTests.cpp b/src/dawn/tests/unittests/native/ObjectContentHasherTests.cpp new file mode 100644 index 0000000000..bf926fe304 --- /dev/null +++ b/src/dawn/tests/unittests/native/ObjectContentHasherTests.cpp @@ -0,0 +1,81 @@ +// Copyright 2022 The Dawn Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include + +#include "dawn/native/ObjectContentHasher.h" +#include "dawn/tests/DawnNativeTest.h" + +namespace dawn::native { + +namespace { + +class ObjectContentHasherTests : public DawnNativeTest {}; + +#define EXPECT_IF_HASH_EQ(eq, a, b) \ + do { \ + ObjectContentHasher ra, rb; \ + ra.Record(a); \ + rb.Record(b); \ + EXPECT_EQ(eq, ra.GetContentHash() == rb.GetContentHash()); \ + } while (0) + +TEST(ObjectContentHasherTests, Pair) { + EXPECT_IF_HASH_EQ(true, (std::pair{"a", 1}), + (std::pair{"a", 1})); + EXPECT_IF_HASH_EQ(false, (std::pair{1, "a"}), + (std::pair{"a", 1})); + EXPECT_IF_HASH_EQ(false, (std::pair{"a", 1}), + (std::pair{"a", 2})); + EXPECT_IF_HASH_EQ(false, (std::pair{"a", 1}), + (std::pair{"b", 1})); +} + +TEST(ObjectContentHasherTests, Vector) { + EXPECT_IF_HASH_EQ(true, (std::vector{0, 1}), (std::vector{0, 1})); + EXPECT_IF_HASH_EQ(false, (std::vector{0, 1}), (std::vector{0, 1, 2})); + EXPECT_IF_HASH_EQ(false, (std::vector{0, 1}), (std::vector{1, 0})); + EXPECT_IF_HASH_EQ(false, (std::vector{0, 1}), (std::vector{})); + EXPECT_IF_HASH_EQ(false, (std::vector{0, 1}), (std::vector{0, 1})); +} + +TEST(ObjectContentHasherTests, Map) { + EXPECT_IF_HASH_EQ(true, (std::map{{"a", 1}, {"b", 2}}), + (std::map{{"b", 2}, {"a", 1}})); + EXPECT_IF_HASH_EQ(false, (std::map{{"a", 1}, {"b", 2}}), + (std::map{{"a", 2}, {"b", 1}})); + EXPECT_IF_HASH_EQ(false, (std::map{{"a", 1}, {"b", 2}}), + (std::map{{"a", 1}, {"b", 2}, {"c", 1}})); + EXPECT_IF_HASH_EQ(false, (std::map{{"a", 1}, {"b", 2}}), + (std::map{})); +} + +TEST(ObjectContentHasherTests, HashCombine) { + ObjectContentHasher ra, rb; + + ra.Record(std::vector{0, 1}); + ra.Record(std::map{{"a", 1}, {"b", 2}}); + + rb.Record(std::map{{"a", 1}, {"b", 2}}); + rb.Record(std::vector{0, 1}); + + EXPECT_NE(ra.GetContentHash(), rb.GetContentHash()); +} + +} // namespace + +} // namespace dawn::native diff --git a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp index 17e25f96ee..1520bb8275 100644 --- a/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/ShaderModuleValidationTests.cpp @@ -450,87 +450,6 @@ TEST_F(ShaderModuleValidationTest, MaximumInterStageShaderComponents) { } } -// Tests that we validate workgroup size limits. -TEST_F(ShaderModuleValidationTest, ComputeWorkgroupSizeLimits) { - auto CheckShaderWithWorkgroupSize = [this](bool success, uint32_t x, uint32_t y, uint32_t z) { - std::ostringstream ss; - ss << "@compute @workgroup_size(" << x << "," << y << "," << z << ") fn main() {}"; - - wgpu::ComputePipelineDescriptor desc; - desc.compute.entryPoint = "main"; - desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); - - if (success) { - device.CreateComputePipeline(&desc); - } else { - ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); - } - }; - - wgpu::Limits supportedLimits = GetSupportedLimits().limits; - - CheckShaderWithWorkgroupSize(true, 1, 1, 1); - CheckShaderWithWorkgroupSize(true, supportedLimits.maxComputeWorkgroupSizeX, 1, 1); - CheckShaderWithWorkgroupSize(true, 1, supportedLimits.maxComputeWorkgroupSizeY, 1); - CheckShaderWithWorkgroupSize(true, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ); - - CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX + 1, 1, 1); - CheckShaderWithWorkgroupSize(false, 1, supportedLimits.maxComputeWorkgroupSizeY + 1, 1); - CheckShaderWithWorkgroupSize(false, 1, 1, supportedLimits.maxComputeWorkgroupSizeZ + 1); - - // No individual dimension exceeds its limit, but the combined size should definitely exceed the - // total invocation limit. - CheckShaderWithWorkgroupSize(false, supportedLimits.maxComputeWorkgroupSizeX, - supportedLimits.maxComputeWorkgroupSizeY, - supportedLimits.maxComputeWorkgroupSizeZ); -} - -// Tests that we validate workgroup storage size limits. -TEST_F(ShaderModuleValidationTest, ComputeWorkgroupStorageSizeLimits) { - wgpu::Limits supportedLimits = GetSupportedLimits().limits; - - constexpr uint32_t kVec4Size = 16; - const uint32_t maxVec4Count = supportedLimits.maxComputeWorkgroupStorageSize / kVec4Size; - constexpr uint32_t kMat4Size = 64; - const uint32_t maxMat4Count = supportedLimits.maxComputeWorkgroupStorageSize / kMat4Size; - - auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count, - uint32_t mat4_count) { - std::ostringstream ss; - std::ostringstream body; - if (vec4_count > 0) { - ss << "var vec4_data: array, " << vec4_count << ">;"; - body << "_ = vec4_data;"; - } - if (mat4_count > 0) { - ss << "var mat4_data: array, " << mat4_count << ">;"; - body << "_ = mat4_data;"; - } - ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }"; - - wgpu::ComputePipelineDescriptor desc; - desc.compute.entryPoint = "main"; - desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); - - if (success) { - device.CreateComputePipeline(&desc); - } else { - ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&desc)); - } - }; - - CheckPipelineWithWorkgroupStorage(true, 1, 1); - CheckPipelineWithWorkgroupStorage(true, maxVec4Count, 0); - CheckPipelineWithWorkgroupStorage(true, 0, maxMat4Count); - CheckPipelineWithWorkgroupStorage(true, maxVec4Count - 4, 1); - CheckPipelineWithWorkgroupStorage(true, 4, maxMat4Count - 1); - - CheckPipelineWithWorkgroupStorage(false, maxVec4Count + 1, 0); - CheckPipelineWithWorkgroupStorage(false, maxVec4Count - 3, 1); - CheckPipelineWithWorkgroupStorage(false, 0, maxMat4Count + 1); - CheckPipelineWithWorkgroupStorage(false, 4, maxMat4Count); -} - // Test that numeric ID must be unique TEST_F(ShaderModuleValidationTest, OverridableConstantsNumericIDConflicts) { ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"( diff --git a/src/dawn/tests/unittests/validation/UnsafeAPIValidationTests.cpp b/src/dawn/tests/unittests/validation/UnsafeAPIValidationTests.cpp index 10d7a0049e..171e037bee 100644 --- a/src/dawn/tests/unittests/validation/UnsafeAPIValidationTests.cpp +++ b/src/dawn/tests/unittests/validation/UnsafeAPIValidationTests.cpp @@ -37,46 +37,6 @@ class UnsafeAPIValidationTest : public ValidationTest { } }; -// Check that pipeline overridable constants are disallowed as part of unsafe APIs. -// TODO(dawn:1041) Remove when implementation for all backend is added -TEST_F(UnsafeAPIValidationTest, PipelineOverridableConstants) { - // Create the placeholder compute pipeline. - wgpu::ComputePipelineDescriptor pipelineDescBase; - pipelineDescBase.compute.entryPoint = "main"; - - // Control case: shader without overridable constant is allowed. - { - wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase; - pipelineDesc.compute.module = - utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}"); - - device.CreateComputePipeline(&pipelineDesc); - } - - // Error case: shader with overridable constant with default value - { - ASSERT_DEVICE_ERROR(utils::CreateShaderModule(device, R"( -@id(1000) override c0: u32 = 1u; -@id(1000) override c1: u32; - -@compute @workgroup_size(1) fn main() { - _ = c0; - _ = c1; -})")); - } - - // Error case: pipeline stage with constant entry is disallowed - { - wgpu::ComputePipelineDescriptor pipelineDesc = pipelineDescBase; - pipelineDesc.compute.module = - utils::CreateShaderModule(device, "@compute @workgroup_size(1) fn main() {}"); - std::vector constants{{nullptr, "c", 1u}}; - pipelineDesc.compute.constants = constants.data(); - pipelineDesc.compute.constantCount = constants.size(); - ASSERT_DEVICE_ERROR(device.CreateComputePipeline(&pipelineDesc)); - } -} - class UnsafeQueryAPIValidationTest : public ValidationTest { protected: WGPUDevice CreateTestDevice(dawn::native::Adapter dawnAdapter) override { diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc index 2f8d09ac02..4dd3c75d50 100644 --- a/src/tint/inspector/inspector.cc +++ b/src/tint/inspector/inspector.cc @@ -133,6 +133,110 @@ Inspector::Inspector(const Program* program) : program_(program) {} Inspector::~Inspector() = default; +EntryPoint Inspector::GetEntryPoint(const tint::ast::Function* func) { + EntryPoint entry_point; + TINT_ASSERT(Inspector, func != nullptr); + TINT_ASSERT(Inspector, func->IsEntryPoint()); + + auto* sem = program_->Sem().Get(func); + + entry_point.name = program_->Symbols().NameFor(func->symbol); + entry_point.remapped_name = program_->Symbols().NameFor(func->symbol); + + switch (func->PipelineStage()) { + case ast::PipelineStage::kCompute: { + entry_point.stage = PipelineStage::kCompute; + + auto wgsize = sem->WorkgroupSize(); + if (!wgsize[0].overridable_const && !wgsize[1].overridable_const && + !wgsize[2].overridable_const) { + entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, wgsize[2].value}; + } + break; + } + case ast::PipelineStage::kFragment: { + entry_point.stage = PipelineStage::kFragment; + break; + } + case ast::PipelineStage::kVertex: { + entry_point.stage = PipelineStage::kVertex; + break; + } + default: { + TINT_UNREACHABLE(Inspector, diagnostics_) + << "invalid pipeline stage for entry point '" << entry_point.name << "'"; + break; + } + } + + for (auto* param : sem->Parameters()) { + AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol), + param->Type(), param->Declaration()->attributes, + entry_point.input_variables); + + entry_point.input_position_used |= ContainsBuiltin( + ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes); + entry_point.front_facing_used |= ContainsBuiltin( + ast::BuiltinValue::kFrontFacing, param->Type(), param->Declaration()->attributes); + entry_point.sample_index_used |= ContainsBuiltin( + ast::BuiltinValue::kSampleIndex, param->Type(), param->Declaration()->attributes); + entry_point.input_sample_mask_used |= ContainsBuiltin( + ast::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes); + entry_point.num_workgroups_used |= ContainsBuiltin( + ast::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes); + } + + if (!sem->ReturnType()->Is()) { + AddEntryPointInOutVariables("", sem->ReturnType(), func->return_type_attributes, + entry_point.output_variables); + + entry_point.output_sample_mask_used = ContainsBuiltin( + ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes); + } + + for (auto* var : sem->TransitivelyReferencedGlobals()) { + auto* decl = var->Declaration(); + + auto name = program_->Symbols().NameFor(decl->symbol); + + auto* global = var->As(); + if (global && global->Declaration()->Is()) { + Override override; + override.name = name; + override.id = global->OverrideId(); + auto* type = var->Type(); + TINT_ASSERT(Inspector, type->is_scalar()); + if (type->is_bool_scalar_or_vector()) { + override.type = Override::Type::kBool; + } else if (type->is_float_scalar()) { + override.type = Override::Type::kFloat32; + } else if (type->is_signed_integer_scalar()) { + override.type = Override::Type::kInt32; + } else if (type->is_unsigned_integer_scalar()) { + override.type = Override::Type::kUint32; + } else { + TINT_UNREACHABLE(Inspector, diagnostics_); + } + + override.is_initialized = global->Declaration()->constructor; + override.is_id_specified = + ast::HasAttribute(global->Declaration()->attributes); + + entry_point.overrides.push_back(override); + } + } + + return entry_point; +} + +EntryPoint Inspector::GetEntryPoint(const std::string& entry_point_name) { + auto* func = FindEntryPointByName(entry_point_name); + if (!func) { + return EntryPoint(); + } + return GetEntryPoint(func); +} + std::vector Inspector::GetEntryPoints() { std::vector result; @@ -141,97 +245,7 @@ std::vector Inspector::GetEntryPoints() { continue; } - auto* sem = program_->Sem().Get(func); - - EntryPoint entry_point; - entry_point.name = program_->Symbols().NameFor(func->symbol); - entry_point.remapped_name = program_->Symbols().NameFor(func->symbol); - - switch (func->PipelineStage()) { - case ast::PipelineStage::kCompute: { - entry_point.stage = PipelineStage::kCompute; - - auto wgsize = sem->WorkgroupSize(); - if (!wgsize[0].overridable_const && !wgsize[1].overridable_const && - !wgsize[2].overridable_const) { - entry_point.workgroup_size = {wgsize[0].value, wgsize[1].value, - wgsize[2].value}; - } - break; - } - case ast::PipelineStage::kFragment: { - entry_point.stage = PipelineStage::kFragment; - break; - } - case ast::PipelineStage::kVertex: { - entry_point.stage = PipelineStage::kVertex; - break; - } - default: { - TINT_UNREACHABLE(Inspector, diagnostics_) - << "invalid pipeline stage for entry point '" << entry_point.name << "'"; - break; - } - } - - for (auto* param : sem->Parameters()) { - AddEntryPointInOutVariables(program_->Symbols().NameFor(param->Declaration()->symbol), - param->Type(), param->Declaration()->attributes, - entry_point.input_variables); - - entry_point.input_position_used |= ContainsBuiltin( - ast::BuiltinValue::kPosition, param->Type(), param->Declaration()->attributes); - entry_point.front_facing_used |= ContainsBuiltin( - ast::BuiltinValue::kFrontFacing, param->Type(), param->Declaration()->attributes); - entry_point.sample_index_used |= ContainsBuiltin( - ast::BuiltinValue::kSampleIndex, param->Type(), param->Declaration()->attributes); - entry_point.input_sample_mask_used |= ContainsBuiltin( - ast::BuiltinValue::kSampleMask, param->Type(), param->Declaration()->attributes); - entry_point.num_workgroups_used |= ContainsBuiltin( - ast::BuiltinValue::kNumWorkgroups, param->Type(), param->Declaration()->attributes); - } - - if (!sem->ReturnType()->Is()) { - AddEntryPointInOutVariables("", sem->ReturnType(), func->return_type_attributes, - entry_point.output_variables); - - entry_point.output_sample_mask_used = ContainsBuiltin( - ast::BuiltinValue::kSampleMask, sem->ReturnType(), func->return_type_attributes); - } - - for (auto* var : sem->TransitivelyReferencedGlobals()) { - auto* decl = var->Declaration(); - - auto name = program_->Symbols().NameFor(decl->symbol); - - auto* global = var->As(); - if (global && global->Declaration()->Is()) { - Override override; - override.name = name; - override.id = global->OverrideId(); - auto* type = var->Type(); - TINT_ASSERT(Inspector, type->is_scalar()); - if (type->is_bool_scalar_or_vector()) { - override.type = Override::Type::kBool; - } else if (type->is_float_scalar()) { - override.type = Override::Type::kFloat32; - } else if (type->is_signed_integer_scalar()) { - override.type = Override::Type::kInt32; - } else if (type->is_unsigned_integer_scalar()) { - override.type = Override::Type::kUint32; - } else { - TINT_UNREACHABLE(Inspector, diagnostics_); - } - - override.is_initialized = global->Declaration()->constructor; - override.is_id_specified = - ast::HasAttribute(global->Declaration()->attributes); - - entry_point.overrides.push_back(override); - } - } - - result.push_back(std::move(entry_point)); + result.push_back(GetEntryPoint(func)); } return result; diff --git a/src/tint/inspector/inspector.h b/src/tint/inspector/inspector.h index f3fe27008a..684852ce28 100644 --- a/src/tint/inspector/inspector.h +++ b/src/tint/inspector/inspector.h @@ -55,6 +55,10 @@ class Inspector { /// @returns vector of entry point information std::vector GetEntryPoints(); + /// @param entry_point name of the entry point to get information about + /// @returns the entry point information + EntryPoint GetEntryPoint(const std::string& entry_point); + /// @returns map of override identifier to initial value std::map GetOverrideDefaultValues(); @@ -230,6 +234,10 @@ class Inspector { /// whenever a set of expressions are resolved to globals. template void GetOriginatingResources(std::array exprs, F&& cb); + + /// @param func the function of the entry point. Must be non-nullptr and true for IsEntryPoint() + /// @returns the entry point information + EntryPoint GetEntryPoint(const tint::ast::Function* func); }; } // namespace tint::inspector diff --git a/src/tint/transform/substitute_override.h b/src/tint/transform/substitute_override.h index 9ea315d92d..940e11d2fa 100644 --- a/src/tint/transform/substitute_override.h +++ b/src/tint/transform/substitute_override.h @@ -20,6 +20,7 @@ #include "tint/override_id.h" +#include "src/tint/reflection.h" #include "src/tint/transform/transform.h" namespace tint::transform { @@ -63,6 +64,9 @@ class SubstituteOverride final : public Castable /// The value is always a double coming into the transform and will be /// converted to the correct type through and initializer. std::unordered_map map; + + /// Reflect the fields of this class so that it can be used by tint::ForeachField() + TINT_REFLECT(map); }; /// Constructor