diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index cfb3e19d9d..da144c7b30 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -66,9 +66,8 @@ Resolver::Resolver(ProgramBuilder* builder) Resolver::~Resolver() = default; Resolver::BlockInfo::BlockInfo(Resolver::BlockInfo::Type ty, - Resolver::BlockInfo* p, - const ast::BlockStatement* b) - : type(ty), parent(p), block(b) {} + Resolver::BlockInfo* p) + : type(ty), parent(p) {} Resolver::BlockInfo::~BlockInfo() = default; @@ -150,12 +149,8 @@ bool Resolver::Function(ast::Function* func) { } bool Resolver::BlockStatement(const ast::BlockStatement* stmt) { - auto* block = - block_infos_.Create(BlockInfo::Type::Generic, current_block_, stmt); - block_to_info_[stmt] = block; - ScopedAssignment scope_sa(current_block_, block); - - return Statements(stmt->list()); + return BlockScope(BlockInfo::Type::kGeneric, + [&] { return Statements(stmt->list()); }); } bool Resolver::Statements(const ast::StatementList& stmts) { @@ -219,18 +214,24 @@ bool Resolver::Statement(ast::Statement* stmt) { return BlockStatement(b); } if (stmt->Is()) { + if (!current_block_->FindFirstParent(BlockInfo::Type::kLoop) && + !current_block_->FindFirstParent(BlockInfo::Type::kSwitchCase)) { + diagnostics_.add_error("break statement must be in a loop or switch case", + stmt->source()); + return false; + } return true; } if (auto* c = stmt->As()) { return Expression(c->expr()); } if (auto* c = stmt->As()) { - return BlockStatement(c->body()); + return CaseStatement(c); } if (stmt->Is()) { // Set if we've hit the first continue statement in our parent loop if (auto* loop_block = - current_block_->FindFirstParent(BlockInfo::Type::Loop)) { + current_block_->FindFirstParent(BlockInfo::Type::kLoop)) { if (loop_block->first_continue == size_t(~0)) { loop_block->first_continue = loop_block->decls.size(); } @@ -268,26 +269,20 @@ bool Resolver::Statement(ast::Statement* stmt) { // these would make their BlockInfo siblings as in the AST, but we want the // body BlockInfo to parent the continuing BlockInfo for semantics and // validation. Also, we need to set their types differently. - auto* block = - block_infos_.Create(BlockInfo::Type::Loop, current_block_, l->body()); - block_to_info_[l->body()] = block; - ScopedAssignment scope_sa(current_block_, block); - - if (!Statements(l->body()->list())) { - return false; - } - - if (l->has_continuing()) { - auto* cont_block = block_infos_.Create(BlockInfo::Type::LoopContinuing, - current_block_, l->continuing()); - block_to_info_[l->continuing()] = cont_block; - ScopedAssignment scope_sa2(current_block_, cont_block); - - if (!Statements(l->continuing()->list())) { + return BlockScope(BlockInfo::Type::kLoop, [&] { + if (!Statements(l->body()->list())) { return false; } - } - return true; + + if (l->has_continuing()) { + if (!BlockScope(BlockInfo::Type::kLoopContinuing, + [&] { return Statements(l->continuing()->list()); })) { + return false; + } + } + + return true; + }); } if (auto* r = stmt->As()) { return Expression(r->value()); @@ -297,7 +292,7 @@ bool Resolver::Statement(ast::Statement* stmt) { return false; } for (auto* case_stmt : s->body()) { - if (!Statement(case_stmt)) { + if (!CaseStatement(case_stmt)) { return false; } } @@ -316,6 +311,11 @@ bool Resolver::Statement(ast::Statement* stmt) { return false; } +bool Resolver::CaseStatement(ast::CaseStatement* stmt) { + return BlockScope(BlockInfo::Type::kSwitchCase, + [&] { return Statements(stmt->body()->list()); }); +} + bool Resolver::Expressions(const ast::ExpressionList& list) { for (auto* expr : list) { if (!Expression(expr)) { @@ -395,8 +395,7 @@ bool Resolver::ArrayAccessor(ast::ArrayAccessorExpression* expr) { } else if (auto* arr = parent_type->As()) { if (!arr->type()->is_scalar()) { // If we extract a non-scalar from an array then we also get a pointer. We - // will generate a Function storage class variable to store this - // into. + // will generate a Function storage class variable to store this into. ret = builder_->create(ret, ast::StorageClass::kFunction); } } @@ -573,9 +572,9 @@ bool Resolver::Identifier(ast::IdentifierExpression* expr) { // refer to a variable that is bypassed by a continue statement in the // loop's body block. if (auto* continuing_block = - current_block_->FindFirstParent(BlockInfo::Type::LoopContinuing)) { + current_block_->FindFirstParent(BlockInfo::Type::kLoopContinuing)) { auto* loop_block = - continuing_block->FindFirstParent(BlockInfo::Type::Loop); + continuing_block->FindFirstParent(BlockInfo::Type::kLoop); if (loop_block->first_continue != size_t(~0)) { auto& decls = loop_block->decls; // If our identifier is in loop_block->decls, make sure its index is @@ -946,6 +945,13 @@ void Resolver::CreateSemanticNodes() const { } } +template +bool Resolver::BlockScope(BlockInfo::Type type, F&& callback) { + BlockInfo block_info(type, current_block_); + ScopedAssignment sa(current_block_, &block_info); + return callback(); +} + Resolver::VariableInfo::VariableInfo(ast::Variable* decl) : declaration(decl), storage_class(decl->declared_storage_class()) {} diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index f07fab5d5d..4a095be7b8 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -33,6 +33,7 @@ class ArrayAccessorExpression; class BinaryExpression; class BitcastExpression; class CallExpression; +class CaseStatement; class ConstructorExpression; class Function; class IdentifierExpression; @@ -105,9 +106,9 @@ class Resolver { /// parent block and variables declared in the block. /// Used to validate variable scoping rules. struct BlockInfo { - enum class Type { Generic, Loop, LoopContinuing }; + enum class Type { kGeneric, kLoop, kLoopContinuing, kSwitchCase }; - BlockInfo(Type type, BlockInfo* parent, const ast::BlockStatement* block); + BlockInfo(Type type, BlockInfo* parent); ~BlockInfo(); template @@ -124,9 +125,8 @@ class Resolver { [ty](auto* block_info) { return block_info->type == ty; }); } - const Type type; - BlockInfo* parent; - const ast::BlockStatement* block; + Type const type; + BlockInfo* const parent; std::vector decls; // first_continue is set to the index of the first variable in decls @@ -134,9 +134,6 @@ class Resolver { constexpr static size_t kNoContinue = size_t(~0); size_t first_continue = kNoContinue; }; - std::unordered_map block_to_info_; - BlockAllocator block_infos_; - BlockInfo* current_block_ = nullptr; /// Resolves the program, without creating final the semantic nodes. /// @returns true on success, false on error @@ -200,6 +197,7 @@ class Resolver { bool Binary(ast::BinaryExpression* expr); bool Bitcast(ast::BitcastExpression* expr); bool Call(ast::CallExpression* expr); + bool CaseStatement(ast::CaseStatement* stmt); bool Constructor(ast::ConstructorExpression* expr); bool Identifier(ast::IdentifierExpression* expr); bool IntrinsicCall(ast::CallExpression* call, @@ -221,9 +219,16 @@ class Resolver { /// @param type the resolved type void SetType(ast::Expression* expr, type::Type* type); + /// Constructs a new BlockInfo with the given type and with #current_block_ as + /// its parent, assigns this to #current_block_, and then calls `callback`. + /// The original #current_block_ is restored on exit. + template + bool BlockScope(BlockInfo::Type type, F&& callback); + ProgramBuilder* const builder_; std::unique_ptr const intrinsic_table_; diag::List diagnostics_; + BlockInfo* current_block_ = nullptr; ScopeStack variable_stack_; std::unordered_map symbol_to_function_; std::unordered_map function_to_info_; diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index d7c971d9a9..d3613cd147 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -17,6 +17,7 @@ #include "gmock/gmock.h" #include "src/ast/assignment_statement.h" #include "src/ast/bitcast_expression.h" +#include "src/ast/break_statement.h" #include "src/ast/call_statement.h" #include "src/ast/continue_statement.h" #include "src/ast/if_statement.h" @@ -476,10 +477,37 @@ TEST_F(ResolverTest, EXPECT_TRUE(r()->Resolve()) << r()->error(); } +TEST_F(ResolverTest, Stmt_ContinueInLoop) { + WrapInFunction(Loop(Block(create(Source{{12, 34}})))); + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + TEST_F(ResolverTest, Stmt_ContinueNotInLoop) { - WrapInFunction(create()); + WrapInFunction(create(Source{{12, 34}})); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "error: continue statement must be in a loop"); + EXPECT_EQ(r()->error(), "12:34 error: continue statement must be in a loop"); +} + +TEST_F(ResolverTest, Stmt_BreakInLoop) { + WrapInFunction(Loop(Block(create(Source{{12, 34}})))); + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverTest, Stmt_BreakInSwitch) { + WrapInFunction(Loop(Block(create( + Expr(1), ast::CaseStatementList{ + create( + ast::CaseSelectorList{Literal(1)}, + Block(create(Source{{12, 34}}))), + })))); + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverTest, Stmt_BreakNotInLoopOrSwitch) { + WrapInFunction(create(Source{{12, 34}})); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: break statement must be in a loop or switch case"); } TEST_F(ResolverTest, Stmt_Return) {