diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 6b1d139890..5d2c983b33 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -24,6 +24,7 @@ #include "src/ast/binary_expression.h" #include "src/ast/break_statement.h" #include "src/ast/call_expression.h" +#include "src/ast/call_statement.h" #include "src/ast/case_statement.h" #include "src/ast/cast_expression.h" #include "src/ast/continue_statement.h" @@ -284,6 +285,9 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) { if (stmt->IsBreak()) { return true; } + if (stmt->IsCall()) { + return DetermineResultType(stmt->AsCall()->expr()); + } if (stmt->IsCase()) { auto* c = stmt->AsCase(); return DetermineStatements(c->body()); diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 983ac36d19..b46cfb7d72 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -26,6 +26,7 @@ #include "src/ast/binary_expression.h" #include "src/ast/break_statement.h" #include "src/ast/call_expression.h" +#include "src/ast/call_statement.h" #include "src/ast/case_statement.h" #include "src/ast/cast_expression.h" #include "src/ast/continue_statement.h" @@ -338,6 +339,29 @@ TEST_F(TypeDeterminerTest, Stmt_Switch) { EXPECT_TRUE(rhs_ptr->result_type()->IsF32()); } +TEST_F(TypeDeterminerTest, Stmt_Call) { + ast::type::F32Type f32; + + ast::VariableList params; + auto func = + std::make_unique("my_func", std::move(params), &f32); + mod()->AddFunction(std::move(func)); + + // Register the function + EXPECT_TRUE(td()->Determine()); + + ast::ExpressionList call_params; + auto expr = std::make_unique( + std::make_unique("my_func"), + std::move(call_params)); + auto* expr_ptr = expr.get(); + + ast::CallStatement call(std::move(expr)); + EXPECT_TRUE(td()->DetermineResultType(&call)); + ASSERT_NE(expr_ptr->result_type(), nullptr); + EXPECT_TRUE(expr_ptr->result_type()->IsF32()); +} + TEST_F(TypeDeterminerTest, Stmt_VariableDecl) { ast::type::I32Type i32; auto var =