diff --git a/src/dawn/tests/end2end/ShaderValidationTests.cpp b/src/dawn/tests/end2end/ShaderValidationTests.cpp index 9f964c9208..923d70233a 100644 --- a/src/dawn/tests/end2end/ShaderValidationTests.cpp +++ b/src/dawn/tests/end2end/ShaderValidationTests.cpp @@ -340,6 +340,7 @@ TEST_P(WorkgroupSizeValidationTest, ValidationAfterOverrideStorageSize) { auto CheckPipelineWithWorkgroupStorage = [this](bool success, uint32_t vec4_count, uint32_t mat4_count) { + std::vector constants; std::ostringstream ss; std::ostringstream body; ss << "override a: u32;"; @@ -347,19 +348,18 @@ TEST_P(WorkgroupSizeValidationTest, ValidationAfterOverrideStorageSize) { if (vec4_count > 0) { ss << "var vec4_data: array, a>;"; body << "_ = vec4_data[0];"; + constants.push_back({nullptr, "a", static_cast(vec4_count)}); } if (mat4_count > 0) { ss << "var mat4_data: array, b>;"; body << "_ = mat4_data[0];"; + constants.push_back({nullptr, "b", static_cast(mat4_count)}); } ss << "@compute @workgroup_size(1) fn main() { " << body.str() << " }"; wgpu::ComputePipelineDescriptor desc; desc.compute.entryPoint = "main"; desc.compute.module = utils::CreateShaderModule(device, ss.str().c_str()); - - std::vector constants{{nullptr, "a", static_cast(vec4_count)}, - {nullptr, "b", static_cast(mat4_count)}}; desc.compute.constants = constants.data(); desc.compute.constantCount = constants.size(); diff --git a/src/tint/inspector/inspector_test.cc b/src/tint/inspector/inspector_test.cc index 6f3aa93665..05f3c196b0 100644 --- a/src/tint/inspector/inspector_test.cc +++ b/src/tint/inspector/inspector_test.cc @@ -716,6 +716,193 @@ TEST_F(InspectorGetEntryPointTest, OverrideSomeReferenced) { EXPECT_EQ(1, result[0].overrides[0].id.value); } +TEST_F(InspectorGetEntryPointTest, OverrideReferencedIndirectly) { + Override("foo", ty.f32()); + Override("bar", ty.f32(), Mul(2_a, "foo")); + MakePlainGlobalReferenceBodyFunction("ep_func", "bar", ty.f32(), + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(1_i), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(2u, result[0].overrides.size()); + EXPECT_EQ("bar", result[0].overrides[0].name); + EXPECT_TRUE(result[0].overrides[0].is_initialized); + EXPECT_EQ("foo", result[0].overrides[1].name); + EXPECT_FALSE(result[0].overrides[1].is_initialized); +} + +TEST_F(InspectorGetEntryPointTest, OverrideReferencedIndirectly_ViaPrivateInitializer) { + Override("foo", ty.f32()); + GlobalVar("bar", ast::AddressSpace::kPrivate, ty.f32(), Mul(2_a, "foo")); + MakePlainGlobalReferenceBodyFunction("ep_func", "bar", ty.f32(), + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(1_i), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(1u, result[0].overrides.size()); + EXPECT_EQ("foo", result[0].overrides[0].name); + EXPECT_FALSE(result[0].overrides[0].is_initialized); +} + +TEST_F(InspectorGetEntryPointTest, OverrideReferencedIndirectly_MultipleEntryPoints) { + Override("foo1", ty.f32()); + Override("bar1", ty.f32(), Mul(2_a, "foo1")); + MakePlainGlobalReferenceBodyFunction("ep_func1", "bar1", ty.f32(), + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(1_i), + }); + Override("foo2", ty.f32()); + Override("bar2", ty.f32(), Mul(2_a, "foo2")); + MakePlainGlobalReferenceBodyFunction("ep_func2", "bar2", ty.f32(), + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(1_i), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(2u, result.size()); + + ASSERT_EQ(2u, result[0].overrides.size()); + EXPECT_EQ("bar1", result[0].overrides[0].name); + EXPECT_TRUE(result[0].overrides[0].is_initialized); + EXPECT_EQ("foo1", result[0].overrides[1].name); + EXPECT_FALSE(result[0].overrides[1].is_initialized); + + ASSERT_EQ(2u, result[1].overrides.size()); + EXPECT_EQ("bar2", result[1].overrides[0].name); + EXPECT_TRUE(result[1].overrides[0].is_initialized); + EXPECT_EQ("foo2", result[1].overrides[1].name); + EXPECT_FALSE(result[1].overrides[1].is_initialized); +} + +TEST_F(InspectorGetEntryPointTest, OverrideReferencedByAttribute) { + Override("wgsize", ty.u32()); + MakeEmptyBodyFunction("ep_func", utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize("wgsize"), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(1u, result[0].overrides.size()); + EXPECT_EQ("wgsize", result[0].overrides[0].name); + EXPECT_FALSE(result[0].overrides[0].is_initialized); +} + +TEST_F(InspectorGetEntryPointTest, OverrideReferencedByAttributeIndirectly) { + Override("foo", ty.u32()); + Override("bar", ty.u32(), Mul(2_a, "foo")); + MakeEmptyBodyFunction("ep_func", utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(Mul(2_a, Expr("bar"))), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(2u, result[0].overrides.size()); + EXPECT_EQ("bar", result[0].overrides[0].name); + EXPECT_TRUE(result[0].overrides[0].is_initialized); + EXPECT_EQ("foo", result[0].overrides[1].name); + EXPECT_FALSE(result[0].overrides[1].is_initialized); +} + +TEST_F(InspectorGetEntryPointTest, OverrideReferencedByArraySize) { + Override("size", ty.u32()); + GlobalVar("v", ast::AddressSpace::kWorkgroup, ty.array(ty.f32(), "size")); + Func("ep", utils::Empty, ty.void_(), + utils::Vector{ + Assign(Phony(), IndexAccessor("v", 0_a)), + }, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(1_i), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(1u, result[0].overrides.size()); + EXPECT_EQ("size", result[0].overrides[0].name); + EXPECT_FALSE(result[0].overrides[0].is_initialized); +} + +TEST_F(InspectorGetEntryPointTest, OverrideReferencedByArraySizeIndirectly) { + Override("foo", ty.u32()); + Override("bar", ty.u32(), Mul(2_a, "foo")); + GlobalVar("v", ast::AddressSpace::kWorkgroup, ty.array(ty.f32(), Mul(2_a, Expr("bar")))); + Func("ep", utils::Empty, ty.void_(), + utils::Vector{ + Assign(Phony(), IndexAccessor("v", 0_a)), + }, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(1_i), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(2u, result[0].overrides.size()); + EXPECT_EQ("bar", result[0].overrides[0].name); + EXPECT_TRUE(result[0].overrides[0].is_initialized); + EXPECT_EQ("foo", result[0].overrides[1].name); + EXPECT_FALSE(result[0].overrides[1].is_initialized); +} + +TEST_F(InspectorGetEntryPointTest, OverrideReferencedByArraySizeViaAlias) { + Override("foo", ty.u32()); + Override("bar", ty.u32(), Expr("foo")); + Alias("MyArray", ty.array(ty.f32(), Mul(2_a, Expr("bar")))); + Override("zoo", ty.u32()); + Alias("MyArrayUnused", ty.array(ty.f32(), Mul(2_a, Expr("zoo")))); + GlobalVar("v", ast::AddressSpace::kWorkgroup, ty.type_name("MyArray")); + Func("ep", utils::Empty, ty.void_(), + utils::Vector{ + Assign(Phony(), IndexAccessor("v", 0_a)), + }, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(1_i), + }); + + Inspector& inspector = Build(); + + auto result = inspector.GetEntryPoints(); + + ASSERT_EQ(1u, result.size()); + ASSERT_EQ(2u, result[0].overrides.size()); + EXPECT_EQ("bar", result[0].overrides[0].name); + EXPECT_TRUE(result[0].overrides[0].is_initialized); + EXPECT_EQ("foo", result[0].overrides[1].name); + EXPECT_FALSE(result[0].overrides[1].is_initialized); +} + TEST_F(InspectorGetEntryPointTest, OverrideTypes) { Override("bool_var", ty.bool_()); Override("float_var", ty.f32()); diff --git a/src/tint/resolver/override_test.cc b/src/tint/resolver/override_test.cc index a12d3c148b..132bd55431 100644 --- a/src/tint/resolver/override_test.cc +++ b/src/tint/resolver/override_test.cc @@ -66,7 +66,6 @@ TEST_F(ResolverOverrideTest, WithoutId) { } TEST_F(ResolverOverrideTest, WithAndWithoutIds) { - std::vector variables; auto* a = Override("a", ty.f32(), Expr(1_f)); auto* b = Override("b", ty.f32(), Expr(1_f)); auto* c = Override("c", ty.f32(), Expr(1_f), Id(2_u)); @@ -113,5 +112,217 @@ TEST_F(ResolverOverrideTest, F16_TemporallyBan) { EXPECT_EQ(r()->error(), "12:34 error: 'override' of type f16 is not implemented yet"); } +TEST_F(ResolverOverrideTest, TransitiveReferences_DirectUse) { + auto* a = Override("a", ty.f32()); + auto* b = Override("b", ty.f32(), Expr(1_f)); + Override("unused", ty.f32(), Expr(1_f)); + auto* func = Func("foo", utils::Empty, ty.void_(), + utils::Vector{ + Assign(Phony(), "a"), + Assign(Phony(), "b"), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto& refs = Sem().Get(func)->TransitivelyReferencedGlobals(); + ASSERT_EQ(refs.Length(), 2u); + EXPECT_EQ(refs[0], Sem().Get(a)); + EXPECT_EQ(refs[1], Sem().Get(b)); +} + +TEST_F(ResolverOverrideTest, TransitiveReferences_ViaOverrideInit) { + auto* a = Override("a", ty.f32()); + auto* b = Override("b", ty.f32(), Mul(2_a, "a")); + Override("unused", ty.f32(), Expr(1_f)); + auto* func = Func("foo", utils::Empty, ty.void_(), + utils::Vector{ + Assign(Phony(), "b"), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + { + auto& refs = Sem().Get(b)->TransitivelyReferencedOverrides(); + ASSERT_EQ(refs.Length(), 1u); + EXPECT_EQ(refs[0], Sem().Get(a)); + } + + { + auto& refs = Sem().Get(func)->TransitivelyReferencedGlobals(); + ASSERT_EQ(refs.Length(), 2u); + EXPECT_EQ(refs[0], Sem().Get(b)); + EXPECT_EQ(refs[1], Sem().Get(a)); + } +} + +TEST_F(ResolverOverrideTest, TransitiveReferences_ViaPrivateInit) { + auto* a = Override("a", ty.f32()); + auto* b = GlobalVar("b", ast::AddressSpace::kPrivate, ty.f32(), Mul(2_a, "a")); + Override("unused", ty.f32(), Expr(1_f)); + auto* func = Func("foo", utils::Empty, ty.void_(), + utils::Vector{ + Assign(Phony(), "b"), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + { + auto& refs = Sem().Get(b)->TransitivelyReferencedOverrides(); + ASSERT_EQ(refs.Length(), 1u); + EXPECT_EQ(refs[0], Sem().Get(a)); + } + + { + auto& refs = Sem().Get(func)->TransitivelyReferencedGlobals(); + ASSERT_EQ(refs.Length(), 2u); + EXPECT_EQ(refs[0], Sem().Get(b)); + EXPECT_EQ(refs[1], Sem().Get(a)); + } +} + +TEST_F(ResolverOverrideTest, TransitiveReferences_ViaAttribute) { + auto* a = Override("a", ty.i32()); + auto* b = Override("b", ty.i32(), Mul(2_a, "a")); + Override("unused", ty.i32(), Expr(1_a)); + auto* func = Func("foo", utils::Empty, ty.void_(), + utils::Vector{ + Assign(Phony(), "b"), + }, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(Mul(2_a, "b")), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto& refs = Sem().Get(func)->TransitivelyReferencedGlobals(); + ASSERT_EQ(refs.Length(), 2u); + EXPECT_EQ(refs[0], Sem().Get(b)); + EXPECT_EQ(refs[1], Sem().Get(a)); +} + +TEST_F(ResolverOverrideTest, TransitiveReferences_ViaArraySize) { + auto* a = Override("a", ty.i32()); + auto* b = Override("b", ty.i32(), Mul(2_a, "a")); + auto* arr_ty = ty.array(ty.i32(), Mul(2_a, "b")); + auto* arr = GlobalVar("arr", ast::AddressSpace::kWorkgroup, arr_ty); + Override("unused", ty.i32(), Expr(1_a)); + auto* func = Func("foo", utils::Empty, ty.void_(), + utils::Vector{ + Assign(IndexAccessor("arr", 0_a), 42_a), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + { + auto& refs = Sem().Get(arr_ty)->TransitivelyReferencedOverrides(); + ASSERT_EQ(refs.Length(), 2u); + EXPECT_EQ(refs[0], Sem().Get(b)); + EXPECT_EQ(refs[1], Sem().Get(a)); + } + + { + auto& refs = Sem().Get(arr)->TransitivelyReferencedOverrides(); + ASSERT_EQ(refs.Length(), 2u); + EXPECT_EQ(refs[0], Sem().Get(b)); + EXPECT_EQ(refs[1], Sem().Get(a)); + } + + { + auto& refs = Sem().Get(func)->TransitivelyReferencedGlobals(); + ASSERT_EQ(refs.Length(), 3u); + EXPECT_EQ(refs[0], Sem().Get(arr)); + EXPECT_EQ(refs[1], Sem().Get(b)); + EXPECT_EQ(refs[2], Sem().Get(a)); + } +} + +TEST_F(ResolverOverrideTest, TransitiveReferences_ViaArraySize_Alias) { + auto* a = Override("a", ty.i32()); + auto* b = Override("b", ty.i32(), Mul(2_a, "a")); + auto* arr_ty = Alias("arr_ty", ty.array(ty.i32(), Mul(2_a, "b"))); + auto* arr = GlobalVar("arr", ast::AddressSpace::kWorkgroup, ty.type_name("arr_ty")); + Override("unused", ty.i32(), Expr(1_a)); + auto* func = Func("foo", utils::Empty, ty.void_(), + utils::Vector{ + Assign(IndexAccessor("arr", 0_a), 42_a), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + { + auto& refs = Sem().Get(arr_ty->type)->TransitivelyReferencedOverrides(); + ASSERT_EQ(refs.Length(), 2u); + EXPECT_EQ(refs[0], Sem().Get(b)); + EXPECT_EQ(refs[1], Sem().Get(a)); + } + + { + auto& refs = Sem().Get(arr)->TransitivelyReferencedOverrides(); + ASSERT_EQ(refs.Length(), 2u); + EXPECT_EQ(refs[0], Sem().Get(b)); + EXPECT_EQ(refs[1], Sem().Get(a)); + } + + { + auto& refs = Sem().Get(func)->TransitivelyReferencedGlobals(); + ASSERT_EQ(refs.Length(), 3u); + EXPECT_EQ(refs[0], Sem().Get(arr)); + EXPECT_EQ(refs[1], Sem().Get(b)); + EXPECT_EQ(refs[2], Sem().Get(a)); + } +} + +TEST_F(ResolverOverrideTest, TransitiveReferences_MultipleEntryPoints) { + auto* a = Override("a", ty.i32()); + auto* b1 = Override("b1", ty.i32(), Mul(2_a, "a")); + auto* b2 = Override("b2", ty.i32(), Mul(2_a, "a")); + auto* c1 = Override("c1", ty.i32()); + auto* c2 = Override("c2", ty.i32()); + auto* d = Override("d", ty.i32()); + Alias("arr_ty1", ty.array(ty.i32(), Mul("b1", "c1"))); + Alias("arr_ty2", ty.array(ty.i32(), Mul("b2", "c2"))); + auto* arr1 = GlobalVar("arr1", ast::AddressSpace::kWorkgroup, ty.type_name("arr_ty1")); + auto* arr2 = GlobalVar("arr2", ast::AddressSpace::kWorkgroup, ty.type_name("arr_ty2")); + Override("unused", ty.i32(), Expr(1_a)); + auto* func1 = Func("foo1", utils::Empty, ty.void_(), + utils::Vector{ + Assign(IndexAccessor("arr1", 0_a), 42_a), + }, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(Mul(2_a, "d")), + }); + auto* func2 = Func("foo2", utils::Empty, ty.void_(), + utils::Vector{ + Assign(IndexAccessor("arr2", 0_a), 42_a), + }, + utils::Vector{ + Stage(ast::PipelineStage::kCompute), + WorkgroupSize(64_a), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + { + auto& refs = Sem().Get(func1)->TransitivelyReferencedGlobals(); + ASSERT_EQ(refs.Length(), 5u); + EXPECT_EQ(refs[0], Sem().Get(d)); + EXPECT_EQ(refs[1], Sem().Get(arr1)); + EXPECT_EQ(refs[2], Sem().Get(b1)); + EXPECT_EQ(refs[3], Sem().Get(a)); + EXPECT_EQ(refs[4], Sem().Get(c1)); + } + + { + auto& refs = Sem().Get(func2)->TransitivelyReferencedGlobals(); + ASSERT_EQ(refs.Length(), 4u); + EXPECT_EQ(refs[0], Sem().Get(arr2)); + EXPECT_EQ(refs[1], Sem().Get(b2)); + EXPECT_EQ(refs[2], Sem().Get(a)); + EXPECT_EQ(refs[3], Sem().Get(c2)); + } +} + } // namespace } // namespace tint::resolver diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 48b35fa190..143334e0ce 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -875,6 +875,9 @@ void Resolver::SetShadows() { } sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) { + utils::UniqueVector transitively_referenced_overrides; + TINT_SCOPED_ASSIGNMENT(resolved_overrides_, &transitively_referenced_overrides); + auto* sem = As(Variable(v, /* is_global */ true)); if (!sem) { return nullptr; @@ -898,6 +901,16 @@ sem::GlobalVariable* Resolver::GlobalVariable(const ast::Variable* v) { return nullptr; } + // Track the pipeline-overridable constants that are transitively referenced by this variable. + for (auto* var : transitively_referenced_overrides) { + sem->AddTransitivelyReferencedOverride(var); + } + if (auto* arr = sem->Type()->UnwrapRef()->As()) { + for (auto* var : arr->TransitivelyReferencedOverrides()) { + sem->AddTransitivelyReferencedOverride(var); + } + } + return sem; } @@ -2477,9 +2490,22 @@ sem::Expression* Resolver::Identifier(const ast::IdentifierExpression* expr) { } } + auto* global = variable->As(); if (current_function_) { - if (auto* global = variable->As()) { + if (global) { current_function_->AddDirectlyReferencedGlobal(global); + for (auto* var : global->TransitivelyReferencedOverrides()) { + current_function_->AddTransitivelyReferencedGlobal(var); + } + } + } else if (variable->Declaration()->Is()) { + if (resolved_overrides_) { + // Track the reference to this pipeline-overridable constant and any other + // pipeline-overridable constants that it references. + resolved_overrides_->Add(global); + for (auto* var : global->TransitivelyReferencedOverrides()) { + resolved_overrides_->Add(var); + } } } else if (variable->Declaration()->Is()) { // Use of a module-scope 'var' outside of a function. @@ -2828,6 +2854,9 @@ sem::Array* Resolver::Array(const ast::Array* arr) { return nullptr; } + utils::UniqueVector transitively_referenced_overrides; + TINT_SCOPED_ASSIGNMENT(resolved_overrides_, &transitively_referenced_overrides); + auto* el_ty = Type(arr->type); if (!el_ty) { return nullptr; @@ -2865,6 +2894,11 @@ sem::Array* Resolver::Array(const ast::Array* arr) { } } + // Track the pipeline-overridable constants that are transitively referenced by this array type. + for (auto* var : transitively_referenced_overrides) { + out->AddTransitivelyReferencedOverride(var); + } + return out; } diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index 46527ff86e..0b31fe1131 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -470,6 +470,7 @@ class Resolver { sem::Statement* current_statement_ = nullptr; sem::CompoundStatement* current_compound_statement_ = nullptr; uint32_t current_scoping_depth_ = 0; + utils::UniqueVector* resolved_overrides_ = nullptr; }; } // namespace tint::resolver diff --git a/src/tint/sem/array.h b/src/tint/sem/array.h index 88e63739ce..4047ae428f 100644 --- a/src/tint/sem/array.h +++ b/src/tint/sem/array.h @@ -23,6 +23,7 @@ #include "src/tint/sem/node.h" #include "src/tint/sem/type.h" #include "src/tint/utils/compiler_macros.h" +#include "src/tint/utils/unique_vector.h" // Forward declarations namespace tint::sem { @@ -229,6 +230,17 @@ class Array final : public Castable { /// @returns true if this array is runtime sized bool IsRuntimeSized() const { return std::holds_alternative(count_); } + /// Records that this array type (transitively) references the given override variable. + /// @param var the module-scope override variable + void AddTransitivelyReferencedOverride(const GlobalVariable* var) { + referenced_overrides_.Add(var); + } + + /// @returns all transitively referenced override variables + const utils::UniqueVector& TransitivelyReferencedOverrides() const { + return referenced_overrides_; + } + /// @param symbols the program's symbol table /// @returns the name for this type that closely resembles how it would be /// declared in WGSL. @@ -241,6 +253,7 @@ class Array final : public Castable { const uint32_t size_; const uint32_t stride_; const uint32_t implicit_stride_; + utils::UniqueVector referenced_overrides_; }; } // namespace tint::sem diff --git a/src/tint/sem/variable.h b/src/tint/sem/variable.h index 634f5da7ef..08ffe99114 100644 --- a/src/tint/sem/variable.h +++ b/src/tint/sem/variable.h @@ -27,6 +27,7 @@ #include "src/tint/sem/binding_point.h" #include "src/tint/sem/expression.h" #include "src/tint/sem/parameter_usage.h" +#include "src/tint/utils/unique_vector.h" // Forward declarations namespace tint::ast { @@ -182,11 +183,23 @@ class GlobalVariable final : public Castable { /// @returns the location value for the parameter, if set std::optional Location() const { return location_; } + /// Records that this variable (transitively) references the given override variable. + /// @param var the module-scope override variable + void AddTransitivelyReferencedOverride(const GlobalVariable* var) { + referenced_overrides_.Add(var); + } + + /// @returns all transitively referenced override variables + const utils::UniqueVector& TransitivelyReferencedOverrides() const { + return referenced_overrides_; + } + private: const sem::BindingPoint binding_point_; tint::OverrideId override_id_; std::optional location_; + utils::UniqueVector referenced_overrides_; }; /// Parameter is a function parameter diff --git a/src/tint/transform/single_entry_point_test.cc b/src/tint/transform/single_entry_point_test.cc index e9fa68f291..81009ebd22 100644 --- a/src/tint/transform/single_entry_point_test.cc +++ b/src/tint/transform/single_entry_point_test.cc @@ -374,6 +374,41 @@ fn comp_main5() { } } +TEST_F(SingleEntryPointTest, OverridableConstants_TransitiveUses) { + // Make sure we do not strip away transitive uses of overridable constants. + auto* src = R"( +@id(0) override c0 : u32; + +@id(1) override c1 : u32 = (2 * c0); + +@id(2) override c2 : u32; + +@id(3) override c3 : u32 = (2 * c2); + +@id(4) override c4 : u32; + +@id(5) override c5 : u32 = (2 * c4); + +type arr_ty = array; + +var arr : arr_ty; + +@compute @workgroup_size(1, 1, (2 * c3)) +fn main() { + let local_d = c1; + arr[0] = 42; +} +)"; + + auto* expect = src; + + SingleEntryPoint::Config cfg("main"); + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + EXPECT_EQ(expect, str(got)); +} + TEST_F(SingleEntryPointTest, CalledFunctions) { auto* src = R"( fn inner1() {