CloneContext: Add an overload of Replace() that takes a function

Replace(T* what, T* with) is bug-prone, as more complex transforms may want to clone `what` multiple times, or not at all. In both cases, this will likely result in an ICE as either the replacement will be reachable multiple times, or not at all.

This is the cause of some of the CTS failures reported in crbug.com/tint/993.

Bug: tint:993
Change-Id: I880ece45faab0e7f07230a1b4436f4e9846edc84
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58221
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-07-15 19:10:15 +00:00 committed by Tint LUCI CQ
parent 60dae2490d
commit f4e213cf2d
2 changed files with 97 additions and 8 deletions

View File

@ -109,7 +109,9 @@ class CloneContext {
// Was Replace() called for this object?
auto it = replacements_.find(a);
if (it != replacements_.end()) {
return CheckedCast<T>(it->second);
auto* replacement = it->second();
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, replacement);
return CheckedCast<T>(replacement);
}
Cloneable* cloned = nullptr;
@ -342,8 +344,10 @@ class CloneContext {
return *this;
}
/// Replace replaces all occurrences of `what` in #src with `with` in #dst
/// when calling Clone().
/// Replace replaces all occurrences of `what` in #src with the pointer `with`
/// in #dst when calling Clone().
/// [DEPRECATED]: This function cannot handle nested replacements. Use the
/// overload of Replace() that take a function for the `WITH` argument.
/// @param what a pointer to the object in #src that will be replaced with
/// `with`
/// @param with a pointer to the replacement object owned by #dst that will be
@ -352,10 +356,32 @@ class CloneContext {
/// references of the original object. A type mismatch will result in an
/// assertion in debug builds, and undefined behavior in release builds.
/// @returns this CloneContext so calls can be chained
template <typename WHAT, typename WITH>
template <typename WHAT,
typename WITH,
typename = traits::EnableIfIsType<WITH, Cloneable>>
CloneContext& Replace(WHAT* what, WITH* with) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what);
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, dst, with);
replacements_[what] = [with]() -> Cloneable* { return with; };
return *this;
}
/// Replace replaces all occurrences of `what` in #src with the result of the
/// function `with` in #dst when calling Clone(). `with` will be called each
/// time `what` is cloned by this context. If `what` is not cloned, then
/// `with` may never be called.
/// @param what a pointer to the object in #src that will be replaced with
/// `with`
/// @param with a function that takes no arguments and returns a pointer to
/// the replacement object owned by #dst. The returned pointer will be used as
/// a replacement for `what`.
/// @warning The replacement object must be of the correct type for all
/// references of the original object. A type mismatch will result in an
/// assertion in debug builds, and undefined behavior in release builds.
/// @returns this CloneContext so calls can be chained
template <typename WHAT, typename WITH, typename = std::result_of_t<WITH()>>
CloneContext& Replace(WHAT* what, WITH&& with) {
TINT_ASSERT_PROGRAM_IDS_EQUAL_IF_VALID(Clone, src, what);
replacements_[what] = with;
return *this;
}
@ -532,8 +558,10 @@ class CloneContext {
std::unordered_map<const Cloneable*, CloneableList> insert_after_;
};
/// A map of object in #src to their replacement in #dst
std::unordered_map<const Cloneable*, Cloneable*> replacements_;
/// A map of object in #src to functions that create their replacement in
/// #dst
std::unordered_map<const Cloneable*, std::function<Cloneable*()>>
replacements_;
/// A map of symbol in #src to their cloned equivalent in #dst
std::unordered_map<Symbol, Symbol> cloned_symbols_;

View File

@ -291,7 +291,7 @@ TEST_F(CloneContextNodeTest, CloneWithoutTransform) {
EXPECT_EQ(cloned_node->name, cloned.Symbols().Get("root"));
}
TEST_F(CloneContextNodeTest, CloneWithReplace) {
TEST_F(CloneContextNodeTest, CloneWithReplacePointer) {
Allocator a;
ProgramBuilder builder;
@ -323,6 +323,39 @@ TEST_F(CloneContextNodeTest, CloneWithReplace) {
EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c"));
}
TEST_F(CloneContextNodeTest, CloneWithReplaceFunction) {
Allocator a;
ProgramBuilder builder;
auto* original_root = a.Create<Node>(builder.Symbols().New("root"));
original_root->a = a.Create<Node>(builder.Symbols().New("a"));
original_root->b = a.Create<Node>(builder.Symbols().New("b"));
original_root->c = a.Create<Node>(builder.Symbols().New("c"));
Program original(std::move(builder));
// root
// ╭──────────────────┼──────────────────╮
// (a) (b) (c)
// Replaced
ProgramBuilder cloned;
auto* replacement = a.Create<Node>(cloned.Symbols().New("replacement"));
auto* cloned_root =
CloneContext(&cloned, &original)
.Replace(original_root->b, [=] { return replacement; })
.Clone(original_root);
EXPECT_NE(cloned_root->a, replacement);
EXPECT_EQ(cloned_root->b, replacement);
EXPECT_NE(cloned_root->c, replacement);
EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("replacement"));
EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c"));
}
TEST_F(CloneContextNodeTest, CloneWithRemove) {
Allocator a;
@ -638,7 +671,7 @@ TEST_F(CloneContextNodeTest, CloneWithReplaceAll_DerivedThenBase) {
replaceable_name);
}
TEST_F(CloneContextNodeTest, CloneWithReplace_WithNotANode) {
TEST_F(CloneContextNodeTest, CloneWithReplacePointer_WithNotANode) {
EXPECT_FATAL_FAILURE(
{
Allocator allocator;
@ -666,6 +699,34 @@ TEST_F(CloneContextNodeTest, CloneWithReplace_WithNotANode) {
"internal compiler error");
}
TEST_F(CloneContextNodeTest, CloneWithReplaceFunction_WithNotANode) {
EXPECT_FATAL_FAILURE(
{
Allocator allocator;
ProgramBuilder builder;
auto* original_root =
allocator.Create<Node>(builder.Symbols().New("root"));
original_root->a = allocator.Create<Node>(builder.Symbols().New("a"));
original_root->b = allocator.Create<Node>(builder.Symbols().New("b"));
original_root->c = allocator.Create<Node>(builder.Symbols().New("c"));
Program original(std::move(builder));
// root
// ╭──────────────────┼──────────────────╮
// (a) (b) (c)
// Replaced
ProgramBuilder cloned;
auto* replacement = allocator.Create<NotANode>();
CloneContext ctx(&cloned, &original);
ctx.Replace(original_root->b, [=] { return replacement; });
ctx.Clone(original_root);
},
"internal compiler error");
}
using CloneContextTest = ::testing::Test;
TEST_F(CloneContextTest, CloneWithReplaceAll_SymbolsTwice) {