diff --git a/src/resolver/block_test.cc b/src/resolver/block_test.cc index 3701330e5c..ab213f6bed 100644 --- a/src/resolver/block_test.cc +++ b/src/resolver/block_test.cc @@ -59,7 +59,6 @@ TEST_F(ResolverBlockTest, Block) { auto* s = Sem().Get(stmt); ASSERT_NE(s, nullptr); ASSERT_NE(s->Block(), nullptr); - ASSERT_NE(s->Block()->Parent(), nullptr); EXPECT_EQ(s->Block(), s->Block()->FindFirstParent()); EXPECT_EQ(s->Block()->Parent(), s->Block()->FindFirstParent()); @@ -69,8 +68,98 @@ TEST_F(ResolverBlockTest, Block) { EXPECT_EQ(s->Block()->Parent()->Parent(), nullptr); } -// TODO(bclayton): Add tests for other block types (LoopBlockStatement, -// LoopContinuingBlockStatement, SwitchCaseBlockStatement) +TEST_F(ResolverBlockTest, LoopBlock) { + // fn F() { + // loop { + // var x : 32; + // } + // } + auto* stmt = Decl(Var("x", ty.i32())); + auto* loop = Loop(Block(stmt)); + auto* f = Func("F", {}, ty.void_(), {loop}); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + auto* s = Sem().Get(stmt); + ASSERT_NE(s, nullptr); + ASSERT_NE(s->Block(), nullptr); + EXPECT_EQ(s->Block(), s->Block()->FindFirstParent()); + ASSERT_TRUE(Is(s->Block()->Parent()->Parent())); + EXPECT_EQ(s->Block()->Parent()->Parent(), + s->Block()->FindFirstParent()); + EXPECT_EQ(s->Block() + ->Parent() + ->Parent() + ->As() + ->Function(), + f); + EXPECT_EQ(s->Block()->Parent()->Parent()->Parent(), nullptr); +} + +TEST_F(ResolverBlockTest, ForLoopBlock) { + // fn F() { + // for (var i : u32; true; i = i + 1u) { + // return; + // } + // } + auto* init = Decl(Var("i", ty.u32())); + auto* cond = Expr(true); + auto* cont = Assign("i", Add("i", 1u)); + auto* stmt = Return(); + auto* body = Block(stmt); + auto* for_ = For(init, cond, cont, body); + auto* f = Func("F", {}, ty.void_(), {for_}); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + + { + auto* s = Sem().Get(init); + ASSERT_NE(s, nullptr); + ASSERT_NE(s->Block(), nullptr); + EXPECT_EQ(s->Block(), + s->Block()->FindFirstParent()); + ASSERT_TRUE( + Is(s->Block()->Parent()->Parent())); + } + { // Condition expression's statement is the for-loop itself + auto* s = Sem().Get(cond); + ASSERT_NE(s, nullptr); + ASSERT_NE(s->Stmt()->Block(), nullptr); + EXPECT_EQ( + s->Stmt()->Block(), + s->Stmt()->Block()->FindFirstParent()); + ASSERT_TRUE(Is(s->Stmt()->Block())); + } + { + auto* s = Sem().Get(cont); + ASSERT_NE(s, nullptr); + ASSERT_NE(s->Block(), nullptr); + EXPECT_EQ(s->Block(), + s->Block()->FindFirstParent()); + ASSERT_TRUE( + Is(s->Block()->Parent()->Parent())); + } + { + auto* s = Sem().Get(stmt); + ASSERT_NE(s, nullptr); + ASSERT_NE(s->Block(), nullptr); + EXPECT_EQ(s->Block(), + s->Block()->FindFirstParent()); + ASSERT_TRUE( + Is(s->Block()->Parent()->Parent())); + EXPECT_EQ(s->Block()->Parent()->Parent(), + s->Block()->FindFirstParent()); + EXPECT_EQ(s->Block() + ->Parent() + ->Parent() + ->As() + ->Function(), + f); + EXPECT_EQ(s->Block()->Parent()->Parent()->Parent(), nullptr); + } +} +// TODO(bclayton): Add tests for other block types +// (LoopContinuingBlockStatement, SwitchCaseBlockStatement) } // namespace } // namespace resolver diff --git a/src/resolver/control_block_validation_test.cc b/src/resolver/control_block_validation_test.cc index b53bc0d988..bc8fc39721 100644 --- a/src/resolver/control_block_validation_test.cc +++ b/src/resolver/control_block_validation_test.cc @@ -109,7 +109,7 @@ TEST_F(ResolverControlBlockValidationTest, SwitchWithTwoDefault_Fail) { "12:34 error: switch statement must have exactly one default clause"); } -TEST_F(ResolverControlBlockValidationTest, UnreachableCode_continue) { +TEST_F(ResolverControlBlockValidationTest, UnreachableCode_Loop_continue) { // loop { // continue; // var z : i32; @@ -122,6 +122,20 @@ TEST_F(ResolverControlBlockValidationTest, UnreachableCode_continue) { EXPECT_EQ(r()->error(), "12:34 error: code is unreachable"); } +TEST_F(ResolverControlBlockValidationTest, UnreachableCode_ForLoop_continue) { + // for (;;;) { + // continue; + // var z : i32; + // } + WrapInFunction( + For(nullptr, nullptr, nullptr, + Block(create(), + Decl(Source{{12, 34}}, + Var("z", ty.i32(), ast::StorageClass::kNone))))); + + EXPECT_FALSE(r()->Resolve()) << r()->error(); + EXPECT_EQ(r()->error(), "12:34 error: code is unreachable"); +} TEST_F(ResolverControlBlockValidationTest, UnreachableCode_break) { // switch (a) { // case 1: { break; var a : u32 = 2;} diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 8bd7aa076f..307f807f6d 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -28,6 +28,7 @@ #include "src/ast/disable_validation_decoration.h" #include "src/ast/discard_statement.h" #include "src/ast/fallthrough_statement.h" +#include "src/ast/for_loop_statement.h" #include "src/ast/if_statement.h" #include "src/ast/internal_decoration.h" #include "src/ast/interpolate_decoration.h" @@ -1701,6 +1702,9 @@ bool Resolver::Statement(ast::Statement* stmt) { if (auto* l = stmt->As()) { return LoopStatement(l); } + if (auto* l = stmt->As()) { + return ForLoopStatement(l); + } if (auto* r = stmt->As()) { return Return(r); } @@ -1824,6 +1828,45 @@ bool Resolver::LoopStatement(ast::LoopStatement* stmt) { }); } +bool Resolver::ForLoopStatement(ast::ForLoopStatement* stmt) { + Mark(stmt->body()); + + auto* sem_block_body = builder_->create( + stmt->body(), current_statement_); + builder_->Sem().Add(stmt->body(), sem_block_body); + TINT_SCOPED_ASSIGNMENT(current_statement_, sem_block_body); + + if (auto* initializer = stmt->initializer()) { + Mark(initializer); + if (!Statement(initializer)) { + return false; + } + } + + if (auto* condition = stmt->condition()) { + Mark(condition); + if (!Expression(condition)) { + return false; + } + + if (!TypeOf(condition)->Is()) { + AddError("for-loop condition must be bool, got " + TypeNameOf(condition), + condition->source()); + return false; + } + } + + if (auto* continuing = stmt->continuing()) { + Mark(continuing); + if (!Statement(continuing)) { + return false; + } + } + + return BlockScope(stmt->body(), + [&] { return Statements(stmt->body()->list()); }); +} + bool Resolver::Expressions(const ast::ExpressionList& list) { for (auto* expr : list) { Mark(expr); diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index f1d663df51..cf26a39cd4 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -38,9 +38,10 @@ class ArrayAccessorExpression; class BinaryExpression; class BitcastExpression; class CallExpression; -class CaseStatement; class CallStatement; +class CaseStatement; class ConstructorExpression; +class ForLoopStatement; class Function; class IdentifierExpression; class LoopStatement; @@ -244,12 +245,13 @@ class Resolver { bool Constructor(ast::ConstructorExpression*); bool Expression(ast::Expression*); bool Expressions(const ast::ExpressionList&); + bool ForLoopStatement(ast::ForLoopStatement*); bool Function(ast::Function*); + bool FunctionCall(const ast::CallExpression* call); bool GlobalVariable(ast::Variable* var); bool Identifier(ast::IdentifierExpression*); bool IfStatement(ast::IfStatement*); bool IntrinsicCall(ast::CallExpression*, sem::IntrinsicType); - bool FunctionCall(const ast::CallExpression* call); bool LoopStatement(ast::LoopStatement*); bool MemberAccessor(ast::MemberAccessorExpression*); bool Parameter(ast::Variable* param); diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index 33d4693444..c688ad732a 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -652,6 +652,17 @@ TEST_F(ResolverTest, Stmt_Loop_ContinueInLoopBodyAfterDecl_UsageInContinuing) { EXPECT_TRUE(r()->Resolve()); } +TEST_F(ResolverTest, Stmt_ForLoop_CondIsNotBool) { + // for (; 1.0f; ) { + // } + + WrapInFunction(For(nullptr, Expr(Source{{12, 34}}, 1.0f), nullptr, Block())); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: for-loop condition must be bool, got f32"); +} + TEST_F(ResolverValidationTest, Stmt_ContinueInLoop) { WrapInFunction(Loop(Block(create(Source{{12, 34}})))); EXPECT_TRUE(r()->Resolve()) << r()->error(); diff --git a/src/sem/block_statement.h b/src/sem/block_statement.h index 9cbd64d7ab..c61c3f888a 100644 --- a/src/sem/block_statement.h +++ b/src/sem/block_statement.h @@ -105,7 +105,7 @@ class FunctionBlockStatement ast::Function const* const function_; }; -/// Holds semantic information about a loop block +/// Holds semantic information about a loop block or a for-loop block class LoopBlockStatement : public Castable { public: /// Constructor diff --git a/src/sem/statement.cc b/src/sem/statement.cc index c7d46fb47a..85ece66871 100644 --- a/src/sem/statement.cc +++ b/src/sem/statement.cc @@ -17,7 +17,6 @@ #include "src/ast/block_statement.h" #include "src/ast/loop_statement.h" #include "src/ast/statement.h" -#include "src/debug.h" #include "src/sem/block_statement.h" #include "src/sem/statement.h" @@ -27,32 +26,7 @@ namespace tint { namespace sem { Statement::Statement(const ast::Statement* declaration, const Statement* parent) - : declaration_(declaration), parent_(parent) { -#ifndef NDEBUG - if (parent_) { - auto* block = Block(); - if (parent_ == block) { - // The parent of this statement is a block. We thus expect the statement - // to be an element of the block. There is one exception: a loop's - // continuing block has the loop's body as its parent, but the continuing - // block is not a statement in the body, so we rule out that case. - auto& stmts = block->Declaration()->statements(); - if (std::find(stmts.begin(), stmts.end(), declaration) == stmts.end()) { - bool statement_is_continuing_for_loop = false; - if (parent_->parent_ != nullptr) { - if (auto* loop = - parent_->parent_->Declaration()->As()) { - if (loop->has_continuing() && Declaration() == loop->continuing()) { - statement_is_continuing_for_loop = true; - } - } - } - TINT_ASSERT(Semantic, statement_is_continuing_for_loop); - } - } - } -#endif // NDEBUG -} + : declaration_(declaration), parent_(parent) {} const BlockStatement* Statement::Block() const { auto* stmt = parent_;