diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc index 508b9745fa..ff345eb62d 100644 --- a/src/tint/transform/utils/hoist_to_decl_before.cc +++ b/src/tint/transform/utils/hoist_to_decl_before.cc @@ -56,29 +56,39 @@ class HoistToDeclBefore::State { /// For-loops that need to be decomposed to loops. std::unordered_map loops; - /// If statements with 'else if's that need to be decomposed to 'else { if - /// }' + /// If statements with 'else if's that need to be decomposed to 'else {if}' std::unordered_map ifs; - // Inserts `decl` before `sem_expr`, possibly marking a for-loop to be - // converted to a loop, or an else-if to an else { if }. - bool InsertBefore(const sem::Expression* sem_expr, + // Inserts `decl` before `before_expr`, possibly marking a for-loop to be + // converted to a loop, or an else-if to an else { if }. If `decl` is nullptr, + // for-loop and else-if conversions are marked, but no hoisting takes place. + bool InsertBefore(const sem::Expression* before_expr, const ast::VariableDeclStatement* decl) { - auto* sem_stmt = sem_expr->Stmt(); + auto* sem_stmt = before_expr->Stmt(); auto* stmt = sem_stmt->Declaration(); if (auto* else_if = sem_stmt->As()) { // Expression used in 'else if' condition. // Need to convert 'else if' to 'else { if }'. auto& if_info = ifs[else_if->Parent()->As()]; - if_info.else_ifs[else_if].cond_decls.push_back(decl); + + // Index the map to convert this else if, even if `decl` is nullptr. + auto& decls = if_info.else_ifs[else_if].cond_decls; + if (decl) { + decls.emplace_back(decl); + } return true; } if (auto* fl = sem_stmt->As()) { // Expression used in for-loop condition. // For-loop needs to be decomposed to a loop. - loops[fl].cond_decls.emplace_back(decl); + + // Index the map to convert this for-loop, even if `decl` is nullptr. + auto& decls = loops[fl].cond_decls; + if (decl) { + decls.emplace_back(decl); + } return true; } @@ -86,7 +96,9 @@ class HoistToDeclBefore::State { if (auto* block = parent->As()) { // Expression's statement sits in a block. Simple case. // Insert the decl before the parent statement - ctx.InsertBefore(block->Declaration()->statements, stmt, decl); + if (decl) { + ctx.InsertBefore(block->Declaration()->statements, stmt, decl); + } return true; } @@ -95,15 +107,22 @@ class HoistToDeclBefore::State { if (fl->Declaration()->initializer == stmt) { // Expression used in for-loop initializer. // Insert the let above the for-loop. - ctx.InsertBefore(fl->Block()->Declaration()->statements, - fl->Declaration(), decl); + if (decl) { + ctx.InsertBefore(fl->Block()->Declaration()->statements, + fl->Declaration(), decl); + } return true; } if (fl->Declaration()->continuing == stmt) { // Expression used in for-loop continuing. // For-loop needs to be decomposed to a loop. - loops[fl].cont_decls.emplace_back(decl); + + // Index the map to convert this for-loop, even if `decl` is nullptr. + auto& decls = loops[fl].cont_decls; + if (decl) { + decls.emplace_back(decl); + } return true; } @@ -263,10 +282,10 @@ class HoistToDeclBefore::State { /// @param as_const hoist to `let` if true, otherwise to `var` /// @param decl_name optional name to use for the variable/constant name /// @return true on success - bool HoistToDeclBefore(const sem::Expression* before_expr, - const ast::Expression* expr, - bool as_const, - const char* decl_name) { + bool Add(const sem::Expression* before_expr, + const ast::Expression* expr, + bool as_const, + const char* decl_name) { auto name = b.Symbols().New(decl_name); // Construct the let/var that holds the hoisted expr @@ -283,6 +302,15 @@ class HoistToDeclBefore::State { return true; } + /// Use to signal that we plan on hoisting a decl before `before_expr`. This + /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if + /// needed. + /// @param before_expr expression we would hoist a decl before + /// @return true on success + bool Prepare(const sem::Expression* before_expr) { + return InsertBefore(before_expr, nullptr); + } + /// Applies any scheduled insertions from previous calls to Add() to /// CloneContext. Call this once before ctx.Clone(). /// @return true on success @@ -302,7 +330,11 @@ bool HoistToDeclBefore::Add(const sem::Expression* before_expr, const ast::Expression* expr, bool as_const, const char* decl_name) { - return state_->HoistToDeclBefore(before_expr, expr, as_const, decl_name); + return state_->Add(before_expr, expr, as_const, decl_name); +} + +bool HoistToDeclBefore::Prepare(const sem::Expression* before_expr) { + return state_->Prepare(before_expr); } bool HoistToDeclBefore::Apply() { diff --git a/src/tint/transform/utils/hoist_to_decl_before.h b/src/tint/transform/utils/hoist_to_decl_before.h index 8f35a0910d..583896d86e 100644 --- a/src/tint/transform/utils/hoist_to_decl_before.h +++ b/src/tint/transform/utils/hoist_to_decl_before.h @@ -23,8 +23,8 @@ namespace tint::transform { /// Utility class that can be used to hoist expressions before other -/// expressions, possibly converting 'for' loops to 'loop's and 'else if to -// 'else if'. +/// expressions, possibly converting 'for-loop's to 'loop's and 'else-if's to +// 'else {if}'s. class HoistToDeclBefore { public: /// Constructor @@ -46,6 +46,13 @@ class HoistToDeclBefore { bool as_const, const char* decl_name = ""); + /// Use to signal that we plan on hoisting a decl before `before_expr`. This + /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if + /// needed. + /// @param before_expr expression we would hoist a decl before + /// @return true on success + bool Prepare(const sem::Expression* before_expr); + /// Applies any scheduled insertions from previous calls to Add() to /// CloneContext. Call this once before ctx.Clone(). /// @return true on success diff --git a/src/tint/transform/utils/hoist_to_decl_before_test.cc b/src/tint/transform/utils/hoist_to_decl_before_test.cc index 70dc8bc6f8..91e17a5770 100644 --- a/src/tint/transform/utils/hoist_to_decl_before_test.cc +++ b/src/tint/transform/utils/hoist_to_decl_before_test.cc @@ -287,5 +287,130 @@ fn f() { EXPECT_EQ(expect, str(cloned)); } +TEST_F(HoistToDeclBeforeTest, Prepare_ForLoopCond) { + // fn f() { + // var a : bool; + // for(; a; ) { + // } + // } + ProgramBuilder b; + auto* var = b.Decl(b.Var("a", b.ty.bool_())); + auto* expr = b.Expr("a"); + auto* s = b.For({}, expr, {}, b.Block()); + b.Func("f", {}, b.ty.void_(), {var, s}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* sem_expr = ctx.src->Sem().Get(expr); + hoistToDeclBefore.Prepare(sem_expr); + hoistToDeclBefore.Apply(); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn f() { + var a : bool; + loop { + if (!(a)) { + break; + } + { + } + } +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + +TEST_F(HoistToDeclBeforeTest, Prepare_ForLoopCont) { + // fn f() { + // for(; true; var a = 1) { + // } + // } + ProgramBuilder b; + auto* expr = b.Expr(1); + auto* s = + b.For({}, b.Expr(true), b.Decl(b.Var("a", nullptr, expr)), b.Block()); + b.Func("f", {}, b.ty.void_(), {s}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* sem_expr = ctx.src->Sem().Get(expr); + hoistToDeclBefore.Prepare(sem_expr); + hoistToDeclBefore.Apply(); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn f() { + loop { + if (!(true)) { + break; + } + { + } + + continuing { + var a = 1; + } + } +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + +TEST_F(HoistToDeclBeforeTest, Prepare_ElseIf) { + // fn f() { + // var a : bool; + // if (true) { + // } else if (a) { + // } else { + // } + // } + ProgramBuilder b; + auto* var = b.Decl(b.Var("a", b.ty.bool_())); + auto* expr = b.Expr("a"); + auto* s = b.If(b.Expr(true), b.Block(), // + b.Else(expr, b.Block()), // + b.Else(b.Block())); + b.Func("f", {}, b.ty.void_(), {var, s}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* sem_expr = ctx.src->Sem().Get(expr); + hoistToDeclBefore.Prepare(sem_expr); + hoistToDeclBefore.Apply(); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn f() { + var a : bool; + if (true) { + } else { + if (a) { + } else { + } + } +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + } // namespace } // namespace tint::transform