diff --git a/src/ast/clone_context.cc b/src/ast/clone_context.cc index 9ec4bfe894..fe9e849425 100644 --- a/src/ast/clone_context.cc +++ b/src/ast/clone_context.cc @@ -14,11 +14,22 @@ #include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { -CloneContext::CloneContext(Module* m) : mod(m) {} +CloneContext::CloneContext(Module* to, Module const* from) + : mod(to), src(from) {} CloneContext::~CloneContext() = default; +Symbol CloneContext::Clone(const Symbol& s) const { + return mod->RegisterSymbol(src->SymbolToName(s)); +} + +void CloneContext::Clone() { + src->Clone(this); +} + } // namespace ast } // namespace tint diff --git a/src/ast/clone_context.h b/src/ast/clone_context.h index a9e303beb1..64a7b11b88 100644 --- a/src/ast/clone_context.h +++ b/src/ast/clone_context.h @@ -22,6 +22,7 @@ #include "src/ast/traits.h" #include "src/castable.h" #include "src/source.h" +#include "src/symbol.h" namespace tint { namespace ast { @@ -32,19 +33,22 @@ class Module; class CloneContext { public: /// Constructor - /// @param m the target module to clone into - explicit CloneContext(Module* m); + /// @param to the target module to clone into + /// @param from the source module to clone from + CloneContext(Module* to, Module const* from); /// Destructor ~CloneContext(); - /// Clones the `Node` or `type::Type` `a` into the module #mod if `a` is not + /// Clones the Node or type::Type `a` into the module #mod if `a` is not /// null. If `a` is null, then Clone() returns null. If `a` has been cloned /// already by this CloneContext then the same cloned pointer is returned. /// /// Clone() may use a function registered with ReplaceAll() to create a /// transformed version of the object. See ReplaceAll() for more information. /// + /// The Node or type::Type `a` must be owned by the module #src. + /// /// @note Semantic information such as resolved expression type and intrinsic /// information is not cloned. /// @param a the `Node` or `type::Type` to clone @@ -70,15 +74,26 @@ class CloneContext { return static_cast(c); } - /// Clones the `Source` `s` into `mod` + /// Clones the Source `s` into `mod` /// TODO(bclayton) - Currently this 'clone' is a shallow copy. If/when /// `Source.File`s are owned by the `Module` this should make a copy of the /// file. /// @param s the `Source` to clone /// @return the cloned source - Source Clone(const Source& s) { return s; } + Source Clone(const Source& s) const { return s; } + + /// Clones the Symbol `s` into `mod` + /// + /// The Symbol `s` must be owned by the module #src. + /// + /// @param s the Symbol to clone + /// @return the cloned source + Symbol Clone(const Symbol& s) const; /// Clones each of the elements of the vector `v` into the module #mod-> + /// + /// All the elements of the vector `v` must be owned by the module #src. + /// /// @param v the vector to clone /// @return the cloned vector template @@ -92,11 +107,12 @@ class CloneContext { } /// ReplaceAll() registers `replacer` to be called whenever the Clone() method - /// is called with a type that matches (or derives from) the type of the first - /// parameter of `replacer`. + /// is called with a type that matches (or derives from) the type of the + /// second parameter of `replacer`. /// - /// `replacer` must be function-like with the signature: `T* (T*)`, where `T` - /// is a type deriving from CastableBase. + /// `replacer` must be function-like with the signature: + /// `T* (CloneContext*, T*)` + /// where `T` is a type deriving from CastableBase. /// /// If `replacer` returns a nullptr then Clone() will attempt the next /// registered replacer function that matches the object type. If no replacers @@ -107,32 +123,43 @@ class CloneContext { /// /// ``` /// // Replace all ast::UintLiterals with the number 42 - /// CloneCtx ctx(mod); - /// ctx.ReplaceAll([&] (ast::UintLiteral* in) { - /// return ctx.mod->create(ctx.Clone(in->source()), - /// ctx.Clone(in->type()), 42); - /// }); - /// auto* out = ctx.Clone(tree); + /// CloneCtx ctx(&out, in) + /// .ReplaceAll([&] (CloneContext* ctx, ast::UintLiteral* l) { + /// return ctx->mod->create(ctx->Clone(l->source()), + /// ctx->Clone(l->type()), + /// 42); + /// }).Clone(); /// ``` /// /// @param replacer a function or function-like object with the signature - /// `T* (T*)`, where `T` derives from CastableBase + /// `T* (CloneContext*, T*)`, where `T` derives from CastableBase + /// @returns this CloneContext so calls can be chained template - void ReplaceAll(F replacer) { - using TPtr = traits::ParamTypeT; + CloneContext& ReplaceAll(F replacer) { + using TPtr = traits::ParamTypeT; using T = typename std::remove_pointer::type; transforms_.emplace_back([=](CastableBase* in) { auto* in_as_t = in->As(); - return in_as_t != nullptr ? replacer(in_as_t) : nullptr; + return in_as_t != nullptr ? replacer(this, in_as_t) : nullptr; }); + return *this; } + /// Clone performs the clone of the entire module #src to #mod. + void Clone(); + /// The target module to clone into. Module* const mod; + /// The source module to clone from. + Module const* const src; + private: using Transform = std::function; + CloneContext(const CloneContext&) = delete; + CloneContext& operator=(const CloneContext&) = delete; + /// LookupOrTransform is the template-independent logic of Clone(). /// This is outside of Clone() to reduce the amount of template-instantiated /// code. diff --git a/src/ast/clone_context_test.cc b/src/ast/clone_context_test.cc index 386a1ae2f5..d8890a2c95 100644 --- a/src/ast/clone_context_test.cc +++ b/src/ast/clone_context_test.cc @@ -65,8 +65,7 @@ TEST(CloneContext, Clone) { // C: Clonable ast::Module cloned; - CloneContext ctx(&cloned); - auto* cloned_root = original_root->Clone(&ctx); + auto* cloned_root = CloneContext(&cloned, &original).Clone(original_root); EXPECT_NE(cloned_root->a, nullptr); EXPECT_EQ(cloned_root->a->a, nullptr); @@ -110,14 +109,14 @@ TEST(CloneContext, CloneWithReplacements) { // R: Replaceable ast::Module cloned; - CloneContext ctx(&cloned); - ctx.ReplaceAll([&](Replaceable* in) { - auto* out = cloned.create(); - out->b = cloned.create(); - out->c = ctx.Clone(in->a); - return out; - }); - auto* cloned_root = original_root->Clone(&ctx); + auto* cloned_root = CloneContext(&cloned, &original) + .ReplaceAll([&](CloneContext* ctx, Replaceable* in) { + auto* out = cloned.create(); + out->b = cloned.create(); + out->c = ctx->Clone(in->a); + return out; + }) + .Clone(original_root); // root // ╭─────────────────┼──────────────────╮ diff --git a/src/ast/function.cc b/src/ast/function.cc index ae692875d8..1c22b9a1e9 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -228,7 +228,7 @@ const Statement* Function::get_last_statement() const { Function* Function::Clone(CloneContext* ctx) const { return ctx->mod->create( - ctx->Clone(source()), symbol_, name_, ctx->Clone(params_), + ctx->Clone(source()), ctx->Clone(symbol()), name_, ctx->Clone(params_), ctx->Clone(return_type_), ctx->Clone(body_), ctx->Clone(decorations_)); } diff --git a/src/ast/identifier_expression.cc b/src/ast/identifier_expression.cc index a38ea3f6db..b18ca10ff0 100644 --- a/src/ast/identifier_expression.cc +++ b/src/ast/identifier_expression.cc @@ -32,8 +32,8 @@ IdentifierExpression::IdentifierExpression(IdentifierExpression&&) = default; IdentifierExpression::~IdentifierExpression() = default; IdentifierExpression* IdentifierExpression::Clone(CloneContext* ctx) const { - return ctx->mod->create(ctx->Clone(source()), sym_, - name_); + return ctx->mod->create(ctx->Clone(source()), + ctx->Clone(symbol()), name_); } bool IdentifierExpression::IsValid() const { diff --git a/src/ast/module.cc b/src/ast/module.cc index a357554e18..c180639f64 100644 --- a/src/ast/module.cc +++ b/src/ast/module.cc @@ -30,33 +30,13 @@ Module& Module::operator=(Module&& rhs) = default; Module::~Module() = default; -Module Module::Clone() { +Module Module::Clone() const { Module out; - CloneContext ctx(&out); - - // Symbol table must be cloned first so that the resulting module has the - // symbols before we start the tree mutations. - ctx.mod->symbol_table_ = symbol_table_; - - CloneUsing(&ctx); + CloneContext(&out, this).Clone(); return out; } -Module Module::Clone(const std::function& init) { - Module out; - CloneContext ctx(&out); - - // Symbol table must be cloned first so that the resulting module has the - // symbols before we start the tree mutations. - ctx.mod->symbol_table_ = symbol_table_; - - init(&ctx); - - CloneUsing(&ctx); - return out; -} - -void Module::CloneUsing(CloneContext* ctx) { +void Module::Clone(CloneContext* ctx) const { for (auto* ty : constructed_types_) { ctx->mod->constructed_types_.emplace_back(ctx->Clone(ty)); } diff --git a/src/ast/module.h b/src/ast/module.h index 30319647e8..a42bd7ff9b 100644 --- a/src/ast/module.h +++ b/src/ast/module.h @@ -51,13 +51,15 @@ class Module { ~Module(); /// @return a deep copy of this module - Module Clone(); + Module Clone() const; - /// @param init a callback function to configure the CloneContex before - /// cloning any of the module's state - /// @return a deep copy of this module, calling `init` to first initialize the - /// context. - Module Clone(const std::function& init); + /// Clone this module into `ctx->mod` using the provided CloneContext + /// The module will be cloned in this order: + /// * Constructed types + /// * Global variables + /// * Functions + /// @param ctx the clone context + void Clone(CloneContext* ctx) const; /// Add a global variable to the module /// @param var the variable to add @@ -177,14 +179,6 @@ class Module { private: Module(const Module&) = delete; - /// Clone this module into `ctx->mod` using the provided CloneContext - /// The module will be cloned in this order: - /// * Constructed types - /// * Global variables - /// * Functions - /// @param ctx the clone context - void CloneUsing(CloneContext* ctx); - SymbolTable symbol_table_; VariableList global_variables_; // The constructed types are owned by the type manager diff --git a/src/ast/module_clone_test.cc b/src/ast/module_clone_test.cc index 23be2ada22..8405dc1a2c 100644 --- a/src/ast/module_clone_test.cc +++ b/src/ast/module_clone_test.cc @@ -16,6 +16,7 @@ #include "gtest/gtest.h" #include "src/ast/case_statement.h" +#include "src/demangler.h" #include "src/reader/wgsl/parser.h" #include "src/writer/wgsl/generator.h" @@ -116,7 +117,9 @@ fn main() -> void { auto dst = src.Clone(); // Expect the AST printed with to_str() to match - EXPECT_EQ(src.to_str(), dst.to_str()); + Demangler demanger; + EXPECT_EQ(demanger.Demangle(src, src.to_str()), + demanger.Demangle(dst, dst.to_str())); // Check that none of the AST nodes or type pointers in dst are found in src std::unordered_set src_nodes; diff --git a/src/ast/type/alias_type.cc b/src/ast/type/alias_type.cc index 723bca222c..bef935b9c9 100644 --- a/src/ast/type/alias_type.cc +++ b/src/ast/type/alias_type.cc @@ -47,7 +47,8 @@ uint64_t Alias::BaseAlignment(MemoryLayout mem_layout) const { } Alias* Alias::Clone(CloneContext* ctx) const { - return ctx->mod->create(symbol_, name_, ctx->Clone(subtype_)); + return ctx->mod->create(ctx->Clone(symbol()), name_, + ctx->Clone(subtype_)); } } // namespace type diff --git a/src/ast/type/struct_type.cc b/src/ast/type/struct_type.cc index ddf163bc65..7a127a67eb 100644 --- a/src/ast/type/struct_type.cc +++ b/src/ast/type/struct_type.cc @@ -84,7 +84,8 @@ uint64_t Struct::BaseAlignment(MemoryLayout mem_layout) const { } Struct* Struct::Clone(CloneContext* ctx) const { - return ctx->mod->create(symbol_, name_, ctx->Clone(struct_)); + return ctx->mod->create(ctx->Clone(symbol()), name_, + ctx->Clone(struct_)); } } // namespace type diff --git a/src/transform/bound_array_accessors.cc b/src/transform/bound_array_accessors.cc index 3adf6b7d26..4543a17983 100644 --- a/src/transform/bound_array_accessors.cc +++ b/src/transform/bound_array_accessors.cc @@ -53,13 +53,14 @@ namespace transform { BoundArrayAccessors::BoundArrayAccessors() = default; BoundArrayAccessors::~BoundArrayAccessors() = default; -Transform::Output BoundArrayAccessors::Run(ast::Module* mod) { +Transform::Output BoundArrayAccessors::Run(ast::Module* in) { Output out; - out.module = mod->Clone([&](ast::CloneContext* ctx) { - ctx->ReplaceAll([&, ctx](ast::ArrayAccessorExpression* expr) { - return Transform(expr, ctx, &out.diagnostics); - }); - }); + ast::CloneContext(&out.module, in) + .ReplaceAll( + [&](ast::CloneContext* ctx, ast::ArrayAccessorExpression* expr) { + return Transform(expr, ctx, &out.diagnostics); + }) + .Clone(); return out; } diff --git a/src/transform/emit_vertex_point_size.cc b/src/transform/emit_vertex_point_size.cc index c80fab76e2..374bcd1326 100644 --- a/src/transform/emit_vertex_point_size.cc +++ b/src/transform/emit_vertex_point_size.cc @@ -47,49 +47,41 @@ Transform::Output EmitVertexPointSize::Run(ast::Module* in) { return out; } - tint::ast::AssignmentStatement* pointsize_assign = nullptr; - auto get_pointsize_assign = [&pointsize_assign](ast::Module* mod) { - if (pointsize_assign != nullptr) { - return pointsize_assign; - } + auto* f32 = out.module.create(); - auto* f32 = mod->create(); + // Declare the pointsize builtin output variable. + auto* pointsize_var = out.module.create( + Source{}, // source + kPointSizeVar, // name + ast::StorageClass::kOutput, // storage_class + f32, // type + false, // is_const + nullptr, // constructor + ast::VariableDecorationList{ + // decorations + out.module.create(Source{}, + ast::Builtin::kPointSize), + }); + out.module.AddGlobalVariable(pointsize_var); - // Declare the pointsize builtin output variable. - auto* pointsize_var = - mod->create(Source{}, // source - kPointSizeVar, // name - ast::StorageClass::kOutput, // storage_class - f32, // type - false, // is_const - nullptr, // constructor - ast::VariableDecorationList{ - // decorations - mod->create( - Source{}, ast::Builtin::kPointSize), - }); - mod->AddGlobalVariable(pointsize_var); - - // Build the AST expression & statement for assigning pointsize one. - auto* one = mod->create( - Source{}, mod->create(Source{}, f32, 1.0f)); - auto* pointsize_ident = mod->create( - Source{}, mod->RegisterSymbol(kPointSizeVar), kPointSizeVar); - pointsize_assign = - mod->create(Source{}, pointsize_ident, one); - return pointsize_assign; - }; + // Build the AST expression & statement for assigning pointsize one. + auto* one = out.module.create( + Source{}, out.module.create(Source{}, f32, 1.0f)); + auto* pointsize_ident = out.module.create( + Source{}, out.module.RegisterSymbol(kPointSizeVar), kPointSizeVar); + auto* pointsize_assign = out.module.create( + Source{}, pointsize_ident, one); // Add the pointsize assignment statement to the front of all vertex stages. - out.module = in->Clone([&](ast::CloneContext* ctx) { - ctx->ReplaceAll([&, ctx](ast::Function* func) -> ast::Function* { - if (func->pipeline_stage() != ast::PipelineStage::kVertex) { - return nullptr; // Just clone func - } - return CloneWithStatementsAtStart(ctx, func, - {get_pointsize_assign(ctx->mod)}); - }); - }); + ast::CloneContext(&out.module, in) + .ReplaceAll( + [&](ast::CloneContext* ctx, ast::Function* func) -> ast::Function* { + if (func->pipeline_stage() != ast::PipelineStage::kVertex) { + return nullptr; // Just clone func + } + return CloneWithStatementsAtStart(ctx, func, {pointsize_assign}); + }) + .Clone(); return out; } diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index b89b94f11e..671076c037 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -126,48 +126,50 @@ Transform::Output FirstIndexOffset::Run(ast::Module* in) { // these builtins. Output out; - out.module = in->Clone([&](ast::CloneContext* ctx) { - ctx->ReplaceAll([&, ctx](ast::Variable* var) -> ast::Variable* { - for (ast::VariableDecoration* dec : var->decorations()) { - if (auto* blt_dec = dec->As()) { - ast::Builtin blt_type = blt_dec->value(); - if (blt_type == ast::Builtin::kVertexIdx) { - vertex_index_name = var->name(); - has_vertex_index_ = true; - return clone_variable_with_new_name( - ctx, var, kIndexOffsetPrefix + var->name()); - } else if (blt_type == ast::Builtin::kInstanceIdx) { - instance_index_name = var->name(); - has_instance_index_ = true; - return clone_variable_with_new_name( - ctx, var, kIndexOffsetPrefix + var->name()); - } - } - } - return nullptr; // Just clone var - }); - ctx->ReplaceAll( // Note: This happens in the same pass as the rename above - // which determines the original builtin variable names, - // but this should be fine, as variables are cloned first. - [&, ctx](ast::Function* func) -> ast::Function* { - maybe_create_buffer_var(ctx->mod); - if (buffer_var == nullptr) { - return nullptr; // no transform need, just clone func - } - ast::StatementList statements; - for (const auto& data : func->local_referenced_builtin_variables()) { - if (data.second->value() == ast::Builtin::kVertexIdx) { - statements.emplace_back(CreateFirstIndexOffset( - vertex_index_name, kFirstVertexName, buffer_var, ctx->mod)); - } else if (data.second->value() == ast::Builtin::kInstanceIdx) { - statements.emplace_back(CreateFirstIndexOffset( - instance_index_name, kFirstInstanceName, buffer_var, - ctx->mod)); + ast::CloneContext(&out.module, in) + .ReplaceAll( + [&](ast::CloneContext* ctx, ast::Variable* var) -> ast::Variable* { + for (ast::VariableDecoration* dec : var->decorations()) { + if (auto* blt_dec = dec->As()) { + ast::Builtin blt_type = blt_dec->value(); + if (blt_type == ast::Builtin::kVertexIdx) { + vertex_index_name = var->name(); + has_vertex_index_ = true; + return clone_variable_with_new_name( + ctx, var, kIndexOffsetPrefix + var->name()); + } else if (blt_type == ast::Builtin::kInstanceIdx) { + instance_index_name = var->name(); + has_instance_index_ = true; + return clone_variable_with_new_name( + ctx, var, kIndexOffsetPrefix + var->name()); + } + } } - } - return CloneWithStatementsAtStart(ctx, func, statements); - }); - }); + return nullptr; // Just clone var + }) + .ReplaceAll( // Note: This happens in the same pass as the rename above + // which determines the original builtin variable names, + // but this should be fine, as variables are cloned first. + [&](ast::CloneContext* ctx, ast::Function* func) -> ast::Function* { + maybe_create_buffer_var(ctx->mod); + if (buffer_var == nullptr) { + return nullptr; // no transform need, just clone func + } + ast::StatementList statements; + for (const auto& data : + func->local_referenced_builtin_variables()) { + if (data.second->value() == ast::Builtin::kVertexIdx) { + statements.emplace_back(CreateFirstIndexOffset( + vertex_index_name, kFirstVertexName, buffer_var, ctx->mod)); + } else if (data.second->value() == ast::Builtin::kInstanceIdx) { + statements.emplace_back(CreateFirstIndexOffset( + instance_index_name, kFirstInstanceName, buffer_var, + ctx->mod)); + } + } + return CloneWithStatementsAtStart(ctx, func, statements); + }) + .Clone(); return out; } diff --git a/src/transform/transform.cc b/src/transform/transform.cc index a03b94394d..51fe645aa9 100644 --- a/src/transform/transform.cc +++ b/src/transform/transform.cc @@ -32,7 +32,7 @@ ast::Function* Transform::CloneWithStatementsAtStart( statements.emplace_back(ctx->Clone(s)); } return ctx->mod->create( - ctx->Clone(in->source()), in->symbol(), in->name(), + ctx->Clone(in->source()), ctx->Clone(in->symbol()), in->name(), ctx->Clone(in->params()), ctx->Clone(in->return_type()), ctx->mod->create(ctx->Clone(in->body()->source()), statements), diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index b1fa0449da..5bbbcb02f0 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -98,21 +98,23 @@ Transform::Output VertexPulling::Run(ast::Module* in) { // TODO(idanr): Make sure we covered all error cases, to guarantee the // following stages will pass Output out; - out.module = in->Clone([&](ast::CloneContext* ctx) { - State state{in, ctx->mod, cfg}; - state.FindOrInsertVertexIndexIfUsed(); - state.FindOrInsertInstanceIndexIfUsed(); - state.ConvertVertexInputVariablesToPrivate(); - state.AddVertexStorageBuffers(); - ctx->ReplaceAll([func, ctx, state](ast::Function* f) -> ast::Function* { - if (f == func) { - return CloneWithStatementsAtStart( - ctx, f, {state.CreateVertexPullingPreamble()}); - } - return nullptr; // Just clone func - }); - }); + State state{in, &out.module, cfg}; + state.FindOrInsertVertexIndexIfUsed(); + state.FindOrInsertInstanceIndexIfUsed(); + state.ConvertVertexInputVariablesToPrivate(); + state.AddVertexStorageBuffers(); + + ast::CloneContext(&out.module, in) + .ReplaceAll( + [&](ast::CloneContext* ctx, ast::Function* f) -> ast::Function* { + if (f == func) { + return CloneWithStatementsAtStart( + ctx, f, {state.CreateVertexPullingPreamble()}); + } + return nullptr; // Just clone func + }) + .Clone(); return out; }