diff --git a/src/ast/function.cc b/src/ast/function.cc index adacdad5c9..d6823b917c 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -156,6 +156,10 @@ void Function::add_ancestor_entry_point(const std::string& ep) { ancestor_entry_points_.push_back(ep); } +const Statement* Function::get_last_statement() const { + return body_->last(); +} + bool Function::IsValid() const { for (const auto& param : params_) { if (param == nullptr || !param->IsValid()) diff --git a/src/ast/function.h b/src/ast/function.h index a112e20aa7..09ddd0e4c2 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -121,6 +121,9 @@ class Function : public Node { void set_return_type(type::Type* type) { return_type_ = type; } /// @returns the function return type. type::Type* return_type() const { return return_type_; } + /// @returns a pointer to the last statement of the function or nullptr if + // function is empty + const Statement* get_last_statement() const; /// Sets the body of the function /// @param body the function body diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index 7988388c05..f374363ca2 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc @@ -349,6 +349,30 @@ TEST_F(FunctionTest, TypeName_WithParams) { EXPECT_EQ(f.type_name(), "__func__void__i32__f32"); } +TEST_F(FunctionTest, GetLastStatement) { + type::VoidType void_type; + + VariableList params; + auto body = std::make_unique(); + auto stmt = std::make_unique(); + auto* stmt_ptr = stmt.get(); + body->append(std::move(stmt)); + Function f("func", std::move(params), &void_type); + f.set_body(std::move(body)); + + EXPECT_EQ(f.get_last_statement(), stmt_ptr); +} + +TEST_F(FunctionTest, GetLastStatement_nullptr) { + type::VoidType void_type; + + VariableList params; + auto body = std::make_unique(); + Function f("func", std::move(params), &void_type); + f.set_body(std::move(body)); + + EXPECT_EQ(f.get_last_statement(), nullptr); +} } // namespace } // namespace ast } // namespace tint