From 26157557e86d8aa80e152660e4c0410a455d7a12 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Thu, 2 Mar 2023 17:37:53 +0000 Subject: [PATCH] tint/transform/utils: Add HoistToDeclBefore::Replace() Handles statement replacement of for-loop initializer and continuing statements. Change-Id: I83ddf6fbd9b19f5022f7b02d7aebcbd95ab4c1f8 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/122302 Kokoro: Ben Clayton Reviewed-by: Dan Sinclair Commit-Queue: Ben Clayton --- .../transform/utils/hoist_to_decl_before.cc | 23 ++ .../transform/utils/hoist_to_decl_before.h | 14 + .../utils/hoist_to_decl_before_test.cc | 259 ++++++++++++++++++ 3 files changed, 296 insertions(+) diff --git a/src/tint/transform/utils/hoist_to_decl_before.cc b/src/tint/transform/utils/hoist_to_decl_before.cc index 4155d516bc..5fa00c4c98 100644 --- a/src/tint/transform/utils/hoist_to_decl_before.cc +++ b/src/tint/transform/utils/hoist_to_decl_before.cc @@ -99,6 +99,21 @@ struct HoistToDeclBefore::State { return InsertBeforeImpl(before_stmt, std::move(builder)); } + /// @copydoc HoistToDeclBefore::Replace(const sem::Statement* what, const ast::Statement* with) + bool Replace(const sem::Statement* what, const ast::Statement* with) { + auto builder = [with] { return with; }; + return Replace(what, std::move(builder)); + } + + /// @copydoc HoistToDeclBefore::Replace(const sem::Statement* what, const StmtBuilder& with) + bool Replace(const sem::Statement* what, const StmtBuilder& with) { + if (!InsertBeforeImpl(what, Decompose{})) { + return false; + } + ctx.Replace(what->Declaration(), with); + return true; + } + /// @copydoc HoistToDeclBefore::Prepare() bool Prepare(const sem::ValueExpression* before_expr) { return InsertBefore(before_expr->Stmt(), nullptr); @@ -413,6 +428,14 @@ bool HoistToDeclBefore::InsertBefore(const sem::Statement* before_stmt, return state_->InsertBefore(before_stmt, builder); } +bool HoistToDeclBefore::Replace(const sem::Statement* what, const ast::Statement* with) { + return state_->Replace(what, with); +} + +bool HoistToDeclBefore::Replace(const sem::Statement* what, const StmtBuilder& with) { + return state_->Replace(what, with); +} + bool HoistToDeclBefore::Prepare(const sem::ValueExpression* before_expr) { return state_->Prepare(before_expr); } diff --git a/src/tint/transform/utils/hoist_to_decl_before.h b/src/tint/transform/utils/hoist_to_decl_before.h index 81c255f468..c662b1eaf8 100644 --- a/src/tint/transform/utils/hoist_to_decl_before.h +++ b/src/tint/transform/utils/hoist_to_decl_before.h @@ -76,6 +76,20 @@ class HoistToDeclBefore { /// @return true on success bool InsertBefore(const sem::Statement* before_stmt, const StmtBuilder& builder); + /// Replaces the statement @p what with the statement @p stmt, possibly converting 'for-loop's + /// to 'loop's if necessary. + /// @param what the statement to replace + /// @param with the replacement statement + /// @return true on success + bool Replace(const sem::Statement* what, const ast::Statement* with); + + /// Replaces the statement @p what with the statement returned by @p stmt, possibly converting + /// 'for-loop's to 'loop's if necessary. + /// @param what the statement to replace + /// @param with the replacement statement builder + /// @return true on success + bool Replace(const sem::Statement* what, const StmtBuilder& with); + /// Use to signal that we plan on hoisting a decl before `before_expr`. This /// will convert 'for-loop's to 'loop's and 'else-if's to 'else {if}'s if /// needed. diff --git a/src/tint/transform/utils/hoist_to_decl_before_test.cc b/src/tint/transform/utils/hoist_to_decl_before_test.cc index 8d364e12d6..0abb809305 100644 --- a/src/tint/transform/utils/hoist_to_decl_before_test.cc +++ b/src/tint/transform/utils/hoist_to_decl_before_test.cc @@ -877,5 +877,264 @@ fn f() { EXPECT_EQ(expect, str(cloned)); } +TEST_F(HoistToDeclBeforeTest, Replace_Block) { + // fn foo() { + // } + // fn f() { + // var a = 1i; + // } + ProgramBuilder b; + b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty); + auto* var = b.Decl(b.Var("a", b.Expr(1_i))); + b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* target_stmt = ctx.src->Sem().Get(var); + auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo")); + hoistToDeclBefore.Replace(target_stmt, new_stmt); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn foo() { +} + +fn f() { + foo(); +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + +TEST_F(HoistToDeclBeforeTest, Replace_Block_Function) { + // fn foo() { + // } + // fn f() { + // var a = 1i; + // } + ProgramBuilder b; + b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty); + auto* var = b.Decl(b.Var("a", b.Expr(1_i))); + b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* target_stmt = ctx.src->Sem().Get(var); + hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); }); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn foo() { +} + +fn f() { + foo(); +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + +TEST_F(HoistToDeclBeforeTest, Replace_ForLoopInit) { + // fn foo() { + // } + // fn f() { + // for(var a = 1i; true;) { + // } + // } + ProgramBuilder b; + b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty); + auto* var = b.Decl(b.Var("a", b.Expr(1_i))); + auto* s = b.For(var, b.Expr(true), nullptr, b.Block()); + b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{s}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* target_stmt = ctx.src->Sem().Get(var); + auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo")); + hoistToDeclBefore.Replace(target_stmt, new_stmt); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn foo() { +} + +fn f() { + { + foo(); + loop { + if (!(true)) { + break; + } + { + } + } + } +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + +TEST_F(HoistToDeclBeforeTest, Replace_ForLoopInit_Function) { + // fn foo() { + // } + // fn f() { + // for(var a = 1i; true;) { + // } + // } + ProgramBuilder b; + b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty); + auto* var = b.Decl(b.Var("a", b.Expr(1_i))); + auto* s = b.For(var, b.Expr(true), nullptr, b.Block()); + b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{s}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* target_stmt = ctx.src->Sem().Get(var); + hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); }); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn foo() { +} + +fn f() { + { + foo(); + loop { + if (!(true)) { + break; + } + { + } + } + } +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + +TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont) { + // fn foo() { + // } + // fn f() { + // var a = 1i; + // for(; true; a+=1i) { + // } + // } + ProgramBuilder b; + b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty); + auto* var = b.Decl(b.Var("a", b.Expr(1_i))); + auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd); + auto* s = b.For(nullptr, b.Expr(true), cont, b.Block()); + b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* target_stmt = ctx.src->Sem().Get(cont->As()); + auto* new_stmt = ctx.dst->CallStmt(ctx.dst->Call("foo")); + hoistToDeclBefore.Replace(target_stmt, new_stmt); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn foo() { +} + +fn f() { + var a = 1i; + loop { + if (!(true)) { + break; + } + { + } + + continuing { + foo(); + } + } +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + +TEST_F(HoistToDeclBeforeTest, Replace_ForLoopCont_Function) { + // fn foo() { + // } + // fn f() { + // var a = 1i; + // for(; true; a+=1i) { + // } + // } + ProgramBuilder b; + b.Func("foo", utils::Empty, b.ty.void_(), utils::Empty); + auto* var = b.Decl(b.Var("a", b.Expr(1_i))); + auto* cont = b.CompoundAssign("a", b.Expr(1_i), ast::BinaryOp::kAdd); + auto* s = b.For(nullptr, b.Expr(true), cont, b.Block()); + b.Func("f", utils::Empty, b.ty.void_(), utils::Vector{var, s}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + HoistToDeclBefore hoistToDeclBefore(ctx); + auto* target_stmt = ctx.src->Sem().Get(cont->As()); + hoistToDeclBefore.Replace(target_stmt, [&] { return ctx.dst->CallStmt(ctx.dst->Call("foo")); }); + + ctx.Clone(); + Program cloned(std::move(cloned_b)); + + auto* expect = R"( +fn foo() { +} + +fn f() { + var a = 1i; + loop { + if (!(true)) { + break; + } + { + } + + continuing { + foo(); + } + } +} +)"; + + EXPECT_EQ(expect, str(cloned)); +} + } // namespace } // namespace tint::transform