diff --git a/src/type_determiner.cc b/src/type_determiner.cc index f214e46a10..86e35570b4 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -20,6 +20,7 @@ #include "src/ast/continue_statement.h" #include "src/ast/else_statement.h" #include "src/ast/if_statement.h" +#include "src/ast/loop_statement.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/type_constructor_expression.h" @@ -117,6 +118,11 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { if (stmt->IsKill()) { return true; } + if (stmt->IsLoop()) { + auto l = stmt->AsLoop(); + return DetermineResultType(l->body()) && + DetermineResultType(l->continuing()); + } if (stmt->IsNop()) { return true; } diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index e66fc7852b..75fb75ac34 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -26,6 +26,7 @@ #include "src/ast/float_literal.h" #include "src/ast/if_statement.h" #include "src/ast/int_literal.h" +#include "src/ast/loop_statement.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" @@ -205,6 +206,47 @@ TEST_F(TypeDeterminerTest, Stmt_If) { EXPECT_TRUE(rhs_ptr->result_type()->IsF32()); } +TEST_F(TypeDeterminerTest, Stmt_Loop) { + ast::type::I32Type i32; + ast::type::F32Type f32; + + auto body_lhs = std::make_unique( + std::make_unique(&i32, 2)); + auto body_lhs_ptr = body_lhs.get(); + + auto body_rhs = std::make_unique( + std::make_unique(&f32, 2.3f)); + auto body_rhs_ptr = body_rhs.get(); + + ast::StatementList body; + body.push_back(std::make_unique( + std::move(body_lhs), std::move(body_rhs))); + + auto continuing_lhs = std::make_unique( + std::make_unique(&i32, 2)); + auto continuing_lhs_ptr = continuing_lhs.get(); + + auto continuing_rhs = std::make_unique( + std::make_unique(&f32, 2.3f)); + auto continuing_rhs_ptr = continuing_rhs.get(); + + ast::StatementList continuing; + continuing.push_back(std::make_unique( + std::move(continuing_lhs), std::move(continuing_rhs))); + + ast::LoopStatement stmt(std::move(body), std::move(continuing)); + + EXPECT_TRUE(td()->DetermineResultType(&stmt)); + ASSERT_NE(body_lhs_ptr->result_type(), nullptr); + ASSERT_NE(body_rhs_ptr->result_type(), nullptr); + ASSERT_NE(continuing_lhs_ptr->result_type(), nullptr); + ASSERT_NE(continuing_rhs_ptr->result_type(), nullptr); + EXPECT_TRUE(body_lhs_ptr->result_type()->IsI32()); + EXPECT_TRUE(body_rhs_ptr->result_type()->IsF32()); + EXPECT_TRUE(continuing_lhs_ptr->result_type()->IsI32()); + EXPECT_TRUE(continuing_rhs_ptr->result_type()->IsF32()); +} + TEST_F(TypeDeterminerTest, Expr_Constructor_Scalar) { ast::type::F32Type f32; ast::ScalarConstructorExpression s(