diff --git a/src/transform/single_entry_point.cc b/src/transform/single_entry_point.cc index e03c96204e..da9c273e68 100644 --- a/src/transform/single_entry_point.cc +++ b/src/transform/single_entry_point.cc @@ -74,7 +74,22 @@ void SingleEntryPoint::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) { // TODO(jrprice): Strip unused types. ctx.dst->AST().AddTypeDecl(ctx.Clone(ty)); } else if (auto* var = decl->As()) { - if (var->is_const || referenced_vars.count(var)) { + if (referenced_vars.count(var)) { + if (var->is_const) { + if (auto* deco = ast::GetDecoration( + var->decorations)) { + // It is an overridable constant + if (!deco->has_value) { + // If the decoration doesn't have numeric ID specified explicitly + // Make their ids explicitly assigned in the decoration so that + // they won't be affected by other stripped away constants + auto* global = sem.Get(var)->As(); + const auto* new_deco = + ctx.dst->Override(deco->source, global->ConstantId()); + ctx.Replace(deco, new_deco); + } + } + } ctx.dst->AST().AddGlobalVariable(ctx.Clone(var)); } } else if (auto* func = decl->As()) { diff --git a/src/transform/single_entry_point_test.cc b/src/transform/single_entry_point_test.cc index 677f4114e7..cdab09423a 100644 --- a/src/transform/single_entry_point_test.cc +++ b/src/transform/single_entry_point_test.cc @@ -219,14 +219,8 @@ fn comp_main2() { )"; auto* expect = R"( -let a : f32 = 1.0; - -let b : f32 = 1.0; - let c : f32 = 1.0; -let d : f32 = 1.0; - [[stage(compute), workgroup_size(1)]] fn comp_main1() { let local_c : f32 = c; @@ -242,6 +236,120 @@ fn comp_main1() { EXPECT_EQ(expect, str(got)); } +TEST_F(SingleEntryPointTest, OverridableConstants) { + auto* src = R"( +[[override(1001)]] let c1 : u32 = 1u; +[[override]] let c2 : u32 = 1u; +[[override(0)]] let c3 : u32 = 1u; +[[override(9999)]] let c4 : u32 = 1u; + +[[stage(compute), workgroup_size(1)]] +fn comp_main1() { + let local_d = c1; +} + +[[stage(compute), workgroup_size(1)]] +fn comp_main2() { + let local_d = c2; +} + +[[stage(compute), workgroup_size(1)]] +fn comp_main3() { + let local_d = c3; +} + +[[stage(compute), workgroup_size(1)]] +fn comp_main4() { + let local_d = c4; +} + +[[stage(compute), workgroup_size(1)]] +fn comp_main5() { + let local_d = 1u; +} +)"; + + { + SingleEntryPoint::Config cfg("comp_main1"); + auto* expect = R"( +[[override(1001)]] let c1 : u32 = 1u; + +[[stage(compute), workgroup_size(1)]] +fn comp_main1() { + let local_d = c1; +} +)"; + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + EXPECT_EQ(expect, str(got)); + } + + { + SingleEntryPoint::Config cfg("comp_main2"); + // The decorator is replaced with the one with explicit id + // And should not be affected by other constants stripped away + auto* expect = R"( +[[override(1)]] let c2 : u32 = 1u; + +[[stage(compute), workgroup_size(1)]] +fn comp_main2() { + let local_d = c2; +} +)"; + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + EXPECT_EQ(expect, str(got)); + } + + { + SingleEntryPoint::Config cfg("comp_main3"); + auto* expect = R"( +[[override(0)]] let c3 : u32 = 1u; + +[[stage(compute), workgroup_size(1)]] +fn comp_main3() { + let local_d = c3; +} +)"; + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + EXPECT_EQ(expect, str(got)); + } + + { + SingleEntryPoint::Config cfg("comp_main4"); + auto* expect = R"( +[[override(9999)]] let c4 : u32 = 1u; + +[[stage(compute), workgroup_size(1)]] +fn comp_main4() { + let local_d = c4; +} +)"; + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + EXPECT_EQ(expect, str(got)); + } + + { + SingleEntryPoint::Config cfg("comp_main5"); + auto* expect = R"( +[[stage(compute), workgroup_size(1)]] +fn comp_main5() { + let local_d = 1u; +} +)"; + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + EXPECT_EQ(expect, str(got)); + } +} + TEST_F(SingleEntryPointTest, CalledFunctions) { auto* src = R"( fn inner1() {