From c2e9bb785a172aad2d7e1c51df086995b870c8ad Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Wed, 30 Mar 2022 20:11:35 +0000 Subject: [PATCH] Factor out GetInsertionPoint to transform/utils This function was copy-pasted in two transforms, and will be used in the next one I'm writing. Bug: tint:1080 Change-Id: Ic5ffe68a7e9d00b37722e8f5faff01e9e15fa6b1 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/85262 Reviewed-by: Ben Clayton Reviewed-by: James Price Kokoro: Kokoro Commit-Queue: Antonio Maiorano --- src/tint/BUILD.gn | 2 + src/tint/CMakeLists.txt | 3 + .../transform/promote_side_effects_to_decl.cc | 25 +---- .../transform/unwind_discard_functions.cc | 47 +--------- .../transform/utils/get_insertion_point.cc | 58 ++++++++++++ .../transform/utils/get_insertion_point.h | 40 ++++++++ .../utils/get_insertion_point_test.cc | 94 +++++++++++++++++++ test/tint/BUILD.gn | 1 + 8 files changed, 204 insertions(+), 66 deletions(-) create mode 100644 src/tint/transform/utils/get_insertion_point.cc create mode 100644 src/tint/transform/utils/get_insertion_point.h create mode 100644 src/tint/transform/utils/get_insertion_point_test.cc diff --git a/src/tint/BUILD.gn b/src/tint/BUILD.gn index 1646c85d4b..a2541bdb32 100644 --- a/src/tint/BUILD.gn +++ b/src/tint/BUILD.gn @@ -489,6 +489,8 @@ libtint_source_set("libtint_core_all_src") { "transform/unshadow.h", "transform/unwind_discard_functions.cc", "transform/unwind_discard_functions.h", + "transform/utils/get_insertion_point.cc", + "transform/utils/get_insertion_point.h", "transform/utils/hoist_to_decl_before.cc", "transform/utils/hoist_to_decl_before.h", "transform/var_for_dynamic_index.cc", diff --git a/src/tint/CMakeLists.txt b/src/tint/CMakeLists.txt index 1f5e90c36f..0a3ce1cb1a 100644 --- a/src/tint/CMakeLists.txt +++ b/src/tint/CMakeLists.txt @@ -377,6 +377,8 @@ set(TINT_LIB_SRCS transform/wrap_arrays_in_structs.h transform/zero_init_workgroup_memory.cc transform/zero_init_workgroup_memory.h + transform/utils/get_insertion_point.cc + transform/utils/get_insertion_point.h transform/utils/hoist_to_decl_before.cc transform/utils/hoist_to_decl_before.h sem/bool_type.cc @@ -1039,6 +1041,7 @@ if(TINT_BUILD_TESTS) transform/vertex_pulling_test.cc transform/wrap_arrays_in_structs_test.cc transform/zero_init_workgroup_memory_test.cc + transform/utils/get_insertion_point_test.cc transform/utils/hoist_to_decl_before_test.cc ) endif() diff --git a/src/tint/transform/promote_side_effects_to_decl.cc b/src/tint/transform/promote_side_effects_to_decl.cc index e1d5ab2e2f..9fd19db3b4 100644 --- a/src/tint/transform/promote_side_effects_to_decl.cc +++ b/src/tint/transform/promote_side_effects_to_decl.cc @@ -28,6 +28,7 @@ #include "src/tint/sem/member_accessor_expression.h" #include "src/tint/sem/variable.h" #include "src/tint/transform/manager.h" +#include "src/tint/transform/utils/get_insertion_point.h" #include "src/tint/transform/utils/hoist_to_decl_before.h" #include "src/tint/utils/scoped_assignment.h" @@ -556,33 +557,11 @@ class DecomposeSideEffects::DecomposeState : public StateBase { }); } - // For the input statement, returns the block and statement within that block - // to insert before/after. - std::pair - GetInsertionPoint(const ast::Statement* stmt) { - auto* sem_stmt = sem.Get(stmt); - if (sem_stmt) { - auto* parent = sem_stmt->Parent(); - if (auto* block = parent->As()) { - // Common case, just insert in the current block above the input - // statement. - return {block, stmt}; - } - if (auto* fl = parent->As()) { - if (fl->Declaration()->initializer == stmt) { - // For loop init, insert above the for loop itself. - return {fl->Block(), fl->Declaration()}; - } - } - } - return {}; - } - // Inserts statements in `stmts` before `stmt` void InsertBefore(const ast::StatementList& stmts, const ast::Statement* stmt) { if (!stmts.empty()) { - auto ip = GetInsertionPoint(stmt); + auto ip = utils::GetInsertionPoint(ctx, stmt); for (auto* s : stmts) { ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, s); } diff --git a/src/tint/transform/unwind_discard_functions.cc b/src/tint/transform/unwind_discard_functions.cc index 80a47b133a..b15f3a093c 100644 --- a/src/tint/transform/unwind_discard_functions.cc +++ b/src/tint/transform/unwind_discard_functions.cc @@ -28,6 +28,7 @@ #include "src/tint/sem/for_loop_statement.h" #include "src/tint/sem/function.h" #include "src/tint/sem/if_statement.h" +#include "src/tint/transform/utils/get_insertion_point.h" TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions); @@ -42,46 +43,6 @@ class State { Symbol module_discard_var_name; // Use ModuleDiscardVarName() to read Symbol module_discard_func_name; // Use ModuleDiscardFuncName() to read - // For the input statement, returns the block and statement within that - // block to insert before/after. - std::pair - GetInsertionPoint(const ast::Statement* stmt) { - using RetType = - std::pair; - - if (auto* sem_stmt = sem.Get(stmt)) { - auto* parent = sem_stmt->Parent(); - return Switch( - parent, - [&](const sem::BlockStatement* block) -> RetType { - // Common case, just insert in the current block above the input - // statement. - return {block, stmt}; - }, - [&](const sem::ForLoopStatement* fl) -> RetType { - // `stmt` is either the for loop initializer or the continuing - // statement of a for-loop. - if (fl->Declaration()->initializer == stmt) { - // For loop init, insert above the for loop itself. - return {fl->Block(), fl->Declaration()}; - } - - TINT_ICE(Transform, b.Diagnostics()) - << "cannot insert before or after continuing statement of a " - "for-loop"; - return {}; - }, - [&](Default) -> RetType { - TINT_ICE(Transform, b.Diagnostics()) - << "expected parent of statement to be either a block or for " - "loop"; - return {}; - }); - } - - return {}; - } - // If `block`'s parent is of type TO, returns pointer to it. template const TO* ParentAs(const ast::BlockStatement* block) { @@ -186,7 +147,7 @@ class State { const sem::Expression* sem_expr) { auto* expr = sem_expr->Declaration(); - auto ip = GetInsertionPoint(stmt); + auto ip = utils::GetInsertionPoint(ctx, stmt); auto var_name = b.Sym(); auto* decl = b.Decl(b.Var(var_name, nullptr, ctx.Clone(expr))); ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl); @@ -239,7 +200,7 @@ class State { return HoistAndInsertBefore(stmt, sem_expr); } - auto ip = GetInsertionPoint(stmt); + auto ip = utils::GetInsertionPoint(ctx, stmt); ctx.InsertAfter(ip.first->Declaration()->statements, ip.second, IfDiscardReturn(stmt)); return nullptr; // Don't replace current statement @@ -269,7 +230,7 @@ class State { to_insert = b.Assign(var_name, true); } - auto ip = GetInsertionPoint(stmt); + auto ip = utils::GetInsertionPoint(ctx, stmt); ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, to_insert); return Return(stmt); } diff --git a/src/tint/transform/utils/get_insertion_point.cc b/src/tint/transform/utils/get_insertion_point.cc new file mode 100644 index 0000000000..0f00e0c7cb --- /dev/null +++ b/src/tint/transform/utils/get_insertion_point.cc @@ -0,0 +1,58 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/tint/transform/utils/get_insertion_point.h" +#include "src/tint/debug.h" +#include "src/tint/diagnostic/diagnostic.h" +#include "src/tint/sem/for_loop_statement.h" + +namespace tint::transform::utils { + +InsertionPoint GetInsertionPoint(CloneContext& ctx, + const ast::Statement* stmt) { + auto& sem = ctx.src->Sem(); + auto& diag = ctx.dst->Diagnostics(); + using RetType = std::pair; + + if (auto* sem_stmt = sem.Get(stmt)) { + auto* parent = sem_stmt->Parent(); + return Switch( + parent, + [&](const sem::BlockStatement* block) -> RetType { + // Common case, can insert in the current block above/below the input + // statement. + return {block, stmt}; + }, + [&](const sem::ForLoopStatement* fl) -> RetType { + // `stmt` is either the for loop initializer or the continuing + // statement of a for-loop. + if (fl->Declaration()->initializer == stmt) { + // For loop init, can insert above the for loop itself. + return {fl->Block(), fl->Declaration()}; + } + + // Cannot insert before or after continuing statement of a for-loop + return {}; + }, + [&](Default) -> RetType { + TINT_ICE(Transform, diag) << "expected parent of statement to be " + "either a block or for loop"; + return {}; + }); + } + + return {}; +} + +} // namespace tint::transform::utils diff --git a/src/tint/transform/utils/get_insertion_point.h b/src/tint/transform/utils/get_insertion_point.h new file mode 100644 index 0000000000..85abcea870 --- /dev/null +++ b/src/tint/transform/utils/get_insertion_point.h @@ -0,0 +1,40 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_ +#define SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_ + +#include + +#include "src/tint/program_builder.h" +#include "src/tint/sem/block_statement.h" + +namespace tint::transform::utils { + +/// InsertionPoint is a pair of the block (`first`) within which, and the +/// statement (`second`) before or after which to insert. +using InsertionPoint = + std::pair; + +/// For the input statement, returns the block and statement within that +/// block to insert before/after. If `stmt` is a for-loop continue statement, +/// the function returns {nullptr, nullptr} as we cannot insert before/after it. +/// @param ctx the clone context +/// @param stmt the statement to insert before or after +/// @return the insertion point +InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt); + +} // namespace tint::transform::utils + +#endif // SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_ diff --git a/src/tint/transform/utils/get_insertion_point_test.cc b/src/tint/transform/utils/get_insertion_point_test.cc new file mode 100644 index 0000000000..48e358ece8 --- /dev/null +++ b/src/tint/transform/utils/get_insertion_point_test.cc @@ -0,0 +1,94 @@ +// Copyright 2022 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "gtest/gtest-spi.h" +#include "src/tint/debug.h" +#include "src/tint/program_builder.h" +#include "src/tint/transform/test_helper.h" +#include "src/tint/transform/utils/get_insertion_point.h" + +namespace tint::transform { +namespace { + +using GetInsertionPointTest = ::testing::Test; + +TEST_F(GetInsertionPointTest, Block) { + // fn f() { + // var a = 1; + // } + ProgramBuilder b; + auto* expr = b.Expr(1); + auto* var = b.Decl(b.Var("a", nullptr, expr)); + auto* block = b.Block(var); + b.Func("f", {}, b.ty.void_(), {block}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + // Can insert in block containing the variable, above or below the input + // statement. + auto ip = utils::GetInsertionPoint(ctx, var); + ASSERT_EQ(ip.first->Declaration(), block); + ASSERT_EQ(ip.second, var); +} + +TEST_F(GetInsertionPointTest, ForLoopInit) { + // fn f() { + // for(var a = 1; true; ) { + // } + // } + ProgramBuilder b; + auto* expr = b.Expr(1); + auto* var = b.Decl(b.Var("a", nullptr, expr)); + auto* fl = b.For(var, b.Expr(true), {}, b.Block()); + auto* func_block = b.Block(fl); + b.Func("f", {}, b.ty.void_(), {func_block}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + // Can insert in block containing for-loop above the for-loop itself. + auto ip = utils::GetInsertionPoint(ctx, var); + ASSERT_EQ(ip.first->Declaration(), func_block); + ASSERT_EQ(ip.second, fl); +} + +TEST_F(GetInsertionPointTest, ForLoopCont_Invalid) { + // fn f() { + // for(; true; var a = 1) { + // } + // } + ProgramBuilder b; + auto* expr = b.Expr(1); + auto* var = b.Decl(b.Var("a", nullptr, expr)); + auto* s = b.For({}, b.Expr(true), var, b.Block()); + b.Func("f", {}, b.ty.void_(), {s}); + + Program original(std::move(b)); + ProgramBuilder cloned_b; + CloneContext ctx(&cloned_b, &original); + + // Can't insert before/after for loop continue statement (would ned to be + // converted to loop). + auto ip = utils::GetInsertionPoint(ctx, var); + ASSERT_EQ(ip.first, nullptr); + ASSERT_EQ(ip.second, nullptr); +} + +} // namespace +} // namespace tint::transform diff --git a/test/tint/BUILD.gn b/test/tint/BUILD.gn index 7f6ec7edc1..6bdd5272ae 100644 --- a/test/tint/BUILD.gn +++ b/test/tint/BUILD.gn @@ -339,6 +339,7 @@ tint_unittests_source_set("tint_unittests_transform_src") { "../../src/tint/transform/transform_test.cc", "../../src/tint/transform/unshadow_test.cc", "../../src/tint/transform/unwind_discard_functions_test.cc", + "../../src/tint/transform/utils/get_insertion_point_test.cc", "../../src/tint/transform/utils/hoist_to_decl_before_test.cc", "../../src/tint/transform/var_for_dynamic_index_test.cc", "../../src/tint/transform/vectorize_scalar_matrix_constructors_test.cc",