diff --git a/src/ast/block_statement.h b/src/ast/block_statement.h index f7df6140b5..a2cb102b22 100644 --- a/src/ast/block_statement.h +++ b/src/ast/block_statement.h @@ -35,6 +35,9 @@ class BlockStatement : public Castable { BlockStatement(BlockStatement&&); ~BlockStatement() override; + /// @returns the StatementList + const StatementList& list() const { return statements_; } + /// @returns true if the block is empty bool empty() const { return statements_.empty(); } /// @returns the number of statements directly in the block diff --git a/src/program_builder.h b/src/program_builder.h index d4b8ff6a71..d7f0f2ab0f 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -19,12 +19,15 @@ #include #include "src/ast/array_accessor_expression.h" +#include "src/ast/assignment_statement.h" #include "src/ast/binary_expression.h" #include "src/ast/bool_literal.h" #include "src/ast/call_expression.h" #include "src/ast/expression.h" #include "src/ast/float_literal.h" #include "src/ast/identifier_expression.h" +#include "src/ast/if_statement.h" +#include "src/ast/loop_statement.h" #include "src/ast/member_accessor_expression.h" #include "src/ast/module.h" #include "src/ast/scalar_constructor_expression.h" @@ -36,6 +39,7 @@ #include "src/ast/type_constructor_expression.h" #include "src/ast/uint_literal.h" #include "src/ast/variable.h" +#include "src/ast/variable_decl_statement.h" #include "src/diagnostic/diagnostic.h" #include "src/program.h" #include "src/semantic/info.h" @@ -1051,6 +1055,64 @@ class ProgramBuilder { }); } + /// Creates a ast::BlockStatement with input statements + /// @param statements statements of block + /// @returns the block statement pointer + template + ast::BlockStatement* Block(Statements&&... statements) { + return create( + ast::StatementList{std::forward(statements)...}); + } + + /// Creates a ast::ElseStatement with input condition and body + /// @param condition the else condition expression + /// @param body the else body + /// @returns the else statement pointer + ast::ElseStatement* Else(ast::Expression* condition, + ast::BlockStatement* body) { + return create(condition, body); + } + + /// Creates a ast::IfStatement with input condition, body, and optional + /// variadic else statements + /// @param condition the if statement condition expression + /// @param body the if statement body + /// @param elseStatements optional variadic else statements + /// @returns the if statement pointer + template + ast::IfStatement* If(ast::Expression* condition, + ast::BlockStatement* body, + ElseStatements&&... elseStatements) { + return create( + condition, body, + ast::ElseStatementList{ + std::forward(elseStatements)...}); + } + + /// Creates a ast::AssignmentStatement with input lhs and rhs expressions + /// @param lhs the left hand side expression + /// @param rhs the right hand side expression + /// @returns the assignment statement pointer + ast::AssignmentStatement* Assign(ast::Expression* lhs, ast::Expression* rhs) { + return create(lhs, rhs); + } + + /// Creates a ast::LoopStatement with input body and optional continuing + /// @param body the loop body + /// @param continuing the optional continuing block + /// @returns the loop statement pointer + ast::LoopStatement* Loop(ast::BlockStatement* body, + ast::BlockStatement* continuing = nullptr) { + return create(body, continuing); + } + + /// Creates a ast::VariableDeclStatement for the input variable + /// @param var the variable to wrap in a decl statement + /// @returns the variable decl statement pointer + ast::VariableDeclStatement* Decl(ast::Variable* var) { + return create(var); + } + /// Sets the current builder source to `src` /// @param src the Source used for future create() calls void SetSource(const Source& src) { diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 22ad3350ae..6f71979997 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -96,6 +96,13 @@ TypeDeterminer::TypeDeterminer(ProgramBuilder* builder) TypeDeterminer::~TypeDeterminer() = default; +TypeDeterminer::BlockInfo::BlockInfo(TypeDeterminer::BlockInfo::Type type, + TypeDeterminer::BlockInfo* parent, + const ast::BlockStatement* block) + : type(type), parent(parent), block(block) {} + +TypeDeterminer::BlockInfo::~BlockInfo() = default; + void TypeDeterminer::set_referenced_from_function_if_needed(VariableInfo* var, bool local) { if (current_function_ == nullptr) { @@ -159,7 +166,7 @@ bool TypeDeterminer::DetermineFunction(ast::Function* func) { variable_stack_.set(param->symbol(), CreateVariableInfo(param)); } - if (!DetermineStatements(func->body())) { + if (!DetermineBlockStatement(func->body())) { return false; } variable_stack_.pop_scope(); @@ -173,8 +180,17 @@ bool TypeDeterminer::DetermineFunction(ast::Function* func) { return true; } -bool TypeDeterminer::DetermineStatements(const ast::BlockStatement* stmts) { - for (auto* stmt : *stmts) { +bool TypeDeterminer::DetermineBlockStatement(const ast::BlockStatement* stmt) { + auto* block = + block_infos_.Create(BlockInfo::Type::Generic, current_block_, stmt); + block_to_info_[stmt] = block; + ScopedAssignment scope_sa(current_block_, block); + + return DetermineStatements(stmt->list()); +} + +bool TypeDeterminer::DetermineStatements(const ast::StatementList& stmts) { + for (auto* stmt : stmts) { if (auto* decl = stmt->As()) { if (!ValidateVariableDeclStatement(decl)) { return false; @@ -231,7 +247,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { return DetermineResultType(a->lhs()) && DetermineResultType(a->rhs()); } if (auto* b = stmt->As()) { - return DetermineStatements(b); + return DetermineBlockStatement(b); } if (stmt->Is()) { return true; @@ -240,9 +256,21 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { return DetermineResultType(c->expr()); } if (auto* c = stmt->As()) { - return DetermineStatements(c->body()); + return DetermineBlockStatement(c->body()); } if (stmt->Is()) { + // Set if we've hit the first continue statement in our parent loop + if (auto* loop_block = + current_block_->FindFirstParent(BlockInfo::Type::Loop)) { + if (loop_block->first_continue == size_t(~0)) { + loop_block->first_continue = loop_block->decls.size(); + } + } else { + diagnostics_.add_error("continue statement must be in a loop", + stmt->source()); + return false; + } + return true; } if (stmt->Is()) { @@ -250,14 +278,14 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { } if (auto* e = stmt->As()) { return DetermineResultType(e->condition()) && - DetermineStatements(e->body()); + DetermineBlockStatement(e->body()); } if (stmt->Is()) { return true; } if (auto* i = stmt->As()) { if (!DetermineResultType(i->condition()) || - !DetermineStatements(i->body())) { + !DetermineBlockStatement(i->body())) { return false; } @@ -269,8 +297,30 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { return true; } if (auto* l = stmt->As()) { - return DetermineStatements(l->body()) && - DetermineStatements(l->continuing()); + // We don't call DetermineBlockStatement on the body and continuing block as + // these would make their BlockInfo siblings as in the AST, but we want the + // body BlockInfo to parent the continuing BlockInfo for semantics and + // validation. Also, we need to set their types differently. + auto* block = + block_infos_.Create(BlockInfo::Type::Loop, current_block_, l->body()); + block_to_info_[l->body()] = block; + ScopedAssignment scope_sa(current_block_, block); + + if (!DetermineStatements(l->body()->list())) { + return false; + } + + if (l->has_continuing()) { + auto* block = block_infos_.Create(BlockInfo::Type::LoopContinuing, + current_block_, l->continuing()); + block_to_info_[l->continuing()] = block; + ScopedAssignment scope_sa(current_block_, block); + + if (!DetermineStatements(l->continuing()->list())) { + return false; + } + } + return true; } if (auto* r = stmt->As()) { return DetermineResultType(r->value()); @@ -289,6 +339,7 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { if (auto* v = stmt->As()) { variable_stack_.set(v->variable()->symbol(), variable_to_info_.at(v->variable())); + current_block_->decls.push_back(v->variable()); return DetermineResultType(v->variable()->constructor()); } @@ -552,6 +603,36 @@ bool TypeDeterminer::DetermineIdentifier(ast::IdentifierExpression* expr) { var->users.push_back(expr); set_referenced_from_function_if_needed(var, true); + + // If identifier is part of a loop continuing block, make sure it doesn't + // refer to a variable that is bypassed by a continue statement in the + // loop's body block. + if (auto* continuing_block = + current_block_->FindFirstParent(BlockInfo::Type::LoopContinuing)) { + auto* loop_block = + continuing_block->FindFirstParent(BlockInfo::Type::Loop); + if (loop_block->first_continue != size_t(~0)) { + auto& decls = loop_block->decls; + // If our identifier is in loop_block->decls, make sure its index is + // less than first_continue + auto iter = std::find_if( + decls.begin(), decls.end(), + [&symbol](auto* var) { return var->symbol() == symbol; }); + if (iter != decls.end()) { + auto var_decl_index = + static_cast(std::distance(decls.begin(), iter)); + if (var_decl_index >= loop_block->first_continue) { + diagnostics_.add_error( + "continue statement bypasses declaration of '" + + builder_->Symbols().NameFor(symbol) + + "' in continuing block", + expr->source()); + return false; + } + } + } + } + return true; } diff --git a/src/type_determiner.h b/src/type_determiner.h index 2c3fd2a71a..b32f9e3d3f 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -106,6 +106,43 @@ class TypeDeterminer { semantic::Statement* statement; }; + /// Structure holding semantic information about a block (i.e. scope), such as + /// parent block and variables declared in the block. + /// Used to validate variable scoping rules. + struct BlockInfo { + enum class Type { Generic, Loop, LoopContinuing }; + + BlockInfo(Type type, BlockInfo* parent, const ast::BlockStatement* block); + ~BlockInfo(); + + template + BlockInfo* FindFirstParent(Pred&& pred) { + BlockInfo* curr = this; + while (curr && !pred(curr)) { + curr = curr->parent; + } + return curr; + } + + BlockInfo* FindFirstParent(BlockInfo::Type type) { + return FindFirstParent( + [type](auto* block_info) { return block_info->type == type; }); + } + + const Type type; + BlockInfo* parent; + const ast::BlockStatement* block; + std::vector decls; + + // first_continue is set to the index of the first variable in decls + // declared after the first continue statement in a loop block, if any. + constexpr static size_t kNoContinue = size_t(~0); + size_t first_continue = kNoContinue; + }; + std::unordered_map block_to_info_; + BlockAllocator block_infos_; + BlockInfo* current_block_ = nullptr; + /// Determines type information for the program, without creating final the /// semantic nodes. /// @returns true if the determination was successful @@ -120,10 +157,14 @@ class TypeDeterminer { /// @param func the function to check /// @returns true if the determination was successful bool DetermineFunction(ast::Function* func); + /// Determines the type information for a block statement + /// @param stmt the block statement + /// @returns true if determination was successful + bool DetermineBlockStatement(const ast::BlockStatement* stmt); /// Determines type information for a set of statements /// @param stmts the statements to check /// @returns true if the determination was successful - bool DetermineStatements(const ast::BlockStatement* stmts); + bool DetermineStatements(const ast::StatementList& stmts); /// Determines type information for a statement /// @param stmt the statement to check /// @returns true if the determination was successful diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 30ea38d959..dd66bbf5f9 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -312,6 +312,218 @@ TEST_F(TypeDeterminerTest, Stmt_Loop) { EXPECT_TRUE(TypeOf(continuing_rhs)->Is()); } +TEST_F(TypeDeterminerTest, + Stmt_Loop_ContinueInLoopBodyBeforeDecl_UsageInContinuing) { + // loop { + // continue; // Bypasses z decl + // var z : i32; + // + // continuing { + // z = 2; + // } + // } + + auto error_loc = Source{Source::Location{12, 34}}; + auto* body = Block(create(), + Decl(Var("z", ty.i32(), ast::StorageClass::kNone))); + auto* continuing = Block(Assign(Expr(error_loc, "z"), Expr(2))); + auto* loop_stmt = Loop(body, continuing); + WrapInFunction(loop_stmt); + + EXPECT_FALSE(td()->Determine()) << td()->error(); + EXPECT_EQ(td()->error(), + "12:34 error: continue statement bypasses declaration of 'z' in " + "continuing block"); +} + +TEST_F(TypeDeterminerTest, + Stmt_Loop_ContinueInLoopBodyBeforeDeclAndAfterDecl_UsageInContinuing) { + // loop { + // continue; // Bypasses z decl + // var z : i32; + // continue; // Ok + // + // continuing { + // z = 2; + // } + // } + + auto error_loc = Source{Source::Location{12, 34}}; + auto* body = Block(create(), + Decl(Var("z", ty.i32(), ast::StorageClass::kNone)), + create()); + auto* continuing = Block(Assign(Expr(error_loc, "z"), Expr(2))); + auto* loop_stmt = Loop(body, continuing); + WrapInFunction(loop_stmt); + + EXPECT_FALSE(td()->Determine()) << td()->error(); + EXPECT_EQ(td()->error(), + "12:34 error: continue statement bypasses declaration of 'z' in " + "continuing block"); +} + +TEST_F(TypeDeterminerTest, + Stmt_Loop_ContinueInLoopBodySubscopeBeforeDecl_UsageInContinuing) { + // loop { + // if (true) { + // continue; // Still bypasses z decl (if we reach here) + // } + // var z : i32; + // continuing { + // z = 2; + // } + // } + + auto error_loc = Source{Source::Location{12, 34}}; + auto* body = Block(If(Expr(true), Block(create())), + Decl(Var("z", ty.i32(), ast::StorageClass::kNone))); + auto* continuing = Block(Assign(Expr(error_loc, "z"), Expr(2))); + auto* loop_stmt = Loop(body, continuing); + WrapInFunction(loop_stmt); + + EXPECT_FALSE(td()->Determine()) << td()->error(); + EXPECT_EQ(td()->error(), + "12:34 error: continue statement bypasses declaration of 'z' in " + "continuing block"); +} + +TEST_F( + TypeDeterminerTest, + Stmt_Loop_ContinueInLoopBodySubscopeBeforeDecl_UsageInContinuingSubscope) { + // loop { + // if (true) { + // continue; // Still bypasses z decl (if we reach here) + // } + // var z : i32; + // continuing { + // if (true) { + // z = 2; // Must fail even if z is in a sub-scope + // } + // } + // } + + auto error_loc = Source{Source::Location{12, 34}}; + auto* body = Block(If(Expr(true), Block(create())), + Decl(Var("z", ty.i32(), ast::StorageClass::kNone))); + + auto* continuing = + Block(If(Expr(true), Block(Assign(Expr(error_loc, "z"), Expr(2))))); + auto* loop_stmt = Loop(body, continuing); + WrapInFunction(loop_stmt); + + EXPECT_FALSE(td()->Determine()) << td()->error(); + EXPECT_EQ(td()->error(), + "12:34 error: continue statement bypasses declaration of 'z' in " + "continuing block"); +} + +TEST_F(TypeDeterminerTest, + Stmt_Loop_ContinueInLoopBodySubscopeBeforeDecl_UsageInContinuingLoop) { + // loop { + // if (true) { + // continue; // Still bypasses z decl (if we reach here) + // } + // var z : i32; + // continuing { + // loop { + // z = 2; // Must fail even if z is in a sub-scope + // } + // } + // } + + auto error_loc = Source{Source::Location{12, 34}}; + auto* body = Block(If(Expr(true), Block(create())), + Decl(Var("z", ty.i32(), ast::StorageClass::kNone))); + + auto* continuing = Block(Loop(Block(Assign(Expr(error_loc, "z"), Expr(2))))); + auto* loop_stmt = Loop(body, continuing); + WrapInFunction(loop_stmt); + + EXPECT_FALSE(td()->Determine()) << td()->error(); + EXPECT_EQ(td()->error(), + "12:34 error: continue statement bypasses declaration of 'z' in " + "continuing block"); +} + +TEST_F(TypeDeterminerTest, + Stmt_Loop_ContinueInNestedLoopBodyBeforeDecl_UsageInContinuing) { + // loop { + // loop { + // continue; // OK: not part of the outer loop + // } + // var z : i32; + // + // continuing { + // z = 2; + // } + // } + + auto* inner_loop = Loop(Block(create())); + auto* body = + Block(inner_loop, Decl(Var("z", ty.i32(), ast::StorageClass::kNone))); + auto* continuing = Block(Assign(Expr("z"), Expr(2))); + auto* loop_stmt = Loop(body, continuing); + WrapInFunction(loop_stmt); + + EXPECT_TRUE(td()->Determine()) << td()->error(); +} + +TEST_F(TypeDeterminerTest, + Stmt_Loop_ContinueInNestedLoopBodyBeforeDecl_UsageInContinuingSubscope) { + // loop { + // loop { + // continue; // OK: not part of the outer loop + // } + // var z : i32; + // + // continuing { + // if (true) { + // z = 2; + // } + // } + // } + + auto* inner_loop = Loop(Block(create())); + auto* body = + Block(inner_loop, Decl(Var("z", ty.i32(), ast::StorageClass::kNone))); + auto* continuing = Block(If(Expr(true), Block(Assign(Expr("z"), Expr(2))))); + auto* loop_stmt = Loop(body, continuing); + WrapInFunction(loop_stmt); + + EXPECT_TRUE(td()->Determine()) << td()->error(); +} + +TEST_F(TypeDeterminerTest, + Stmt_Loop_ContinueInNestedLoopBodyBeforeDecl_UsageInContinuingLoop) { + // loop { + // loop { + // continue; // OK: not part of the outer loop + // } + // var z : i32; + // + // continuing { + // loop { + // z = 2; + // } + // } + // } + + auto* inner_loop = Loop(Block(create())); + auto* body = + Block(inner_loop, Decl(Var("z", ty.i32(), ast::StorageClass::kNone))); + auto* continuing = Block(Loop(Block(Assign(Expr("z"), Expr(2))))); + auto* loop_stmt = Loop(body, continuing); + WrapInFunction(loop_stmt); + + EXPECT_TRUE(td()->Determine()) << td()->error(); +} + +TEST_F(TypeDeterminerTest, Stmt_ContinueNotInLoop) { + WrapInFunction(create()); + EXPECT_FALSE(td()->Determine()); + EXPECT_EQ(td()->error(), "error: continue statement must be in a loop"); +} + TEST_F(TypeDeterminerTest, Stmt_Return) { auto* cond = Expr(2);