diff --git a/src/ast/module.cc b/src/ast/module.cc index e85c033224..12c8ebaf6e 100644 --- a/src/ast/module.cc +++ b/src/ast/module.cc @@ -69,6 +69,7 @@ void Module::AddGlobalVariable(ast::Variable* var) { void Module::AddTypeDecl(ast::TypeDecl* type) { TINT_ASSERT(type); + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(type, program_id()); type_decls_.push_back(type); global_declarations_.push_back(type); } @@ -87,17 +88,21 @@ Module* Module::Clone(CloneContext* ctx) const { } void Module::Copy(CloneContext* ctx, const Module* src) { - for (auto* decl : ctx->Clone(src->global_declarations_)) { + ctx->Clone(global_declarations_, src->global_declarations_); + for (auto* decl : global_declarations_) { if (!decl) { TINT_ICE(ctx->dst->Diagnostics()) << "src global declaration was nullptr"; continue; } - if (auto* ty = decl->As()) { - AddTypeDecl(ty); + if (auto* type = decl->As()) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(type, program_id()); + type_decls_.push_back(type); } else if (auto* func = decl->As()) { - AddFunction(func); + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(func, program_id()); + functions_.push_back(func); } else if (auto* var = decl->As()) { - AddGlobalVariable(var); + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(var, program_id()); + global_variables_.push_back(var); } else { TINT_ICE(ctx->dst->Diagnostics()) << "Unknown global declaration type"; } diff --git a/src/ast/module_test.cc b/src/ast/module_test.cc index 28bf5487f1..510844a5fd 100644 --- a/src/ast/module_test.cc +++ b/src/ast/module_test.cc @@ -14,6 +14,7 @@ #include "gtest/gtest-spi.h" #include "src/ast/test_helper.h" +#include "src/clone_context.h" namespace tint { namespace ast { @@ -96,6 +97,53 @@ TEST_F(ModuleTest, Assert_Null_Function) { "internal compiler error"); } +TEST_F(ModuleTest, CloneOrder) { + // Create a program with a function, alias decl and var decl. + Program p = [] { + ProgramBuilder b; + b.Func("F", {}, b.ty.void_(), {}); + b.Alias("A", b.ty.u32()); + b.Global("V", b.ty.i32(), ast::StorageClass::kPrivate); + return Program(std::move(b)); + }(); + + // Clone the program, using ReplaceAll() to create new module-scope + // declarations. We want to test that these are added just before the + // declaration that triggered the ReplaceAll(). + ProgramBuilder cloned; + CloneContext ctx(&cloned, &p); + ctx.ReplaceAll([&](ast::Function*) -> ast::Function* { + ctx.dst->Alias("inserted_before_F", cloned.ty.u32()); + return nullptr; + }); + ctx.ReplaceAll([&](ast::Alias*) -> ast::Alias* { + ctx.dst->Alias("inserted_before_A", cloned.ty.u32()); + return nullptr; + }); + ctx.ReplaceAll([&](ast::Variable*) -> ast::Variable* { + ctx.dst->Alias("inserted_before_V", cloned.ty.u32()); + return nullptr; + }); + ctx.Clone(); + + auto& decls = cloned.AST().GlobalDeclarations(); + ASSERT_EQ(decls.size(), 6u); + EXPECT_TRUE(decls[1]->Is()); + EXPECT_TRUE(decls[3]->Is()); + EXPECT_TRUE(decls[5]->Is()); + + ASSERT_TRUE(decls[0]->Is()); + ASSERT_TRUE(decls[2]->Is()); + ASSERT_TRUE(decls[4]->Is()); + + ASSERT_EQ(cloned.Symbols().NameFor(decls[0]->As()->name()), + "inserted_before_F"); + ASSERT_EQ(cloned.Symbols().NameFor(decls[2]->As()->name()), + "inserted_before_A"); + ASSERT_EQ(cloned.Symbols().NameFor(decls[4]->As()->name()), + "inserted_before_V"); +} + } // namespace } // namespace ast } // namespace tint diff --git a/src/clone_context.h b/src/clone_context.h index 0b20d25eeb..78b47f6526 100644 --- a/src/clone_context.h +++ b/src/clone_context.h @@ -210,7 +210,7 @@ class CloneContext { return out; } - /// Clones each of the elements of the vector `v` into the ProgramBuilder + /// Clones each of the elements of the vector `v` using the ProgramBuilder /// #dst, inserting any additional elements into the list that were registered /// with calls to InsertBefore(). /// @@ -221,40 +221,53 @@ class CloneContext { template std::vector Clone(const std::vector& v) { std::vector out; - out.reserve(v.size()); + Clone(out, v); + return out; + } - auto list_transform_it = list_transforms_.find(&v); + /// Clones each of the elements of the vector `from` into the vector `to`, + /// inserting any additional elements into the list that were registered with + /// calls to InsertBefore(). + /// + /// All the elements of the vector `from` must be owned by the Program #src. + /// + /// @param from the vector to clone + /// @param to the cloned result + template + void Clone(std::vector& to, const std::vector& from) { + to.reserve(from.size()); + + auto list_transform_it = list_transforms_.find(&from); if (list_transform_it != list_transforms_.end()) { const auto& transforms = list_transform_it->second; for (auto* o : transforms.insert_front_) { - out.emplace_back(CheckedCast(o)); + to.emplace_back(CheckedCast(o)); } - for (auto& el : v) { + for (auto& el : from) { auto insert_before_it = transforms.insert_before_.find(el); if (insert_before_it != transforms.insert_before_.end()) { for (auto insert : insert_before_it->second) { - out.emplace_back(CheckedCast(insert)); + to.emplace_back(CheckedCast(insert)); } } if (transforms.remove_.count(el) == 0) { - out.emplace_back(Clone(el)); + to.emplace_back(Clone(el)); } auto insert_after_it = transforms.insert_after_.find(el); if (insert_after_it != transforms.insert_after_.end()) { for (auto insert : insert_after_it->second) { - out.emplace_back(CheckedCast(insert)); + to.emplace_back(CheckedCast(insert)); } } } for (auto* o : transforms.insert_back_) { - out.emplace_back(CheckedCast(o)); + to.emplace_back(CheckedCast(o)); } } else { - for (auto& el : v) { - out.emplace_back(Clone(el)); + for (auto& el : from) { + to.emplace_back(Clone(el)); } } - return out; } /// Clones each of the elements of the vector `v` into the ProgramBuilder