diff --git a/src/tint/resolver/override_test.cc b/src/tint/resolver/override_test.cc index 132bd55431..a97c92d529 100644 --- a/src/tint/resolver/override_test.cc +++ b/src/tint/resolver/override_test.cc @@ -142,7 +142,9 @@ TEST_F(ResolverOverrideTest, TransitiveReferences_ViaOverrideInit) { EXPECT_TRUE(r()->Resolve()) << r()->error(); { - auto& refs = Sem().Get(b)->TransitivelyReferencedOverrides(); + auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get(b)); + ASSERT_NE(r, nullptr); + auto& refs = *r; ASSERT_EQ(refs.Length(), 1u); EXPECT_EQ(refs[0], Sem().Get(a)); } @@ -167,7 +169,9 @@ TEST_F(ResolverOverrideTest, TransitiveReferences_ViaPrivateInit) { EXPECT_TRUE(r()->Resolve()) << r()->error(); { - auto& refs = Sem().Get(b)->TransitivelyReferencedOverrides(); + auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get(b)); + ASSERT_NE(r, nullptr); + auto& refs = *r; ASSERT_EQ(refs.Length(), 1u); EXPECT_EQ(refs[0], Sem().Get(a)); } @@ -215,14 +219,18 @@ TEST_F(ResolverOverrideTest, TransitiveReferences_ViaArraySize) { EXPECT_TRUE(r()->Resolve()) << r()->error(); { - auto& refs = Sem().Get(arr_ty)->TransitivelyReferencedOverrides(); + auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get(arr_ty)); + ASSERT_NE(r, nullptr); + auto& refs = *r; ASSERT_EQ(refs.Length(), 2u); EXPECT_EQ(refs[0], Sem().Get(b)); EXPECT_EQ(refs[1], Sem().Get(a)); } { - auto& refs = Sem().Get(arr)->TransitivelyReferencedOverrides(); + auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get(arr)); + ASSERT_NE(r, nullptr); + auto& refs = *r; ASSERT_EQ(refs.Length(), 2u); EXPECT_EQ(refs[0], Sem().Get(b)); EXPECT_EQ(refs[1], Sem().Get(a)); @@ -251,14 +259,18 @@ TEST_F(ResolverOverrideTest, TransitiveReferences_ViaArraySize_Alias) { EXPECT_TRUE(r()->Resolve()) << r()->error(); { - auto& refs = Sem().Get(arr_ty->type)->TransitivelyReferencedOverrides(); + auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get(arr_ty->type)); + ASSERT_NE(r, nullptr); + auto& refs = *r; ASSERT_EQ(refs.Length(), 2u); EXPECT_EQ(refs[0], Sem().Get(b)); EXPECT_EQ(refs[1], Sem().Get(a)); } { - auto& refs = Sem().Get(arr)->TransitivelyReferencedOverrides(); + auto* r = Sem().TransitivelyReferencedOverrides(Sem().Get(arr)); + ASSERT_NE(r, nullptr); + auto& refs = *r; ASSERT_EQ(refs.Length(), 2u); EXPECT_EQ(refs[0], Sem().Get(b)); EXPECT_EQ(refs[1], Sem().Get(a)); diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 9fbc7f0560..e5b188c880 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -917,11 +917,14 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) { // Track the pipeline-overridable constants that are transitively referenced by this variable. for (auto* var : transitively_referenced_overrides) { - sem->AddTransitivelyReferencedOverride(var); + builder_->Sem().AddTransitivelyReferencedOverride(sem, var); } if (auto* arr = sem->Type()->UnwrapRef()->As()) { - for (auto* var : arr->TransitivelyReferencedOverrides()) { - sem->AddTransitivelyReferencedOverride(var); + auto* refs = builder_->Sem().TransitivelyReferencedOverrides(arr); + if (refs) { + for (auto* var : *refs) { + builder_->Sem().AddTransitivelyReferencedOverride(sem, var); + } } } @@ -2553,8 +2556,11 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) { if (current_function_) { if (global) { current_function_->AddDirectlyReferencedGlobal(global); - for (auto* var : global->TransitivelyReferencedOverrides()) { - current_function_->AddTransitivelyReferencedGlobal(var); + auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global); + if (refs) { + for (auto* var : *refs) { + current_function_->AddTransitivelyReferencedGlobal(var); + } } } } else if (variable->Declaration()->Is()) { @@ -2562,8 +2568,11 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) { // Track the reference to this pipeline-overridable constant and any other // pipeline-overridable constants that it references. resolved_overrides_->Add(global); - for (auto* var : global->TransitivelyReferencedOverrides()) { - resolved_overrides_->Add(var); + auto* refs = builder_->Sem().TransitivelyReferencedOverrides(global); + if (refs) { + for (auto* var : *refs) { + resolved_overrides_->Add(var); + } } } } else if (variable->Declaration()->Is()) { @@ -2956,7 +2965,7 @@ sem::Array* Resolver::Array(const ast::Array* arr) { // Track the pipeline-overridable constants that are transitively referenced by this array // type. for (auto* var : transitively_referenced_overrides) { - out->AddTransitivelyReferencedOverride(var); + builder_->Sem().AddTransitivelyReferencedOverride(out, var); } return out; diff --git a/src/tint/sem/array.h b/src/tint/sem/array.h index 4047ae428f..4d1ed7dce6 100644 --- a/src/tint/sem/array.h +++ b/src/tint/sem/array.h @@ -230,17 +230,6 @@ class Array final : public Castable { /// @returns true if this array is runtime sized bool IsRuntimeSized() const { return std::holds_alternative(count_); } - /// Records that this array type (transitively) references the given override variable. - /// @param var the module-scope override variable - void AddTransitivelyReferencedOverride(const GlobalVariable* var) { - referenced_overrides_.Add(var); - } - - /// @returns all transitively referenced override variables - const utils::UniqueVector& TransitivelyReferencedOverrides() const { - return referenced_overrides_; - } - /// @param symbols the program's symbol table /// @returns the name for this type that closely resembles how it would be /// declared in WGSL. @@ -253,7 +242,6 @@ class Array final : public Castable { const uint32_t size_; const uint32_t stride_; const uint32_t implicit_stride_; - utils::UniqueVector referenced_overrides_; }; } // namespace tint::sem diff --git a/src/tint/sem/info.h b/src/tint/sem/info.h index 894b408fe0..28b41a6c47 100644 --- a/src/tint/sem/info.h +++ b/src/tint/sem/info.h @@ -24,6 +24,7 @@ #include "src/tint/debug.h" #include "src/tint/sem/node.h" #include "src/tint/sem/type_mappings.h" +#include "src/tint/utils/unique_vector.h" // Forward declarations namespace tint::sem { @@ -44,6 +45,9 @@ class Info { using GetResultType = std::conditional_t::value, SemanticNodeTypeFor, SEM>; + /// Alias to a unique vector of transitively referenced global variables + using TransitivelyReferenced = utils::UniqueVector; + /// Constructor Info(); @@ -117,9 +121,30 @@ class Info { /// @returns the semantic module. const sem::Module* Module() const { return module_; } + /// Records that this variable (transitively) references the given override variable. + /// @param from the item the variable is referenced from + /// @param var the module-scope override variable + void AddTransitivelyReferencedOverride(const CastableBase* from, const GlobalVariable* var) { + if (referenced_overrides_.count(from) == 0) { + referenced_overrides_.insert({from, TransitivelyReferenced{}}); + } + referenced_overrides_[from].Add(var); + } + + /// @param from the key to look up + /// @returns all transitively referenced override variables or nullptr if none set + const TransitivelyReferenced* TransitivelyReferencedOverrides(const CastableBase* from) const { + if (referenced_overrides_.count(from) == 0) { + return nullptr; + } + return &referenced_overrides_.at(from); + } + private: // AST node index to semantic node std::vector nodes_; + // Lists transitively referenced overrides for the given item + std::unordered_map referenced_overrides_; // The semantic module sem::Module* module_ = nullptr; }; diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h index 08dfa98928..b9777973c2 100644 --- a/src/tint/sem/variable.h +++ b/src/tint/sem/variable.h @@ -183,23 +183,11 @@ class GlobalVariable final : public Castable { /// @returns the location value for the parameter, if set std::optional Location() const { return location_; } - /// Records that this variable (transitively) references the given override variable. - /// @param var the module-scope override variable - void AddTransitivelyReferencedOverride(const GlobalVariable* var) { - referenced_overrides_.Add(var); - } - - /// @returns all transitively referenced override variables - const utils::UniqueVector& TransitivelyReferencedOverrides() const { - return referenced_overrides_; - } - private: const sem::BindingPoint binding_point_; tint::OverrideId override_id_; std::optional location_; - utils::UniqueVector referenced_overrides_; }; /// Parameter is a function parameter diff --git a/src/tint/transform/single_entry_point.cc b/src/tint/transform/single_entry_point.cc index 386631ed0b..694cd2ddec 100644 --- a/src/tint/transform/single_entry_point.cc +++ b/src/tint/transform/single_entry_point.cc @@ -71,9 +71,12 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src, [&](const ast::TypeDecl* ty) { // Strip aliases that reference unused override declarations. if (auto* arr = sem.Get(ty)->As()) { - for (auto* o : arr->TransitivelyReferencedOverrides()) { - if (!referenced_vars.Contains(o)) { - return; + auto* refs = sem.TransitivelyReferencedOverrides(arr); + if (refs) { + for (auto* o : *refs) { + if (!referenced_vars.Contains(o)) { + return; + } } } }