diff --git a/BUILD.gn b/BUILD.gn index af036e5bb2..5a7fb80342 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -251,6 +251,8 @@ source_set("libtint_core_src") { "src/ast/call_statement.h", "src/ast/case_statement.cc", "src/ast/case_statement.h", + "src/ast/clone_context.cc", + "src/ast/clone_context.h", "src/ast/constant_id_decoration.cc", "src/ast/constant_id_decoration.h", "src/ast/constructor_expression.cc", @@ -772,6 +774,7 @@ source_set("tint_unittests_core_src") { "src/ast/loop_statement_test.cc", "src/ast/member_accessor_expression_test.cc", "src/ast/module_test.cc", + "src/ast/module_clone_test.cc", "src/ast/null_literal_test.cc", "src/ast/return_statement_test.cc", "src/ast/scalar_constructor_expression_test.cc", @@ -1309,6 +1312,17 @@ if (build_with_chromium) { ] } } + + if (tint_build_wgsl_reader && tint_build_wgsl_writer) { + fuzzer_test("tint_spv_reader_fuzzer") { + sources = [ "fuzzers/tint_ast_clone_fuzzer.cc" ] + deps = [ + ":libtint_wgsl_reader_src", + ":libtint_wgsl_writer_src", + ":tint_fuzzer_common", + ] + } + } } ############################################################################### diff --git a/CMakeLists.txt b/CMakeLists.txt index 1d23434864..d30734b265 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,7 +28,7 @@ endif() option(TINT_BUILD_DOCS "Build documentation" ON) option(TINT_BUILD_SPV_READER "Build the SPIR-V input reader" ON) -option(TINT_BUILD_WGSL_READER "Builde the WGSL input reader" ON) +option(TINT_BUILD_WGSL_READER "Build the WGSL input reader" ON) option(TINT_BUILD_HLSL_WRITER "Build the HLSL output writer" ON) option(TINT_BUILD_MSL_WRITER "Build the MSL output writer" ON) option(TINT_BUILD_SPV_WRITER "Build the SPIR-V output writer" ON) diff --git a/fuzzers/CMakeLists.txt b/fuzzers/CMakeLists.txt index 3599d426ae..bfa50e7604 100644 --- a/fuzzers/CMakeLists.txt +++ b/fuzzers/CMakeLists.txt @@ -26,3 +26,7 @@ endif() if (${TINT_BUILD_SPV_READER}) add_tint_fuzzer(tint_spv_reader_fuzzer) endif() + +if (${TINT_BUILD_WGSL_READER} AND ${TINT_BUILD_WGSL_WRITER}) + add_tint_fuzzer(tint_ast_clone_fuzzer) +endif() diff --git a/fuzzers/tint_ast_clone_fuzzer.cc b/fuzzers/tint_ast_clone_fuzzer.cc new file mode 100644 index 0000000000..71c1e27c28 --- /dev/null +++ b/fuzzers/tint_ast_clone_fuzzer.cc @@ -0,0 +1,103 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "src/reader/wgsl/parser_impl.h" +#include "src/writer/wgsl/generator.h" + +#define ASSERT_EQ(A, B) \ + do { \ + decltype(A) assert_a = (A); \ + decltype(B) assert_b = (B); \ + if (assert_a != assert_b) { \ + std::cerr << "ASSERT_EQ(" #A ", " #B ") failed:\n" \ + << #A << " was: " << assert_a << "\n" \ + << #B << " was: " << assert_b << "\n"; \ + __builtin_trap(); \ + } \ + } while (false) + +#define ASSERT_TRUE(A) \ + do { \ + decltype(A) assert_a = (A); \ + if (!assert_a) { \ + std::cerr << "ASSERT_TRUE(" #A ") failed:\n" \ + << #A << " was: " << assert_a << "\n"; \ + __builtin_trap(); \ + } \ + } while (false) + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { + std::string str(reinterpret_cast(data), size); + + tint::Source::File file("test.wgsl", str); + + // Parse the wgsl, create the src module + tint::Context ctx; + tint::reader::wgsl::ParserImpl parser(&ctx, &file); + parser.set_max_errors(1); + if (!parser.Parse()) { + return 0; + } + auto src = parser.module(); + + // Clone the src module to dst + auto dst = src.Clone(); + + // Expect the AST printed with to_str() to match + ASSERT_EQ(src.to_str(), dst.to_str()); + + // Check that none of the AST nodes or type pointers in dst are found in src + std::unordered_set src_nodes; + for (auto& src_node : src.nodes()) { + src_nodes.emplace(src_node.get()); + } + std::unordered_set src_types; + for (auto& src_type : src.types()) { + src_types.emplace(src_type.second.get()); + } + for (auto& dst_node : dst.nodes()) { + ASSERT_EQ(src_nodes.count(dst_node.get()), 0u); + } + for (auto& dst_type : dst.types()) { + ASSERT_EQ(src_types.count(dst_type.second.get()), 0u); + } + + // Regenerate the wgsl for the src module. We use this instead of the original + // source so that reformatting doesn't impact the final wgsl comparision. + // Note that the src module is moved into the generator and this generator has + // a limited scope, so that the src module is released before we attempt to + // print the dst module. + // This guarantee that all the source module nodes and types are destructed + // and freed. + // ASAN should error if there's any remaining references in dst when we try to + // reconstruct the WGSL. + std::string src_wgsl; + { + tint::writer::wgsl::Generator src_gen(std::move(src)); + ASSERT_TRUE(src_gen.Generate()); + src_wgsl = src_gen.result(); + } + + // Print the dst module, check it matches the original source + tint::writer::wgsl::Generator dst_gen(std::move(dst)); + ASSERT_TRUE(dst_gen.Generate()); + auto dst_wgsl = dst_gen.result(); + ASSERT_EQ(src_wgsl, dst_wgsl); + + return 0; +} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b5c44775cf..da3f67d2ce 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -72,6 +72,8 @@ set(TINT_LIB_SRCS ast/call_statement.h ast/case_statement.cc ast/case_statement.h + ast/clone_context.cc + ast/clone_context.h ast/constant_id_decoration.cc ast/constant_id_decoration.h ast/constructor_expression.cc @@ -381,6 +383,7 @@ set(TINT_TEST_SRCS ast/loop_statement_test.cc ast/member_accessor_expression_test.cc ast/module_test.cc + ast/module_clone_test.cc ast/null_literal_test.cc ast/binary_expression_test.cc ast/return_statement_test.cc diff --git a/src/ast/access_decoration.cc b/src/ast/access_decoration.cc index f77d78c8f3..45c8760ab4 100644 --- a/src/ast/access_decoration.cc +++ b/src/ast/access_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/access_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -27,5 +30,9 @@ void AccessDecoration::to_str(std::ostream& out, size_t indent) const { out << "AccessDecoration{" << value_ << "}" << std::endl; } +AccessDecoration* AccessDecoration::Clone(CloneContext* ctx) const { + return ctx->mod->create(value_, ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/access_decoration.h b/src/ast/access_decoration.h index 8e45062412..07d437e66f 100644 --- a/src/ast/access_decoration.h +++ b/src/ast/access_decoration.h @@ -40,6 +40,14 @@ class AccessDecoration : public Castable { /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + AccessDecoration* Clone(CloneContext* ctx) const override; + private: AccessControl value_ = ast::AccessControl::kReadWrite; }; diff --git a/src/ast/array_accessor_expression.cc b/src/ast/array_accessor_expression.cc index ded77166bf..75f3bf545c 100644 --- a/src/ast/array_accessor_expression.cc +++ b/src/ast/array_accessor_expression.cc @@ -14,6 +14,9 @@ #include "src/ast/array_accessor_expression.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -33,6 +36,12 @@ ArrayAccessorExpression::ArrayAccessorExpression(ArrayAccessorExpression&&) = ArrayAccessorExpression::~ArrayAccessorExpression() = default; +ArrayAccessorExpression* ArrayAccessorExpression::Clone( + CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(array_), + ctx->Clone(idx_expr_)); +} + bool ArrayAccessorExpression::IsValid() const { if (array_ == nullptr || !array_->IsValid()) return false; diff --git a/src/ast/array_accessor_expression.h b/src/ast/array_accessor_expression.h index 0f18d5e203..9cc98f75ea 100644 --- a/src/ast/array_accessor_expression.h +++ b/src/ast/array_accessor_expression.h @@ -57,6 +57,14 @@ class ArrayAccessorExpression /// @returns the index expression Expression* idx_expr() const { return idx_expr_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + ArrayAccessorExpression* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/assignment_statement.cc b/src/ast/assignment_statement.cc index 3fe9a8b1f9..fe72839cfc 100644 --- a/src/ast/assignment_statement.cc +++ b/src/ast/assignment_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/assignment_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -31,6 +34,11 @@ AssignmentStatement::AssignmentStatement(AssignmentStatement&&) = default; AssignmentStatement::~AssignmentStatement() = default; +AssignmentStatement* AssignmentStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), ctx->Clone(lhs_), ctx->Clone(rhs_)); +} + bool AssignmentStatement::IsValid() const { if (lhs_ == nullptr || !lhs_->IsValid()) return false; diff --git a/src/ast/assignment_statement.h b/src/ast/assignment_statement.h index 08c972e422..2b61bd879b 100644 --- a/src/ast/assignment_statement.h +++ b/src/ast/assignment_statement.h @@ -55,6 +55,14 @@ class AssignmentStatement : public Castable { /// @returns the right side expression Expression* rhs() const { return rhs_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + AssignmentStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/binary_expression.cc b/src/ast/binary_expression.cc index c14f538e06..e6f3555b34 100644 --- a/src/ast/binary_expression.cc +++ b/src/ast/binary_expression.cc @@ -14,6 +14,9 @@ #include "src/ast/binary_expression.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -34,6 +37,11 @@ BinaryExpression::BinaryExpression(BinaryExpression&&) = default; BinaryExpression::~BinaryExpression() = default; +BinaryExpression* BinaryExpression::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source()), op_, + ctx->Clone(lhs_), ctx->Clone(rhs_)); +} + bool BinaryExpression::IsValid() const { if (lhs_ == nullptr || !lhs_->IsValid()) { return false; diff --git a/src/ast/binary_expression.h b/src/ast/binary_expression.h index 1431f18e37..783954f6cb 100644 --- a/src/ast/binary_expression.h +++ b/src/ast/binary_expression.h @@ -125,6 +125,14 @@ class BinaryExpression : public Castable { /// @returns the right side expression Expression* rhs() const { return rhs_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + BinaryExpression* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/binding_decoration.cc b/src/ast/binding_decoration.cc index 34e6023504..bb64cc1e23 100644 --- a/src/ast/binding_decoration.cc +++ b/src/ast/binding_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/binding_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -27,5 +30,9 @@ void BindingDecoration::to_str(std::ostream& out, size_t indent) const { out << "BindingDecoration{" << value_ << "}" << std::endl; } +BindingDecoration* BindingDecoration::Clone(CloneContext* ctx) const { + return ctx->mod->create(value_, ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/binding_decoration.h b/src/ast/binding_decoration.h index ec2e756de3..0299ddcbf6 100644 --- a/src/ast/binding_decoration.h +++ b/src/ast/binding_decoration.h @@ -40,6 +40,14 @@ class BindingDecoration /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + BindingDecoration* Clone(CloneContext* ctx) const override; + private: uint32_t value_; }; diff --git a/src/ast/bitcast_expression.cc b/src/ast/bitcast_expression.cc index 4ebfc53e43..8d30b6aa7c 100644 --- a/src/ast/bitcast_expression.cc +++ b/src/ast/bitcast_expression.cc @@ -14,6 +14,9 @@ #include "src/ast/bitcast_expression.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -30,6 +33,11 @@ BitcastExpression::BitcastExpression(const Source& source, BitcastExpression::BitcastExpression(BitcastExpression&&) = default; BitcastExpression::~BitcastExpression() = default; +BitcastExpression* BitcastExpression::Clone(CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), ctx->Clone(type_), ctx->Clone(expr_)); +} + bool BitcastExpression::IsValid() const { if (expr_ == nullptr || !expr_->IsValid()) return false; diff --git a/src/ast/bitcast_expression.h b/src/ast/bitcast_expression.h index a898036907..f736071777 100644 --- a/src/ast/bitcast_expression.h +++ b/src/ast/bitcast_expression.h @@ -55,6 +55,14 @@ class BitcastExpression : public Castable { /// @returns the expression Expression* expr() const { return expr_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + BitcastExpression* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/block_statement.cc b/src/ast/block_statement.cc index 7ab8d9bdb8..ee636a67d9 100644 --- a/src/ast/block_statement.cc +++ b/src/ast/block_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/block_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -25,6 +28,12 @@ BlockStatement::BlockStatement(BlockStatement&&) = default; BlockStatement::~BlockStatement() = default; +BlockStatement* BlockStatement::Clone(CloneContext* ctx) const { + auto* cloned = ctx->mod->create(ctx->Clone(source())); + cloned->statements_ = ctx->Clone(statements_); + return cloned; +} + bool BlockStatement::IsValid() const { for (auto* stmt : *this) { if (stmt == nullptr || !stmt->IsValid()) { diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h index f5803c3bfd..4490606076 100644 --- a/src/ast/block_statement.h +++ b/src/ast/block_statement.h @@ -85,6 +85,14 @@ class BlockStatement : public Castable { return statements_.end(); } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + BlockStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/bool_literal.cc b/src/ast/bool_literal.cc index 7f909a6017..3e481d501f 100644 --- a/src/ast/bool_literal.cc +++ b/src/ast/bool_literal.cc @@ -14,6 +14,9 @@ #include "src/ast/bool_literal.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -30,5 +33,9 @@ std::string BoolLiteral::name() const { return value_ ? "__bool_true" : "__bool_false"; } +BoolLiteral* BoolLiteral::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(type()), value_); +} + } // namespace ast } // namespace tint diff --git a/src/ast/bool_literal.h b/src/ast/bool_literal.h index 4b5acf36fe..c95ca827fb 100644 --- a/src/ast/bool_literal.h +++ b/src/ast/bool_literal.h @@ -42,6 +42,14 @@ class BoolLiteral : public Castable { /// @returns the literal as a string std::string to_str() const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + BoolLiteral* Clone(CloneContext* ctx) const override; + private: bool value_; }; diff --git a/src/ast/break_statement.cc b/src/ast/break_statement.cc index 1894ba4fc0..cbccb00066 100644 --- a/src/ast/break_statement.cc +++ b/src/ast/break_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/break_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -25,6 +28,10 @@ BreakStatement::BreakStatement(BreakStatement&&) = default; BreakStatement::~BreakStatement() = default; +BreakStatement* BreakStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source())); +} + bool BreakStatement::IsValid() const { return true; } diff --git a/src/ast/break_statement.h b/src/ast/break_statement.h index 2a72c0384c..83d902eb4c 100644 --- a/src/ast/break_statement.h +++ b/src/ast/break_statement.h @@ -32,6 +32,14 @@ class BreakStatement : public Castable { BreakStatement(BreakStatement&&); ~BreakStatement() override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + BreakStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/builtin_decoration.cc b/src/ast/builtin_decoration.cc index 856fc9a91b..4070be8761 100644 --- a/src/ast/builtin_decoration.cc +++ b/src/ast/builtin_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/builtin_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -27,5 +30,9 @@ void BuiltinDecoration::to_str(std::ostream& out, size_t indent) const { out << "BuiltinDecoration{" << builtin_ << "}" << std::endl; } +BuiltinDecoration* BuiltinDecoration::Clone(CloneContext* ctx) const { + return ctx->mod->create(builtin_, ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/builtin_decoration.h b/src/ast/builtin_decoration.h index 9f4636d937..c907f6975e 100644 --- a/src/ast/builtin_decoration.h +++ b/src/ast/builtin_decoration.h @@ -39,6 +39,14 @@ class BuiltinDecoration /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + BuiltinDecoration* Clone(CloneContext* ctx) const override; + private: Builtin builtin_ = Builtin::kNone; }; diff --git a/src/ast/call_expression.cc b/src/ast/call_expression.cc index c17a1464a5..fdee36ef1a 100644 --- a/src/ast/call_expression.cc +++ b/src/ast/call_expression.cc @@ -14,6 +14,9 @@ #include "src/ast/call_expression.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -31,6 +34,11 @@ CallExpression::CallExpression(CallExpression&&) = default; CallExpression::~CallExpression() = default; +CallExpression* CallExpression::Clone(CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), ctx->Clone(func_), ctx->Clone(params_)); +} + bool CallExpression::IsValid() const { if (func_ == nullptr || !func_->IsValid()) return false; diff --git a/src/ast/call_expression.h b/src/ast/call_expression.h index 0e27008e80..3c12822bd6 100644 --- a/src/ast/call_expression.h +++ b/src/ast/call_expression.h @@ -54,6 +54,14 @@ class CallExpression : public Castable { /// @returns the parameters const ExpressionList& params() const { return params_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + CallExpression* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/call_statement.cc b/src/ast/call_statement.cc index c37c99a827..9158fbb988 100644 --- a/src/ast/call_statement.cc +++ b/src/ast/call_statement.cc @@ -15,6 +15,8 @@ #include "src/ast/call_statement.h" #include "src/ast/call_expression.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" namespace tint { namespace ast { @@ -27,6 +29,10 @@ CallStatement::CallStatement(CallStatement&&) = default; CallStatement::~CallStatement() = default; +CallStatement* CallStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(call_)); +} + bool CallStatement::IsValid() const { return call_ != nullptr && call_->IsValid(); } diff --git a/src/ast/call_statement.h b/src/ast/call_statement.h index 1657388fe5..026e61f346 100644 --- a/src/ast/call_statement.h +++ b/src/ast/call_statement.h @@ -42,6 +42,14 @@ class CallStatement : public Castable { /// @returns the call expression CallExpression* expr() const { return call_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + CallStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/case_statement.cc b/src/ast/case_statement.cc index 2d5018568d..35227a8a1a 100644 --- a/src/ast/case_statement.cc +++ b/src/ast/case_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/case_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -31,6 +34,11 @@ CaseStatement::CaseStatement(CaseStatement&&) = default; CaseStatement::~CaseStatement() = default; +CaseStatement* CaseStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), ctx->Clone(selectors_), ctx->Clone(body_)); +} + bool CaseStatement::IsValid() const { return body_ != nullptr && body_->IsValid(); } diff --git a/src/ast/case_statement.h b/src/ast/case_statement.h index ef39af0e1e..9ed2deb30b 100644 --- a/src/ast/case_statement.h +++ b/src/ast/case_statement.h @@ -70,6 +70,14 @@ class CaseStatement : public Castable { /// @returns the case body BlockStatement* body() { return body_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + CaseStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/clone_context.cc b/src/ast/clone_context.cc new file mode 100644 index 0000000000..9ec4bfe894 --- /dev/null +++ b/src/ast/clone_context.cc @@ -0,0 +1,24 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/ast/clone_context.h" + +namespace tint { +namespace ast { + +CloneContext::CloneContext(Module* m) : mod(m) {} +CloneContext::~CloneContext() = default; + +} // namespace ast +} // namespace tint diff --git a/src/ast/clone_context.h b/src/ast/clone_context.h new file mode 100644 index 0000000000..deb31818cb --- /dev/null +++ b/src/ast/clone_context.h @@ -0,0 +1,90 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SRC_AST_CLONE_CONTEXT_H_ +#define SRC_AST_CLONE_CONTEXT_H_ + +#include +#include + +#include "src/source.h" + +namespace tint { +namespace ast { + +class Module; + +/// CloneContext holds the state used while cloning AST nodes and types. +class CloneContext { + public: + /// Constructor + /// @param m the target module to clone into + explicit CloneContext(Module* m); + /// Destructor + ~CloneContext(); + + /// Clones the `Node` or `type::Type` @p a into the module #mod if @p a is not + /// null. If @p a is null, then Clone() returns null. If @p a has been cloned + /// already by this CloneContext then the same cloned pointer is returned. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param a the `Node` or `type::Type` to clone + /// @return the cloned node + template + T* Clone(T* a) { + if (a == nullptr) { + return nullptr; + } + + auto it = cloned_.find(a); + if (it != cloned_.end()) { + return static_cast(it->second); + } + auto* c = a->Clone(this); + cloned_.emplace(a, c); + return static_cast(c); + } + + /// Clones the `Source` @p s into @p mod + /// TODO(bclayton) - Currently this 'clone' is a shallow copy. If/when + /// `Source.File`s are owned by the `Module` this should make a copy of the + /// file. + /// @param s the `Source` to clone + /// @return the cloned source + Source Clone(const Source& s) { return s; } + + /// Clones each of the elements of the vector @p v into the module #mod. + /// @param v the vector to clone + /// @return the cloned vector + template + std::vector Clone(const std::vector& v) { + std::vector out; + out.reserve(v.size()); + for (auto& el : v) { + out.emplace_back(Clone(el)); + } + return out; + } + + /// The target module to clone into. + Module* const mod; + + private: + std::unordered_map cloned_; +}; + +} // namespace ast +} // namespace tint + +#endif // SRC_AST_CLONE_CONTEXT_H_ diff --git a/src/ast/constant_id_decoration.cc b/src/ast/constant_id_decoration.cc index 6b10150f50..f6de69cf25 100644 --- a/src/ast/constant_id_decoration.cc +++ b/src/ast/constant_id_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/constant_id_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -27,5 +30,9 @@ void ConstantIdDecoration::to_str(std::ostream& out, size_t indent) const { out << "ConstantIdDecoration{" << value_ << "}" << std::endl; } +ConstantIdDecoration* ConstantIdDecoration::Clone(CloneContext* ctx) const { + return ctx->mod->create(value_, ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/constant_id_decoration.h b/src/ast/constant_id_decoration.h index d7a02561b5..d683e2ad60 100644 --- a/src/ast/constant_id_decoration.h +++ b/src/ast/constant_id_decoration.h @@ -39,6 +39,14 @@ class ConstantIdDecoration /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + ConstantIdDecoration* Clone(CloneContext* ctx) const override; + private: uint32_t value_ = 0; }; diff --git a/src/ast/continue_statement.cc b/src/ast/continue_statement.cc index 1f9da2a9f0..0e5d825242 100644 --- a/src/ast/continue_statement.cc +++ b/src/ast/continue_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/continue_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -25,6 +28,10 @@ ContinueStatement::ContinueStatement(ContinueStatement&&) = default; ContinueStatement::~ContinueStatement() = default; +ContinueStatement* ContinueStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source())); +} + bool ContinueStatement::IsValid() const { return true; } diff --git a/src/ast/continue_statement.h b/src/ast/continue_statement.h index b2a01f9b6e..e68f9e6137 100644 --- a/src/ast/continue_statement.h +++ b/src/ast/continue_statement.h @@ -35,6 +35,14 @@ class ContinueStatement : public Castable { ContinueStatement(ContinueStatement&&); ~ContinueStatement() override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + ContinueStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/decorated_variable.cc b/src/ast/decorated_variable.cc index 447269294a..63ddbbbd11 100644 --- a/src/ast/decorated_variable.cc +++ b/src/ast/decorated_variable.cc @@ -17,8 +17,10 @@ #include #include "src/ast/builtin_decoration.h" +#include "src/ast/clone_context.h" #include "src/ast/constant_id_decoration.h" #include "src/ast/location_decoration.h" +#include "src/ast/module.h" namespace tint { namespace ast { @@ -69,6 +71,18 @@ uint32_t DecoratedVariable::constant_id() const { return 0; } +DecoratedVariable* DecoratedVariable::Clone(CloneContext* ctx) const { + auto* cloned = ctx->mod->create(); + cloned->set_source(ctx->Clone(source())); + cloned->set_name(name()); + cloned->set_storage_class(storage_class()); + cloned->set_type(ctx->Clone(type())); + cloned->set_constructor(ctx->Clone(constructor())); + cloned->set_is_const(is_const()); + cloned->set_decorations(ctx->Clone(decorations())); + return cloned; +} + bool DecoratedVariable::IsValid() const { return Variable::IsValid(); } diff --git a/src/ast/decorated_variable.h b/src/ast/decorated_variable.h index b0067b5e45..b78abbc862 100644 --- a/src/ast/decorated_variable.h +++ b/src/ast/decorated_variable.h @@ -56,6 +56,14 @@ class DecoratedVariable : public Castable { /// |HasConstantIdDecoration| has been called first. uint32_t constant_id() const; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + DecoratedVariable* Clone(CloneContext* ctx) const override; + /// @returns true if the name and path are both present bool IsValid() const override; diff --git a/src/ast/discard_statement.cc b/src/ast/discard_statement.cc index b70856b92e..7db671d276 100644 --- a/src/ast/discard_statement.cc +++ b/src/ast/discard_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/discard_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -25,6 +28,10 @@ DiscardStatement::DiscardStatement(DiscardStatement&&) = default; DiscardStatement::~DiscardStatement() = default; +DiscardStatement* DiscardStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source())); +} + bool DiscardStatement::IsValid() const { return true; } diff --git a/src/ast/discard_statement.h b/src/ast/discard_statement.h index ba7b3982bc..c5afd77ab9 100644 --- a/src/ast/discard_statement.h +++ b/src/ast/discard_statement.h @@ -32,6 +32,14 @@ class DiscardStatement : public Castable { DiscardStatement(DiscardStatement&&); ~DiscardStatement() override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + DiscardStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/else_statement.cc b/src/ast/else_statement.cc index 937755b25d..21d804b6a0 100644 --- a/src/ast/else_statement.cc +++ b/src/ast/else_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/else_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -34,6 +37,11 @@ ElseStatement::ElseStatement(ElseStatement&&) = default; ElseStatement::~ElseStatement() = default; +ElseStatement* ElseStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), ctx->Clone(condition_), ctx->Clone(body_)); +} + bool ElseStatement::IsValid() const { if (body_ == nullptr || !body_->IsValid()) { return false; diff --git a/src/ast/else_statement.h b/src/ast/else_statement.h index 60a675ab58..01be9a8dfd 100644 --- a/src/ast/else_statement.h +++ b/src/ast/else_statement.h @@ -67,6 +67,14 @@ class ElseStatement : public Castable { /// @returns the else body BlockStatement* body() { return body_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + ElseStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/expression_test.cc b/src/ast/expression_test.cc index 62538e8e83..01c177d050 100644 --- a/src/ast/expression_test.cc +++ b/src/ast/expression_test.cc @@ -26,6 +26,7 @@ class Expr : public Expression { public: Expr() : Expression() {} + Expr* Clone(CloneContext*) const override { return nullptr; } bool IsValid() const override { return true; } void to_str(std::ostream&, size_t) const override {} }; diff --git a/src/ast/fallthrough_statement.cc b/src/ast/fallthrough_statement.cc index cd5f00d874..f193a8d829 100644 --- a/src/ast/fallthrough_statement.cc +++ b/src/ast/fallthrough_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/fallthrough_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -26,6 +29,10 @@ FallthroughStatement::FallthroughStatement(FallthroughStatement&&) = default; FallthroughStatement::~FallthroughStatement() = default; +FallthroughStatement* FallthroughStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source())); +} + bool FallthroughStatement::IsValid() const { return true; } diff --git a/src/ast/fallthrough_statement.h b/src/ast/fallthrough_statement.h index 5b0bc81c51..5f652bda98 100644 --- a/src/ast/fallthrough_statement.h +++ b/src/ast/fallthrough_statement.h @@ -32,6 +32,14 @@ class FallthroughStatement : public Castable { FallthroughStatement(FallthroughStatement&&); ~FallthroughStatement() override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + FallthroughStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/float_literal.cc b/src/ast/float_literal.cc index c781afa64c..6d182ba196 100644 --- a/src/ast/float_literal.cc +++ b/src/ast/float_literal.cc @@ -17,6 +17,9 @@ #include #include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -37,5 +40,9 @@ std::string FloatLiteral::name() const { return out.str(); } +FloatLiteral* FloatLiteral::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(type()), value_); +} + } // namespace ast } // namespace tint diff --git a/src/ast/float_literal.h b/src/ast/float_literal.h index f474c6773a..9e55e47e53 100644 --- a/src/ast/float_literal.h +++ b/src/ast/float_literal.h @@ -40,6 +40,14 @@ class FloatLiteral : public Castable { /// @returns the literal as a string std::string to_str() const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + FloatLiteral* Clone(CloneContext* ctx) const override; + private: float value_; }; diff --git a/src/ast/function.cc b/src/ast/function.cc index 6c9f5d425d..e3c5ca27c6 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -16,7 +16,9 @@ #include +#include "src/ast/clone_context.h" #include "src/ast/decorated_variable.h" +#include "src/ast/module.h" #include "src/ast/stage_decoration.h" #include "src/ast/type/multisampled_texture_type.h" #include "src/ast/type/sampled_texture_type.h" @@ -213,6 +215,14 @@ const Statement* Function::get_last_statement() const { return body_->last(); } +Function* Function::Clone(CloneContext* ctx) const { + auto* cloned = ctx->mod->create( + ctx->Clone(source()), name_, ctx->Clone(params_), + ctx->Clone(return_type_), ctx->Clone(body_)); + cloned->set_decorations(ctx->Clone(decorations_)); + return cloned; +} + bool Function::IsValid() const { for (auto* param : params_) { if (param == nullptr || !param->IsValid()) diff --git a/src/ast/function.h b/src/ast/function.h index f539224eed..a7afe1e753 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -190,6 +190,14 @@ class Function : public Castable { /// @returns the function body BlockStatement* body() { return body_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + Function* Clone(CloneContext* ctx) const override; + /// @returns true if the name and type are both present bool IsValid() const override; diff --git a/src/ast/identifier_expression.cc b/src/ast/identifier_expression.cc index cf30ddca6b..6952d31f31 100644 --- a/src/ast/identifier_expression.cc +++ b/src/ast/identifier_expression.cc @@ -14,6 +14,9 @@ #include "src/ast/identifier_expression.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -28,6 +31,10 @@ IdentifierExpression::IdentifierExpression(IdentifierExpression&&) = default; IdentifierExpression::~IdentifierExpression() = default; +IdentifierExpression* IdentifierExpression::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source()), name_); +} + bool IdentifierExpression::IsValid() const { return !name_.empty(); } diff --git a/src/ast/identifier_expression.h b/src/ast/identifier_expression.h index 708648d642..703c041b11 100644 --- a/src/ast/identifier_expression.h +++ b/src/ast/identifier_expression.h @@ -61,6 +61,14 @@ class IdentifierExpression : public Castable { /// @returns true if this identifier is for an intrinsic bool IsIntrinsic() const { return intrinsic_ != Intrinsic::kNone; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + IdentifierExpression* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/if_statement.cc b/src/ast/if_statement.cc index c8406b74c3..5bdf9ce9c0 100644 --- a/src/ast/if_statement.cc +++ b/src/ast/if_statement.cc @@ -14,7 +14,9 @@ #include "src/ast/if_statement.h" +#include "src/ast/clone_context.h" #include "src/ast/else_statement.h" +#include "src/ast/module.h" namespace tint { namespace ast { @@ -31,6 +33,13 @@ IfStatement::IfStatement(IfStatement&&) = default; IfStatement::~IfStatement() = default; +IfStatement* IfStatement::Clone(CloneContext* ctx) const { + auto* cloned = ctx->mod->create( + ctx->Clone(source()), ctx->Clone(condition_), ctx->Clone(body_)); + cloned->else_statements_ = ctx->Clone(else_statements_); + return cloned; +} + bool IfStatement::IsValid() const { if (condition_ == nullptr || !condition_->IsValid()) { return false; diff --git a/src/ast/if_statement.h b/src/ast/if_statement.h index c6bb79f1cb..5087f28ad8 100644 --- a/src/ast/if_statement.h +++ b/src/ast/if_statement.h @@ -71,6 +71,14 @@ class IfStatement : public Castable { /// @returns true if there are else statements bool has_else_statements() const { return !else_statements_.empty(); } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + IfStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/location_decoration.cc b/src/ast/location_decoration.cc index de2beae327..3abffc958a 100644 --- a/src/ast/location_decoration.cc +++ b/src/ast/location_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/location_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -27,5 +30,9 @@ void LocationDecoration::to_str(std::ostream& out, size_t indent) const { out << "LocationDecoration{" << value_ << "}" << std::endl; } +LocationDecoration* LocationDecoration::Clone(CloneContext* ctx) const { + return ctx->mod->create(value_, ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/location_decoration.h b/src/ast/location_decoration.h index 6e75d63fe2..560c3d69a5 100644 --- a/src/ast/location_decoration.h +++ b/src/ast/location_decoration.h @@ -40,6 +40,14 @@ class LocationDecoration /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + LocationDecoration* Clone(CloneContext* ctx) const override; + private: uint32_t value_; }; diff --git a/src/ast/loop_statement.cc b/src/ast/loop_statement.cc index 220201f6b8..72b9a973e9 100644 --- a/src/ast/loop_statement.cc +++ b/src/ast/loop_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/loop_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -29,6 +32,11 @@ LoopStatement::LoopStatement(LoopStatement&&) = default; LoopStatement::~LoopStatement() = default; +LoopStatement* LoopStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), ctx->Clone(body_), ctx->Clone(continuing_)); +} + bool LoopStatement::IsValid() const { if (body_ == nullptr || !body_->IsValid()) { return false; diff --git a/src/ast/loop_statement.h b/src/ast/loop_statement.h index a107e08c49..374c669703 100644 --- a/src/ast/loop_statement.h +++ b/src/ast/loop_statement.h @@ -62,6 +62,14 @@ class LoopStatement : public Castable { return continuing_ != nullptr && !continuing_->empty(); } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + LoopStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/member_accessor_expression.cc b/src/ast/member_accessor_expression.cc index 64d8aad1d7..8f8184631d 100644 --- a/src/ast/member_accessor_expression.cc +++ b/src/ast/member_accessor_expression.cc @@ -14,6 +14,9 @@ #include "src/ast/member_accessor_expression.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -33,6 +36,12 @@ MemberAccessorExpression::MemberAccessorExpression(MemberAccessorExpression&&) = MemberAccessorExpression::~MemberAccessorExpression() = default; +MemberAccessorExpression* MemberAccessorExpression::Clone( + CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), ctx->Clone(struct_), ctx->Clone(member_)); +} + bool MemberAccessorExpression::IsValid() const { if (struct_ == nullptr || !struct_->IsValid()) { return false; diff --git a/src/ast/member_accessor_expression.h b/src/ast/member_accessor_expression.h index 2af58397c9..4b4f8590ca 100644 --- a/src/ast/member_accessor_expression.h +++ b/src/ast/member_accessor_expression.h @@ -59,6 +59,14 @@ class MemberAccessorExpression /// @returns the member expression IdentifierExpression* member() const { return member_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + MemberAccessorExpression* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/module.cc b/src/ast/module.cc index 748a984cf3..e9fecf125d 100644 --- a/src/ast/module.cc +++ b/src/ast/module.cc @@ -16,6 +16,7 @@ #include +#include "src/ast/clone_context.h" #include "src/ast/type/struct_type.h" namespace tint { @@ -27,6 +28,23 @@ Module::Module(Module&&) = default; Module::~Module() = default; +Module Module::Clone() { + Module out; + + CloneContext ctx(&out); + for (auto* ty : constructed_types_) { + out.constructed_types_.emplace_back(ctx.Clone(ty)); + } + for (auto* var : global_variables_) { + out.global_variables_.emplace_back(ctx.Clone(var)); + } + for (auto* func : functions_) { + out.functions_.emplace_back(ctx.Clone(func)); + } + + return out; +} + Function* Module::FindFunctionByName(const std::string& name) const { for (auto* func : functions_) { if (func->name() == name) { diff --git a/src/ast/module.h b/src/ast/module.h index ded2225bc2..9caf6aa3cf 100644 --- a/src/ast/module.h +++ b/src/ast/module.h @@ -42,6 +42,9 @@ class Module { Module(Module&&); ~Module(); + /// @return a deep copy of this module + Module Clone(); + /// Add a global variable to the module /// @param var the variable to add void AddGlobalVariable(Variable* var) { global_variables_.push_back(var); } @@ -135,6 +138,9 @@ class Module { return type_mgr_.types(); } + /// @returns all the declared nodes in the module + const std::vector>& nodes() { return ast_nodes_; } + private: Module(const Module&) = delete; diff --git a/src/ast/module_clone_test.cc b/src/ast/module_clone_test.cc new file mode 100644 index 0000000000..93ec7ee8fb --- /dev/null +++ b/src/ast/module_clone_test.cc @@ -0,0 +1,168 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "src/ast/case_statement.h" + +#include "gtest/gtest.h" +#include "src/reader/wgsl/parser.h" +#include "src/writer/wgsl/generator.h" + +namespace tint { +namespace ast { +namespace { + +TEST(ModuleCloneTest, Clone) { +#if TINT_BUILD_WGSL_READER && TINT_BUILD_WGSL_WRITER + // Shader that exercises the bulk of the AST nodes and types. + // See also fuzzers/tint_ast_clone_fuzzer.cc for further coverage of cloning. + Source::File file("test.wgsl", R"([[block]] +struct S { + [[offset(0)]] + m0 : u32; + [[offset(4)]] + m1 : array; +}; + +type t0 = [[stride(16)]] array>; +type t1 = [[stride(32)]] array>; + +const c0 : i32 = 10; +const c1 : bool = true; + +var g0 : u32 = 20u; +var g1 : f32 = 123.0; +var g2 : texture_2d; +var g3 : texture_storage_ro_2d; +var g4 : texture_storage_wo_2d; +var g5 : texture_storage_ro_2d; +var g6 : texture_storage_wo_2d; + +[[builtin(position)]] var g7 : vec3; +[[set(10), binding(20)]] var g7 : S; +[[set(10), binding(20)]] var g8 : [[access(read)]] +S; +[[set(10), binding(20)]] var g9 : [[access(read_write)]] +S; + +fn f0(p0 : bool) -> f32 { + if (p0) { + return 1.0; + } + return 0.0; +} + +fn f1(p0 : f32, p1 : i32) -> f32 { + var l0 : i32 = 3; + var l1 : f32 = 8; + var l2 : u32 = bitcast(4); + var l3 : vec2 = vec2(l0, l1); + var l4 : S; + var l5 : u32 = l4.m1[5]; + var l6 : ptr; + l6 = null; + loop { + l0 = (p1 + 2); + if (((l0 % 4) == 0)) { + continue; + } + + continuing { + if (1 == 2) { + l0 = l0 - 1; + } else { + l0 = l0 - 2; + } + } + } + switch(l2) { + case 0: { + break; + } + case 1: { + return f0(true); + } + default: { + discard; + } + } + return 1.0; +} + +[[stage(fragment)]] +fn main() -> void { + f1(1.0, 2); +} + +)"); + + // Parse the wgsl, create the src module + Context ctx; + reader::wgsl::Parser parser(&ctx, &file); + ASSERT_TRUE(parser.Parse()) << parser.error(); + auto src = parser.module(); + + // Clone the src module to dst + auto dst = src.Clone(); + + // Expect the AST printed with to_str() to match + EXPECT_EQ(src.to_str(), dst.to_str()); + + // Check that none of the AST nodes or type pointers in dst are found in src + std::unordered_set src_nodes; + for (auto& src_node : src.nodes()) { + src_nodes.emplace(src_node.get()); + } + std::unordered_set src_types; + for (auto& src_type : src.types()) { + src_types.emplace(src_type.second.get()); + } + for (auto& dst_node : dst.nodes()) { + ASSERT_EQ(src_nodes.count(dst_node.get()), 0u) << dst_node->str(); + } + for (auto& dst_type : dst.types()) { + ASSERT_EQ(src_types.count(dst_type.second.get()), 0u) + << dst_type.second->type_name(); + } + + // Regenerate the wgsl for the src module. We use this instead of the original + // source so that reformatting doesn't impact the final wgsl comparision. + // Note that the src module is moved into the generator and this generator has + // a limited scope, so that the src module is released before we attempt to + // print the dst module. + // This guarantee that all the source module nodes and types are destructed + // and freed. + // ASAN should error if there's any remaining references in dst when we try to + // reconstruct the WGSL. + std::string src_wgsl; + { + writer::wgsl::Generator src_gen(std::move(src)); + ASSERT_TRUE(src_gen.Generate()); + src_wgsl = src_gen.result(); + } + + // Print the dst module, check it matches the original source + writer::wgsl::Generator dst_gen(std::move(dst)); + ASSERT_TRUE(dst_gen.Generate()); + auto dst_wgsl = dst_gen.result(); + ASSERT_EQ(src_wgsl, dst_wgsl); + +#else // #if TINT_BUILD_WGSL_READER && TINT_BUILD_WGSL_WRITER + GTEST_SKIP() << "ModuleCloneTest requires TINT_BUILD_WGSL_READER and " + "TINT_BUILD_WGSL_WRITER to be enabled"; +#endif +} + +} // namespace +} // namespace ast +} // namespace tint diff --git a/src/ast/node.h b/src/ast/node.h index 9bfe8d6711..916ced842c 100644 --- a/src/ast/node.h +++ b/src/ast/node.h @@ -17,6 +17,7 @@ #include #include +#include #include "src/castable.h" #include "src/source.h" @@ -24,11 +25,26 @@ namespace tint { namespace ast { +class Module; +class CloneContext; + +namespace type { +class Type; +} + /// AST base class node class Node : public Castable { public: ~Node() override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + virtual Node* Clone(CloneContext* ctx) const = 0; + /// @returns the node source data const Source& source() const { return source_; } /// Sets the source data diff --git a/src/ast/null_literal.cc b/src/ast/null_literal.cc index eb3d58942d..cb2bf9496c 100644 --- a/src/ast/null_literal.cc +++ b/src/ast/null_literal.cc @@ -14,6 +14,9 @@ #include "src/ast/null_literal.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -29,5 +32,9 @@ std::string NullLiteral::name() const { return "__null" + type()->type_name(); } +NullLiteral* NullLiteral::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(type())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/null_literal.h b/src/ast/null_literal.h index 7cffcadd0b..8ae3ff03de 100644 --- a/src/ast/null_literal.h +++ b/src/ast/null_literal.h @@ -35,6 +35,14 @@ class NullLiteral : public Castable { /// @returns the literal as a string std::string to_str() const override; + + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + NullLiteral* Clone(CloneContext* ctx) const override; }; } // namespace ast diff --git a/src/ast/return_statement.cc b/src/ast/return_statement.cc index 138618fe1e..698daf3fc2 100644 --- a/src/ast/return_statement.cc +++ b/src/ast/return_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/return_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -30,6 +33,11 @@ ReturnStatement::ReturnStatement(ReturnStatement&&) = default; ReturnStatement::~ReturnStatement() = default; +ReturnStatement* ReturnStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source()), + ctx->Clone(value_)); +} + bool ReturnStatement::IsValid() const { if (value_ != nullptr) { return value_->IsValid(); diff --git a/src/ast/return_statement.h b/src/ast/return_statement.h index f2426fa6f9..c90079b891 100644 --- a/src/ast/return_statement.h +++ b/src/ast/return_statement.h @@ -51,6 +51,14 @@ class ReturnStatement : public Castable { /// @returns true if the return has a value bool has_value() const { return value_ != nullptr; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + ReturnStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/scalar_constructor_expression.cc b/src/ast/scalar_constructor_expression.cc index 2c05b4d415..7a68042e6b 100644 --- a/src/ast/scalar_constructor_expression.cc +++ b/src/ast/scalar_constructor_expression.cc @@ -14,6 +14,9 @@ #include "src/ast/scalar_constructor_expression.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -31,6 +34,12 @@ ScalarConstructorExpression::ScalarConstructorExpression( ScalarConstructorExpression::~ScalarConstructorExpression() = default; +ScalarConstructorExpression* ScalarConstructorExpression::Clone( + CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source()), + ctx->Clone(literal_)); +} + bool ScalarConstructorExpression::IsValid() const { return literal_ != nullptr; } diff --git a/src/ast/scalar_constructor_expression.h b/src/ast/scalar_constructor_expression.h index 273f7bc1a3..05ce2167d3 100644 --- a/src/ast/scalar_constructor_expression.h +++ b/src/ast/scalar_constructor_expression.h @@ -47,6 +47,14 @@ class ScalarConstructorExpression /// @returns the literal value Literal* literal() const { return literal_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + ScalarConstructorExpression* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/set_decoration.cc b/src/ast/set_decoration.cc index 4d6a776a51..0c53618498 100644 --- a/src/ast/set_decoration.cc +++ b/src/ast/set_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/set_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -27,5 +30,9 @@ void SetDecoration::to_str(std::ostream& out, size_t indent) const { out << "SetDecoration{" << value_ << "}" << std::endl; } +SetDecoration* SetDecoration::Clone(CloneContext* ctx) const { + return ctx->mod->create(value_, ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/set_decoration.h b/src/ast/set_decoration.h index 9706f86322..a58e334564 100644 --- a/src/ast/set_decoration.h +++ b/src/ast/set_decoration.h @@ -39,6 +39,14 @@ class SetDecoration : public Castable { /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + SetDecoration* Clone(CloneContext* ctx) const override; + private: uint32_t value_; }; diff --git a/src/ast/sint_literal.cc b/src/ast/sint_literal.cc index c94e7c9879..6bded6f6bf 100644 --- a/src/ast/sint_literal.cc +++ b/src/ast/sint_literal.cc @@ -14,6 +14,9 @@ #include "src/ast/sint_literal.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -30,5 +33,9 @@ std::string SintLiteral::name() const { return "__sint" + type()->type_name() + "_" + std::to_string(value_); } +SintLiteral* SintLiteral::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(type()), value_); +} + } // namespace ast } // namespace tint diff --git a/src/ast/sint_literal.h b/src/ast/sint_literal.h index dd56fb35c2..5a18406281 100644 --- a/src/ast/sint_literal.h +++ b/src/ast/sint_literal.h @@ -43,6 +43,14 @@ class SintLiteral : public Castable { /// @returns the literal as a string std::string to_str() const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + SintLiteral* Clone(CloneContext* ctx) const override; + private: int32_t value_; }; diff --git a/src/ast/stage_decoration.cc b/src/ast/stage_decoration.cc index af9676e5b6..8613167d86 100644 --- a/src/ast/stage_decoration.cc +++ b/src/ast/stage_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/stage_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -27,5 +30,9 @@ void StageDecoration::to_str(std::ostream& out, size_t indent) const { out << "StageDecoration{" << stage_ << "}" << std::endl; } +StageDecoration* StageDecoration::Clone(CloneContext* ctx) const { + return ctx->mod->create(stage_, ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/stage_decoration.h b/src/ast/stage_decoration.h index 1ae33727fb..4339f9899b 100644 --- a/src/ast/stage_decoration.h +++ b/src/ast/stage_decoration.h @@ -38,6 +38,14 @@ class StageDecoration : public Castable { /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + StageDecoration* Clone(CloneContext* ctx) const override; + private: PipelineStage stage_ = PipelineStage::kNone; }; diff --git a/src/ast/stride_decoration.cc b/src/ast/stride_decoration.cc index 6de4550e18..f41ac20224 100644 --- a/src/ast/stride_decoration.cc +++ b/src/ast/stride_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/stride_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -27,5 +30,9 @@ void StrideDecoration::to_str(std::ostream& out, size_t indent) const { out << "stride " << stride_; } +StrideDecoration* StrideDecoration::Clone(CloneContext* ctx) const { + return ctx->mod->create(stride_, ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/stride_decoration.h b/src/ast/stride_decoration.h index e3abd4bb09..f2113f7c00 100644 --- a/src/ast/stride_decoration.h +++ b/src/ast/stride_decoration.h @@ -41,6 +41,14 @@ class StrideDecoration : public Castable { /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + StrideDecoration* Clone(CloneContext* ctx) const override; + private: uint32_t stride_; }; diff --git a/src/ast/struct.cc b/src/ast/struct.cc index 5825dca1cb..bb076475af 100644 --- a/src/ast/struct.cc +++ b/src/ast/struct.cc @@ -14,6 +14,8 @@ #include "src/ast/struct.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" #include "src/ast/struct_block_decoration.h" namespace tint { @@ -61,6 +63,11 @@ bool Struct::IsBlockDecorated() const { return false; } +Struct* Struct::Clone(CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), ctx->Clone(decorations_), ctx->Clone(members_)); +} + bool Struct::IsValid() const { for (auto* mem : members_) { if (mem == nullptr || !mem->IsValid()) { diff --git a/src/ast/struct.h b/src/ast/struct.h index 419b2708f7..0b37d2824f 100644 --- a/src/ast/struct.h +++ b/src/ast/struct.h @@ -76,6 +76,14 @@ class Struct : public Castable { /// @returns true if the struct is block decorated bool IsBlockDecorated() const; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + Struct* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/struct_block_decoration.cc b/src/ast/struct_block_decoration.cc index 90fb248e41..22df5c443e 100644 --- a/src/ast/struct_block_decoration.cc +++ b/src/ast/struct_block_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/struct_block_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -27,5 +30,9 @@ void StructBlockDecoration::to_str(std::ostream& out, size_t indent) const { out << "block"; } +StructBlockDecoration* StructBlockDecoration::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/struct_block_decoration.h b/src/ast/struct_block_decoration.h index 6f55984777..732fa5ecbc 100644 --- a/src/ast/struct_block_decoration.h +++ b/src/ast/struct_block_decoration.h @@ -37,6 +37,14 @@ class StructBlockDecoration /// @param out the stream to write to /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + StructBlockDecoration* Clone(CloneContext* ctx) const override; }; /// List of struct decorations diff --git a/src/ast/struct_member.cc b/src/ast/struct_member.cc index 0905557c7c..cefbeded7e 100644 --- a/src/ast/struct_member.cc +++ b/src/ast/struct_member.cc @@ -14,6 +14,8 @@ #include "src/ast/struct_member.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" #include "src/ast/struct_member_offset_decoration.h" namespace tint { @@ -57,6 +59,11 @@ uint32_t StructMember::offset() const { return 0; } +StructMember* StructMember::Clone(CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), name_, ctx->Clone(type_), ctx->Clone(decorations_)); +} + bool StructMember::IsValid() const { if (name_.empty() || type_ == nullptr) { return false; diff --git a/src/ast/struct_member.h b/src/ast/struct_member.h index ea18502e8c..a659e8713f 100644 --- a/src/ast/struct_member.h +++ b/src/ast/struct_member.h @@ -77,6 +77,14 @@ class StructMember : public Castable { /// @returns the offset decoration value. uint32_t offset() const; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + StructMember* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/struct_member_offset_decoration.cc b/src/ast/struct_member_offset_decoration.cc index 91ed01b9d6..974b5fde7a 100644 --- a/src/ast/struct_member_offset_decoration.cc +++ b/src/ast/struct_member_offset_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/struct_member_offset_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -29,5 +32,11 @@ void StructMemberOffsetDecoration::to_str(std::ostream& out, out << "offset " << std::to_string(offset_); } +StructMemberOffsetDecoration* StructMemberOffsetDecoration::Clone( + CloneContext* ctx) const { + return ctx->mod->create(offset_, + ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/struct_member_offset_decoration.h b/src/ast/struct_member_offset_decoration.h index 1f442715ec..bfc8ca2342 100644 --- a/src/ast/struct_member_offset_decoration.h +++ b/src/ast/struct_member_offset_decoration.h @@ -42,6 +42,14 @@ class StructMemberOffsetDecoration /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + StructMemberOffsetDecoration* Clone(CloneContext* ctx) const override; + private: uint32_t offset_; }; diff --git a/src/ast/switch_statement.cc b/src/ast/switch_statement.cc index ed0d7306be..18140b6c32 100644 --- a/src/ast/switch_statement.cc +++ b/src/ast/switch_statement.cc @@ -15,6 +15,8 @@ #include "src/ast/switch_statement.h" #include "src/ast/case_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" namespace tint { namespace ast { @@ -33,6 +35,11 @@ SwitchStatement::SwitchStatement(SwitchStatement&&) = default; SwitchStatement::~SwitchStatement() = default; +SwitchStatement* SwitchStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), ctx->Clone(condition_), ctx->Clone(body_)); +} + bool SwitchStatement::IsValid() const { if (condition_ == nullptr || !condition_->IsValid()) { return false; diff --git a/src/ast/switch_statement.h b/src/ast/switch_statement.h index df7d570944..53c80ca0dd 100644 --- a/src/ast/switch_statement.h +++ b/src/ast/switch_statement.h @@ -60,6 +60,14 @@ class SwitchStatement : public Castable { /// @returns the Switch body const CaseStatementList& body() const { return body_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + SwitchStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/type/access_control_type.cc b/src/ast/type/access_control_type.cc index a419efbdc4..7a7a6768b9 100644 --- a/src/ast/type/access_control_type.cc +++ b/src/ast/type/access_control_type.cc @@ -16,6 +16,9 @@ #include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -54,6 +57,10 @@ uint64_t AccessControl::BaseAlignment(MemoryLayout mem_layout) const { return subtype_->BaseAlignment(mem_layout); } +AccessControl* AccessControl::Clone(CloneContext* ctx) const { + return ctx->mod->create(access_, ctx->Clone(subtype_)); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/access_control_type.h b/src/ast/type/access_control_type.h index c5828a6899..bffd44e6f0 100644 --- a/src/ast/type/access_control_type.h +++ b/src/ast/type/access_control_type.h @@ -60,6 +60,11 @@ class AccessControl : public Castable { /// 0 for non-host shareable types. uint64_t BaseAlignment(MemoryLayout mem_layout) const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + AccessControl* Clone(CloneContext* ctx) const override; + private: ast::AccessControl access_ = ast::AccessControl::kReadOnly; Type* subtype_ = nullptr; diff --git a/src/ast/type/alias_type.cc b/src/ast/type/alias_type.cc index 955f12f4ac..e07ed100d1 100644 --- a/src/ast/type/alias_type.cc +++ b/src/ast/type/alias_type.cc @@ -16,6 +16,9 @@ #include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -41,6 +44,10 @@ uint64_t Alias::BaseAlignment(MemoryLayout mem_layout) const { return subtype_->BaseAlignment(mem_layout); } +Alias* Alias::Clone(CloneContext* ctx) const { + return ctx->mod->create(name_, ctx->Clone(subtype_)); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/alias_type.h b/src/ast/type/alias_type.h index 7d4840a8b9..260a089a1a 100644 --- a/src/ast/type/alias_type.h +++ b/src/ast/type/alias_type.h @@ -52,6 +52,11 @@ class Alias : public Castable { /// 0 for non-host shareable types. uint64_t BaseAlignment(MemoryLayout mem_layout) const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + Alias* Clone(CloneContext* ctx) const override; + private: std::string name_; Type* subtype_ = nullptr; diff --git a/src/ast/type/array_type.cc b/src/ast/type/array_type.cc index 3883e0e7da..34cec86a03 100644 --- a/src/ast/type/array_type.cc +++ b/src/ast/type/array_type.cc @@ -15,7 +15,10 @@ #include "src/ast/type/array_type.h" #include +#include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" #include "src/ast/stride_decoration.h" #include "src/ast/type/vector_type.h" @@ -92,6 +95,12 @@ std::string Array::type_name() const { return type_name; } +Array* Array::Clone(CloneContext* ctx) const { + auto cloned = std::make_unique(ctx->Clone(subtype_), size_); + cloned->set_decorations(ctx->Clone(decorations())); + return ctx->mod->unique_type(std::move(cloned)); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/array_type.h b/src/ast/type/array_type.h index ef6f6c2fca..8c4bc9aece 100644 --- a/src/ast/type/array_type.h +++ b/src/ast/type/array_type.h @@ -74,6 +74,11 @@ class Array : public Castable { /// @returns the name for the type std::string type_name() const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + Array* Clone(CloneContext* ctx) const override; + private: Type* subtype_ = nullptr; uint32_t size_ = 0; diff --git a/src/ast/type/bool_type.cc b/src/ast/type/bool_type.cc index 55bbc3567b..da095dca9a 100644 --- a/src/ast/type/bool_type.cc +++ b/src/ast/type/bool_type.cc @@ -14,6 +14,9 @@ #include "src/ast/type/bool_type.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -28,6 +31,10 @@ std::string Bool::type_name() const { return "__bool"; } +Bool* Bool::Clone(CloneContext* ctx) const { + return ctx->mod->create(); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/bool_type.h b/src/ast/type/bool_type.h index 92b08ed184..11f5338046 100644 --- a/src/ast/type/bool_type.h +++ b/src/ast/type/bool_type.h @@ -34,6 +34,11 @@ class Bool : public Castable { /// @returns the name for this type std::string type_name() const override; + + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + Bool* Clone(CloneContext* ctx) const override; }; } // namespace type diff --git a/src/ast/type/depth_texture_type.cc b/src/ast/type/depth_texture_type.cc index c6153ee397..c4b1530cd1 100644 --- a/src/ast/type/depth_texture_type.cc +++ b/src/ast/type/depth_texture_type.cc @@ -17,6 +17,9 @@ #include #include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -47,6 +50,10 @@ std::string DepthTexture::type_name() const { return out.str(); } +DepthTexture* DepthTexture::Clone(CloneContext* ctx) const { + return ctx->mod->create(dim()); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/depth_texture_type.h b/src/ast/type/depth_texture_type.h index 49a0a1c67f..02f07d334c 100644 --- a/src/ast/type/depth_texture_type.h +++ b/src/ast/type/depth_texture_type.h @@ -35,6 +35,11 @@ class DepthTexture : public Castable { /// @returns the name for this type std::string type_name() const override; + + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + DepthTexture* Clone(CloneContext* ctx) const override; }; } // namespace type diff --git a/src/ast/type/f32_type.cc b/src/ast/type/f32_type.cc index 7712aec7c5..50823b25c2 100644 --- a/src/ast/type/f32_type.cc +++ b/src/ast/type/f32_type.cc @@ -14,6 +14,9 @@ #include "src/ast/type/f32_type.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -36,6 +39,10 @@ uint64_t F32::BaseAlignment(MemoryLayout) const { return 4; } +F32* F32::Clone(CloneContext* ctx) const { + return ctx->mod->create(); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/f32_type.h b/src/ast/type/f32_type.h index 92f25f058d..86187f3eaa 100644 --- a/src/ast/type/f32_type.h +++ b/src/ast/type/f32_type.h @@ -44,6 +44,11 @@ class F32 : public Castable { /// @returns base alignment for the type, in bytes. /// 0 for non-host shareable types. uint64_t BaseAlignment(MemoryLayout mem_layout) const override; + + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + F32* Clone(CloneContext* ctx) const override; }; } // namespace type diff --git a/src/ast/type/i32_type.cc b/src/ast/type/i32_type.cc index eced6b691c..303b93eb9d 100644 --- a/src/ast/type/i32_type.cc +++ b/src/ast/type/i32_type.cc @@ -14,6 +14,9 @@ #include "src/ast/type/i32_type.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -36,6 +39,10 @@ uint64_t I32::BaseAlignment(MemoryLayout) const { return 4; } +I32* I32::Clone(CloneContext* ctx) const { + return ctx->mod->create(); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/i32_type.h b/src/ast/type/i32_type.h index 4ad01158c6..104d593fb8 100644 --- a/src/ast/type/i32_type.h +++ b/src/ast/type/i32_type.h @@ -44,6 +44,11 @@ class I32 : public Castable { /// @returns base alignment for the type, in bytes. /// 0 for non-host shareable types. uint64_t BaseAlignment(MemoryLayout mem_layout) const override; + + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + I32* Clone(CloneContext* ctx) const override; }; } // namespace type diff --git a/src/ast/type/matrix_type.cc b/src/ast/type/matrix_type.cc index 7065b7fcba..c0210bc84d 100644 --- a/src/ast/type/matrix_type.cc +++ b/src/ast/type/matrix_type.cc @@ -16,6 +16,8 @@ #include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" #include "src/ast/type/array_type.h" #include "src/ast/type/vector_type.h" @@ -52,6 +54,10 @@ uint64_t Matrix::BaseAlignment(MemoryLayout mem_layout) const { return arr.BaseAlignment(mem_layout); } +Matrix* Matrix::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(subtype_), rows_, columns_); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/matrix_type.h b/src/ast/type/matrix_type.h index 2943a03ac9..ef39c1ac89 100644 --- a/src/ast/type/matrix_type.h +++ b/src/ast/type/matrix_type.h @@ -55,6 +55,11 @@ class Matrix : public Castable { /// 0 for non-host shareable types. uint64_t BaseAlignment(MemoryLayout mem_layout) const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + Matrix* Clone(CloneContext* ctx) const override; + private: Type* subtype_ = nullptr; uint32_t rows_ = 2; diff --git a/src/ast/type/multisampled_texture_type.cc b/src/ast/type/multisampled_texture_type.cc index d73ad18649..a321684aba 100644 --- a/src/ast/type/multisampled_texture_type.cc +++ b/src/ast/type/multisampled_texture_type.cc @@ -17,6 +17,9 @@ #include #include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -36,6 +39,10 @@ std::string MultisampledTexture::type_name() const { return out.str(); } +MultisampledTexture* MultisampledTexture::Clone(CloneContext* ctx) const { + return ctx->mod->create(dim(), ctx->Clone(type_)); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/multisampled_texture_type.h b/src/ast/type/multisampled_texture_type.h index 351e9f83e7..cdbfdb4776 100644 --- a/src/ast/type/multisampled_texture_type.h +++ b/src/ast/type/multisampled_texture_type.h @@ -40,6 +40,11 @@ class MultisampledTexture : public Castable { /// @returns the name for this type std::string type_name() const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + MultisampledTexture* Clone(CloneContext* ctx) const override; + private: Type* type_ = nullptr; }; diff --git a/src/ast/type/pointer_type.cc b/src/ast/type/pointer_type.cc index 745349145d..604eb8b892 100644 --- a/src/ast/type/pointer_type.cc +++ b/src/ast/type/pointer_type.cc @@ -14,6 +14,9 @@ #include "src/ast/type/pointer_type.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -31,6 +34,10 @@ Pointer::Pointer(Pointer&&) = default; Pointer::~Pointer() = default; +Pointer* Pointer::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(subtype_), storage_class_); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/pointer_type.h b/src/ast/type/pointer_type.h index 07a3c6beab..b22b130521 100644 --- a/src/ast/type/pointer_type.h +++ b/src/ast/type/pointer_type.h @@ -44,6 +44,11 @@ class Pointer : public Castable { /// @returns the name for this type std::string type_name() const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + Pointer* Clone(CloneContext* ctx) const override; + private: Type* subtype_; StorageClass storage_class_; diff --git a/src/ast/type/sampled_texture_type.cc b/src/ast/type/sampled_texture_type.cc index 490a5863a2..0fa62d7f82 100644 --- a/src/ast/type/sampled_texture_type.cc +++ b/src/ast/type/sampled_texture_type.cc @@ -17,6 +17,9 @@ #include #include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -36,6 +39,10 @@ std::string SampledTexture::type_name() const { return out.str(); } +SampledTexture* SampledTexture::Clone(CloneContext* ctx) const { + return ctx->mod->create(dim(), ctx->Clone(type_)); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/sampled_texture_type.h b/src/ast/type/sampled_texture_type.h index 3c6abafda3..54d9da4b42 100644 --- a/src/ast/type/sampled_texture_type.h +++ b/src/ast/type/sampled_texture_type.h @@ -40,6 +40,11 @@ class SampledTexture : public Castable { /// @returns the name for this type std::string type_name() const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + SampledTexture* Clone(CloneContext* ctx) const override; + private: Type* type_ = nullptr; }; diff --git a/src/ast/type/sampler_type.cc b/src/ast/type/sampler_type.cc index 5d61348bb7..d3eaff19b9 100644 --- a/src/ast/type/sampler_type.cc +++ b/src/ast/type/sampler_type.cc @@ -14,6 +14,9 @@ #include "src/ast/type/sampler_type.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -41,6 +44,10 @@ std::string Sampler::type_name() const { (kind_ == SamplerKind::kSampler ? "sampler" : "comparison"); } +Sampler* Sampler::Clone(CloneContext* ctx) const { + return ctx->mod->create(kind_); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/sampler_type.h b/src/ast/type/sampler_type.h index 0d2b914ae1..4e71fbbc34 100644 --- a/src/ast/type/sampler_type.h +++ b/src/ast/type/sampler_type.h @@ -52,6 +52,11 @@ class Sampler : public Castable { /// @returns the name for this type std::string type_name() const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + Sampler* Clone(CloneContext* ctx) const override; + private: SamplerKind kind_ = SamplerKind::kSampler; }; diff --git a/src/ast/type/storage_texture_type.cc b/src/ast/type/storage_texture_type.cc index 72fd648e0d..87c4657d3f 100644 --- a/src/ast/type/storage_texture_type.cc +++ b/src/ast/type/storage_texture_type.cc @@ -17,6 +17,9 @@ #include #include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -176,6 +179,10 @@ std::string StorageTexture::type_name() const { return out.str(); } +StorageTexture* StorageTexture::Clone(CloneContext* ctx) const { + return ctx->mod->create(dim(), access_, image_format_); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/storage_texture_type.h b/src/ast/type/storage_texture_type.h index d4545399bc..9f88b4d825 100644 --- a/src/ast/type/storage_texture_type.h +++ b/src/ast/type/storage_texture_type.h @@ -95,6 +95,11 @@ class StorageTexture : public Castable { /// @returns the name for this type std::string type_name() const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + StorageTexture* Clone(CloneContext* ctx) const override; + private: Type* type_ = nullptr; ast::AccessControl access_ = ast::AccessControl::kReadOnly; diff --git a/src/ast/type/struct_type.cc b/src/ast/type/struct_type.cc index fc8f1102d4..c4756ec79a 100644 --- a/src/ast/type/struct_type.cc +++ b/src/ast/type/struct_type.cc @@ -17,6 +17,8 @@ #include #include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" #include "src/ast/type/alias_type.h" #include "src/ast/type/array_type.h" #include "src/ast/type/matrix_type.h" @@ -79,6 +81,10 @@ uint64_t Struct::BaseAlignment(MemoryLayout mem_layout) const { return 0; } +Struct* Struct::Clone(CloneContext* ctx) const { + return ctx->mod->create(name_, ctx->Clone(struct_)); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/struct_type.h b/src/ast/type/struct_type.h index 615468de63..91b60c9b08 100644 --- a/src/ast/type/struct_type.h +++ b/src/ast/type/struct_type.h @@ -58,6 +58,11 @@ class Struct : public Castable { /// 0 for non-host shareable types. uint64_t BaseAlignment(MemoryLayout mem_layout) const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + Struct* Clone(CloneContext* ctx) const override; + private: std::string name_; ast::Struct* struct_ = nullptr; diff --git a/src/ast/type/type.h b/src/ast/type/type.h index 87a5739785..8eb916eb99 100644 --- a/src/ast/type/type.h +++ b/src/ast/type/type.h @@ -21,6 +21,10 @@ namespace tint { namespace ast { + +class Module; +class CloneContext; + namespace type { /// Supported memory layouts for calculating sizes @@ -33,6 +37,11 @@ class Type : public Castable { Type(Type&&); ~Type() override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + virtual Type* Clone(CloneContext* ctx) const = 0; + /// @returns the name for this type. The |type_name| is unique over all types. virtual std::string type_name() const = 0; @@ -89,6 +98,16 @@ class Type : public Castable { protected: Type(); + + /// A helper method for cloning the `Type` `t` if it is not null. + /// If `t` is null, then `Clone()` returns null. + /// @param m the module to clone `n` into + /// @param t the `Type` to clone (if not null) + /// @return the cloned type + template + static T* Clone(Module* m, const T* t) { + return (t != nullptr) ? static_cast(t->Clone(m)) : nullptr; + } }; } // namespace type diff --git a/src/ast/type/u32_type.cc b/src/ast/type/u32_type.cc index 97cf951d49..18bbb73654 100644 --- a/src/ast/type/u32_type.cc +++ b/src/ast/type/u32_type.cc @@ -14,6 +14,9 @@ #include "src/ast/type/u32_type.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -36,6 +39,10 @@ uint64_t U32::BaseAlignment(MemoryLayout) const { return 4; } +U32* U32::Clone(CloneContext* ctx) const { + return ctx->mod->create(); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/u32_type.h b/src/ast/type/u32_type.h index 9decaa3602..9e041556cb 100644 --- a/src/ast/type/u32_type.h +++ b/src/ast/type/u32_type.h @@ -44,6 +44,11 @@ class U32 : public Castable { /// @returns base alignment for the type, in bytes. /// 0 for non-host shareable types. uint64_t BaseAlignment(MemoryLayout mem_layout) const override; + + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + U32* Clone(CloneContext* ctx) const override; }; } // namespace type diff --git a/src/ast/type/vector_type.cc b/src/ast/type/vector_type.cc index c2dfaa5b1c..9f52bdc769 100644 --- a/src/ast/type/vector_type.cc +++ b/src/ast/type/vector_type.cc @@ -17,6 +17,9 @@ #include #include +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -48,6 +51,10 @@ uint64_t Vector::BaseAlignment(MemoryLayout mem_layout) const { return 0; // vectors are only supposed to have 2, 3, or 4 elements. } +Vector* Vector::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(subtype_), size_); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/vector_type.h b/src/ast/type/vector_type.h index 2186774274..9634f43232 100644 --- a/src/ast/type/vector_type.h +++ b/src/ast/type/vector_type.h @@ -52,6 +52,11 @@ class Vector : public Castable { /// 0 for non-host shareable types. uint64_t BaseAlignment(MemoryLayout mem_layout) const override; + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + Vector* Clone(CloneContext* ctx) const override; + private: Type* subtype_ = nullptr; uint32_t size_ = 2; diff --git a/src/ast/type/void_type.cc b/src/ast/type/void_type.cc index 66fc7c6b4a..6eaaec97c7 100644 --- a/src/ast/type/void_type.cc +++ b/src/ast/type/void_type.cc @@ -14,6 +14,9 @@ #include "src/ast/type/void_type.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { namespace type { @@ -28,6 +31,10 @@ std::string Void::type_name() const { return "__void"; } +Void* Void::Clone(CloneContext* ctx) const { + return ctx->mod->create(); +} + } // namespace type } // namespace ast } // namespace tint diff --git a/src/ast/type/void_type.h b/src/ast/type/void_type.h index 16d9193a69..8631188156 100644 --- a/src/ast/type/void_type.h +++ b/src/ast/type/void_type.h @@ -34,6 +34,11 @@ class Void : public Castable { /// @returns the name for this type std::string type_name() const override; + + /// Clones this type and all transitive types using the `CloneContext` `ctx`. + /// @param ctx the clone context + /// @return the newly cloned type + Void* Clone(CloneContext* ctx) const override; }; } // namespace type diff --git a/src/ast/type_constructor_expression.cc b/src/ast/type_constructor_expression.cc index b0f111febd..8717833cfb 100644 --- a/src/ast/type_constructor_expression.cc +++ b/src/ast/type_constructor_expression.cc @@ -14,6 +14,9 @@ #include "src/ast/type_constructor_expression.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -33,6 +36,12 @@ TypeConstructorExpression::TypeConstructorExpression( TypeConstructorExpression::~TypeConstructorExpression() = default; +TypeConstructorExpression* TypeConstructorExpression::Clone( + CloneContext* ctx) const { + return ctx->mod->create( + ctx->Clone(source()), ctx->Clone(type_), ctx->Clone(values_)); +} + bool TypeConstructorExpression::IsValid() const { if (values_.empty()) { return true; diff --git a/src/ast/type_constructor_expression.h b/src/ast/type_constructor_expression.h index 4ed73eede7..12b0e97208 100644 --- a/src/ast/type_constructor_expression.h +++ b/src/ast/type_constructor_expression.h @@ -56,6 +56,14 @@ class TypeConstructorExpression /// @returns the values const ExpressionList& values() const { return values_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + TypeConstructorExpression* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/uint_literal.cc b/src/ast/uint_literal.cc index 32bd0edcd1..bf00090aac 100644 --- a/src/ast/uint_literal.cc +++ b/src/ast/uint_literal.cc @@ -14,6 +14,9 @@ #include "src/ast/uint_literal.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -30,5 +33,9 @@ std::string UintLiteral::name() const { return "__uint" + std::to_string(value_); } +UintLiteral* UintLiteral::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(type()), value_); +} + } // namespace ast } // namespace tint diff --git a/src/ast/uint_literal.h b/src/ast/uint_literal.h index 33b7faac57..662fdd61b7 100644 --- a/src/ast/uint_literal.h +++ b/src/ast/uint_literal.h @@ -43,6 +43,14 @@ class UintLiteral : public Castable { /// @returns the literal as a string std::string to_str() const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + UintLiteral* Clone(CloneContext* ctx) const override; + private: uint32_t value_; }; diff --git a/src/ast/unary_op_expression.cc b/src/ast/unary_op_expression.cc index 46d53c6617..c5bc38c520 100644 --- a/src/ast/unary_op_expression.cc +++ b/src/ast/unary_op_expression.cc @@ -14,6 +14,9 @@ #include "src/ast/unary_op_expression.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -31,6 +34,11 @@ UnaryOpExpression::UnaryOpExpression(UnaryOpExpression&&) = default; UnaryOpExpression::~UnaryOpExpression() = default; +UnaryOpExpression* UnaryOpExpression::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source()), op_, + ctx->Clone(expr_)); +} + bool UnaryOpExpression::IsValid() const { return expr_ != nullptr && expr_->IsValid(); } diff --git a/src/ast/unary_op_expression.h b/src/ast/unary_op_expression.h index 7d96fc4e08..d2c5e42d9d 100644 --- a/src/ast/unary_op_expression.h +++ b/src/ast/unary_op_expression.h @@ -55,6 +55,14 @@ class UnaryOpExpression : public Castable { /// @returns the expression Expression* expr() const { return expr_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + UnaryOpExpression* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/variable.cc b/src/ast/variable.cc index 37c1f207cf..01ae068d41 100644 --- a/src/ast/variable.cc +++ b/src/ast/variable.cc @@ -16,7 +16,9 @@ #include +#include "src/ast/clone_context.h" #include "src/ast/decorated_variable.h" +#include "src/ast/module.h" namespace tint { namespace ast { @@ -36,6 +38,17 @@ Variable::Variable(Variable&&) = default; Variable::~Variable() = default; +Variable* Variable::Clone(CloneContext* ctx) const { + auto* cloned = ctx->mod->create(); + cloned->set_source(ctx->Clone(source())); + cloned->set_name(name()); + cloned->set_storage_class(storage_class()); + cloned->set_type(ctx->Clone(type())); + cloned->set_constructor(ctx->Clone(constructor())); + cloned->set_is_const(is_const()); + return cloned; +} + bool Variable::IsValid() const { if (name_.length() == 0) { return false; diff --git a/src/ast/variable.h b/src/ast/variable.h index 11c95d3847..43fe16c1e3 100644 --- a/src/ast/variable.h +++ b/src/ast/variable.h @@ -132,6 +132,14 @@ class Variable : public Castable { /// @returns true if this is a constant, false otherwise bool is_const() const { return is_const_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + Variable* Clone(CloneContext* ctx) const override; + /// @returns true if the name and path are both present bool IsValid() const override; diff --git a/src/ast/variable_decl_statement.cc b/src/ast/variable_decl_statement.cc index 9b51bf4c4c..f1ec4d35af 100644 --- a/src/ast/variable_decl_statement.cc +++ b/src/ast/variable_decl_statement.cc @@ -14,6 +14,9 @@ #include "src/ast/variable_decl_statement.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -30,6 +33,11 @@ VariableDeclStatement::VariableDeclStatement(VariableDeclStatement&&) = default; VariableDeclStatement::~VariableDeclStatement() = default; +VariableDeclStatement* VariableDeclStatement::Clone(CloneContext* ctx) const { + return ctx->mod->create(ctx->Clone(source()), + ctx->Clone(variable_)); +} + bool VariableDeclStatement::IsValid() const { return variable_ != nullptr && variable_->IsValid(); } diff --git a/src/ast/variable_decl_statement.h b/src/ast/variable_decl_statement.h index d03d5ad5ad..bdf5ed24a9 100644 --- a/src/ast/variable_decl_statement.h +++ b/src/ast/variable_decl_statement.h @@ -48,6 +48,14 @@ class VariableDeclStatement /// @returns the variable Variable* variable() const { return variable_; } + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + VariableDeclStatement* Clone(CloneContext* ctx) const override; + /// @returns true if the node is valid bool IsValid() const override; diff --git a/src/ast/workgroup_decoration.cc b/src/ast/workgroup_decoration.cc index f32d09ac91..22a78f2d1f 100644 --- a/src/ast/workgroup_decoration.cc +++ b/src/ast/workgroup_decoration.cc @@ -14,6 +14,9 @@ #include "src/ast/workgroup_decoration.h" +#include "src/ast/clone_context.h" +#include "src/ast/module.h" + namespace tint { namespace ast { @@ -39,5 +42,10 @@ void WorkgroupDecoration::to_str(std::ostream& out, size_t indent) const { << std::endl; } +WorkgroupDecoration* WorkgroupDecoration::Clone(CloneContext* ctx) const { + return ctx->mod->create(x_, y_, z_, + ctx->Clone(source())); +} + } // namespace ast } // namespace tint diff --git a/src/ast/workgroup_decoration.h b/src/ast/workgroup_decoration.h index 2345b53e5c..ac16b416fe 100644 --- a/src/ast/workgroup_decoration.h +++ b/src/ast/workgroup_decoration.h @@ -55,6 +55,14 @@ class WorkgroupDecoration /// @param indent number of spaces to indent the node when writing void to_str(std::ostream& out, size_t indent) const override; + /// Clones this node and all transitive child nodes using the `CloneContext` + /// `ctx`. + /// @note Semantic information such as resolved expression type and intrinsic + /// information is not cloned. + /// @param ctx the clone context + /// @return the newly cloned node + WorkgroupDecoration* Clone(CloneContext* ctx) const override; + private: uint32_t x_ = 1; uint32_t y_ = 1; diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index e39736cd67..2d240f01e9 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -72,12 +72,14 @@ namespace { class FakeStmt : public ast::Statement { public: + FakeStmt* Clone(ast::CloneContext*) const override { return nullptr; } bool IsValid() const override { return true; } void to_str(std::ostream& out, size_t) const override { out << "Fake"; } }; class FakeExpr : public ast::Expression { public: + FakeExpr* Clone(ast::CloneContext*) const override { return nullptr; } bool IsValid() const override { return true; } void to_str(std::ostream&, size_t) const override {} };