From f4e213cf2da291596cc30117f9bf29098900748d Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Thu, 15 Jul 2021 19:10:15 +0000 Subject: [PATCH] CloneContext: Add an overload of Replace() that takes a function Replace(T* what, T* with) is bug-prone, as more complex transforms may want to clone `what` multiple times, or not at all. In both cases, this will likely result in an ICE as either the replacement will be reachable multiple times, or not at all. This is the cause of some of the CTS failures reported in crbug.com/tint/993. Bug: tint:993 Change-Id: I880ece45faab0e7f07230a1b4436f4e9846edc84 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58221 Kokoro: Kokoro Reviewed-by: James Price Commit-Queue: Ben Clayton Commit-Queue: Ben Clayton --- src/clone_context.h | 40 ++++++++++++++++++++---- src/clone_context_test.cc | 65 +++++++++++++++++++++++++++++++++++++-- 2 files changed, 97 insertions(+), 8 deletions(-) diff --git a/src/clone_context.h b/src/clone_context.h index 6c89ef7fe3..8be0b9f196 100644 --- a/src/clone_context.h +++ b/src/clone_context.h @@ -109,7 +109,9 @@ class CloneContext { // Was Replace() called for this object? auto it = replacements_.find(a); if (it != replacements_.end()) { - return CheckedCast(it->second); + auto* replacement = it->second(); + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, replacement); + return CheckedCast(replacement); } Cloneable* cloned = nullptr; @@ -342,8 +344,10 @@ class CloneContext { return *this; } - /// Replace replaces all occurrences of `what` in #src with `with` in #dst - /// when calling Clone(). + /// Replace replaces all occurrences of `what` in #src with the pointer `with` + /// in #dst when calling Clone(). + /// [DEPRECATED]: This function cannot handle nested replacements. Use the + /// overload of Replace() that take a function for the `WITH` argument. /// @param what a pointer to the object in #src that will be replaced with /// `with` /// @param with a pointer to the replacement object owned by #dst that will be @@ -352,10 +356,32 @@ class CloneContext { /// references of the original object. A type mismatch will result in an /// assertion in debug builds, and undefined behavior in release builds. /// @returns this CloneContext so calls can be chained - template + template > CloneContext& Replace(WHAT* what, WITH* with) { TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what); TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, with); + replacements_[what] = [with]() -> Cloneable* { return with; }; + return *this; + } + + /// Replace replaces all occurrences of `what` in #src with the result of the + /// function `with` in #dst when calling Clone(). `with` will be called each + /// time `what` is cloned by this context. If `what` is not cloned, then + /// `with` may never be called. + /// @param what a pointer to the object in #src that will be replaced with + /// `with` + /// @param with a function that takes no arguments and returns a pointer to + /// the replacement object owned by #dst. The returned pointer will be used as + /// a replacement for `what`. + /// @warning The replacement object must be of the correct type for all + /// references of the original object. A type mismatch will result in an + /// assertion in debug builds, and undefined behavior in release builds. + /// @returns this CloneContext so calls can be chained + template > + CloneContext& Replace(WHAT* what, WITH&& with) { + TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what); replacements_[what] = with; return *this; } @@ -532,8 +558,10 @@ class CloneContext { std::unordered_map insert_after_; }; - /// A map of object in #src to their replacement in #dst - std::unordered_map replacements_; + /// A map of object in #src to functions that create their replacement in + /// #dst + std::unordered_map> + replacements_; /// A map of symbol in #src to their cloned equivalent in #dst std::unordered_map cloned_symbols_; diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc index a9865a2f38..391c7c6026 100644 --- a/src/clone_context_test.cc +++ b/src/clone_context_test.cc @@ -291,7 +291,7 @@ TEST_F(CloneContextNodeTest, CloneWithoutTransform) { EXPECT_EQ(cloned_node->name, cloned.Symbols().Get("root")); } -TEST_F(CloneContextNodeTest, CloneWithReplace) { +TEST_F(CloneContextNodeTest, CloneWithReplacePointer) { Allocator a; ProgramBuilder builder; @@ -323,6 +323,39 @@ TEST_F(CloneContextNodeTest, CloneWithReplace) { EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c")); } +TEST_F(CloneContextNodeTest, CloneWithReplaceFunction) { + Allocator a; + + ProgramBuilder builder; + auto* original_root = a.Create(builder.Symbols().New("root")); + original_root->a = a.Create(builder.Symbols().New("a")); + original_root->b = a.Create(builder.Symbols().New("b")); + original_root->c = a.Create(builder.Symbols().New("c")); + Program original(std::move(builder)); + + // root + // ╭──────────────────┼──────────────────╮ + // (a) (b) (c) + // Replaced + + ProgramBuilder cloned; + auto* replacement = a.Create(cloned.Symbols().New("replacement")); + + auto* cloned_root = + CloneContext(&cloned, &original) + .Replace(original_root->b, [=] { return replacement; }) + .Clone(original_root); + + EXPECT_NE(cloned_root->a, replacement); + EXPECT_EQ(cloned_root->b, replacement); + EXPECT_NE(cloned_root->c, replacement); + + EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root")); + EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a")); + EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("replacement")); + EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c")); +} + TEST_F(CloneContextNodeTest, CloneWithRemove) { Allocator a; @@ -638,7 +671,7 @@ TEST_F(CloneContextNodeTest, CloneWithReplaceAll_DerivedThenBase) { replaceable_name); } -TEST_F(CloneContextNodeTest, CloneWithReplace_WithNotANode) { +TEST_F(CloneContextNodeTest, CloneWithReplacePointer_WithNotANode) { EXPECT_FATAL_FAILURE( { Allocator allocator; @@ -666,6 +699,34 @@ TEST_F(CloneContextNodeTest, CloneWithReplace_WithNotANode) { "internal compiler error"); } +TEST_F(CloneContextNodeTest, CloneWithReplaceFunction_WithNotANode) { + EXPECT_FATAL_FAILURE( + { + Allocator allocator; + ProgramBuilder builder; + auto* original_root = + allocator.Create(builder.Symbols().New("root")); + original_root->a = allocator.Create(builder.Symbols().New("a")); + original_root->b = allocator.Create(builder.Symbols().New("b")); + original_root->c = allocator.Create(builder.Symbols().New("c")); + Program original(std::move(builder)); + + // root + // ╭──────────────────┼──────────────────╮ + // (a) (b) (c) + // Replaced + + ProgramBuilder cloned; + auto* replacement = allocator.Create(); + + CloneContext ctx(&cloned, &original); + ctx.Replace(original_root->b, [=] { return replacement; }); + + ctx.Clone(original_root); + }, + "internal compiler error"); +} + using CloneContextTest = ::testing::Test; TEST_F(CloneContextTest, CloneWithReplaceAll_SymbolsTwice) {