From 2e97435ba613df5dd4ccb500cca0a07998d8ff74 Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Mon, 22 Mar 2021 23:20:17 +0000 Subject: [PATCH] Move return validation from Validator to Resolver Improved error message to use friendly names. Fixed tests that broke as a result of this change. Bug: tint:642 Change-Id: I9a1e819e1a6110a89c826936b96ab84f7f79a084 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/45582 Reviewed-by: Ben Clayton Kokoro: Kokoro Commit-Queue: Antonio Maiorano --- src/resolver/function_validation_test.cc | 107 ++++++++++++++++++++ src/resolver/resolver.cc | 33 +++++- src/resolver/resolver.h | 2 + src/resolver/resolver_test.cc | 2 +- src/transform/renamer_test.cc | 8 +- src/validator/validator_function_test.cc | 117 ---------------------- src/validator/validator_impl.cc | 25 ----- src/validator/validator_impl.h | 4 - src/writer/spirv/builder_call_test.cc | 38 +++---- src/writer/spirv/builder_function_test.cc | 15 ++- src/writer/spirv/builder_if_test.cc | 12 +-- src/writer/spirv/builder_return_test.cc | 4 +- 12 files changed, 177 insertions(+), 190 deletions(-) diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc index 0da990a4b8..ba839defb7 100644 --- a/src/resolver/function_validation_test.cc +++ b/src/resolver/function_validation_test.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "src/ast/return_statement.h" +#include "src/ast/stage_decoration.h" #include "src/resolver/resolver.h" #include "src/resolver/resolver_test_helper.h" @@ -74,5 +75,111 @@ TEST_F(ResolverFunctionValidationTest, "12:34 error v-0002: non-void function must end with a return statement"); } +TEST_F(ResolverFunctionValidationTest, + FunctionTypeMustMatchReturnStatementType_Pass) { + // [[stage(vertex)]] + // fn func -> void { return; } + + Func("func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + create(), + }, + ast::DecorationList{ + create(ast::PipelineStage::kVertex), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverFunctionValidationTest, + FunctionTypeMustMatchReturnStatementType_fail) { + // fn func -> void { return 2; } + Func("func", ast::VariableList{}, ty.void_(), + ast::StatementList{ + create(Source{Source::Location{12, 34}}, + Expr(2)), + }, + ast::DecorationList{}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error v-000y: return statement type must match its function " + "return type, returned 'i32', expected 'void'"); +} + +TEST_F(ResolverFunctionValidationTest, + FunctionTypeMustMatchReturnStatementTypeF32_pass) { + // fn func -> f32 { return 2.0; } + Func("func", ast::VariableList{}, ty.f32(), + ast::StatementList{ + create(Source{Source::Location{12, 34}}, + Expr(2.f)), + }, + ast::DecorationList{}); + Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{}, + ast::DecorationList{ + create(ast::PipelineStage::kVertex), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverFunctionValidationTest, + FunctionTypeMustMatchReturnStatementTypeF32_fail) { + // fn func -> f32 { return 2; } + Func("func", ast::VariableList{}, ty.f32(), + ast::StatementList{ + create(Source{Source::Location{12, 34}}, + Expr(2)), + }, + ast::DecorationList{}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error v-000y: return statement type must match its function " + "return type, returned 'i32', expected 'f32'"); +} + +TEST_F(ResolverFunctionValidationTest, + FunctionTypeMustMatchReturnStatementTypeF32Alias_pass) { + // type myf32 = f32; + // fn func -> myf32 { return 2.0; } + auto* myf32 = ty.alias("myf32", ty.f32()); + Func("func", ast::VariableList{}, myf32, + ast::StatementList{ + create(Source{Source::Location{12, 34}}, + Expr(2.f)), + }, + ast::DecorationList{}); + Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{}, + ast::DecorationList{ + create(ast::PipelineStage::kVertex), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverFunctionValidationTest, + FunctionTypeMustMatchReturnStatementTypeF32Alias_fail) { + // type myf32 = f32; + // fn func -> myf32 { return 2; } + auto* myf32 = ty.alias("myf32", ty.f32()); + Func("func", ast::VariableList{}, myf32, + ast::StatementList{ + create(Source{Source::Location{12, 34}}, + Expr(2u)), + }, + ast::DecorationList{}); + Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{}, + ast::DecorationList{ + create(ast::PipelineStage::kVertex), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error v-000y: return statement type must match its function " + "return type, returned 'u32', expected 'myf32'"); +} + } // namespace } // namespace tint diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index fa5a25f5fc..b4915cfa0a 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -406,8 +406,7 @@ bool Resolver::Statement(ast::Statement* stmt) { }); } if (auto* r = stmt->As()) { - current_function_->return_statements.push_back(r); - return Expression(r->value()); + return Return(r); } if (auto* s = stmt->As()) { if (!Expression(s->condition())) { @@ -1647,6 +1646,36 @@ Resolver::StructInfo* Resolver::Structure(type::Struct* str) { return info; } +bool Resolver::ValidateReturn(const ast::ReturnStatement* ret) { + type::Type* func_type = current_function_->declaration->return_type(); + + auto* ret_type = ret->has_value() ? TypeOf(ret->value())->UnwrapAll() + : builder_->ty.void_(); + + if (func_type->UnwrapAll() != ret_type) { + diagnostics_.add_error( + "v-000y", + "return statement type must match its function " + "return type, returned '" + + ret_type->FriendlyName(builder_->Symbols()) + "', expected '" + + func_type->FriendlyName(builder_->Symbols()) + "'", + ret->source()); + return false; + } + + return true; +} + +bool Resolver::Return(ast::ReturnStatement* ret) { + current_function_->return_statements.push_back(ret); + + auto result = Expression(ret->value()); + + // Validate after processing the return value expression so that its type is + // available for validation + return result && ValidateReturn(ret); +} + bool Resolver::ApplyStorageClassUsageToType(ast::StorageClass sc, type::Type* ty, Source usage) { diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 97938c117c..13a0453f3c 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -219,6 +219,7 @@ class Resolver { bool Statements(const ast::StatementList&); bool UnaryOp(ast::UnaryOpExpression*); bool VariableDeclStatement(const ast::VariableDeclStatement*); + bool Return(ast::ReturnStatement* ret); // AST and Type validation methods // Each return true on success, false on failure. @@ -226,6 +227,7 @@ class Resolver { bool ValidateParameter(const ast::Variable* param); bool ValidateFunction(const ast::Function* func); bool ValidateStructure(const type::Struct* st); + bool ValidateReturn(const ast::ReturnStatement* ret); /// @returns the semantic information for the array `arr`, building it if it /// hasn't been constructed already. If an error is raised, nullptr is diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index 57b0820f13..53e35b0721 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -228,7 +228,7 @@ TEST_F(ResolverTest, Stmt_Return) { auto* cond = Expr(2); auto* ret = create(cond); - WrapInFunction(ret); + Func("test", {}, ty.i32(), {ret}, {}); EXPECT_TRUE(r()->Resolve()) << r()->error(); diff --git a/src/transform/renamer_test.cc b/src/transform/renamer_test.cc index 887cfaafa9..cfd3290175 100644 --- a/src/transform/renamer_test.cc +++ b/src/transform/renamer_test.cc @@ -83,7 +83,7 @@ fn _tint_3() -> void { TEST_F(RenamerTest, PreserveSwizzles) { auto* src = R"( [[stage(vertex)]] -fn entry() -> void { +fn entry() -> vec4 { var v : vec4; var rgba : f32; var xyzw : f32; @@ -93,7 +93,7 @@ fn entry() -> void { auto* expect = R"( [[stage(vertex)]] -fn _tint_1() -> void { +fn _tint_1() -> vec4 { var _tint_2 : vec4; var _tint_3 : f32; var _tint_4 : f32; @@ -120,7 +120,7 @@ fn _tint_1() -> void { TEST_F(RenamerTest, PreserveIntrinsics) { auto* src = R"( [[stage(vertex)]] -fn entry() -> void { +fn entry() -> vec4 { var blah : vec4; return abs(blah); } @@ -128,7 +128,7 @@ fn entry() -> void { auto* expect = R"( [[stage(vertex)]] -fn _tint_1() -> void { +fn _tint_1() -> vec4 { var _tint_2 : vec4; return abs(_tint_2); } diff --git a/src/validator/validator_function_test.cc b/src/validator/validator_function_test.cc index e5cc9bcdc8..d74e10377f 100644 --- a/src/validator/validator_function_test.cc +++ b/src/validator/validator_function_test.cc @@ -56,123 +56,6 @@ TEST_F(ValidateFunctionTest, EXPECT_TRUE(v.Validate()); } -TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_Pass) { - // [[stage(vertex)]] - // fn func -> void { return; } - - Func("func", ast::VariableList{}, ty.void_(), - ast::StatementList{ - create(), - }, - ast::DecorationList{ - create(ast::PipelineStage::kVertex), - }); - - ValidatorImpl& v = Build(); - - EXPECT_TRUE(v.Validate()) << v.error(); -} - -TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_fail) { - // fn func -> void { return 2; } - Func("func", ast::VariableList{}, ty.void_(), - ast::StatementList{ - create(Source{Source::Location{12, 34}}, - Expr(2)), - }, - ast::DecorationList{}); - - ValidatorImpl& v = Build(); - - EXPECT_FALSE(v.Validate()); - // TODO(sarahM0): replace 000y with a rule number - EXPECT_EQ(v.error(), - "12:34 v-000y: return statement type must match its function " - "return type, returned '__i32', expected '__void'"); -} - -TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_pass) { - // fn func -> f32 { return 2.0; } - Func("func", ast::VariableList{}, ty.f32(), - ast::StatementList{ - create(Source{Source::Location{12, 34}}, - Expr(2.f)), - }, - ast::DecorationList{}); - Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{}, - ast::DecorationList{ - create(ast::PipelineStage::kVertex), - }); - - ValidatorImpl& v = Build(); - - EXPECT_TRUE(v.Validate()); -} - -TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) { - // fn func -> f32 { return 2; } - Func("func", ast::VariableList{}, ty.f32(), - ast::StatementList{ - create(Source{Source::Location{12, 34}}, - Expr(2)), - }, - ast::DecorationList{}); - - ValidatorImpl& v = Build(); - - EXPECT_FALSE(v.Validate()); - // TODO(sarahM0): replace 000y with a rule number - EXPECT_EQ(v.error(), - "12:34 v-000y: return statement type must match its function " - "return type, returned '__i32', expected '__f32'"); -} - -TEST_F(ValidateFunctionTest, - FunctionTypeMustMatchReturnStatementTypeF32Alias_pass) { - // type myf32 = f32; - // fn func -> myf32 { return 2.0; } - auto* myf32 = ty.alias("myf32", ty.f32()); - Func("func", ast::VariableList{}, myf32, - ast::StatementList{ - create(Source{Source::Location{12, 34}}, - Expr(2.f)), - }, - ast::DecorationList{}); - Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{}, - ast::DecorationList{ - create(ast::PipelineStage::kVertex), - }); - - ValidatorImpl& v = Build(); - - EXPECT_TRUE(v.Validate()); -} - -TEST_F(ValidateFunctionTest, - FunctionTypeMustMatchReturnStatementTypeF32Alias_fail) { - // type myf32 = f32; - // fn func -> myf32 { return 2; } - auto* myf32 = ty.alias("myf32", ty.f32()); - Func("func", ast::VariableList{}, myf32, - ast::StatementList{ - create(Source{Source::Location{12, 34}}, - Expr(2u)), - }, - ast::DecorationList{}); - Func("main", ast::VariableList{}, ty.void_(), ast::StatementList{}, - ast::DecorationList{ - create(ast::PipelineStage::kVertex), - }); - - ValidatorImpl& v = Build(); - - EXPECT_FALSE(v.Validate()); - EXPECT_EQ( - v.error(), - "12:34 v-000y: return statement type must match its function " - "return type, returned '__u32', expected '__alias_tint_symbol_1__f32'"); -} - TEST_F(ValidateFunctionTest, PipelineStage_MustBeUnique_Fail) { // [[stage(fragment)]] // [[stage(vertex)]] diff --git a/src/validator/validator_impl.cc b/src/validator/validator_impl.cc index 65e2a33400..16187d9db1 100644 --- a/src/validator/validator_impl.cc +++ b/src/validator/validator_impl.cc @@ -182,28 +182,6 @@ bool ValidatorImpl::ValidateFunction(const ast::Function* func) { return true; } -bool ValidatorImpl::ValidateReturnStatement(const ast::ReturnStatement* ret) { - // TODO(sarahM0): update this when this issue resolves: - // https://github.com/gpuweb/gpuweb/issues/996 - type::Type* func_type = current_function_->return_type(); - - type::Void void_type; - auto* ret_type = ret->has_value() - ? program_->Sem().Get(ret->value())->Type()->UnwrapAll() - : &void_type; - - if (func_type->UnwrapAll()->type_name() != ret_type->type_name()) { - add_error(ret->source(), "v-000y", - "return statement type must match its function return " - "type, returned '" + - ret_type->type_name() + "', expected '" + - func_type->type_name() + "'"); - return false; - } - - return true; -} - bool ValidatorImpl::ValidateStatements(const ast::BlockStatement* block) { if (!block) { return false; @@ -262,9 +240,6 @@ bool ValidatorImpl::ValidateStatement(const ast::Statement* stmt) { if (auto* a = stmt->As()) { return ValidateAssign(a); } - if (auto* r = stmt->As()) { - return ValidateReturnStatement(r); - } if (auto* s = stmt->As()) { return ValidateSwitch(s); } diff --git a/src/validator/validator_impl.h b/src/validator/validator_impl.h index 2ec4c22e26..04edf9578e 100644 --- a/src/validator/validator_impl.h +++ b/src/validator/validator_impl.h @@ -99,10 +99,6 @@ class ValidatorImpl { /// @returns true if no previous declaration with the `decl` 's name /// exist in the variable stack bool ValidateDeclStatement(const ast::VariableDeclStatement* decl); - /// Validates return statement - /// @param ret the return statement to check - /// @returns true if function return type matches the return statement type - bool ValidateReturnStatement(const ast::ReturnStatement* ret); /// Validates switch statements /// @param s the switch statement to check /// @returns true if the valdiation was successful diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc index de6fc29726..3f5311003e 100644 --- a/src/writer/spirv/builder_call_test.cc +++ b/src/writer/spirv/builder_call_test.cc @@ -79,7 +79,7 @@ TEST_F(BuilderTest, Statement_Call) { func_params.push_back(Var("b", ty.f32(), ast::StorageClass::kFunction)); auto* a_func = - Func("a_func", func_params, ty.void_(), + Func("a_func", func_params, ty.f32(), ast::StatementList{create(Add("a", "b"))}, ast::DecorationList{}); @@ -96,27 +96,27 @@ TEST_F(BuilderTest, Statement_Call) { ASSERT_TRUE(b.GenerateFunction(func)) << b.error(); EXPECT_TRUE(b.GenerateStatement(expr)) << b.error(); - EXPECT_EQ(DumpBuilder(b), R"(OpName %4 "a_func" -OpName %5 "a" -OpName %6 "b" + EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" +OpName %4 "a" +OpName %5 "b" OpName %12 "main" -%2 = OpTypeVoid -%3 = OpTypeFloat 32 -%1 = OpTypeFunction %2 %3 %3 -%11 = OpTypeFunction %2 -%15 = OpConstant %3 1 -%4 = OpFunction %2 None %1 -%5 = OpFunctionParameter %3 -%6 = OpFunctionParameter %3 -%7 = OpLabel -%8 = OpLoad %3 %5 -%9 = OpLoad %3 %6 -%10 = OpFAdd %3 %8 %9 -OpReturnValue %10 +%2 = OpTypeFloat 32 +%1 = OpTypeFunction %2 %2 %2 +%11 = OpTypeVoid +%10 = OpTypeFunction %11 +%15 = OpConstant %2 1 +%3 = OpFunction %2 None %1 +%4 = OpFunctionParameter %2 +%5 = OpFunctionParameter %2 +%6 = OpLabel +%7 = OpLoad %2 %4 +%8 = OpLoad %2 %5 +%9 = OpFAdd %2 %7 %8 +OpReturnValue %9 OpFunctionEnd -%12 = OpFunction %2 None %11 +%12 = OpFunction %11 None %10 %13 = OpLabel -%14 = OpFunctionCall %2 %4 %15 %15 +%14 = OpFunctionCall %2 %3 %15 %15 OpReturn OpFunctionEnd )"); diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc index 392a05ed98..e92680db4b 100644 --- a/src/writer/spirv/builder_function_test.cc +++ b/src/writer/spirv/builder_function_test.cc @@ -65,7 +65,7 @@ OpFunctionEnd TEST_F(BuilderTest, Function_Terminator_ReturnValue) { Global("a", ty.f32(), ast::StorageClass::kPrivate); - Func("a_func", {}, ty.void_(), + Func("a_func", {}, ty.f32(), ast::StatementList{create(Expr("a"))}, ast::DecorationList{}); @@ -77,17 +77,16 @@ TEST_F(BuilderTest, Function_Terminator_ReturnValue) { ASSERT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error(); ASSERT_TRUE(b.GenerateFunction(func)) << b.error(); EXPECT_EQ(DumpBuilder(b), R"(OpName %1 "a" -OpName %7 "a_func" +OpName %6 "a_func" %3 = OpTypeFloat 32 %2 = OpTypePointer Private %3 %4 = OpConstantNull %3 %1 = OpVariable %2 Private %4 -%6 = OpTypeVoid -%5 = OpTypeFunction %6 -%7 = OpFunction %6 None %5 -%8 = OpLabel -%9 = OpLoad %3 %1 -OpReturnValue %9 +%5 = OpTypeFunction %3 +%6 = OpFunction %3 None %5 +%7 = OpLabel +%8 = OpLoad %3 %1 +OpReturnValue %8 OpFunctionEnd )"); } diff --git a/src/writer/spirv/builder_if_test.cc b/src/writer/spirv/builder_if_test.cc index 6918e8670b..90215a5283 100644 --- a/src/writer/spirv/builder_if_test.cc +++ b/src/writer/spirv/builder_if_test.cc @@ -511,14 +511,10 @@ TEST_F(BuilderTest, If_WithReturnValue) { // if (true) { // return false; // } - auto* if_body = create(ast::StatementList{ - create(Expr(false)), - }); - - auto* expr = - create(Expr(true), if_body, ast::ElseStatementList{}); - WrapInFunction(expr); - + // return true; + auto* if_body = Block(Return(Expr(false))); + auto* expr = If(Expr(true), if_body); + Func("test", {}, ty.bool_(), {expr, Return(Expr(true))}, {}); spirv::Builder& b = Build(); b.push_function(Function{}); diff --git a/src/writer/spirv/builder_return_test.cc b/src/writer/spirv/builder_return_test.cc index 17265b8195..3169a2ff74 100644 --- a/src/writer/spirv/builder_return_test.cc +++ b/src/writer/spirv/builder_return_test.cc @@ -39,7 +39,7 @@ TEST_F(BuilderTest, Return_WithValue) { auto* val = vec3(1.f, 1.f, 3.f); auto* ret = create(val); - WrapInFunction(ret); + Func("test", {}, ty.vec3(), {ret}, {}); spirv::Builder& b = Build(); @@ -62,7 +62,7 @@ TEST_F(BuilderTest, Return_WithValue_GeneratesLoad) { auto* var = Global("param", ty.f32(), ast::StorageClass::kFunction); auto* ret = create(Expr("param")); - WrapInFunction(ret); + Func("test", {}, ty.f32(), {ret}, {}); spirv::Builder& b = Build();