Factor out GetInsertionPoint to transform/utils

This function was copy-pasted in two transforms, and will be used in the
next one I'm writing.

Bug: tint:1080
Change-Id: Ic5ffe68a7e9d00b37722e8f5faff01e9e15fa6b1
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/85262
Reviewed-by: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2022-03-30 20:11:35 +00:00 committed by Tint LUCI CQ
parent e6c76095fc
commit c2e9bb785a
8 changed files with 204 additions and 66 deletions

View File

@ -489,6 +489,8 @@ libtint_source_set("libtint_core_all_src") {
"transform/unshadow.h",
"transform/unwind_discard_functions.cc",
"transform/unwind_discard_functions.h",
"transform/utils/get_insertion_point.cc",
"transform/utils/get_insertion_point.h",
"transform/utils/hoist_to_decl_before.cc",
"transform/utils/hoist_to_decl_before.h",
"transform/var_for_dynamic_index.cc",

View File

@ -377,6 +377,8 @@ set(TINT_LIB_SRCS
transform/wrap_arrays_in_structs.h
transform/zero_init_workgroup_memory.cc
transform/zero_init_workgroup_memory.h
transform/utils/get_insertion_point.cc
transform/utils/get_insertion_point.h
transform/utils/hoist_to_decl_before.cc
transform/utils/hoist_to_decl_before.h
sem/bool_type.cc
@ -1039,6 +1041,7 @@ if(TINT_BUILD_TESTS)
transform/vertex_pulling_test.cc
transform/wrap_arrays_in_structs_test.cc
transform/zero_init_workgroup_memory_test.cc
transform/utils/get_insertion_point_test.cc
transform/utils/hoist_to_decl_before_test.cc
)
endif()

View File

@ -28,6 +28,7 @@
#include "src/tint/sem/member_accessor_expression.h"
#include "src/tint/sem/variable.h"
#include "src/tint/transform/manager.h"
#include "src/tint/transform/utils/get_insertion_point.h"
#include "src/tint/transform/utils/hoist_to_decl_before.h"
#include "src/tint/utils/scoped_assignment.h"
@ -556,33 +557,11 @@ class DecomposeSideEffects::DecomposeState : public StateBase {
});
}
// For the input statement, returns the block and statement within that block
// to insert before/after.
std::pair<const sem::BlockStatement*, const ast::Statement*>
GetInsertionPoint(const ast::Statement* stmt) {
auto* sem_stmt = sem.Get(stmt);
if (sem_stmt) {
auto* parent = sem_stmt->Parent();
if (auto* block = parent->As<sem::BlockStatement>()) {
// Common case, just insert in the current block above the input
// statement.
return {block, stmt};
}
if (auto* fl = parent->As<sem::ForLoopStatement>()) {
if (fl->Declaration()->initializer == stmt) {
// For loop init, insert above the for loop itself.
return {fl->Block(), fl->Declaration()};
}
}
}
return {};
}
// Inserts statements in `stmts` before `stmt`
void InsertBefore(const ast::StatementList& stmts,
const ast::Statement* stmt) {
if (!stmts.empty()) {
auto ip = GetInsertionPoint(stmt);
auto ip = utils::GetInsertionPoint(ctx, stmt);
for (auto* s : stmts) {
ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, s);
}

View File

@ -28,6 +28,7 @@
#include "src/tint/sem/for_loop_statement.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/if_statement.h"
#include "src/tint/transform/utils/get_insertion_point.h"
TINT_INSTANTIATE_TYPEINFO(tint::transform::UnwindDiscardFunctions);
@ -42,46 +43,6 @@ class State {
Symbol module_discard_var_name; // Use ModuleDiscardVarName() to read
Symbol module_discard_func_name; // Use ModuleDiscardFuncName() to read
// For the input statement, returns the block and statement within that
// block to insert before/after.
std::pair<const sem::BlockStatement*, const ast::Statement*>
GetInsertionPoint(const ast::Statement* stmt) {
using RetType =
std::pair<const sem::BlockStatement*, const ast::Statement*>;
if (auto* sem_stmt = sem.Get(stmt)) {
auto* parent = sem_stmt->Parent();
return Switch(
parent,
[&](const sem::BlockStatement* block) -> RetType {
// Common case, just insert in the current block above the input
// statement.
return {block, stmt};
},
[&](const sem::ForLoopStatement* fl) -> RetType {
// `stmt` is either the for loop initializer or the continuing
// statement of a for-loop.
if (fl->Declaration()->initializer == stmt) {
// For loop init, insert above the for loop itself.
return {fl->Block(), fl->Declaration()};
}
TINT_ICE(Transform, b.Diagnostics())
<< "cannot insert before or after continuing statement of a "
"for-loop";
return {};
},
[&](Default) -> RetType {
TINT_ICE(Transform, b.Diagnostics())
<< "expected parent of statement to be either a block or for "
"loop";
return {};
});
}
return {};
}
// If `block`'s parent is of type TO, returns pointer to it.
template <typename TO>
const TO* ParentAs(const ast::BlockStatement* block) {
@ -186,7 +147,7 @@ class State {
const sem::Expression* sem_expr) {
auto* expr = sem_expr->Declaration();
auto ip = GetInsertionPoint(stmt);
auto ip = utils::GetInsertionPoint(ctx, stmt);
auto var_name = b.Sym();
auto* decl = b.Decl(b.Var(var_name, nullptr, ctx.Clone(expr)));
ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, decl);
@ -239,7 +200,7 @@ class State {
return HoistAndInsertBefore(stmt, sem_expr);
}
auto ip = GetInsertionPoint(stmt);
auto ip = utils::GetInsertionPoint(ctx, stmt);
ctx.InsertAfter(ip.first->Declaration()->statements, ip.second,
IfDiscardReturn(stmt));
return nullptr; // Don't replace current statement
@ -269,7 +230,7 @@ class State {
to_insert = b.Assign(var_name, true);
}
auto ip = GetInsertionPoint(stmt);
auto ip = utils::GetInsertionPoint(ctx, stmt);
ctx.InsertBefore(ip.first->Declaration()->statements, ip.second, to_insert);
return Return(stmt);
}

View File

@ -0,0 +1,58 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/tint/transform/utils/get_insertion_point.h"
#include "src/tint/debug.h"
#include "src/tint/diagnostic/diagnostic.h"
#include "src/tint/sem/for_loop_statement.h"
namespace tint::transform::utils {
InsertionPoint GetInsertionPoint(CloneContext& ctx,
const ast::Statement* stmt) {
auto& sem = ctx.src->Sem();
auto& diag = ctx.dst->Diagnostics();
using RetType = std::pair<const sem::BlockStatement*, const ast::Statement*>;
if (auto* sem_stmt = sem.Get(stmt)) {
auto* parent = sem_stmt->Parent();
return Switch(
parent,
[&](const sem::BlockStatement* block) -> RetType {
// Common case, can insert in the current block above/below the input
// statement.
return {block, stmt};
},
[&](const sem::ForLoopStatement* fl) -> RetType {
// `stmt` is either the for loop initializer or the continuing
// statement of a for-loop.
if (fl->Declaration()->initializer == stmt) {
// For loop init, can insert above the for loop itself.
return {fl->Block(), fl->Declaration()};
}
// Cannot insert before or after continuing statement of a for-loop
return {};
},
[&](Default) -> RetType {
TINT_ICE(Transform, diag) << "expected parent of statement to be "
"either a block or for loop";
return {};
});
}
return {};
}
} // namespace tint::transform::utils

View File

@ -0,0 +1,40 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
#define SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_
#include <utility>
#include "src/tint/program_builder.h"
#include "src/tint/sem/block_statement.h"
namespace tint::transform::utils {
/// InsertionPoint is a pair of the block (`first`) within which, and the
/// statement (`second`) before or after which to insert.
using InsertionPoint =
std::pair<const sem::BlockStatement*, const ast::Statement*>;
/// For the input statement, returns the block and statement within that
/// block to insert before/after. If `stmt` is a for-loop continue statement,
/// the function returns {nullptr, nullptr} as we cannot insert before/after it.
/// @param ctx the clone context
/// @param stmt the statement to insert before or after
/// @return the insertion point
InsertionPoint GetInsertionPoint(CloneContext& ctx, const ast::Statement* stmt);
} // namespace tint::transform::utils
#endif // SRC_TINT_TRANSFORM_UTILS_GET_INSERTION_POINT_H_

View File

@ -0,0 +1,94 @@
// Copyright 2022 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <utility>
#include "gtest/gtest-spi.h"
#include "src/tint/debug.h"
#include "src/tint/program_builder.h"
#include "src/tint/transform/test_helper.h"
#include "src/tint/transform/utils/get_insertion_point.h"
namespace tint::transform {
namespace {
using GetInsertionPointTest = ::testing::Test;
TEST_F(GetInsertionPointTest, Block) {
// fn f() {
// var a = 1;
// }
ProgramBuilder b;
auto* expr = b.Expr(1);
auto* var = b.Decl(b.Var("a", nullptr, expr));
auto* block = b.Block(var);
b.Func("f", {}, b.ty.void_(), {block});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
// Can insert in block containing the variable, above or below the input
// statement.
auto ip = utils::GetInsertionPoint(ctx, var);
ASSERT_EQ(ip.first->Declaration(), block);
ASSERT_EQ(ip.second, var);
}
TEST_F(GetInsertionPointTest, ForLoopInit) {
// fn f() {
// for(var a = 1; true; ) {
// }
// }
ProgramBuilder b;
auto* expr = b.Expr(1);
auto* var = b.Decl(b.Var("a", nullptr, expr));
auto* fl = b.For(var, b.Expr(true), {}, b.Block());
auto* func_block = b.Block(fl);
b.Func("f", {}, b.ty.void_(), {func_block});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
// Can insert in block containing for-loop above the for-loop itself.
auto ip = utils::GetInsertionPoint(ctx, var);
ASSERT_EQ(ip.first->Declaration(), func_block);
ASSERT_EQ(ip.second, fl);
}
TEST_F(GetInsertionPointTest, ForLoopCont_Invalid) {
// fn f() {
// for(; true; var a = 1) {
// }
// }
ProgramBuilder b;
auto* expr = b.Expr(1);
auto* var = b.Decl(b.Var("a", nullptr, expr));
auto* s = b.For({}, b.Expr(true), var, b.Block());
b.Func("f", {}, b.ty.void_(), {s});
Program original(std::move(b));
ProgramBuilder cloned_b;
CloneContext ctx(&cloned_b, &original);
// Can't insert before/after for loop continue statement (would ned to be
// converted to loop).
auto ip = utils::GetInsertionPoint(ctx, var);
ASSERT_EQ(ip.first, nullptr);
ASSERT_EQ(ip.second, nullptr);
}
} // namespace
} // namespace tint::transform

View File

@ -339,6 +339,7 @@ tint_unittests_source_set("tint_unittests_transform_src") {
"../../src/tint/transform/transform_test.cc",
"../../src/tint/transform/unshadow_test.cc",
"../../src/tint/transform/unwind_discard_functions_test.cc",
"../../src/tint/transform/utils/get_insertion_point_test.cc",
"../../src/tint/transform/utils/hoist_to_decl_before_test.cc",
"../../src/tint/transform/var_for_dynamic_index_test.cc",
"../../src/tint/transform/vectorize_scalar_matrix_constructors_test.cc",