From 9a6acc419eb550a46bf77c4174c521ddfc5ad98d Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 27 Jul 2022 20:50:40 +0000 Subject: [PATCH] tint: Add tint::OverrideId This is a public API definition of a program-unique override identifier. Bug: tint:1155 Change-Id: I6e55d43208e72a7a316557a89e2169d1b952f9bf Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97006 Kokoro: Kokoro Reviewed-by: Dan Sinclair Commit-Queue: Ben Clayton --- include/tint/override_id.h | 60 +++++ src/dawn/native/Pipeline.cpp | 8 +- src/dawn/native/ShaderModule.cpp | 45 ++-- src/dawn/native/ShaderModule.h | 22 +- src/dawn/native/d3d12/ShaderModuleD3D12.cpp | 17 +- src/dawn/native/metal/UtilsMetal.mm | 22 +- src/dawn/native/vulkan/ComputePipelineVk.cpp | 2 +- src/dawn/native/vulkan/RenderPipelineVk.cpp | 2 +- src/dawn/native/vulkan/UtilsVulkan.cpp | 26 +- src/dawn/native/vulkan/UtilsVulkan.h | 4 +- src/tint/BUILD.gn | 2 +- src/tint/CMakeLists.txt | 2 +- src/tint/fuzzers/tint_common_fuzzer.cc | 4 +- src/tint/inspector/entry_point.h | 22 +- src/tint/inspector/inspector.cc | 54 ++-- src/tint/inspector/inspector.h | 8 +- src/tint/inspector/inspector_test.cc | 239 +++++++++--------- src/tint/inspector/test_inspector_builder.h | 25 -- src/tint/program_builder.h | 17 +- ...able_constant_test.cc => override_test.cc} | 42 +-- src/tint/resolver/resolver.cc | 51 ++-- src/tint/resolver/resolver.h | 5 +- src/tint/resolver/validator.cc | 26 +- src/tint/resolver/validator.h | 8 +- src/tint/resolver/variable_validation_test.cc | 31 +++ src/tint/sem/variable.h | 8 +- src/tint/transform/single_entry_point.cc | 6 +- src/tint/writer/glsl/generator_impl.cc | 12 +- src/tint/writer/hlsl/generator_impl.cc | 12 +- .../generator_impl_module_constant_test.cc | 15 +- src/tint/writer/msl/generator_impl.cc | 2 +- src/tint/writer/spirv/builder.cc | 6 +- 32 files changed, 446 insertions(+), 359 deletions(-) create mode 100644 include/tint/override_id.h rename src/tint/resolver/{pipeline_overridable_constant_test.cc => override_test.cc} (71%) diff --git a/include/tint/override_id.h b/include/tint/override_id.h new file mode 100644 index 0000000000..957673de19 --- /dev/null +++ b/include/tint/override_id.h @@ -0,0 +1,60 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_OVERRIDE_ID_H_ +#define SRC_TINT_OVERRIDE_ID_H_ + +#include + +namespace tint { + +/// OverrideId is a numerical identifier for an override variable, unique per program. +struct OverrideId { + uint16_t value = 0; +}; + +/// Equality operator for OverrideId +/// @param lhs the OverrideId on the left of the '=' operator +/// @param rhs the OverrideId on the right of the '=' operator +/// @returns true if `lhs` is equal to `rhs` +inline bool operator==(OverrideId lhs, OverrideId rhs) { + return lhs.value == rhs.value; +} + +/// Less-than operator for OverrideId +/// @param lhs the OverrideId on the left of the '<' operator +/// @param rhs the OverrideId on the right of the '<' operator +/// @returns true if `lhs` comes before `rhs` +inline bool operator<(OverrideId lhs, OverrideId rhs) { + return lhs.value < rhs.value; +} + +} // namespace tint + +namespace std { + +/// Custom std::hash specialization for tint::OverrideId. +template <> +class hash { + public: + /// @param id the override identifier + /// @return the hash of the override identifier + inline std::size_t operator()(tint::OverrideId id) const { + return std::hash()(id.value); + } +}; + +} // namespace std + +#endif // SRC_TINT_OVERRIDE_ID_H_ diff --git a/src/dawn/native/Pipeline.cpp b/src/dawn/native/Pipeline.cpp index 6bee1eb5a5..513bd2ac4c 100644 --- a/src/dawn/native/Pipeline.cpp +++ b/src/dawn/native/Pipeline.cpp @@ -66,16 +66,16 @@ MaybeError ValidateProgrammableStage(DeviceBase* device, // Validate if overridable constants exist in shader module // pipelineBase is not yet constructed at this moment so iterate constants from descriptor - size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size(); + size_t numUninitializedConstants = metadata.uninitializedOverrides.size(); // Keep an initialized constants sets to handle duplicate initialization cases std::unordered_set stageInitializedConstantIdentifiers; for (uint32_t i = 0; i < constantCount; i++) { - DAWN_INVALID_IF(metadata.overridableConstants.count(constants[i].key) == 0, + DAWN_INVALID_IF(metadata.overrides.count(constants[i].key) == 0, "Pipeline overridable constant \"%s\" not found in %s.", constants[i].key, module); if (stageInitializedConstantIdentifiers.count(constants[i].key) == 0) { - if (metadata.uninitializedOverridableConstants.count(constants[i].key) > 0) { + if (metadata.uninitializedOverrides.count(constants[i].key) > 0) { numUninitializedConstants--; } stageInitializedConstantIdentifiers.insert(constants[i].key); @@ -91,7 +91,7 @@ MaybeError ValidateProgrammableStage(DeviceBase* device, if (DAWN_UNLIKELY(numUninitializedConstants > 0)) { std::string uninitializedConstantsArray; bool isFirst = true; - for (std::string identifier : metadata.uninitializedOverridableConstants) { + for (std::string identifier : metadata.uninitializedOverrides) { if (stageInitializedConstantIdentifiers.count(identifier) > 0) { continue; } diff --git a/src/dawn/native/ShaderModule.cpp b/src/dawn/native/ShaderModule.cpp index c851bdbdd3..d06665045c 100644 --- a/src/dawn/native/ShaderModule.cpp +++ b/src/dawn/native/ShaderModule.cpp @@ -358,17 +358,16 @@ ResultOrError TintInterpolationSamplingToInterpolationSam UNREACHABLE(); } -EntryPointMetadata::OverridableConstant::Type FromTintOverridableConstantType( - tint::inspector::OverridableConstant::Type type) { +EntryPointMetadata::Override::Type FromTintOverrideType(tint::inspector::Override::Type type) { switch (type) { - case tint::inspector::OverridableConstant::Type::kBool: - return EntryPointMetadata::OverridableConstant::Type::Boolean; - case tint::inspector::OverridableConstant::Type::kFloat32: - return EntryPointMetadata::OverridableConstant::Type::Float32; - case tint::inspector::OverridableConstant::Type::kInt32: - return EntryPointMetadata::OverridableConstant::Type::Int32; - case tint::inspector::OverridableConstant::Type::kUint32: - return EntryPointMetadata::OverridableConstant::Type::Uint32; + case tint::inspector::Override::Type::kBool: + return EntryPointMetadata::Override::Type::Boolean; + case tint::inspector::Override::Type::kFloat32: + return EntryPointMetadata::Override::Type::Float32; + case tint::inspector::Override::Type::kInt32: + return EntryPointMetadata::Override::Type::Int32; + case tint::inspector::Override::Type::kUint32: + return EntryPointMetadata::Override::Type::Uint32; } UNREACHABLE(); } @@ -610,17 +609,17 @@ ResultOrError> ReflectEntryPointUsingTint( return invalid; \ })() - if (!entryPoint.overridable_constants.empty()) { + if (!entryPoint.overrides.empty()) { DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs), "Pipeline overridable constants are disallowed because they " "are partially implemented."); - const auto& name2Id = inspector->GetConstantNameToIdMap(); - const auto& id2Scalar = inspector->GetConstantIDs(); + const auto& name2Id = inspector->GetNamedOverrideIds(); + const auto& id2Scalar = inspector->GetOverrideDefaultValues(); - for (auto& c : entryPoint.overridable_constants) { - uint32_t id = name2Id.at(c.name); - OverridableConstantScalar defaultValue; + for (auto& c : entryPoint.overrides) { + auto id = name2Id.at(c.name); + OverrideScalar defaultValue; if (c.is_initialized) { // if it is initialized, the scalar must exist const auto& scalar = id2Scalar.at(id); @@ -636,21 +635,19 @@ ResultOrError> ReflectEntryPointUsingTint( UNREACHABLE(); } } - EntryPointMetadata::OverridableConstant constant = { - id, FromTintOverridableConstantType(c.type), c.is_initialized, defaultValue}; + EntryPointMetadata::Override override = {id.value, FromTintOverrideType(c.type), + c.is_initialized, defaultValue}; - std::string identifier = - c.is_numeric_id_specified ? std::to_string(constant.id) : c.name; - metadata->overridableConstants[identifier] = constant; + std::string identifier = c.is_id_specified ? std::to_string(override.id) : c.name; + metadata->overrides[identifier] = override; if (!c.is_initialized) { auto [_, inserted] = - metadata->uninitializedOverridableConstants.emplace(std::move(identifier)); + metadata->uninitializedOverrides.emplace(std::move(identifier)); // The insertion should have taken place ASSERT(inserted); } else { - auto [_, inserted] = - metadata->initializedOverridableConstants.emplace(std::move(identifier)); + auto [_, inserted] = metadata->initializedOverrides.emplace(std::move(identifier)); // The insertion should have taken place ASSERT(inserted); } diff --git a/src/dawn/native/ShaderModule.h b/src/dawn/native/ShaderModule.h index a491a2141b..d1e450caec 100644 --- a/src/dawn/native/ShaderModule.h +++ b/src/dawn/native/ShaderModule.h @@ -155,8 +155,8 @@ struct ShaderBindingInfo { using BindingGroupInfoMap = std::map; using BindingInfoArray = ityp::array; -// The WebGPU overridable constants only support these scalar types -union OverridableConstantScalar { +// The WebGPU override variables only support these scalar types +union OverrideScalar { // Use int32_t for boolean to initialize the full 32bit int32_t b; float f32; @@ -216,9 +216,9 @@ struct EntryPointMetadata { // The shader stage for this binding. SingleShaderStage stage; - struct OverridableConstant { + struct Override { uint32_t id; - // Match tint::inspector::OverridableConstant::Type + // Match tint::inspector::Override::Type // Bool is defined as a macro on linux X11 and cannot compile enum class Type { Boolean, Float32, Uint32, Int32 } type; @@ -230,23 +230,23 @@ struct EntryPointMetadata { // Store the default initialized value in shader // This is used by metal backend as the function_constant does not have dafault values // Initialized when isInitialized == true - OverridableConstantScalar defaultValue; + OverrideScalar defaultValue; }; - using OverridableConstantsMap = std::unordered_map; + using OverridesMap = std::unordered_map; - // Map identifier to overridable constant + // Map identifier to override variable // Identifier is unique: either the variable name or the numeric ID if specified - OverridableConstantsMap overridableConstants; + OverridesMap overrides; - // Overridable constants that are not initialized in shaders + // Override variables that are not initialized in shaders // They need value initialization from pipeline stage or it is a validation error - std::unordered_set uninitializedOverridableConstants; + std::unordered_set uninitializedOverrides; // Store constants with shader initialized values as well // This is used by metal backend to set values with default initializers that are not // overridden - std::unordered_set initializedOverridableConstants; + std::unordered_set initializedOverrides; bool usesNumWorkgroups = false; // Used at render pipeline validation. diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp index 7a23be5e7f..320449a1fa 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp @@ -111,17 +111,17 @@ std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) { return out.str(); } -std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType, - const OverridableConstantScalar* entry, +std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType, + const OverrideScalar* entry, double value = 0) { switch (dawnType) { - case EntryPointMetadata::OverridableConstant::Type::Boolean: + case EntryPointMetadata::Override::Type::Boolean: return std::to_string(entry ? entry->b : static_cast(value)); - case EntryPointMetadata::OverridableConstant::Type::Float32: + case EntryPointMetadata::Override::Type::Float32: return FloatToStringWithPrecision(entry ? entry->f32 : static_cast(value)); - case EntryPointMetadata::OverridableConstant::Type::Int32: + case EntryPointMetadata::Override::Type::Int32: return std::to_string(entry ? entry->i32 : static_cast(value)); - case EntryPointMetadata::OverridableConstant::Type::Uint32: + case EntryPointMetadata::Override::Type::Uint32: return std::to_string(entry ? entry->u32 : static_cast(value)); default: UNREACHABLE(); @@ -133,7 +133,7 @@ constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; void GetOverridableConstantsDefines( std::vector>* defineStrings, const PipelineConstantEntries* pipelineConstantEntries, - const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) { + const EntryPointMetadata::OverridesMap* shaderEntryPointConstants) { std::unordered_set overriddenConstants; // Set pipeline overridden values @@ -305,8 +305,7 @@ struct ShaderCompilationRequest { GetOverridableConstantsDefines( &request.defineStrings, &programmableStage.constants, - &programmableStage.module->GetEntryPoint(programmableStage.entryPoint) - .overridableConstants); + &programmableStage.module->GetEntryPoint(programmableStage.entryPoint).overrides); return std::move(request); } diff --git a/src/dawn/native/metal/UtilsMetal.mm b/src/dawn/native/metal/UtilsMetal.mm index 22d3681ab5..df122f3e3e 100644 --- a/src/dawn/native/metal/UtilsMetal.mm +++ b/src/dawn/native/metal/UtilsMetal.mm @@ -332,7 +332,7 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage, ShaderModule* shaderModule = ToBackend(programmableStage.module.Get()); const char* shaderEntryPoint = programmableStage.entryPoint.c_str(); const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint); - if (entryPointMetadata.overridableConstants.size() == 0) { + if (entryPointMetadata.overrides.size() == 0) { DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout, functionData, nil, sampleMask, renderPipeline)); return {}; @@ -345,29 +345,29 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage, std::unordered_set overriddenConstants; - auto switchType = [&](EntryPointMetadata::OverridableConstant::Type dawnType, - MTLDataType* type, OverridableConstantScalar* entry, + auto switchType = [&](EntryPointMetadata::Override::Type dawnType, + MTLDataType* type, OverrideScalar* entry, double value = 0) { switch (dawnType) { - case EntryPointMetadata::OverridableConstant::Type::Boolean: + case EntryPointMetadata::Override::Type::Boolean: *type = MTLDataTypeBool; if (entry) { entry->b = static_cast(value); } break; - case EntryPointMetadata::OverridableConstant::Type::Float32: + case EntryPointMetadata::Override::Type::Float32: *type = MTLDataTypeFloat; if (entry) { entry->f32 = static_cast(value); } break; - case EntryPointMetadata::OverridableConstant::Type::Int32: + case EntryPointMetadata::Override::Type::Int32: *type = MTLDataTypeInt; if (entry) { entry->i32 = static_cast(value); } break; - case EntryPointMetadata::OverridableConstant::Type::Uint32: + case EntryPointMetadata::Override::Type::Uint32: *type = MTLDataTypeUInt; if (entry) { entry->u32 = static_cast(value); @@ -382,10 +382,10 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage, overriddenConstants.insert(name); // This is already validated so `name` must exist - const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name); + const auto& moduleConstant = entryPointMetadata.overrides.at(name); MTLDataType type; - OverridableConstantScalar entry{}; + OverrideScalar entry{}; switchType(moduleConstant.type, &type, &entry, value); @@ -394,14 +394,14 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage, // Set shader initialized default values because MSL function_constant // has no default value - for (const std::string& name : entryPointMetadata.initializedOverridableConstants) { + for (const std::string& name : entryPointMetadata.initializedOverrides) { if (overriddenConstants.count(name) != 0) { // This constant already has overridden value continue; } // Must exist because it is validated - const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name); + const auto& moduleConstant = entryPointMetadata.overrides.at(name); ASSERT(moduleConstant.isInitialized); MTLDataType type; diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp index b12006688a..25bb888a06 100644 --- a/src/dawn/native/vulkan/ComputePipelineVk.cpp +++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp @@ -66,7 +66,7 @@ MaybeError ComputePipeline::Initialize() { createInfo.stage.module = moduleAndSpirv.module; createInfo.stage.pName = computeStage.entryPoint.c_str(); - std::vector specializationDataEntries; + std::vector specializationDataEntries; std::vector specializationMapEntries; VkSpecializationInfo specializationInfo{}; createInfo.stage.pSpecializationInfo = GetVkSpecializationInfo( diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp index 39a24b7627..6f07ee39ca 100644 --- a/src/dawn/native/vulkan/RenderPipelineVk.cpp +++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp @@ -341,7 +341,7 @@ MaybeError RenderPipeline::Initialize() { // There are at most 2 shader stages in render pipeline, i.e. vertex and fragment std::array shaderStages; - std::array, 2> specializationDataEntriesPerStages; + std::array, 2> specializationDataEntriesPerStages; std::array, 2> specializationMapEntriesPerStages; std::array specializationInfoPerStages; uint32_t stageCount = 0; diff --git a/src/dawn/native/vulkan/UtilsVulkan.cpp b/src/dawn/native/vulkan/UtilsVulkan.cpp index 2595112faa..ca8b4052ca 100644 --- a/src/dawn/native/vulkan/UtilsVulkan.cpp +++ b/src/dawn/native/vulkan/UtilsVulkan.cpp @@ -261,7 +261,7 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName) { VkSpecializationInfo* GetVkSpecializationInfo( const ProgrammableStage& programmableStage, VkSpecializationInfo* specializationInfo, - std::vector* specializationDataEntries, + std::vector* specializationDataEntries, std::vector* specializationMapEntries) { ASSERT(specializationInfo); ASSERT(specializationDataEntries); @@ -279,26 +279,25 @@ VkSpecializationInfo* GetVkSpecializationInfo( double value = pipelineConstant.second; // This is already validated so `identifier` must exist - const auto& moduleConstant = entryPointMetaData.overridableConstants.at(identifier); + const auto& moduleConstant = entryPointMetaData.overrides.at(identifier); - specializationMapEntries->push_back( - VkSpecializationMapEntry{moduleConstant.id, - static_cast(specializationDataEntries->size() * - sizeof(OverridableConstantScalar)), - sizeof(OverridableConstantScalar)}); + specializationMapEntries->push_back(VkSpecializationMapEntry{ + moduleConstant.id, + static_cast(specializationDataEntries->size() * sizeof(OverrideScalar)), + sizeof(OverrideScalar)}); - OverridableConstantScalar entry{}; + OverrideScalar entry{}; switch (moduleConstant.type) { - case EntryPointMetadata::OverridableConstant::Type::Boolean: + case EntryPointMetadata::Override::Type::Boolean: entry.b = static_cast(value); break; - case EntryPointMetadata::OverridableConstant::Type::Float32: + case EntryPointMetadata::Override::Type::Float32: entry.f32 = static_cast(value); break; - case EntryPointMetadata::OverridableConstant::Type::Int32: + case EntryPointMetadata::Override::Type::Int32: entry.i32 = static_cast(value); break; - case EntryPointMetadata::OverridableConstant::Type::Uint32: + case EntryPointMetadata::Override::Type::Uint32: entry.u32 = static_cast(value); break; default: @@ -309,8 +308,7 @@ VkSpecializationInfo* GetVkSpecializationInfo( specializationInfo->mapEntryCount = static_cast(specializationMapEntries->size()); specializationInfo->pMapEntries = specializationMapEntries->data(); - specializationInfo->dataSize = - specializationDataEntries->size() * sizeof(OverridableConstantScalar); + specializationInfo->dataSize = specializationDataEntries->size() * sizeof(OverrideScalar); specializationInfo->pData = specializationDataEntries->data(); return specializationInfo; diff --git a/src/dawn/native/vulkan/UtilsVulkan.h b/src/dawn/native/vulkan/UtilsVulkan.h index 7c63b1dfcd..3e2748e173 100644 --- a/src/dawn/native/vulkan/UtilsVulkan.h +++ b/src/dawn/native/vulkan/UtilsVulkan.h @@ -24,7 +24,7 @@ namespace dawn::native { struct ProgrammableStage; -union OverridableConstantScalar; +union OverrideScalar; } // namespace dawn::native namespace dawn::native::vulkan { @@ -150,7 +150,7 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName); VkSpecializationInfo* GetVkSpecializationInfo( const ProgrammableStage& programmableStage, VkSpecializationInfo* specializationInfo, - std::vector* specializationDataEntries, + std::vector* specializationDataEntries, std::vector* specializationMapEntries); } // namespace dawn::native::vulkan diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 86410edd73..1b3342e979 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -1110,7 +1110,7 @@ if (tint_build_unittests) { "resolver/is_host_shareable_test.cc", "resolver/is_storeable_test.cc", "resolver/materialize_test.cc", - "resolver/pipeline_overridable_constant_test.cc", + "resolver/override_test.cc", "resolver/ptr_ref_test.cc", "resolver/ptr_ref_validation_test.cc", "resolver/resolver_behavior_test.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 25c2723a85..6308c5f9ef 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -794,7 +794,7 @@ if(TINT_BUILD_TESTS) resolver/is_host_shareable_test.cc resolver/is_storeable_test.cc resolver/materialize_test.cc - resolver/pipeline_overridable_constant_test.cc + resolver/override_test.cc resolver/ptr_ref_test.cc resolver/ptr_ref_validation_test.cc resolver/resolver_behavior_test.cc diff --git a/src/tint/fuzzers/tint_common_fuzzer.cc b/src/tint/fuzzers/tint_common_fuzzer.cc index ac57071ad0..f547d559c7 100644 --- a/src/tint/fuzzers/tint_common_fuzzer.cc +++ b/src/tint/fuzzers/tint_common_fuzzer.cc @@ -268,10 +268,10 @@ void CommonFuzzer::RunInspector(Program* program) { auto entry_points = inspector.GetEntryPoints(); CHECK_INSPECTOR(program, inspector); - auto constant_ids = inspector.GetConstantIDs(); + auto override_ids = inspector.GetOverrideDefaultValues(); CHECK_INSPECTOR(program, inspector); - auto constant_name_to_id = inspector.GetConstantNameToIdMap(); + auto override_name_to_id = inspector.GetNamedOverrideIds(); CHECK_INSPECTOR(program, inspector); for (auto& ep : entry_points) { diff --git a/src/tint/inspector/entry_point.h b/src/tint/inspector/entry_point.h index 99683b07e5..5d119d4a40 100644 --- a/src/tint/inspector/entry_point.h +++ b/src/tint/inspector/entry_point.h @@ -20,6 +20,8 @@ #include #include +#include "tint/override_id.h" + #include "src/tint/ast/interpolate_attribute.h" #include "src/tint/ast/pipeline_stage.h" @@ -93,14 +95,13 @@ InterpolationType ASTToInspectorInterpolationType(ast::InterpolationType ast_typ /// @returns the publicly visible equivalent InterpolationSampling ASTToInspectorInterpolationSampling(ast::InterpolationSampling sampling); -/// Reflection data about a pipeline overridable constant referenced by an entry -/// point -struct OverridableConstant { - /// Name of the constant +/// Reflection data about an override variable referenced by an entry point +struct Override { + /// Name of the override std::string name; - /// ID of the constant - uint16_t numeric_id; + /// ID of the override + OverrideId id; /// Type of the scalar enum class Type { @@ -113,12 +114,11 @@ struct OverridableConstant { /// Type of the scalar Type type; - /// Does this pipeline overridable constant have an initializer? + /// Does this override have an initializer? bool is_initialized = false; - /// Does this pipeline overridable constant have a numeric ID specified - /// explicitly? - bool is_numeric_id_specified = false; + /// Does this override have a numeric ID specified explicitly? + bool is_id_specified = false; }; /// The pipeline stage @@ -159,7 +159,7 @@ struct EntryPoint { /// List of the output variable accessed via this entry point. std::vector output_variables; /// List of the pipeline overridable constants accessed via this entry point. - std::vector overridable_constants; + std::vector overrides; /// Does the entry point use the sample_mask builtin as an input builtin /// variable. bool input_sample_mask_used = false; diff --git a/src/tint/inspector/inspector.cc b/src/tint/inspector/inspector.cc index 931cd80d1e..ae56de506d 100644 --- a/src/tint/inspector/inspector.cc +++ b/src/tint/inspector/inspector.cc @@ -206,28 +206,28 @@ std::vector Inspector::GetEntryPoints() { auto* global = var->As(); if (global && global->Declaration()->Is()) { - OverridableConstant overridable_constant; - overridable_constant.name = name; - overridable_constant.numeric_id = global->ConstantId(); + Override override; + override.name = name; + override.id = global->OverrideId(); auto* type = var->Type(); TINT_ASSERT(Inspector, type->is_scalar()); if (type->is_bool_scalar_or_vector()) { - overridable_constant.type = OverridableConstant::Type::kBool; + override.type = Override::Type::kBool; } else if (type->is_float_scalar()) { - overridable_constant.type = OverridableConstant::Type::kFloat32; + override.type = Override::Type::kFloat32; } else if (type->is_signed_integer_scalar()) { - overridable_constant.type = OverridableConstant::Type::kInt32; + override.type = Override::Type::kInt32; } else if (type->is_unsigned_integer_scalar()) { - overridable_constant.type = OverridableConstant::Type::kUint32; + override.type = Override::Type::kUint32; } else { TINT_UNREACHABLE(Inspector, diagnostics_); } - overridable_constant.is_initialized = global->Declaration()->constructor; - overridable_constant.is_numeric_id_specified = + override.is_initialized = global->Declaration()->constructor; + override.is_id_specified = ast::HasAttribute(global->Declaration()->attributes); - entry_point.overridable_constants.push_back(overridable_constant); + entry_point.overrides.push_back(override); } } @@ -237,37 +237,37 @@ std::vector Inspector::GetEntryPoints() { return result; } -std::map Inspector::GetConstantIDs() { - std::map result; +std::map Inspector::GetOverrideDefaultValues() { + std::map result; for (auto* var : program_->AST().GlobalVariables()) { auto* global = program_->Sem().Get(var); if (!global || !global->Declaration()->Is()) { continue; } - // If there are conflicting defintions for a constant id, that is invalid + // If there are conflicting defintions for an override id, that is invalid // WGSL, so the resolver should catch it. Thus here the inspector just - // assumes all definitions of the constant id are the same, so only needs - // to find the first reference to constant id. - uint32_t constant_id = global->ConstantId(); - if (result.find(constant_id) != result.end()) { + // assumes all definitions of the override id are the same, so only needs + // to find the first reference to override id. + OverrideId override_id = global->OverrideId(); + if (result.find(override_id) != result.end()) { continue; } if (!var->constructor) { - result[constant_id] = Scalar(); + result[override_id] = Scalar(); continue; } auto* literal = var->constructor->As(); if (!literal) { // This is invalid WGSL, but handling gracefully. - result[constant_id] = Scalar(); + result[override_id] = Scalar(); continue; } if (auto* l = literal->As()) { - result[constant_id] = Scalar(l->value); + result[override_id] = Scalar(l->value); continue; } @@ -275,32 +275,32 @@ std::map Inspector::GetConstantIDs() { switch (l->suffix) { case ast::IntLiteralExpression::Suffix::kNone: case ast::IntLiteralExpression::Suffix::kI: - result[constant_id] = Scalar(static_cast(l->value)); + result[override_id] = Scalar(static_cast(l->value)); continue; case ast::IntLiteralExpression::Suffix::kU: - result[constant_id] = Scalar(static_cast(l->value)); + result[override_id] = Scalar(static_cast(l->value)); continue; } } if (auto* l = literal->As()) { - result[constant_id] = Scalar(static_cast(l->value)); + result[override_id] = Scalar(static_cast(l->value)); continue; } - result[constant_id] = Scalar(); + result[override_id] = Scalar(); } return result; } -std::map Inspector::GetConstantNameToIdMap() { - std::map result; +std::map Inspector::GetNamedOverrideIds() { + std::map result; for (auto* var : program_->AST().GlobalVariables()) { auto* global = program_->Sem().Get(var); if (global && global->Declaration()->Is()) { auto name = program_->Symbols().NameFor(var->symbol); - result[name] = global->ConstantId(); + result[name] = global->OverrideId(); } } return result; diff --git a/src/tint/inspector/inspector.h b/src/tint/inspector/inspector.h index bc31fda6ee..2690581d3b 100644 --- a/src/tint/inspector/inspector.h +++ b/src/tint/inspector/inspector.h @@ -23,6 +23,8 @@ #include #include +#include "tint/override_id.h" + #include "src/tint/inspector/entry_point.h" #include "src/tint/inspector/resource_binding.h" #include "src/tint/inspector/scalar.h" @@ -53,11 +55,11 @@ class Inspector { /// @returns vector of entry point information std::vector GetEntryPoints(); - /// @returns map of const_id to initial value - std::map GetConstantIDs(); + /// @returns map of override identifier to initial value + std::map GetOverrideDefaultValues(); /// @returns map of module-constant name to pipeline constant ID - std::map GetConstantNameToIdMap(); + std::map GetNamedOverrideIds(); /// @param entry_point name of the entry point to get information about. /// @returns the total size of shared storage required by an entry point, diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc index 31cd803263..2bb862d19c 100644 --- a/src/tint/inspector/inspector_test.cc +++ b/src/tint/inspector/inspector_test.cc @@ -60,7 +60,7 @@ struct InspectorGetEntryPointInterpolateTestParams { class InspectorGetEntryPointInterpolateTest : public InspectorBuilder, public testing::TestWithParam {}; -class InspectorGetConstantIDsTest : public InspectorBuilder, public testing::Test {}; +class InspectorGetOverrideDefaultValuesTest : public InspectorBuilder, public testing::Test {}; class InspectorGetConstantNameToIdMapTest : public InspectorBuilder, public testing::Test {}; class InspectorGetStorageSizeTest : public InspectorBuilder, public testing::Test {}; class InspectorGetResourceBindingsTest : public InspectorBuilder, public testing::Test {}; @@ -605,8 +605,8 @@ TEST_F(InspectorGetEntryPointTest, MixInOutVariablesAndStruct) { EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[1].component_type); } -TEST_F(InspectorGetEntryPointTest, OverridableConstantUnreferenced) { - AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); +TEST_F(InspectorGetEntryPointTest, OverrideUnreferenced) { + Override("foo", ty.f32(), nullptr); MakeEmptyBodyFunction("ep_func", { Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_i), @@ -617,11 +617,11 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantUnreferenced) { auto result = inspector.GetEntryPoints(); ASSERT_EQ(1u, result.size()); - EXPECT_EQ(0u, result[0].overridable_constants.size()); + EXPECT_EQ(0u, result[0].overrides.size()); } -TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) { - AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); +TEST_F(InspectorGetEntryPointTest, OverrideReferencedByEntryPoint) { + Override("foo", ty.f32(), nullptr); MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(), { Stage(ast::PipelineStage::kCompute), @@ -633,12 +633,12 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) { auto result = inspector.GetEntryPoints(); ASSERT_EQ(1u, result.size()); - ASSERT_EQ(1u, result[0].overridable_constants.size()); - EXPECT_EQ("foo", result[0].overridable_constants[0].name); + ASSERT_EQ(1u, result[0].overrides.size()); + EXPECT_EQ("foo", result[0].overrides[0].name); } -TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) { - AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); +TEST_F(InspectorGetEntryPointTest, OverrideReferencedByCallee) { + Override("foo", ty.f32(), nullptr); MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {}); MakeCallerBodyFunction("ep_func", {"callee_func"}, { @@ -651,13 +651,13 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) { auto result = inspector.GetEntryPoints(); ASSERT_EQ(1u, result.size()); - ASSERT_EQ(1u, result[0].overridable_constants.size()); - EXPECT_EQ("foo", result[0].overridable_constants[0].name); + ASSERT_EQ(1u, result[0].overrides.size()); + EXPECT_EQ("foo", result[0].overrides[0].name); } -TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) { - AddOverridableConstantWithID("foo", 1, ty.f32(), nullptr); - AddOverridableConstantWithID("bar", 2, ty.f32(), nullptr); +TEST_F(InspectorGetEntryPointTest, OverrideSomeReferenced) { + Override("foo", ty.f32(), nullptr, {Id(1)}); + Override("bar", ty.f32(), nullptr, {Id(2)}); MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {}); MakeCallerBodyFunction("ep_func", {"callee_func"}, { @@ -670,16 +670,16 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) { auto result = inspector.GetEntryPoints(); ASSERT_EQ(1u, result.size()); - ASSERT_EQ(1u, result[0].overridable_constants.size()); - EXPECT_EQ("foo", result[0].overridable_constants[0].name); - EXPECT_EQ(1, result[0].overridable_constants[0].numeric_id); + ASSERT_EQ(1u, result[0].overrides.size()); + EXPECT_EQ("foo", result[0].overrides[0].name); + EXPECT_EQ(1, result[0].overrides[0].id.value); } -TEST_F(InspectorGetEntryPointTest, OverridableConstantTypes) { - AddOverridableConstantWithoutID("bool_var", ty.bool_(), nullptr); - AddOverridableConstantWithoutID("float_var", ty.f32(), nullptr); - AddOverridableConstantWithoutID("u32_var", ty.u32(), nullptr); - AddOverridableConstantWithoutID("i32_var", ty.i32(), nullptr); +TEST_F(InspectorGetEntryPointTest, OverrideTypes) { + Override("bool_var", ty.bool_(), nullptr); + Override("float_var", ty.f32(), nullptr); + Override("u32_var", ty.u32(), nullptr); + Override("i32_var", ty.i32(), nullptr); MakePlainGlobalReferenceBodyFunction("bool_func", "bool_var", ty.bool_(), {}); MakePlainGlobalReferenceBodyFunction("float_func", "float_var", ty.f32(), {}); @@ -697,22 +697,19 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantTypes) { auto result = inspector.GetEntryPoints(); ASSERT_EQ(1u, result.size()); - ASSERT_EQ(4u, result[0].overridable_constants.size()); - EXPECT_EQ("bool_var", result[0].overridable_constants[0].name); - EXPECT_EQ(inspector::OverridableConstant::Type::kBool, result[0].overridable_constants[0].type); - EXPECT_EQ("float_var", result[0].overridable_constants[1].name); - EXPECT_EQ(inspector::OverridableConstant::Type::kFloat32, - result[0].overridable_constants[1].type); - EXPECT_EQ("u32_var", result[0].overridable_constants[2].name); - EXPECT_EQ(inspector::OverridableConstant::Type::kUint32, - result[0].overridable_constants[2].type); - EXPECT_EQ("i32_var", result[0].overridable_constants[3].name); - EXPECT_EQ(inspector::OverridableConstant::Type::kInt32, - result[0].overridable_constants[3].type); + ASSERT_EQ(4u, result[0].overrides.size()); + EXPECT_EQ("bool_var", result[0].overrides[0].name); + EXPECT_EQ(inspector::Override::Type::kBool, result[0].overrides[0].type); + EXPECT_EQ("float_var", result[0].overrides[1].name); + EXPECT_EQ(inspector::Override::Type::kFloat32, result[0].overrides[1].type); + EXPECT_EQ("u32_var", result[0].overrides[2].name); + EXPECT_EQ(inspector::Override::Type::kUint32, result[0].overrides[2].type); + EXPECT_EQ("i32_var", result[0].overrides[3].name); + EXPECT_EQ(inspector::Override::Type::kInt32, result[0].overrides[3].type); } -TEST_F(InspectorGetEntryPointTest, OverridableConstantInitialized) { - AddOverridableConstantWithoutID("foo", ty.f32(), Expr(0_f)); +TEST_F(InspectorGetEntryPointTest, OverrideInitialized) { + Override("foo", ty.f32(), Expr(0_f)); MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(), { Stage(ast::PipelineStage::kCompute), @@ -724,13 +721,13 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantInitialized) { auto result = inspector.GetEntryPoints(); ASSERT_EQ(1u, result.size()); - ASSERT_EQ(1u, result[0].overridable_constants.size()); - EXPECT_EQ("foo", result[0].overridable_constants[0].name); - EXPECT_TRUE(result[0].overridable_constants[0].is_initialized); + ASSERT_EQ(1u, result[0].overrides.size()); + EXPECT_EQ("foo", result[0].overrides[0].name); + EXPECT_TRUE(result[0].overrides[0].is_initialized); } -TEST_F(InspectorGetEntryPointTest, OverridableConstantUninitialized) { - AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); +TEST_F(InspectorGetEntryPointTest, OverrideUninitialized) { + Override("foo", ty.f32(), nullptr); MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(), { Stage(ast::PipelineStage::kCompute), @@ -742,15 +739,15 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantUninitialized) { auto result = inspector.GetEntryPoints(); ASSERT_EQ(1u, result.size()); - ASSERT_EQ(1u, result[0].overridable_constants.size()); - EXPECT_EQ("foo", result[0].overridable_constants[0].name); + ASSERT_EQ(1u, result[0].overrides.size()); + EXPECT_EQ("foo", result[0].overrides[0].name); - EXPECT_FALSE(result[0].overridable_constants[0].is_initialized); + EXPECT_FALSE(result[0].overrides[0].is_initialized); } -TEST_F(InspectorGetEntryPointTest, OverridableConstantNumericIDSpecified) { - AddOverridableConstantWithoutID("foo_no_id", ty.f32(), nullptr); - AddOverridableConstantWithID("foo_id", 1234, ty.f32(), nullptr); +TEST_F(InspectorGetEntryPointTest, OverrideNumericIDSpecified) { + Override("foo_no_id", ty.f32(), nullptr); + Override("foo_id", ty.f32(), nullptr, {Id(1234)}); MakePlainGlobalReferenceBodyFunction("no_id_func", "foo_no_id", ty.f32(), {}); MakePlainGlobalReferenceBodyFunction("id_func", "foo_id", ty.f32(), {}); @@ -766,16 +763,16 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantNumericIDSpecified) { auto result = inspector.GetEntryPoints(); ASSERT_EQ(1u, result.size()); - ASSERT_EQ(2u, result[0].overridable_constants.size()); - EXPECT_EQ("foo_no_id", result[0].overridable_constants[0].name); - EXPECT_EQ("foo_id", result[0].overridable_constants[1].name); - EXPECT_EQ(1234, result[0].overridable_constants[1].numeric_id); + ASSERT_EQ(2u, result[0].overrides.size()); + EXPECT_EQ("foo_no_id", result[0].overrides[0].name); + EXPECT_EQ("foo_id", result[0].overrides[1].name); + EXPECT_EQ(1234, result[0].overrides[1].id.value); - EXPECT_FALSE(result[0].overridable_constants[0].is_numeric_id_specified); - EXPECT_TRUE(result[0].overridable_constants[1].is_numeric_id_specified); + EXPECT_FALSE(result[0].overrides[0].is_id_specified); + EXPECT_TRUE(result[0].overrides[1].is_id_specified); } -TEST_F(InspectorGetEntryPointTest, NonOverridableConstantSkipped) { +TEST_F(InspectorGetEntryPointTest, NonOverrideSkipped) { auto* foo_struct_type = MakeUniformBufferType("foo_type", {ty.i32()}); AddUniformBuffer("foo_ub", ty.Of(foo_struct_type), 0, 0); MakeStructVariableReferenceBodyFunction("ub_func", "foo_ub", {{0, ty.i32()}}); @@ -789,7 +786,7 @@ TEST_F(InspectorGetEntryPointTest, NonOverridableConstantSkipped) { auto result = inspector.GetEntryPoints(); ASSERT_EQ(1u, result.size()); - EXPECT_EQ(0u, result[0].overridable_constants.size()); + EXPECT_EQ(0u, result[0].overrides.size()); } TEST_F(InspectorGetEntryPointTest, BuiltinNotReferenced) { @@ -1172,127 +1169,127 @@ INSTANTIATE_TEST_SUITE_P( ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone, InterpolationType::kFlat, InterpolationSampling::kNone})); -TEST_F(InspectorGetConstantIDsTest, Bool) { - AddOverridableConstantWithID("foo", 1, ty.bool_(), nullptr); - AddOverridableConstantWithID("bar", 20, ty.bool_(), Expr(true)); - AddOverridableConstantWithID("baz", 300, ty.bool_(), Expr(false)); +TEST_F(InspectorGetOverrideDefaultValuesTest, Bool) { + Override("foo", ty.bool_(), nullptr, {Id(1)}); + Override("bar", ty.bool_(), Expr(true), {Id(20)}); + Override("baz", ty.bool_(), Expr(false), {Id(300)}); Inspector& inspector = Build(); - auto result = inspector.GetConstantIDs(); + auto result = inspector.GetOverrideDefaultValues(); ASSERT_EQ(3u, result.size()); - ASSERT_TRUE(result.find(1) != result.end()); - EXPECT_TRUE(result[1].IsNull()); + ASSERT_TRUE(result.find(OverrideId{1}) != result.end()); + EXPECT_TRUE(result[OverrideId{1}].IsNull()); - ASSERT_TRUE(result.find(20) != result.end()); - EXPECT_TRUE(result[20].IsBool()); - EXPECT_TRUE(result[20].AsBool()); + ASSERT_TRUE(result.find(OverrideId{20}) != result.end()); + EXPECT_TRUE(result[OverrideId{20}].IsBool()); + EXPECT_TRUE(result[OverrideId{20}].AsBool()); - ASSERT_TRUE(result.find(300) != result.end()); - EXPECT_TRUE(result[300].IsBool()); - EXPECT_FALSE(result[300].AsBool()); + ASSERT_TRUE(result.find(OverrideId{300}) != result.end()); + EXPECT_TRUE(result[OverrideId{300}].IsBool()); + EXPECT_FALSE(result[OverrideId{300}].AsBool()); } -TEST_F(InspectorGetConstantIDsTest, U32) { - AddOverridableConstantWithID("foo", 1, ty.u32(), nullptr); - AddOverridableConstantWithID("bar", 20, ty.u32(), Expr(42_u)); +TEST_F(InspectorGetOverrideDefaultValuesTest, U32) { + Override("foo", ty.u32(), nullptr, {Id(1)}); + Override("bar", ty.u32(), Expr(42_u), {Id(20)}); Inspector& inspector = Build(); - auto result = inspector.GetConstantIDs(); + auto result = inspector.GetOverrideDefaultValues(); ASSERT_EQ(2u, result.size()); - ASSERT_TRUE(result.find(1) != result.end()); - EXPECT_TRUE(result[1].IsNull()); + ASSERT_TRUE(result.find(OverrideId{1}) != result.end()); + EXPECT_TRUE(result[OverrideId{1}].IsNull()); - ASSERT_TRUE(result.find(20) != result.end()); - EXPECT_TRUE(result[20].IsU32()); - EXPECT_EQ(42u, result[20].AsU32()); + ASSERT_TRUE(result.find(OverrideId{20}) != result.end()); + EXPECT_TRUE(result[OverrideId{20}].IsU32()); + EXPECT_EQ(42u, result[OverrideId{20}].AsU32()); } -TEST_F(InspectorGetConstantIDsTest, I32) { - AddOverridableConstantWithID("foo", 1, ty.i32(), nullptr); - AddOverridableConstantWithID("bar", 20, ty.i32(), Expr(-42_i)); - AddOverridableConstantWithID("baz", 300, ty.i32(), Expr(42_i)); +TEST_F(InspectorGetOverrideDefaultValuesTest, I32) { + Override("foo", ty.i32(), nullptr, {Id(1)}); + Override("bar", ty.i32(), Expr(-42_i), {Id(20)}); + Override("baz", ty.i32(), Expr(42_i), {Id(300)}); Inspector& inspector = Build(); - auto result = inspector.GetConstantIDs(); + auto result = inspector.GetOverrideDefaultValues(); ASSERT_EQ(3u, result.size()); - ASSERT_TRUE(result.find(1) != result.end()); - EXPECT_TRUE(result[1].IsNull()); + ASSERT_TRUE(result.find(OverrideId{1}) != result.end()); + EXPECT_TRUE(result[OverrideId{1}].IsNull()); - ASSERT_TRUE(result.find(20) != result.end()); - EXPECT_TRUE(result[20].IsI32()); - EXPECT_EQ(-42, result[20].AsI32()); + ASSERT_TRUE(result.find(OverrideId{20}) != result.end()); + EXPECT_TRUE(result[OverrideId{20}].IsI32()); + EXPECT_EQ(-42, result[OverrideId{20}].AsI32()); - ASSERT_TRUE(result.find(300) != result.end()); - EXPECT_TRUE(result[300].IsI32()); - EXPECT_EQ(42, result[300].AsI32()); + ASSERT_TRUE(result.find(OverrideId{300}) != result.end()); + EXPECT_TRUE(result[OverrideId{300}].IsI32()); + EXPECT_EQ(42, result[OverrideId{300}].AsI32()); } -TEST_F(InspectorGetConstantIDsTest, Float) { - AddOverridableConstantWithID("foo", 1, ty.f32(), nullptr); - AddOverridableConstantWithID("bar", 20, ty.f32(), Expr(0_f)); - AddOverridableConstantWithID("baz", 300, ty.f32(), Expr(-10_f)); - AddOverridableConstantWithID("x", 4000, ty.f32(), Expr(15_f)); +TEST_F(InspectorGetOverrideDefaultValuesTest, Float) { + Override("foo", ty.f32(), nullptr, {Id(1)}); + Override("bar", ty.f32(), Expr(0_f), {Id(20)}); + Override("baz", ty.f32(), Expr(-10_f), {Id(300)}); + Override("x", ty.f32(), Expr(15_f), {Id(4000)}); Inspector& inspector = Build(); - auto result = inspector.GetConstantIDs(); + auto result = inspector.GetOverrideDefaultValues(); ASSERT_EQ(4u, result.size()); - ASSERT_TRUE(result.find(1) != result.end()); - EXPECT_TRUE(result[1].IsNull()); + ASSERT_TRUE(result.find(OverrideId{1}) != result.end()); + EXPECT_TRUE(result[OverrideId{1}].IsNull()); - ASSERT_TRUE(result.find(20) != result.end()); - EXPECT_TRUE(result[20].IsFloat()); - EXPECT_FLOAT_EQ(0.0f, result[20].AsFloat()); + ASSERT_TRUE(result.find(OverrideId{20}) != result.end()); + EXPECT_TRUE(result[OverrideId{20}].IsFloat()); + EXPECT_FLOAT_EQ(0.0f, result[OverrideId{20}].AsFloat()); - ASSERT_TRUE(result.find(300) != result.end()); - EXPECT_TRUE(result[300].IsFloat()); - EXPECT_FLOAT_EQ(-10.0f, result[300].AsFloat()); + ASSERT_TRUE(result.find(OverrideId{300}) != result.end()); + EXPECT_TRUE(result[OverrideId{300}].IsFloat()); + EXPECT_FLOAT_EQ(-10.0f, result[OverrideId{300}].AsFloat()); - ASSERT_TRUE(result.find(4000) != result.end()); - EXPECT_TRUE(result[4000].IsFloat()); - EXPECT_FLOAT_EQ(15.0f, result[4000].AsFloat()); + ASSERT_TRUE(result.find(OverrideId{4000}) != result.end()); + EXPECT_TRUE(result[OverrideId{4000}].IsFloat()); + EXPECT_FLOAT_EQ(15.0f, result[OverrideId{4000}].AsFloat()); } TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) { - AddOverridableConstantWithID("v1", 1, ty.f32(), nullptr); - AddOverridableConstantWithID("v20", 20, ty.f32(), nullptr); - AddOverridableConstantWithID("v300", 300, ty.f32(), nullptr); - auto* a = AddOverridableConstantWithoutID("a", ty.f32(), nullptr); - auto* b = AddOverridableConstantWithoutID("b", ty.f32(), nullptr); - auto* c = AddOverridableConstantWithoutID("c", ty.f32(), nullptr); + Override("v1", ty.f32(), nullptr, {Id(1)}); + Override("v20", ty.f32(), nullptr, {Id(20)}); + Override("v300", ty.f32(), nullptr, {Id(300)}); + auto* a = Override("a", ty.f32(), nullptr); + auto* b = Override("b", ty.f32(), nullptr); + auto* c = Override("c", ty.f32(), nullptr); Inspector& inspector = Build(); - auto result = inspector.GetConstantNameToIdMap(); + auto result = inspector.GetNamedOverrideIds(); ASSERT_EQ(6u, result.size()); ASSERT_TRUE(result.count("v1")); - EXPECT_EQ(result["v1"], 1u); + EXPECT_EQ(result["v1"].value, 1u); ASSERT_TRUE(result.count("v20")); - EXPECT_EQ(result["v20"], 20u); + EXPECT_EQ(result["v20"].value, 20u); ASSERT_TRUE(result.count("v300")); - EXPECT_EQ(result["v300"], 300u); + EXPECT_EQ(result["v300"].value, 300u); ASSERT_TRUE(result.count("a")); ASSERT_TRUE(program_->Sem().Get(a)); - EXPECT_EQ(result["a"], program_->Sem().Get(a)->ConstantId()); + EXPECT_EQ(result["a"], program_->Sem().Get(a)->OverrideId()); ASSERT_TRUE(result.count("b")); ASSERT_TRUE(program_->Sem().Get(b)); - EXPECT_EQ(result["b"], program_->Sem().Get(b)->ConstantId()); + EXPECT_EQ(result["b"], program_->Sem().Get(b)->OverrideId()); ASSERT_TRUE(result.count("c")); ASSERT_TRUE(program_->Sem().Get(c)); - EXPECT_EQ(result["c"], program_->Sem().Get(c)->ConstantId()); + EXPECT_EQ(result["c"], program_->Sem().Get(c)->OverrideId()); } TEST_F(InspectorGetStorageSizeTest, Empty) { diff --git a/src/tint/inspector/test_inspector_builder.h b/src/tint/inspector/test_inspector_builder.h index 391e1a1479..a3ccb1351c 100644 --- a/src/tint/inspector/test_inspector_builder.h +++ b/src/tint/inspector/test_inspector_builder.h @@ -92,31 +92,6 @@ class InspectorBuilder : public ProgramBuilder { std::vector> inout_vars, ast::AttributeList attributes); - /// Add a pipeline constant to the global variables, with a specific ID. - /// @param name name of the variable to add - /// @param id id number for the constant id - /// @param type type of the variable - /// @param constructor val to initialize the constant with, if NULL no - /// constructor will be added. - /// @returns the constant that was created - const ast::Variable* AddOverridableConstantWithID(std::string name, - uint32_t id, - const ast::Type* type, - const ast::Expression* constructor) { - return Override(name, type, constructor, {Id(id)}); - } - - /// Add a pipeline constant to the global variables, without a specific ID. - /// @param name name of the variable to add - /// @param type type of the variable - /// @param constructor val to initialize the constant with, if NULL no - /// constructor will be added. - /// @returns the constant that was created - const ast::Variable* AddOverridableConstantWithoutID(std::string name, - const ast::Type* type, - const ast::Expression* constructor) { - return Override(name, type, constructor); - } /// Generates a function that references module-scoped, plain-typed constant /// or variable. diff --git a/src/tint/program_builder.h b/src/tint/program_builder.h index f7d91162f5..61954ee171 100644 --- a/src/tint/program_builder.h +++ b/src/tint/program_builder.h @@ -19,6 +19,8 @@ #include #include +#include "tint/override_id.h" + #include "src/tint/ast/alias.h" #include "src/tint/ast/array.h" #include "src/tint/ast/assignment_statement.h" @@ -2588,6 +2590,19 @@ class ProgramBuilder { return create(source_, location); } + /// Creates an ast::IdAttribute + /// @param source the source information + /// @param id the id value + /// @returns the override attribute pointer + const ast::IdAttribute* Id(const Source& source, OverrideId id) { + return create(source, id.value); + } + + /// Creates an ast::IdAttribute with an override identifier + /// @param id the optional id value + /// @returns the override attribute pointer + const ast::IdAttribute* Id(OverrideId id) { return Id(source_, id); } + /// Creates an ast::IdAttribute /// @param source the source information /// @param id the id value @@ -2596,7 +2611,7 @@ class ProgramBuilder { return create(source, id); } - /// Creates an ast::IdAttribute with a constant ID + /// Creates an ast::IdAttribute with an override identifier /// @param id the optional id value /// @returns the override attribute pointer const ast::IdAttribute* Id(uint32_t id) { return Id(source_, id); } diff --git a/src/tint/resolver/pipeline_overridable_constant_test.cc b/src/tint/resolver/override_test.cc similarity index 71% rename from src/tint/resolver/pipeline_overridable_constant_test.cc rename to src/tint/resolver/override_test.cc index 361450038a..f1e13e9585 100644 --- a/src/tint/resolver/pipeline_overridable_constant_test.cc +++ b/src/tint/resolver/override_test.cc @@ -21,23 +21,23 @@ using namespace tint::number_suffixes; // NOLINT namespace tint::resolver { namespace { -class ResolverPipelineOverridableConstantTest : public ResolverTest { +class ResolverOverrideTest : public ResolverTest { protected: /// Verify that the AST node `var` was resolved to an overridable constant /// with an ID equal to `id`. /// @param var the overridable constant AST node /// @param id the expected constant ID - void ExpectConstantId(const ast::Variable* var, uint16_t id) { + void ExpectOverrideId(const ast::Variable* var, uint16_t id) { auto* sem = Sem().Get(var); ASSERT_NE(sem, nullptr); EXPECT_EQ(sem->Declaration(), var); EXPECT_TRUE(sem->Declaration()->Is()); - EXPECT_EQ(sem->ConstantId(), id); + EXPECT_EQ(sem->OverrideId().value, id); EXPECT_FALSE(sem->ConstantValue()); } }; -TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) { +TEST_F(ResolverOverrideTest, NonOverridable) { auto* a = GlobalConst("a", ty.f32(), Expr(1_f)); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -49,23 +49,23 @@ TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) { EXPECT_TRUE(sem_a->ConstantValue()); } -TEST_F(ResolverPipelineOverridableConstantTest, WithId) { +TEST_F(ResolverOverrideTest, WithId) { auto* a = Override("a", ty.f32(), Expr(1_f), {Id(7u)}); EXPECT_TRUE(r()->Resolve()) << r()->error(); - ExpectConstantId(a, 7u); + ExpectOverrideId(a, 7u); } -TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) { +TEST_F(ResolverOverrideTest, WithoutId) { auto* a = Override("a", ty.f32(), Expr(1_f)); EXPECT_TRUE(r()->Resolve()) << r()->error(); - ExpectConstantId(a, 0u); + ExpectOverrideId(a, 0u); } -TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) { +TEST_F(ResolverOverrideTest, WithAndWithoutIds) { std::vector variables; auto* a = Override("a", ty.f32(), Expr(1_f)); auto* b = Override("b", ty.f32(), Expr(1_f)); @@ -77,33 +77,33 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) { EXPECT_TRUE(r()->Resolve()) << r()->error(); // Verify that constant id allocation order is deterministic. - ExpectConstantId(a, 0u); - ExpectConstantId(b, 3u); - ExpectConstantId(c, 2u); - ExpectConstantId(d, 4u); - ExpectConstantId(e, 5u); - ExpectConstantId(f, 1u); + ExpectOverrideId(a, 0u); + ExpectOverrideId(b, 3u); + ExpectOverrideId(c, 2u); + ExpectOverrideId(d, 4u); + ExpectOverrideId(e, 5u); + ExpectOverrideId(f, 1u); } -TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) { +TEST_F(ResolverOverrideTest, DuplicateIds) { Override("a", ty.f32(), Expr(1_f), {Id(Source{{12, 34}}, 7u)}); Override("b", ty.f32(), Expr(1_f), {Id(Source{{56, 78}}, 7u)}); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), R"(56:78 error: pipeline constant IDs must be unique -12:34 note: a pipeline constant with an ID of 7 was previously declared here:)"); + EXPECT_EQ(r()->error(), R"(56:78 error: override IDs must be unique +12:34 note: a override with an ID of 7 was previously declared here:)"); } -TEST_F(ResolverPipelineOverridableConstantTest, IdTooLarge) { +TEST_F(ResolverOverrideTest, IdTooLarge) { Override("a", ty.f32(), Expr(1_f), {Id(Source{{12, 34}}, 65536u)}); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: pipeline constant IDs must be between 0 and 65535"); + EXPECT_EQ(r()->error(), "12:34 error: override IDs must be between 0 and 65535"); } -TEST_F(ResolverPipelineOverridableConstantTest, F16_TemporallyBan) { +TEST_F(ResolverOverrideTest, F16_TemporallyBan) { Enable(ast::Extension::kF16); Override(Source{{12, 34}}, "a", ty.f16(), Expr(1_h), {Id(1u)}); diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index b7b7586f9d..d861acf76d 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -149,7 +149,9 @@ bool Resolver::ResolveInternal() { } } - AllocateOverridableConstantIds(); + if (!AllocateOverridableConstantIds()) { + return false; + } SetShadows(); @@ -432,7 +434,7 @@ sem::Variable* Resolver::Override(const ast::Override* v) { /* constant_value */ nullptr, sem::BindingPoint{}); if (auto* id = ast::GetAttribute(v->attributes)) { - sem->SetConstantId(static_cast(id->value)); + sem->SetOverrideId(OverrideId{static_cast(id->value)}); } sem->SetConstructor(rhs); @@ -641,9 +643,19 @@ ast::Access Resolver::DefaultAccessForStorageClass(ast::StorageClass storage_cla return ast::Access::kReadWrite; } -void Resolver::AllocateOverridableConstantIds() { +bool Resolver::AllocateOverridableConstantIds() { + constexpr size_t kLimit = std::numeric_limits::max(); // The next pipeline constant ID to try to allocate. - uint16_t next_constant_id = 0; + OverrideId next_id; + bool ids_exhausted = false; + + auto increment_next_id = [&] { + if (next_id.value == kLimit) { + ids_exhausted = true; + } else { + next_id.value = next_id.value + 1; + } + }; // Allocate constant IDs in global declaration order, so that they are // deterministic. @@ -655,26 +667,28 @@ void Resolver::AllocateOverridableConstantIds() { continue; } - uint16_t constant_id; + OverrideId id; if (auto* id_attr = ast::GetAttribute(override->attributes)) { - constant_id = static_cast(id_attr->value); + id = OverrideId{static_cast(id_attr->value)}; } else { // No ID was specified, so allocate the next available ID. - constant_id = next_constant_id; - while (constant_ids_.count(constant_id)) { - if (constant_id == UINT16_MAX) { - TINT_ICE(Resolver, builder_->Diagnostics()) - << "no more pipeline constant IDs available"; - return; - } - constant_id++; + while (!ids_exhausted && override_ids_.count(next_id)) { + increment_next_id(); } - next_constant_id = constant_id + 1; + if (ids_exhausted) { + AddError( + "number of 'override' variables exceeded limit of " + std::to_string(kLimit), + decl->source); + return false; + } + id = next_id; + increment_next_id(); } auto* sem = sem_.Get(override); - const_cast(sem)->SetConstantId(constant_id); + const_cast(sem)->SetOverrideId(id); } + return true; } void Resolver::SetShadows() { @@ -697,7 +711,8 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) { if (auto* id_attr = attr->As()) { // Track the constant IDs that are specified in the shader. - constant_ids_.emplace(id_attr->value, sem); + override_ids_.emplace( + OverrideId{static_cast(id_attr->value)}, sem); } } @@ -705,7 +720,7 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) { return nullptr; } - if (!validator_.GlobalVariable(sem, constant_ids_, atomic_composite_info_)) { + if (!validator_.GlobalVariable(sem, override_ids_, atomic_composite_info_)) { return nullptr; } diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 03f1b9e956..ccfaf6c502 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -362,7 +362,8 @@ class Resolver { ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class); /// Allocate constant IDs for pipeline-overridable constants. - void AllocateOverridableConstantIds(); + /// @returns true on success, false on error + bool AllocateOverridableConstantIds(); /// Set the shadowing information on variable declarations. /// @note this method must only be called after all semantic nodes are built. @@ -422,7 +423,7 @@ class Resolver { std::vector entry_points_; std::unordered_map atomic_composite_info_; utils::Bitset<0> marked_; - std::unordered_map constant_ids_; + std::unordered_map override_ids_; std::unordered_map array_ctors_; std::unordered_map struct_ctors_; diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index b8a62475fc..940e154a3d 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -562,7 +562,7 @@ bool Validator::LocalVariable(const sem::Variable* v) const { bool Validator::GlobalVariable( const sem::GlobalVariable* global, - const std::unordered_map& constant_ids, + const std::unordered_map& override_ids, const std::unordered_map& atomic_composite_info) const { auto* decl = global->Declaration(); bool ok = Switch( @@ -627,7 +627,7 @@ bool Validator::GlobalVariable( return Var(global); }, - [&](const ast::Override*) { return Override(global, constant_ids); }, + [&](const ast::Override*) { return Override(global, override_ids); }, [&](const ast::Const*) { if (!decl->attributes.empty()) { AddError("attribute is not valid for module-scope 'const' declaration", @@ -763,7 +763,7 @@ bool Validator::Let(const sem::Variable* v) const { bool Validator::Override( const sem::Variable* v, - const std::unordered_map& constant_ids) const { + const std::unordered_map& override_ids) const { auto* decl = v->Declaration(); auto* storage_ty = v->Type()->UnwrapRef(); @@ -776,19 +776,23 @@ bool Validator::Override( for (auto* attr : decl->attributes) { if (auto* id_attr = attr->As()) { uint32_t id = id_attr->value; - auto it = constant_ids.find(id); - if (it != constant_ids.end() && it->second != v) { - AddError("pipeline constant IDs must be unique", attr->source); - AddNote("a pipeline constant with an ID of " + std::to_string(id) + + if (id > std::numeric_limits::max()) { + AddError( + "override IDs must be between 0 and " + + std::to_string(std::numeric_limits::max()), + attr->source); + return false; + } + if (auto it = + override_ids.find(OverrideId{static_cast(id)}); + it != override_ids.end() && it->second != v) { + AddError("override IDs must be unique", attr->source); + AddNote("a override with an ID of " + std::to_string(id) + " was previously declared here:", ast::GetAttribute(it->second->Declaration()->attributes) ->source); return false; } - if (id > 65535) { - AddError("pipeline constant IDs must be between 0 and 65535", attr->source); - return false; - } } else { AddError("attribute is not valid for 'override' declaration", attr->source); return false; diff --git a/src/tint/resolver/validator.h b/src/tint/resolver/validator.h index ed20af0787..e9f9ceac4c 100644 --- a/src/tint/resolver/validator.h +++ b/src/tint/resolver/validator.h @@ -234,12 +234,12 @@ class Validator { /// Validates a global variable /// @param var the global variable to validate - /// @param constant_ids the set of constant ids in the module + /// @param override_id the set of override ids in the module /// @param atomic_composite_info atomic composite info in the module /// @returns true on success, false otherwise bool GlobalVariable( const sem::GlobalVariable* var, - const std::unordered_map& constant_ids, + const std::unordered_map& override_id, const std::unordered_map& atomic_composite_info) const; /// Validates an if statement @@ -371,10 +371,10 @@ class Validator { /// Validates a 'override' variable declaration /// @param v the variable to validate - /// @param constant_ids the set of constant ids in the module + /// @param override_id the set of override ids in the module /// @returns true on success, false otherwise. bool Override(const sem::Variable* v, - const std::unordered_map& constant_ids) const; + const std::unordered_map& override_id) const; /// Validates a 'const' variable declaration /// @param v the variable to validate diff --git a/src/tint/resolver/variable_validation_test.cc b/src/tint/resolver/variable_validation_test.cc index e13585f2cc..08312fef53 100644 --- a/src/tint/resolver/variable_validation_test.cc +++ b/src/tint/resolver/variable_validation_test.cc @@ -77,6 +77,37 @@ TEST_F(ResolverVariableValidationTest, OverrideNoInitializerNoType) { EXPECT_EQ(r()->error(), "12:34 error: override declaration requires a type or initializer"); } +TEST_F(ResolverVariableValidationTest, OverrideExceedsIDLimit_LastUnreserved) { + // override o0 : i32; + // override o1 : i32; + // ... + // override bang : i32; + constexpr size_t kLimit = std::numeric_limits::max(); + for (size_t i = 0; i <= kLimit; i++) { + Override("o" + std::to_string(i), ty.i32(), nullptr); + } + Override(Source{{12, 34}}, "bang", ty.i32(), nullptr); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: number of 'override' variables exceeded limit of 65535"); +} + +TEST_F(ResolverVariableValidationTest, OverrideExceedsIDLimit_LastReserved) { + // override o0 : i32; + // override o1 : i32; + // ... + // @id(N) override oN : i32; + constexpr size_t kLimit = std::numeric_limits::max(); + Override("reserved", ty.i32(), nullptr, {Id(kLimit)}); + for (size_t i = 0; i < kLimit; i++) { + Override("o" + std::to_string(i), ty.i32(), nullptr); + } + Override(Source{{12, 34}}, "bang", ty.i32(), nullptr); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: number of 'override' variables exceeded limit of 65535"); +} + TEST_F(ResolverVariableValidationTest, VarTypeNotConstructible) { // var i : i32; // var p : pointer = &v; diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h index ecc536b2e8..2bf6d23ee6 100644 --- a/src/tint/sem/variable.h +++ b/src/tint/sem/variable.h @@ -18,6 +18,8 @@ #include #include +#include "tint/override_id.h" + #include "src/tint/ast/access.h" #include "src/tint/ast/storage_class.h" #include "src/tint/sem/binding_point.h" @@ -165,15 +167,15 @@ class GlobalVariable final : public Castable { sem::BindingPoint BindingPoint() const { return binding_point_; } /// @param id the constant identifier to assign to this variable - void SetConstantId(uint16_t id) { constant_id_ = id; } + void SetOverrideId(OverrideId id) { override_id_ = id; } /// @returns the pipeline constant ID associated with the variable - uint16_t ConstantId() const { return constant_id_; } + tint::OverrideId OverrideId() const { return override_id_; } private: const sem::BindingPoint binding_point_; - uint16_t constant_id_ = 0; + tint::OverrideId override_id_; }; /// Parameter is a function parameter diff --git a/src/tint/transform/single_entry_point.cc b/src/tint/transform/single_entry_point.cc index ab7bed3da1..133c836e07 100644 --- a/src/tint/transform/single_entry_point.cc +++ b/src/tint/transform/single_entry_point.cc @@ -76,11 +76,11 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c [&](const ast::Override* override) { if (referenced_vars.count(override)) { if (!ast::HasAttribute(override->attributes)) { - // If the constant doesn't already have an @id() attribute, add one + // If the override doesn't already have an @id() attribute, add one // so that its allocated ID so that it won't be affected by other - // stripped away constants + // stripped away overrides auto* global = sem.Get(override); - const auto* id = ctx.dst->Id(global->ConstantId()); + const auto* id = ctx.dst->Id(global->OverrideId()); ctx.InsertFront(override->attributes, id); } ctx.dst->AST().AddGlobalVariable(ctx.Clone(override)); diff --git a/src/tint/writer/glsl/generator_impl.cc b/src/tint/writer/glsl/generator_impl.cc index f4fa5ad9c7..03b69140d1 100644 --- a/src/tint/writer/glsl/generator_impl.cc +++ b/src/tint/writer/glsl/generator_impl.cc @@ -2159,7 +2159,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { TINT_ICE(Writer, builder_.Diagnostics()) << "expected a pipeline-overridable constant"; } - out << kSpecConstantPrefix << global->ConstantId(); + out << kSpecConstantPrefix << global->OverrideId().value; } else { out << std::to_string(wgsize[i].value); } @@ -3053,18 +3053,18 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) { auto* type = sem->Type(); auto* global = sem->As(); - auto const_id = global->ConstantId(); + auto override_id = global->OverrideId(); - line() << "#ifndef " << kSpecConstantPrefix << const_id; + line() << "#ifndef " << kSpecConstantPrefix << override_id.value; if (override->constructor != nullptr) { auto out = line(); - out << "#define " << kSpecConstantPrefix << const_id << " "; + out << "#define " << kSpecConstantPrefix << override_id.value << " "; if (!EmitExpression(out, override->constructor)) { return false; } } else { - line() << "#error spec constant required for constant id " << const_id; + line() << "#error spec constant required for constant id " << override_id.value; } line() << "#endif"; { @@ -3074,7 +3074,7 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) { builder_.Symbols().NameFor(override->symbol))) { return false; } - out << " = " << kSpecConstantPrefix << const_id << ";"; + out << " = " << kSpecConstantPrefix << override_id.value << ";"; } return true; diff --git a/src/tint/writer/hlsl/generator_impl.cc b/src/tint/writer/hlsl/generator_impl.cc index 7bf1eb91c8..d106804ab6 100644 --- a/src/tint/writer/hlsl/generator_impl.cc +++ b/src/tint/writer/hlsl/generator_impl.cc @@ -3061,7 +3061,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) { TINT_ICE(Writer, diagnostics_) << "expected a pipeline-overridable constant"; } - out << kSpecConstantPrefix << global->ConstantId(); + out << kSpecConstantPrefix << global->OverrideId().value; } else { out << std::to_string(wgsize[i].value); } @@ -4101,18 +4101,18 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) { auto* sem = builder_.Sem().Get(override); auto* type = sem->Type(); - auto const_id = sem->ConstantId(); + auto override_id = sem->OverrideId(); - line() << "#ifndef " << kSpecConstantPrefix << const_id; + line() << "#ifndef " << kSpecConstantPrefix << override_id.value; if (override->constructor != nullptr) { auto out = line(); - out << "#define " << kSpecConstantPrefix << const_id << " "; + out << "#define " << kSpecConstantPrefix << override_id.value << " "; if (!EmitExpression(out, override->constructor)) { return false; } } else { - line() << "#error spec constant required for constant id " << const_id; + line() << "#error spec constant required for constant id " << override_id.value; } line() << "#endif"; { @@ -4122,7 +4122,7 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) { builder_.Symbols().NameFor(override->symbol))) { return false; } - out << " = " << kSpecConstantPrefix << const_id << ";"; + out << " = " << kSpecConstantPrefix << override_id.value << ";"; } return true; } diff --git a/src/tint/writer/hlsl/generator_impl_module_constant_test.cc b/src/tint/writer/hlsl/generator_impl_module_constant_test.cc index d5dee6faf4..9dc9acb81e 100644 --- a/src/tint/writer/hlsl/generator_impl_module_constant_test.cc +++ b/src/tint/writer/hlsl/generator_impl_module_constant_test.cc @@ -244,10 +244,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_arr_vec2_bool) { } TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override) { - auto* var = Override("pos", ty.f32(), Expr(3_f), - ast::AttributeList{ - Id(23), - }); + auto* var = Override("pos", ty.f32(), Expr(3_f), {Id(23)}); GeneratorImpl& gen = Build(); @@ -260,10 +257,7 @@ static const float pos = WGSL_SPEC_CONSTANT_23; } TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoConstructor) { - auto* var = Override("pos", ty.f32(), nullptr, - ast::AttributeList{ - Id(23), - }); + auto* var = Override("pos", ty.f32(), nullptr, {Id(23)}); GeneratorImpl& gen = Build(); @@ -276,10 +270,7 @@ static const float pos = WGSL_SPEC_CONSTANT_23; } TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoId) { - auto* a = Override("a", ty.f32(), Expr(3_f), - ast::AttributeList{ - Id(0), - }); + auto* a = Override("a", ty.f32(), Expr(3_f), {Id(0)}); auto* b = Override("b", ty.f32(), Expr(2_f)); GeneratorImpl& gen = Build(); diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index d02468692a..3c1525f8fc 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -3063,7 +3063,7 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) { } out << " " << program_->Symbols().NameFor(override->symbol); - out << " [[function_constant(" << global->ConstantId() << ")]];"; + out << " [[function_constant(" << global->OverrideId().value << ")]];"; return true; } diff --git a/src/tint/writer/spirv/builder.cc b/src/tint/writer/spirv/builder.cc index e41096cfdb..0cd3f7cf60 100644 --- a/src/tint/writer/spirv/builder.cc +++ b/src/tint/writer/spirv/builder.cc @@ -540,7 +540,7 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) { << "expected a pipeline-overridable constant"; } constant.is_spec_op = true; - constant.constant_id = sem_const->ConstantId(); + constant.constant_id = sem_const->OverrideId().value; } auto result = GenerateConstantIfNeeded(constant); @@ -1340,7 +1340,7 @@ uint32_t Builder::GenerateTypeConstructorOrConversion(const sem::Call* call, // Generate the zero initializer if there are no values provided. if (args.IsEmpty()) { if (global_var && global_var->Declaration()->Is()) { - auto constant_id = global_var->ConstantId(); + auto constant_id = global_var->OverrideId().value; if (result_type->Is()) { return GenerateConstantIfNeeded(ScalarConstant::I32(0).AsSpecOp(constant_id)); } @@ -1664,7 +1664,7 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var, auto* global = builder_.Sem().Get(var); if (global && global->Declaration()->Is()) { constant.is_spec_op = true; - constant.constant_id = global->ConstantId(); + constant.constant_id = global->OverrideId().value; } Switch(