tint: Add tint::OverrideId
This is a public API definition of a program-unique override identifier. Bug: tint:1155 Change-Id: I6e55d43208e72a7a316557a89e2169d1b952f9bf Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/97006 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Dan Sinclair <dsinclair@chromium.org> Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
b2306e27fc
commit
9a6acc419e
|
@ -0,0 +1,60 @@
|
|||
// Copyright 2022 The Tint Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef SRC_TINT_OVERRIDE_ID_H_
|
||||
#define SRC_TINT_OVERRIDE_ID_H_
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace tint {
|
||||
|
||||
/// OverrideId is a numerical identifier for an override variable, unique per program.
|
||||
struct OverrideId {
|
||||
uint16_t value = 0;
|
||||
};
|
||||
|
||||
/// Equality operator for OverrideId
|
||||
/// @param lhs the OverrideId on the left of the '=' operator
|
||||
/// @param rhs the OverrideId on the right of the '=' operator
|
||||
/// @returns true if `lhs` is equal to `rhs`
|
||||
inline bool operator==(OverrideId lhs, OverrideId rhs) {
|
||||
return lhs.value == rhs.value;
|
||||
}
|
||||
|
||||
/// Less-than operator for OverrideId
|
||||
/// @param lhs the OverrideId on the left of the '<' operator
|
||||
/// @param rhs the OverrideId on the right of the '<' operator
|
||||
/// @returns true if `lhs` comes before `rhs`
|
||||
inline bool operator<(OverrideId lhs, OverrideId rhs) {
|
||||
return lhs.value < rhs.value;
|
||||
}
|
||||
|
||||
} // namespace tint
|
||||
|
||||
namespace std {
|
||||
|
||||
/// Custom std::hash specialization for tint::OverrideId.
|
||||
template <>
|
||||
class hash<tint::OverrideId> {
|
||||
public:
|
||||
/// @param id the override identifier
|
||||
/// @return the hash of the override identifier
|
||||
inline std::size_t operator()(tint::OverrideId id) const {
|
||||
return std::hash<decltype(tint::OverrideId::value)>()(id.value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace std
|
||||
|
||||
#endif // SRC_TINT_OVERRIDE_ID_H_
|
|
@ -66,16 +66,16 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
|
|||
|
||||
// Validate if overridable constants exist in shader module
|
||||
// pipelineBase is not yet constructed at this moment so iterate constants from descriptor
|
||||
size_t numUninitializedConstants = metadata.uninitializedOverridableConstants.size();
|
||||
size_t numUninitializedConstants = metadata.uninitializedOverrides.size();
|
||||
// Keep an initialized constants sets to handle duplicate initialization cases
|
||||
std::unordered_set<std::string> stageInitializedConstantIdentifiers;
|
||||
for (uint32_t i = 0; i < constantCount; i++) {
|
||||
DAWN_INVALID_IF(metadata.overridableConstants.count(constants[i].key) == 0,
|
||||
DAWN_INVALID_IF(metadata.overrides.count(constants[i].key) == 0,
|
||||
"Pipeline overridable constant \"%s\" not found in %s.", constants[i].key,
|
||||
module);
|
||||
|
||||
if (stageInitializedConstantIdentifiers.count(constants[i].key) == 0) {
|
||||
if (metadata.uninitializedOverridableConstants.count(constants[i].key) > 0) {
|
||||
if (metadata.uninitializedOverrides.count(constants[i].key) > 0) {
|
||||
numUninitializedConstants--;
|
||||
}
|
||||
stageInitializedConstantIdentifiers.insert(constants[i].key);
|
||||
|
@ -91,7 +91,7 @@ MaybeError ValidateProgrammableStage(DeviceBase* device,
|
|||
if (DAWN_UNLIKELY(numUninitializedConstants > 0)) {
|
||||
std::string uninitializedConstantsArray;
|
||||
bool isFirst = true;
|
||||
for (std::string identifier : metadata.uninitializedOverridableConstants) {
|
||||
for (std::string identifier : metadata.uninitializedOverrides) {
|
||||
if (stageInitializedConstantIdentifiers.count(identifier) > 0) {
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -358,17 +358,16 @@ ResultOrError<InterpolationSampling> TintInterpolationSamplingToInterpolationSam
|
|||
UNREACHABLE();
|
||||
}
|
||||
|
||||
EntryPointMetadata::OverridableConstant::Type FromTintOverridableConstantType(
|
||||
tint::inspector::OverridableConstant::Type type) {
|
||||
EntryPointMetadata::Override::Type FromTintOverrideType(tint::inspector::Override::Type type) {
|
||||
switch (type) {
|
||||
case tint::inspector::OverridableConstant::Type::kBool:
|
||||
return EntryPointMetadata::OverridableConstant::Type::Boolean;
|
||||
case tint::inspector::OverridableConstant::Type::kFloat32:
|
||||
return EntryPointMetadata::OverridableConstant::Type::Float32;
|
||||
case tint::inspector::OverridableConstant::Type::kInt32:
|
||||
return EntryPointMetadata::OverridableConstant::Type::Int32;
|
||||
case tint::inspector::OverridableConstant::Type::kUint32:
|
||||
return EntryPointMetadata::OverridableConstant::Type::Uint32;
|
||||
case tint::inspector::Override::Type::kBool:
|
||||
return EntryPointMetadata::Override::Type::Boolean;
|
||||
case tint::inspector::Override::Type::kFloat32:
|
||||
return EntryPointMetadata::Override::Type::Float32;
|
||||
case tint::inspector::Override::Type::kInt32:
|
||||
return EntryPointMetadata::Override::Type::Int32;
|
||||
case tint::inspector::Override::Type::kUint32:
|
||||
return EntryPointMetadata::Override::Type::Uint32;
|
||||
}
|
||||
UNREACHABLE();
|
||||
}
|
||||
|
@ -610,17 +609,17 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
|||
return invalid; \
|
||||
})()
|
||||
|
||||
if (!entryPoint.overridable_constants.empty()) {
|
||||
if (!entryPoint.overrides.empty()) {
|
||||
DAWN_INVALID_IF(device->IsToggleEnabled(Toggle::DisallowUnsafeAPIs),
|
||||
"Pipeline overridable constants are disallowed because they "
|
||||
"are partially implemented.");
|
||||
|
||||
const auto& name2Id = inspector->GetConstantNameToIdMap();
|
||||
const auto& id2Scalar = inspector->GetConstantIDs();
|
||||
const auto& name2Id = inspector->GetNamedOverrideIds();
|
||||
const auto& id2Scalar = inspector->GetOverrideDefaultValues();
|
||||
|
||||
for (auto& c : entryPoint.overridable_constants) {
|
||||
uint32_t id = name2Id.at(c.name);
|
||||
OverridableConstantScalar defaultValue;
|
||||
for (auto& c : entryPoint.overrides) {
|
||||
auto id = name2Id.at(c.name);
|
||||
OverrideScalar defaultValue;
|
||||
if (c.is_initialized) {
|
||||
// if it is initialized, the scalar must exist
|
||||
const auto& scalar = id2Scalar.at(id);
|
||||
|
@ -636,21 +635,19 @@ ResultOrError<std::unique_ptr<EntryPointMetadata>> ReflectEntryPointUsingTint(
|
|||
UNREACHABLE();
|
||||
}
|
||||
}
|
||||
EntryPointMetadata::OverridableConstant constant = {
|
||||
id, FromTintOverridableConstantType(c.type), c.is_initialized, defaultValue};
|
||||
EntryPointMetadata::Override override = {id.value, FromTintOverrideType(c.type),
|
||||
c.is_initialized, defaultValue};
|
||||
|
||||
std::string identifier =
|
||||
c.is_numeric_id_specified ? std::to_string(constant.id) : c.name;
|
||||
metadata->overridableConstants[identifier] = constant;
|
||||
std::string identifier = c.is_id_specified ? std::to_string(override.id) : c.name;
|
||||
metadata->overrides[identifier] = override;
|
||||
|
||||
if (!c.is_initialized) {
|
||||
auto [_, inserted] =
|
||||
metadata->uninitializedOverridableConstants.emplace(std::move(identifier));
|
||||
metadata->uninitializedOverrides.emplace(std::move(identifier));
|
||||
// The insertion should have taken place
|
||||
ASSERT(inserted);
|
||||
} else {
|
||||
auto [_, inserted] =
|
||||
metadata->initializedOverridableConstants.emplace(std::move(identifier));
|
||||
auto [_, inserted] = metadata->initializedOverrides.emplace(std::move(identifier));
|
||||
// The insertion should have taken place
|
||||
ASSERT(inserted);
|
||||
}
|
||||
|
|
|
@ -155,8 +155,8 @@ struct ShaderBindingInfo {
|
|||
using BindingGroupInfoMap = std::map<BindingNumber, ShaderBindingInfo>;
|
||||
using BindingInfoArray = ityp::array<BindGroupIndex, BindingGroupInfoMap, kMaxBindGroups>;
|
||||
|
||||
// The WebGPU overridable constants only support these scalar types
|
||||
union OverridableConstantScalar {
|
||||
// The WebGPU override variables only support these scalar types
|
||||
union OverrideScalar {
|
||||
// Use int32_t for boolean to initialize the full 32bit
|
||||
int32_t b;
|
||||
float f32;
|
||||
|
@ -216,9 +216,9 @@ struct EntryPointMetadata {
|
|||
// The shader stage for this binding.
|
||||
SingleShaderStage stage;
|
||||
|
||||
struct OverridableConstant {
|
||||
struct Override {
|
||||
uint32_t id;
|
||||
// Match tint::inspector::OverridableConstant::Type
|
||||
// Match tint::inspector::Override::Type
|
||||
// Bool is defined as a macro on linux X11 and cannot compile
|
||||
enum class Type { Boolean, Float32, Uint32, Int32 } type;
|
||||
|
||||
|
@ -230,23 +230,23 @@ struct EntryPointMetadata {
|
|||
// Store the default initialized value in shader
|
||||
// This is used by metal backend as the function_constant does not have dafault values
|
||||
// Initialized when isInitialized == true
|
||||
OverridableConstantScalar defaultValue;
|
||||
OverrideScalar defaultValue;
|
||||
};
|
||||
|
||||
using OverridableConstantsMap = std::unordered_map<std::string, OverridableConstant>;
|
||||
using OverridesMap = std::unordered_map<std::string, Override>;
|
||||
|
||||
// Map identifier to overridable constant
|
||||
// Map identifier to override variable
|
||||
// Identifier is unique: either the variable name or the numeric ID if specified
|
||||
OverridableConstantsMap overridableConstants;
|
||||
OverridesMap overrides;
|
||||
|
||||
// Overridable constants that are not initialized in shaders
|
||||
// Override variables that are not initialized in shaders
|
||||
// They need value initialization from pipeline stage or it is a validation error
|
||||
std::unordered_set<std::string> uninitializedOverridableConstants;
|
||||
std::unordered_set<std::string> uninitializedOverrides;
|
||||
|
||||
// Store constants with shader initialized values as well
|
||||
// This is used by metal backend to set values with default initializers that are not
|
||||
// overridden
|
||||
std::unordered_set<std::string> initializedOverridableConstants;
|
||||
std::unordered_set<std::string> initializedOverrides;
|
||||
|
||||
bool usesNumWorkgroups = false;
|
||||
// Used at render pipeline validation.
|
||||
|
|
|
@ -111,17 +111,17 @@ std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
|
|||
return out.str();
|
||||
}
|
||||
|
||||
std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
|
||||
const OverridableConstantScalar* entry,
|
||||
std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType,
|
||||
const OverrideScalar* entry,
|
||||
double value = 0) {
|
||||
switch (dawnType) {
|
||||
case EntryPointMetadata::OverridableConstant::Type::Boolean:
|
||||
case EntryPointMetadata::Override::Type::Boolean:
|
||||
return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Float32:
|
||||
case EntryPointMetadata::Override::Type::Float32:
|
||||
return FloatToStringWithPrecision(entry ? entry->f32 : static_cast<float>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Int32:
|
||||
case EntryPointMetadata::Override::Type::Int32:
|
||||
return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Uint32:
|
||||
case EntryPointMetadata::Override::Type::Uint32:
|
||||
return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
|
||||
default:
|
||||
UNREACHABLE();
|
||||
|
@ -133,7 +133,7 @@ constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
|
|||
void GetOverridableConstantsDefines(
|
||||
std::vector<std::pair<std::string, std::string>>* defineStrings,
|
||||
const PipelineConstantEntries* pipelineConstantEntries,
|
||||
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
|
||||
const EntryPointMetadata::OverridesMap* shaderEntryPointConstants) {
|
||||
std::unordered_set<std::string> overriddenConstants;
|
||||
|
||||
// Set pipeline overridden values
|
||||
|
@ -305,8 +305,7 @@ struct ShaderCompilationRequest {
|
|||
|
||||
GetOverridableConstantsDefines(
|
||||
&request.defineStrings, &programmableStage.constants,
|
||||
&programmableStage.module->GetEntryPoint(programmableStage.entryPoint)
|
||||
.overridableConstants);
|
||||
&programmableStage.module->GetEntryPoint(programmableStage.entryPoint).overrides);
|
||||
|
||||
return std::move(request);
|
||||
}
|
||||
|
|
|
@ -332,7 +332,7 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
|
|||
ShaderModule* shaderModule = ToBackend(programmableStage.module.Get());
|
||||
const char* shaderEntryPoint = programmableStage.entryPoint.c_str();
|
||||
const auto& entryPointMetadata = programmableStage.module->GetEntryPoint(shaderEntryPoint);
|
||||
if (entryPointMetadata.overridableConstants.size() == 0) {
|
||||
if (entryPointMetadata.overrides.size() == 0) {
|
||||
DAWN_TRY(shaderModule->CreateFunction(shaderEntryPoint, singleShaderStage, pipelineLayout,
|
||||
functionData, nil, sampleMask, renderPipeline));
|
||||
return {};
|
||||
|
@ -345,29 +345,29 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
|
|||
|
||||
std::unordered_set<std::string> overriddenConstants;
|
||||
|
||||
auto switchType = [&](EntryPointMetadata::OverridableConstant::Type dawnType,
|
||||
MTLDataType* type, OverridableConstantScalar* entry,
|
||||
auto switchType = [&](EntryPointMetadata::Override::Type dawnType,
|
||||
MTLDataType* type, OverrideScalar* entry,
|
||||
double value = 0) {
|
||||
switch (dawnType) {
|
||||
case EntryPointMetadata::OverridableConstant::Type::Boolean:
|
||||
case EntryPointMetadata::Override::Type::Boolean:
|
||||
*type = MTLDataTypeBool;
|
||||
if (entry) {
|
||||
entry->b = static_cast<int32_t>(value);
|
||||
}
|
||||
break;
|
||||
case EntryPointMetadata::OverridableConstant::Type::Float32:
|
||||
case EntryPointMetadata::Override::Type::Float32:
|
||||
*type = MTLDataTypeFloat;
|
||||
if (entry) {
|
||||
entry->f32 = static_cast<float>(value);
|
||||
}
|
||||
break;
|
||||
case EntryPointMetadata::OverridableConstant::Type::Int32:
|
||||
case EntryPointMetadata::Override::Type::Int32:
|
||||
*type = MTLDataTypeInt;
|
||||
if (entry) {
|
||||
entry->i32 = static_cast<int32_t>(value);
|
||||
}
|
||||
break;
|
||||
case EntryPointMetadata::OverridableConstant::Type::Uint32:
|
||||
case EntryPointMetadata::Override::Type::Uint32:
|
||||
*type = MTLDataTypeUInt;
|
||||
if (entry) {
|
||||
entry->u32 = static_cast<uint32_t>(value);
|
||||
|
@ -382,10 +382,10 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
|
|||
overriddenConstants.insert(name);
|
||||
|
||||
// This is already validated so `name` must exist
|
||||
const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name);
|
||||
const auto& moduleConstant = entryPointMetadata.overrides.at(name);
|
||||
|
||||
MTLDataType type;
|
||||
OverridableConstantScalar entry{};
|
||||
OverrideScalar entry{};
|
||||
|
||||
switchType(moduleConstant.type, &type, &entry, value);
|
||||
|
||||
|
@ -394,14 +394,14 @@ MaybeError CreateMTLFunction(const ProgrammableStage& programmableStage,
|
|||
|
||||
// Set shader initialized default values because MSL function_constant
|
||||
// has no default value
|
||||
for (const std::string& name : entryPointMetadata.initializedOverridableConstants) {
|
||||
for (const std::string& name : entryPointMetadata.initializedOverrides) {
|
||||
if (overriddenConstants.count(name) != 0) {
|
||||
// This constant already has overridden value
|
||||
continue;
|
||||
}
|
||||
|
||||
// Must exist because it is validated
|
||||
const auto& moduleConstant = entryPointMetadata.overridableConstants.at(name);
|
||||
const auto& moduleConstant = entryPointMetadata.overrides.at(name);
|
||||
ASSERT(moduleConstant.isInitialized);
|
||||
MTLDataType type;
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ MaybeError ComputePipeline::Initialize() {
|
|||
createInfo.stage.module = moduleAndSpirv.module;
|
||||
createInfo.stage.pName = computeStage.entryPoint.c_str();
|
||||
|
||||
std::vector<OverridableConstantScalar> specializationDataEntries;
|
||||
std::vector<OverrideScalar> specializationDataEntries;
|
||||
std::vector<VkSpecializationMapEntry> specializationMapEntries;
|
||||
VkSpecializationInfo specializationInfo{};
|
||||
createInfo.stage.pSpecializationInfo = GetVkSpecializationInfo(
|
||||
|
|
|
@ -341,7 +341,7 @@ MaybeError RenderPipeline::Initialize() {
|
|||
|
||||
// There are at most 2 shader stages in render pipeline, i.e. vertex and fragment
|
||||
std::array<VkPipelineShaderStageCreateInfo, 2> shaderStages;
|
||||
std::array<std::vector<OverridableConstantScalar>, 2> specializationDataEntriesPerStages;
|
||||
std::array<std::vector<OverrideScalar>, 2> specializationDataEntriesPerStages;
|
||||
std::array<std::vector<VkSpecializationMapEntry>, 2> specializationMapEntriesPerStages;
|
||||
std::array<VkSpecializationInfo, 2> specializationInfoPerStages;
|
||||
uint32_t stageCount = 0;
|
||||
|
|
|
@ -261,7 +261,7 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName) {
|
|||
VkSpecializationInfo* GetVkSpecializationInfo(
|
||||
const ProgrammableStage& programmableStage,
|
||||
VkSpecializationInfo* specializationInfo,
|
||||
std::vector<OverridableConstantScalar>* specializationDataEntries,
|
||||
std::vector<OverrideScalar>* specializationDataEntries,
|
||||
std::vector<VkSpecializationMapEntry>* specializationMapEntries) {
|
||||
ASSERT(specializationInfo);
|
||||
ASSERT(specializationDataEntries);
|
||||
|
@ -279,26 +279,25 @@ VkSpecializationInfo* GetVkSpecializationInfo(
|
|||
double value = pipelineConstant.second;
|
||||
|
||||
// This is already validated so `identifier` must exist
|
||||
const auto& moduleConstant = entryPointMetaData.overridableConstants.at(identifier);
|
||||
const auto& moduleConstant = entryPointMetaData.overrides.at(identifier);
|
||||
|
||||
specializationMapEntries->push_back(
|
||||
VkSpecializationMapEntry{moduleConstant.id,
|
||||
static_cast<uint32_t>(specializationDataEntries->size() *
|
||||
sizeof(OverridableConstantScalar)),
|
||||
sizeof(OverridableConstantScalar)});
|
||||
specializationMapEntries->push_back(VkSpecializationMapEntry{
|
||||
moduleConstant.id,
|
||||
static_cast<uint32_t>(specializationDataEntries->size() * sizeof(OverrideScalar)),
|
||||
sizeof(OverrideScalar)});
|
||||
|
||||
OverridableConstantScalar entry{};
|
||||
OverrideScalar entry{};
|
||||
switch (moduleConstant.type) {
|
||||
case EntryPointMetadata::OverridableConstant::Type::Boolean:
|
||||
case EntryPointMetadata::Override::Type::Boolean:
|
||||
entry.b = static_cast<int32_t>(value);
|
||||
break;
|
||||
case EntryPointMetadata::OverridableConstant::Type::Float32:
|
||||
case EntryPointMetadata::Override::Type::Float32:
|
||||
entry.f32 = static_cast<float>(value);
|
||||
break;
|
||||
case EntryPointMetadata::OverridableConstant::Type::Int32:
|
||||
case EntryPointMetadata::Override::Type::Int32:
|
||||
entry.i32 = static_cast<int32_t>(value);
|
||||
break;
|
||||
case EntryPointMetadata::OverridableConstant::Type::Uint32:
|
||||
case EntryPointMetadata::Override::Type::Uint32:
|
||||
entry.u32 = static_cast<uint32_t>(value);
|
||||
break;
|
||||
default:
|
||||
|
@ -309,8 +308,7 @@ VkSpecializationInfo* GetVkSpecializationInfo(
|
|||
|
||||
specializationInfo->mapEntryCount = static_cast<uint32_t>(specializationMapEntries->size());
|
||||
specializationInfo->pMapEntries = specializationMapEntries->data();
|
||||
specializationInfo->dataSize =
|
||||
specializationDataEntries->size() * sizeof(OverridableConstantScalar);
|
||||
specializationInfo->dataSize = specializationDataEntries->size() * sizeof(OverrideScalar);
|
||||
specializationInfo->pData = specializationDataEntries->data();
|
||||
|
||||
return specializationInfo;
|
||||
|
|
|
@ -24,7 +24,7 @@
|
|||
|
||||
namespace dawn::native {
|
||||
struct ProgrammableStage;
|
||||
union OverridableConstantScalar;
|
||||
union OverrideScalar;
|
||||
} // namespace dawn::native
|
||||
|
||||
namespace dawn::native::vulkan {
|
||||
|
@ -150,7 +150,7 @@ std::string GetDeviceDebugPrefixFromDebugName(const char* debugName);
|
|||
VkSpecializationInfo* GetVkSpecializationInfo(
|
||||
const ProgrammableStage& programmableStage,
|
||||
VkSpecializationInfo* specializationInfo,
|
||||
std::vector<OverridableConstantScalar>* specializationDataEntries,
|
||||
std::vector<OverrideScalar>* specializationDataEntries,
|
||||
std::vector<VkSpecializationMapEntry>* specializationMapEntries);
|
||||
|
||||
} // namespace dawn::native::vulkan
|
||||
|
|
|
@ -1110,7 +1110,7 @@ if (tint_build_unittests) {
|
|||
"resolver/is_host_shareable_test.cc",
|
||||
"resolver/is_storeable_test.cc",
|
||||
"resolver/materialize_test.cc",
|
||||
"resolver/pipeline_overridable_constant_test.cc",
|
||||
"resolver/override_test.cc",
|
||||
"resolver/ptr_ref_test.cc",
|
||||
"resolver/ptr_ref_validation_test.cc",
|
||||
"resolver/resolver_behavior_test.cc",
|
||||
|
|
|
@ -794,7 +794,7 @@ if(TINT_BUILD_TESTS)
|
|||
resolver/is_host_shareable_test.cc
|
||||
resolver/is_storeable_test.cc
|
||||
resolver/materialize_test.cc
|
||||
resolver/pipeline_overridable_constant_test.cc
|
||||
resolver/override_test.cc
|
||||
resolver/ptr_ref_test.cc
|
||||
resolver/ptr_ref_validation_test.cc
|
||||
resolver/resolver_behavior_test.cc
|
||||
|
|
|
@ -268,10 +268,10 @@ void CommonFuzzer::RunInspector(Program* program) {
|
|||
auto entry_points = inspector.GetEntryPoints();
|
||||
CHECK_INSPECTOR(program, inspector);
|
||||
|
||||
auto constant_ids = inspector.GetConstantIDs();
|
||||
auto override_ids = inspector.GetOverrideDefaultValues();
|
||||
CHECK_INSPECTOR(program, inspector);
|
||||
|
||||
auto constant_name_to_id = inspector.GetConstantNameToIdMap();
|
||||
auto override_name_to_id = inspector.GetNamedOverrideIds();
|
||||
CHECK_INSPECTOR(program, inspector);
|
||||
|
||||
for (auto& ep : entry_points) {
|
||||
|
|
|
@ -20,6 +20,8 @@
|
|||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "tint/override_id.h"
|
||||
|
||||
#include "src/tint/ast/interpolate_attribute.h"
|
||||
#include "src/tint/ast/pipeline_stage.h"
|
||||
|
||||
|
@ -93,14 +95,13 @@ InterpolationType ASTToInspectorInterpolationType(ast::InterpolationType ast_typ
|
|||
/// @returns the publicly visible equivalent
|
||||
InterpolationSampling ASTToInspectorInterpolationSampling(ast::InterpolationSampling sampling);
|
||||
|
||||
/// Reflection data about a pipeline overridable constant referenced by an entry
|
||||
/// point
|
||||
struct OverridableConstant {
|
||||
/// Name of the constant
|
||||
/// Reflection data about an override variable referenced by an entry point
|
||||
struct Override {
|
||||
/// Name of the override
|
||||
std::string name;
|
||||
|
||||
/// ID of the constant
|
||||
uint16_t numeric_id;
|
||||
/// ID of the override
|
||||
OverrideId id;
|
||||
|
||||
/// Type of the scalar
|
||||
enum class Type {
|
||||
|
@ -113,12 +114,11 @@ struct OverridableConstant {
|
|||
/// Type of the scalar
|
||||
Type type;
|
||||
|
||||
/// Does this pipeline overridable constant have an initializer?
|
||||
/// Does this override have an initializer?
|
||||
bool is_initialized = false;
|
||||
|
||||
/// Does this pipeline overridable constant have a numeric ID specified
|
||||
/// explicitly?
|
||||
bool is_numeric_id_specified = false;
|
||||
/// Does this override have a numeric ID specified explicitly?
|
||||
bool is_id_specified = false;
|
||||
};
|
||||
|
||||
/// The pipeline stage
|
||||
|
@ -159,7 +159,7 @@ struct EntryPoint {
|
|||
/// List of the output variable accessed via this entry point.
|
||||
std::vector<StageVariable> output_variables;
|
||||
/// List of the pipeline overridable constants accessed via this entry point.
|
||||
std::vector<OverridableConstant> overridable_constants;
|
||||
std::vector<Override> overrides;
|
||||
/// Does the entry point use the sample_mask builtin as an input builtin
|
||||
/// variable.
|
||||
bool input_sample_mask_used = false;
|
||||
|
|
|
@ -206,28 +206,28 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
|||
|
||||
auto* global = var->As<sem::GlobalVariable>();
|
||||
if (global && global->Declaration()->Is<ast::Override>()) {
|
||||
OverridableConstant overridable_constant;
|
||||
overridable_constant.name = name;
|
||||
overridable_constant.numeric_id = global->ConstantId();
|
||||
Override override;
|
||||
override.name = name;
|
||||
override.id = global->OverrideId();
|
||||
auto* type = var->Type();
|
||||
TINT_ASSERT(Inspector, type->is_scalar());
|
||||
if (type->is_bool_scalar_or_vector()) {
|
||||
overridable_constant.type = OverridableConstant::Type::kBool;
|
||||
override.type = Override::Type::kBool;
|
||||
} else if (type->is_float_scalar()) {
|
||||
overridable_constant.type = OverridableConstant::Type::kFloat32;
|
||||
override.type = Override::Type::kFloat32;
|
||||
} else if (type->is_signed_integer_scalar()) {
|
||||
overridable_constant.type = OverridableConstant::Type::kInt32;
|
||||
override.type = Override::Type::kInt32;
|
||||
} else if (type->is_unsigned_integer_scalar()) {
|
||||
overridable_constant.type = OverridableConstant::Type::kUint32;
|
||||
override.type = Override::Type::kUint32;
|
||||
} else {
|
||||
TINT_UNREACHABLE(Inspector, diagnostics_);
|
||||
}
|
||||
|
||||
overridable_constant.is_initialized = global->Declaration()->constructor;
|
||||
overridable_constant.is_numeric_id_specified =
|
||||
override.is_initialized = global->Declaration()->constructor;
|
||||
override.is_id_specified =
|
||||
ast::HasAttribute<ast::IdAttribute>(global->Declaration()->attributes);
|
||||
|
||||
entry_point.overridable_constants.push_back(overridable_constant);
|
||||
entry_point.overrides.push_back(override);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -237,37 +237,37 @@ std::vector<EntryPoint> Inspector::GetEntryPoints() {
|
|||
return result;
|
||||
}
|
||||
|
||||
std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
|
||||
std::map<uint32_t, Scalar> result;
|
||||
std::map<OverrideId, Scalar> Inspector::GetOverrideDefaultValues() {
|
||||
std::map<OverrideId, Scalar> result;
|
||||
for (auto* var : program_->AST().GlobalVariables()) {
|
||||
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
|
||||
if (!global || !global->Declaration()->Is<ast::Override>()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// If there are conflicting defintions for a constant id, that is invalid
|
||||
// If there are conflicting defintions for an override id, that is invalid
|
||||
// WGSL, so the resolver should catch it. Thus here the inspector just
|
||||
// assumes all definitions of the constant id are the same, so only needs
|
||||
// to find the first reference to constant id.
|
||||
uint32_t constant_id = global->ConstantId();
|
||||
if (result.find(constant_id) != result.end()) {
|
||||
// assumes all definitions of the override id are the same, so only needs
|
||||
// to find the first reference to override id.
|
||||
OverrideId override_id = global->OverrideId();
|
||||
if (result.find(override_id) != result.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!var->constructor) {
|
||||
result[constant_id] = Scalar();
|
||||
result[override_id] = Scalar();
|
||||
continue;
|
||||
}
|
||||
|
||||
auto* literal = var->constructor->As<ast::LiteralExpression>();
|
||||
if (!literal) {
|
||||
// This is invalid WGSL, but handling gracefully.
|
||||
result[constant_id] = Scalar();
|
||||
result[override_id] = Scalar();
|
||||
continue;
|
||||
}
|
||||
|
||||
if (auto* l = literal->As<ast::BoolLiteralExpression>()) {
|
||||
result[constant_id] = Scalar(l->value);
|
||||
result[override_id] = Scalar(l->value);
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -275,32 +275,32 @@ std::map<uint32_t, Scalar> Inspector::GetConstantIDs() {
|
|||
switch (l->suffix) {
|
||||
case ast::IntLiteralExpression::Suffix::kNone:
|
||||
case ast::IntLiteralExpression::Suffix::kI:
|
||||
result[constant_id] = Scalar(static_cast<int32_t>(l->value));
|
||||
result[override_id] = Scalar(static_cast<int32_t>(l->value));
|
||||
continue;
|
||||
case ast::IntLiteralExpression::Suffix::kU:
|
||||
result[constant_id] = Scalar(static_cast<uint32_t>(l->value));
|
||||
result[override_id] = Scalar(static_cast<uint32_t>(l->value));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (auto* l = literal->As<ast::FloatLiteralExpression>()) {
|
||||
result[constant_id] = Scalar(static_cast<float>(l->value));
|
||||
result[override_id] = Scalar(static_cast<float>(l->value));
|
||||
continue;
|
||||
}
|
||||
|
||||
result[constant_id] = Scalar();
|
||||
result[override_id] = Scalar();
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::map<std::string, uint32_t> Inspector::GetConstantNameToIdMap() {
|
||||
std::map<std::string, uint32_t> result;
|
||||
std::map<std::string, OverrideId> Inspector::GetNamedOverrideIds() {
|
||||
std::map<std::string, OverrideId> result;
|
||||
for (auto* var : program_->AST().GlobalVariables()) {
|
||||
auto* global = program_->Sem().Get<sem::GlobalVariable>(var);
|
||||
if (global && global->Declaration()->Is<ast::Override>()) {
|
||||
auto name = program_->Symbols().NameFor(var->symbol);
|
||||
result[name] = global->ConstantId();
|
||||
result[name] = global->OverrideId();
|
||||
}
|
||||
}
|
||||
return result;
|
||||
|
|
|
@ -23,6 +23,8 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tint/override_id.h"
|
||||
|
||||
#include "src/tint/inspector/entry_point.h"
|
||||
#include "src/tint/inspector/resource_binding.h"
|
||||
#include "src/tint/inspector/scalar.h"
|
||||
|
@ -53,11 +55,11 @@ class Inspector {
|
|||
/// @returns vector of entry point information
|
||||
std::vector<EntryPoint> GetEntryPoints();
|
||||
|
||||
/// @returns map of const_id to initial value
|
||||
std::map<uint32_t, Scalar> GetConstantIDs();
|
||||
/// @returns map of override identifier to initial value
|
||||
std::map<OverrideId, Scalar> GetOverrideDefaultValues();
|
||||
|
||||
/// @returns map of module-constant name to pipeline constant ID
|
||||
std::map<std::string, uint32_t> GetConstantNameToIdMap();
|
||||
std::map<std::string, OverrideId> GetNamedOverrideIds();
|
||||
|
||||
/// @param entry_point name of the entry point to get information about.
|
||||
/// @returns the total size of shared storage required by an entry point,
|
||||
|
|
|
@ -60,7 +60,7 @@ struct InspectorGetEntryPointInterpolateTestParams {
|
|||
class InspectorGetEntryPointInterpolateTest
|
||||
: public InspectorBuilder,
|
||||
public testing::TestWithParam<InspectorGetEntryPointInterpolateTestParams> {};
|
||||
class InspectorGetConstantIDsTest : public InspectorBuilder, public testing::Test {};
|
||||
class InspectorGetOverrideDefaultValuesTest : public InspectorBuilder, public testing::Test {};
|
||||
class InspectorGetConstantNameToIdMapTest : public InspectorBuilder, public testing::Test {};
|
||||
class InspectorGetStorageSizeTest : public InspectorBuilder, public testing::Test {};
|
||||
class InspectorGetResourceBindingsTest : public InspectorBuilder, public testing::Test {};
|
||||
|
@ -605,8 +605,8 @@ TEST_F(InspectorGetEntryPointTest, MixInOutVariablesAndStruct) {
|
|||
EXPECT_EQ(ComponentType::kUInt, result[0].output_variables[1].component_type);
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetEntryPointTest, OverridableConstantUnreferenced) {
|
||||
AddOverridableConstantWithoutID("foo", ty.f32(), nullptr);
|
||||
TEST_F(InspectorGetEntryPointTest, OverrideUnreferenced) {
|
||||
Override("foo", ty.f32(), nullptr);
|
||||
MakeEmptyBodyFunction("ep_func", {
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
WorkgroupSize(1_i),
|
||||
|
@ -617,11 +617,11 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantUnreferenced) {
|
|||
auto result = inspector.GetEntryPoints();
|
||||
|
||||
ASSERT_EQ(1u, result.size());
|
||||
EXPECT_EQ(0u, result[0].overridable_constants.size());
|
||||
EXPECT_EQ(0u, result[0].overrides.size());
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) {
|
||||
AddOverridableConstantWithoutID("foo", ty.f32(), nullptr);
|
||||
TEST_F(InspectorGetEntryPointTest, OverrideReferencedByEntryPoint) {
|
||||
Override("foo", ty.f32(), nullptr);
|
||||
MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(),
|
||||
{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
|
@ -633,12 +633,12 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByEntryPoint) {
|
|||
auto result = inspector.GetEntryPoints();
|
||||
|
||||
ASSERT_EQ(1u, result.size());
|
||||
ASSERT_EQ(1u, result[0].overridable_constants.size());
|
||||
EXPECT_EQ("foo", result[0].overridable_constants[0].name);
|
||||
ASSERT_EQ(1u, result[0].overrides.size());
|
||||
EXPECT_EQ("foo", result[0].overrides[0].name);
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) {
|
||||
AddOverridableConstantWithoutID("foo", ty.f32(), nullptr);
|
||||
TEST_F(InspectorGetEntryPointTest, OverrideReferencedByCallee) {
|
||||
Override("foo", ty.f32(), nullptr);
|
||||
MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
|
||||
MakeCallerBodyFunction("ep_func", {"callee_func"},
|
||||
{
|
||||
|
@ -651,13 +651,13 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantReferencedByCallee) {
|
|||
auto result = inspector.GetEntryPoints();
|
||||
|
||||
ASSERT_EQ(1u, result.size());
|
||||
ASSERT_EQ(1u, result[0].overridable_constants.size());
|
||||
EXPECT_EQ("foo", result[0].overridable_constants[0].name);
|
||||
ASSERT_EQ(1u, result[0].overrides.size());
|
||||
EXPECT_EQ("foo", result[0].overrides[0].name);
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) {
|
||||
AddOverridableConstantWithID("foo", 1, ty.f32(), nullptr);
|
||||
AddOverridableConstantWithID("bar", 2, ty.f32(), nullptr);
|
||||
TEST_F(InspectorGetEntryPointTest, OverrideSomeReferenced) {
|
||||
Override("foo", ty.f32(), nullptr, {Id(1)});
|
||||
Override("bar", ty.f32(), nullptr, {Id(2)});
|
||||
MakePlainGlobalReferenceBodyFunction("callee_func", "foo", ty.f32(), {});
|
||||
MakeCallerBodyFunction("ep_func", {"callee_func"},
|
||||
{
|
||||
|
@ -670,16 +670,16 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantSomeReferenced) {
|
|||
auto result = inspector.GetEntryPoints();
|
||||
|
||||
ASSERT_EQ(1u, result.size());
|
||||
ASSERT_EQ(1u, result[0].overridable_constants.size());
|
||||
EXPECT_EQ("foo", result[0].overridable_constants[0].name);
|
||||
EXPECT_EQ(1, result[0].overridable_constants[0].numeric_id);
|
||||
ASSERT_EQ(1u, result[0].overrides.size());
|
||||
EXPECT_EQ("foo", result[0].overrides[0].name);
|
||||
EXPECT_EQ(1, result[0].overrides[0].id.value);
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetEntryPointTest, OverridableConstantTypes) {
|
||||
AddOverridableConstantWithoutID("bool_var", ty.bool_(), nullptr);
|
||||
AddOverridableConstantWithoutID("float_var", ty.f32(), nullptr);
|
||||
AddOverridableConstantWithoutID("u32_var", ty.u32(), nullptr);
|
||||
AddOverridableConstantWithoutID("i32_var", ty.i32(), nullptr);
|
||||
TEST_F(InspectorGetEntryPointTest, OverrideTypes) {
|
||||
Override("bool_var", ty.bool_(), nullptr);
|
||||
Override("float_var", ty.f32(), nullptr);
|
||||
Override("u32_var", ty.u32(), nullptr);
|
||||
Override("i32_var", ty.i32(), nullptr);
|
||||
|
||||
MakePlainGlobalReferenceBodyFunction("bool_func", "bool_var", ty.bool_(), {});
|
||||
MakePlainGlobalReferenceBodyFunction("float_func", "float_var", ty.f32(), {});
|
||||
|
@ -697,22 +697,19 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantTypes) {
|
|||
auto result = inspector.GetEntryPoints();
|
||||
|
||||
ASSERT_EQ(1u, result.size());
|
||||
ASSERT_EQ(4u, result[0].overridable_constants.size());
|
||||
EXPECT_EQ("bool_var", result[0].overridable_constants[0].name);
|
||||
EXPECT_EQ(inspector::OverridableConstant::Type::kBool, result[0].overridable_constants[0].type);
|
||||
EXPECT_EQ("float_var", result[0].overridable_constants[1].name);
|
||||
EXPECT_EQ(inspector::OverridableConstant::Type::kFloat32,
|
||||
result[0].overridable_constants[1].type);
|
||||
EXPECT_EQ("u32_var", result[0].overridable_constants[2].name);
|
||||
EXPECT_EQ(inspector::OverridableConstant::Type::kUint32,
|
||||
result[0].overridable_constants[2].type);
|
||||
EXPECT_EQ("i32_var", result[0].overridable_constants[3].name);
|
||||
EXPECT_EQ(inspector::OverridableConstant::Type::kInt32,
|
||||
result[0].overridable_constants[3].type);
|
||||
ASSERT_EQ(4u, result[0].overrides.size());
|
||||
EXPECT_EQ("bool_var", result[0].overrides[0].name);
|
||||
EXPECT_EQ(inspector::Override::Type::kBool, result[0].overrides[0].type);
|
||||
EXPECT_EQ("float_var", result[0].overrides[1].name);
|
||||
EXPECT_EQ(inspector::Override::Type::kFloat32, result[0].overrides[1].type);
|
||||
EXPECT_EQ("u32_var", result[0].overrides[2].name);
|
||||
EXPECT_EQ(inspector::Override::Type::kUint32, result[0].overrides[2].type);
|
||||
EXPECT_EQ("i32_var", result[0].overrides[3].name);
|
||||
EXPECT_EQ(inspector::Override::Type::kInt32, result[0].overrides[3].type);
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetEntryPointTest, OverridableConstantInitialized) {
|
||||
AddOverridableConstantWithoutID("foo", ty.f32(), Expr(0_f));
|
||||
TEST_F(InspectorGetEntryPointTest, OverrideInitialized) {
|
||||
Override("foo", ty.f32(), Expr(0_f));
|
||||
MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(),
|
||||
{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
|
@ -724,13 +721,13 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantInitialized) {
|
|||
auto result = inspector.GetEntryPoints();
|
||||
|
||||
ASSERT_EQ(1u, result.size());
|
||||
ASSERT_EQ(1u, result[0].overridable_constants.size());
|
||||
EXPECT_EQ("foo", result[0].overridable_constants[0].name);
|
||||
EXPECT_TRUE(result[0].overridable_constants[0].is_initialized);
|
||||
ASSERT_EQ(1u, result[0].overrides.size());
|
||||
EXPECT_EQ("foo", result[0].overrides[0].name);
|
||||
EXPECT_TRUE(result[0].overrides[0].is_initialized);
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetEntryPointTest, OverridableConstantUninitialized) {
|
||||
AddOverridableConstantWithoutID("foo", ty.f32(), nullptr);
|
||||
TEST_F(InspectorGetEntryPointTest, OverrideUninitialized) {
|
||||
Override("foo", ty.f32(), nullptr);
|
||||
MakePlainGlobalReferenceBodyFunction("ep_func", "foo", ty.f32(),
|
||||
{
|
||||
Stage(ast::PipelineStage::kCompute),
|
||||
|
@ -742,15 +739,15 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantUninitialized) {
|
|||
auto result = inspector.GetEntryPoints();
|
||||
|
||||
ASSERT_EQ(1u, result.size());
|
||||
ASSERT_EQ(1u, result[0].overridable_constants.size());
|
||||
EXPECT_EQ("foo", result[0].overridable_constants[0].name);
|
||||
ASSERT_EQ(1u, result[0].overrides.size());
|
||||
EXPECT_EQ("foo", result[0].overrides[0].name);
|
||||
|
||||
EXPECT_FALSE(result[0].overridable_constants[0].is_initialized);
|
||||
EXPECT_FALSE(result[0].overrides[0].is_initialized);
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetEntryPointTest, OverridableConstantNumericIDSpecified) {
|
||||
AddOverridableConstantWithoutID("foo_no_id", ty.f32(), nullptr);
|
||||
AddOverridableConstantWithID("foo_id", 1234, ty.f32(), nullptr);
|
||||
TEST_F(InspectorGetEntryPointTest, OverrideNumericIDSpecified) {
|
||||
Override("foo_no_id", ty.f32(), nullptr);
|
||||
Override("foo_id", ty.f32(), nullptr, {Id(1234)});
|
||||
|
||||
MakePlainGlobalReferenceBodyFunction("no_id_func", "foo_no_id", ty.f32(), {});
|
||||
MakePlainGlobalReferenceBodyFunction("id_func", "foo_id", ty.f32(), {});
|
||||
|
@ -766,16 +763,16 @@ TEST_F(InspectorGetEntryPointTest, OverridableConstantNumericIDSpecified) {
|
|||
auto result = inspector.GetEntryPoints();
|
||||
|
||||
ASSERT_EQ(1u, result.size());
|
||||
ASSERT_EQ(2u, result[0].overridable_constants.size());
|
||||
EXPECT_EQ("foo_no_id", result[0].overridable_constants[0].name);
|
||||
EXPECT_EQ("foo_id", result[0].overridable_constants[1].name);
|
||||
EXPECT_EQ(1234, result[0].overridable_constants[1].numeric_id);
|
||||
ASSERT_EQ(2u, result[0].overrides.size());
|
||||
EXPECT_EQ("foo_no_id", result[0].overrides[0].name);
|
||||
EXPECT_EQ("foo_id", result[0].overrides[1].name);
|
||||
EXPECT_EQ(1234, result[0].overrides[1].id.value);
|
||||
|
||||
EXPECT_FALSE(result[0].overridable_constants[0].is_numeric_id_specified);
|
||||
EXPECT_TRUE(result[0].overridable_constants[1].is_numeric_id_specified);
|
||||
EXPECT_FALSE(result[0].overrides[0].is_id_specified);
|
||||
EXPECT_TRUE(result[0].overrides[1].is_id_specified);
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetEntryPointTest, NonOverridableConstantSkipped) {
|
||||
TEST_F(InspectorGetEntryPointTest, NonOverrideSkipped) {
|
||||
auto* foo_struct_type = MakeUniformBufferType("foo_type", {ty.i32()});
|
||||
AddUniformBuffer("foo_ub", ty.Of(foo_struct_type), 0, 0);
|
||||
MakeStructVariableReferenceBodyFunction("ub_func", "foo_ub", {{0, ty.i32()}});
|
||||
|
@ -789,7 +786,7 @@ TEST_F(InspectorGetEntryPointTest, NonOverridableConstantSkipped) {
|
|||
auto result = inspector.GetEntryPoints();
|
||||
|
||||
ASSERT_EQ(1u, result.size());
|
||||
EXPECT_EQ(0u, result[0].overridable_constants.size());
|
||||
EXPECT_EQ(0u, result[0].overrides.size());
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetEntryPointTest, BuiltinNotReferenced) {
|
||||
|
@ -1172,127 +1169,127 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
ast::InterpolationType::kFlat, ast::InterpolationSampling::kNone,
|
||||
InterpolationType::kFlat, InterpolationSampling::kNone}));
|
||||
|
||||
TEST_F(InspectorGetConstantIDsTest, Bool) {
|
||||
AddOverridableConstantWithID("foo", 1, ty.bool_(), nullptr);
|
||||
AddOverridableConstantWithID("bar", 20, ty.bool_(), Expr(true));
|
||||
AddOverridableConstantWithID("baz", 300, ty.bool_(), Expr(false));
|
||||
TEST_F(InspectorGetOverrideDefaultValuesTest, Bool) {
|
||||
Override("foo", ty.bool_(), nullptr, {Id(1)});
|
||||
Override("bar", ty.bool_(), Expr(true), {Id(20)});
|
||||
Override("baz", ty.bool_(), Expr(false), {Id(300)});
|
||||
|
||||
Inspector& inspector = Build();
|
||||
|
||||
auto result = inspector.GetConstantIDs();
|
||||
auto result = inspector.GetOverrideDefaultValues();
|
||||
ASSERT_EQ(3u, result.size());
|
||||
|
||||
ASSERT_TRUE(result.find(1) != result.end());
|
||||
EXPECT_TRUE(result[1].IsNull());
|
||||
ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{1}].IsNull());
|
||||
|
||||
ASSERT_TRUE(result.find(20) != result.end());
|
||||
EXPECT_TRUE(result[20].IsBool());
|
||||
EXPECT_TRUE(result[20].AsBool());
|
||||
ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{20}].IsBool());
|
||||
EXPECT_TRUE(result[OverrideId{20}].AsBool());
|
||||
|
||||
ASSERT_TRUE(result.find(300) != result.end());
|
||||
EXPECT_TRUE(result[300].IsBool());
|
||||
EXPECT_FALSE(result[300].AsBool());
|
||||
ASSERT_TRUE(result.find(OverrideId{300}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{300}].IsBool());
|
||||
EXPECT_FALSE(result[OverrideId{300}].AsBool());
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetConstantIDsTest, U32) {
|
||||
AddOverridableConstantWithID("foo", 1, ty.u32(), nullptr);
|
||||
AddOverridableConstantWithID("bar", 20, ty.u32(), Expr(42_u));
|
||||
TEST_F(InspectorGetOverrideDefaultValuesTest, U32) {
|
||||
Override("foo", ty.u32(), nullptr, {Id(1)});
|
||||
Override("bar", ty.u32(), Expr(42_u), {Id(20)});
|
||||
|
||||
Inspector& inspector = Build();
|
||||
|
||||
auto result = inspector.GetConstantIDs();
|
||||
auto result = inspector.GetOverrideDefaultValues();
|
||||
ASSERT_EQ(2u, result.size());
|
||||
|
||||
ASSERT_TRUE(result.find(1) != result.end());
|
||||
EXPECT_TRUE(result[1].IsNull());
|
||||
ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{1}].IsNull());
|
||||
|
||||
ASSERT_TRUE(result.find(20) != result.end());
|
||||
EXPECT_TRUE(result[20].IsU32());
|
||||
EXPECT_EQ(42u, result[20].AsU32());
|
||||
ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{20}].IsU32());
|
||||
EXPECT_EQ(42u, result[OverrideId{20}].AsU32());
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetConstantIDsTest, I32) {
|
||||
AddOverridableConstantWithID("foo", 1, ty.i32(), nullptr);
|
||||
AddOverridableConstantWithID("bar", 20, ty.i32(), Expr(-42_i));
|
||||
AddOverridableConstantWithID("baz", 300, ty.i32(), Expr(42_i));
|
||||
TEST_F(InspectorGetOverrideDefaultValuesTest, I32) {
|
||||
Override("foo", ty.i32(), nullptr, {Id(1)});
|
||||
Override("bar", ty.i32(), Expr(-42_i), {Id(20)});
|
||||
Override("baz", ty.i32(), Expr(42_i), {Id(300)});
|
||||
|
||||
Inspector& inspector = Build();
|
||||
|
||||
auto result = inspector.GetConstantIDs();
|
||||
auto result = inspector.GetOverrideDefaultValues();
|
||||
ASSERT_EQ(3u, result.size());
|
||||
|
||||
ASSERT_TRUE(result.find(1) != result.end());
|
||||
EXPECT_TRUE(result[1].IsNull());
|
||||
ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{1}].IsNull());
|
||||
|
||||
ASSERT_TRUE(result.find(20) != result.end());
|
||||
EXPECT_TRUE(result[20].IsI32());
|
||||
EXPECT_EQ(-42, result[20].AsI32());
|
||||
ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{20}].IsI32());
|
||||
EXPECT_EQ(-42, result[OverrideId{20}].AsI32());
|
||||
|
||||
ASSERT_TRUE(result.find(300) != result.end());
|
||||
EXPECT_TRUE(result[300].IsI32());
|
||||
EXPECT_EQ(42, result[300].AsI32());
|
||||
ASSERT_TRUE(result.find(OverrideId{300}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{300}].IsI32());
|
||||
EXPECT_EQ(42, result[OverrideId{300}].AsI32());
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetConstantIDsTest, Float) {
|
||||
AddOverridableConstantWithID("foo", 1, ty.f32(), nullptr);
|
||||
AddOverridableConstantWithID("bar", 20, ty.f32(), Expr(0_f));
|
||||
AddOverridableConstantWithID("baz", 300, ty.f32(), Expr(-10_f));
|
||||
AddOverridableConstantWithID("x", 4000, ty.f32(), Expr(15_f));
|
||||
TEST_F(InspectorGetOverrideDefaultValuesTest, Float) {
|
||||
Override("foo", ty.f32(), nullptr, {Id(1)});
|
||||
Override("bar", ty.f32(), Expr(0_f), {Id(20)});
|
||||
Override("baz", ty.f32(), Expr(-10_f), {Id(300)});
|
||||
Override("x", ty.f32(), Expr(15_f), {Id(4000)});
|
||||
|
||||
Inspector& inspector = Build();
|
||||
|
||||
auto result = inspector.GetConstantIDs();
|
||||
auto result = inspector.GetOverrideDefaultValues();
|
||||
ASSERT_EQ(4u, result.size());
|
||||
|
||||
ASSERT_TRUE(result.find(1) != result.end());
|
||||
EXPECT_TRUE(result[1].IsNull());
|
||||
ASSERT_TRUE(result.find(OverrideId{1}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{1}].IsNull());
|
||||
|
||||
ASSERT_TRUE(result.find(20) != result.end());
|
||||
EXPECT_TRUE(result[20].IsFloat());
|
||||
EXPECT_FLOAT_EQ(0.0f, result[20].AsFloat());
|
||||
ASSERT_TRUE(result.find(OverrideId{20}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{20}].IsFloat());
|
||||
EXPECT_FLOAT_EQ(0.0f, result[OverrideId{20}].AsFloat());
|
||||
|
||||
ASSERT_TRUE(result.find(300) != result.end());
|
||||
EXPECT_TRUE(result[300].IsFloat());
|
||||
EXPECT_FLOAT_EQ(-10.0f, result[300].AsFloat());
|
||||
ASSERT_TRUE(result.find(OverrideId{300}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{300}].IsFloat());
|
||||
EXPECT_FLOAT_EQ(-10.0f, result[OverrideId{300}].AsFloat());
|
||||
|
||||
ASSERT_TRUE(result.find(4000) != result.end());
|
||||
EXPECT_TRUE(result[4000].IsFloat());
|
||||
EXPECT_FLOAT_EQ(15.0f, result[4000].AsFloat());
|
||||
ASSERT_TRUE(result.find(OverrideId{4000}) != result.end());
|
||||
EXPECT_TRUE(result[OverrideId{4000}].IsFloat());
|
||||
EXPECT_FLOAT_EQ(15.0f, result[OverrideId{4000}].AsFloat());
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetConstantNameToIdMapTest, WithAndWithoutIds) {
|
||||
AddOverridableConstantWithID("v1", 1, ty.f32(), nullptr);
|
||||
AddOverridableConstantWithID("v20", 20, ty.f32(), nullptr);
|
||||
AddOverridableConstantWithID("v300", 300, ty.f32(), nullptr);
|
||||
auto* a = AddOverridableConstantWithoutID("a", ty.f32(), nullptr);
|
||||
auto* b = AddOverridableConstantWithoutID("b", ty.f32(), nullptr);
|
||||
auto* c = AddOverridableConstantWithoutID("c", ty.f32(), nullptr);
|
||||
Override("v1", ty.f32(), nullptr, {Id(1)});
|
||||
Override("v20", ty.f32(), nullptr, {Id(20)});
|
||||
Override("v300", ty.f32(), nullptr, {Id(300)});
|
||||
auto* a = Override("a", ty.f32(), nullptr);
|
||||
auto* b = Override("b", ty.f32(), nullptr);
|
||||
auto* c = Override("c", ty.f32(), nullptr);
|
||||
|
||||
Inspector& inspector = Build();
|
||||
|
||||
auto result = inspector.GetConstantNameToIdMap();
|
||||
auto result = inspector.GetNamedOverrideIds();
|
||||
ASSERT_EQ(6u, result.size());
|
||||
|
||||
ASSERT_TRUE(result.count("v1"));
|
||||
EXPECT_EQ(result["v1"], 1u);
|
||||
EXPECT_EQ(result["v1"].value, 1u);
|
||||
|
||||
ASSERT_TRUE(result.count("v20"));
|
||||
EXPECT_EQ(result["v20"], 20u);
|
||||
EXPECT_EQ(result["v20"].value, 20u);
|
||||
|
||||
ASSERT_TRUE(result.count("v300"));
|
||||
EXPECT_EQ(result["v300"], 300u);
|
||||
EXPECT_EQ(result["v300"].value, 300u);
|
||||
|
||||
ASSERT_TRUE(result.count("a"));
|
||||
ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(a));
|
||||
EXPECT_EQ(result["a"], program_->Sem().Get<sem::GlobalVariable>(a)->ConstantId());
|
||||
EXPECT_EQ(result["a"], program_->Sem().Get<sem::GlobalVariable>(a)->OverrideId());
|
||||
|
||||
ASSERT_TRUE(result.count("b"));
|
||||
ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(b));
|
||||
EXPECT_EQ(result["b"], program_->Sem().Get<sem::GlobalVariable>(b)->ConstantId());
|
||||
EXPECT_EQ(result["b"], program_->Sem().Get<sem::GlobalVariable>(b)->OverrideId());
|
||||
|
||||
ASSERT_TRUE(result.count("c"));
|
||||
ASSERT_TRUE(program_->Sem().Get<sem::GlobalVariable>(c));
|
||||
EXPECT_EQ(result["c"], program_->Sem().Get<sem::GlobalVariable>(c)->ConstantId());
|
||||
EXPECT_EQ(result["c"], program_->Sem().Get<sem::GlobalVariable>(c)->OverrideId());
|
||||
}
|
||||
|
||||
TEST_F(InspectorGetStorageSizeTest, Empty) {
|
||||
|
|
|
@ -92,31 +92,6 @@ class InspectorBuilder : public ProgramBuilder {
|
|||
std::vector<std::tuple<std::string, std::string>> inout_vars,
|
||||
ast::AttributeList attributes);
|
||||
|
||||
/// Add a pipeline constant to the global variables, with a specific ID.
|
||||
/// @param name name of the variable to add
|
||||
/// @param id id number for the constant id
|
||||
/// @param type type of the variable
|
||||
/// @param constructor val to initialize the constant with, if NULL no
|
||||
/// constructor will be added.
|
||||
/// @returns the constant that was created
|
||||
const ast::Variable* AddOverridableConstantWithID(std::string name,
|
||||
uint32_t id,
|
||||
const ast::Type* type,
|
||||
const ast::Expression* constructor) {
|
||||
return Override(name, type, constructor, {Id(id)});
|
||||
}
|
||||
|
||||
/// Add a pipeline constant to the global variables, without a specific ID.
|
||||
/// @param name name of the variable to add
|
||||
/// @param type type of the variable
|
||||
/// @param constructor val to initialize the constant with, if NULL no
|
||||
/// constructor will be added.
|
||||
/// @returns the constant that was created
|
||||
const ast::Variable* AddOverridableConstantWithoutID(std::string name,
|
||||
const ast::Type* type,
|
||||
const ast::Expression* constructor) {
|
||||
return Override(name, type, constructor);
|
||||
}
|
||||
|
||||
/// Generates a function that references module-scoped, plain-typed constant
|
||||
/// or variable.
|
||||
|
|
|
@ -19,6 +19,8 @@
|
|||
#include <unordered_set>
|
||||
#include <utility>
|
||||
|
||||
#include "tint/override_id.h"
|
||||
|
||||
#include "src/tint/ast/alias.h"
|
||||
#include "src/tint/ast/array.h"
|
||||
#include "src/tint/ast/assignment_statement.h"
|
||||
|
@ -2588,6 +2590,19 @@ class ProgramBuilder {
|
|||
return create<ast::LocationAttribute>(source_, location);
|
||||
}
|
||||
|
||||
/// Creates an ast::IdAttribute
|
||||
/// @param source the source information
|
||||
/// @param id the id value
|
||||
/// @returns the override attribute pointer
|
||||
const ast::IdAttribute* Id(const Source& source, OverrideId id) {
|
||||
return create<ast::IdAttribute>(source, id.value);
|
||||
}
|
||||
|
||||
/// Creates an ast::IdAttribute with an override identifier
|
||||
/// @param id the optional id value
|
||||
/// @returns the override attribute pointer
|
||||
const ast::IdAttribute* Id(OverrideId id) { return Id(source_, id); }
|
||||
|
||||
/// Creates an ast::IdAttribute
|
||||
/// @param source the source information
|
||||
/// @param id the id value
|
||||
|
@ -2596,7 +2611,7 @@ class ProgramBuilder {
|
|||
return create<ast::IdAttribute>(source, id);
|
||||
}
|
||||
|
||||
/// Creates an ast::IdAttribute with a constant ID
|
||||
/// Creates an ast::IdAttribute with an override identifier
|
||||
/// @param id the optional id value
|
||||
/// @returns the override attribute pointer
|
||||
const ast::IdAttribute* Id(uint32_t id) { return Id(source_, id); }
|
||||
|
|
|
@ -21,23 +21,23 @@ using namespace tint::number_suffixes; // NOLINT
|
|||
namespace tint::resolver {
|
||||
namespace {
|
||||
|
||||
class ResolverPipelineOverridableConstantTest : public ResolverTest {
|
||||
class ResolverOverrideTest : public ResolverTest {
|
||||
protected:
|
||||
/// Verify that the AST node `var` was resolved to an overridable constant
|
||||
/// with an ID equal to `id`.
|
||||
/// @param var the overridable constant AST node
|
||||
/// @param id the expected constant ID
|
||||
void ExpectConstantId(const ast::Variable* var, uint16_t id) {
|
||||
void ExpectOverrideId(const ast::Variable* var, uint16_t id) {
|
||||
auto* sem = Sem().Get<sem::GlobalVariable>(var);
|
||||
ASSERT_NE(sem, nullptr);
|
||||
EXPECT_EQ(sem->Declaration(), var);
|
||||
EXPECT_TRUE(sem->Declaration()->Is<ast::Override>());
|
||||
EXPECT_EQ(sem->ConstantId(), id);
|
||||
EXPECT_EQ(sem->OverrideId().value, id);
|
||||
EXPECT_FALSE(sem->ConstantValue());
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
|
||||
TEST_F(ResolverOverrideTest, NonOverridable) {
|
||||
auto* a = GlobalConst("a", ty.f32(), Expr(1_f));
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
@ -49,23 +49,23 @@ TEST_F(ResolverPipelineOverridableConstantTest, NonOverridable) {
|
|||
EXPECT_TRUE(sem_a->ConstantValue());
|
||||
}
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, WithId) {
|
||||
TEST_F(ResolverOverrideTest, WithId) {
|
||||
auto* a = Override("a", ty.f32(), Expr(1_f), {Id(7u)});
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
ExpectConstantId(a, 7u);
|
||||
ExpectOverrideId(a, 7u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, WithoutId) {
|
||||
TEST_F(ResolverOverrideTest, WithoutId) {
|
||||
auto* a = Override("a", ty.f32(), Expr(1_f));
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
ExpectConstantId(a, 0u);
|
||||
ExpectOverrideId(a, 0u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
|
||||
TEST_F(ResolverOverrideTest, WithAndWithoutIds) {
|
||||
std::vector<ast::Variable*> variables;
|
||||
auto* a = Override("a", ty.f32(), Expr(1_f));
|
||||
auto* b = Override("b", ty.f32(), Expr(1_f));
|
||||
|
@ -77,33 +77,33 @@ TEST_F(ResolverPipelineOverridableConstantTest, WithAndWithoutIds) {
|
|||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
// Verify that constant id allocation order is deterministic.
|
||||
ExpectConstantId(a, 0u);
|
||||
ExpectConstantId(b, 3u);
|
||||
ExpectConstantId(c, 2u);
|
||||
ExpectConstantId(d, 4u);
|
||||
ExpectConstantId(e, 5u);
|
||||
ExpectConstantId(f, 1u);
|
||||
ExpectOverrideId(a, 0u);
|
||||
ExpectOverrideId(b, 3u);
|
||||
ExpectOverrideId(c, 2u);
|
||||
ExpectOverrideId(d, 4u);
|
||||
ExpectOverrideId(e, 5u);
|
||||
ExpectOverrideId(f, 1u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, DuplicateIds) {
|
||||
TEST_F(ResolverOverrideTest, DuplicateIds) {
|
||||
Override("a", ty.f32(), Expr(1_f), {Id(Source{{12, 34}}, 7u)});
|
||||
Override("b", ty.f32(), Expr(1_f), {Id(Source{{56, 78}}, 7u)});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
||||
EXPECT_EQ(r()->error(), R"(56:78 error: pipeline constant IDs must be unique
|
||||
12:34 note: a pipeline constant with an ID of 7 was previously declared here:)");
|
||||
EXPECT_EQ(r()->error(), R"(56:78 error: override IDs must be unique
|
||||
12:34 note: a override with an ID of 7 was previously declared here:)");
|
||||
}
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, IdTooLarge) {
|
||||
TEST_F(ResolverOverrideTest, IdTooLarge) {
|
||||
Override("a", ty.f32(), Expr(1_f), {Id(Source{{12, 34}}, 65536u)});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
||||
EXPECT_EQ(r()->error(), "12:34 error: pipeline constant IDs must be between 0 and 65535");
|
||||
EXPECT_EQ(r()->error(), "12:34 error: override IDs must be between 0 and 65535");
|
||||
}
|
||||
|
||||
TEST_F(ResolverPipelineOverridableConstantTest, F16_TemporallyBan) {
|
||||
TEST_F(ResolverOverrideTest, F16_TemporallyBan) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
Override(Source{{12, 34}}, "a", ty.f16(), Expr(1_h), {Id(1u)});
|
|
@ -149,7 +149,9 @@ bool Resolver::ResolveInternal() {
|
|||
}
|
||||
}
|
||||
|
||||
AllocateOverridableConstantIds();
|
||||
if (!AllocateOverridableConstantIds()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
SetShadows();
|
||||
|
||||
|
@ -432,7 +434,7 @@ sem::Variable* Resolver::Override(const ast::Override* v) {
|
|||
/* constant_value */ nullptr, sem::BindingPoint{});
|
||||
|
||||
if (auto* id = ast::GetAttribute<ast::IdAttribute>(v->attributes)) {
|
||||
sem->SetConstantId(static_cast<uint16_t>(id->value));
|
||||
sem->SetOverrideId(OverrideId{static_cast<decltype(OverrideId::value)>(id->value)});
|
||||
}
|
||||
|
||||
sem->SetConstructor(rhs);
|
||||
|
@ -641,9 +643,19 @@ ast::Access Resolver::DefaultAccessForStorageClass(ast::StorageClass storage_cla
|
|||
return ast::Access::kReadWrite;
|
||||
}
|
||||
|
||||
void Resolver::AllocateOverridableConstantIds() {
|
||||
bool Resolver::AllocateOverridableConstantIds() {
|
||||
constexpr size_t kLimit = std::numeric_limits<decltype(OverrideId::value)>::max();
|
||||
// The next pipeline constant ID to try to allocate.
|
||||
uint16_t next_constant_id = 0;
|
||||
OverrideId next_id;
|
||||
bool ids_exhausted = false;
|
||||
|
||||
auto increment_next_id = [&] {
|
||||
if (next_id.value == kLimit) {
|
||||
ids_exhausted = true;
|
||||
} else {
|
||||
next_id.value = next_id.value + 1;
|
||||
}
|
||||
};
|
||||
|
||||
// Allocate constant IDs in global declaration order, so that they are
|
||||
// deterministic.
|
||||
|
@ -655,26 +667,28 @@ void Resolver::AllocateOverridableConstantIds() {
|
|||
continue;
|
||||
}
|
||||
|
||||
uint16_t constant_id;
|
||||
OverrideId id;
|
||||
if (auto* id_attr = ast::GetAttribute<ast::IdAttribute>(override->attributes)) {
|
||||
constant_id = static_cast<uint16_t>(id_attr->value);
|
||||
id = OverrideId{static_cast<decltype(OverrideId::value)>(id_attr->value)};
|
||||
} else {
|
||||
// No ID was specified, so allocate the next available ID.
|
||||
constant_id = next_constant_id;
|
||||
while (constant_ids_.count(constant_id)) {
|
||||
if (constant_id == UINT16_MAX) {
|
||||
TINT_ICE(Resolver, builder_->Diagnostics())
|
||||
<< "no more pipeline constant IDs available";
|
||||
return;
|
||||
while (!ids_exhausted && override_ids_.count(next_id)) {
|
||||
increment_next_id();
|
||||
}
|
||||
constant_id++;
|
||||
if (ids_exhausted) {
|
||||
AddError(
|
||||
"number of 'override' variables exceeded limit of " + std::to_string(kLimit),
|
||||
decl->source);
|
||||
return false;
|
||||
}
|
||||
next_constant_id = constant_id + 1;
|
||||
id = next_id;
|
||||
increment_next_id();
|
||||
}
|
||||
|
||||
auto* sem = sem_.Get<sem::GlobalVariable>(override);
|
||||
const_cast<sem::GlobalVariable*>(sem)->SetConstantId(constant_id);
|
||||
const_cast<sem::GlobalVariable*>(sem)->SetOverrideId(id);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void Resolver::SetShadows() {
|
||||
|
@ -697,7 +711,8 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) {
|
|||
|
||||
if (auto* id_attr = attr->As<ast::IdAttribute>()) {
|
||||
// Track the constant IDs that are specified in the shader.
|
||||
constant_ids_.emplace(id_attr->value, sem);
|
||||
override_ids_.emplace(
|
||||
OverrideId{static_cast<decltype(OverrideId::value)>(id_attr->value)}, sem);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -705,7 +720,7 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
if (!validator_.GlobalVariable(sem, constant_ids_, atomic_composite_info_)) {
|
||||
if (!validator_.GlobalVariable(sem, override_ids_, atomic_composite_info_)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -362,7 +362,8 @@ class Resolver {
|
|||
ast::Access DefaultAccessForStorageClass(ast::StorageClass storage_class);
|
||||
|
||||
/// Allocate constant IDs for pipeline-overridable constants.
|
||||
void AllocateOverridableConstantIds();
|
||||
/// @returns true on success, false on error
|
||||
bool AllocateOverridableConstantIds();
|
||||
|
||||
/// Set the shadowing information on variable declarations.
|
||||
/// @note this method must only be called after all semantic nodes are built.
|
||||
|
@ -422,7 +423,7 @@ class Resolver {
|
|||
std::vector<sem::Function*> entry_points_;
|
||||
std::unordered_map<const sem::Type*, const Source&> atomic_composite_info_;
|
||||
utils::Bitset<0> marked_;
|
||||
std::unordered_map<uint32_t, const sem::Variable*> constant_ids_;
|
||||
std::unordered_map<OverrideId, const sem::Variable*> override_ids_;
|
||||
std::unordered_map<ArrayConstructorSig, sem::CallTarget*> array_ctors_;
|
||||
std::unordered_map<StructConstructorSig, sem::CallTarget*> struct_ctors_;
|
||||
|
||||
|
|
|
@ -562,7 +562,7 @@ bool Validator::LocalVariable(const sem::Variable* v) const {
|
|||
|
||||
bool Validator::GlobalVariable(
|
||||
const sem::GlobalVariable* global,
|
||||
const std::unordered_map<uint32_t, const sem::Variable*>& constant_ids,
|
||||
const std::unordered_map<OverrideId, const sem::Variable*>& override_ids,
|
||||
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const {
|
||||
auto* decl = global->Declaration();
|
||||
bool ok = Switch(
|
||||
|
@ -627,7 +627,7 @@ bool Validator::GlobalVariable(
|
|||
|
||||
return Var(global);
|
||||
},
|
||||
[&](const ast::Override*) { return Override(global, constant_ids); },
|
||||
[&](const ast::Override*) { return Override(global, override_ids); },
|
||||
[&](const ast::Const*) {
|
||||
if (!decl->attributes.empty()) {
|
||||
AddError("attribute is not valid for module-scope 'const' declaration",
|
||||
|
@ -763,7 +763,7 @@ bool Validator::Let(const sem::Variable* v) const {
|
|||
|
||||
bool Validator::Override(
|
||||
const sem::Variable* v,
|
||||
const std::unordered_map<uint32_t, const sem::Variable*>& constant_ids) const {
|
||||
const std::unordered_map<OverrideId, const sem::Variable*>& override_ids) const {
|
||||
auto* decl = v->Declaration();
|
||||
auto* storage_ty = v->Type()->UnwrapRef();
|
||||
|
||||
|
@ -776,19 +776,23 @@ bool Validator::Override(
|
|||
for (auto* attr : decl->attributes) {
|
||||
if (auto* id_attr = attr->As<ast::IdAttribute>()) {
|
||||
uint32_t id = id_attr->value;
|
||||
auto it = constant_ids.find(id);
|
||||
if (it != constant_ids.end() && it->second != v) {
|
||||
AddError("pipeline constant IDs must be unique", attr->source);
|
||||
AddNote("a pipeline constant with an ID of " + std::to_string(id) +
|
||||
if (id > std::numeric_limits<decltype(OverrideId::value)>::max()) {
|
||||
AddError(
|
||||
"override IDs must be between 0 and " +
|
||||
std::to_string(std::numeric_limits<decltype(OverrideId::value)>::max()),
|
||||
attr->source);
|
||||
return false;
|
||||
}
|
||||
if (auto it =
|
||||
override_ids.find(OverrideId{static_cast<decltype(OverrideId::value)>(id)});
|
||||
it != override_ids.end() && it->second != v) {
|
||||
AddError("override IDs must be unique", attr->source);
|
||||
AddNote("a override with an ID of " + std::to_string(id) +
|
||||
" was previously declared here:",
|
||||
ast::GetAttribute<ast::IdAttribute>(it->second->Declaration()->attributes)
|
||||
->source);
|
||||
return false;
|
||||
}
|
||||
if (id > 65535) {
|
||||
AddError("pipeline constant IDs must be between 0 and 65535", attr->source);
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
AddError("attribute is not valid for 'override' declaration", attr->source);
|
||||
return false;
|
||||
|
|
|
@ -234,12 +234,12 @@ class Validator {
|
|||
|
||||
/// Validates a global variable
|
||||
/// @param var the global variable to validate
|
||||
/// @param constant_ids the set of constant ids in the module
|
||||
/// @param override_id the set of override ids in the module
|
||||
/// @param atomic_composite_info atomic composite info in the module
|
||||
/// @returns true on success, false otherwise
|
||||
bool GlobalVariable(
|
||||
const sem::GlobalVariable* var,
|
||||
const std::unordered_map<uint32_t, const sem::Variable*>& constant_ids,
|
||||
const std::unordered_map<OverrideId, const sem::Variable*>& override_id,
|
||||
const std::unordered_map<const sem::Type*, const Source&>& atomic_composite_info) const;
|
||||
|
||||
/// Validates an if statement
|
||||
|
@ -371,10 +371,10 @@ class Validator {
|
|||
|
||||
/// Validates a 'override' variable declaration
|
||||
/// @param v the variable to validate
|
||||
/// @param constant_ids the set of constant ids in the module
|
||||
/// @param override_id the set of override ids in the module
|
||||
/// @returns true on success, false otherwise.
|
||||
bool Override(const sem::Variable* v,
|
||||
const std::unordered_map<uint32_t, const sem::Variable*>& constant_ids) const;
|
||||
const std::unordered_map<OverrideId, const sem::Variable*>& override_id) const;
|
||||
|
||||
/// Validates a 'const' variable declaration
|
||||
/// @param v the variable to validate
|
||||
|
|
|
@ -77,6 +77,37 @@ TEST_F(ResolverVariableValidationTest, OverrideNoInitializerNoType) {
|
|||
EXPECT_EQ(r()->error(), "12:34 error: override declaration requires a type or initializer");
|
||||
}
|
||||
|
||||
TEST_F(ResolverVariableValidationTest, OverrideExceedsIDLimit_LastUnreserved) {
|
||||
// override o0 : i32;
|
||||
// override o1 : i32;
|
||||
// ...
|
||||
// override bang : i32;
|
||||
constexpr size_t kLimit = std::numeric_limits<decltype(OverrideId::value)>::max();
|
||||
for (size_t i = 0; i <= kLimit; i++) {
|
||||
Override("o" + std::to_string(i), ty.i32(), nullptr);
|
||||
}
|
||||
Override(Source{{12, 34}}, "bang", ty.i32(), nullptr);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: number of 'override' variables exceeded limit of 65535");
|
||||
}
|
||||
|
||||
TEST_F(ResolverVariableValidationTest, OverrideExceedsIDLimit_LastReserved) {
|
||||
// override o0 : i32;
|
||||
// override o1 : i32;
|
||||
// ...
|
||||
// @id(N) override oN : i32;
|
||||
constexpr size_t kLimit = std::numeric_limits<decltype(OverrideId::value)>::max();
|
||||
Override("reserved", ty.i32(), nullptr, {Id(kLimit)});
|
||||
for (size_t i = 0; i < kLimit; i++) {
|
||||
Override("o" + std::to_string(i), ty.i32(), nullptr);
|
||||
}
|
||||
Override(Source{{12, 34}}, "bang", ty.i32(), nullptr);
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: number of 'override' variables exceeded limit of 65535");
|
||||
}
|
||||
|
||||
TEST_F(ResolverVariableValidationTest, VarTypeNotConstructible) {
|
||||
// var i : i32;
|
||||
// var p : pointer<function, i32> = &v;
|
||||
|
|
|
@ -18,6 +18,8 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tint/override_id.h"
|
||||
|
||||
#include "src/tint/ast/access.h"
|
||||
#include "src/tint/ast/storage_class.h"
|
||||
#include "src/tint/sem/binding_point.h"
|
||||
|
@ -165,15 +167,15 @@ class GlobalVariable final : public Castable<GlobalVariable, Variable> {
|
|||
sem::BindingPoint BindingPoint() const { return binding_point_; }
|
||||
|
||||
/// @param id the constant identifier to assign to this variable
|
||||
void SetConstantId(uint16_t id) { constant_id_ = id; }
|
||||
void SetOverrideId(OverrideId id) { override_id_ = id; }
|
||||
|
||||
/// @returns the pipeline constant ID associated with the variable
|
||||
uint16_t ConstantId() const { return constant_id_; }
|
||||
tint::OverrideId OverrideId() const { return override_id_; }
|
||||
|
||||
private:
|
||||
const sem::BindingPoint binding_point_;
|
||||
|
||||
uint16_t constant_id_ = 0;
|
||||
tint::OverrideId override_id_;
|
||||
};
|
||||
|
||||
/// Parameter is a function parameter
|
||||
|
|
|
@ -76,11 +76,11 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) c
|
|||
[&](const ast::Override* override) {
|
||||
if (referenced_vars.count(override)) {
|
||||
if (!ast::HasAttribute<ast::IdAttribute>(override->attributes)) {
|
||||
// If the constant doesn't already have an @id() attribute, add one
|
||||
// If the override doesn't already have an @id() attribute, add one
|
||||
// so that its allocated ID so that it won't be affected by other
|
||||
// stripped away constants
|
||||
// stripped away overrides
|
||||
auto* global = sem.Get(override);
|
||||
const auto* id = ctx.dst->Id(global->ConstantId());
|
||||
const auto* id = ctx.dst->Id(global->OverrideId());
|
||||
ctx.InsertFront(override->attributes, id);
|
||||
}
|
||||
ctx.dst->AST().AddGlobalVariable(ctx.Clone(override));
|
||||
|
|
|
@ -2159,7 +2159,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
|
|||
TINT_ICE(Writer, builder_.Diagnostics())
|
||||
<< "expected a pipeline-overridable constant";
|
||||
}
|
||||
out << kSpecConstantPrefix << global->ConstantId();
|
||||
out << kSpecConstantPrefix << global->OverrideId().value;
|
||||
} else {
|
||||
out << std::to_string(wgsize[i].value);
|
||||
}
|
||||
|
@ -3053,18 +3053,18 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
|
|||
auto* type = sem->Type();
|
||||
|
||||
auto* global = sem->As<sem::GlobalVariable>();
|
||||
auto const_id = global->ConstantId();
|
||||
auto override_id = global->OverrideId();
|
||||
|
||||
line() << "#ifndef " << kSpecConstantPrefix << const_id;
|
||||
line() << "#ifndef " << kSpecConstantPrefix << override_id.value;
|
||||
|
||||
if (override->constructor != nullptr) {
|
||||
auto out = line();
|
||||
out << "#define " << kSpecConstantPrefix << const_id << " ";
|
||||
out << "#define " << kSpecConstantPrefix << override_id.value << " ";
|
||||
if (!EmitExpression(out, override->constructor)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
line() << "#error spec constant required for constant id " << const_id;
|
||||
line() << "#error spec constant required for constant id " << override_id.value;
|
||||
}
|
||||
line() << "#endif";
|
||||
{
|
||||
|
@ -3074,7 +3074,7 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
|
|||
builder_.Symbols().NameFor(override->symbol))) {
|
||||
return false;
|
||||
}
|
||||
out << " = " << kSpecConstantPrefix << const_id << ";";
|
||||
out << " = " << kSpecConstantPrefix << override_id.value << ";";
|
||||
}
|
||||
|
||||
return true;
|
||||
|
|
|
@ -3061,7 +3061,7 @@ bool GeneratorImpl::EmitEntryPointFunction(const ast::Function* func) {
|
|||
TINT_ICE(Writer, diagnostics_)
|
||||
<< "expected a pipeline-overridable constant";
|
||||
}
|
||||
out << kSpecConstantPrefix << global->ConstantId();
|
||||
out << kSpecConstantPrefix << global->OverrideId().value;
|
||||
} else {
|
||||
out << std::to_string(wgsize[i].value);
|
||||
}
|
||||
|
@ -4101,18 +4101,18 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
|
|||
auto* sem = builder_.Sem().Get(override);
|
||||
auto* type = sem->Type();
|
||||
|
||||
auto const_id = sem->ConstantId();
|
||||
auto override_id = sem->OverrideId();
|
||||
|
||||
line() << "#ifndef " << kSpecConstantPrefix << const_id;
|
||||
line() << "#ifndef " << kSpecConstantPrefix << override_id.value;
|
||||
|
||||
if (override->constructor != nullptr) {
|
||||
auto out = line();
|
||||
out << "#define " << kSpecConstantPrefix << const_id << " ";
|
||||
out << "#define " << kSpecConstantPrefix << override_id.value << " ";
|
||||
if (!EmitExpression(out, override->constructor)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
line() << "#error spec constant required for constant id " << const_id;
|
||||
line() << "#error spec constant required for constant id " << override_id.value;
|
||||
}
|
||||
line() << "#endif";
|
||||
{
|
||||
|
@ -4122,7 +4122,7 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
|
|||
builder_.Symbols().NameFor(override->symbol))) {
|
||||
return false;
|
||||
}
|
||||
out << " = " << kSpecConstantPrefix << const_id << ";";
|
||||
out << " = " << kSpecConstantPrefix << override_id.value << ";";
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -244,10 +244,7 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_arr_vec2_bool) {
|
|||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override) {
|
||||
auto* var = Override("pos", ty.f32(), Expr(3_f),
|
||||
ast::AttributeList{
|
||||
Id(23),
|
||||
});
|
||||
auto* var = Override("pos", ty.f32(), Expr(3_f), {Id(23)});
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
|
@ -260,10 +257,7 @@ static const float pos = WGSL_SPEC_CONSTANT_23;
|
|||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoConstructor) {
|
||||
auto* var = Override("pos", ty.f32(), nullptr,
|
||||
ast::AttributeList{
|
||||
Id(23),
|
||||
});
|
||||
auto* var = Override("pos", ty.f32(), nullptr, {Id(23)});
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
|
@ -276,10 +270,7 @@ static const float pos = WGSL_SPEC_CONSTANT_23;
|
|||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_Override_NoId) {
|
||||
auto* a = Override("a", ty.f32(), Expr(3_f),
|
||||
ast::AttributeList{
|
||||
Id(0),
|
||||
});
|
||||
auto* a = Override("a", ty.f32(), Expr(3_f), {Id(0)});
|
||||
auto* b = Override("b", ty.f32(), Expr(2_f));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
|
|
@ -3063,7 +3063,7 @@ bool GeneratorImpl::EmitOverride(const ast::Override* override) {
|
|||
}
|
||||
out << " " << program_->Symbols().NameFor(override->symbol);
|
||||
|
||||
out << " [[function_constant(" << global->ConstantId() << ")]];";
|
||||
out << " [[function_constant(" << global->OverrideId().value << ")]];";
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -540,7 +540,7 @@ bool Builder::GenerateExecutionModes(const ast::Function* func, uint32_t id) {
|
|||
<< "expected a pipeline-overridable constant";
|
||||
}
|
||||
constant.is_spec_op = true;
|
||||
constant.constant_id = sem_const->ConstantId();
|
||||
constant.constant_id = sem_const->OverrideId().value;
|
||||
}
|
||||
|
||||
auto result = GenerateConstantIfNeeded(constant);
|
||||
|
@ -1340,7 +1340,7 @@ uint32_t Builder::GenerateTypeConstructorOrConversion(const sem::Call* call,
|
|||
// Generate the zero initializer if there are no values provided.
|
||||
if (args.IsEmpty()) {
|
||||
if (global_var && global_var->Declaration()->Is<ast::Override>()) {
|
||||
auto constant_id = global_var->ConstantId();
|
||||
auto constant_id = global_var->OverrideId().value;
|
||||
if (result_type->Is<sem::I32>()) {
|
||||
return GenerateConstantIfNeeded(ScalarConstant::I32(0).AsSpecOp(constant_id));
|
||||
}
|
||||
|
@ -1664,7 +1664,7 @@ uint32_t Builder::GenerateLiteralIfNeeded(const ast::Variable* var,
|
|||
auto* global = builder_.Sem().Get<sem::GlobalVariable>(var);
|
||||
if (global && global->Declaration()->Is<ast::Override>()) {
|
||||
constant.is_spec_op = true;
|
||||
constant.constant_id = global->ConstantId();
|
||||
constant.constant_id = global->OverrideId().value;
|
||||
}
|
||||
|
||||
Switch(
|
||||
|
|
Loading…
Reference in New Issue