CloneContext: Add support for transforming symbols

Will be used by a Renamer transform

Change-Id: Ic0e9b69874f51103f0beec7745d32a9f8419e93a
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/42841
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
Ben Clayton 2021-03-04 10:42:55 +00:00 committed by Commit Bot service account
parent ec44eef965
commit 5ca9741b69
3 changed files with 248 additions and 65 deletions

View File

@ -28,6 +28,9 @@ CloneContext::CloneContext(ProgramBuilder* to, Program const* from)
CloneContext::~CloneContext() = default;
Symbol CloneContext::Clone(const Symbol& s) const {
if (symbol_transform_) {
return symbol_transform_(s);
}
return dst->Symbols().Register(src->Symbols().NameFor(s));
}
@ -48,4 +51,9 @@ diag::List& CloneContext::Diagnostics() const {
return dst->Diagnostics();
}
CloneContext::CloneableTransform::CloneableTransform() = default;
CloneContext::CloneableTransform::CloneableTransform(
const CloneableTransform&) = default;
CloneContext::CloneableTransform::~CloneableTransform() = default;
} // namespace tint

View File

@ -18,6 +18,7 @@
#include <cassert>
#include <functional>
#include <unordered_map>
#include <utility>
#include <vector>
#include "src/castable.h"
@ -48,7 +49,18 @@ class Cloneable : public Castable<Cloneable> {
/// CloneContext holds the state used while cloning AST nodes and types.
class CloneContext {
/// ParamTypeIsPtrOf<F, T>::value is true iff the first parameter of
/// F is a pointer of (or derives from) type T.
template <typename F, typename T>
using ParamTypeIsPtrOf = traits::IsTypeOrDerived<
typename std::remove_pointer<traits::ParamTypeT<F, 0>>::type,
T>;
public:
/// SymbolTransform is a function that takes a symbol and returns a new
/// symbol.
using SymbolTransform = std::function<Symbol(Symbol)>;
/// Constructor
/// @param to the target ProgramBuilder to clone into
/// @param from the source Program to clone from
@ -199,10 +211,8 @@ class CloneContext {
/// `replacer` must be function-like with the signature: `T* (T*)`
/// where `T` is a type deriving from Cloneable.
///
/// If `replacer` returns a nullptr then Clone() will attempt the next
/// registered replacer function that matches the object type. If no replacers
/// match the object type, or all returned nullptr then Clone() will call
/// `T::Clone()` to clone the object.
/// If `replacer` returns a nullptr then Clone() will call `T::Clone()` to
/// clone the object.
///
/// Example:
///
@ -218,6 +228,9 @@ class CloneContext {
/// ctx.Clone();
/// ```
///
/// @warning a single handler can only be registered for any given type.
/// Attempting to register two handlers for the same type will result in an
/// ICE.
/// @warning The replacement object must be of the correct type for all
/// references of the original object. A type mismatch will result in an
/// assertion in debug builds, and undefined behavior in release builds.
@ -225,13 +238,44 @@ class CloneContext {
/// `T* (T*)`, where `T` derives from Cloneable
/// @returns this CloneContext so calls can be chained
template <typename F>
CloneContext& ReplaceAll(F&& replacer) {
traits::EnableIf<ParamTypeIsPtrOf<F, Cloneable>::value, CloneContext>&
ReplaceAll(F&& replacer) {
using TPtr = traits::ParamTypeT<F, 0>;
using T = typename std::remove_pointer<TPtr>::type;
transforms_.emplace_back([=](Cloneable* in) {
auto* in_as_t = in->As<T>();
return in_as_t != nullptr ? replacer(in_as_t) : nullptr;
});
for (auto& transform : transforms_) {
if (transform.typeinfo->Is(TypeInfo::Of<T>()) ||
TypeInfo::Of<T>().Is(*transform.typeinfo)) {
TINT_ICE(Diagnostics())
<< "ReplaceAll() called with a handler for type "
<< TypeInfo::Of<T>().name
<< " that is already handled by a handler for type "
<< transform.typeinfo->name;
return *this;
}
}
CloneableTransform transform;
transform.typeinfo = &TypeInfo::Of<T>();
transform.function = [=](Cloneable* in) { return replacer(in->As<T>()); };
transforms_.emplace_back(std::move(transform));
return *this;
}
/// ReplaceAll() registers `replacer` to be called whenever the Clone() method
/// is called with a Symbol.
/// The returned symbol of `replacer` will be used as the replacement for
/// all references to the symbol that's being cloned. This returned Symbol
/// must be owned by the Program #dst.
/// @param replacer a function the signature `Symbol(Symbol)`.
/// @warning a SymbolTransform can only be registered once. Attempting to
/// register a SymbolTransform more than once will result in an ICE.
/// @returns this CloneContext so calls can be chained
CloneContext& ReplaceAll(const SymbolTransform& replacer) {
if (symbol_transform_) {
TINT_ICE(Diagnostics()) << "ReplaceAll(const SymbolTransform&) called "
"multiple times on the same CloneContext";
return *this;
}
symbol_transform_ = replacer;
return *this;
}
@ -276,7 +320,19 @@ class CloneContext {
Program const* const src;
private:
using Transform = std::function<Cloneable*(Cloneable*)>;
struct CloneableTransform {
/// Constructor
CloneableTransform();
/// Copy constructor
/// @param other the CloneableTransform to copy
CloneableTransform(const CloneableTransform& other);
/// Destructor
~CloneableTransform();
// TypeInfo of the Cloneable that the transform operates on
const TypeInfo* typeinfo;
std::function<Cloneable*(Cloneable*)> function;
};
CloneContext(const CloneContext&) = delete;
CloneContext& operator=(const CloneContext&) = delete;
@ -293,11 +349,16 @@ class CloneContext {
}
// Attempt to clone using the registered replacer functions.
for (auto& f : transforms_) {
if (Cloneable* c = f(a)) {
auto& typeinfo = a->TypeInfo();
for (auto& transform : transforms_) {
if (!typeinfo.Is(*transform.typeinfo)) {
continue;
}
if (Cloneable* c = transform.function(a)) {
cloned_.emplace(a, c);
return c;
}
break;
}
// No luck, Clone() will have to call T::Clone().
@ -329,8 +390,11 @@ class CloneContext {
/// into the target vector/ before cloning and inserting the map-key.
std::unordered_map<Cloneable*, CloneableList> insert_before_;
/// Transform functions registered with ReplaceAll()
std::vector<Transform> transforms_;
/// Cloneable transform functions registered with ReplaceAll()
std::vector<CloneableTransform> transforms_;
/// Symbol transform registered with ReplaceAll()
SymbolTransform symbol_transform_;
};
} // namespace tint

View File

@ -26,16 +26,16 @@ namespace tint {
namespace {
struct Node : public Castable<Node, ast::Node> {
explicit Node(const Source& source, std::string n) : Base(source), name(n) {}
explicit Node(const Source& source, Symbol n) : Base(source), name(n) {}
std::string name;
Symbol name;
Node* a = nullptr;
Node* b = nullptr;
Node* c = nullptr;
std::vector<Node*> vec;
Node* Clone(CloneContext* ctx) const override {
auto* out = ctx->dst->create<Node>(name);
auto* out = ctx->dst->create<Node>(ctx->Clone(name));
out->a = ctx->Clone(a);
out->b = ctx->Clone(b);
out->c = ctx->Clone(c);
@ -48,10 +48,10 @@ struct Node : public Castable<Node, ast::Node> {
};
struct Replaceable : public Castable<Replaceable, Node> {
explicit Replaceable(const Source& source, std::string n) : Base(source, n) {}
explicit Replaceable(const Source& source, Symbol n) : Base(source, n) {}
};
struct Replacement : public Castable<Replacement, Replaceable> {
explicit Replacement(const Source& source, std::string n) : Base(source, n) {}
explicit Replacement(const Source& source, Symbol n) : Base(source, n) {}
};
struct NotANode : public Castable<NotANode, ast::Node> {
@ -67,12 +67,15 @@ struct NotANode : public Castable<NotANode, ast::Node> {
TEST(CloneContext, Clone) {
ProgramBuilder builder;
auto* original_root = builder.create<Node>("root");
original_root->a = builder.create<Node>("a");
original_root->a->b = builder.create<Node>("a->b");
original_root->b = builder.create<Node>("b");
auto* original_root =
builder.create<Node>(builder.Symbols().Register("root"));
original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
original_root->a->b =
builder.create<Node>(builder.Symbols().Register("a->b"));
original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
original_root->b->a = original_root->a; // Aliased
original_root->b->b = builder.create<Node>("b->b");
original_root->b->b =
builder.create<Node>(builder.Symbols().Register("b->b"));
original_root->c = original_root->b; // Aliased
Program original(std::move(builder));
@ -106,22 +109,25 @@ TEST(CloneContext, Clone) {
EXPECT_NE(cloned_root->b->b, original_root->b->b);
EXPECT_NE(cloned_root->c, original_root->c);
EXPECT_EQ(cloned_root->name, "root");
EXPECT_EQ(cloned_root->a->name, "a");
EXPECT_EQ(cloned_root->a->b->name, "a->b");
EXPECT_EQ(cloned_root->b->name, "b");
EXPECT_EQ(cloned_root->b->b->name, "b->b");
EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
EXPECT_EQ(cloned_root->a->b->name, cloned.Symbols().Get("a->b"));
EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("b"));
EXPECT_EQ(cloned_root->b->b->name, cloned.Symbols().Get("b->b"));
EXPECT_EQ(cloned_root->b->a, cloned_root->a); // Aliased
EXPECT_EQ(cloned_root->c, cloned_root->b); // Aliased
}
TEST(CloneContext, CloneWithReplacements) {
TEST(CloneContext, CloneWithReplaceAll_Cloneable) {
ProgramBuilder builder;
auto* original_root = builder.create<Node>("root");
original_root->a = builder.create<Node>("a");
original_root->a->b = builder.create<Replaceable>("a->b");
original_root->b = builder.create<Replaceable>("b");
auto* original_root =
builder.create<Node>(builder.Symbols().Register("root"));
original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
original_root->a->b =
builder.create<Replaceable>(builder.Symbols().Register("a->b"));
original_root->b =
builder.create<Replaceable>(builder.Symbols().Register("b"));
original_root->b->a = original_root->a; // Aliased
original_root->c = original_root->b; // Aliased
Program original(std::move(builder));
@ -141,8 +147,12 @@ TEST(CloneContext, CloneWithReplacements) {
CloneContext ctx(&cloned, &original);
ctx.ReplaceAll([&](Replaceable* in) {
auto* out = cloned.create<Replacement>("replacement:" + in->name);
out->b = cloned.create<Node>("replacement-child:" + in->name);
auto out_name = cloned.Symbols().Register(
"replacement:" + original.Symbols().NameFor(in->name));
auto b_name = cloned.Symbols().Register(
"replacement-child:" + original.Symbols().NameFor(in->name));
auto* out = cloned.create<Replacement>(out_name);
out->b = cloned.create<Node>(b_name);
out->c = ctx.Clone(in->a);
return out;
});
@ -181,12 +191,14 @@ TEST(CloneContext, CloneWithReplacements) {
EXPECT_NE(cloned_root->b->a, original_root->b->a);
EXPECT_NE(cloned_root->c, original_root->c);
EXPECT_EQ(cloned_root->name, "root");
EXPECT_EQ(cloned_root->a->name, "a");
EXPECT_EQ(cloned_root->a->b->name, "replacement:a->b");
EXPECT_EQ(cloned_root->a->b->b->name, "replacement-child:a->b");
EXPECT_EQ(cloned_root->b->name, "replacement:b");
EXPECT_EQ(cloned_root->b->b->name, "replacement-child:b");
EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
EXPECT_EQ(cloned_root->a->b->name, cloned.Symbols().Get("replacement:a->b"));
EXPECT_EQ(cloned_root->a->b->b->name,
cloned.Symbols().Get("replacement-child:a->b"));
EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("replacement:b"));
EXPECT_EQ(cloned_root->b->b->name,
cloned.Symbols().Get("replacement-child:b"));
EXPECT_EQ(cloned_root->b->c, cloned_root->a); // Aliased
EXPECT_EQ(cloned_root->c, cloned_root->b); // Aliased
@ -198,12 +210,53 @@ TEST(CloneContext, CloneWithReplacements) {
EXPECT_FALSE(cloned_root->b->b->Is<Replacement>());
}
TEST(CloneContext, CloneWithReplaceAll_Symbols) {
ProgramBuilder builder;
auto* original_root =
builder.create<Node>(builder.Symbols().Register("root"));
original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
original_root->a->b =
builder.create<Node>(builder.Symbols().Register("a->b"));
original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
original_root->b->a = original_root->a; // Aliased
original_root->b->b =
builder.create<Node>(builder.Symbols().Register("b->b"));
original_root->c = original_root->b; // Aliased
Program original(std::move(builder));
// root
// ╭──────────────────┼──────────────────╮
// (a) (b) (c)
// N <──────┐ N <───────────────┘
// ╭────┼────╮ │ ╭────┼────╮
// (a) (b) (c) │ (a) (b) (c)
// N └───┘ N
//
// N: Node
ProgramBuilder cloned;
auto* cloned_root = CloneContext(&cloned, &original)
.ReplaceAll([&](Symbol sym) {
auto in = original.Symbols().NameFor(sym);
auto out = "transformed<" + in + ">";
return cloned.Symbols().Register(out);
})
.Clone(original_root);
EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("transformed<root>"));
EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("transformed<a>"));
EXPECT_EQ(cloned_root->a->b->name, cloned.Symbols().Get("transformed<a->b>"));
EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("transformed<b>"));
EXPECT_EQ(cloned_root->b->b->name, cloned.Symbols().Get("transformed<b->b>"));
}
TEST(CloneContext, CloneWithReplace) {
ProgramBuilder builder;
auto* original_root = builder.create<Node>("root");
original_root->a = builder.create<Node>("a");
original_root->b = builder.create<Node>("b");
original_root->c = builder.create<Node>("c");
auto* original_root =
builder.create<Node>(builder.Symbols().Register("root"));
original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
original_root->c = builder.create<Node>(builder.Symbols().Register("c"));
Program original(std::move(builder));
// root
@ -212,7 +265,8 @@ TEST(CloneContext, CloneWithReplace) {
// Replaced
ProgramBuilder cloned;
auto* replacement = cloned.create<Node>("replacement");
auto* replacement =
cloned.create<Node>(cloned.Symbols().Register("replacement"));
auto* cloned_root = CloneContext(&cloned, &original)
.Replace(original_root->b, replacement)
@ -222,23 +276,24 @@ TEST(CloneContext, CloneWithReplace) {
EXPECT_EQ(cloned_root->b, replacement);
EXPECT_NE(cloned_root->c, replacement);
EXPECT_EQ(cloned_root->name, "root");
EXPECT_EQ(cloned_root->a->name, "a");
EXPECT_EQ(cloned_root->b->name, "replacement");
EXPECT_EQ(cloned_root->c->name, "c");
EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
EXPECT_EQ(cloned_root->a->name, cloned.Symbols().Get("a"));
EXPECT_EQ(cloned_root->b->name, cloned.Symbols().Get("replacement"));
EXPECT_EQ(cloned_root->c->name, cloned.Symbols().Get("c"));
}
TEST(CloneContext, CloneWithInsertBefore) {
ProgramBuilder builder;
auto* original_root = builder.create<Node>("root");
original_root->a = builder.create<Node>("a");
original_root->b = builder.create<Node>("b");
original_root->c = builder.create<Node>("c");
auto* original_root =
builder.create<Node>(builder.Symbols().Register("root"));
original_root->a = builder.create<Node>(builder.Symbols().Register("a"));
original_root->b = builder.create<Node>(builder.Symbols().Register("b"));
original_root->c = builder.create<Node>(builder.Symbols().Register("c"));
original_root->vec = {original_root->a, original_root->b, original_root->c};
Program original(std::move(builder));
ProgramBuilder cloned;
auto* insertion = cloned.create<Node>("insertion");
auto* insertion = cloned.create<Node>(cloned.Symbols().Register("insertion"));
auto* cloned_root = CloneContext(&cloned, &original)
.InsertBefore(original_root->b, insertion)
@ -249,21 +304,77 @@ TEST(CloneContext, CloneWithInsertBefore) {
EXPECT_EQ(cloned_root->vec[2], cloned_root->b);
EXPECT_EQ(cloned_root->vec[3], cloned_root->c);
EXPECT_EQ(cloned_root->name, "root");
EXPECT_EQ(cloned_root->vec[0]->name, "a");
EXPECT_EQ(cloned_root->vec[1]->name, "insertion");
EXPECT_EQ(cloned_root->vec[2]->name, "b");
EXPECT_EQ(cloned_root->vec[3]->name, "c");
EXPECT_EQ(cloned_root->name, cloned.Symbols().Get("root"));
EXPECT_EQ(cloned_root->vec[0]->name, cloned.Symbols().Get("a"));
EXPECT_EQ(cloned_root->vec[1]->name, cloned.Symbols().Get("insertion"));
EXPECT_EQ(cloned_root->vec[2]->name, cloned.Symbols().Get("b"));
EXPECT_EQ(cloned_root->vec[3]->name, cloned.Symbols().Get("c"));
}
TEST(CloneContext, CloneWithReplaceAll_SameTypeTwice) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder cloned;
Program original;
CloneContext ctx(&cloned, &original);
ctx.ReplaceAll([](Node*) { return nullptr; });
ctx.ReplaceAll([](Node*) { return nullptr; });
},
"internal compiler error: ReplaceAll() called with a handler for type "
"Node that is already handled by a handler for type Node");
}
TEST(CloneContext, CloneWithReplaceAll_BaseThenDerived) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder cloned;
Program original;
CloneContext ctx(&cloned, &original);
ctx.ReplaceAll([](Node*) { return nullptr; });
ctx.ReplaceAll([](Replaceable*) { return nullptr; });
},
"internal compiler error: ReplaceAll() called with a handler for type "
"Replaceable that is already handled by a handler for type Node");
}
TEST(CloneContext, CloneWithReplaceAll_DerivedThenBase) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder cloned;
Program original;
CloneContext ctx(&cloned, &original);
ctx.ReplaceAll([](Replaceable*) { return nullptr; });
ctx.ReplaceAll([](Node*) { return nullptr; });
},
"internal compiler error: ReplaceAll() called with a handler for type "
"Node that is already handled by a handler for type Replaceable");
}
TEST(CloneContext, CloneWithReplaceAll_SymbolsTwice) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder cloned;
Program original;
CloneContext ctx(&cloned, &original);
ctx.ReplaceAll([](Symbol s) { return s; });
ctx.ReplaceAll([](Symbol s) { return s; });
},
"internal compiler error: ReplaceAll(const SymbolTransform&) called "
"multiple times on the same CloneContext");
}
TEST(CloneContext, CloneWithReplace_WithNotANode) {
EXPECT_FATAL_FAILURE(
{
ProgramBuilder builder;
auto* original_root = builder.create<Node>("root");
original_root->a = builder.create<Node>("a");
original_root->b = builder.create<Node>("b");
original_root->c = builder.create<Node>("c");
auto* original_root =
builder.create<Node>(builder.Symbols().Register("root"));
original_root->a =
builder.create<Node>(builder.Symbols().Register("a"));
original_root->b =
builder.create<Node>(builder.Symbols().Register("b"));
original_root->c =
builder.create<Node>(builder.Symbols().Register("c"));
Program original(std::move(builder));
// root