diff --git a/src/ast/module.cc b/src/ast/module.cc index d6cc39a08c..463b2dd1e3 100644 --- a/src/ast/module.cc +++ b/src/ast/module.cc @@ -16,6 +16,7 @@ #include +#include "src/ast/named_type.h" #include "src/program_builder.h" TINT_INSTANTIATE_TYPEINFO(tint::ast::Module); @@ -50,6 +51,17 @@ Module::Module(ProgramID program_id, Module::~Module() = default; +const ast::NamedType* Module::LookupType(Symbol name) const { + for (auto ct : ConstructedTypes()) { + if (auto* ty = ct.ast->As()) { + if (ty->name() == name) { + return ty; + } + } + } + return nullptr; +} + Module* Module::Clone(CloneContext* ctx) const { auto* out = ctx->dst->create(); out->Copy(ctx, this); @@ -80,7 +92,7 @@ void Module::to_str(const sem::Info& sem, make_indent(out, indent); out << "Module{" << std::endl; indent += 2; - for (auto* const ty : constructed_types_) { + for (auto const ty : constructed_types_) { make_indent(out, indent); if (auto* alias = ty->As()) { out << alias->symbol().to_str() << " -> " << alias->type()->type_name() diff --git a/src/ast/module.h b/src/ast/module.h index fe319bc2a6..2b6b4ddc78 100644 --- a/src/ast/module.h +++ b/src/ast/module.h @@ -19,10 +19,13 @@ #include #include "src/ast/function.h" +#include "src/ast/type.h" namespace tint { namespace ast { +class NamedType; + /// Module holds the top-level AST types, functions and global variables used by /// a Program. class Module : public Castable { @@ -58,6 +61,17 @@ class Module : public Castable { global_declarations_.push_back(var); } + /// @returns true if the module has the global declaration `decl` + /// @param decl the declaration to check + bool HasGlobalDeclaration(const Cloneable* decl) const { + for (auto* d : global_declarations_) { + if (d == decl) { + return true; + } + } + return false; + } + /// @returns the global variables for the translation unit const VariableList& GlobalVariables() const { return global_variables_; } @@ -67,14 +81,18 @@ class Module : public Castable { /// Adds a constructed type to the Builder. /// The type must be an alias or a struct. /// @param type the constructed type to add - void AddConstructedType(sem::Type* type) { + void AddConstructedType(typ::Type type) { TINT_ASSERT(type); constructed_types_.push_back(type); - global_declarations_.push_back(type); + global_declarations_.push_back(const_cast(type.sem)); } + /// @returns the NamedType registered as a ConstructedType() + /// @param name the name of the type to search for + const ast::NamedType* LookupType(Symbol name) const; + /// @returns the constructed types in the translation unit - const std::vector& ConstructedTypes() const { + const std::vector& ConstructedTypes() const { return constructed_types_; } @@ -115,7 +133,7 @@ class Module : public Castable { private: std::vector global_declarations_; - std::vector constructed_types_; + std::vector constructed_types_; FunctionList functions_; VariableList global_variables_; }; diff --git a/src/reader/wgsl/parser_impl_global_decl_test.cc b/src/reader/wgsl/parser_impl_global_decl_test.cc index 79fddcbb96..72a0d49fdf 100644 --- a/src/reader/wgsl/parser_impl_global_decl_test.cc +++ b/src/reader/wgsl/parser_impl_global_decl_test.cc @@ -163,7 +163,7 @@ TEST_F(ParserImplTest, GlobalDecl_ParsesStruct) { auto program = p->program(); ASSERT_EQ(program.AST().ConstructedTypes().size(), 1u); - auto* t = program.AST().ConstructedTypes()[0]; + auto t = program.AST().ConstructedTypes()[0]; ASSERT_NE(t, nullptr); ASSERT_TRUE(t->Is()); @@ -181,7 +181,7 @@ TEST_F(ParserImplTest, GlobalDecl_Struct_WithStride) { auto program = p->program(); ASSERT_EQ(program.AST().ConstructedTypes().size(), 1u); - auto* t = program.AST().ConstructedTypes()[0]; + auto t = program.AST().ConstructedTypes()[0]; ASSERT_NE(t, nullptr); ASSERT_TRUE(t->Is()); @@ -208,7 +208,7 @@ TEST_F(ParserImplTest, GlobalDecl_Struct_WithDecoration) { auto program = p->program(); ASSERT_EQ(program.AST().ConstructedTypes().size(), 1u); - auto* t = program.AST().ConstructedTypes()[0]; + auto t = program.AST().ConstructedTypes()[0]; ASSERT_NE(t, nullptr); ASSERT_TRUE(t->Is()); diff --git a/src/transform/canonicalize_entry_point_io.cc b/src/transform/canonicalize_entry_point_io.cc index d31b2d05c6..ab090151fa 100644 --- a/src/transform/canonicalize_entry_point_io.cc +++ b/src/transform/canonicalize_entry_point_io.cc @@ -65,7 +65,7 @@ Output CanonicalizeEntryPointIO::Run(const Program* in, const DataMap&) { // Strip entry point IO decorations from struct declarations. // TODO(jrprice): This code is duplicated with the SPIR-V transform. - for (auto* ty : ctx.src->AST().ConstructedTypes()) { + for (auto ty : ctx.src->AST().ConstructedTypes()) { if (auto* struct_ty = ty->As()) { // Build new list of struct members without entry point IO decorations. ast::StructMemberList new_struct_members; diff --git a/src/transform/spirv.cc b/src/transform/spirv.cc index 47b8aa627e..29c099ed63 100644 --- a/src/transform/spirv.cc +++ b/src/transform/spirv.cc @@ -110,7 +110,7 @@ void Spirv::HandleEntryPointIOTypes(CloneContext& ctx) const { // ``` // Strip entry point IO decorations from struct declarations. - for (auto* ty : ctx.src->AST().ConstructedTypes()) { + for (auto ty : ctx.src->AST().ConstructedTypes()) { if (auto* struct_ty = ty->As()) { // Build new list of struct members without entry point IO decorations. ast::StructMemberList new_struct_members; diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index e4ed9fd874..defa30fd73 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -119,7 +119,7 @@ bool GeneratorImpl::Generate(std::ostream& out) { register_global(global); } - for (auto* const ty : builder_.AST().ConstructedTypes()) { + for (auto const ty : builder_.AST().ConstructedTypes()) { if (!EmitConstructedType(out, ty)) { return false; } diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index 1c239a77a6..f3170bc0cf 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -86,7 +86,7 @@ bool GeneratorImpl::Generate() { global_variables_.set(global->symbol(), sem); } - for (auto* const ty : program_->AST().ConstructedTypes()) { + for (auto const ty : program_->AST().ConstructedTypes()) { if (!EmitConstructedType(ty)) { return false; }