diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h index 103c8c44ea..f50e17542d 100644 --- a/src/ast/block_statement.h +++ b/src/ast/block_statement.h @@ -65,6 +65,9 @@ class BlockStatement : public Castable { /// @returns the ending iterator StatementList::const_iterator end() const { return statements_.end(); } + /// @returns the statement list + const StatementList& statements() const { return statements_; } + /// Clones this node and all transitive child nodes using the `CloneContext` /// `ctx`. /// @param ctx the clone context diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index ae04b1597c..a6ef97e1e7 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -78,9 +78,10 @@ Resolver::Resolver(ProgramBuilder* builder) Resolver::~Resolver() = default; -Resolver::BlockInfo::BlockInfo(Resolver::BlockInfo::Type ty, +Resolver::BlockInfo::BlockInfo(const ast::BlockStatement* b, + Resolver::BlockInfo::Type ty, Resolver::BlockInfo* p) - : type(ty), parent(p) {} + : block(b), type(ty), parent(p) {} Resolver::BlockInfo::~BlockInfo() = default; @@ -370,7 +371,7 @@ bool Resolver::Function(ast::Function* func) { } bool Resolver::BlockStatement(const ast::BlockStatement* stmt) { - return BlockScope(BlockInfo::Type::kGeneric, + return BlockScope(stmt, BlockInfo::Type::kGeneric, [&] { return Statements(stmt->list()); }); } @@ -384,7 +385,9 @@ bool Resolver::Statements(const ast::StatementList& stmts) { } bool Resolver::Statement(ast::Statement* stmt) { - auto* sem_statement = builder_->create(stmt); + auto* sem_statement = + builder_->create(stmt, current_block_->block); + builder_->Sem().Add(stmt, sem_statement); ScopedAssignment sa(current_statement_, sem_statement); @@ -427,9 +430,6 @@ bool Resolver::Statement(ast::Statement* stmt) { if (stmt->Is()) { return true; } - if (auto* e = stmt->As()) { - return Expression(e->condition()) && BlockStatement(e->body()); - } if (stmt->Is()) { return true; } @@ -441,13 +441,13 @@ bool Resolver::Statement(ast::Statement* stmt) { // these would make their BlockInfo siblings as in the AST, but we want the // body BlockInfo to parent the continuing BlockInfo for semantics and // validation. Also, we need to set their types differently. - return BlockScope(BlockInfo::Type::kLoop, [&] { + return BlockScope(l->body(), BlockInfo::Type::kLoop, [&] { if (!Statements(l->body()->list())) { return false; } if (l->has_continuing()) { - if (!BlockScope(BlockInfo::Type::kLoopContinuing, + if (!BlockScope(l->continuing(), BlockInfo::Type::kLoopContinuing, [&] { return Statements(l->continuing()->list()); })) { return false; } @@ -473,7 +473,7 @@ bool Resolver::Statement(ast::Statement* stmt) { } bool Resolver::CaseStatement(ast::CaseStatement* stmt) { - return BlockScope(BlockInfo::Type::kSwitchCase, + return BlockScope(stmt->body(), BlockInfo::Type::kSwitchCase, [&] { return Statements(stmt->body()->list()); }); } @@ -495,7 +495,18 @@ bool Resolver::IfStatement(ast::IfStatement* stmt) { } for (auto* else_stmt : stmt->else_statements()) { - if (!Statement(else_stmt)) { + // Else statements are a bit unusual - they're owned by the if-statement, + // not a BlockStatement. + constexpr ast::BlockStatement* no_block_statement = nullptr; + auto* sem_else_stmt = + builder_->create(else_stmt, no_block_statement); + builder_->Sem().Add(else_stmt, sem_else_stmt); + ScopedAssignment sa(current_statement_, + sem_else_stmt); + if (!Expression(else_stmt->condition())) { + return false; + } + if (!BlockStatement(else_stmt->body())) { return false; } } @@ -1923,8 +1934,10 @@ bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, } template -bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) { - BlockInfo block_info(type, current_block_); +bool Resolver::BlockScope(const ast::BlockStatement* block, + BlockInfo::Type type, + F&& callback) { + BlockInfo block_info(block, type, current_block_); ScopedAssignment sa(current_block_, &block_info); variable_stack_.push_scope(); bool result = callback(); diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index be29997854..709b2be4d5 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -149,7 +149,7 @@ class Resolver { struct BlockInfo { enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase }; - BlockInfo(Type type, BlockInfo* parent); + BlockInfo(const ast::BlockStatement* block, Type type, BlockInfo* parent); ~BlockInfo(); template @@ -166,6 +166,7 @@ class Resolver { [ty](auto* block_info) { return block_info->type == ty; }); } + ast::BlockStatement const* const block; Type const type; BlockInfo* const parent; std::vector decls; @@ -279,7 +280,9 @@ class Resolver { /// its parent, assigns this to #current_block_, and then calls `callback`. /// The original #current_block_ is restored on exit. template - bool BlockScope(BlockInfo::Type type, F&& callback); + bool BlockScope(const ast::BlockStatement* block, + BlockInfo::Type type, + F&& callback); /// Returns a human-readable string representation of the vector type name /// with the given parameters. diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index fa0e933169..eadc0bcc3e 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -77,10 +77,10 @@ TEST_F(ResolverTest, Stmt_Case) { auto* rhs = Expr(2.3f); auto* assign = create(lhs, rhs); - auto* body = Block(assign); + auto* block = Block(assign); ast::CaseSelectorList lit; lit.push_back(create(ty.i32(), 3)); - auto* cse = create(lit, body); + auto* cse = create(lit, block); WrapInFunction(v, cse); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -91,6 +91,7 @@ TEST_F(ResolverTest, Stmt_Case) { EXPECT_TRUE(TypeOf(rhs)->Is()); EXPECT_EQ(StmtOf(lhs), assign); EXPECT_EQ(StmtOf(rhs), assign); + EXPECT_EQ(BlockOf(assign), block); } TEST_F(ResolverTest, Stmt_Block) { @@ -110,30 +111,9 @@ TEST_F(ResolverTest, Stmt_Block) { EXPECT_TRUE(TypeOf(rhs)->Is()); EXPECT_EQ(StmtOf(lhs), assign); EXPECT_EQ(StmtOf(rhs), assign); -} - -TEST_F(ResolverTest, Stmt_Else) { - auto* v = Var("v", ty.f32(), ast::StorageClass::kFunction); - auto* lhs = Expr("v"); - auto* rhs = Expr(2.3f); - - auto* assign = create(lhs, rhs); - auto* body = Block(assign); - auto* cond = Expr(3); - auto* stmt = create(cond, body); - WrapInFunction(v, stmt); - - EXPECT_TRUE(r()->Resolve()) << r()->error(); - - ASSERT_NE(TypeOf(stmt->condition()), nullptr); - ASSERT_NE(TypeOf(lhs), nullptr); - ASSERT_NE(TypeOf(rhs), nullptr); - EXPECT_TRUE(TypeOf(stmt->condition())->Is()); - EXPECT_TRUE(TypeOf(lhs)->UnwrapAll()->Is()); - EXPECT_TRUE(TypeOf(rhs)->Is()); - EXPECT_EQ(StmtOf(lhs), assign); - EXPECT_EQ(StmtOf(rhs), assign); - EXPECT_EQ(StmtOf(cond), stmt); + EXPECT_EQ(BlockOf(lhs), block); + EXPECT_EQ(BlockOf(rhs), block); + EXPECT_EQ(BlockOf(assign), block); } TEST_F(ResolverTest, Stmt_If) { @@ -172,6 +152,10 @@ TEST_F(ResolverTest, Stmt_If) { EXPECT_EQ(StmtOf(rhs), assign); EXPECT_EQ(StmtOf(cond), stmt); EXPECT_EQ(StmtOf(else_cond), else_stmt); + EXPECT_EQ(BlockOf(lhs), body); + EXPECT_EQ(BlockOf(rhs), body); + EXPECT_EQ(BlockOf(else_lhs), else_body); + EXPECT_EQ(BlockOf(else_rhs), else_body); } TEST_F(ResolverTest, Stmt_Loop) { @@ -199,6 +183,10 @@ TEST_F(ResolverTest, Stmt_Loop) { EXPECT_TRUE(TypeOf(body_rhs)->Is()); EXPECT_TRUE(TypeOf(continuing_lhs)->UnwrapAll()->Is()); EXPECT_TRUE(TypeOf(continuing_rhs)->Is()); + EXPECT_EQ(BlockOf(body_lhs), body); + EXPECT_EQ(BlockOf(body_rhs), body); + EXPECT_EQ(BlockOf(continuing_lhs), continuing); + EXPECT_EQ(BlockOf(continuing_rhs), continuing); } TEST_F(ResolverTest, Stmt_Return) { @@ -224,9 +212,8 @@ TEST_F(ResolverTest, Stmt_Switch) { auto* v = Var("v", ty.f32(), ast::StorageClass::kFunction); auto* lhs = Expr("v"); auto* rhs = Expr(2.3f); - - auto* stmt = - Switch(Expr(2), Case(Literal(3), Block(Assign(lhs, rhs))), DefaultCase()); + auto* case_block = Block(Assign(lhs, rhs)); + auto* stmt = Switch(Expr(2), Case(Literal(3), case_block), DefaultCase()); WrapInFunction(v, stmt); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -238,6 +225,8 @@ TEST_F(ResolverTest, Stmt_Switch) { EXPECT_TRUE(TypeOf(stmt->condition())->Is()); EXPECT_TRUE(TypeOf(lhs)->UnwrapAll()->Is()); EXPECT_TRUE(TypeOf(rhs)->Is()); + EXPECT_EQ(BlockOf(lhs), case_block); + EXPECT_EQ(BlockOf(rhs), case_block); } TEST_F(ResolverTest, Stmt_Call) { diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h index 57d820d838..79afa790f0 100644 --- a/src/resolver/resolver_test_helper.h +++ b/src/resolver/resolver_test_helper.h @@ -44,11 +44,29 @@ class TestHelper : public ProgramBuilder { /// @param expr the ast::Expression /// @return the ast::Statement of the ast::Expression, or nullptr if the /// expression is not owned by a statement. - ast::Statement* StmtOf(ast::Expression* expr) { + const ast::Statement* StmtOf(ast::Expression* expr) { auto* sem_stmt = Sem().Get(expr)->Stmt(); return sem_stmt ? sem_stmt->Declaration() : nullptr; } + /// Returns the BlockStatement that holds the given statement. + /// @param stmt the ast::Statment + /// @return the ast::BlockStatement that holds the ast::Statement, or nullptr + /// if the statement is not owned by a BlockStatement. + const ast::BlockStatement* BlockOf(ast::Statement* stmt) { + auto* sem_stmt = Sem().Get(stmt); + return sem_stmt ? sem_stmt->Block() : nullptr; + } + + /// Returns the BlockStatement that holds the given expression. + /// @param expr the ast::Expression + /// @return the ast::Statement of the ast::Expression, or nullptr if the + /// expression is not indirectly owned by a BlockStatement. + const ast::BlockStatement* BlockOf(ast::Expression* expr) { + auto* sem_stmt = Sem().Get(expr)->Stmt(); + return sem_stmt ? sem_stmt->Block() : nullptr; + } + /// Checks that all the users of the given variable are as expected /// @param var the variable to check /// @param expected_users the expected users of the variable diff --git a/src/semantic/sem_statement.cc b/src/semantic/sem_statement.cc index a614b42d4b..811842f08a 100644 --- a/src/semantic/sem_statement.cc +++ b/src/semantic/sem_statement.cc @@ -12,6 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + +#include "src/ast/block_statement.h" +#include "src/debug.h" #include "src/semantic/statement.h" TINT_INSTANTIATE_TYPEINFO(tint::semantic::Statement); @@ -19,7 +23,17 @@ TINT_INSTANTIATE_TYPEINFO(tint::semantic::Statement); namespace tint { namespace semantic { -Statement::Statement(ast::Statement* declaration) : declaration_(declaration) {} +Statement::Statement(const ast::Statement* declaration, + const ast::BlockStatement* block) + : declaration_(declaration), block_(block) { +#ifndef NDEBUG + if (block) { + auto& stmts = block->statements(); + TINT_ASSERT(std::find(stmts.begin(), stmts.end(), declaration) != + stmts.end()); + } +#endif // NDEBUG +} } // namespace semantic } // namespace tint diff --git a/src/semantic/statement.h b/src/semantic/statement.h index 0782bcb888..ea4d5349a1 100644 --- a/src/semantic/statement.h +++ b/src/semantic/statement.h @@ -21,6 +21,7 @@ namespace tint { // Forward declarations namespace ast { +class BlockStatement; class Statement; } // namespace ast @@ -31,13 +32,19 @@ class Statement : public Castable { public: /// Constructor /// @param declaration the AST node for this statement - explicit Statement(ast::Statement* declaration); + /// @param block the owning AST block statement + Statement(const ast::Statement* declaration, + const ast::BlockStatement* block); /// @return the AST node for this statement - ast::Statement* Declaration() const { return declaration_; } + const ast::Statement* Declaration() const { return declaration_; } + + /// @return the owning AST block statement for this statement + const ast::BlockStatement* Block() const { return block_; } private: - ast::Statement* const declaration_; + ast::Statement const* const declaration_; + ast::BlockStatement const* const block_; }; } // namespace semantic diff --git a/src/semantic/type_mappings.h b/src/semantic/type_mappings.h index b0e69f4b02..47fdfe9e22 100644 --- a/src/semantic/type_mappings.h +++ b/src/semantic/type_mappings.h @@ -25,6 +25,7 @@ class CallExpression; class Expression; class Function; class MemberAccessorExpression; +class Statement; class StructMember; class Variable; } // namespace ast @@ -41,6 +42,7 @@ class Call; class Expression; class Function; class MemberAccessorExpression; +class Statement; class Struct; class StructMember; class Variable; @@ -56,6 +58,7 @@ struct TypeMappings { Expression* operator()(ast::Expression*); Function* operator()(ast::Function*); MemberAccessorExpression* operator()(ast::MemberAccessorExpression*); + Statement* operator()(ast::Statement*); Struct* operator()(type::Struct*); StructMember* operator()(ast::StructMember*); Variable* operator()(ast::Variable*);