validation: validate function call pointer parameter
Each argument of a function call of pointer type must be one of: - An address-of expression of a variable identifier expression - A function parameter Also added source location to duplicate struct member name unittest Bug: tint:983 Change-Id: Ic5ab010b2ed76207a1d8d3ef9f66140ea95f7e72 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58480 Auto-Submit: Sarah Mashayekhi <sarahmashay@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@chromium.org> Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
This commit is contained in:
parent
6d27f23451
commit
3d441d48bb
|
@ -128,6 +128,188 @@ TEST_F(ResolverCallValidationTest, UnusedRetval) {
|
|||
"intentional wrap the function call in ignore()");
|
||||
}
|
||||
|
||||
TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) {
|
||||
// fn foo(p: ptr<function, i32>) {}
|
||||
// fn main() {
|
||||
// var z: i32 = 1;
|
||||
// foo(&z);
|
||||
// }
|
||||
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
|
||||
Func("foo", {param}, ty.void_(), {});
|
||||
Func("main", {}, ty.void_(),
|
||||
ast::StatementList{
|
||||
Decl(Var("z", ty.i32(), Expr(1))),
|
||||
create<ast::CallStatement>(
|
||||
Call("foo", AddressOf(Source{{12, 34}}, Expr("z")))),
|
||||
});
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) {
|
||||
// fn foo(p: ptr<function, i32>) {}
|
||||
// fn main() {
|
||||
// let z: i32 = 1;
|
||||
// foo(&z);
|
||||
// }
|
||||
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
|
||||
Func("foo", {param}, ty.void_(), {});
|
||||
Func("main", {}, ty.void_(),
|
||||
ast::StatementList{
|
||||
Decl(Const("z", ty.i32(), Expr(1))),
|
||||
create<ast::CallStatement>(
|
||||
Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
|
||||
}
|
||||
|
||||
TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
|
||||
// struct S { m: i32; };
|
||||
// fn foo(p: ptr<function, i32>) {}
|
||||
// fn main() {
|
||||
// var v: S;
|
||||
// foo(&v.m);
|
||||
// }
|
||||
auto* S = Structure("S", {Member("m", ty.i32())});
|
||||
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
|
||||
Func("foo", {param}, ty.void_(), {});
|
||||
Func("main", {}, ty.void_(),
|
||||
ast::StatementList{
|
||||
Decl(Var("v", ty.Of(S))),
|
||||
create<ast::CallStatement>(Call(
|
||||
"foo", AddressOf(Source{{12, 34}}, MemberAccessor("v", "m")))),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: expected an address-of expression of a variable "
|
||||
"identifier expression or a function parameter");
|
||||
}
|
||||
|
||||
TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
|
||||
// struct S { m: i32; };
|
||||
// fn foo(p: ptr<function, i32>) {}
|
||||
// fn main() {
|
||||
// let v: S = S();
|
||||
// foo(&v.m);
|
||||
// }
|
||||
auto* S = Structure("S", {Member("m", ty.i32())});
|
||||
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
|
||||
Func("foo", {param}, ty.void_(), {});
|
||||
Func("main", {}, ty.void_(),
|
||||
ast::StatementList{
|
||||
Decl(Const("v", ty.Of(S), Construct(ty.Of(S)))),
|
||||
create<ast::CallStatement>(Call(
|
||||
"foo",
|
||||
AddressOf(Expr(Source{{12, 34}}, MemberAccessor("v", "m"))))),
|
||||
});
|
||||
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
|
||||
}
|
||||
|
||||
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
|
||||
// fn foo(p: ptr<function, i32>) {}
|
||||
// fn bar(p: ptr<function, i32>) {
|
||||
// foo(p);
|
||||
// }
|
||||
Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||
ty.void_(), {});
|
||||
Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||
ty.void_(),
|
||||
ast::StatementList{create<ast::CallStatement>(Call("foo", Expr("p")))});
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParamWithMain) {
|
||||
// fn foo(p: ptr<function, i32>) {}
|
||||
// fn bar(p: ptr<function, i32>) {
|
||||
// foo(p);
|
||||
// }
|
||||
// [[stage(fragment)]]
|
||||
// fn main() {
|
||||
// var v: i32;
|
||||
// bar(&v);
|
||||
// }
|
||||
Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||
ty.void_(), {});
|
||||
Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||
ty.void_(),
|
||||
ast::StatementList{create<ast::CallStatement>(Call("foo", Expr("p")))});
|
||||
Func("main", ast::VariableList{}, ty.void_(),
|
||||
{
|
||||
Decl(Var("v", ty.i32(), Expr(1))),
|
||||
create<ast::CallStatement>(Call("foo", AddressOf(Expr("v")))),
|
||||
},
|
||||
{
|
||||
Stage(ast::PipelineStage::kFragment),
|
||||
});
|
||||
|
||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||
}
|
||||
|
||||
TEST_F(ResolverCallValidationTest, LetPointer) {
|
||||
// fn x(p : ptr<function, i32>) -> i32 {}
|
||||
// [[stage(fragment)]]
|
||||
// fn main() {
|
||||
// var v: i32;
|
||||
// let p: ptr<function, i32> = &v;
|
||||
// var c: i32 = x(p);
|
||||
// }
|
||||
Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||
ty.void_(), {});
|
||||
auto* v = Var("v", ty.i32());
|
||||
auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
|
||||
AddressOf(v));
|
||||
auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
|
||||
Call("x", Expr(Source{{12, 34}}, p)));
|
||||
Func("main", ast::VariableList{}, ty.void_(),
|
||||
{
|
||||
Decl(v),
|
||||
Decl(p),
|
||||
Decl(c),
|
||||
},
|
||||
{
|
||||
Stage(ast::PipelineStage::kFragment),
|
||||
});
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: expected an address-of expression of a variable "
|
||||
"identifier expression or a function parameter");
|
||||
}
|
||||
|
||||
TEST_F(ResolverCallValidationTest, LetPointerPrivate) {
|
||||
// let p: ptr<private, i32> = &v;
|
||||
// fn foo(p : ptr<private, i32>) -> i32 {}
|
||||
// var v: i32;
|
||||
// [[stage(fragment)]]
|
||||
// fn main() {
|
||||
// var c: i32 = foo(p);
|
||||
// }
|
||||
Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kPrivate))},
|
||||
ty.void_(), {});
|
||||
auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
|
||||
auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kPrivate),
|
||||
AddressOf(v));
|
||||
auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
|
||||
Call("foo", Expr(Source{{12, 34}}, p)));
|
||||
Func("main", ast::VariableList{}, ty.void_(),
|
||||
{
|
||||
Decl(p),
|
||||
Decl(c),
|
||||
},
|
||||
{
|
||||
Stage(ast::PipelineStage::kFragment),
|
||||
});
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: expected an address-of expression of a variable "
|
||||
"identifier expression or a function parameter");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace resolver
|
||||
} // namespace tint
|
||||
|
|
|
@ -2578,7 +2578,21 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) {
|
|||
}
|
||||
}
|
||||
|
||||
// Validate number of arguments match number of parameters
|
||||
function_calls_.emplace(call,
|
||||
FunctionCallInfo{callee_func, current_statement_});
|
||||
SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
|
||||
|
||||
if (!ValidateFunctionCall(call, callee_func)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
|
||||
const FunctionInfo* callee_func) {
|
||||
auto* ident = call->func();
|
||||
auto name = builder_->Symbols().NameFor(ident->symbol());
|
||||
|
||||
if (call->params().size() != callee_func->parameters.size()) {
|
||||
bool more = call->params().size() > callee_func->parameters.size();
|
||||
AddError("too " + (more ? std::string("many") : std::string("few")) +
|
||||
|
@ -2589,7 +2603,6 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Validate arguments match parameter types
|
||||
for (size_t i = 0; i < call->params().size(); ++i) {
|
||||
const VariableInfo* param = callee_func->parameters[i];
|
||||
const ast::Expression* arg_expr = call->params()[i];
|
||||
|
@ -2603,12 +2616,48 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) {
|
|||
arg_expr->source());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (param->declaration->type()->Is<ast::Pointer>()) {
|
||||
auto is_valid = false;
|
||||
if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
|
||||
VariableInfo* var;
|
||||
if (!variable_stack_.get(ident_expr->symbol(), &var)) {
|
||||
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
|
||||
return false;
|
||||
}
|
||||
if (var->kind == VariableKind::kParameter) {
|
||||
is_valid = true;
|
||||
}
|
||||
} else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
|
||||
if (unary->op() == ast::UnaryOp::kAddressOf) {
|
||||
if (auto* ident_unary =
|
||||
unary->expr()->As<ast::IdentifierExpression>()) {
|
||||
VariableInfo* var;
|
||||
if (!variable_stack_.get(ident_unary->symbol(), &var)) {
|
||||
TINT_ICE(Resolver, diagnostics_)
|
||||
<< "failed to resolve identifier";
|
||||
return false;
|
||||
}
|
||||
if (var->declaration->is_const()) {
|
||||
TINT_ICE(Resolver, diagnostics_)
|
||||
<< "Resolver::FunctionCall() encountered an address-of "
|
||||
"expression of a constant identifier expression";
|
||||
return false;
|
||||
}
|
||||
is_valid = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function_calls_.emplace(call,
|
||||
FunctionCallInfo{callee_func, current_statement_});
|
||||
SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
|
||||
|
||||
if (!is_valid) {
|
||||
AddError(
|
||||
"expected an address-of expression of a variable identifier "
|
||||
"expression or a function parameter",
|
||||
arg_expr->source());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -287,6 +287,8 @@ class Resolver {
|
|||
bool ValidateCallStatement(ast::CallStatement* stmt);
|
||||
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
|
||||
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
|
||||
bool ValidateFunctionCall(const ast::CallExpression* call,
|
||||
const FunctionInfo* info);
|
||||
bool ValidateGlobalVariable(const VariableInfo* var);
|
||||
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
|
||||
const sem::Type* storage_type);
|
||||
|
|
|
@ -811,20 +811,20 @@ TEST_F(ResolverValidationTest, Stmt_BreakNotInLoopOrSwitch) {
|
|||
}
|
||||
|
||||
TEST_F(ResolverValidationTest, StructMemberDuplicateName) {
|
||||
Structure("S",
|
||||
{Member("a", ty.i32()), Member(Source{{12, 34}}, "a", ty.i32())});
|
||||
Structure("S", {Member(Source{{12, 34}}, "a", ty.i32()),
|
||||
Member(Source{{56, 78}}, "a", ty.i32())});
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(
|
||||
r()->error(),
|
||||
"12:34 error: redefinition of 'a'\nnote: previous definition is here");
|
||||
EXPECT_EQ(r()->error(),
|
||||
"56:78 error: redefinition of 'a'\n12:34 note: previous definition "
|
||||
"is here");
|
||||
}
|
||||
TEST_F(ResolverValidationTest, StructMemberDuplicateNameDifferentTypes) {
|
||||
Structure("S", {Member("a", ty.bool_()),
|
||||
Structure("S", {Member(Source{{12, 34}}, "a", ty.bool_()),
|
||||
Member(Source{{12, 34}}, "a", ty.vec3<f32>())});
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(
|
||||
r()->error(),
|
||||
"12:34 error: redefinition of 'a'\nnote: previous definition is here");
|
||||
EXPECT_EQ(r()->error(),
|
||||
"12:34 error: redefinition of 'a'\n12:34 note: previous definition "
|
||||
"is here");
|
||||
}
|
||||
TEST_F(ResolverValidationTest, StructMemberDuplicateNamePass) {
|
||||
Structure("S", {Member("a", ty.i32()), Member("b", ty.f32())});
|
||||
|
|
|
@ -79,35 +79,6 @@ fn f() {
|
|||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(InlinePointerLetsTest, Param) {
|
||||
auto* src = R"(
|
||||
fn x(p : ptr<function, i32>) -> i32 {
|
||||
return *p;
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var v : i32;
|
||||
let p : ptr<function, i32> = &v;
|
||||
var r : i32 = x(p);
|
||||
}
|
||||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
fn x(p : ptr<function, i32>) -> i32 {
|
||||
return *(p);
|
||||
}
|
||||
|
||||
fn f() {
|
||||
var v : i32;
|
||||
var r : i32 = x(&(v));
|
||||
}
|
||||
)";
|
||||
|
||||
auto got = Run<InlinePointerLets>(src);
|
||||
|
||||
EXPECT_EQ(expect, str(got));
|
||||
}
|
||||
|
||||
TEST_F(InlinePointerLetsTest, SavedVars) {
|
||||
auto* src = R"(
|
||||
struct S {
|
||||
|
|
|
@ -286,54 +286,6 @@ TEST_F(HlslSanitizerTest, InlinePtrLetsComplexChain) {
|
|||
EXPECT_EQ(expect, got);
|
||||
}
|
||||
|
||||
TEST_F(HlslSanitizerTest, InlineParam) {
|
||||
// fn x(p : ptr<function, i32>) -> i32 {
|
||||
// return *p;
|
||||
// }
|
||||
//
|
||||
// [[stage(fragment)]]
|
||||
// fn main() {
|
||||
// var v : i32;
|
||||
// let p : ptr<function, i32> = &v;
|
||||
// var r : i32 = x(p);
|
||||
// }
|
||||
|
||||
Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||
ty.i32(), {Return(Deref("p"))});
|
||||
|
||||
auto* v = Var("v", ty.i32());
|
||||
auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
|
||||
AddressOf(v));
|
||||
auto* r = Var("r", ty.i32(), ast::StorageClass::kNone, Call("x", p));
|
||||
|
||||
Func("main", ast::VariableList{}, ty.void_(),
|
||||
{
|
||||
Decl(v),
|
||||
Decl(p),
|
||||
Decl(r),
|
||||
},
|
||||
{
|
||||
Stage(ast::PipelineStage::kFragment),
|
||||
});
|
||||
|
||||
GeneratorImpl& gen = SanitizeAndBuild();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
auto got = gen.result();
|
||||
auto* expect = R"(int x(inout int p) {
|
||||
return p;
|
||||
}
|
||||
|
||||
void main() {
|
||||
int v = 0;
|
||||
int r = x(v);
|
||||
return;
|
||||
}
|
||||
)";
|
||||
EXPECT_EQ(expect, got);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace hlsl
|
||||
} // namespace writer
|
||||
|
|
Loading…
Reference in New Issue