From 9432c970701f68384a3f8b7feb02a1a57011ddb8 Mon Sep 17 00:00:00 2001 From: Sarah Date: Tue, 29 Jun 2021 15:24:04 +0000 Subject: [PATCH] validation: validate function parameters A function parameter of pointer type must be in one of the following storage classes: - function - private - workgroup A function parameter must one the following types: - atomic-free plain type - a pointer type - a texture type - a sampler type Bug: tint:896 tint:894 Change-Id: Id8cec1bdc8e5be2c8c18a8420cec8f13f6aeddd0 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/55940 Kokoro: Kokoro Reviewed-by: Antonio Maiorano --- src/resolver/function_validation_test.cc | 60 ++++++++++++++++++++++++ src/resolver/resolver.cc | 36 ++++++++++++-- src/resolver/resolver.h | 3 +- 3 files changed, 95 insertions(+), 4 deletions(-) diff --git a/src/resolver/function_validation_test.cc b/src/resolver/function_validation_test.cc index 2ed9ab7572..5d55377ffe 100644 --- a/src/resolver/function_validation_test.cc +++ b/src/resolver/function_validation_test.cc @@ -521,5 +521,65 @@ TEST_F(ResolverFunctionValidationTest, ReturnIsAtomicFreePlain_StructOfAtomic) { "12:34 error: function return type must be an atomic-free plain type"); } +TEST_F(ResolverFunctionValidationTest, ParameterSotreType_NonAtomicFree) { + Structure("S", {Member("m", ty.atomic(ty.i32()))}); + auto* ret_type = ty.type_name(Source{{12, 34}}, "S"); + auto* bar = Param(Source{{12, 34}}, "bar", ret_type); + Func("f", ast::VariableList{bar}, ty.void_(), {}); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: store type of function parameter must be an " + "atomic-free type"); +} + +TEST_F(ResolverFunctionValidationTest, ParameterSotreType_AtomicFree) { + Structure("S", {Member("m", ty.i32())}); + auto* ret_type = ty.type_name(Source{{12, 34}}, "S"); + auto* bar = Param(Source{{12, 34}}, "bar", ret_type); + Func("f", ast::VariableList{bar}, ty.void_(), {}); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +struct TestParams { + ast::StorageClass storage_class; + bool should_pass; +}; + +struct TestWithParams : resolver::ResolverTestWithParam {}; + +using ResolverFunctionParameterValidationTest = TestWithParams; +TEST_P(ResolverFunctionParameterValidationTest, SotrageClass) { + auto& param = GetParam(); + auto* ptr_type = ty.pointer(Source{{12, 34}}, ty.i32(), param.storage_class); + auto* arg = Param(Source{{12, 34}}, "p", ptr_type); + Func("f", ast::VariableList{arg}, ty.void_(), {}); + + if (param.should_pass) { + EXPECT_TRUE(r()->Resolve()) << r()->error(); + } else { + std::stringstream ss; + ss << param.storage_class; + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: function parameter of pointer type cannot be in '" + + ss.str() + "' storage class"); + } +} +INSTANTIATE_TEST_SUITE_P( + ResolverTest, + ResolverFunctionParameterValidationTest, + testing::Values(TestParams{ast::StorageClass::kNone, false}, + TestParams{ast::StorageClass::kInput, false}, + TestParams{ast::StorageClass::kOutput, false}, + TestParams{ast::StorageClass::kUniform, false}, + TestParams{ast::StorageClass::kWorkgroup, true}, + TestParams{ast::StorageClass::kUniformConstant, false}, + TestParams{ast::StorageClass::kStorage, false}, + TestParams{ast::StorageClass::kImage, false}, + TestParams{ast::StorageClass::kPrivate, true}, + TestParams{ast::StorageClass::kFunction, true})); + } // namespace } // namespace tint diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index ca64b25a04..b4c8a5d5a8 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -916,8 +916,8 @@ bool Resolver::ValidateVariable(const VariableInfo* info) { return true; } -bool Resolver::ValidateParameter(const ast::Function* func, - const VariableInfo* info) { +bool Resolver::ValidateFunctionParameter(const ast::Function* func, + const VariableInfo* info) { if (!ValidateVariable(info)) { return false; } @@ -953,6 +953,36 @@ bool Resolver::ValidateParameter(const ast::Function* func, return false; } } + + if (auto* ref = info->type->As()) { + auto sc = ref->StorageClass(); + if (!(sc == ast::StorageClass::kFunction || + sc == ast::StorageClass::kPrivate || + sc == ast::StorageClass::kWorkgroup)) { + std::stringstream ss; + ss << "function parameter of pointer type cannot be in '" << sc + << "' storage class"; + AddError(ss.str(), info->declaration->source()); + return false; + } + } + + if (IsPlain(info->type)) { + if (!IsAtomicFreePlain(info->type) && + !IsValidationDisabled( + info->declaration->decorations(), + ast::DisabledValidation::kIgnoreAtomicFunctionParameter)) { + AddError("store type of function parameter must be an atomic-free type", + info->declaration->source()); + return false; + } + } else if (!info->type->IsAnyOf()) { + AddError("store type of function parameter cannot be " + + info->type->FriendlyName(builder_->Symbols()), + info->declaration->source()); + return false; + } + return true; } @@ -1077,7 +1107,7 @@ bool Resolver::ValidateFunction(const ast::Function* func, } for (auto* param : func->params()) { - if (!ValidateParameter(func, variable_to_info_.at(param))) { + if (!ValidateFunctionParameter(func, variable_to_info_.at(param))) { return false; } } diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index 0506a814d1..8eb2ed28d1 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -281,7 +281,8 @@ class Resolver { bool ValidateMatrix(const sem::Matrix* matirx_type, const Source& source); bool ValidateMatrixConstructor(const ast::TypeConstructorExpression* ctor, const sem::Matrix* matrix_type); - bool ValidateParameter(const ast::Function* func, const VariableInfo* info); + bool ValidateFunctionParameter(const ast::Function* func, + const VariableInfo* info); bool ValidateReturn(const ast::ReturnStatement* ret); bool ValidateStatements(const ast::StatementList& stmts); bool ValidateStorageTexture(const ast::StorageTexture* t);