diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 2e8ffbd3c3..8cc2b8b41d 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -18,6 +18,7 @@ #include "src/ast/break_statement.h" #include "src/ast/case_statement.h" #include "src/ast/continue_statement.h" +#include "src/ast/else_statement.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/type_constructor_expression.h" @@ -90,6 +91,11 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { auto c = stmt->AsContinue(); return DetermineResultType(c->conditional()); } + if (stmt->IsElse()) { + auto e = stmt->AsElse(); + return DetermineResultType(e->condition()) && + DetermineResultType(e->body()); + } error_ = "unknown statement type for type determination"; return false; diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index f5c54f1026..9fe15dd8c6 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -22,6 +22,7 @@ #include "src/ast/break_statement.h" #include "src/ast/case_statement.h" #include "src/ast/continue_statement.h" +#include "src/ast/else_statement.h" #include "src/ast/float_literal.h" #include "src/ast/int_literal.h" #include "src/ast/scalar_constructor_expression.h" @@ -120,6 +121,35 @@ TEST_F(TypeDeterminerTest, Stmt_Continue) { EXPECT_TRUE(cond_ptr->result_type()->IsI32()); } +TEST_F(TypeDeterminerTest, Stmt_Else) { + ast::type::I32Type i32; + ast::type::F32Type f32; + + auto lhs = std::make_unique( + std::make_unique(&i32, 2)); + auto lhs_ptr = lhs.get(); + + auto rhs = std::make_unique( + std::make_unique(&f32, 2.3f)); + auto rhs_ptr = rhs.get(); + + ast::StatementList body; + body.push_back(std::make_unique(std::move(lhs), + std::move(rhs))); + + ast::ElseStatement stmt(std::make_unique( + std::make_unique(&i32, 3)), + std::move(body)); + + EXPECT_TRUE(td()->DetermineResultType(&stmt)); + ASSERT_NE(stmt.condition()->result_type(), nullptr); + ASSERT_NE(lhs_ptr->result_type(), nullptr); + ASSERT_NE(rhs_ptr->result_type(), nullptr); + EXPECT_TRUE(stmt.condition()->result_type()->IsI32()); + EXPECT_TRUE(lhs_ptr->result_type()->IsI32()); + EXPECT_TRUE(rhs_ptr->result_type()->IsF32()); +} + TEST_F(TypeDeterminerTest, Expr_Constructor_Scalar) { ast::type::F32Type f32; ast::ScalarConstructorExpression s(