diff --git a/src/transform/utils/hoist_to_decl_before.cc b/src/transform/utils/hoist_to_decl_before.cc index 2448506cc0..aeff023b65 100644 --- a/src/transform/utils/hoist_to_decl_before.cc +++ b/src/transform/utils/hoist_to_decl_before.cc @@ -20,6 +20,8 @@ #include "src/sem/block_statement.h" #include "src/sem/for_loop_statement.h" #include "src/sem/if_statement.h" +#include "src/sem/reference_type.h" +#include "src/sem/variable.h" #include "src/utils/reverse.h" namespace tint::transform { @@ -265,10 +267,22 @@ class HoistToDeclBefore::State { const ast::Expression* expr, bool as_const, const char* decl_name = "") { - // Construct the let/var that holds the hoisted expr auto name = b.Symbols().New(decl_name); - auto* v = as_const ? b.Const(name, nullptr, ctx.Clone(expr)) - : b.Var(name, nullptr, ctx.Clone(expr)); + + auto* sem_expr = ctx.src->Sem().Get(expr); + bool is_ref = + sem_expr && + !sem_expr->Is() // Don't need to take a ref to a var + && sem_expr->Type()->Is(); + + auto* expr_clone = ctx.Clone(expr); + if (is_ref) { + expr_clone = b.AddressOf(expr_clone); + } + + // Construct the let/var that holds the hoisted expr + auto* v = as_const ? b.Const(name, nullptr, expr_clone) + : b.Var(name, nullptr, expr_clone); auto* decl = b.Decl(v); if (!InsertBefore(before_expr, decl)) { @@ -276,7 +290,11 @@ class HoistToDeclBefore::State { } // Replace the initializer expression with a reference to the let - ctx.Replace(expr, b.Expr(name)); + const ast::Expression* new_expr = b.Expr(name); + if (is_ref) { + new_expr = b.Deref(new_expr); + } + ctx.Replace(expr, new_expr); return true; } diff --git a/src/transform/utils/hoist_to_decl_before_test.cc b/src/transform/utils/hoist_to_decl_before_test.cc index 5dddde4ba1..8dd902d510 100644 --- a/src/transform/utils/hoist_to_decl_before_test.cc +++ b/src/transform/utils/hoist_to_decl_before_test.cc @@ -217,5 +217,76 @@ fn f() { EXPECT_EQ(expect, str(cloned)); } +TEST_F(HoistToDeclBeforeTest, Array1D) { + // fn f() { + // var a : array; + // var b = a[0]; + // } + ProgramBuilder b; + auto* var1 = b.Decl(b.Var("a", b.ty.array())); + auto* expr = b.IndexAccessor("a", 0); + auto* var2 = b.Decl(b.Var("b", nullptr, expr)); + b.Func("f", {}, b.ty.void_(), {var1, var2}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* sem_expr = ctx.src->Sem().Get(expr); + hoistToDeclBefore.Add(sem_expr, expr, true); + hoistToDeclBefore.Apply(); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn f() { + var a : array; + let tint_symbol = &(a[0]); + var b = *(tint_symbol); +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + +TEST_F(HoistToDeclBeforeTest, Array2D) { + // fn f() { + // var a : array, 10>; + // var b = a[0][0]; + // } + ProgramBuilder b; + + auto* var1 = + b.Decl(b.Var("a", b.ty.array(b.ty.array(), 10))); + auto* expr = b.IndexAccessor(b.IndexAccessor("a", 0), 0); + auto* var2 = b.Decl(b.Var("b", nullptr, expr)); + b.Func("f", {}, b.ty.void_(), {var1, var2}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + std::cout << str(original) << std::endl; + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* sem_expr = ctx.src->Sem().Get(expr); + hoistToDeclBefore.Add(sem_expr, expr, true); + hoistToDeclBefore.Apply(); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn f() { + var a : array, 10>; + let tint_symbol = &(a[0][0]); + var b = *(tint_symbol); +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + } // namespace } // namespace tint::transform