diff --git a/src/tint/transform/single_entry_point.cc b/src/tint/transform/single_entry_point.cc index 87787ae0cc..386631ed0b 100644 --- a/src/tint/transform/single_entry_point.cc +++ b/src/tint/transform/single_entry_point.cc @@ -61,12 +61,7 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src, } auto& sem = src->Sem(); - - // Build set of referenced module-scope variables for faster lookups later. - std::unordered_set referenced_vars; - for (auto* var : sem.Get(entry_point)->TransitivelyReferencedGlobals()) { - referenced_vars.emplace(var->Declaration()); - } + auto& referenced_vars = sem.Get(entry_point)->TransitivelyReferencedGlobals(); // Clone any module-scope variables, types, and functions that are statically referenced by the // target entry point. @@ -74,11 +69,20 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src, Switch( decl, // [&](const ast::TypeDecl* ty) { - // TODO(jrprice): Strip unused types. + // Strip aliases that reference unused override declarations. + if (auto* arr = sem.Get(ty)->As()) { + for (auto* o : arr->TransitivelyReferencedOverrides()) { + if (!referenced_vars.Contains(o)) { + return; + } + } + } + + // TODO(jrprice): Strip other unused types. b.AST().AddTypeDecl(ctx.Clone(ty)); }, [&](const ast::Override* override) { - if (referenced_vars.count(override)) { + if (referenced_vars.Contains(sem.Get(override))) { if (!ast::HasAttribute(override->attributes)) { // If the override doesn't already have an @id() attribute, add one // so that its allocated ID so that it won't be affected by other @@ -91,7 +95,7 @@ Transform::ApplyResult SingleEntryPoint::Apply(const Program* src, } }, [&](const ast::Var* var) { - if (referenced_vars.count(var)) { + if (referenced_vars.Contains(sem.Get(var))) { b.AST().AddGlobalVariable(ctx.Clone(var)); } }, diff --git a/src/tint/transform/single_entry_point_test.cc b/src/tint/transform/single_entry_point_test.cc index 81009ebd22..b225008784 100644 --- a/src/tint/transform/single_entry_point_test.cc +++ b/src/tint/transform/single_entry_point_test.cc @@ -409,6 +409,43 @@ fn main() { EXPECT_EQ(expect, str(got)); } +TEST_F(SingleEntryPointTest, OverridableConstants_UnusedAliasForOverrideSizedArray) { + // Make sure we strip away aliases that reference unused overridable constants. + auto* src = R"( +@id(0) override c0 : u32; + +// This is all unused by the target entry point. +@id(1) override c1 : u32; +type arr_ty = array; +var arr : arr_ty; + +@compute @workgroup_size(64) +fn unused() { + arr[0] = 42; +} + +@compute @workgroup_size(64) +fn main() { + let local_d = c0; +} +)"; + + auto* expect = R"( +@id(0) override c0 : u32; + +@compute @workgroup_size(64) +fn main() { + let local_d = c0; +} +)"; + + SingleEntryPoint::Config cfg("main"); + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + EXPECT_EQ(expect, str(got)); +} + TEST_F(SingleEntryPointTest, CalledFunctions) { auto* src = R"( fn inner1() {