diff --git a/src/clone_context.h b/src/clone_context.h index 40017a5e3f..b00e5474b8 100644 --- a/src/clone_context.h +++ b/src/clone_context.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -248,6 +249,9 @@ class CloneContext { if (list_transform_it != list_transforms_.end()) { const auto& transforms = list_transform_it->second; for (auto& el : v) { + if (transforms.remove_.count(el)) { + continue; + } auto insert_before_it = transforms.insert_before_.find(el); if (insert_before_it != transforms.insert_before_.end()) { for (auto insert : insert_before_it->second) { @@ -369,10 +373,30 @@ class CloneContext { /// @returns this CloneContext so calls can be chained template CloneContext& Replace(WHAT* what, WITH* with) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, what); + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(dst, with); cloned_[what] = with; return *this; } + /// Removes `object` from the cloned copy of `vector`. + /// @param vector the vector in #src + /// @param object a pointer to the object in #src that will be omitted from + /// the cloned vector. + /// @returns this CloneContext so calls can be chained + template + CloneContext& Remove(const std::vector& vector, OBJECT* object) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, object); + if (std::find(vector.begin(), vector.end(), object) == vector.end()) { + TINT_ICE(Diagnostics()) + << "CloneContext::Remove() vector does not contain object"; + return *this; + } + + list_transforms_[&vector].remove_.emplace(object); + return *this; + } + /// Inserts `object` before `before` whenever `vector` is cloned. /// @param vector the vector in #src /// @param before a pointer to the object in #src @@ -383,6 +407,8 @@ class CloneContext { CloneContext& InsertBefore(const std::vector& vector, const BEFORE* before, OBJECT* object) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, before); + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(dst, object); if (std::find(vector.begin(), vector.end(), before) == vector.end()) { TINT_ICE(Diagnostics()) << "CloneContext::InsertBefore() vector does not contain before"; @@ -405,6 +431,8 @@ class CloneContext { CloneContext& InsertAfter(const std::vector& vector, const AFTER* after, OBJECT* object) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, after); + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(dst, object); if (std::find(vector.begin(), vector.end(), after) == vector.end()) { TINT_ICE(Diagnostics()) << "CloneContext::InsertAfter() vector does not contain after"; @@ -453,7 +481,10 @@ class CloneContext { if (TO* cast = As(obj)) { return cast; } - TINT_ICE(Diagnostics()) << "Cloned object was not of the expected type"; + TINT_ICE(Diagnostics()) + << "Cloned object was not of the expected type\n" + << "got: " << (obj ? obj->TypeInfo().name : "") << "\n" + << "expected: " << TypeInfo::Of().name; return nullptr; } @@ -470,6 +501,9 @@ class CloneContext { /// Destructor ~ListTransforms(); + /// A map of object in #src to omit when cloned into #dst. + std::unordered_set remove_; + /// A map of object in #src to the list of cloned objects in #dst. /// Clone(const std::vector& v) will use this to insert the map-value /// list into the target vector before cloning and inserting the map-key. diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc index d5086c57b5..4ec9585e31 100644 --- a/src/clone_context_test.cc +++ b/src/clone_context_test.cc @@ -397,6 +397,39 @@ TYPED_TEST(CloneContextNodeTest, CloneWithReplace) { EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c")); } +TYPED_TEST(CloneContextNodeTest, CloneWithRemove) { + using Node = typename TestFixture::Node; + constexpr bool is_unique = TestFixture::is_unique; + + Allocator a; + + ProgramBuilder builder; + auto* original_root = a.Create(builder.Symbols().Register("root")); + original_root->a = a.Create(builder.Symbols().Register("a")); + original_root->b = a.Create(builder.Symbols().Register("b")); + original_root->c = a.Create(builder.Symbols().Register("c")); + original_root->vec = {original_root->a, original_root->b, original_root->c}; + Program original(std::move(builder)); + + ProgramBuilder cloned; + auto* cloned_root = CloneContext(&cloned, &original) + .Remove(original_root->vec, original_root->b) + .Clone(original_root); + + EXPECT_EQ(cloned_root->vec.size(), 2u); + if (is_unique) { + EXPECT_NE(cloned_root->vec[0], cloned_root->a); + EXPECT_NE(cloned_root->vec[1], cloned_root->c); + } else { + EXPECT_EQ(cloned_root->vec[0], cloned_root->a); + EXPECT_EQ(cloned_root->vec[1], cloned_root->c); + } + + EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root")); + EXPECT_EQ(cloned_root->vec[0]->name, cloned.Symbols().Get("a")); + EXPECT_EQ(cloned_root->vec[1]->name, cloned.Symbols().Get("c")); +} + TYPED_TEST(CloneContextNodeTest, CloneWithInsertBefore) { using Node = typename TestFixture::Node; constexpr bool is_unique = TestFixture::is_unique; @@ -691,7 +724,7 @@ TEST_F(CloneContextTest, ProgramIDs) { EXPECT_EQ(cloned->program_id, dst.ID()); } -TEST_F(CloneContextTest, ProgramIDs_ObjectNotOwnedBySrc) { +TEST_F(CloneContextTest, ProgramIDs_Clone_ObjectNotOwnedBySrc) { EXPECT_FATAL_FAILURE( { ProgramBuilder dst; @@ -703,7 +736,7 @@ TEST_F(CloneContextTest, ProgramIDs_ObjectNotOwnedBySrc) { R"(internal compiler error: TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(src, a))"); } -TEST_F(CloneContextTest, ProgramIDs_ObjectNotOwnedByDst) { +TEST_F(CloneContextTest, ProgramIDs_Clone_ObjectNotOwnedByDst) { EXPECT_FATAL_FAILURE( { ProgramBuilder dst;