diff --git a/src/resolver/call_validation_test.cc b/src/resolver/call_validation_test.cc index d94b70d838..0fc2212a75 100644 --- a/src/resolver/call_validation_test.cc +++ b/src/resolver/call_validation_test.cc @@ -128,6 +128,188 @@ TEST_F(ResolverCallValidationTest, UnusedRetval) { "intentional wrap the function call in ignore()"); } +TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) { + // fn foo(p: ptr) {} + // fn main() { + // var z: i32 = 1; + // foo(&z); + // } + auto* param = Param("p", ty.pointer(ast::StorageClass::kFunction)); + Func("foo", {param}, ty.void_(), {}); + Func("main", {}, ty.void_(), + ast::StatementList{ + Decl(Var("z", ty.i32(), Expr(1))), + create( + Call("foo", AddressOf(Source{{12, 34}}, Expr("z")))), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) { + // fn foo(p: ptr) {} + // fn main() { + // let z: i32 = 1; + // foo(&z); + // } + auto* param = Param("p", ty.pointer(ast::StorageClass::kFunction)); + Func("foo", {param}, ty.void_(), {}); + Func("main", {}, ty.void_(), + ast::StatementList{ + Decl(Const("z", ty.i32(), Expr(1))), + create( + Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression"); +} + +TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) { + // struct S { m: i32; }; + // fn foo(p: ptr) {} + // fn main() { + // var v: S; + // foo(&v.m); + // } + auto* S = Structure("S", {Member("m", ty.i32())}); + auto* param = Param("p", ty.pointer(ast::StorageClass::kFunction)); + Func("foo", {param}, ty.void_(), {}); + Func("main", {}, ty.void_(), + ast::StatementList{ + Decl(Var("v", ty.Of(S))), + create(Call( + "foo", AddressOf(Source{{12, 34}}, MemberAccessor("v", "m")))), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: expected an address-of expression of a variable " + "identifier expression or a function parameter"); +} + +TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) { + // struct S { m: i32; }; + // fn foo(p: ptr) {} + // fn main() { + // let v: S = S(); + // foo(&v.m); + // } + auto* S = Structure("S", {Member("m", ty.i32())}); + auto* param = Param("p", ty.pointer(ast::StorageClass::kFunction)); + Func("foo", {param}, ty.void_(), {}); + Func("main", {}, ty.void_(), + ast::StatementList{ + Decl(Const("v", ty.Of(S), Construct(ty.Of(S)))), + create(Call( + "foo", + AddressOf(Expr(Source{{12, 34}}, MemberAccessor("v", "m"))))), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression"); +} + +TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) { + // fn foo(p: ptr) {} + // fn bar(p: ptr) { + // foo(p); + // } + Func("foo", {Param("p", ty.pointer(ast::StorageClass::kFunction))}, + ty.void_(), {}); + Func("bar", {Param("p", ty.pointer(ast::StorageClass::kFunction))}, + ty.void_(), + ast::StatementList{create(Call("foo", Expr("p")))}); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParamWithMain) { + // fn foo(p: ptr) {} + // fn bar(p: ptr) { + // foo(p); + // } + // [[stage(fragment)]] + // fn main() { + // var v: i32; + // bar(&v); + // } + Func("foo", {Param("p", ty.pointer(ast::StorageClass::kFunction))}, + ty.void_(), {}); + Func("bar", {Param("p", ty.pointer(ast::StorageClass::kFunction))}, + ty.void_(), + ast::StatementList{create(Call("foo", Expr("p")))}); + Func("main", ast::VariableList{}, ty.void_(), + { + Decl(Var("v", ty.i32(), Expr(1))), + create(Call("foo", AddressOf(Expr("v")))), + }, + { + Stage(ast::PipelineStage::kFragment), + }); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverCallValidationTest, LetPointer) { + // fn x(p : ptr) -> i32 {} + // [[stage(fragment)]] + // fn main() { + // var v: i32; + // let p: ptr = &v; + // var c: i32 = x(p); + // } + Func("x", {Param("p", ty.pointer(ast::StorageClass::kFunction))}, + ty.void_(), {}); + auto* v = Var("v", ty.i32()); + auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction), + AddressOf(v)); + auto* c = Var("c", ty.i32(), ast::StorageClass::kNone, + Call("x", Expr(Source{{12, 34}}, p))); + Func("main", ast::VariableList{}, ty.void_(), + { + Decl(v), + Decl(p), + Decl(c), + }, + { + Stage(ast::PipelineStage::kFragment), + }); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: expected an address-of expression of a variable " + "identifier expression or a function parameter"); +} + +TEST_F(ResolverCallValidationTest, LetPointerPrivate) { + // let p: ptr = &v; + // fn foo(p : ptr) -> i32 {} + // var v: i32; + // [[stage(fragment)]] + // fn main() { + // var c: i32 = foo(p); + // } + Func("foo", {Param("p", ty.pointer(ast::StorageClass::kPrivate))}, + ty.void_(), {}); + auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate); + auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kPrivate), + AddressOf(v)); + auto* c = Var("c", ty.i32(), ast::StorageClass::kNone, + Call("foo", Expr(Source{{12, 34}}, p))); + Func("main", ast::VariableList{}, ty.void_(), + { + Decl(p), + Decl(c), + }, + { + Stage(ast::PipelineStage::kFragment), + }); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: expected an address-of expression of a variable " + "identifier expression or a function parameter"); +} + } // namespace } // namespace resolver } // namespace tint diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 374f99babb..e485c05380 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -2578,7 +2578,21 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) { } } - // Validate number of arguments match number of parameters + function_calls_.emplace(call, + FunctionCallInfo{callee_func, current_statement_}); + SetExprInfo(call, callee_func->return_type, callee_func->return_type_name); + + if (!ValidateFunctionCall(call, callee_func)) { + return false; + } + return true; +} + +bool Resolver::ValidateFunctionCall(const ast::CallExpression* call, + const FunctionInfo* callee_func) { + auto* ident = call->func(); + auto name = builder_->Symbols().NameFor(ident->symbol()); + if (call->params().size() != callee_func->parameters.size()) { bool more = call->params().size() > callee_func->parameters.size(); AddError("too " + (more ? std::string("many") : std::string("few")) + @@ -2589,7 +2603,6 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) { return false; } - // Validate arguments match parameter types for (size_t i = 0; i < call->params().size(); ++i) { const VariableInfo* param = callee_func->parameters[i]; const ast::Expression* arg_expr = call->params()[i]; @@ -2603,12 +2616,48 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) { arg_expr->source()); return false; } + + if (param->declaration->type()->Is()) { + auto is_valid = false; + if (auto* ident_expr = arg_expr->As()) { + VariableInfo* var; + if (!variable_stack_.get(ident_expr->symbol(), &var)) { + TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; + return false; + } + if (var->kind == VariableKind::kParameter) { + is_valid = true; + } + } else if (auto* unary = arg_expr->As()) { + if (unary->op() == ast::UnaryOp::kAddressOf) { + if (auto* ident_unary = + unary->expr()->As()) { + VariableInfo* var; + if (!variable_stack_.get(ident_unary->symbol(), &var)) { + TINT_ICE(Resolver, diagnostics_) + << "failed to resolve identifier"; + return false; + } + if (var->declaration->is_const()) { + TINT_ICE(Resolver, diagnostics_) + << "Resolver::FunctionCall() encountered an address-of " + "expression of a constant identifier expression"; + return false; + } + is_valid = true; + } + } + } + + if (!is_valid) { + AddError( + "expected an address-of expression of a variable identifier " + "expression or a function parameter", + arg_expr->source()); + return false; + } + } } - - function_calls_.emplace(call, - FunctionCallInfo{callee_func, current_statement_}); - SetExprInfo(call, callee_func->return_type, callee_func->return_type_name); - return true; } diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index ad7e56d78e..0cbcc6f39d 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -287,6 +287,8 @@ class Resolver { bool ValidateCallStatement(ast::CallStatement* stmt); bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info); bool ValidateFunction(const ast::Function* func, const FunctionInfo* info); + bool ValidateFunctionCall(const ast::CallExpression* call, + const FunctionInfo* info); bool ValidateGlobalVariable(const VariableInfo* var); bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco, const sem::Type* storage_type); diff --git a/src/resolver/validation_test.cc b/src/resolver/validation_test.cc index 92bd17c8fb..c34fd7d6aa 100644 --- a/src/resolver/validation_test.cc +++ b/src/resolver/validation_test.cc @@ -811,20 +811,20 @@ TEST_F(ResolverValidationTest, Stmt_BreakNotInLoopOrSwitch) { } TEST_F(ResolverValidationTest, StructMemberDuplicateName) { - Structure("S", - {Member("a", ty.i32()), Member(Source{{12, 34}}, "a", ty.i32())}); + Structure("S", {Member(Source{{12, 34}}, "a", ty.i32()), + Member(Source{{56, 78}}, "a", ty.i32())}); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ( - r()->error(), - "12:34 error: redefinition of 'a'\nnote: previous definition is here"); + EXPECT_EQ(r()->error(), + "56:78 error: redefinition of 'a'\n12:34 note: previous definition " + "is here"); } TEST_F(ResolverValidationTest, StructMemberDuplicateNameDifferentTypes) { - Structure("S", {Member("a", ty.bool_()), + Structure("S", {Member(Source{{12, 34}}, "a", ty.bool_()), Member(Source{{12, 34}}, "a", ty.vec3())}); EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ( - r()->error(), - "12:34 error: redefinition of 'a'\nnote: previous definition is here"); + EXPECT_EQ(r()->error(), + "12:34 error: redefinition of 'a'\n12:34 note: previous definition " + "is here"); } TEST_F(ResolverValidationTest, StructMemberDuplicateNamePass) { Structure("S", {Member("a", ty.i32()), Member("b", ty.f32())}); diff --git a/src/transform/inline_pointer_lets_test.cc b/src/transform/inline_pointer_lets_test.cc index 91818b8cc7..7682f8796f 100644 --- a/src/transform/inline_pointer_lets_test.cc +++ b/src/transform/inline_pointer_lets_test.cc @@ -79,35 +79,6 @@ fn f() { EXPECT_EQ(expect, str(got)); } -TEST_F(InlinePointerLetsTest, Param) { - auto* src = R"( -fn x(p : ptr) -> i32 { - return *p; -} - -fn f() { - var v : i32; - let p : ptr = &v; - var r : i32 = x(p); -} -)"; - - auto* expect = R"( -fn x(p : ptr) -> i32 { - return *(p); -} - -fn f() { - var v : i32; - var r : i32 = x(&(v)); -} -)"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - TEST_F(InlinePointerLetsTest, SavedVars) { auto* src = R"( struct S { diff --git a/src/writer/hlsl/generator_impl_sanitizer_test.cc b/src/writer/hlsl/generator_impl_sanitizer_test.cc index 33c2cce167..3fb79159e2 100644 --- a/src/writer/hlsl/generator_impl_sanitizer_test.cc +++ b/src/writer/hlsl/generator_impl_sanitizer_test.cc @@ -286,54 +286,6 @@ TEST_F(HlslSanitizerTest, InlinePtrLetsComplexChain) { EXPECT_EQ(expect, got); } -TEST_F(HlslSanitizerTest, InlineParam) { - // fn x(p : ptr) -> i32 { - // return *p; - // } - // - // [[stage(fragment)]] - // fn main() { - // var v : i32; - // let p : ptr = &v; - // var r : i32 = x(p); - // } - - Func("x", {Param("p", ty.pointer(ast::StorageClass::kFunction))}, - ty.i32(), {Return(Deref("p"))}); - - auto* v = Var("v", ty.i32()); - auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction), - AddressOf(v)); - auto* r = Var("r", ty.i32(), ast::StorageClass::kNone, Call("x", p)); - - Func("main", ast::VariableList{}, ty.void_(), - { - Decl(v), - Decl(p), - Decl(r), - }, - { - Stage(ast::PipelineStage::kFragment), - }); - - GeneratorImpl& gen = SanitizeAndBuild(); - - ASSERT_TRUE(gen.Generate()) << gen.error(); - - auto got = gen.result(); - auto* expect = R"(int x(inout int p) { - return p; -} - -void main() { - int v = 0; - int r = x(v); - return; -} -)"; - EXPECT_EQ(expect, got); -} - } // namespace } // namespace hlsl } // namespace writer