tint: Allow captured pointers as function args

The updated WGSL validation rule now requires that the memory view of
the argument matches its root identifier.

This allows for code like this:
   let p = &v;
   foo(p);

Fixed: tint:1754, tint:1734
Change-Id: I3239ec84e1c06398a6ce5bebb1e0b28986764bc6
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/109221
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price 2022-11-09 16:47:19 +00:00 committed by Dawn LUCI CQ
parent 6251598ad2
commit 97519c2cd3
2 changed files with 154 additions and 97 deletions

View File

@ -114,25 +114,7 @@ TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) {
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) { TEST_F(ResolverCallValidationTest, PointerArgument_NotWholeVar) {
// fn foo(p: ptr<function, i32>) {}
// fn main() {
// let z: i32 = 1i;
// foo(&z);
// }
auto* param = Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction));
Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Let("z", ty.i32(), Expr(1_i))),
CallStmt(Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
}
TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
// struct S { m: i32; }; // struct S { m: i32; };
// fn foo(p: ptr<function, i32>) {} // fn foo(p: ptr<function, i32>) {}
// fn main() { // fn main() {
@ -152,30 +134,8 @@ TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: expected an address-of expression of a variable " "12:34 error: arguments of pointer type must not point to a subset of the "
"identifier expression or a function parameter"); "originating variable");
}
TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
// struct S { m: i32; };
// fn foo(p: ptr<function, i32>) {}
// fn main() {
// let v: S = S();
// foo(&v.m);
// }
auto* S = Structure("S", utils::Vector{
Member("m", ty.i32()),
});
auto* param = Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction));
Func("foo", utils::Vector{param}, ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Let("v", ty.Of(S), Construct(ty.Of(S)))),
CallStmt(Call("foo", AddressOf(MemberAccessor(Source{{12, 34}}, "v", "m")))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
} }
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) { TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
@ -235,65 +195,169 @@ TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParamWithMain) {
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
} }
TEST_F(ResolverCallValidationTest, LetPointer) { TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam_NotWholeVar) {
// fn x(p : ptr<function, i32>) -> i32 {} // fn foo(p: ptr<function, i32>) {}
// @fragment // fn bar(p: ptr<function, array<i32, 4>>) {
// fn main() { // foo(&(*p)[0]);
// var v: i32;
// let p: ptr<function, i32> = &v;
// var c: i32 = x(p);
// } // }
Func("x", Func("foo",
utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("bar",
utils::Vector{
Param("p", ty.pointer(ty.array<i32, 4>(), ast::AddressSpace::kFunction)),
},
ty.void_(),
utils::Vector{
CallStmt(Call("foo", AddressOf(Source{{12, 34}}, IndexAccessor(Deref("p"), 0_a)))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: arguments of pointer type must not point to a subset of the "
"originating variable");
}
TEST_F(ResolverCallValidationTest, LetPointer) {
// fn foo(p : ptr<function, i32>) {}
// @fragment
// fn main() {
// var v: i32;
// let p: ptr<function, i32> = &v;
// foo(p);
// }
Func("foo",
utils::Vector{ utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)), Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
}, },
ty.void_(), utils::Empty); ty.void_(), utils::Empty);
auto* v = Var("v", ty.i32());
auto* p = Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), AddressOf(v));
auto* c = Var("c", ty.i32(), Call("x", Expr(Source{{12, 34}}, p)));
Func("main", utils::Empty, ty.void_(), Func("main", utils::Empty, ty.void_(),
utils::Vector{ utils::Vector{
Decl(v), Decl(Var("v", ty.i32())),
Decl(p), Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction), AddressOf("v"))),
Decl(c), CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
}, },
utils::Vector{ utils::Vector{
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_TRUE(r()->Resolve()) << r()->error();
EXPECT_EQ(r()->error(),
"12:34 error: expected an address-of expression of a variable "
"identifier expression or a function parameter");
} }
TEST_F(ResolverCallValidationTest, LetPointerPrivate) { TEST_F(ResolverCallValidationTest, LetPointerPrivate) {
// let p: ptr<private, i32> = &v; // fn foo(p : ptr<private, i32>) {}
// fn foo(p : ptr<private, i32>) -> i32 {} // var<private> v: i32;
// var v: i32;
// @fragment // @fragment
// fn main() { // fn main() {
// var c: i32 = foo(p); // let p: ptr<private, i32> = &v;
// foo(p);
// } // }
Func("foo", Func("foo",
utils::Vector{ utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kPrivate)), Param("p", ty.pointer<i32>(ast::AddressSpace::kPrivate)),
}, },
ty.void_(), utils::Empty); ty.void_(), utils::Empty);
auto* v = GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate); GlobalVar("v", ty.i32(), ast::AddressSpace::kPrivate);
auto* p = Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kPrivate), AddressOf(v));
auto* c = Var("c", ty.i32(), Call("foo", Expr(Source{{12, 34}}, p)));
Func("main", utils::Empty, ty.void_(), Func("main", utils::Empty, ty.void_(),
utils::Vector{ utils::Vector{
Decl(p), Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kPrivate), AddressOf("v"))),
Decl(c), CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
},
utils::Vector{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, LetPointer_NotWholeVar) {
// fn foo(p : ptr<function, i32>) {}
// @fragment
// fn main() {
// var v: array<i32, 4>;
// let p: ptr<function, i32> = &(v[0]);
// x(p);
// }
Func("foo",
utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("v", ty.array<i32, 4>())),
Decl(Let("p", ty.pointer(ty.i32(), ast::AddressSpace::kFunction),
AddressOf(IndexAccessor("v", 0_a)))),
CallStmt(Call("foo", Expr(Source{{12, 34}}, "p"))),
}, },
utils::Vector{ utils::Vector{
Stage(ast::PipelineStage::kFragment), Stage(ast::PipelineStage::kFragment),
}); });
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), EXPECT_EQ(r()->error(),
"12:34 error: expected an address-of expression of a variable " "12:34 error: arguments of pointer type must not point to a subset of the "
"identifier expression or a function parameter"); "originating variable");
}
TEST_F(ResolverCallValidationTest, ComplexPointerChain) {
// fn foo(p : ptr<function, array<i32, 4>>) {}
// @fragment
// fn main() {
// var v: array<i32, 4>;
// let p1 = &v;
// let p2 = p1;
// let p3 = &*p2;
// foo(&*p);
// }
Func("foo",
utils::Vector{
Param("p", ty.pointer(ty.array<i32, 4>(), ast::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("v", ty.array<i32, 4>())),
Decl(Let("p1", AddressOf("v"))),
Decl(Let("p2", Expr("p1"))),
Decl(Let("p3", AddressOf(Deref("p2")))),
CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Deref("p3")))),
},
utils::Vector{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, ComplexPointerChain_NotWholeVar) {
// fn foo(p : ptr<function, i32>) {}
// @fragment
// fn main() {
// var v: array<i32, 4>;
// let p1 = &v;
// let p2 = p1;
// let p3 = &(*p2)[0];
// foo(&*p);
// }
Func("foo",
utils::Vector{
Param("p", ty.pointer<i32>(ast::AddressSpace::kFunction)),
},
ty.void_(), utils::Empty);
Func("main", utils::Empty, ty.void_(),
utils::Vector{
Decl(Var("v", ty.array<i32, 4>())),
Decl(Let("p1", AddressOf("v"))),
Decl(Let("p2", Expr("p1"))),
Decl(Let("p3", AddressOf(IndexAccessor(Deref("p2"), 0_a)))),
CallStmt(Call("foo", AddressOf(Source{{12, 34}}, Deref("p3")))),
},
utils::Vector{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: arguments of pointer type must not point to a subset of the "
"originating variable");
} }
TEST_F(ResolverCallValidationTest, CallVariable) { TEST_F(ResolverCallValidationTest, CallVariable) {

View File

@ -1799,35 +1799,28 @@ bool Validator::FunctionCall(const sem::Call* call, sem::Statement* current_stat
} }
if (param_type->Is<sem::Pointer>()) { if (param_type->Is<sem::Pointer>()) {
auto is_valid = false; // https://gpuweb.github.io/gpuweb/wgsl/#function-restriction
if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) { // Each argument of pointer type to a user-defined function must have the same memory
auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_expr); // view as its root identifier.
if (!var) { // We can validate this by just comparing the store type of the argument with that of
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier"; // its root identifier, as these will match iff the memory view is the same.
return false; auto* arg_store_type = arg_type->As<sem::Pointer>()->StoreType();
} auto* root = call->Arguments()[i]->RootIdentifier();
if (var->Is<sem::Parameter>()) { auto* root_ptr_ty = root->Type()->As<sem::Pointer>();
is_valid = true; auto* root_ref_ty = root->Type()->As<sem::Reference>();
} TINT_ASSERT(Resolver, root_ptr_ty || root_ref_ty);
} else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) { const sem::Type* root_store_type;
if (unary->op == ast::UnaryOp::kAddressOf) { if (root_ptr_ty) {
if (auto* ident_unary = unary->expr->As<ast::IdentifierExpression>()) { root_store_type = root_ptr_ty->StoreType();
auto* var = sem_.ResolvedSymbol<sem::Variable>(ident_unary); } else {
if (!var) { root_store_type = root_ref_ty->StoreType();
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false;
}
is_valid = true;
}
}
} }
if (root_store_type != arg_store_type &&
if (!is_valid &&
IsValidationEnabled(param->Declaration()->attributes, IsValidationEnabled(param->Declaration()->attributes,
ast::DisabledValidation::kIgnoreInvalidPointerArgument)) { ast::DisabledValidation::kIgnoreInvalidPointerArgument)) {
AddError( AddError(
"expected an address-of expression of a variable identifier expression or a " "arguments of pointer type must not point to a subset of the originating "
"function parameter", "variable",
arg_expr->source); arg_expr->source);
return false; return false;
} }