diff --git a/src/clone_context.cc b/src/clone_context.cc index 06a28e1502..f449b0b95e 100644 --- a/src/clone_context.cc +++ b/src/clone_context.cc @@ -20,6 +20,9 @@ TINT_INSTANTIATE_TYPEINFO(tint::Cloneable); namespace tint { +CloneContext::ListTransforms::ListTransforms() = default; +CloneContext::ListTransforms::~ListTransforms() = default; + CloneContext::CloneContext(ProgramBuilder* to, Program const* from) : dst(to), src(from) {} CloneContext::~CloneContext() = default; diff --git a/src/clone_context.h b/src/clone_context.h index 2853ad4fb3..fd7a0acb4f 100644 --- a/src/clone_context.h +++ b/src/clone_context.h @@ -15,6 +15,7 @@ #ifndef SRC_CLONE_CONTEXT_H_ #define SRC_CLONE_CONTEXT_H_ +#include #include #include #include @@ -178,14 +179,29 @@ class CloneContext { std::vector Clone(const std::vector& v) { std::vector out; out.reserve(v.size()); - for (auto& el : v) { - auto it = insert_before_.find(el); - if (it != insert_before_.end()) { - for (auto insert : it->second) { - out.emplace_back(CheckedCast(insert)); + + auto list_transform_it = list_transforms_.find(&v); + if (list_transform_it != list_transforms_.end()) { + const auto& transforms = list_transform_it->second; + for (auto& el : v) { + auto insert_before_it = transforms.insert_before_.find(el); + if (insert_before_it != transforms.insert_before_.end()) { + for (auto insert : insert_before_it->second) { + out.emplace_back(CheckedCast(insert)); + } + } + out.emplace_back(Clone(el)); + auto insert_after_it = transforms.insert_after_.find(el); + if (insert_after_it != transforms.insert_after_.end()) { + for (auto insert : insert_after_it->second) { + out.emplace_back(CheckedCast(insert)); + } } } - out.emplace_back(Clone(el)); + } else { + for (auto& el : v) { + out.emplace_back(Clone(el)); + } } return out; } @@ -293,15 +309,46 @@ class CloneContext { return *this; } - /// Inserts `object` before `before` whenever a vector containing `object` is - /// cloned. + /// Inserts `object` before `before` whenever `vector` is cloned. + /// @param vector the vector in #src /// @param before a pointer to the object in #src /// @param object a pointer to the object in #dst that will be inserted before /// any occurrence of the clone of `before` /// @returns this CloneContext so calls can be chained - template - CloneContext& InsertBefore(BEFORE* before, OBJECT* object) { - auto& list = insert_before_[before]; + template + CloneContext& InsertBefore(const std::vector& vector, + BEFORE* before, + OBJECT* object) { + if (std::find(vector.begin(), vector.end(), before) == vector.end()) { + TINT_ICE(Diagnostics()) + << "CloneContext::InsertBefore() vector does not contain before"; + return *this; + } + + auto& transforms = list_transforms_[&vector]; + auto& list = transforms.insert_before_[before]; + list.emplace_back(object); + return *this; + } + + /// Inserts `object` after `after` whenever `vector` is cloned. + /// @param vector the vector in #src + /// @param after a pointer to the object in #src + /// @param object a pointer to the object in #dst that will be inserted after + /// any occurrence of the clone of `after` + /// @returns this CloneContext so calls can be chained + template + CloneContext& InsertAfter(const std::vector& vector, + AFTER* after, + OBJECT* object) { + if (std::find(vector.begin(), vector.end(), after) == vector.end()) { + TINT_ICE(Diagnostics()) + << "CloneContext::InsertAfter() vector does not contain after"; + return *this; + } + + auto& transforms = list_transforms_[&vector]; + auto& list = transforms.insert_after_[after]; list.emplace_back(object); return *this; } @@ -380,17 +427,33 @@ class CloneContext { /// A vector of Cloneable* using CloneableList = std::vector; + // Transformations to be applied to a list (vector) + struct ListTransforms { + /// Constructor + ListTransforms(); + /// Destructor + ~ListTransforms(); + + /// A map of object in #src to the list of cloned objects in #dst. + /// Clone(const std::vector& v) will use this to insert the map-value + /// list into the target vector before cloning and inserting the map-key. + std::unordered_map insert_before_; + + /// A map of object in #src to the list of cloned objects in #dst. + /// Clone(const std::vector& v) will use this to insert the map-value + /// list into the target vector after cloning and inserting the map-key. + std::unordered_map insert_after_; + }; + /// A map of object in #src to their cloned equivalent in #dst std::unordered_map cloned_; - /// A map of object in #src to the list of cloned objects in #dst. - /// Clone(const std::vector& v) will use this to insert the map-value list - /// into the target vector/ before cloning and inserting the map-key. - std::unordered_map insert_before_; - /// Cloneable transform functions registered with ReplaceAll() std::vector transforms_; + /// Map of std::vector pointer to transforms for that list + std::unordered_map list_transforms_; + /// Symbol transform registered with ReplaceAll() SymbolTransform symbol_transform_; }; diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc index 372da300a0..b09ccdcc56 100644 --- a/src/clone_context_test.cc +++ b/src/clone_context_test.cc @@ -287,9 +287,10 @@ TEST(CloneContext, CloneWithInsertBefore) { ProgramBuilder cloned; auto* insertion = cloned.create(cloned.Symbols().Register("insertion")); - auto* cloned_root = CloneContext(&cloned, &original) - .InsertBefore(original_root->b, insertion) - .Clone(original_root); + auto* cloned_root = + CloneContext(&cloned, &original) + .InsertBefore(original_root->vec, original_root->b, insertion) + .Clone(original_root); EXPECT_EQ(cloned_root->vec.size(), 4u); EXPECT_EQ(cloned_root->vec[0], cloned_root->a); @@ -303,6 +304,36 @@ TEST(CloneContext, CloneWithInsertBefore) { EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c")); } +TEST(CloneContext, CloneWithInsertAfter) { + ProgramBuilder builder; + auto* original_root = + builder.create(builder.Symbols().Register("root")); + original_root->a = builder.create(builder.Symbols().Register("a")); + original_root->b = builder.create(builder.Symbols().Register("b")); + original_root->c = builder.create(builder.Symbols().Register("c")); + original_root->vec = {original_root->a, original_root->b, original_root->c}; + Program original(std::move(builder)); + + ProgramBuilder cloned; + auto* insertion = cloned.create(cloned.Symbols().Register("insertion")); + + auto* cloned_root = + CloneContext(&cloned, &original) + .InsertAfter(original_root->vec, original_root->b, insertion) + .Clone(original_root); + + EXPECT_EQ(cloned_root->vec.size(), 4u); + EXPECT_EQ(cloned_root->vec[0], cloned_root->a); + EXPECT_EQ(cloned_root->vec[1], cloned_root->b); + EXPECT_EQ(cloned_root->vec[3], cloned_root->c); + + EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root")); + EXPECT_EQ(cloned_root->vec[0]->name, cloned.Symbols().Get("a")); + EXPECT_EQ(cloned_root->vec[1]->name, cloned.Symbols().Get("b")); + EXPECT_EQ(cloned_root->vec[2]->name, cloned.Symbols().Get("insertion")); + EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c")); +} + TEST(CloneContext, CloneWithReplaceAll_SameTypeTwice) { EXPECT_FATAL_FAILURE( { diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc index ca4af82b97..1c52066ce1 100644 --- a/src/transform/canonicalize_entry_point_io.cc +++ b/src/transform/canonicalize_entry_point_io.cc @@ -18,6 +18,7 @@ #include "src/program_builder.h" #include "src/semantic/function.h" +#include "src/semantic/statement.h" #include "src/semantic/variable.h" namespace tint { @@ -119,7 +120,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in, // Initialize it with the value extracted from the new struct parameter. auto* func_const = ctx.dst->Const( func_const_symbol, ctx.Clone(param_ty), func_const_initializer); - ctx.InsertBefore(*func->body()->begin(), + ctx.InsertBefore(func->body()->statements(), *func->body()->begin(), ctx.dst->WrapInStatement(func_const)); // Replace all uses of the function parameter with the function const. @@ -134,7 +135,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in, ctx.dst->Symbols().New(), ctx.dst->create(new_struct_members, ast::DecorationList{})); - ctx.InsertBefore(func, in_struct); + ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, in_struct); // Create a new function parameter using this struct type. auto* struct_param = ctx.dst->Var(new_struct_param_symbol, in_struct, @@ -177,12 +178,13 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in, ctx.dst->Symbols().New(), ctx.dst->create(new_struct_members, ast::DecorationList{})); - ctx.InsertBefore(func, out_struct); + ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, out_struct); new_ret_type = out_struct; // Replace all return statements. auto* sem_func = ctx.src->Sem().Get(func); for (auto* ret : sem_func->ReturnStatements()) { + auto* ret_sem = ctx.src->Sem().Get(ret); // Reconstruct the return value using the newly created struct. auto* new_ret_value = ctx.Clone(ret->value()); ast::ExpressionList ret_values; @@ -193,7 +195,7 @@ Transform::Output CanonicalizeEntryPointIO::Run(const Program* in, auto temp = ctx.dst->Symbols().New(); auto* temp_var = ctx.dst->Decl( ctx.dst->Const(temp, ctx.Clone(ret_type), new_ret_value)); - ctx.InsertBefore(ret, temp_var); + ctx.InsertBefore(ret_sem->Block()->statements(), ret, temp_var); new_ret_value = ctx.dst->Expr(temp); } diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc index 21b9bf7c94..0a08e24a30 100644 --- a/src/transform/hlsl.cc +++ b/src/transform/hlsl.cc @@ -96,7 +96,8 @@ void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const { auto* dst_ident = ctx.dst->Expr(dst_symbol); // Insert the constant before the usage - ctx.InsertBefore(src_stmt, dst_var_decl); + ctx.InsertBefore(src_sem_stmt->Block()->statements(), src_stmt, + dst_var_decl); // Replace the inlined array with a reference to the constant ctx.Replace(src_init, dst_ident); } diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc index 1eb7a6aadc..d29cfa4460 100644 --- a/src/transform/spirv.cc +++ b/src/transform/spirv.cc @@ -21,6 +21,7 @@ #include "src/ast/return_statement.h" #include "src/program_builder.h" #include "src/semantic/function.h" +#include "src/semantic/statement.h" #include "src/semantic/variable.h" namespace tint { @@ -162,13 +163,15 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const { return_func_symbol, ast::VariableList{store_value}, ctx.dst->ty.void_(), ctx.dst->create(stores), ast::DecorationList{}, ast::DecorationList{}); - ctx.InsertBefore(func, return_func); + ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, return_func); // Replace all return statements with calls to the output function. auto* sem_func = ctx.src->Sem().Get(func); for (auto* ret : sem_func->ReturnStatements()) { + auto* ret_sem = ctx.src->Sem().Get(ret); auto* call = ctx.dst->Call(return_func_symbol, ctx.Clone(ret->value())); - ctx.InsertBefore(ret, ctx.dst->create(call)); + ctx.InsertBefore(ret_sem->Block()->statements(), ret, + ctx.dst->create(call)); ctx.Replace(ret, ctx.dst->create()); } } @@ -247,7 +250,7 @@ Symbol Spirv::HoistToInputVariables( auto* global_var = ctx.dst->Var(global_var_symbol, ctx.Clone(ty), ast::StorageClass::kInput, nullptr, new_decorations); - ctx.InsertBefore(func, global_var); + ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var); return global_var_symbol; } @@ -269,7 +272,8 @@ Symbol Spirv::HoistToInputVariables( // Create a function-scope variable for the struct. auto* initializer = ctx.dst->Construct(ctx.Clone(ty), init_values); auto* func_var = ctx.dst->Const(func_var_symbol, ctx.Clone(ty), initializer); - ctx.InsertBefore(*func->body()->begin(), ctx.dst->WrapInStatement(func_var)); + ctx.InsertBefore(func->body()->statements(), *func->body()->begin(), + ctx.dst->WrapInStatement(func_var)); return func_var_symbol; } @@ -292,7 +296,7 @@ void Spirv::HoistToOutputVariables(CloneContext& ctx, auto* global_var = ctx.dst->Var(global_var_symbol, ctx.Clone(ty), ast::StorageClass::kOutput, nullptr, new_decorations); - ctx.InsertBefore(func, global_var); + ctx.InsertBefore(ctx.src->AST().GlobalDeclarations(), func, global_var); // Create the assignment instruction. ast::Expression* rhs = ctx.dst->Expr(store_value);