diff --git a/src/tint/resolver/function_validation_test.cc b/src/tint/resolver/function_validation_test.cc index 317cadc265..3be06a9ae7 100644 --- a/src/tint/resolver/function_validation_test.cc +++ b/src/tint/resolver/function_validation_test.cc @@ -73,8 +73,8 @@ TEST_F(ResolverFunctionValidationTest, VoidFunctionEndWithoutReturnStatement_Pas // fn func { var a:i32 = 2i; } auto* var = Var("a", ty.i32(), Expr(2_i)); - Func(Source{{12, 34}}, "func", ast::VariableList{}, ty.void_(), - ast::StatementList{ + Func(Source{{12, 34}}, "func", {}, ty.void_(), + { Decl(var), }); @@ -88,12 +88,11 @@ TEST_F(ResolverFunctionValidationTest, FunctionUsingSameVariableName_Pass) { // } auto* var = Var("func", ty.i32(), Expr(0_i)); - Func("func", ast::VariableList{}, ty.i32(), - ast::StatementList{ + Func("func", {}, ty.i32(), + { Decl(var), Return(Source{{12, 34}}, Expr("func")), - }, - ast::AttributeList{}); + }); ASSERT_TRUE(r()->Resolve()) << r()->error(); } @@ -103,17 +102,15 @@ TEST_F(ResolverFunctionValidationTest, FunctionNameSameAsFunctionScopeVariableNa // fn b() -> i32 { return 2; } auto* var = Var("b", ty.i32(), Expr(0_i)); - Func("a", ast::VariableList{}, ty.void_(), - ast::StatementList{ + Func("a", {}, ty.void_(), + { Decl(var), - }, - ast::AttributeList{}); + }); - Func(Source{{12, 34}}, "b", ast::VariableList{}, ty.i32(), - ast::StatementList{ + Func(Source{{12, 34}}, "b", {}, ty.i32(), + { Return(2_i), - }, - ast::AttributeList{}); + }); ASSERT_TRUE(r()->Resolve()) << r()->error(); } @@ -129,7 +126,7 @@ TEST_F(ResolverFunctionValidationTest, UnreachableCode_return) { auto* ret = Return(); auto* assign_a = Assign(Source{{12, 34}}, "a", 2_i); - Func("func", ast::VariableList{}, ty.void_(), {decl_a, ret, assign_a}); + Func("func", {}, ty.void_(), {decl_a, ret, assign_a}); ASSERT_TRUE(r()->Resolve()); @@ -150,7 +147,7 @@ TEST_F(ResolverFunctionValidationTest, UnreachableCode_return_InBlocks) { auto* ret = Return(); auto* assign_a = Assign(Source{{12, 34}}, "a", 2_i); - Func("func", ast::VariableList{}, ty.void_(), {decl_a, Block(Block(Block(ret))), assign_a}); + Func("func", {}, ty.void_(), {decl_a, Block(Block(Block(ret))), assign_a}); ASSERT_TRUE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable"); @@ -170,7 +167,7 @@ TEST_F(ResolverFunctionValidationTest, UnreachableCode_discard) { auto* discard = Discard(); auto* assign_a = Assign(Source{{12, 34}}, "a", 2_i); - Func("func", ast::VariableList{}, ty.void_(), {decl_a, discard, assign_a}); + Func("func", {}, ty.void_(), {decl_a, discard, assign_a}); ASSERT_TRUE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable"); @@ -190,7 +187,7 @@ TEST_F(ResolverFunctionValidationTest, UnreachableCode_discard_InBlocks) { auto* discard = Discard(); auto* assign_a = Assign(Source{{12, 34}}, "a", 2_i); - Func("func", ast::VariableList{}, ty.void_(), {decl_a, Block(Block(Block(discard))), assign_a}); + Func("func", {}, ty.void_(), {decl_a, Block(Block(Block(discard))), assign_a}); ASSERT_TRUE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 warning: code is unreachable"); @@ -204,11 +201,10 @@ TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatement_Fail) { auto* var = Var("a", ty.i32(), Expr(2_i)); - Func(Source{{12, 34}}, "func", ast::VariableList{}, ty.i32(), - ast::StatementList{ + Func(Source{{12, 34}}, "func", {}, ty.i32(), + { Decl(var), - }, - ast::AttributeList{}); + }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: missing return at end of function"); @@ -217,7 +213,7 @@ TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatement_Fail) { TEST_F(ResolverFunctionValidationTest, VoidFunctionEndWithoutReturnStatementEmptyBody_Pass) { // fn func {} - Func(Source{{12, 34}}, "func", ast::VariableList{}, ty.void_(), ast::StatementList{}); + Func(Source{{12, 34}}, "func", {}, ty.void_(), {}); ASSERT_TRUE(r()->Resolve()) << r()->error(); } @@ -225,8 +221,7 @@ TEST_F(ResolverFunctionValidationTest, VoidFunctionEndWithoutReturnStatementEmpt TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatementEmptyBody_Fail) { // fn func() -> int {} - Func(Source{{12, 34}}, "func", ast::VariableList{}, ty.i32(), ast::StatementList{}, - ast::AttributeList{}); + Func(Source{{12, 34}}, "func", {}, ty.i32(), {}); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "12:34 error: missing return at end of function"); @@ -235,21 +230,34 @@ TEST_F(ResolverFunctionValidationTest, FunctionEndWithoutReturnStatementEmptyBod TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementType_Pass) { // fn func { return; } - Func("func", ast::VariableList{}, ty.void_(), - ast::StatementList{ - Return(), - }); + Func("func", {}, ty.void_(), {Return()}); ASSERT_TRUE(r()->Resolve()) << r()->error(); } -TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementType_fail) { +TEST_F(ResolverFunctionValidationTest, VoidFunctionReturnsAInt) { + // fn func { return 2; } + Func("func", {}, ty.void_(), {Return(Source{{12, 34}}, Expr(2_a))}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: return statement type must match its function return " + "type, returned 'abstract-int', expected 'void'"); +} + +TEST_F(ResolverFunctionValidationTest, VoidFunctionReturnsAFloat) { + // fn func { return 2.0; } + Func("func", {}, ty.void_(), {Return(Source{{12, 34}}, Expr(2.0_a))}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: return statement type must match its function return " + "type, returned 'abstract-float', expected 'void'"); +} + +TEST_F(ResolverFunctionValidationTest, VoidFunctionReturnsI32) { // fn func { return 2i; } - Func("func", ast::VariableList{}, ty.void_(), - ast::StatementList{ - Return(Source{{12, 34}}, Expr(2_i)), - }, - ast::AttributeList{}); + Func("func", {}, ty.void_(), {Return(Source{{12, 34}}, Expr(2_i))}); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), @@ -272,11 +280,10 @@ TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementType_ TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeMissing_fail) { // fn func() -> f32 { return; } - Func("func", ast::VariableList{}, ty.f32(), - ast::StatementList{ + Func("func", {}, ty.f32(), + { Return(Source{{12, 34}}, nullptr), - }, - ast::AttributeList{}); + }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), @@ -286,22 +293,20 @@ TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeM TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeF32_pass) { // fn func() -> f32 { return 2.0; } - Func("func", ast::VariableList{}, ty.f32(), - ast::StatementList{ + Func("func", {}, ty.f32(), + { Return(Source{{12, 34}}, Expr(2_f)), - }, - ast::AttributeList{}); + }); ASSERT_TRUE(r()->Resolve()) << r()->error(); } TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) { // fn func() -> f32 { return 2i; } - Func("func", ast::VariableList{}, ty.f32(), - ast::StatementList{ + Func("func", {}, ty.f32(), + { Return(Source{{12, 34}}, Expr(2_i)), - }, - ast::AttributeList{}); + }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), @@ -313,11 +318,10 @@ TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeF // type myf32 = f32; // fn func() -> myf32 { return 2.0; } auto* myf32 = Alias("myf32", ty.f32()); - Func("func", ast::VariableList{}, ty.Of(myf32), - ast::StatementList{ + Func("func", {}, ty.Of(myf32), + { Return(Source{{12, 34}}, Expr(2_f)), - }, - ast::AttributeList{}); + }); ASSERT_TRUE(r()->Resolve()) << r()->error(); } @@ -326,11 +330,10 @@ TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeF // type myf32 = f32; // fn func() -> myf32 { return 2u; } auto* myf32 = Alias("myf32", ty.f32()); - Func("func", ast::VariableList{}, ty.Of(myf32), - ast::StatementList{ + Func("func", {}, ty.Of(myf32), + { Return(Source{{12, 34}}, Expr(2_u)), - }, - ast::AttributeList{}); + }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), @@ -341,10 +344,10 @@ TEST_F(ResolverFunctionValidationTest, FunctionTypeMustMatchReturnStatementTypeF TEST_F(ResolverFunctionValidationTest, CannotCallEntryPoint) { // @stage(compute) @workgroup_size(1) fn entrypoint() {} // fn func() { return entrypoint(); } - Func("entrypoint", ast::VariableList{}, ty.void_(), {}, + Func("entrypoint", {}, ty.void_(), {}, {Stage(ast::PipelineStage::kCompute), WorkgroupSize(1_i)}); - Func("func", ast::VariableList{}, ty.void_(), + Func("func", {}, ty.void_(), { CallStmt(Call(Source{{12, 34}}, "entrypoint")), }); @@ -359,8 +362,8 @@ TEST_F(ResolverFunctionValidationTest, PipelineStage_MustBeUnique_Fail) { // @stage(fragment) // @stage(vertex) // fn main() { return; } - Func(Source{{12, 34}}, "main", ast::VariableList{}, ty.void_(), - ast::StatementList{ + Func(Source{{12, 34}}, "main", {}, ty.void_(), + { Return(), }, ast::AttributeList{ @@ -375,11 +378,10 @@ TEST_F(ResolverFunctionValidationTest, PipelineStage_MustBeUnique_Fail) { } TEST_F(ResolverFunctionValidationTest, NoPipelineEntryPoints) { - Func("vtx_func", ast::VariableList{}, ty.void_(), - ast::StatementList{ + Func("vtx_func", {}, ty.void_(), + { Return(), - }, - ast::AttributeList{}); + }); ASSERT_TRUE(r()->Resolve()) << r()->error(); } @@ -392,8 +394,7 @@ TEST_F(ResolverFunctionValidationTest, FunctionVarInitWithParam) { auto* bar = Param("bar", ty.f32()); auto* baz = Var("baz", ty.f32(), Expr("bar")); - Func("foo", ast::VariableList{bar}, ty.void_(), ast::StatementList{Decl(baz)}, - ast::AttributeList{}); + Func("foo", ast::VariableList{bar}, ty.void_(), {Decl(baz)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); } @@ -406,8 +407,7 @@ TEST_F(ResolverFunctionValidationTest, FunctionConstInitWithParam) { auto* bar = Param("bar", ty.f32()); auto* baz = Let("baz", ty.f32(), Expr("bar")); - Func("foo", ast::VariableList{bar}, ty.void_(), ast::StatementList{Decl(baz)}, - ast::AttributeList{}); + Func("foo", ast::VariableList{bar}, ty.void_(), {Decl(baz)}); ASSERT_TRUE(r()->Resolve()) << r()->error(); } diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 294b3fb499..90ae4b0822 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -2302,10 +2302,16 @@ sem::Statement* Resolver::ReturnStatement(const ast::ReturnStatement* stmt) { const sem::Type* value_ty = nullptr; if (auto* value = stmt->value) { - const auto* expr = Materialize(Expression(value), current_function_->ReturnType()); + const auto* expr = Expression(value); if (!expr) { return false; } + if (auto* ret_ty = current_function_->ReturnType(); !ret_ty->Is()) { + expr = Materialize(expr, ret_ty); + if (!expr) { + return false; + } + } behaviors.Add(expr->Behaviors() - sem::Behavior::kNext); value_ty = expr->Type()->UnwrapRef(); } else {