CloneContext: Drop the first CloneContext* parameter from the ReplaceAll() callback

You have to have the CloneContext in order to call ReplaceAll() in the first place. The overhead of capturing the pointer in the closure is negligible.

Cleans up the callsites.

Change-Id: I3a0fd808517d69f19756f590f3426e5ba226c57e
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/42840
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2021-02-26 19:33:56 +00:00 committed by Commit Bot service account
parent d5638c93c5
commit a8b20bec7b
6 changed files with 88 additions and 87 deletions

View File

@ -190,11 +190,13 @@ class CloneContext {
ast::FunctionList Clone(const ast::FunctionList& v); ast::FunctionList Clone(const ast::FunctionList& v);
/// ReplaceAll() registers `replacer` to be called whenever the Clone() method /// ReplaceAll() registers `replacer` to be called whenever the Clone() method
/// is called with a type that matches (or derives from) the type of the /// is called with a Cloneable type that matches (or derives from) the type of
/// second parameter of `replacer`. /// the single parameter of `replacer`.
/// The returned Cloneable of `replacer` will be used as the replacement for
/// all references to the object that's being cloned. This returned Cloneable
/// must be owned by the Program #dst.
/// ///
/// `replacer` must be function-like with the signature: /// `replacer` must be function-like with the signature: `T* (T*)`
/// `T* (CloneContext*, T*)`
/// where `T` is a type deriving from Cloneable. /// where `T` is a type deriving from Cloneable.
/// ///
/// If `replacer` returns a nullptr then Clone() will attempt the next /// If `replacer` returns a nullptr then Clone() will attempt the next
@ -206,28 +208,29 @@ class CloneContext {
/// ///
/// ``` /// ```
/// // Replace all ast::UintLiterals with the number 42 /// // Replace all ast::UintLiterals with the number 42
/// CloneCtx ctx(&out, in) /// CloneCtx ctx(&out, in);
/// .ReplaceAll([&] (CloneContext* ctx, ast::UintLiteral* l) { /// ctx.ReplaceAll([&] (ast::UintLiteral* l) {
/// return ctx->dst->create<ast::UintLiteral>( /// return ctx->dst->create<ast::UintLiteral>(
/// ctx->Clone(l->source()), /// ctx->Clone(l->source()),
/// ctx->Clone(l->type()), /// ctx->Clone(l->type()),
/// 42); /// 42);
/// }).Clone(); /// });
/// ctx.Clone();
/// ``` /// ```
/// ///
/// @warning The replacement object must be of the correct type for all /// @warning The replacement object must be of the correct type for all
/// references of the original object. A type mismatch will result in an /// references of the original object. A type mismatch will result in an
/// assertion in debug builds, and undefined behavior in release builds. /// assertion in debug builds, and undefined behavior in release builds.
/// @param replacer a function or function-like object with the signature /// @param replacer a function or function-like object with the signature
/// `T* (CloneContext*, T*)`, where `T` derives from Cloneable /// `T* (T*)`, where `T` derives from Cloneable
/// @returns this CloneContext so calls can be chained /// @returns this CloneContext so calls can be chained
template <typename F> template <typename F>
CloneContext& ReplaceAll(F replacer) { CloneContext& ReplaceAll(F&& replacer) {
using TPtr = traits::ParamTypeT<F, 1>; using TPtr = traits::ParamTypeT<F, 0>;
using T = typename std::remove_pointer<TPtr>::type; using T = typename std::remove_pointer<TPtr>::type;
transforms_.emplace_back([=](Cloneable* in) { transforms_.emplace_back([=](Cloneable* in) {
auto* in_as_t = in->As<T>(); auto* in_as_t = in->As<T>();
return in_as_t != nullptr ? replacer(this, in_as_t) : nullptr; return in_as_t != nullptr ? replacer(in_as_t) : nullptr;
}); });
return *this; return *this;
} }

View File

@ -138,15 +138,15 @@ TEST(CloneContext, CloneWithReplacements) {
// R: Replaceable // R: Replaceable
ProgramBuilder cloned; ProgramBuilder cloned;
auto* cloned_root =
CloneContext(&cloned, &original) CloneContext ctx(&cloned, &original);
.ReplaceAll([&](CloneContext* ctx, Replaceable* in) { ctx.ReplaceAll([&](Replaceable* in) {
auto* out = cloned.create<Replacement>("replacement:" + in->name); auto* out = cloned.create<Replacement>("replacement:" + in->name);
out->b = cloned.create<Node>("replacement-child:" + in->name); out->b = cloned.create<Node>("replacement-child:" + in->name);
out->c = ctx->Clone(in->a); out->c = ctx.Clone(in->a);
return out; return out;
}) });
.Clone(original_root); auto* cloned_root = ctx.Clone(original_root);
// root // root
// ╭─────────────────┼──────────────────╮ // ╭─────────────────┼──────────────────╮

View File

@ -58,11 +58,11 @@ BoundArrayAccessors::~BoundArrayAccessors() = default;
Transform::Output BoundArrayAccessors::Run(const Program* in) { Transform::Output BoundArrayAccessors::Run(const Program* in) {
ProgramBuilder out; ProgramBuilder out;
CloneContext(&out, in) CloneContext ctx(&out, in);
.ReplaceAll([&](CloneContext* ctx, ast::ArrayAccessorExpression* expr) { ctx.ReplaceAll([&](ast::ArrayAccessorExpression* expr) {
return Transform(expr, ctx); return Transform(expr, &ctx);
}) });
.Clone(); ctx.Clone();
return Output(Program(std::move(out))); return Output(Program(std::move(out)));
} }

View File

@ -64,24 +64,23 @@ Transform::Output EmitVertexPointSize::Run(const Program* in) {
out.AST().AddGlobalVariable(pointsize_var); out.AST().AddGlobalVariable(pointsize_var);
// Add the pointsize assignment statement to the front of all vertex stages. // Add the pointsize assignment statement to the front of all vertex stages.
CloneContext(&out, in) CloneContext ctx(&out, in);
.ReplaceAll( ctx.ReplaceAll([&](ast::Function* func) -> ast::Function* {
[&](CloneContext* ctx, ast::Function* func) -> ast::Function* { if (func->pipeline_stage() != ast::PipelineStage::kVertex) {
if (func->pipeline_stage() != ast::PipelineStage::kVertex) { return nullptr; // Just clone func
return nullptr; // Just clone func }
}
// Build the AST expression & statement for assigning pointsize one. // Build the AST expression & statement for assigning pointsize one.
auto* one = out.create<ast::ScalarConstructorExpression>( auto* one = out.create<ast::ScalarConstructorExpression>(
Source{}, out.create<ast::FloatLiteral>(Source{}, f32, 1.0f)); Source{}, out.create<ast::FloatLiteral>(Source{}, f32, 1.0f));
auto* pointsize_ident = out.create<ast::IdentifierExpression>( auto* pointsize_ident = out.create<ast::IdentifierExpression>(
Source{}, out.Symbols().Register(kPointSizeVar)); Source{}, out.Symbols().Register(kPointSizeVar));
auto* pointsize_assign = out.create<ast::AssignmentStatement>( auto* pointsize_assign =
Source{}, pointsize_ident, one); out.create<ast::AssignmentStatement>(Source{}, pointsize_ident, one);
return CloneWithStatementsAtStart(ctx, func, {pointsize_assign}); return CloneWithStatementsAtStart(&ctx, func, {pointsize_assign});
}) });
.Clone(); ctx.Clone();
return Output(Program(std::move(out))); return Output(Program(std::move(out)));
} }

View File

@ -131,53 +131,52 @@ Transform::Output FirstIndexOffset::Run(const Program* in) {
// add a CreateFirstIndexOffset() statement to each function that uses one of // add a CreateFirstIndexOffset() statement to each function that uses one of
// these builtins. // these builtins.
CloneContext(&out, in) CloneContext ctx(&out, in);
.ReplaceAll([&](CloneContext* ctx, ast::Variable* var) -> ast::Variable* { ctx.ReplaceAll([&](ast::Variable* var) -> ast::Variable* {
for (ast::VariableDecoration* dec : var->decorations()) { for (ast::VariableDecoration* dec : var->decorations()) {
if (auto* blt_dec = dec->As<ast::BuiltinDecoration>()) { if (auto* blt_dec = dec->As<ast::BuiltinDecoration>()) {
ast::Builtin blt_type = blt_dec->value(); ast::Builtin blt_type = blt_dec->value();
if (blt_type == ast::Builtin::kVertexIndex) { if (blt_type == ast::Builtin::kVertexIndex) {
vertex_index_sym = var->symbol(); vertex_index_sym = var->symbol();
has_vertex_index_ = true; has_vertex_index_ = true;
return clone_variable_with_new_name( return clone_variable_with_new_name(
ctx, var, &ctx, var,
kIndexOffsetPrefix + in->Symbols().NameFor(var->symbol())); kIndexOffsetPrefix + in->Symbols().NameFor(var->symbol()));
} else if (blt_type == ast::Builtin::kInstanceIndex) { } else if (blt_type == ast::Builtin::kInstanceIndex) {
instance_index_sym = var->symbol(); instance_index_sym = var->symbol();
has_instance_index_ = true; has_instance_index_ = true;
return clone_variable_with_new_name( return clone_variable_with_new_name(
ctx, var, &ctx, var,
kIndexOffsetPrefix + in->Symbols().NameFor(var->symbol())); kIndexOffsetPrefix + in->Symbols().NameFor(var->symbol()));
} }
}
}
return nullptr; // Just clone var
});
ctx.ReplaceAll( // Note: This happens in the same pass as the rename above
// which determines the original builtin variable names,
// but this should be fine, as variables are cloned first.
[&](ast::Function* func) -> ast::Function* {
maybe_create_buffer_var(ctx.dst);
if (buffer_var == nullptr) {
return nullptr; // no transform need, just clone func
}
auto* func_sem = in->Sem().Get(func);
ast::StatementList statements;
for (const auto& data : func_sem->LocalReferencedBuiltinVariables()) {
if (data.second->value() == ast::Builtin::kVertexIndex) {
statements.emplace_back(
CreateFirstIndexOffset(in->Symbols().NameFor(vertex_index_sym),
kFirstVertexName, buffer_var, ctx.dst));
} else if (data.second->value() == ast::Builtin::kInstanceIndex) {
statements.emplace_back(CreateFirstIndexOffset(
in->Symbols().NameFor(instance_index_sym), kFirstInstanceName,
buffer_var, ctx.dst));
} }
} }
return nullptr; // Just clone var return CloneWithStatementsAtStart(&ctx, func, statements);
}) });
.ReplaceAll( // Note: This happens in the same pass as the rename above ctx.Clone();
// which determines the original builtin variable names,
// but this should be fine, as variables are cloned first.
[&](CloneContext* ctx, ast::Function* func) -> ast::Function* {
maybe_create_buffer_var(ctx->dst);
if (buffer_var == nullptr) {
return nullptr; // no transform need, just clone func
}
auto* func_sem = in->Sem().Get(func);
ast::StatementList statements;
for (const auto& data :
func_sem->LocalReferencedBuiltinVariables()) {
if (data.second->value() == ast::Builtin::kVertexIndex) {
statements.emplace_back(CreateFirstIndexOffset(
in->Symbols().NameFor(vertex_index_sym), kFirstVertexName,
buffer_var, ctx->dst));
} else if (data.second->value() == ast::Builtin::kInstanceIndex) {
statements.emplace_back(CreateFirstIndexOffset(
in->Symbols().NameFor(instance_index_sym),
kFirstInstanceName, buffer_var, ctx->dst));
}
}
return CloneWithStatementsAtStart(ctx, func, statements);
})
.Clone();
return Output( return Output(
Program(std::move(out)), Program(std::move(out)),

View File

@ -113,7 +113,7 @@ Transform::Output VertexPulling::Run(const Program* in) {
for (auto& replacement : state.location_replacements) { for (auto& replacement : state.location_replacements) {
ctx.Replace(replacement.from, replacement.to); ctx.Replace(replacement.from, replacement.to);
} }
ctx.ReplaceAll([&](CloneContext*, ast::Function* f) -> ast::Function* { ctx.ReplaceAll([&](ast::Function* f) -> ast::Function* {
if (f == func) { if (f == func) {
return CloneWithStatementsAtStart(&ctx, f, return CloneWithStatementsAtStart(&ctx, f,
{state.CreateVertexPullingPreamble()}); {state.CreateVertexPullingPreamble()});