diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 87322b60e3..7011e468bf 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -1631,20 +1631,30 @@ void Resolver::CreateSemanticNodes() const { for (auto it : variable_to_info_) { auto* var = it.first; auto* info = it.second; - std::vector users; + auto* sem_var = builder_->create(var, info->type, + info->storage_class); + std::vector users; for (auto* user : info->users) { // Create semantic node for the identifier expression if necessary auto* sem_expr = sem.Get(user); if (sem_expr == nullptr) { auto* type = expr_info_.at(user).type; auto* stmt = expr_info_.at(user).statement; - sem_expr = builder_->create(user, type, stmt); - sem.Add(user, sem_expr); + auto* sem_user = + builder_->create(user, type, stmt, sem_var); + sem_var->AddUser(sem_user); + sem.Add(user, sem_user); + } else { + auto* sem_user = sem_expr->As(); + if (!sem_user) { + TINT_ICE(builder_->Diagnostics()) + << "expected semantic::VariableUser, got " + << sem_expr->TypeInfo().name; + } + sem_var->AddUser(sem_user); } - users.push_back(sem_expr); } - sem.Add(var, builder_->create( - var, info->type, info->storage_class, std::move(users))); + sem.Add(var, sem_var); } auto remap_vars = [&sem](const std::vector& in) { diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index e3d7edf2b1..76c99c3d93 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -135,7 +135,7 @@ class Resolver { StructInfo(); ~StructInfo(); - std::vector members; + std::vector members; uint32_t align = 0; uint32_t size = 0; uint32_t size_no_padding = 0; diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index fd2c15185f..8b02b933b2 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -338,6 +338,10 @@ TEST_F(ResolverTest, Stmt_VariableDecl_OuterScopeAfterInnerScope) { EXPECT_EQ(StmtOf(bar_f32_init), bar_f32_decl); EXPECT_TRUE(CheckVarUsers(foo_i32, {bar_i32->constructor()})); EXPECT_TRUE(CheckVarUsers(foo_f32, {bar_f32->constructor()})); + ASSERT_NE(VarOf(bar_i32->constructor()), nullptr); + EXPECT_EQ(VarOf(bar_i32->constructor())->Declaration(), foo_i32); + ASSERT_NE(VarOf(bar_f32->constructor()), nullptr); + EXPECT_EQ(VarOf(bar_f32->constructor())->Declaration(), foo_f32); } TEST_F(ResolverTest, Stmt_VariableDecl_ModuleScopeAfterFunctionScope) { @@ -383,6 +387,8 @@ TEST_F(ResolverTest, Stmt_VariableDecl_ModuleScopeAfterFunctionScope) { EXPECT_EQ(StmtOf(fn_f32_init), fn_f32_decl); EXPECT_TRUE(CheckVarUsers(fn_i32, {})); EXPECT_TRUE(CheckVarUsers(mod_f32, {fn_f32->constructor()})); + ASSERT_NE(VarOf(fn_f32->constructor()), nullptr); + EXPECT_EQ(VarOf(fn_f32->constructor())->Declaration(), mod_f32); } TEST_F(ResolverTest, Expr_ArrayAccessor_Array) { @@ -612,6 +618,8 @@ TEST_F(ResolverTest, Expr_Identifier_GlobalVariable) { EXPECT_TRUE(TypeOf(ident)->Is()); EXPECT_TRUE(TypeOf(ident)->As()->type()->Is()); EXPECT_TRUE(CheckVarUsers(my_var, {ident})); + ASSERT_NE(VarOf(ident), nullptr); + EXPECT_EQ(VarOf(ident)->Declaration(), my_var); } TEST_F(ResolverTest, Expr_Identifier_GlobalConstant) { @@ -625,6 +633,8 @@ TEST_F(ResolverTest, Expr_Identifier_GlobalConstant) { ASSERT_NE(TypeOf(ident), nullptr); EXPECT_TRUE(TypeOf(ident)->Is()); EXPECT_TRUE(CheckVarUsers(my_var, {ident})); + ASSERT_NE(VarOf(ident), nullptr); + EXPECT_EQ(VarOf(ident)->Declaration(), my_var); } TEST_F(ResolverTest, Expr_Identifier_FunctionVariable_Const) { @@ -645,6 +655,8 @@ TEST_F(ResolverTest, Expr_Identifier_FunctionVariable_Const) { EXPECT_TRUE(TypeOf(my_var_a)->Is()); EXPECT_EQ(StmtOf(my_var_a), decl); EXPECT_TRUE(CheckVarUsers(var, {my_var_a})); + ASSERT_NE(VarOf(my_var_a), nullptr); + EXPECT_EQ(VarOf(my_var_a)->Declaration(), var); } TEST_F(ResolverTest, Expr_Identifier_FunctionVariable) { @@ -672,6 +684,10 @@ TEST_F(ResolverTest, Expr_Identifier_FunctionVariable) { EXPECT_TRUE(TypeOf(my_var_b)->As()->type()->Is()); EXPECT_EQ(StmtOf(my_var_b), assign); EXPECT_TRUE(CheckVarUsers(var, {my_var_a, my_var_b})); + ASSERT_NE(VarOf(my_var_a), nullptr); + EXPECT_EQ(VarOf(my_var_a)->Declaration(), var); + ASSERT_NE(VarOf(my_var_b), nullptr); + EXPECT_EQ(VarOf(my_var_b)->Declaration(), var); } TEST_F(ResolverTest, Expr_Identifier_Function_Ptr) { diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h index 79afa790f0..cbcb2a63a1 100644 --- a/src/resolver/resolver_test_helper.h +++ b/src/resolver/resolver_test_helper.h @@ -67,6 +67,17 @@ class TestHelper : public ProgramBuilder { return sem_stmt ? sem_stmt->Block() : nullptr; } + /// Returns the semantic variable for the given identifier expression. + /// @param expr the identifier expression + /// @return the resolved semantic::Variable of the identifier, or nullptr if + /// the expression did not resolve to a variable. + const semantic::Variable* VarOf(ast::Expression* expr) { + auto* sem_ident = Sem().Get(expr); + auto* var_user = + sem_ident ? sem_ident->As() : nullptr; + return var_user ? var_user->Variable() : nullptr; + } + /// Checks that all the users of the given variable are as expected /// @param var the variable to check /// @param expected_users the expected users of the variable diff --git a/src/semantic/expression.h b/src/semantic/expression.h index 6b9b4b182e..e361658d1a 100644 --- a/src/semantic/expression.h +++ b/src/semantic/expression.h @@ -37,9 +37,9 @@ class Expression : public Castable { /// @param declaration the AST node /// @param type the resolved type of the expression /// @param statement the statement that owns this expression - explicit Expression(ast::Expression* declaration, - type::Type* type, - Statement* statement); + Expression(ast::Expression* declaration, + type::Type* type, + Statement* statement); /// @return the resolved type of the expression type::Type* Type() const { return type_; } diff --git a/src/semantic/sem_variable.cc b/src/semantic/sem_variable.cc index 03dc0344bf..13b3820bef 100644 --- a/src/semantic/sem_variable.cc +++ b/src/semantic/sem_variable.cc @@ -14,21 +14,26 @@ #include "src/semantic/variable.h" +#include "src/ast/identifier_expression.h" + TINT_INSTANTIATE_TYPEINFO(tint::semantic::Variable); +TINT_INSTANTIATE_TYPEINFO(tint::semantic::VariableUser); namespace tint { namespace semantic { Variable::Variable(const ast::Variable* declaration, type::Type* type, - ast::StorageClass storage_class, - std::vector users) - : declaration_(declaration), - type_(type), - storage_class_(storage_class), - users_(std::move(users)) {} + ast::StorageClass storage_class) + : declaration_(declaration), type_(type), storage_class_(storage_class) {} Variable::~Variable() = default; +VariableUser::VariableUser(ast::IdentifierExpression* declaration, + type::Type* type, + Statement* statement, + semantic::Variable* variable) + : Base(declaration, type, statement), variable_(variable) {} + } // namespace semantic } // namespace tint diff --git a/src/semantic/struct.h b/src/semantic/struct.h index 58cd91edfd..3cd16964a3 100644 --- a/src/semantic/struct.h +++ b/src/semantic/struct.h @@ -38,7 +38,7 @@ namespace semantic { class StructMember; /// A vector of StructMember pointers. -using StructMemberList = std::vector; +using StructMemberList = std::vector; /// Metadata to capture how a structure is used in a shader module. enum class PipelineStageUsage { diff --git a/src/semantic/variable.h b/src/semantic/variable.h index 0e606a27c7..7cae490897 100644 --- a/src/semantic/variable.h +++ b/src/semantic/variable.h @@ -24,6 +24,7 @@ namespace tint { // Forward declarations namespace ast { +class IdentifierExpression; class Variable; } // namespace ast namespace type { @@ -32,6 +33,8 @@ class Type; namespace semantic { +class VariableUser; + /// Variable holds the semantic information for variables. class Variable : public Castable { public: @@ -39,11 +42,9 @@ class Variable : public Castable { /// @param declaration the AST declaration node /// @param type the variable type /// @param storage_class the variable storage class - /// @param users the expressions that use the variable - explicit Variable(const ast::Variable* declaration, - type::Type* type, - ast::StorageClass storage_class, - std::vector users); + Variable(const ast::Variable* declaration, + type::Type* type, + ast::StorageClass storage_class); /// Destructor ~Variable() override; @@ -58,13 +59,37 @@ class Variable : public Castable { ast::StorageClass StorageClass() const { return storage_class_; } /// @returns the expressions that use the variable - const std::vector& Users() const { return users_; } + const std::vector& Users() const { return users_; } + + /// @param user the user to add + void AddUser(const VariableUser* user) { users_.emplace_back(user); } private: const ast::Variable* const declaration_; type::Type* const type_; ast::StorageClass const storage_class_; - std::vector const users_; + std::vector users_; +}; + +/// VariableUser holds the semantic information for an identifier expression +/// node that resolves to a variable. +class VariableUser : public Castable { + public: + /// Constructor + /// @param declaration the AST identifier node + /// @param type the resolved type of the expression + /// @param statement the statement that owns this expression + /// @param variable the semantic variable + VariableUser(ast::IdentifierExpression* declaration, + type::Type* type, + Statement* statement, + semantic::Variable* variable); + + /// @returns the variable that this expression refers to + const semantic::Variable* Variable() const { return variable_; } + + private: + semantic::Variable const* const variable_; }; } // namespace semantic