CloneContext: Pass the vector to InsertBefore()

There's usually only ever one vector we want to insert into.
Inserting into *all* vectors that happen to contain the reference object is likely unintended, and is a foot-gun waiting to go off.

Change-Id: I533084ccad102fc998b851fd238fd6bea9299450
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46445
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2021-03-31 21:00:26 +00:00 committed by Commit Bot service account
parent 90f43cf87f
commit b4275c870e
6 changed files with 133 additions and 29 deletions

View File

@ -20,6 +20,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::Cloneable);
namespace tint {
CloneContext::ListTransforms::ListTransforms() = default;
CloneContext::ListTransforms::~ListTransforms() = default;
CloneContext::CloneContext(ProgramBuilder* to, Program const* from)
: dst(to), src(from) {}
CloneContext::~CloneContext() = default;

View File

@ -15,6 +15,7 @@
#ifndef SRC_CLONE_CONTEXT_H_
#define SRC_CLONE_CONTEXT_H_
#include <algorithm>
#include <functional>
#include <unordered_map>
#include <utility>
@ -178,14 +179,29 @@ class CloneContext {
std::vector<T*> Clone(const std::vector<T*>& v) {
std::vector<T*> out;
out.reserve(v.size());
for (auto& el : v) {
auto it = insert_before_.find(el);
if (it != insert_before_.end()) {
for (auto insert : it->second) {
out.emplace_back(CheckedCast<T>(insert));
auto list_transform_it = list_transforms_.find(&v);
if (list_transform_it != list_transforms_.end()) {
const auto& transforms = list_transform_it->second;
for (auto& el : v) {
auto insert_before_it = transforms.insert_before_.find(el);
if (insert_before_it != transforms.insert_before_.end()) {
for (auto insert : insert_before_it->second) {
out.emplace_back(CheckedCast<T>(insert));
}
}
out.emplace_back(Clone(el));
auto insert_after_it = transforms.insert_after_.find(el);
if (insert_after_it != transforms.insert_after_.end()) {
for (auto insert : insert_after_it->second) {
out.emplace_back(CheckedCast<T>(insert));
}
}
}
out.emplace_back(Clone(el));
} else {
for (auto& el : v) {
out.emplace_back(Clone(el));
}
}
return out;
}
@ -293,15 +309,46 @@ class CloneContext {
return *this;
}
/// Inserts `object` before `before` whenever a vector containing `object` is
/// cloned.
/// Inserts `object` before `before` whenever `vector` is cloned.
/// @param vector the vector in #src
/// @param before a pointer to the object in #src
/// @param object a pointer to the object in #dst that will be inserted before
/// any occurrence of the clone of `before`
/// @returns this CloneContext so calls can be chained
template <typename BEFORE, typename OBJECT>
CloneContext& InsertBefore(BEFORE* before, OBJECT* object) {
auto& list = insert_before_[before];
template <typename T, typename BEFORE, typename OBJECT>
CloneContext& InsertBefore(const std::vector<T>& vector,
BEFORE* before,
OBJECT* object) {
if (std::find(vector.begin(), vector.end(), before) == vector.end()) {
TINT_ICE(Diagnostics())
<< "CloneContext::InsertBefore() vector does not contain before";
return *this;
}
auto& transforms = list_transforms_[&vector];
auto& list = transforms.insert_before_[before];
list.emplace_back(object);
return *this;
}
/// Inserts `object` after `after` whenever `vector` is cloned.
/// @param vector the vector in #src
/// @param after a pointer to the object in #src
/// @param object a pointer to the object in #dst that will be inserted after
/// any occurrence of the clone of `after`
/// @returns this CloneContext so calls can be chained
template <typename T, typename AFTER, typename OBJECT>
CloneContext& InsertAfter(const std::vector<T>& vector,
AFTER* after,
OBJECT* object) {
if (std::find(vector.begin(), vector.end(), after) == vector.end()) {
TINT_ICE(Diagnostics())
<< "CloneContext::InsertAfter() vector does not contain after";
return *this;
}
auto& transforms = list_transforms_[&vector];
auto& list = transforms.insert_after_[after];
list.emplace_back(object);
return *this;
}
@ -380,17 +427,33 @@ class CloneContext {
/// A vector of Cloneable*
using CloneableList = std::vector<Cloneable*>;
// Transformations to be applied to a list (vector)
struct ListTransforms {
/// Constructor
ListTransforms();
/// Destructor
~ListTransforms();
/// A map of object in #src to the list of cloned objects in #dst.
/// Clone(const std::vector<T*>& v) will use this to insert the map-value
/// list into the target vector before cloning and inserting the map-key.
std::unordered_map<const Cloneable*, CloneableList> insert_before_;
/// A map of object in #src to the list of cloned objects in #dst.
/// Clone(const std::vector<T*>& v) will use this to insert the map-value
/// list into the target vector after cloning and inserting the map-key.
std::unordered_map<const Cloneable*, CloneableList> insert_after_;
};
/// A map of object in #src to their cloned equivalent in #dst
std::unordered_map<const Cloneable*, Cloneable*> cloned_;
/// A map of object in #src to the list of cloned objects in #dst.
/// Clone(const std::vector<T*>& v) will use this to insert the map-value list
/// into the target vector/ before cloning and inserting the map-key.
std::unordered_map<const Cloneable*, CloneableList> insert_before_;
/// Cloneable transform functions registered with ReplaceAll()
std::vector<CloneableTransform> transforms_;
/// Map of std::vector pointer to transforms for that list
std::unordered_map<const void*, ListTransforms> list_transforms_;
/// Symbol transform registered with ReplaceAll()
SymbolTransform symbol_transform_;
};

View File

@ -287,9 +287,10 @@ TEST(CloneContext, CloneWithInsertBefore) {
ProgramBuilder cloned;
auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
auto* cloned_root = CloneContext(&cloned, &original)
.InsertBefore(original_root->b, insertion)
.Clone(original_root);
auto* cloned_root =
CloneContext(&cloned, &original)
.InsertBefore(original_root->vec, original_root->b, insertion)
.Clone(original_root);
EXPECT_EQ(cloned_root->vec.size(), 4u);
EXPECT_EQ(cloned_root->vec[0], cloned_root->a);
@ -303,6 +304,36 @@ TEST(CloneContext, CloneWithInsertBefore) {
EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
}
TEST(CloneContext, CloneWithInsertAfter) {
ProgramBuilder builder;
auto* original_root =
builder.create<Node>(builder.Symbols().Register("root"));
original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
original_root->c = builder.create<Node>(builder.Symbols().Register("c"));
original_root->vec = {original_root->a, original_root->b, original_root->c};
Program original(std::move(builder));
ProgramBuilder cloned;
auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
auto* cloned_root =
CloneContext(&cloned, &original)
.InsertAfter(original_root->vec, original_root->b, insertion)
.Clone(original_root);
EXPECT_EQ(cloned_root->vec.size(), 4u);
EXPECT_EQ(cloned_root->vec[0], cloned_root->a);
EXPECT_EQ(cloned_root->vec[1], cloned_root->b);
EXPECT_EQ(cloned_root->vec[3], cloned_root->c);
EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
EXPECT_EQ(cloned_root->vec[0]->name, cloned.Symbols().Get("a"));
EXPECT_EQ(cloned_root->vec[1]->name, cloned.Symbols().Get("b"));
EXPECT_EQ(cloned_root->vec[2]->name, cloned.Symbols().Get("insertion"));
EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
}
TEST(CloneContext, CloneWithReplaceAll_SameTypeTwice) {
EXPECT_FATAL_FAILURE(
{

View File

@ -18,6 +18,7 @@
#include "src/program_builder.h"
#include "src/semantic/function.h"
#include "src/semantic/statement.h"
#include "src/semantic/variable.h"
namespace tint {
@ -119,7 +120,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
// Initialize it with the value extracted from the new struct parameter.
auto* func_const = ctx.dst->Const(
func_const_symbol, ctx.Clone(param_ty), func_const_initializer);
ctx.InsertBefore(*func->body()->begin(),
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_const));
// Replace all uses of the function parameter with the function const.
@ -134,7 +135,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
ctx.dst->Symbols().New(),
ctx.dst->create<ast::Struct>(new_struct_members,
ast::DecorationList{}));
ctx.InsertBefore(func, in_struct);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, in_struct);
// Create a new function parameter using this struct type.
auto* struct_param = ctx.dst->Var(new_struct_param_symbol, in_struct,
@ -177,12 +178,13 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
ctx.dst->Symbols().New(),
ctx.dst->create<ast::Struct>(new_struct_members,
ast::DecorationList{}));
ctx.InsertBefore(func, out_struct);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, out_struct);
new_ret_type = out_struct;
// Replace all return statements.
auto* sem_func = ctx.src->Sem().Get(func);
for (auto* ret : sem_func->ReturnStatements()) {
auto* ret_sem = ctx.src->Sem().Get(ret);
// Reconstruct the return value using the newly created struct.
auto* new_ret_value = ctx.Clone(ret->value());
ast::ExpressionList ret_values;
@ -193,7 +195,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in,
auto temp = ctx.dst->Symbols().New();
auto* temp_var = ctx.dst->Decl(
ctx.dst->Const(temp, ctx.Clone(ret_type), new_ret_value));
ctx.InsertBefore(ret, temp_var);
ctx.InsertBefore(ret_sem->Block()->statements(), ret, temp_var);
new_ret_value = ctx.dst->Expr(temp);
}

View File

@ -96,7 +96,8 @@ void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const {
auto* dst_ident = ctx.dst->Expr(dst_symbol);
// Insert the constant before the usage
ctx.InsertBefore(src_stmt, dst_var_decl);
ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt,
dst_var_decl);
// Replace the inlined array with a reference to the constant
ctx.Replace(src_init, dst_ident);
}

View File

@ -21,6 +21,7 @@
#include "src/ast/return_statement.h"
#include "src/program_builder.h"
#include "src/semantic/function.h"
#include "src/semantic/statement.h"
#include "src/semantic/variable.h"
namespace tint {
@ -162,13 +163,15 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const {
return_func_symbol, ast::VariableList{store_value},
ctx.dst->ty.void_(), ctx.dst->create<ast::BlockStatement>(stores),
ast::DecorationList{}, ast::DecorationList{});
ctx.InsertBefore(func, return_func);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, return_func);
// Replace all return statements with calls to the output function.
auto* sem_func = ctx.src->Sem().Get(func);
for (auto* ret : sem_func->ReturnStatements()) {
auto* ret_sem = ctx.src->Sem().Get(ret);
auto* call = ctx.dst->Call(return_func_symbol, ctx.Clone(ret->value()));
ctx.InsertBefore(ret, ctx.dst->create<ast::CallStatement>(call));
ctx.InsertBefore(ret_sem->Block()->statements(), ret,
ctx.dst->create<ast::CallStatement>(call));
ctx.Replace(ret, ctx.dst->create<ast::ReturnStatement>());
}
}
@ -247,7 +250,7 @@ Symbol Spirv::HoistToInputVariables(
auto* global_var =
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
ast::StorageClass::kInput, nullptr, new_decorations);
ctx.InsertBefore(func, global_var);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
return global_var_symbol;
}
@ -269,7 +272,8 @@ Symbol Spirv::HoistToInputVariables(
// Create a function-scope variable for the struct.
auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values);
auto* func_var = ctx.dst->Const(func_var_symbol, ctx.Clone(ty), initializer);
ctx.InsertBefore(*func->body()->begin(), ctx.dst->WrapInStatement(func_var));
ctx.InsertBefore(func->body()->statements(), *func->body()->begin(),
ctx.dst->WrapInStatement(func_var));
return func_var_symbol;
}
@ -292,7 +296,7 @@ void Spirv::HoistToOutputVariables(CloneContext& ctx,
auto* global_var =
ctx.dst->Var(global_var_symbol, ctx.Clone(ty),
ast::StorageClass::kOutput, nullptr, new_decorations);
ctx.InsertBefore(func, global_var);
ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var);
// Create the assignment instruction.
ast::Expression* rhs = ctx.dst->Expr(store_value);