diff --git a/src/ast/module.cc b/src/ast/module.cc index e9e3a63c7d..326dd3a02f 100644 --- a/src/ast/module.cc +++ b/src/ast/module.cc @@ -29,18 +29,33 @@ namespace ast { Module::Module(const Source& source) : Base(source) {} -Module::Module(const Source& source, - std::vector constructed_types, - FunctionList functions, - VariableList global_variables) - : Base(source), - constructed_types_(std::move(constructed_types)), - functions_(std::move(functions)), - global_variables_(std::move(global_variables)) {} +Module::Module(const Source& source, std::vector global_decls) + : Base(source), global_declarations_(std::move(global_decls)) { + for (auto* decl : global_declarations_) { + if (decl == nullptr) { + continue; + } + + if (auto* ty = decl->As()) { + constructed_types_.push_back(ty); + } else if (auto* func = decl->As()) { + functions_.push_back(func); + } else if (auto* var = decl->As()) { + global_variables_.push_back(var); + } else { + assert(false /* unreachable */); + } + } +} Module::~Module() = default; bool Module::IsValid() const { + for (auto* decl : global_declarations_) { + if (decl == nullptr) { + return false; + } + } for (auto* var : global_variables_) { if (var == nullptr || !var->IsValid()) { return false; @@ -76,9 +91,20 @@ bool Module::IsValid() const { } Module* Module::Clone(CloneContext* ctx) const { - return ctx->dst->create(ctx->Clone(constructed_types_), - ctx->Clone(functions_), - ctx->Clone(global_variables_)); + std::vector global_decls; + for (auto* decl : global_declarations_) { + assert(decl); + if (auto* ty = decl->As()) { + global_decls.push_back(ctx->Clone(ty)); + } else if (auto* func = decl->As()) { + global_decls.push_back(ctx->Clone(func)); + } else if (auto* var = decl->As()) { + global_decls.push_back(ctx->Clone(var)); + } else { + assert(false /* unreachable */); + } + } + return ctx->dst->create(global_decls); } void Module::to_str(const semantic::Info& sem, diff --git a/src/ast/module.h b/src/ast/module.h index 8ecedf3ede..416eb80b3c 100644 --- a/src/ast/module.h +++ b/src/ast/module.h @@ -36,21 +36,23 @@ class Module : public Castable { /// Constructor /// @param source the source of the module - /// @param constructed_types the list of types explicitly declared in the AST - /// @param functions the list of program functions - /// @param global_variables the list of global variables - Module(const Source& source, - std::vector constructed_types, - FunctionList functions, - VariableList global_variables); + /// @param global_decls the list of global types, functions, and variables, in + /// the order they were declared in the source program + Module(const Source& source, std::vector global_decls); /// Destructor ~Module() override; + /// @returns the ordered global declarations for the translation unit + const std::vector& GlobalDeclarations() const { + return global_declarations_; + } + /// Add a global variable to the Builder /// @param var the variable to add void AddGlobalVariable(ast::Variable* var) { global_variables_.push_back(var); + global_declarations_.push_back(var); } /// @returns the global variables for the translation unit @@ -64,6 +66,7 @@ class Module : public Castable { /// @param type the constructed type to add void AddConstructedType(type::Type* type) { constructed_types_.push_back(type); + global_declarations_.push_back(type); } /// @returns the constructed types in the translation unit @@ -73,7 +76,10 @@ class Module : public Castable { /// Add a function to the Builder /// @param func the function to add - void AddFunction(ast::Function* func) { functions_.push_back(func); } + void AddFunction(ast::Function* func) { + functions_.push_back(func); + global_declarations_.push_back(func); + } /// @returns the functions declared in the translation unit const FunctionList& Functions() const { return functions_; } @@ -102,6 +108,7 @@ class Module : public Castable { std::string to_str(const semantic::Info& sem) const; private: + std::vector global_declarations_; std::vector constructed_types_; FunctionList functions_; VariableList global_variables_; diff --git a/src/program.cc b/src/program.cc index 09447aa496..a2b534bf58 100644 --- a/src/program.cc +++ b/src/program.cc @@ -57,8 +57,7 @@ Program::Program(ProgramBuilder&& builder) { ast_nodes_ = std::move(builder.ASTNodes()); sem_nodes_ = std::move(builder.SemNodes()); ast_ = ast_nodes_.Create( - Source{}, builder.AST().ConstructedTypes(), builder.AST().Functions(), - builder.AST().GlobalVariables()); + Source{}, std::move(builder.AST().GlobalDeclarations())); sem_ = std::move(builder.Sem()); symbols_ = std::move(builder.Symbols()); diagnostics_.add(std::move(builder.Diagnostics())); diff --git a/src/program_builder.cc b/src/program_builder.cc index bc3727c492..780da734d3 100644 --- a/src/program_builder.cc +++ b/src/program_builder.cc @@ -60,8 +60,7 @@ ProgramBuilder ProgramBuilder::Wrap(const Program* program) { ProgramBuilder builder; builder.types_ = type::Manager::Wrap(program->Types()); builder.ast_ = builder.create( - program->AST().source(), program->AST().ConstructedTypes(), - program->AST().Functions(), program->AST().GlobalVariables()); + program->AST().source(), program->AST().GlobalDeclarations()); builder.sem_ = semantic::Info::Wrap(program->Sem()); builder.symbols_ = program->Symbols(); builder.diagnostics_ = program->Diagnostics();