diff --git a/src/clone_context.h b/src/clone_context.h index e1b4da0205..31c18c7fdc 100644 --- a/src/clone_context.h +++ b/src/clone_context.h @@ -190,11 +190,13 @@ class CloneContext { ast::FunctionList Clone(const ast::FunctionList& v); /// ReplaceAll() registers `replacer` to be called whenever the Clone() method - /// is called with a type that matches (or derives from) the type of the - /// second parameter of `replacer`. + /// is called with a Cloneable type that matches (or derives from) the type of + /// the single parameter of `replacer`. + /// The returned Cloneable of `replacer` will be used as the replacement for + /// all references to the object that's being cloned. This returned Cloneable + /// must be owned by the Program #dst. /// - /// `replacer` must be function-like with the signature: - /// `T* (CloneContext*, T*)` + /// `replacer` must be function-like with the signature: `T* (T*)` /// where `T` is a type deriving from Cloneable. /// /// If `replacer` returns a nullptr then Clone() will attempt the next @@ -206,28 +208,29 @@ class CloneContext { /// /// ``` /// // Replace all ast::UintLiterals with the number 42 - /// CloneCtx ctx(&out, in) - /// .ReplaceAll([&] (CloneContext* ctx, ast::UintLiteral* l) { + /// CloneCtx ctx(&out, in); + /// ctx.ReplaceAll([&] (ast::UintLiteral* l) { /// return ctx->dst->create( /// ctx->Clone(l->source()), /// ctx->Clone(l->type()), /// 42); - /// }).Clone(); + /// }); + /// ctx.Clone(); /// ``` /// /// @warning The replacement object must be of the correct type for all /// references of the original object. A type mismatch will result in an /// assertion in debug builds, and undefined behavior in release builds. /// @param replacer a function or function-like object with the signature - /// `T* (CloneContext*, T*)`, where `T` derives from Cloneable + /// `T* (T*)`, where `T` derives from Cloneable /// @returns this CloneContext so calls can be chained template - CloneContext& ReplaceAll(F replacer) { - using TPtr = traits::ParamTypeT; + CloneContext& ReplaceAll(F&& replacer) { + using TPtr = traits::ParamTypeT; using T = typename std::remove_pointer::type; transforms_.emplace_back([=](Cloneable* in) { auto* in_as_t = in->As(); - return in_as_t != nullptr ? replacer(this, in_as_t) : nullptr; + return in_as_t != nullptr ? replacer(in_as_t) : nullptr; }); return *this; } diff --git a/src/clone_context_test.cc b/src/clone_context_test.cc index 1a97fb827d..bb13dc39e2 100644 --- a/src/clone_context_test.cc +++ b/src/clone_context_test.cc @@ -138,15 +138,15 @@ TEST(CloneContext, CloneWithReplacements) { // R: Replaceable ProgramBuilder cloned; - auto* cloned_root = - CloneContext(&cloned, &original) - .ReplaceAll([&](CloneContext* ctx, Replaceable* in) { - auto* out = cloned.create("replacement:" + in->name); - out->b = cloned.create("replacement-child:" + in->name); - out->c = ctx->Clone(in->a); - return out; - }) - .Clone(original_root); + + CloneContext ctx(&cloned, &original); + ctx.ReplaceAll([&](Replaceable* in) { + auto* out = cloned.create("replacement:" + in->name); + out->b = cloned.create("replacement-child:" + in->name); + out->c = ctx.Clone(in->a); + return out; + }); + auto* cloned_root = ctx.Clone(original_root); // root // ╭─────────────────┼──────────────────╮ diff --git a/src/transform/bound_array_accessors.cc b/src/transform/bound_array_accessors.cc index 4cb38227d0..ec0a768fdc 100644 --- a/src/transform/bound_array_accessors.cc +++ b/src/transform/bound_array_accessors.cc @@ -58,11 +58,11 @@ BoundArrayAccessors::~BoundArrayAccessors() = default; Transform::Output BoundArrayAccessors::Run(const Program* in) { ProgramBuilder out; - CloneContext(&out, in) - .ReplaceAll([&](CloneContext* ctx, ast::ArrayAccessorExpression* expr) { - return Transform(expr, ctx); - }) - .Clone(); + CloneContext ctx(&out, in); + ctx.ReplaceAll([&](ast::ArrayAccessorExpression* expr) { + return Transform(expr, &ctx); + }); + ctx.Clone(); return Output(Program(std::move(out))); } diff --git a/src/transform/emit_vertex_point_size.cc b/src/transform/emit_vertex_point_size.cc index 748021cd49..0dcc50a4b2 100644 --- a/src/transform/emit_vertex_point_size.cc +++ b/src/transform/emit_vertex_point_size.cc @@ -64,24 +64,23 @@ Transform::Output EmitVertexPointSize::Run(const Program* in) { out.AST().AddGlobalVariable(pointsize_var); // Add the pointsize assignment statement to the front of all vertex stages. - CloneContext(&out, in) - .ReplaceAll( - [&](CloneContext* ctx, ast::Function* func) -> ast::Function* { - if (func->pipeline_stage() != ast::PipelineStage::kVertex) { - return nullptr; // Just clone func - } + CloneContext ctx(&out, in); + ctx.ReplaceAll([&](ast::Function* func) -> ast::Function* { + if (func->pipeline_stage() != ast::PipelineStage::kVertex) { + return nullptr; // Just clone func + } - // Build the AST expression & statement for assigning pointsize one. - auto* one = out.create( - Source{}, out.create(Source{}, f32, 1.0f)); - auto* pointsize_ident = out.create( - Source{}, out.Symbols().Register(kPointSizeVar)); - auto* pointsize_assign = out.create( - Source{}, pointsize_ident, one); + // Build the AST expression & statement for assigning pointsize one. + auto* one = out.create( + Source{}, out.create(Source{}, f32, 1.0f)); + auto* pointsize_ident = out.create( + Source{}, out.Symbols().Register(kPointSizeVar)); + auto* pointsize_assign = + out.create(Source{}, pointsize_ident, one); - return CloneWithStatementsAtStart(ctx, func, {pointsize_assign}); - }) - .Clone(); + return CloneWithStatementsAtStart(&ctx, func, {pointsize_assign}); + }); + ctx.Clone(); return Output(Program(std::move(out))); } diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index f5f3026fce..1690d4a8ed 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -131,53 +131,52 @@ Transform::Output FirstIndexOffset::Run(const Program* in) { // add a CreateFirstIndexOffset() statement to each function that uses one of // these builtins. - CloneContext(&out, in) - .ReplaceAll([&](CloneContext* ctx, ast::Variable* var) -> ast::Variable* { - for (ast::VariableDecoration* dec : var->decorations()) { - if (auto* blt_dec = dec->As()) { - ast::Builtin blt_type = blt_dec->value(); - if (blt_type == ast::Builtin::kVertexIndex) { - vertex_index_sym = var->symbol(); - has_vertex_index_ = true; - return clone_variable_with_new_name( - ctx, var, - kIndexOffsetPrefix + in->Symbols().NameFor(var->symbol())); - } else if (blt_type == ast::Builtin::kInstanceIndex) { - instance_index_sym = var->symbol(); - has_instance_index_ = true; - return clone_variable_with_new_name( - ctx, var, - kIndexOffsetPrefix + in->Symbols().NameFor(var->symbol())); - } + CloneContext ctx(&out, in); + ctx.ReplaceAll([&](ast::Variable* var) -> ast::Variable* { + for (ast::VariableDecoration* dec : var->decorations()) { + if (auto* blt_dec = dec->As()) { + ast::Builtin blt_type = blt_dec->value(); + if (blt_type == ast::Builtin::kVertexIndex) { + vertex_index_sym = var->symbol(); + has_vertex_index_ = true; + return clone_variable_with_new_name( + &ctx, var, + kIndexOffsetPrefix + in->Symbols().NameFor(var->symbol())); + } else if (blt_type == ast::Builtin::kInstanceIndex) { + instance_index_sym = var->symbol(); + has_instance_index_ = true; + return clone_variable_with_new_name( + &ctx, var, + kIndexOffsetPrefix + in->Symbols().NameFor(var->symbol())); + } + } + } + return nullptr; // Just clone var + }); + ctx.ReplaceAll( // Note: This happens in the same pass as the rename above + // which determines the original builtin variable names, + // but this should be fine, as variables are cloned first. + [&](ast::Function* func) -> ast::Function* { + maybe_create_buffer_var(ctx.dst); + if (buffer_var == nullptr) { + return nullptr; // no transform need, just clone func + } + auto* func_sem = in->Sem().Get(func); + ast::StatementList statements; + for (const auto& data : func_sem->LocalReferencedBuiltinVariables()) { + if (data.second->value() == ast::Builtin::kVertexIndex) { + statements.emplace_back( + CreateFirstIndexOffset(in->Symbols().NameFor(vertex_index_sym), + kFirstVertexName, buffer_var, ctx.dst)); + } else if (data.second->value() == ast::Builtin::kInstanceIndex) { + statements.emplace_back(CreateFirstIndexOffset( + in->Symbols().NameFor(instance_index_sym), kFirstInstanceName, + buffer_var, ctx.dst)); } } - return nullptr; // Just clone var - }) - .ReplaceAll( // Note: This happens in the same pass as the rename above - // which determines the original builtin variable names, - // but this should be fine, as variables are cloned first. - [&](CloneContext* ctx, ast::Function* func) -> ast::Function* { - maybe_create_buffer_var(ctx->dst); - if (buffer_var == nullptr) { - return nullptr; // no transform need, just clone func - } - auto* func_sem = in->Sem().Get(func); - ast::StatementList statements; - for (const auto& data : - func_sem->LocalReferencedBuiltinVariables()) { - if (data.second->value() == ast::Builtin::kVertexIndex) { - statements.emplace_back(CreateFirstIndexOffset( - in->Symbols().NameFor(vertex_index_sym), kFirstVertexName, - buffer_var, ctx->dst)); - } else if (data.second->value() == ast::Builtin::kInstanceIndex) { - statements.emplace_back(CreateFirstIndexOffset( - in->Symbols().NameFor(instance_index_sym), - kFirstInstanceName, buffer_var, ctx->dst)); - } - } - return CloneWithStatementsAtStart(ctx, func, statements); - }) - .Clone(); + return CloneWithStatementsAtStart(&ctx, func, statements); + }); + ctx.Clone(); return Output( Program(std::move(out)), diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index 273806f7ce..3a7adb4e59 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -113,7 +113,7 @@ Transform::Output VertexPulling::Run(const Program* in) { for (auto& replacement : state.location_replacements) { ctx.Replace(replacement.from, replacement.to); } - ctx.ReplaceAll([&](CloneContext*, ast::Function* f) -> ast::Function* { + ctx.ReplaceAll([&](ast::Function* f) -> ast::Function* { if (f == func) { return CloneWithStatementsAtStart(&ctx, f, {state.CreateVertexPullingPreamble()});