tint: Allow captured pointers as function args

The updated WGSL validation rule now requires that the memory view of
the argument matches its root identifier.

This allows for code like this:
   let p = &v;
   foo(p);

Fixed: tint:1754, tint:1734
Change-Id: I3239ec84e1c06398a6ce5bebb1e0b28986764bc6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/109221
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price 2022-11-09 16:47:19 +00:00 committed by Dawn LUCI CQ
parent 6251598ad2
commit 97519c2cd3
2 changed files with 154 additions and 97 deletions

View File

@ -114,25 +114,7 @@ TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) {
// fn foo(p: ptr<function, i32>) {}
// fn main() {
// let z: i32 = 1i;
// foo(&z);
// }
auto* param = Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction));
Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Let("z", ty.i32(), Expr(1_i))),
CallStmt(Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
}
TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
TEST_F(ResolverCallValidationTest, PointerArgument_NotWholeVar) {
// struct S { m: i32; };
// fn foo(p: ptr<function, i32>) {}
// fn main() {
@ -152,30 +134,8 @@ TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: expected an address-of expression of a variable "
"identifier expression or a function parameter");
}
TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
// struct S { m: i32; };
// fn foo(p: ptr<function, i32>) {}
// fn main() {
// let v: S = S();
// foo(&v.m);
// }
auto* S = Structure("S", utils::Vector{
Member("m", ty.i32()),
});
auto* param = Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction));
Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Let("v", ty.Of(S), Construct(ty.Of(S)))),
CallStmt(Call("foo", AddressOf(MemberAccessor(Source{{12, 34}}, "v", "m")))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
"12:34 error: arguments of pointer type must not point to a subset of the "
"originating variable");
}
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
@ -235,65 +195,169 @@ TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParamWithMain) {
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, LetPointer) {
// fn x(p : ptr<function, i32>) -> i32 {}
// @fragment
// fn main() {
// var v: i32;
// let p: ptr<function, i32> = &v;
// var c: i32 = x(p);
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam_NotWholeVar) {
// fn foo(p: ptr<function, i32>) {}
// fn bar(p: ptr<function, array<i32, 4>>) {
// foo(&(*p)[0]);
// }
Func("x",
Func("foo",
utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("bar",
utils::Vector{
Param("p", ty.pointer(ty.array<i32, 4>(), ast::AddressSpace::kFunction)),
},
ty.void_(),
utils::Vector{
CallStmt(Call("foo", AddressOf(Source{{12, 34}}, IndexAccessor(Deref("p"), 0_a)))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: arguments of pointer type must not point to a subset of the "
"originating variable");
}
TEST_F(ResolverCallValidationTest, LetPointer) {
// fn foo(p : ptr<function, i32>) {}
// @fragment
// fn main() {
// var v: i32;
// let p: ptr<function, i32> = &v;
// foo(p);
// }
Func("foo",
utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
auto* v = Var("v", ty.i32());
auto* p = Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), AddressOf(v));
auto* c = Var("c", ty.i32(), Call("x", Expr(Source{{12, 34}}, p)));
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(v),
Decl(p),
Decl(c),
Decl(Var("v", ty.i32())),
Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), AddressOf("v"))),
CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
},
utils::Vector{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: expected an address-of expression of a variable "
"identifier expression or a function parameter");
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, LetPointerPrivate) {
// let p: ptr<private, i32> = &v;
// fn foo(p : ptr<private, i32>) -> i32 {}
// var v: i32;
// fn foo(p : ptr<private, i32>) {}
// var<private> v: i32;
// @fragment
// fn main() {
// var c: i32 = foo(p);
// let p: ptr<private, i32> = &v;
// foo(p);
// }
Func("foo",
utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kPrivate)),
},
ty.void_(), utils::Empty);
auto* v = GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate);
auto* p = Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kPrivate), AddressOf(v));
auto* c = Var("c", ty.i32(), Call("foo", Expr(Source{{12, 34}}, p)));
GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(p),
Decl(c),
Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kPrivate), AddressOf("v"))),
CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
},
utils::Vector{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, LetPointer_NotWholeVar) {
// fn foo(p : ptr<function, i32>) {}
// @fragment
// fn main() {
// var v: array<i32, 4>;
// let p: ptr<function, i32> = &(v[0]);
// x(p);
// }
Func("foo",
utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("v", ty.array<i32, 4>())),
Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction),
AddressOf(IndexAccessor("v", 0_a)))),
CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
},
utils::Vector{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: expected an address-of expression of a variable "
"identifier expression or a function parameter");
"12:34 error: arguments of pointer type must not point to a subset of the "
"originating variable");
}
TEST_F(ResolverCallValidationTest, ComplexPointerChain) {
// fn foo(p : ptr<function, array<i32, 4>>) {}
// @fragment
// fn main() {
// var v: array<i32, 4>;
// let p1 = &v;
// let p2 = p1;
// let p3 = &*p2;
// foo(&*p);
// }
Func("foo",
utils::Vector{
Param("p", ty.pointer(ty.array<i32, 4>(), ast::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("v", ty.array<i32, 4>())),
Decl(Let("p1", AddressOf("v"))),
Decl(Let("p2", Expr("p1"))),
Decl(Let("p3", AddressOf(Deref("p2")))),
CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Deref("p3")))),
},
utils::Vector{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, ComplexPointerChain_NotWholeVar) {
// fn foo(p : ptr<function, i32>) {}
// @fragment
// fn main() {
// var v: array<i32, 4>;
// let p1 = &v;
// let p2 = p1;
// let p3 = &(*p2)[0];
// foo(&*p);
// }
Func("foo",
utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("v", ty.array<i32, 4>())),
Decl(Let("p1", AddressOf("v"))),
Decl(Let("p2", Expr("p1"))),
Decl(Let("p3", AddressOf(IndexAccessor(Deref("p2"), 0_a)))),
CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Deref("p3")))),
},
utils::Vector{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: arguments of pointer type must not point to a subset of the "
"originating variable");
}
TEST_F(ResolverCallValidationTest, CallVariable) {

View File

@ -1799,35 +1799,28 @@ bool Validator::FunctionCall(const sem::Call* call, sem::Statement* current_stat
}
if (param_type->Is<sem::Pointer>()) {
auto is_valid = false;
if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_expr);
if (!var) {
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false;
}
if (var->Is<sem::Parameter>()) {
is_valid = true;
}
} else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
if (unary->op == ast::UnaryOp::kAddressOf) {
if (auto* ident_unary = unary->expr->As<ast::IdentifierExpression>()) {
auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_unary);
if (!var) {
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false;
}
is_valid = true;
}
}
// https://gpuweb.github.io/gpuweb/wgsl/#function-restriction
// Each argument of pointer type to a user-defined function must have the same memory
// view as its root identifier.
// We can validate this by just comparing the store type of the argument with that of
// its root identifier, as these will match iff the memory view is the same.
auto* arg_store_type = arg_type->As<sem::Pointer>()->StoreType();
auto* root = call->Arguments()[i]->RootIdentifier();
auto* root_ptr_ty = root->Type()->As<sem::Pointer>();
auto* root_ref_ty = root->Type()->As<sem::Reference>();
TINT_ASSERT(Resolver, root_ptr_ty || root_ref_ty);
const sem::Type* root_store_type;
if (root_ptr_ty) {
root_store_type = root_ptr_ty->StoreType();
} else {
root_store_type = root_ref_ty->StoreType();
}
if (!is_valid &&
if (root_store_type != arg_store_type &&
IsValidationEnabled(param->Declaration()->attributes,
ast::DisabledValidation::kIgnoreInvalidPointerArgument)) {
AddError(
"expected an address-of expression of a variable identifier expression or a "
"function parameter",
"arguments of pointer type must not point to a subset of the originating "
"variable",
arg_expr->source);
return false;
}