diff --git a/src/validator.cc b/src/validator.cc index 5af8414e52..3186d9e07d 100644 --- a/src/validator.cc +++ b/src/validator.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "src/validator.h" - #include "src/validator_impl.h" namespace tint { diff --git a/src/validator_function_test.cc b/src/validator_function_test.cc index b220e45e60..e23216a040 100644 --- a/src/validator_function_test.cc +++ b/src/validator_function_test.cc @@ -88,10 +88,7 @@ TEST_F(ValidateFunctionTest, FunctionEndWithoutReturnStatementEmptyBody_Fail) { "12:34: v-0002: function must end with a return statement"); } -TEST_F(ValidateFunctionTest, - DISABLED_FunctionTypeMustMatchReturnStatementType_pass) { - // TODO(sarahM0): remove DISABLED after implementing function type must match - // return type +TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_Pass) { // fn func -> void { return; } ast::type::VoidType void_type; ast::VariableList params; @@ -107,10 +104,7 @@ TEST_F(ValidateFunctionTest, EXPECT_TRUE(v.Validate(mod())) << v.error(); } -TEST_F(ValidateFunctionTest, - DISABLED_FunctionTypeMustMatchReturnStatementType_fail) { - // TODO(sarahM0): remove DISABLED after implementing function type must match - // return type +TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_fail) { // fn func -> void { return 2; } ast::type::VoidType void_type; ast::type::I32Type i32; @@ -130,17 +124,13 @@ TEST_F(ValidateFunctionTest, tint::ValidatorImpl v; EXPECT_FALSE(v.Validate(mod())); // TODO(sarahM0): replace 000y with a rule number - EXPECT_EQ( - v.error(), - "12:34: v-000y: function type must match its return statement type"); + EXPECT_EQ(v.error(), + "12:34: v-000y: return statement type must match its function " + "return type, returned '__i32', expected '__void'"); } -TEST_F(ValidateFunctionTest, - DISABLED_FunctionTypeMustMatchReturnStatementTypeF32_fail) { - // TODO(sarahM0): remove DISABLED after implementing function type must match - // return type +TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) { // fn func -> f32 { return 2; } - ast::type::VoidType void_type; ast::type::I32Type i32; ast::type::F32Type f32; ast::VariableList params; @@ -158,9 +148,9 @@ TEST_F(ValidateFunctionTest, tint::ValidatorImpl v; EXPECT_FALSE(v.Validate(mod())); // TODO(sarahM0): replace 000y with a rule number - EXPECT_EQ( - v.error(), - "12:34: v-000y: function type must match its return statement type"); + EXPECT_EQ(v.error(), + "12:34: v-000y: return statement type must match its function " + "return type, returned '__i32', expected '__f32'"); } TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) { diff --git a/src/validator_impl.cc b/src/validator_impl.cc index 0bbda6c177..672948d9be 100644 --- a/src/validator_impl.cc +++ b/src/validator_impl.cc @@ -14,6 +14,7 @@ #include "src/validator_impl.h" #include "src/ast/function.h" +#include "src/ast/type/void_type.h" #include "src/ast/variable_decl_statement.h" namespace tint { @@ -60,10 +61,11 @@ bool ValidatorImpl::ValidateFunctions(const ast::FunctionList& funcs) { } function_stack_.set(func_ptr->name(), func_ptr); - + current_function_ = func_ptr; if (!ValidateFunction(func_ptr)) { return false; } + current_function_ = nullptr; } return true; } @@ -87,6 +89,28 @@ bool ValidatorImpl::ValidateFunction(const ast::Function* func) { return true; } +bool ValidatorImpl::ValidateReturnStatement(const ast::ReturnStatement* ret) { + // TODO(sarahM0): update this when this issue resolves: + // https://github.com/gpuweb/gpuweb/issues/996 + ast::type::Type* func_type = current_function_->return_type(); + + ast::type::VoidType void_type; + auto* ret_type = ret->has_value() + ? ret->value()->result_type()->UnwrapAliasPtrAlias() + : &void_type; + + if (func_type->type_name() != ret_type->type_name()) { + set_error(ret->source(), + "v-000y: return statement type must match its function return " + "type, returned '" + + ret_type->type_name() + "', expected '" + + func_type->type_name() + "'"); + return false; + } + + return true; +} + bool ValidatorImpl::ValidateStatements(const ast::BlockStatement* block) { if (!block) { return false; @@ -126,6 +150,9 @@ bool ValidatorImpl::ValidateStatement(const ast::Statement* stmt) { if (stmt->IsAssign()) { return ValidateAssign(stmt->AsAssign()); } + if (stmt->IsReturn()) { + return ValidateReturnStatement(stmt->AsReturn()); + } return true; } diff --git a/src/validator_impl.h b/src/validator_impl.h index 8cf7b96265..6f0488ee7f 100644 --- a/src/validator_impl.h +++ b/src/validator_impl.h @@ -22,8 +22,8 @@ #include "src/ast/expression.h" #include "src/ast/identifier_expression.h" #include "src/ast/module.h" +#include "src/ast/return_statement.h" #include "src/ast/statement.h" -#include "src/ast/type/type.h" #include "src/ast/variable.h" #include "src/scope_stack.h" @@ -96,11 +96,16 @@ class ValidatorImpl { /// @returns true if no previous decleration with the |decl|'s name /// exist in the variable stack bool ValidateDeclStatement(const ast::VariableDeclStatement* decl); + /// Validates return statement + /// @param ret the return statement to check + /// @returns true if function return type matches the return statement type + bool ValidateReturnStatement(const ast::ReturnStatement* ret); private: std::string error_; ScopeStack variable_stack_; ScopeStack function_stack_; + ast::Function* current_function_ = nullptr; }; } // namespace tint diff --git a/src/validator_test.cc b/src/validator_test.cc index 9f4ada0f7e..aaed618b82 100644 --- a/src/validator_test.cc +++ b/src/validator_test.cc @@ -674,7 +674,7 @@ TEST_F(ValidatorTest, RedeclaredIdentifierDifferentFunctions_Pass) { ast::VariableList params1; auto func1 = - std::make_unique("func1", std::move(params1), &f32); + std::make_unique("func1", std::move(params1), &void_type); auto body1 = std::make_unique(); body1->append(std::make_unique(Source{13, 34}, std::move(var1)));