diff --git a/src/tint/resolver/call_validation_test.cc b/src/tint/resolver/call_validation_test.cc index 6c64fade93..82037af7c3 100644 --- a/src/tint/resolver/call_validation_test.cc +++ b/src/tint/resolver/call_validation_test.cc @@ -114,25 +114,7 @@ TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } -TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) { - // fn foo(p: ptr) {} - // fn main() { - // let z: i32 = 1i; - // foo(&z); - // } - auto* param = Param("p", ty.pointer(ast::AddressSpace::kFunction)); - Func("foo", utils::Vector{param}, ty.void_(), utils::Empty); - Func("main", utils::Empty, ty.void_(), - utils::Vector{ - Decl(Let("z", ty.i32(), Expr(1_i))), - CallStmt(Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))), - }); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression"); -} - -TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) { +TEST_F(ResolverCallValidationTest, PointerArgument_NotWholeVar) { // struct S { m: i32; }; // fn foo(p: ptr) {} // fn main() { @@ -152,30 +134,8 @@ TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: expected an address-of expression of a variable " - "identifier expression or a function parameter"); -} - -TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) { - // struct S { m: i32; }; - // fn foo(p: ptr) {} - // fn main() { - // let v: S = S(); - // foo(&v.m); - // } - auto* S = Structure("S", utils::Vector{ - Member("m", ty.i32()), - }); - auto* param = Param("p", ty.pointer(ast::AddressSpace::kFunction)); - Func("foo", utils::Vector{param}, ty.void_(), utils::Empty); - Func("main", utils::Empty, ty.void_(), - utils::Vector{ - Decl(Let("v", ty.Of(S), Construct(ty.Of(S)))), - CallStmt(Call("foo", AddressOf(MemberAccessor(Source{{12, 34}}, "v", "m")))), - }); - - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression"); + "12:34 error: arguments of pointer type must not point to a subset of the " + "originating variable"); } TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) { @@ -235,65 +195,169 @@ TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParamWithMain) { EXPECT_TRUE(r()->Resolve()) << r()->error(); } -TEST_F(ResolverCallValidationTest, LetPointer) { - // fn x(p : ptr) -> i32 {} - // @fragment - // fn main() { - // var v: i32; - // let p: ptr = &v; - // var c: i32 = x(p); +TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam_NotWholeVar) { + // fn foo(p: ptr) {} + // fn bar(p: ptr>) { + // foo(&(*p)[0]); // } - Func("x", + Func("foo", + utils::Vector{ + Param("p", ty.pointer(ast::AddressSpace::kFunction)), + }, + ty.void_(), utils::Empty); + Func("bar", + utils::Vector{ + Param("p", ty.pointer(ty.array(), ast::AddressSpace::kFunction)), + }, + ty.void_(), + utils::Vector{ + CallStmt(Call("foo", AddressOf(Source{{12, 34}}, IndexAccessor(Deref("p"), 0_a)))), + }); + + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: arguments of pointer type must not point to a subset of the " + "originating variable"); +} + +TEST_F(ResolverCallValidationTest, LetPointer) { + // fn foo(p : ptr) {} + // @fragment + // fn main() { + // var v: i32; + // let p: ptr = &v; + // foo(p); + // } + Func("foo", utils::Vector{ Param("p", ty.pointer(ast::AddressSpace::kFunction)), }, ty.void_(), utils::Empty); - auto* v = Var("v", ty.i32()); - auto* p = Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), AddressOf(v)); - auto* c = Var("c", ty.i32(), Call("x", Expr(Source{{12, 34}}, p))); Func("main", utils::Empty, ty.void_(), utils::Vector{ - Decl(v), - Decl(p), - Decl(c), + Decl(Var("v", ty.i32())), + Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), AddressOf("v"))), + CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))), }, utils::Vector{ Stage(ast::PipelineStage::kFragment), }); - EXPECT_FALSE(r()->Resolve()); - EXPECT_EQ(r()->error(), - "12:34 error: expected an address-of expression of a variable " - "identifier expression or a function parameter"); + EXPECT_TRUE(r()->Resolve()) << r()->error(); } TEST_F(ResolverCallValidationTest, LetPointerPrivate) { - // let p: ptr = &v; - // fn foo(p : ptr) -> i32 {} - // var v: i32; + // fn foo(p : ptr) {} + // var v: i32; // @fragment // fn main() { - // var c: i32 = foo(p); + // let p: ptr = &v; + // foo(p); // } Func("foo", utils::Vector{ Param("p", ty.pointer(ast::AddressSpace::kPrivate)), }, ty.void_(), utils::Empty); - auto* v = GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate); - auto* p = Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kPrivate), AddressOf(v)); - auto* c = Var("c", ty.i32(), Call("foo", Expr(Source{{12, 34}}, p))); + GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate); Func("main", utils::Empty, ty.void_(), utils::Vector{ - Decl(p), - Decl(c), + Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kPrivate), AddressOf("v"))), + CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))), + }, + utils::Vector{ + Stage(ast::PipelineStage::kFragment), + }); + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverCallValidationTest, LetPointer_NotWholeVar) { + // fn foo(p : ptr) {} + // @fragment + // fn main() { + // var v: array; + // let p: ptr = &(v[0]); + // x(p); + // } + Func("foo", + utils::Vector{ + Param("p", ty.pointer(ast::AddressSpace::kFunction)), + }, + ty.void_(), utils::Empty); + Func("main", utils::Empty, ty.void_(), + utils::Vector{ + Decl(Var("v", ty.array())), + Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), + AddressOf(IndexAccessor("v", 0_a)))), + CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))), }, utils::Vector{ Stage(ast::PipelineStage::kFragment), }); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), - "12:34 error: expected an address-of expression of a variable " - "identifier expression or a function parameter"); + "12:34 error: arguments of pointer type must not point to a subset of the " + "originating variable"); +} + +TEST_F(ResolverCallValidationTest, ComplexPointerChain) { + // fn foo(p : ptr>) {} + // @fragment + // fn main() { + // var v: array; + // let p1 = &v; + // let p2 = p1; + // let p3 = &*p2; + // foo(&*p); + // } + Func("foo", + utils::Vector{ + Param("p", ty.pointer(ty.array(), ast::AddressSpace::kFunction)), + }, + ty.void_(), utils::Empty); + Func("main", utils::Empty, ty.void_(), + utils::Vector{ + Decl(Var("v", ty.array())), + Decl(Let("p1", AddressOf("v"))), + Decl(Let("p2", Expr("p1"))), + Decl(Let("p3", AddressOf(Deref("p2")))), + CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Deref("p3")))), + }, + utils::Vector{ + Stage(ast::PipelineStage::kFragment), + }); + EXPECT_TRUE(r()->Resolve()) << r()->error(); +} + +TEST_F(ResolverCallValidationTest, ComplexPointerChain_NotWholeVar) { + // fn foo(p : ptr) {} + // @fragment + // fn main() { + // var v: array; + // let p1 = &v; + // let p2 = p1; + // let p3 = &(*p2)[0]; + // foo(&*p); + // } + Func("foo", + utils::Vector{ + Param("p", ty.pointer(ast::AddressSpace::kFunction)), + }, + ty.void_(), utils::Empty); + Func("main", utils::Empty, ty.void_(), + utils::Vector{ + Decl(Var("v", ty.array())), + Decl(Let("p1", AddressOf("v"))), + Decl(Let("p2", Expr("p1"))), + Decl(Let("p3", AddressOf(IndexAccessor(Deref("p2"), 0_a)))), + CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Deref("p3")))), + }, + utils::Vector{ + Stage(ast::PipelineStage::kFragment), + }); + EXPECT_FALSE(r()->Resolve()); + EXPECT_EQ(r()->error(), + "12:34 error: arguments of pointer type must not point to a subset of the " + "originating variable"); } TEST_F(ResolverCallValidationTest, CallVariable) { diff --git a/src/tint/resolver/validator.cc b/src/tint/resolver/validator.cc index f96af05ebc..f285b5f42a 100644 --- a/src/tint/resolver/validator.cc +++ b/src/tint/resolver/validator.cc @@ -1799,35 +1799,28 @@ bool Validator::FunctionCall(const sem::Call* call, sem::Statement* current_stat } if (param_type->Is()) { - auto is_valid = false; - if (auto* ident_expr = arg_expr->As()) { - auto* var = sem_.ResolvedSymbol(ident_expr); - if (!var) { - TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; - return false; - } - if (var->Is()) { - is_valid = true; - } - } else if (auto* unary = arg_expr->As()) { - if (unary->op == ast::UnaryOp::kAddressOf) { - if (auto* ident_unary = unary->expr->As()) { - auto* var = sem_.ResolvedSymbol(ident_unary); - if (!var) { - TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; - return false; - } - is_valid = true; - } - } + // https://gpuweb.github.io/gpuweb/wgsl/#function-restriction + // Each argument of pointer type to a user-defined function must have the same memory + // view as its root identifier. + // We can validate this by just comparing the store type of the argument with that of + // its root identifier, as these will match iff the memory view is the same. + auto* arg_store_type = arg_type->As()->StoreType(); + auto* root = call->Arguments()[i]->RootIdentifier(); + auto* root_ptr_ty = root->Type()->As(); + auto* root_ref_ty = root->Type()->As(); + TINT_ASSERT(Resolver, root_ptr_ty || root_ref_ty); + const sem::Type* root_store_type; + if (root_ptr_ty) { + root_store_type = root_ptr_ty->StoreType(); + } else { + root_store_type = root_ref_ty->StoreType(); } - - if (!is_valid && + if (root_store_type != arg_store_type && IsValidationEnabled(param->Declaration()->attributes, ast::DisabledValidation::kIgnoreInvalidPointerArgument)) { AddError( - "expected an address-of expression of a variable identifier expression or a " - "function parameter", + "arguments of pointer type must not point to a subset of the originating " + "variable", arg_expr->source); return false; }