tint: Add tint::OverrideId

This is a public API definition of a program-unique override identifier.

Bug: tint:1155
Change-Id: I6e55d43208e72a7a316557a89e2169d1b952f9bf
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97006
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-07-27 20:50:40 +00:00 committed by Dawn LUCI CQ
parent b2306e27fc
commit 9a6acc419e
32 changed files with 446 additions and 359 deletions

View File

@ -0,0 +1,60 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_OVERRIDE_ID_H_
#define SRC_TINT_OVERRIDE_ID_H_
#include <stdint.h>
namespace tint {
/// OverrideId is a numerical identifier for an override variable, unique per program.
struct OverrideId {
uint16_t value = 0;
};
/// Equality operator for OverrideId
/// @param lhs the OverrideId on the left of the '=' operator
/// @param rhs the OverrideId on the right of the '=' operator
/// @returns true if `lhs` is equal to `rhs`
inline bool operator==(OverrideId lhs, OverrideId rhs) {
return lhs.value == rhs.value;
}
/// Less-than operator for OverrideId
/// @param lhs the OverrideId on the left of the '<' operator
/// @param rhs the OverrideId on the right of the '<' operator
/// @returns true if `lhs` comes before `rhs`
inline bool operator<(OverrideId lhs, OverrideId rhs) {
return lhs.value < rhs.value;
}
} // namespace tint
namespace std {
/// Custom std::hash specialization for tint::OverrideId.
template <>
class hash<tint::OverrideId> {
public:
/// @param id the override identifier
/// @return the hash of the override identifier
inline std::size_t operator()(tint::OverrideId id) const {
return std::hash<decltype(tint::OverrideId::value)>()(id.value);
}
};
} // namespace std
#endif // SRC_TINT_OVERRIDE_ID_H_

View File

@ -66,16 +66,16 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
// Validate if overridable constants exist in shader module // Validate if overridable constants exist in shader module
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor // pipelineBase is not yet constructed at this moment so iterate constants from descriptor
size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size(); size_t numUninitializedConstants = metadata.uninitializedOverrides.size();
// Keep an initialized constants sets to handle duplicate initialization cases // Keep an initialized constants sets to handle duplicate initialization cases
std::unordered_set<std::string> stageInitializedConstantIdentifiers; std::unordered_set<std::string> stageInitializedConstantIdentifiers;
for (uint32_t i = 0; i < constantCount; i++) { for (uint32_t i = 0; i < constantCount; i++) {
DAWN_INVALID_IF(metadata.overridableConstants.count(constants[i].key) == 0, DAWN_INVALID_IF(metadata.overrides.count(constants[i].key) == 0,
"Pipeline overridable constant \"%s\" not found in %s.", constants[i].key, "Pipeline overridable constant \"%s\" not found in %s.", constants[i].key,
module); module);
if (stageInitializedConstantIdentifiers.count(constants[i].key) == 0) { if (stageInitializedConstantIdentifiers.count(constants[i].key) == 0) {
if (metadata.uninitializedOverridableConstants.count(constants[i].key) > 0) { if (metadata.uninitializedOverrides.count(constants[i].key) > 0) {
numUninitializedConstants--; numUninitializedConstants--;
} }
stageInitializedConstantIdentifiers.insert(constants[i].key); stageInitializedConstantIdentifiers.insert(constants[i].key);
@ -91,7 +91,7 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
if (DAWN_UNLIKELY(numUninitializedConstants > 0)) { if (DAWN_UNLIKELY(numUninitializedConstants > 0)) {
std::string uninitializedConstantsArray; std::string uninitializedConstantsArray;
bool isFirst = true; bool isFirst = true;
for (std::string identifier : metadata.uninitializedOverridableConstants) { for (std::string identifier : metadata.uninitializedOverrides) {
if (stageInitializedConstantIdentifiers.count(identifier) > 0) { if (stageInitializedConstantIdentifiers.count(identifier) > 0) {
continue; continue;
} }

View File

@ -358,17 +358,16 @@ ResultOrError<InterpolationSampling> TintInterpolationSamplingToInterpolationSam
UNREACHABLE(); UNREACHABLE();
} }
EntryPointMetadata::OverridableConstant::Type FromTintOverridableConstantType( EntryPointMetadata::Override::Type FromTintOverrideType(tint::inspector::Override::Type type) {
tint::inspector::OverridableConstant::Type type) {
switch (type) { switch (type) {
case tint::inspector::OverridableConstant::Type::kBool: case tint::inspector::Override::Type::kBool:
return EntryPointMetadata::OverridableConstant::Type::Boolean; return EntryPointMetadata::Override::Type::Boolean;
case tint::inspector::OverridableConstant::Type::kFloat32: case tint::inspector::Override::Type::kFloat32:
return EntryPointMetadata::OverridableConstant::Type::Float32; return EntryPointMetadata::Override::Type::Float32;
case tint::inspector::OverridableConstant::Type::kInt32: case tint::inspector::Override::Type::kInt32:
return EntryPointMetadata::OverridableConstant::Type::Int32; return EntryPointMetadata::Override::Type::Int32;
case tint::inspector::OverridableConstant::Type::kUint32: case tint::inspector::Override::Type::kUint32:
return EntryPointMetadata::OverridableConstant::Type::Uint32; return EntryPointMetadata::Override::Type::Uint32;
} }
UNREACHABLE(); UNREACHABLE();
} }
@ -610,17 +609,17 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
return invalid; \ return invalid; \
})() })()
if (!entryPoint.overridable_constants.empty()) { if (!entryPoint.overrides.empty()) {
DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs), DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs),
"Pipeline overridable constants are disallowed because they " "Pipeline overridable constants are disallowed because they "
"are partially implemented."); "are partially implemented.");
const auto& name2Id = inspector->GetConstantNameToIdMap(); const auto& name2Id = inspector->GetNamedOverrideIds();
const auto& id2Scalar = inspector->GetConstantIDs(); const auto& id2Scalar = inspector->GetOverrideDefaultValues();
for (auto& c : entryPoint.overridable_constants) { for (auto& c : entryPoint.overrides) {
uint32_t id = name2Id.at(c.name); auto id = name2Id.at(c.name);
OverridableConstantScalar defaultValue; OverrideScalar defaultValue;
if (c.is_initialized) { if (c.is_initialized) {
// if it is initialized, the scalar must exist // if it is initialized, the scalar must exist
const auto& scalar = id2Scalar.at(id); const auto& scalar = id2Scalar.at(id);
@ -636,21 +635,19 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
UNREACHABLE(); UNREACHABLE();
} }
} }
EntryPointMetadata::OverridableConstant constant = { EntryPointMetadata::Override override = {id.value, FromTintOverrideType(c.type),
id, FromTintOverridableConstantType(c.type), c.is_initialized, defaultValue}; c.is_initialized, defaultValue};
std::string identifier = std::string identifier = c.is_id_specified ? std::to_string(override.id) : c.name;
c.is_numeric_id_specified ? std::to_string(constant.id) : c.name; metadata->overrides[identifier] = override;
metadata->overridableConstants[identifier] = constant;
if (!c.is_initialized) { if (!c.is_initialized) {
auto [_, inserted] = auto [_, inserted] =
metadata->uninitializedOverridableConstants.emplace(std::move(identifier)); metadata->uninitializedOverrides.emplace(std::move(identifier));
// The insertion should have taken place // The insertion should have taken place
ASSERT(inserted); ASSERT(inserted);
} else { } else {
auto [_, inserted] = auto [_, inserted] = metadata->initializedOverrides.emplace(std::move(identifier));
metadata->initializedOverridableConstants.emplace(std::move(identifier));
// The insertion should have taken place // The insertion should have taken place
ASSERT(inserted); ASSERT(inserted);
} }

View File

@ -155,8 +155,8 @@ struct ShaderBindingInfo {
using BindingGroupInfoMap = std::map<BindingNumber, ShaderBindingInfo>; using BindingGroupInfoMap = std::map<BindingNumber, ShaderBindingInfo>;
using BindingInfoArray = ityp::array<BindGroupIndex, BindingGroupInfoMap, kMaxBindGroups>; using BindingInfoArray = ityp::array<BindGroupIndex, BindingGroupInfoMap, kMaxBindGroups>;
// The WebGPU overridable constants only support these scalar types // The WebGPU override variables only support these scalar types
union OverridableConstantScalar { union OverrideScalar {
// Use int32_t for boolean to initialize the full 32bit // Use int32_t for boolean to initialize the full 32bit
int32_t b; int32_t b;
float f32; float f32;
@ -216,9 +216,9 @@ struct EntryPointMetadata {
// The shader stage for this binding. // The shader stage for this binding.
SingleShaderStage stage; SingleShaderStage stage;
struct OverridableConstant { struct Override {
uint32_t id; uint32_t id;
// Match tint::inspector::OverridableConstant::Type // Match tint::inspector::Override::Type
// Bool is defined as a macro on linux X11 and cannot compile // Bool is defined as a macro on linux X11 and cannot compile
enum class Type { Boolean, Float32, Uint32, Int32 } type; enum class Type { Boolean, Float32, Uint32, Int32 } type;
@ -230,23 +230,23 @@ struct EntryPointMetadata {
// Store the default initialized value in shader // Store the default initialized value in shader
// This is used by metal backend as the function_constant does not have dafault values // This is used by metal backend as the function_constant does not have dafault values
// Initialized when isInitialized == true // Initialized when isInitialized == true
OverridableConstantScalar defaultValue; OverrideScalar defaultValue;
}; };
using OverridableConstantsMap = std::unordered_map<std::string, OverridableConstant>; using OverridesMap = std::unordered_map<std::string, Override>;
// Map identifier to overridable constant // Map identifier to override variable
// Identifier is unique: either the variable name or the numeric ID if specified // Identifier is unique: either the variable name or the numeric ID if specified
OverridableConstantsMap overridableConstants; OverridesMap overrides;
// Overridable constants that are not initialized in shaders // Override variables that are not initialized in shaders
// They need value initialization from pipeline stage or it is a validation error // They need value initialization from pipeline stage or it is a validation error
std::unordered_set<std::string> uninitializedOverridableConstants; std::unordered_set<std::string> uninitializedOverrides;
// Store constants with shader initialized values as well // Store constants with shader initialized values as well
// This is used by metal backend to set values with default initializers that are not // This is used by metal backend to set values with default initializers that are not
// overridden // overridden
std::unordered_set<std::string> initializedOverridableConstants; std::unordered_set<std::string> initializedOverrides;
bool usesNumWorkgroups = false; bool usesNumWorkgroups = false;
// Used at render pipeline validation. // Used at render pipeline validation.

View File

@ -111,17 +111,17 @@ std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
return out.str(); return out.str();
} }
std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType, std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType,
const OverridableConstantScalar* entry, const OverrideScalar* entry,
double value = 0) { double value = 0) {
switch (dawnType) { switch (dawnType) {
case EntryPointMetadata::OverridableConstant::Type::Boolean: case EntryPointMetadata::Override::Type::Boolean:
return std::to_string(entry ? entry->b : static_cast<int32_t>(value)); return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
case EntryPointMetadata::OverridableConstant::Type::Float32: case EntryPointMetadata::Override::Type::Float32:
return FloatToStringWithPrecision(entry ? entry->f32 : static_cast<float>(value)); return FloatToStringWithPrecision(entry ? entry->f32 : static_cast<float>(value));
case EntryPointMetadata::OverridableConstant::Type::Int32: case EntryPointMetadata::Override::Type::Int32:
return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value)); return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
case EntryPointMetadata::OverridableConstant::Type::Uint32: case EntryPointMetadata::Override::Type::Uint32:
return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value)); return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
default: default:
UNREACHABLE(); UNREACHABLE();
@ -133,7 +133,7 @@ constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
void GetOverridableConstantsDefines( void GetOverridableConstantsDefines(
std::vector<std::pair<std::string, std::string>>* defineStrings, std::vector<std::pair<std::string, std::string>>* defineStrings,
const PipelineConstantEntries* pipelineConstantEntries, const PipelineConstantEntries* pipelineConstantEntries,
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) { const EntryPointMetadata::OverridesMap* shaderEntryPointConstants) {
std::unordered_set<std::string> overriddenConstants; std::unordered_set<std::string> overriddenConstants;
// Set pipeline overridden values // Set pipeline overridden values
@ -305,8 +305,7 @@ struct ShaderCompilationRequest {
GetOverridableConstantsDefines( GetOverridableConstantsDefines(
&request.defineStrings, &programmableStage.constants, &request.defineStrings, &programmableStage.constants,
&programmableStage.module->GetEntryPoint(programmableStage.entryPoint) &programmableStage.module->GetEntryPoint(programmableStage.entryPoint).overrides);
.overridableConstants);
return std::move(request); return std::move(request);
} }

View File

@ -332,7 +332,7 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
ShaderModule* shaderModule = ToBackend(programmableStage.module.Get()); ShaderModule* shaderModule = ToBackend(programmableStage.module.Get());
const char* shaderEntryPoint = programmableStage.entryPoint.c_str(); const char* shaderEntryPoint = programmableStage.entryPoint.c_str();
const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint); const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint);
if (entryPointMetadata.overridableConstants.size() == 0) { if (entryPointMetadata.overrides.size() == 0) {
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout, DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout,
functionData, nil, sampleMask, renderPipeline)); functionData, nil, sampleMask, renderPipeline));
return {}; return {};
@ -345,29 +345,29 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
std::unordered_set<std::string> overriddenConstants; std::unordered_set<std::string> overriddenConstants;
auto switchType = [&](EntryPointMetadata::OverridableConstant::Type dawnType, auto switchType = [&](EntryPointMetadata::Override::Type dawnType,
MTLDataType* type, OverridableConstantScalar* entry, MTLDataType* type, OverrideScalar* entry,
double value = 0) { double value = 0) {
switch (dawnType) { switch (dawnType) {
case EntryPointMetadata::OverridableConstant::Type::Boolean: case EntryPointMetadata::Override::Type::Boolean:
*type = MTLDataTypeBool; *type = MTLDataTypeBool;
if (entry) { if (entry) {
entry->b = static_cast<int32_t>(value); entry->b = static_cast<int32_t>(value);
} }
break; break;
case EntryPointMetadata::OverridableConstant::Type::Float32: case EntryPointMetadata::Override::Type::Float32:
*type = MTLDataTypeFloat; *type = MTLDataTypeFloat;
if (entry) { if (entry) {
entry->f32 = static_cast<float>(value); entry->f32 = static_cast<float>(value);
} }
break; break;
case EntryPointMetadata::OverridableConstant::Type::Int32: case EntryPointMetadata::Override::Type::Int32:
*type = MTLDataTypeInt; *type = MTLDataTypeInt;
if (entry) { if (entry) {
entry->i32 = static_cast<int32_t>(value); entry->i32 = static_cast<int32_t>(value);
} }
break; break;
case EntryPointMetadata::OverridableConstant::Type::Uint32: case EntryPointMetadata::Override::Type::Uint32:
*type = MTLDataTypeUInt; *type = MTLDataTypeUInt;
if (entry) { if (entry) {
entry->u32 = static_cast<uint32_t>(value); entry->u32 = static_cast<uint32_t>(value);
@ -382,10 +382,10 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
overriddenConstants.insert(name); overriddenConstants.insert(name);
// This is already validated so `name` must exist // This is already validated so `name` must exist
const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name); const auto& moduleConstant = entryPointMetadata.overrides.at(name);
MTLDataType type; MTLDataType type;
OverridableConstantScalar entry{}; OverrideScalar entry{};
switchType(moduleConstant.type, &type, &entry, value); switchType(moduleConstant.type, &type, &entry, value);
@ -394,14 +394,14 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
// Set shader initialized default values because MSL function_constant // Set shader initialized default values because MSL function_constant
// has no default value // has no default value
for (const std::string& name : entryPointMetadata.initializedOverridableConstants) { for (const std::string& name : entryPointMetadata.initializedOverrides) {
if (overriddenConstants.count(name) != 0) { if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value // This constant already has overridden value
continue; continue;
} }
// Must exist because it is validated // Must exist because it is validated
const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name); const auto& moduleConstant = entryPointMetadata.overrides.at(name);
ASSERT(moduleConstant.isInitialized); ASSERT(moduleConstant.isInitialized);
MTLDataType type; MTLDataType type;

View File

@ -66,7 +66,7 @@ MaybeError ComputePipeline::Initialize() {
createInfo.stage.module = moduleAndSpirv.module; createInfo.stage.module = moduleAndSpirv.module;
createInfo.stage.pName = computeStage.entryPoint.c_str(); createInfo.stage.pName = computeStage.entryPoint.c_str();
std::vector<OverridableConstantScalar> specializationDataEntries; std::vector<OverrideScalar> specializationDataEntries;
std::vector<VkSpecializationMapEntry> specializationMapEntries; std::vector<VkSpecializationMapEntry> specializationMapEntries;
VkSpecializationInfo specializationInfo{}; VkSpecializationInfo specializationInfo{};
createInfo.stage.pSpecializationInfo = GetVkSpecializationInfo( createInfo.stage.pSpecializationInfo = GetVkSpecializationInfo(

View File

@ -341,7 +341,7 @@ MaybeError RenderPipeline::Initialize() {
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment // There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages; std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
std::array<std::vector<OverridableConstantScalar>, 2> specializationDataEntriesPerStages; std::array<std::vector<OverrideScalar>, 2> specializationDataEntriesPerStages;
std::array<std::vector<VkSpecializationMapEntry>, 2> specializationMapEntriesPerStages; std::array<std::vector<VkSpecializationMapEntry>, 2> specializationMapEntriesPerStages;
std::array<VkSpecializationInfo, 2> specializationInfoPerStages; std::array<VkSpecializationInfo, 2> specializationInfoPerStages;
uint32_t stageCount = 0; uint32_t stageCount = 0;

View File

@ -261,7 +261,7 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName) {
VkSpecializationInfo* GetVkSpecializationInfo( VkSpecializationInfo* GetVkSpecializationInfo(
const ProgrammableStage& programmableStage, const ProgrammableStage& programmableStage,
VkSpecializationInfo* specializationInfo, VkSpecializationInfo* specializationInfo,
std::vector<OverridableConstantScalar>* specializationDataEntries, std::vector<OverrideScalar>* specializationDataEntries,
std::vector<VkSpecializationMapEntry>* specializationMapEntries) { std::vector<VkSpecializationMapEntry>* specializationMapEntries) {
ASSERT(specializationInfo); ASSERT(specializationInfo);
ASSERT(specializationDataEntries); ASSERT(specializationDataEntries);
@ -279,26 +279,25 @@ VkSpecializationInfo* GetVkSpecializationInfo(
double value = pipelineConstant.second; double value = pipelineConstant.second;
// This is already validated so `identifier` must exist // This is already validated so `identifier` must exist
const auto& moduleConstant = entryPointMetaData.overridableConstants.at(identifier); const auto& moduleConstant = entryPointMetaData.overrides.at(identifier);
specializationMapEntries->push_back( specializationMapEntries->push_back(VkSpecializationMapEntry{
VkSpecializationMapEntry{moduleConstant.id, moduleConstant.id,
static_cast<uint32_t>(specializationDataEntries->size() * static_cast<uint32_t>(specializationDataEntries->size() * sizeof(OverrideScalar)),
sizeof(OverridableConstantScalar)), sizeof(OverrideScalar)});
sizeof(OverridableConstantScalar)});
OverridableConstantScalar entry{}; OverrideScalar entry{};
switch (moduleConstant.type) { switch (moduleConstant.type) {
case EntryPointMetadata::OverridableConstant::Type::Boolean: case EntryPointMetadata::Override::Type::Boolean:
entry.b = static_cast<int32_t>(value); entry.b = static_cast<int32_t>(value);
break; break;
case EntryPointMetadata::OverridableConstant::Type::Float32: case EntryPointMetadata::Override::Type::Float32:
entry.f32 = static_cast<float>(value); entry.f32 = static_cast<float>(value);
break; break;
case EntryPointMetadata::OverridableConstant::Type::Int32: case EntryPointMetadata::Override::Type::Int32:
entry.i32 = static_cast<int32_t>(value); entry.i32 = static_cast<int32_t>(value);
break; break;
case EntryPointMetadata::OverridableConstant::Type::Uint32: case EntryPointMetadata::Override::Type::Uint32:
entry.u32 = static_cast<uint32_t>(value); entry.u32 = static_cast<uint32_t>(value);
break; break;
default: default:
@ -309,8 +308,7 @@ VkSpecializationInfo* GetVkSpecializationInfo(
specializationInfo->mapEntryCount = static_cast<uint32_t>(specializationMapEntries->size()); specializationInfo->mapEntryCount = static_cast<uint32_t>(specializationMapEntries->size());
specializationInfo->pMapEntries = specializationMapEntries->data(); specializationInfo->pMapEntries = specializationMapEntries->data();
specializationInfo->dataSize = specializationInfo->dataSize = specializationDataEntries->size() * sizeof(OverrideScalar);
specializationDataEntries->size() * sizeof(OverridableConstantScalar);
specializationInfo->pData = specializationDataEntries->data(); specializationInfo->pData = specializationDataEntries->data();
return specializationInfo; return specializationInfo;

View File

@ -24,7 +24,7 @@
namespace dawn::native { namespace dawn::native {
struct ProgrammableStage; struct ProgrammableStage;
union OverridableConstantScalar; union OverrideScalar;
} // namespace dawn::native } // namespace dawn::native
namespace dawn::native::vulkan { namespace dawn::native::vulkan {
@ -150,7 +150,7 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName);
VkSpecializationInfo* GetVkSpecializationInfo( VkSpecializationInfo* GetVkSpecializationInfo(
const ProgrammableStage& programmableStage, const ProgrammableStage& programmableStage,
VkSpecializationInfo* specializationInfo, VkSpecializationInfo* specializationInfo,
std::vector<OverridableConstantScalar>* specializationDataEntries, std::vector<OverrideScalar>* specializationDataEntries,
std::vector<VkSpecializationMapEntry>* specializationMapEntries); std::vector<VkSpecializationMapEntry>* specializationMapEntries);
} // namespace dawn::native::vulkan } // namespace dawn::native::vulkan

View File

@ -1110,7 +1110,7 @@ if (tint_build_unittests) {
"resolver/is_host_shareable_test.cc", "resolver/is_host_shareable_test.cc",
"resolver/is_storeable_test.cc", "resolver/is_storeable_test.cc",
"resolver/materialize_test.cc", "resolver/materialize_test.cc",
"resolver/pipeline_overridable_constant_test.cc", "resolver/override_test.cc",
"resolver/ptr_ref_test.cc", "resolver/ptr_ref_test.cc",
"resolver/ptr_ref_validation_test.cc", "resolver/ptr_ref_validation_test.cc",
"resolver/resolver_behavior_test.cc", "resolver/resolver_behavior_test.cc",

View File

@ -794,7 +794,7 @@ if(TINT_BUILD_TESTS)
resolver/is_host_shareable_test.cc resolver/is_host_shareable_test.cc
resolver/is_storeable_test.cc resolver/is_storeable_test.cc
resolver/materialize_test.cc resolver/materialize_test.cc
resolver/pipeline_overridable_constant_test.cc resolver/override_test.cc
resolver/ptr_ref_test.cc resolver/ptr_ref_test.cc
resolver/ptr_ref_validation_test.cc resolver/ptr_ref_validation_test.cc
resolver/resolver_behavior_test.cc resolver/resolver_behavior_test.cc

View File

@ -268,10 +268,10 @@ void CommonFuzzer::RunInspector(Program* program) {
auto entry_points = inspector.GetEntryPoints(); auto entry_points = inspector.GetEntryPoints();
CHECK_INSPECTOR(program, inspector); CHECK_INSPECTOR(program, inspector);
auto constant_ids = inspector.GetConstantIDs(); auto override_ids = inspector.GetOverrideDefaultValues();
CHECK_INSPECTOR(program, inspector); CHECK_INSPECTOR(program, inspector);
auto constant_name_to_id = inspector.GetConstantNameToIdMap(); auto override_name_to_id = inspector.GetNamedOverrideIds();
CHECK_INSPECTOR(program, inspector); CHECK_INSPECTOR(program, inspector);
for (auto& ep : entry_points) { for (auto& ep : entry_points) {

View File

@ -20,6 +20,8 @@
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "tint/override_id.h"
#include "src/tint/ast/interpolate_attribute.h" #include "src/tint/ast/interpolate_attribute.h"
#include "src/tint/ast/pipeline_stage.h" #include "src/tint/ast/pipeline_stage.h"
@ -93,14 +95,13 @@ InterpolationType ASTToInspectorInterpolationType(ast::InterpolationType ast_typ
/// @returns the publicly visible equivalent /// @returns the publicly visible equivalent
InterpolationSampling ASTToInspectorInterpolationSampling(ast::InterpolationSampling sampling); InterpolationSampling ASTToInspectorInterpolationSampling(ast::InterpolationSampling sampling);
/// Reflection data about a pipeline overridable constant referenced by an entry /// Reflection data about an override variable referenced by an entry point
/// point struct Override {
struct OverridableConstant { /// Name of the override
/// Name of the constant
std::string name; std::string name;
/// ID of the constant /// ID of the override
uint16_t numeric_id; OverrideId id;
/// Type of the scalar /// Type of the scalar
enum class Type { enum class Type {
@ -113,12 +114,11 @@ struct OverridableConstant {
/// Type of the scalar /// Type of the scalar
Type type; Type type;
/// Does this pipeline overridable constant have an initializer? /// Does this override have an initializer?
bool is_initialized = false; bool is_initialized = false;
/// Does this pipeline overridable constant have a numeric ID specified /// Does this override have a numeric ID specified explicitly?
/// explicitly? bool is_id_specified = false;
bool is_numeric_id_specified = false;
}; };
/// The pipeline stage /// The pipeline stage
@ -159,7 +159,7 @@ struct EntryPoint {
/// List of the output variable accessed via this entry point. /// List of the output variable accessed via this entry point.
std::vector<StageVariable> output_variables; std::vector<StageVariable> output_variables;
/// List of the pipeline overridable constants accessed via this entry point. /// List of the pipeline overridable constants accessed via this entry point.
std::vector<OverridableConstant> overridable_constants; std::vector<Override> overrides;
/// Does the entry point use the sample_mask builtin as an input builtin /// Does the entry point use the sample_mask builtin as an input builtin
/// variable. /// variable.
bool input_sample_mask_used = false; bool input_sample_mask_used = false;

View File

@ -206,28 +206,28 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
auto* global = var->As<sem::GlobalVariable>(); auto* global = var->As<sem::GlobalVariable>();
if (global && global->Declaration()->Is<ast::Override>()) { if (global && global->Declaration()->Is<ast::Override>()) {
OverridableConstant overridable_constant; Override override;
overridable_constant.name = name; override.name = name;
overridable_constant.numeric_id = global->ConstantId(); override.id = global->OverrideId();
auto* type = var->Type(); auto* type = var->Type();
TINT_ASSERT(Inspector, type->is_scalar()); TINT_ASSERT(Inspector, type->is_scalar());
if (type->is_bool_scalar_or_vector()) { if (type->is_bool_scalar_or_vector()) {
overridable_constant.type = OverridableConstant::Type::kBool; override.type = Override::Type::kBool;
} else if (type->is_float_scalar()) { } else if (type->is_float_scalar()) {
overridable_constant.type = OverridableConstant::Type::kFloat32; override.type = Override::Type::kFloat32;
} else if (type->is_signed_integer_scalar()) { } else if (type->is_signed_integer_scalar()) {
overridable_constant.type = OverridableConstant::Type::kInt32; override.type = Override::Type::kInt32;
} else if (type->is_unsigned_integer_scalar()) { } else if (type->is_unsigned_integer_scalar()) {
overridable_constant.type = OverridableConstant::Type::kUint32; override.type = Override::Type::kUint32;
} else { } else {
TINT_UNREACHABLE(Inspector, diagnostics_); TINT_UNREACHABLE(Inspector, diagnostics_);
} }
overridable_constant.is_initialized = global->Declaration()->constructor; override.is_initialized = global->Declaration()->constructor;
overridable_constant.is_numeric_id_specified = override.is_id_specified =
ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes); ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes);
entry_point.overridable_constants.push_back(overridable_constant); entry_point.overrides.push_back(override);
} }
} }
@ -237,37 +237,37 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
return result; return result;
} }
std::map<uint32_t, Scalar> Inspector::GetConstantIDs() { std::map<OverrideId, Scalar> Inspector::GetOverrideDefaultValues() {
std::map<uint32_t, Scalar> result; std::map<OverrideId, Scalar> result;
for (auto* var : program_->AST().GlobalVariables()) { for (auto* var : program_->AST().GlobalVariables()) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(var); auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
if (!global || !global->Declaration()->Is<ast::Override>()) { if (!global || !global->Declaration()->Is<ast::Override>()) {
continue; continue;
} }
// If there are conflicting defintions for a constant id, that is invalid // If there are conflicting defintions for an override id, that is invalid
// WGSL, so the resolver should catch it. Thus here the inspector just // WGSL, so the resolver should catch it. Thus here the inspector just
// assumes all definitions of the constant id are the same, so only needs // assumes all definitions of the override id are the same, so only needs
// to find the first reference to constant id. // to find the first reference to override id.
uint32_t constant_id = global->ConstantId(); OverrideId override_id = global->OverrideId();
if (result.find(constant_id) != result.end()) { if (result.find(override_id) != result.end()) {
continue; continue;
} }
if (!var->constructor) { if (!var->constructor) {
result[constant_id] = Scalar(); result[override_id] = Scalar();
continue; continue;
} }
auto* literal = var->constructor->As<ast::LiteralExpression>(); auto* literal = var->constructor->As<ast::LiteralExpression>();
if (!literal) { if (!literal) {
// This is invalid WGSL, but handling gracefully. // This is invalid WGSL, but handling gracefully.
result[constant_id] = Scalar(); result[override_id] = Scalar();
continue; continue;
} }
if (auto* l = literal->As<ast::BoolLiteralExpression>()) { if (auto* l = literal->As<ast::BoolLiteralExpression>()) {
result[constant_id] = Scalar(l->value); result[override_id] = Scalar(l->value);
continue; continue;
} }
@ -275,32 +275,32 @@ std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
switch (l->suffix) { switch (l->suffix) {
case ast::IntLiteralExpression::Suffix::kNone: case ast::IntLiteralExpression::Suffix::kNone:
case ast::IntLiteralExpression::Suffix::kI: case ast::IntLiteralExpression::Suffix::kI:
result[constant_id] = Scalar(static_cast<int32_t>(l->value)); result[override_id] = Scalar(static_cast<int32_t>(l->value));
continue; continue;
case ast::IntLiteralExpression::Suffix::kU: case ast::IntLiteralExpression::Suffix::kU:
result[constant_id] = Scalar(static_cast<uint32_t>(l->value)); result[override_id] = Scalar(static_cast<uint32_t>(l->value));
continue; continue;
} }
} }
if (auto* l = literal->As<ast::FloatLiteralExpression>()) { if (auto* l = literal->As<ast::FloatLiteralExpression>()) {
result[constant_id] = Scalar(static_cast<float>(l->value)); result[override_id] = Scalar(static_cast<float>(l->value));
continue; continue;
} }
result[constant_id] = Scalar(); result[override_id] = Scalar();
} }
return result; return result;
} }
std::map<std::string, uint32_t> Inspector::GetConstantNameToIdMap() { std::map<std::string, OverrideId> Inspector::GetNamedOverrideIds() {
std::map<std::string, uint32_t> result; std::map<std::string, OverrideId> result;
for (auto* var : program_->AST().GlobalVariables()) { for (auto* var : program_->AST().GlobalVariables()) {
auto* global = program_->Sem().Get<sem::GlobalVariable>(var); auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
if (global && global->Declaration()->Is<ast::Override>()) { if (global && global->Declaration()->Is<ast::Override>()) {
auto name = program_->Symbols().NameFor(var->symbol); auto name = program_->Symbols().NameFor(var->symbol);
result[name] = global->ConstantId(); result[name] = global->OverrideId();
} }
} }
return result; return result;

View File

@ -23,6 +23,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "tint/override_id.h"
#include "src/tint/inspector/entry_point.h" #include "src/tint/inspector/entry_point.h"
#include "src/tint/inspector/resource_binding.h" #include "src/tint/inspector/resource_binding.h"
#include "src/tint/inspector/scalar.h" #include "src/tint/inspector/scalar.h"
@ -53,11 +55,11 @@ class Inspector {
/// @returns vector of entry point information /// @returns vector of entry point information
std::vector<EntryPoint> GetEntryPoints(); std::vector<EntryPoint> GetEntryPoints();
/// @returns map of const_id to initial value /// @returns map of override identifier to initial value
std::map<uint32_t, Scalar> GetConstantIDs(); std::map<OverrideId, Scalar> GetOverrideDefaultValues();
/// @returns map of module-constant name to pipeline constant ID /// @returns map of module-constant name to pipeline constant ID
std::map<std::string, uint32_t> GetConstantNameToIdMap(); std::map<std::string, OverrideId> GetNamedOverrideIds();
/// @param entry_point name of the entry point to get information about. /// @param entry_point name of the entry point to get information about.
/// @returns the total size of shared storage required by an entry point, /// @returns the total size of shared storage required by an entry point,

View File

@ -60,7 +60,7 @@ struct InspectorGetEntryPointInterpolateTestParams {
class InspectorGetEntryPointInterpolateTest class InspectorGetEntryPointInterpolateTest
: public InspectorBuilder, : public InspectorBuilder,
public testing::TestWithParam<InspectorGetEntryPointInterpolateTestParams> {}; public testing::TestWithParam<InspectorGetEntryPointInterpolateTestParams> {};
class InspectorGetConstantIDsTest : public InspectorBuilder, public testing::Test {}; class InspectorGetOverrideDefaultValuesTest : public InspectorBuilder, public testing::Test {};
class InspectorGetConstantNameToIdMapTest : public InspectorBuilder, public testing::Test {}; class InspectorGetConstantNameToIdMapTest : public InspectorBuilder, public testing::Test {};
class InspectorGetStorageSizeTest : public InspectorBuilder, public testing::Test {}; class InspectorGetStorageSizeTest : public InspectorBuilder, public testing::Test {};
class InspectorGetResourceBindingsTest : public InspectorBuilder, public testing::Test {}; class InspectorGetResourceBindingsTest : public InspectorBuilder, public testing::Test {};
@ -605,8 +605,8 @@ TEST_F(InspectorGetEntryPointTest, MixInOutVariablesAndStruct) {
EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[1].component_type); EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[1].component_type);
} }
TEST_F(InspectorGetEntryPointTest, OverridableConstantUnreferenced) { TEST_F(InspectorGetEntryPointTest, OverrideUnreferenced) {
AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); Override("foo", ty.f32(), nullptr);
MakeEmptyBodyFunction("ep_func", { MakeEmptyBodyFunction("ep_func", {
Stage(ast::PipelineStage::kCompute), Stage(ast::PipelineStage::kCompute),
WorkgroupSize(1_i), WorkgroupSize(1_i),
@ -617,11 +617,11 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantUnreferenced) {
auto result = inspector.GetEntryPoints(); auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
EXPECT_EQ(0u, result[0].overridable_constants.size()); EXPECT_EQ(0u, result[0].overrides.size());
} }
TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) { TEST_F(InspectorGetEntryPointTest, OverrideReferencedByEntryPoint) {
AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); Override("foo", ty.f32(), nullptr);
MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(), MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(),
{ {
Stage(ast::PipelineStage::kCompute), Stage(ast::PipelineStage::kCompute),
@ -633,12 +633,12 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) {
auto result = inspector.GetEntryPoints(); auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
ASSERT_EQ(1u, result[0].overridable_constants.size()); ASSERT_EQ(1u, result[0].overrides.size());
EXPECT_EQ("foo", result[0].overridable_constants[0].name); EXPECT_EQ("foo", result[0].overrides[0].name);
} }
TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) { TEST_F(InspectorGetEntryPointTest, OverrideReferencedByCallee) {
AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); Override("foo", ty.f32(), nullptr);
MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {}); MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
MakeCallerBodyFunction("ep_func", {"callee_func"}, MakeCallerBodyFunction("ep_func", {"callee_func"},
{ {
@ -651,13 +651,13 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) {
auto result = inspector.GetEntryPoints(); auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
ASSERT_EQ(1u, result[0].overridable_constants.size()); ASSERT_EQ(1u, result[0].overrides.size());
EXPECT_EQ("foo", result[0].overridable_constants[0].name); EXPECT_EQ("foo", result[0].overrides[0].name);
} }
TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) { TEST_F(InspectorGetEntryPointTest, OverrideSomeReferenced) {
AddOverridableConstantWithID("foo", 1, ty.f32(), nullptr); Override("foo", ty.f32(), nullptr, {Id(1)});
AddOverridableConstantWithID("bar", 2, ty.f32(), nullptr); Override("bar", ty.f32(), nullptr, {Id(2)});
MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {}); MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
MakeCallerBodyFunction("ep_func", {"callee_func"}, MakeCallerBodyFunction("ep_func", {"callee_func"},
{ {
@ -670,16 +670,16 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) {
auto result = inspector.GetEntryPoints(); auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
ASSERT_EQ(1u, result[0].overridable_constants.size()); ASSERT_EQ(1u, result[0].overrides.size());
EXPECT_EQ("foo", result[0].overridable_constants[0].name); EXPECT_EQ("foo", result[0].overrides[0].name);
EXPECT_EQ(1, result[0].overridable_constants[0].numeric_id); EXPECT_EQ(1, result[0].overrides[0].id.value);
} }
TEST_F(InspectorGetEntryPointTest, OverridableConstantTypes) { TEST_F(InspectorGetEntryPointTest, OverrideTypes) {
AddOverridableConstantWithoutID("bool_var", ty.bool_(), nullptr); Override("bool_var", ty.bool_(), nullptr);
AddOverridableConstantWithoutID("float_var", ty.f32(), nullptr); Override("float_var", ty.f32(), nullptr);
AddOverridableConstantWithoutID("u32_var", ty.u32(), nullptr); Override("u32_var", ty.u32(), nullptr);
AddOverridableConstantWithoutID("i32_var", ty.i32(), nullptr); Override("i32_var", ty.i32(), nullptr);
MakePlainGlobalReferenceBodyFunction("bool_func", "bool_var", ty.bool_(), {}); MakePlainGlobalReferenceBodyFunction("bool_func", "bool_var", ty.bool_(), {});
MakePlainGlobalReferenceBodyFunction("float_func", "float_var", ty.f32(), {}); MakePlainGlobalReferenceBodyFunction("float_func", "float_var", ty.f32(), {});
@ -697,22 +697,19 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantTypes) {
auto result = inspector.GetEntryPoints(); auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
ASSERT_EQ(4u, result[0].overridable_constants.size()); ASSERT_EQ(4u, result[0].overrides.size());
EXPECT_EQ("bool_var", result[0].overridable_constants[0].name); EXPECT_EQ("bool_var", result[0].overrides[0].name);
EXPECT_EQ(inspector::OverridableConstant::Type::kBool, result[0].overridable_constants[0].type); EXPECT_EQ(inspector::Override::Type::kBool, result[0].overrides[0].type);
EXPECT_EQ("float_var", result[0].overridable_constants[1].name); EXPECT_EQ("float_var", result[0].overrides[1].name);
EXPECT_EQ(inspector::OverridableConstant::Type::kFloat32, EXPECT_EQ(inspector::Override::Type::kFloat32, result[0].overrides[1].type);
result[0].overridable_constants[1].type); EXPECT_EQ("u32_var", result[0].overrides[2].name);
EXPECT_EQ("u32_var", result[0].overridable_constants[2].name); EXPECT_EQ(inspector::Override::Type::kUint32, result[0].overrides[2].type);
EXPECT_EQ(inspector::OverridableConstant::Type::kUint32, EXPECT_EQ("i32_var", result[0].overrides[3].name);
result[0].overridable_constants[2].type); EXPECT_EQ(inspector::Override::Type::kInt32, result[0].overrides[3].type);
EXPECT_EQ("i32_var", result[0].overridable_constants[3].name);
EXPECT_EQ(inspector::OverridableConstant::Type::kInt32,
result[0].overridable_constants[3].type);
} }
TEST_F(InspectorGetEntryPointTest, OverridableConstantInitialized) { TEST_F(InspectorGetEntryPointTest, OverrideInitialized) {
AddOverridableConstantWithoutID("foo", ty.f32(), Expr(0_f)); Override("foo", ty.f32(), Expr(0_f));
MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(), MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(),
{ {
Stage(ast::PipelineStage::kCompute), Stage(ast::PipelineStage::kCompute),
@ -724,13 +721,13 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantInitialized) {
auto result = inspector.GetEntryPoints(); auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
ASSERT_EQ(1u, result[0].overridable_constants.size()); ASSERT_EQ(1u, result[0].overrides.size());
EXPECT_EQ("foo", result[0].overridable_constants[0].name); EXPECT_EQ("foo", result[0].overrides[0].name);
EXPECT_TRUE(result[0].overridable_constants[0].is_initialized); EXPECT_TRUE(result[0].overrides[0].is_initialized);
} }
TEST_F(InspectorGetEntryPointTest, OverridableConstantUninitialized) { TEST_F(InspectorGetEntryPointTest, OverrideUninitialized) {
AddOverridableConstantWithoutID("foo", ty.f32(), nullptr); Override("foo", ty.f32(), nullptr);
MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(), MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(),
{ {
Stage(ast::PipelineStage::kCompute), Stage(ast::PipelineStage::kCompute),
@ -742,15 +739,15 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantUninitialized) {
auto result = inspector.GetEntryPoints(); auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
ASSERT_EQ(1u, result[0].overridable_constants.size()); ASSERT_EQ(1u, result[0].overrides.size());
EXPECT_EQ("foo", result[0].overridable_constants[0].name); EXPECT_EQ("foo", result[0].overrides[0].name);
EXPECT_FALSE(result[0].overridable_constants[0].is_initialized); EXPECT_FALSE(result[0].overrides[0].is_initialized);
} }
TEST_F(InspectorGetEntryPointTest, OverridableConstantNumericIDSpecified) { TEST_F(InspectorGetEntryPointTest, OverrideNumericIDSpecified) {
AddOverridableConstantWithoutID("foo_no_id", ty.f32(), nullptr); Override("foo_no_id", ty.f32(), nullptr);
AddOverridableConstantWithID("foo_id", 1234, ty.f32(), nullptr); Override("foo_id", ty.f32(), nullptr, {Id(1234)});
MakePlainGlobalReferenceBodyFunction("no_id_func", "foo_no_id", ty.f32(), {}); MakePlainGlobalReferenceBodyFunction("no_id_func", "foo_no_id", ty.f32(), {});
MakePlainGlobalReferenceBodyFunction("id_func", "foo_id", ty.f32(), {}); MakePlainGlobalReferenceBodyFunction("id_func", "foo_id", ty.f32(), {});
@ -766,16 +763,16 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantNumericIDSpecified) {
auto result = inspector.GetEntryPoints(); auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
ASSERT_EQ(2u, result[0].overridable_constants.size()); ASSERT_EQ(2u, result[0].overrides.size());
EXPECT_EQ("foo_no_id", result[0].overridable_constants[0].name); EXPECT_EQ("foo_no_id", result[0].overrides[0].name);
EXPECT_EQ("foo_id", result[0].overridable_constants[1].name); EXPECT_EQ("foo_id", result[0].overrides[1].name);
EXPECT_EQ(1234, result[0].overridable_constants[1].numeric_id); EXPECT_EQ(1234, result[0].overrides[1].id.value);
EXPECT_FALSE(result[0].overridable_constants[0].is_numeric_id_specified); EXPECT_FALSE(result[0].overrides[0].is_id_specified);
EXPECT_TRUE(result[0].overridable_constants[1].is_numeric_id_specified); EXPECT_TRUE(result[0].overrides[1].is_id_specified);
} }
TEST_F(InspectorGetEntryPointTest, NonOverridableConstantSkipped) { TEST_F(InspectorGetEntryPointTest, NonOverrideSkipped) {
auto* foo_struct_type = MakeUniformBufferType("foo_type", {ty.i32()}); auto* foo_struct_type = MakeUniformBufferType("foo_type", {ty.i32()});
AddUniformBuffer("foo_ub", ty.Of(foo_struct_type), 0, 0); AddUniformBuffer("foo_ub", ty.Of(foo_struct_type), 0, 0);
MakeStructVariableReferenceBodyFunction("ub_func", "foo_ub", {{0, ty.i32()}}); MakeStructVariableReferenceBodyFunction("ub_func", "foo_ub", {{0, ty.i32()}});
@ -789,7 +786,7 @@ TEST_F(InspectorGetEntryPointTest, NonOverridableConstantSkipped) {
auto result = inspector.GetEntryPoints(); auto result = inspector.GetEntryPoints();
ASSERT_EQ(1u, result.size()); ASSERT_EQ(1u, result.size());
EXPECT_EQ(0u, result[0].overridable_constants.size()); EXPECT_EQ(0u, result[0].overrides.size());
} }
TEST_F(InspectorGetEntryPointTest, BuiltinNotReferenced) { TEST_F(InspectorGetEntryPointTest, BuiltinNotReferenced) {
@ -1172,127 +1169,127 @@ INSTANTIATE_TEST_SUITE_P(
ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone, ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone,
InterpolationType::kFlat, InterpolationSampling::kNone})); InterpolationType::kFlat, InterpolationSampling::kNone}));
TEST_F(InspectorGetConstantIDsTest, Bool) { TEST_F(InspectorGetOverrideDefaultValuesTest, Bool) {
AddOverridableConstantWithID("foo", 1, ty.bool_(), nullptr); Override("foo", ty.bool_(), nullptr, {Id(1)});
AddOverridableConstantWithID("bar", 20, ty.bool_(), Expr(true)); Override("bar", ty.bool_(), Expr(true), {Id(20)});
AddOverridableConstantWithID("baz", 300, ty.bool_(), Expr(false)); Override("baz", ty.bool_(), Expr(false), {Id(300)});
Inspector& inspector = Build(); Inspector& inspector = Build();
auto result = inspector.GetConstantIDs(); auto result = inspector.GetOverrideDefaultValues();
ASSERT_EQ(3u, result.size()); ASSERT_EQ(3u, result.size());
ASSERT_TRUE(result.find(1) != result.end()); ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
EXPECT_TRUE(result[1].IsNull()); EXPECT_TRUE(result[OverrideId{1}].IsNull());
ASSERT_TRUE(result.find(20) != result.end()); ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
EXPECT_TRUE(result[20].IsBool()); EXPECT_TRUE(result[OverrideId{20}].IsBool());
EXPECT_TRUE(result[20].AsBool()); EXPECT_TRUE(result[OverrideId{20}].AsBool());
ASSERT_TRUE(result.find(300) != result.end()); ASSERT_TRUE(result.find(OverrideId{300}) != result.end());
EXPECT_TRUE(result[300].IsBool()); EXPECT_TRUE(result[OverrideId{300}].IsBool());
EXPECT_FALSE(result[300].AsBool()); EXPECT_FALSE(result[OverrideId{300}].AsBool());
} }
TEST_F(InspectorGetConstantIDsTest, U32) { TEST_F(InspectorGetOverrideDefaultValuesTest, U32) {
AddOverridableConstantWithID("foo", 1, ty.u32(), nullptr); Override("foo", ty.u32(), nullptr, {Id(1)});
AddOverridableConstantWithID("bar", 20, ty.u32(), Expr(42_u)); Override("bar", ty.u32(), Expr(42_u), {Id(20)});
Inspector& inspector = Build(); Inspector& inspector = Build();
auto result = inspector.GetConstantIDs(); auto result = inspector.GetOverrideDefaultValues();
ASSERT_EQ(2u, result.size()); ASSERT_EQ(2u, result.size());
ASSERT_TRUE(result.find(1) != result.end()); ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
EXPECT_TRUE(result[1].IsNull()); EXPECT_TRUE(result[OverrideId{1}].IsNull());
ASSERT_TRUE(result.find(20) != result.end()); ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
EXPECT_TRUE(result[20].IsU32()); EXPECT_TRUE(result[OverrideId{20}].IsU32());
EXPECT_EQ(42u, result[20].AsU32()); EXPECT_EQ(42u, result[OverrideId{20}].AsU32());
} }
TEST_F(InspectorGetConstantIDsTest, I32) { TEST_F(InspectorGetOverrideDefaultValuesTest, I32) {
AddOverridableConstantWithID("foo", 1, ty.i32(), nullptr); Override("foo", ty.i32(), nullptr, {Id(1)});
AddOverridableConstantWithID("bar", 20, ty.i32(), Expr(-42_i)); Override("bar", ty.i32(), Expr(-42_i), {Id(20)});
AddOverridableConstantWithID("baz", 300, ty.i32(), Expr(42_i)); Override("baz", ty.i32(), Expr(42_i), {Id(300)});
Inspector& inspector = Build(); Inspector& inspector = Build();
auto result = inspector.GetConstantIDs(); auto result = inspector.GetOverrideDefaultValues();
ASSERT_EQ(3u, result.size()); ASSERT_EQ(3u, result.size());
ASSERT_TRUE(result.find(1) != result.end()); ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
EXPECT_TRUE(result[1].IsNull()); EXPECT_TRUE(result[OverrideId{1}].IsNull());
ASSERT_TRUE(result.find(20) != result.end()); ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
EXPECT_TRUE(result[20].IsI32()); EXPECT_TRUE(result[OverrideId{20}].IsI32());
EXPECT_EQ(-42, result[20].AsI32()); EXPECT_EQ(-42, result[OverrideId{20}].AsI32());
ASSERT_TRUE(result.find(300) != result.end()); ASSERT_TRUE(result.find(OverrideId{300}) != result.end());
EXPECT_TRUE(result[300].IsI32()); EXPECT_TRUE(result[OverrideId{300}].IsI32());
EXPECT_EQ(42, result[300].AsI32()); EXPECT_EQ(42, result[OverrideId{300}].AsI32());
} }
TEST_F(InspectorGetConstantIDsTest, Float) { TEST_F(InspectorGetOverrideDefaultValuesTest, Float) {
AddOverridableConstantWithID("foo", 1, ty.f32(), nullptr); Override("foo", ty.f32(), nullptr, {Id(1)});
AddOverridableConstantWithID("bar", 20, ty.f32(), Expr(0_f)); Override("bar", ty.f32(), Expr(0_f), {Id(20)});
AddOverridableConstantWithID("baz", 300, ty.f32(), Expr(-10_f)); Override("baz", ty.f32(), Expr(-10_f), {Id(300)});
AddOverridableConstantWithID("x", 4000, ty.f32(), Expr(15_f)); Override("x", ty.f32(), Expr(15_f), {Id(4000)});
Inspector& inspector = Build(); Inspector& inspector = Build();
auto result = inspector.GetConstantIDs(); auto result = inspector.GetOverrideDefaultValues();
ASSERT_EQ(4u, result.size()); ASSERT_EQ(4u, result.size());
ASSERT_TRUE(result.find(1) != result.end()); ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
EXPECT_TRUE(result[1].IsNull()); EXPECT_TRUE(result[OverrideId{1}].IsNull());
ASSERT_TRUE(result.find(20) != result.end()); ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
EXPECT_TRUE(result[20].IsFloat()); EXPECT_TRUE(result[OverrideId{20}].IsFloat());
EXPECT_FLOAT_EQ(0.0f, result[20].AsFloat()); EXPECT_FLOAT_EQ(0.0f, result[OverrideId{20}].AsFloat());
ASSERT_TRUE(result.find(300) != result.end()); ASSERT_TRUE(result.find(OverrideId{300}) != result.end());
EXPECT_TRUE(result[300].IsFloat()); EXPECT_TRUE(result[OverrideId{300}].IsFloat());
EXPECT_FLOAT_EQ(-10.0f, result[300].AsFloat()); EXPECT_FLOAT_EQ(-10.0f, result[OverrideId{300}].AsFloat());
ASSERT_TRUE(result.find(4000) != result.end()); ASSERT_TRUE(result.find(OverrideId{4000}) != result.end());
EXPECT_TRUE(result[4000].IsFloat()); EXPECT_TRUE(result[OverrideId{4000}].IsFloat());
EXPECT_FLOAT_EQ(15.0f, result[4000].AsFloat()); EXPECT_FLOAT_EQ(15.0f, result[OverrideId{4000}].AsFloat());
} }
TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) { TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) {
AddOverridableConstantWithID("v1", 1, ty.f32(), nullptr); Override("v1", ty.f32(), nullptr, {Id(1)});
AddOverridableConstantWithID("v20", 20, ty.f32(), nullptr); Override("v20", ty.f32(), nullptr, {Id(20)});
AddOverridableConstantWithID("v300", 300, ty.f32(), nullptr); Override("v300", ty.f32(), nullptr, {Id(300)});
auto* a = AddOverridableConstantWithoutID("a", ty.f32(), nullptr); auto* a = Override("a", ty.f32(), nullptr);
auto* b = AddOverridableConstantWithoutID("b", ty.f32(), nullptr); auto* b = Override("b", ty.f32(), nullptr);
auto* c = AddOverridableConstantWithoutID("c", ty.f32(), nullptr); auto* c = Override("c", ty.f32(), nullptr);
Inspector& inspector = Build(); Inspector& inspector = Build();
auto result = inspector.GetConstantNameToIdMap(); auto result = inspector.GetNamedOverrideIds();
ASSERT_EQ(6u, result.size()); ASSERT_EQ(6u, result.size());
ASSERT_TRUE(result.count("v1")); ASSERT_TRUE(result.count("v1"));
EXPECT_EQ(result["v1"], 1u); EXPECT_EQ(result["v1"].value, 1u);
ASSERT_TRUE(result.count("v20")); ASSERT_TRUE(result.count("v20"));
EXPECT_EQ(result["v20"], 20u); EXPECT_EQ(result["v20"].value, 20u);
ASSERT_TRUE(result.count("v300")); ASSERT_TRUE(result.count("v300"));
EXPECT_EQ(result["v300"], 300u); EXPECT_EQ(result["v300"].value, 300u);
ASSERT_TRUE(result.count("a")); ASSERT_TRUE(result.count("a"));
ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(a)); ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(a));
EXPECT_EQ(result["a"], program_->Sem().Get<sem::GlobalVariable>(a)->ConstantId()); EXPECT_EQ(result["a"], program_->Sem().Get<sem::GlobalVariable>(a)->OverrideId());
ASSERT_TRUE(result.count("b")); ASSERT_TRUE(result.count("b"));
ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(b)); ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(b));
EXPECT_EQ(result["b"], program_->Sem().Get<sem::GlobalVariable>(b)->ConstantId()); EXPECT_EQ(result["b"], program_->Sem().Get<sem::GlobalVariable>(b)->OverrideId());
ASSERT_TRUE(result.count("c")); ASSERT_TRUE(result.count("c"));
ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(c)); ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(c));
EXPECT_EQ(result["c"], program_->Sem().Get<sem::GlobalVariable>(c)->ConstantId()); EXPECT_EQ(result["c"], program_->Sem().Get<sem::GlobalVariable>(c)->OverrideId());
} }
TEST_F(InspectorGetStorageSizeTest, Empty) { TEST_F(InspectorGetStorageSizeTest, Empty) {

View File

@ -92,31 +92,6 @@ class InspectorBuilder : public ProgramBuilder {
std::vector<std::tuple<std::string, std::string>> inout_vars, std::vector<std::tuple<std::string, std::string>> inout_vars,
ast::AttributeList attributes); ast::AttributeList attributes);
/// Add a pipeline constant to the global variables, with a specific ID.
/// @param name name of the variable to add
/// @param id id number for the constant id
/// @param type type of the variable
/// @param constructor val to initialize the constant with, if NULL no
/// constructor will be added.
/// @returns the constant that was created
const ast::Variable* AddOverridableConstantWithID(std::string name,
uint32_t id,
const ast::Type* type,
const ast::Expression* constructor) {
return Override(name, type, constructor, {Id(id)});
}
/// Add a pipeline constant to the global variables, without a specific ID.
/// @param name name of the variable to add
/// @param type type of the variable
/// @param constructor val to initialize the constant with, if NULL no
/// constructor will be added.
/// @returns the constant that was created
const ast::Variable* AddOverridableConstantWithoutID(std::string name,
const ast::Type* type,
const ast::Expression* constructor) {
return Override(name, type, constructor);
}
/// Generates a function that references module-scoped, plain-typed constant /// Generates a function that references module-scoped, plain-typed constant
/// or variable. /// or variable.

View File

@ -19,6 +19,8 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include "tint/override_id.h"
#include "src/tint/ast/alias.h" #include "src/tint/ast/alias.h"
#include "src/tint/ast/array.h" #include "src/tint/ast/array.h"
#include "src/tint/ast/assignment_statement.h" #include "src/tint/ast/assignment_statement.h"
@ -2588,6 +2590,19 @@ class ProgramBuilder {
return create<ast::LocationAttribute>(source_, location); return create<ast::LocationAttribute>(source_, location);
} }
/// Creates an ast::IdAttribute
/// @param source the source information
/// @param id the id value
/// @returns the override attribute pointer
const ast::IdAttribute* Id(const Source& source, OverrideId id) {
return create<ast::IdAttribute>(source, id.value);
}
/// Creates an ast::IdAttribute with an override identifier
/// @param id the optional id value
/// @returns the override attribute pointer
const ast::IdAttribute* Id(OverrideId id) { return Id(source_, id); }
/// Creates an ast::IdAttribute /// Creates an ast::IdAttribute
/// @param source the source information /// @param source the source information
/// @param id the id value /// @param id the id value
@ -2596,7 +2611,7 @@ class ProgramBuilder {
return create<ast::IdAttribute>(source, id); return create<ast::IdAttribute>(source, id);
} }
/// Creates an ast::IdAttribute with a constant ID /// Creates an ast::IdAttribute with an override identifier
/// @param id the optional id value /// @param id the optional id value
/// @returns the override attribute pointer /// @returns the override attribute pointer
const ast::IdAttribute* Id(uint32_t id) { return Id(source_, id); } const ast::IdAttribute* Id(uint32_t id) { return Id(source_, id); }

View File

@ -21,23 +21,23 @@ using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver { namespace tint::resolver {
namespace { namespace {
class ResolverPipelineOverridableConstantTest : public ResolverTest { class ResolverOverrideTest : public ResolverTest {
protected: protected:
/// Verify that the AST node `var` was resolved to an overridable constant /// Verify that the AST node `var` was resolved to an overridable constant
/// with an ID equal to `id`. /// with an ID equal to `id`.
/// @param var the overridable constant AST node /// @param var the overridable constant AST node
/// @param id the expected constant ID /// @param id the expected constant ID
void ExpectConstantId(const ast::Variable* var, uint16_t id) { void ExpectOverrideId(const ast::Variable* var, uint16_t id) {
auto* sem = Sem().Get<sem::GlobalVariable>(var); auto* sem = Sem().Get<sem::GlobalVariable>(var);
ASSERT_NE(sem, nullptr); ASSERT_NE(sem, nullptr);
EXPECT_EQ(sem->Declaration(), var); EXPECT_EQ(sem->Declaration(), var);
EXPECT_TRUE(sem->Declaration()->Is<ast::Override>()); EXPECT_TRUE(sem->Declaration()->Is<ast::Override>());
EXPECT_EQ(sem->ConstantId(), id); EXPECT_EQ(sem->OverrideId().value, id);
EXPECT_FALSE(sem->ConstantValue()); EXPECT_FALSE(sem->ConstantValue());
} }
}; };
TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) { TEST_F(ResolverOverrideTest, NonOverridable) {
auto* a = GlobalConst("a", ty.f32(), Expr(1_f)); auto* a = GlobalConst("a", ty.f32(), Expr(1_f));
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -49,23 +49,23 @@ TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
EXPECT_TRUE(sem_a->ConstantValue()); EXPECT_TRUE(sem_a->ConstantValue());
} }
TEST_F(ResolverPipelineOverridableConstantTest, WithId) { TEST_F(ResolverOverrideTest, WithId) {
auto* a = Override("a", ty.f32(), Expr(1_f), {Id(7u)}); auto* a = Override("a", ty.f32(), Expr(1_f), {Id(7u)});
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
ExpectConstantId(a, 7u); ExpectOverrideId(a, 7u);
} }
TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) { TEST_F(ResolverOverrideTest, WithoutId) {
auto* a = Override("a", ty.f32(), Expr(1_f)); auto* a = Override("a", ty.f32(), Expr(1_f));
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
ExpectConstantId(a, 0u); ExpectOverrideId(a, 0u);
} }
TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) { TEST_F(ResolverOverrideTest, WithAndWithoutIds) {
std::vector<ast::Variable*> variables; std::vector<ast::Variable*> variables;
auto* a = Override("a", ty.f32(), Expr(1_f)); auto* a = Override("a", ty.f32(), Expr(1_f));
auto* b = Override("b", ty.f32(), Expr(1_f)); auto* b = Override("b", ty.f32(), Expr(1_f));
@ -77,33 +77,33 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
// Verify that constant id allocation order is deterministic. // Verify that constant id allocation order is deterministic.
ExpectConstantId(a, 0u); ExpectOverrideId(a, 0u);
ExpectConstantId(b, 3u); ExpectOverrideId(b, 3u);
ExpectConstantId(c, 2u); ExpectOverrideId(c, 2u);
ExpectConstantId(d, 4u); ExpectOverrideId(d, 4u);
ExpectConstantId(e, 5u); ExpectOverrideId(e, 5u);
ExpectConstantId(f, 1u); ExpectOverrideId(f, 1u);
} }
TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) { TEST_F(ResolverOverrideTest, DuplicateIds) {
Override("a", ty.f32(), Expr(1_f), {Id(Source{{12, 34}}, 7u)}); Override("a", ty.f32(), Expr(1_f), {Id(Source{{12, 34}}, 7u)});
Override("b", ty.f32(), Expr(1_f), {Id(Source{{56, 78}}, 7u)}); Override("b", ty.f32(), Expr(1_f), {Id(Source{{56, 78}}, 7u)});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), R"(56:78 error: pipeline constant IDs must be unique EXPECT_EQ(r()->error(), R"(56:78 error: override IDs must be unique
12:34 note: a pipeline constant with an ID of 7 was previously declared here:)"); 12:34 note: a override with an ID of 7 was previously declared here:)");
} }
TEST_F(ResolverPipelineOverridableConstantTest, IdTooLarge) { TEST_F(ResolverOverrideTest, IdTooLarge) {
Override("a", ty.f32(), Expr(1_f), {Id(Source{{12, 34}}, 65536u)}); Override("a", ty.f32(), Expr(1_f), {Id(Source{{12, 34}}, 65536u)});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: pipeline constant IDs must be between 0 and 65535"); EXPECT_EQ(r()->error(), "12:34 error: override IDs must be between 0 and 65535");
} }
TEST_F(ResolverPipelineOverridableConstantTest, F16_TemporallyBan) { TEST_F(ResolverOverrideTest, F16_TemporallyBan) {
Enable(ast::Extension::kF16); Enable(ast::Extension::kF16);
Override(Source{{12, 34}}, "a", ty.f16(), Expr(1_h), {Id(1u)}); Override(Source{{12, 34}}, "a", ty.f16(), Expr(1_h), {Id(1u)});

View File

@ -149,7 +149,9 @@ bool Resolver::ResolveInternal() {
} }
} }
AllocateOverridableConstantIds(); if (!AllocateOverridableConstantIds()) {
return false;
}
SetShadows(); SetShadows();
@ -432,7 +434,7 @@ sem::Variable* Resolver::Override(const ast::Override* v) {
/* constant_value */ nullptr, sem::BindingPoint{}); /* constant_value */ nullptr, sem::BindingPoint{});
if (auto* id = ast::GetAttribute<ast::IdAttribute>(v->attributes)) { if (auto* id = ast::GetAttribute<ast::IdAttribute>(v->attributes)) {
sem->SetConstantId(static_cast<uint16_t>(id->value)); sem->SetOverrideId(OverrideId{static_cast<decltype(OverrideId::value)>(id->value)});
} }
sem->SetConstructor(rhs); sem->SetConstructor(rhs);
@ -641,9 +643,19 @@ ast::Access Resolver::DefaultAccessForStorageClass(ast::StorageClass storage_cla
return ast::Access::kReadWrite; return ast::Access::kReadWrite;
} }
void Resolver::AllocateOverridableConstantIds() { bool Resolver::AllocateOverridableConstantIds() {
constexpr size_t kLimit = std::numeric_limits<decltype(OverrideId::value)>::max();
// The next pipeline constant ID to try to allocate. // The next pipeline constant ID to try to allocate.
uint16_t next_constant_id = 0; OverrideId next_id;
bool ids_exhausted = false;
auto increment_next_id = [&] {
if (next_id.value == kLimit) {
ids_exhausted = true;
} else {
next_id.value = next_id.value + 1;
}
};
// Allocate constant IDs in global declaration order, so that they are // Allocate constant IDs in global declaration order, so that they are
// deterministic. // deterministic.
@ -655,26 +667,28 @@ void Resolver::AllocateOverridableConstantIds() {
continue; continue;
} }
uint16_t constant_id; OverrideId id;
if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes)) { if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes)) {
constant_id = static_cast<uint16_t>(id_attr->value); id = OverrideId{static_cast<decltype(OverrideId::value)>(id_attr->value)};
} else { } else {
// No ID was specified, so allocate the next available ID. // No ID was specified, so allocate the next available ID.
constant_id = next_constant_id; while (!ids_exhausted && override_ids_.count(next_id)) {
while (constant_ids_.count(constant_id)) { increment_next_id();
if (constant_id == UINT16_MAX) {
TINT_ICE(Resolver, builder_->Diagnostics())
<< "no more pipeline constant IDs available";
return;
}
constant_id++;
} }
next_constant_id = constant_id + 1; if (ids_exhausted) {
AddError(
"number of 'override' variables exceeded limit of " + std::to_string(kLimit),
decl->source);
return false;
}
id = next_id;
increment_next_id();
} }
auto* sem = sem_.Get<sem::GlobalVariable>(override); auto* sem = sem_.Get<sem::GlobalVariable>(override);
const_cast<sem::GlobalVariable*>(sem)->SetConstantId(constant_id); const_cast<sem::GlobalVariable*>(sem)->SetOverrideId(id);
} }
return true;
} }
void Resolver::SetShadows() { void Resolver::SetShadows() {
@ -697,7 +711,8 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) {
if (auto* id_attr = attr->As<ast::IdAttribute>()) { if (auto* id_attr = attr->As<ast::IdAttribute>()) {
// Track the constant IDs that are specified in the shader. // Track the constant IDs that are specified in the shader.
constant_ids_.emplace(id_attr->value, sem); override_ids_.emplace(
OverrideId{static_cast<decltype(OverrideId::value)>(id_attr->value)}, sem);
} }
} }
@ -705,7 +720,7 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) {
return nullptr; return nullptr;
} }
if (!validator_.GlobalVariable(sem, constant_ids_, atomic_composite_info_)) { if (!validator_.GlobalVariable(sem, override_ids_, atomic_composite_info_)) {
return nullptr; return nullptr;
} }

View File

@ -362,7 +362,8 @@ class Resolver {
ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class); ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class);
/// Allocate constant IDs for pipeline-overridable constants. /// Allocate constant IDs for pipeline-overridable constants.
void AllocateOverridableConstantIds(); /// @returns true on success, false on error
bool AllocateOverridableConstantIds();
/// Set the shadowing information on variable declarations. /// Set the shadowing information on variable declarations.
/// @note this method must only be called after all semantic nodes are built. /// @note this method must only be called after all semantic nodes are built.
@ -422,7 +423,7 @@ class Resolver {
std::vector<sem::Function*> entry_points_; std::vector<sem::Function*> entry_points_;
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_; std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
utils::Bitset<0> marked_; utils::Bitset<0> marked_;
std::unordered_map<uint32_t, const sem::Variable*> constant_ids_; std::unordered_map<OverrideId, const sem::Variable*> override_ids_;
std::unordered_map<ArrayConstructorSig, sem::CallTarget*> array_ctors_; std::unordered_map<ArrayConstructorSig, sem::CallTarget*> array_ctors_;
std::unordered_map<StructConstructorSig, sem::CallTarget*> struct_ctors_; std::unordered_map<StructConstructorSig, sem::CallTarget*> struct_ctors_;

View File

@ -562,7 +562,7 @@ bool Validator::LocalVariable(const sem::Variable* v) const {
bool Validator::GlobalVariable( bool Validator::GlobalVariable(
const sem::GlobalVariable* global, const sem::GlobalVariable* global,
const std::unordered_map<uint32_t, const sem::Variable*>& constant_ids, const std::unordered_map<OverrideId, const sem::Variable*>& override_ids,
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const { const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const {
auto* decl = global->Declaration(); auto* decl = global->Declaration();
bool ok = Switch( bool ok = Switch(
@ -627,7 +627,7 @@ bool Validator::GlobalVariable(
return Var(global); return Var(global);
}, },
[&](const ast::Override*) { return Override(global, constant_ids); }, [&](const ast::Override*) { return Override(global, override_ids); },
[&](const ast::Const*) { [&](const ast::Const*) {
if (!decl->attributes.empty()) { if (!decl->attributes.empty()) {
AddError("attribute is not valid for module-scope 'const' declaration", AddError("attribute is not valid for module-scope 'const' declaration",
@ -763,7 +763,7 @@ bool Validator::Let(const sem::Variable* v) const {
bool Validator::Override( bool Validator::Override(
const sem::Variable* v, const sem::Variable* v,
const std::unordered_map<uint32_t, const sem::Variable*>& constant_ids) const { const std::unordered_map<OverrideId, const sem::Variable*>& override_ids) const {
auto* decl = v->Declaration(); auto* decl = v->Declaration();
auto* storage_ty = v->Type()->UnwrapRef(); auto* storage_ty = v->Type()->UnwrapRef();
@ -776,19 +776,23 @@ bool Validator::Override(
for (auto* attr : decl->attributes) { for (auto* attr : decl->attributes) {
if (auto* id_attr = attr->As<ast::IdAttribute>()) { if (auto* id_attr = attr->As<ast::IdAttribute>()) {
uint32_t id = id_attr->value; uint32_t id = id_attr->value;
auto it = constant_ids.find(id); if (id > std::numeric_limits<decltype(OverrideId::value)>::max()) {
if (it != constant_ids.end() && it->second != v) { AddError(
AddError("pipeline constant IDs must be unique", attr->source); "override IDs must be between 0 and " +
AddNote("a pipeline constant with an ID of " + std::to_string(id) + std::to_string(std::numeric_limits<decltype(OverrideId::value)>::max()),
attr->source);
return false;
}
if (auto it =
override_ids.find(OverrideId{static_cast<decltype(OverrideId::value)>(id)});
it != override_ids.end() && it->second != v) {
AddError("override IDs must be unique", attr->source);
AddNote("a override with an ID of " + std::to_string(id) +
" was previously declared here:", " was previously declared here:",
ast::GetAttribute<ast::IdAttribute>(it->second->Declaration()->attributes) ast::GetAttribute<ast::IdAttribute>(it->second->Declaration()->attributes)
->source); ->source);
return false; return false;
} }
if (id > 65535) {
AddError("pipeline constant IDs must be between 0 and 65535", attr->source);
return false;
}
} else { } else {
AddError("attribute is not valid for 'override' declaration", attr->source); AddError("attribute is not valid for 'override' declaration", attr->source);
return false; return false;

View File

@ -234,12 +234,12 @@ class Validator {
/// Validates a global variable /// Validates a global variable
/// @param var the global variable to validate /// @param var the global variable to validate
/// @param constant_ids the set of constant ids in the module /// @param override_id the set of override ids in the module
/// @param atomic_composite_info atomic composite info in the module /// @param atomic_composite_info atomic composite info in the module
/// @returns true on success, false otherwise /// @returns true on success, false otherwise
bool GlobalVariable( bool GlobalVariable(
const sem::GlobalVariable* var, const sem::GlobalVariable* var,
const std::unordered_map<uint32_t, const sem::Variable*>& constant_ids, const std::unordered_map<OverrideId, const sem::Variable*>& override_id,
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const; const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const;
/// Validates an if statement /// Validates an if statement
@ -371,10 +371,10 @@ class Validator {
/// Validates a 'override' variable declaration /// Validates a 'override' variable declaration
/// @param v the variable to validate /// @param v the variable to validate
/// @param constant_ids the set of constant ids in the module /// @param override_id the set of override ids in the module
/// @returns true on success, false otherwise. /// @returns true on success, false otherwise.
bool Override(const sem::Variable* v, bool Override(const sem::Variable* v,
const std::unordered_map<uint32_t, const sem::Variable*>& constant_ids) const; const std::unordered_map<OverrideId, const sem::Variable*>& override_id) const;
/// Validates a 'const' variable declaration /// Validates a 'const' variable declaration
/// @param v the variable to validate /// @param v the variable to validate

View File

@ -77,6 +77,37 @@ TEST_F(ResolverVariableValidationTest, OverrideNoInitializerNoType) {
EXPECT_EQ(r()->error(), "12:34 error: override declaration requires a type or initializer"); EXPECT_EQ(r()->error(), "12:34 error: override declaration requires a type or initializer");
} }
TEST_F(ResolverVariableValidationTest, OverrideExceedsIDLimit_LastUnreserved) {
// override o0 : i32;
// override o1 : i32;
// ...
// override bang : i32;
constexpr size_t kLimit = std::numeric_limits<decltype(OverrideId::value)>::max();
for (size_t i = 0; i <= kLimit; i++) {
Override("o" + std::to_string(i), ty.i32(), nullptr);
}
Override(Source{{12, 34}}, "bang", ty.i32(), nullptr);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: number of 'override' variables exceeded limit of 65535");
}
TEST_F(ResolverVariableValidationTest, OverrideExceedsIDLimit_LastReserved) {
// override o0 : i32;
// override o1 : i32;
// ...
// @id(N) override oN : i32;
constexpr size_t kLimit = std::numeric_limits<decltype(OverrideId::value)>::max();
Override("reserved", ty.i32(), nullptr, {Id(kLimit)});
for (size_t i = 0; i < kLimit; i++) {
Override("o" + std::to_string(i), ty.i32(), nullptr);
}
Override(Source{{12, 34}}, "bang", ty.i32(), nullptr);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: number of 'override' variables exceeded limit of 65535");
}
TEST_F(ResolverVariableValidationTest, VarTypeNotConstructible) { TEST_F(ResolverVariableValidationTest, VarTypeNotConstructible) {
// var i : i32; // var i : i32;
// var p : pointer<function, i32> = &v; // var p : pointer<function, i32> = &v;

View File

@ -18,6 +18,8 @@
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "tint/override_id.h"
#include "src/tint/ast/access.h" #include "src/tint/ast/access.h"
#include "src/tint/ast/storage_class.h" #include "src/tint/ast/storage_class.h"
#include "src/tint/sem/binding_point.h" #include "src/tint/sem/binding_point.h"
@ -165,15 +167,15 @@ class GlobalVariable final : public Castable<GlobalVariable, Variable> {
sem::BindingPoint BindingPoint() const { return binding_point_; } sem::BindingPoint BindingPoint() const { return binding_point_; }
/// @param id the constant identifier to assign to this variable /// @param id the constant identifier to assign to this variable
void SetConstantId(uint16_t id) { constant_id_ = id; } void SetOverrideId(OverrideId id) { override_id_ = id; }
/// @returns the pipeline constant ID associated with the variable /// @returns the pipeline constant ID associated with the variable
uint16_t ConstantId() const { return constant_id_; } tint::OverrideId OverrideId() const { return override_id_; }
private: private:
const sem::BindingPoint binding_point_; const sem::BindingPoint binding_point_;
uint16_t constant_id_ = 0; tint::OverrideId override_id_;
}; };
/// Parameter is a function parameter /// Parameter is a function parameter

View File

@ -76,11 +76,11 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c
[&](const ast::Override* override) { [&](const ast::Override* override) {
if (referenced_vars.count(override)) { if (referenced_vars.count(override)) {
if (!ast::HasAttribute<ast::IdAttribute>(override->attributes)) { if (!ast::HasAttribute<ast::IdAttribute>(override->attributes)) {
// If the constant doesn't already have an @id() attribute, add one // If the override doesn't already have an @id() attribute, add one
// so that its allocated ID so that it won't be affected by other // so that its allocated ID so that it won't be affected by other
// stripped away constants // stripped away overrides
auto* global = sem.Get(override); auto* global = sem.Get(override);
const auto* id = ctx.dst->Id(global->ConstantId()); const auto* id = ctx.dst->Id(global->OverrideId());
ctx.InsertFront(override->attributes, id); ctx.InsertFront(override->attributes, id);
} }
ctx.dst->AST().AddGlobalVariable(ctx.Clone(override)); ctx.dst->AST().AddGlobalVariable(ctx.Clone(override));

View File

@ -2159,7 +2159,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
TINT_ICE(Writer, builder_.Diagnostics()) TINT_ICE(Writer, builder_.Diagnostics())
<< "expected a pipeline-overridable constant"; << "expected a pipeline-overridable constant";
} }
out << kSpecConstantPrefix << global->ConstantId(); out << kSpecConstantPrefix << global->OverrideId().value;
} else { } else {
out << std::to_string(wgsize[i].value); out << std::to_string(wgsize[i].value);
} }
@ -3053,18 +3053,18 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
auto* type = sem->Type(); auto* type = sem->Type();
auto* global = sem->As<sem::GlobalVariable>(); auto* global = sem->As<sem::GlobalVariable>();
auto const_id = global->ConstantId(); auto override_id = global->OverrideId();
line() << "#ifndef " << kSpecConstantPrefix << const_id; line() << "#ifndef " << kSpecConstantPrefix << override_id.value;
if (override->constructor != nullptr) { if (override->constructor != nullptr) {
auto out = line(); auto out = line();
out << "#define " << kSpecConstantPrefix << const_id << " "; out << "#define " << kSpecConstantPrefix << override_id.value << " ";
if (!EmitExpression(out, override->constructor)) { if (!EmitExpression(out, override->constructor)) {
return false; return false;
} }
} else { } else {
line() << "#error spec constant required for constant id " << const_id; line() << "#error spec constant required for constant id " << override_id.value;
} }
line() << "#endif"; line() << "#endif";
{ {
@ -3074,7 +3074,7 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
builder_.Symbols().NameFor(override->symbol))) { builder_.Symbols().NameFor(override->symbol))) {
return false; return false;
} }
out << " = " << kSpecConstantPrefix << const_id << ";"; out << " = " << kSpecConstantPrefix << override_id.value << ";";
} }
return true; return true;

View File

@ -3061,7 +3061,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
TINT_ICE(Writer, diagnostics_) TINT_ICE(Writer, diagnostics_)
<< "expected a pipeline-overridable constant"; << "expected a pipeline-overridable constant";
} }
out << kSpecConstantPrefix << global->ConstantId(); out << kSpecConstantPrefix << global->OverrideId().value;
} else { } else {
out << std::to_string(wgsize[i].value); out << std::to_string(wgsize[i].value);
} }
@ -4101,18 +4101,18 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
auto* sem = builder_.Sem().Get(override); auto* sem = builder_.Sem().Get(override);
auto* type = sem->Type(); auto* type = sem->Type();
auto const_id = sem->ConstantId(); auto override_id = sem->OverrideId();
line() << "#ifndef " << kSpecConstantPrefix << const_id; line() << "#ifndef " << kSpecConstantPrefix << override_id.value;
if (override->constructor != nullptr) { if (override->constructor != nullptr) {
auto out = line(); auto out = line();
out << "#define " << kSpecConstantPrefix << const_id << " "; out << "#define " << kSpecConstantPrefix << override_id.value << " ";
if (!EmitExpression(out, override->constructor)) { if (!EmitExpression(out, override->constructor)) {
return false; return false;
} }
} else { } else {
line() << "#error spec constant required for constant id " << const_id; line() << "#error spec constant required for constant id " << override_id.value;
} }
line() << "#endif"; line() << "#endif";
{ {
@ -4122,7 +4122,7 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
builder_.Symbols().NameFor(override->symbol))) { builder_.Symbols().NameFor(override->symbol))) {
return false; return false;
} }
out << " = " << kSpecConstantPrefix << const_id << ";"; out << " = " << kSpecConstantPrefix << override_id.value << ";";
} }
return true; return true;
} }

View File

@ -244,10 +244,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_arr_vec2_bool) {
} }
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override) { TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override) {
auto* var = Override("pos", ty.f32(), Expr(3_f), auto* var = Override("pos", ty.f32(), Expr(3_f), {Id(23)});
ast::AttributeList{
Id(23),
});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -260,10 +257,7 @@ static const float pos = WGSL_SPEC_CONSTANT_23;
} }
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoConstructor) { TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoConstructor) {
auto* var = Override("pos", ty.f32(), nullptr, auto* var = Override("pos", ty.f32(), nullptr, {Id(23)});
ast::AttributeList{
Id(23),
});
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -276,10 +270,7 @@ static const float pos = WGSL_SPEC_CONSTANT_23;
} }
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoId) { TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoId) {
auto* a = Override("a", ty.f32(), Expr(3_f), auto* a = Override("a", ty.f32(), Expr(3_f), {Id(0)});
ast::AttributeList{
Id(0),
});
auto* b = Override("b", ty.f32(), Expr(2_f)); auto* b = Override("b", ty.f32(), Expr(2_f));
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();

View File

@ -3063,7 +3063,7 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
} }
out << " " << program_->Symbols().NameFor(override->symbol); out << " " << program_->Symbols().NameFor(override->symbol);
out << " [[function_constant(" << global->ConstantId() << ")]];"; out << " [[function_constant(" << global->OverrideId().value << ")]];";
return true; return true;
} }

View File

@ -540,7 +540,7 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
<< "expected a pipeline-overridable constant"; << "expected a pipeline-overridable constant";
} }
constant.is_spec_op = true; constant.is_spec_op = true;
constant.constant_id = sem_const->ConstantId(); constant.constant_id = sem_const->OverrideId().value;
} }
auto result = GenerateConstantIfNeeded(constant); auto result = GenerateConstantIfNeeded(constant);
@ -1340,7 +1340,7 @@ uint32_t Builder::GenerateTypeConstructorOrConversion(const sem::Call* call,
// Generate the zero initializer if there are no values provided. // Generate the zero initializer if there are no values provided.
if (args.IsEmpty()) { if (args.IsEmpty()) {
if (global_var && global_var->Declaration()->Is<ast::Override>()) { if (global_var && global_var->Declaration()->Is<ast::Override>()) {
auto constant_id = global_var->ConstantId(); auto constant_id = global_var->OverrideId().value;
if (result_type->Is<sem::I32>()) { if (result_type->Is<sem::I32>()) {
return GenerateConstantIfNeeded(ScalarConstant::I32(0).AsSpecOp(constant_id)); return GenerateConstantIfNeeded(ScalarConstant::I32(0).AsSpecOp(constant_id));
} }
@ -1664,7 +1664,7 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
auto* global = builder_.Sem().Get<sem::GlobalVariable>(var); auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
if (global && global->Declaration()->Is<ast::Override>()) { if (global && global->Declaration()->Is<ast::Override>()) {
constant.is_spec_op = true; constant.is_spec_op = true;
constant.constant_id = global->ConstantId(); constant.constant_id = global->OverrideId().value;
} }
Switch( Switch(