validation: validate function call pointer parameter
Each argument of a function call of pointer type must be one of: - An address-of expression of a variable identifier expression - A function parameter Also added source location to duplicate struct member name unittest Bug: tint:983 Change-Id: Ic5ab010b2ed76207a1d8d3ef9f66140ea95f7e72 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58480 Auto-Submit: Sarah Mashayekhi <sarahmashay@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@chromium.org> Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
This commit is contained in:
parent
6d27f23451
commit
3d441d48bb
|
@ -128,6 +128,188 @@ TEST_F(ResolverCallValidationTest, UnusedRetval) {
|
||||||
"intentional wrap the function call in ignore()");
|
"intentional wrap the function call in ignore()");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) {
|
||||||
|
// fn foo(p: ptr<function, i32>) {}
|
||||||
|
// fn main() {
|
||||||
|
// var z: i32 = 1;
|
||||||
|
// foo(&z);
|
||||||
|
// }
|
||||||
|
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
|
||||||
|
Func("foo", {param}, ty.void_(), {});
|
||||||
|
Func("main", {}, ty.void_(),
|
||||||
|
ast::StatementList{
|
||||||
|
Decl(Var("z", ty.i32(), Expr(1))),
|
||||||
|
create<ast::CallStatement>(
|
||||||
|
Call("foo", AddressOf(Source{{12, 34}}, Expr("z")))),
|
||||||
|
});
|
||||||
|
|
||||||
|
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) {
|
||||||
|
// fn foo(p: ptr<function, i32>) {}
|
||||||
|
// fn main() {
|
||||||
|
// let z: i32 = 1;
|
||||||
|
// foo(&z);
|
||||||
|
// }
|
||||||
|
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
|
||||||
|
Func("foo", {param}, ty.void_(), {});
|
||||||
|
Func("main", {}, ty.void_(),
|
||||||
|
ast::StatementList{
|
||||||
|
Decl(Const("z", ty.i32(), Expr(1))),
|
||||||
|
create<ast::CallStatement>(
|
||||||
|
Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
|
||||||
|
});
|
||||||
|
|
||||||
|
EXPECT_FALSE(r()->Resolve());
|
||||||
|
EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
|
||||||
|
// struct S { m: i32; };
|
||||||
|
// fn foo(p: ptr<function, i32>) {}
|
||||||
|
// fn main() {
|
||||||
|
// var v: S;
|
||||||
|
// foo(&v.m);
|
||||||
|
// }
|
||||||
|
auto* S = Structure("S", {Member("m", ty.i32())});
|
||||||
|
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
|
||||||
|
Func("foo", {param}, ty.void_(), {});
|
||||||
|
Func("main", {}, ty.void_(),
|
||||||
|
ast::StatementList{
|
||||||
|
Decl(Var("v", ty.Of(S))),
|
||||||
|
create<ast::CallStatement>(Call(
|
||||||
|
"foo", AddressOf(Source{{12, 34}}, MemberAccessor("v", "m")))),
|
||||||
|
});
|
||||||
|
|
||||||
|
EXPECT_FALSE(r()->Resolve());
|
||||||
|
EXPECT_EQ(r()->error(),
|
||||||
|
"12:34 error: expected an address-of expression of a variable "
|
||||||
|
"identifier expression or a function parameter");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
|
||||||
|
// struct S { m: i32; };
|
||||||
|
// fn foo(p: ptr<function, i32>) {}
|
||||||
|
// fn main() {
|
||||||
|
// let v: S = S();
|
||||||
|
// foo(&v.m);
|
||||||
|
// }
|
||||||
|
auto* S = Structure("S", {Member("m", ty.i32())});
|
||||||
|
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
|
||||||
|
Func("foo", {param}, ty.void_(), {});
|
||||||
|
Func("main", {}, ty.void_(),
|
||||||
|
ast::StatementList{
|
||||||
|
Decl(Const("v", ty.Of(S), Construct(ty.Of(S)))),
|
||||||
|
create<ast::CallStatement>(Call(
|
||||||
|
"foo",
|
||||||
|
AddressOf(Expr(Source{{12, 34}}, MemberAccessor("v", "m"))))),
|
||||||
|
});
|
||||||
|
|
||||||
|
EXPECT_FALSE(r()->Resolve());
|
||||||
|
EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
|
||||||
|
// fn foo(p: ptr<function, i32>) {}
|
||||||
|
// fn bar(p: ptr<function, i32>) {
|
||||||
|
// foo(p);
|
||||||
|
// }
|
||||||
|
Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||||
|
ty.void_(), {});
|
||||||
|
Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||||
|
ty.void_(),
|
||||||
|
ast::StatementList{create<ast::CallStatement>(Call("foo", Expr("p")))});
|
||||||
|
|
||||||
|
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParamWithMain) {
|
||||||
|
// fn foo(p: ptr<function, i32>) {}
|
||||||
|
// fn bar(p: ptr<function, i32>) {
|
||||||
|
// foo(p);
|
||||||
|
// }
|
||||||
|
// [[stage(fragment)]]
|
||||||
|
// fn main() {
|
||||||
|
// var v: i32;
|
||||||
|
// bar(&v);
|
||||||
|
// }
|
||||||
|
Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||||
|
ty.void_(), {});
|
||||||
|
Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||||
|
ty.void_(),
|
||||||
|
ast::StatementList{create<ast::CallStatement>(Call("foo", Expr("p")))});
|
||||||
|
Func("main", ast::VariableList{}, ty.void_(),
|
||||||
|
{
|
||||||
|
Decl(Var("v", ty.i32(), Expr(1))),
|
||||||
|
create<ast::CallStatement>(Call("foo", AddressOf(Expr("v")))),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Stage(ast::PipelineStage::kFragment),
|
||||||
|
});
|
||||||
|
|
||||||
|
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverCallValidationTest, LetPointer) {
|
||||||
|
// fn x(p : ptr<function, i32>) -> i32 {}
|
||||||
|
// [[stage(fragment)]]
|
||||||
|
// fn main() {
|
||||||
|
// var v: i32;
|
||||||
|
// let p: ptr<function, i32> = &v;
|
||||||
|
// var c: i32 = x(p);
|
||||||
|
// }
|
||||||
|
Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
||||||
|
ty.void_(), {});
|
||||||
|
auto* v = Var("v", ty.i32());
|
||||||
|
auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
|
||||||
|
AddressOf(v));
|
||||||
|
auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
|
||||||
|
Call("x", Expr(Source{{12, 34}}, p)));
|
||||||
|
Func("main", ast::VariableList{}, ty.void_(),
|
||||||
|
{
|
||||||
|
Decl(v),
|
||||||
|
Decl(p),
|
||||||
|
Decl(c),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Stage(ast::PipelineStage::kFragment),
|
||||||
|
});
|
||||||
|
EXPECT_FALSE(r()->Resolve());
|
||||||
|
EXPECT_EQ(r()->error(),
|
||||||
|
"12:34 error: expected an address-of expression of a variable "
|
||||||
|
"identifier expression or a function parameter");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(ResolverCallValidationTest, LetPointerPrivate) {
|
||||||
|
// let p: ptr<private, i32> = &v;
|
||||||
|
// fn foo(p : ptr<private, i32>) -> i32 {}
|
||||||
|
// var v: i32;
|
||||||
|
// [[stage(fragment)]]
|
||||||
|
// fn main() {
|
||||||
|
// var c: i32 = foo(p);
|
||||||
|
// }
|
||||||
|
Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kPrivate))},
|
||||||
|
ty.void_(), {});
|
||||||
|
auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
|
||||||
|
auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kPrivate),
|
||||||
|
AddressOf(v));
|
||||||
|
auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
|
||||||
|
Call("foo", Expr(Source{{12, 34}}, p)));
|
||||||
|
Func("main", ast::VariableList{}, ty.void_(),
|
||||||
|
{
|
||||||
|
Decl(p),
|
||||||
|
Decl(c),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Stage(ast::PipelineStage::kFragment),
|
||||||
|
});
|
||||||
|
EXPECT_FALSE(r()->Resolve());
|
||||||
|
EXPECT_EQ(r()->error(),
|
||||||
|
"12:34 error: expected an address-of expression of a variable "
|
||||||
|
"identifier expression or a function parameter");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace resolver
|
} // namespace resolver
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
|
@ -2578,7 +2578,21 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate number of arguments match number of parameters
|
function_calls_.emplace(call,
|
||||||
|
FunctionCallInfo{callee_func, current_statement_});
|
||||||
|
SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
|
||||||
|
|
||||||
|
if (!ValidateFunctionCall(call, callee_func)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
|
||||||
|
const FunctionInfo* callee_func) {
|
||||||
|
auto* ident = call->func();
|
||||||
|
auto name = builder_->Symbols().NameFor(ident->symbol());
|
||||||
|
|
||||||
if (call->params().size() != callee_func->parameters.size()) {
|
if (call->params().size() != callee_func->parameters.size()) {
|
||||||
bool more = call->params().size() > callee_func->parameters.size();
|
bool more = call->params().size() > callee_func->parameters.size();
|
||||||
AddError("too " + (more ? std::string("many") : std::string("few")) +
|
AddError("too " + (more ? std::string("many") : std::string("few")) +
|
||||||
|
@ -2589,7 +2603,6 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate arguments match parameter types
|
|
||||||
for (size_t i = 0; i < call->params().size(); ++i) {
|
for (size_t i = 0; i < call->params().size(); ++i) {
|
||||||
const VariableInfo* param = callee_func->parameters[i];
|
const VariableInfo* param = callee_func->parameters[i];
|
||||||
const ast::Expression* arg_expr = call->params()[i];
|
const ast::Expression* arg_expr = call->params()[i];
|
||||||
|
@ -2603,12 +2616,48 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) {
|
||||||
arg_expr->source());
|
arg_expr->source());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (param->declaration->type()->Is<ast::Pointer>()) {
|
||||||
|
auto is_valid = false;
|
||||||
|
if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
|
||||||
|
VariableInfo* var;
|
||||||
|
if (!variable_stack_.get(ident_expr->symbol(), &var)) {
|
||||||
|
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (var->kind == VariableKind::kParameter) {
|
||||||
|
is_valid = true;
|
||||||
|
}
|
||||||
|
} else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
|
||||||
|
if (unary->op() == ast::UnaryOp::kAddressOf) {
|
||||||
|
if (auto* ident_unary =
|
||||||
|
unary->expr()->As<ast::IdentifierExpression>()) {
|
||||||
|
VariableInfo* var;
|
||||||
|
if (!variable_stack_.get(ident_unary->symbol(), &var)) {
|
||||||
|
TINT_ICE(Resolver, diagnostics_)
|
||||||
|
<< "failed to resolve identifier";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (var->declaration->is_const()) {
|
||||||
|
TINT_ICE(Resolver, diagnostics_)
|
||||||
|
<< "Resolver::FunctionCall() encountered an address-of "
|
||||||
|
"expression of a constant identifier expression";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
is_valid = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!is_valid) {
|
||||||
|
AddError(
|
||||||
|
"expected an address-of expression of a variable identifier "
|
||||||
|
"expression or a function parameter",
|
||||||
|
arg_expr->source());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function_calls_.emplace(call,
|
|
||||||
FunctionCallInfo{callee_func, current_statement_});
|
|
||||||
SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
|
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -287,6 +287,8 @@ class Resolver {
|
||||||
bool ValidateCallStatement(ast::CallStatement* stmt);
|
bool ValidateCallStatement(ast::CallStatement* stmt);
|
||||||
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
|
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
|
||||||
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
|
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
|
||||||
|
bool ValidateFunctionCall(const ast::CallExpression* call,
|
||||||
|
const FunctionInfo* info);
|
||||||
bool ValidateGlobalVariable(const VariableInfo* var);
|
bool ValidateGlobalVariable(const VariableInfo* var);
|
||||||
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
|
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
|
||||||
const sem::Type* storage_type);
|
const sem::Type* storage_type);
|
||||||
|
|
|
@ -811,20 +811,20 @@ TEST_F(ResolverValidationTest, Stmt_BreakNotInLoopOrSwitch) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(ResolverValidationTest, StructMemberDuplicateName) {
|
TEST_F(ResolverValidationTest, StructMemberDuplicateName) {
|
||||||
Structure("S",
|
Structure("S", {Member(Source{{12, 34}}, "a", ty.i32()),
|
||||||
{Member("a", ty.i32()), Member(Source{{12, 34}}, "a", ty.i32())});
|
Member(Source{{56, 78}}, "a", ty.i32())});
|
||||||
EXPECT_FALSE(r()->Resolve());
|
EXPECT_FALSE(r()->Resolve());
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(r()->error(),
|
||||||
r()->error(),
|
"56:78 error: redefinition of 'a'\n12:34 note: previous definition "
|
||||||
"12:34 error: redefinition of 'a'\nnote: previous definition is here");
|
"is here");
|
||||||
}
|
}
|
||||||
TEST_F(ResolverValidationTest, StructMemberDuplicateNameDifferentTypes) {
|
TEST_F(ResolverValidationTest, StructMemberDuplicateNameDifferentTypes) {
|
||||||
Structure("S", {Member("a", ty.bool_()),
|
Structure("S", {Member(Source{{12, 34}}, "a", ty.bool_()),
|
||||||
Member(Source{{12, 34}}, "a", ty.vec3<f32>())});
|
Member(Source{{12, 34}}, "a", ty.vec3<f32>())});
|
||||||
EXPECT_FALSE(r()->Resolve());
|
EXPECT_FALSE(r()->Resolve());
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(r()->error(),
|
||||||
r()->error(),
|
"12:34 error: redefinition of 'a'\n12:34 note: previous definition "
|
||||||
"12:34 error: redefinition of 'a'\nnote: previous definition is here");
|
"is here");
|
||||||
}
|
}
|
||||||
TEST_F(ResolverValidationTest, StructMemberDuplicateNamePass) {
|
TEST_F(ResolverValidationTest, StructMemberDuplicateNamePass) {
|
||||||
Structure("S", {Member("a", ty.i32()), Member("b", ty.f32())});
|
Structure("S", {Member("a", ty.i32()), Member("b", ty.f32())});
|
||||||
|
|
|
@ -79,35 +79,6 @@ fn f() {
|
||||||
EXPECT_EQ(expect, str(got));
|
EXPECT_EQ(expect, str(got));
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(InlinePointerLetsTest, Param) {
|
|
||||||
auto* src = R"(
|
|
||||||
fn x(p : ptr<function, i32>) -> i32 {
|
|
||||||
return *p;
|
|
||||||
}
|
|
||||||
|
|
||||||
fn f() {
|
|
||||||
var v : i32;
|
|
||||||
let p : ptr<function, i32> = &v;
|
|
||||||
var r : i32 = x(p);
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
|
|
||||||
auto* expect = R"(
|
|
||||||
fn x(p : ptr<function, i32>) -> i32 {
|
|
||||||
return *(p);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn f() {
|
|
||||||
var v : i32;
|
|
||||||
var r : i32 = x(&(v));
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
|
|
||||||
auto got = Run<InlinePointerLets>(src);
|
|
||||||
|
|
||||||
EXPECT_EQ(expect, str(got));
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(InlinePointerLetsTest, SavedVars) {
|
TEST_F(InlinePointerLetsTest, SavedVars) {
|
||||||
auto* src = R"(
|
auto* src = R"(
|
||||||
struct S {
|
struct S {
|
||||||
|
|
|
@ -286,54 +286,6 @@ TEST_F(HlslSanitizerTest, InlinePtrLetsComplexChain) {
|
||||||
EXPECT_EQ(expect, got);
|
EXPECT_EQ(expect, got);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(HlslSanitizerTest, InlineParam) {
|
|
||||||
// fn x(p : ptr<function, i32>) -> i32 {
|
|
||||||
// return *p;
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// [[stage(fragment)]]
|
|
||||||
// fn main() {
|
|
||||||
// var v : i32;
|
|
||||||
// let p : ptr<function, i32> = &v;
|
|
||||||
// var r : i32 = x(p);
|
|
||||||
// }
|
|
||||||
|
|
||||||
Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
|
|
||||||
ty.i32(), {Return(Deref("p"))});
|
|
||||||
|
|
||||||
auto* v = Var("v", ty.i32());
|
|
||||||
auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
|
|
||||||
AddressOf(v));
|
|
||||||
auto* r = Var("r", ty.i32(), ast::StorageClass::kNone, Call("x", p));
|
|
||||||
|
|
||||||
Func("main", ast::VariableList{}, ty.void_(),
|
|
||||||
{
|
|
||||||
Decl(v),
|
|
||||||
Decl(p),
|
|
||||||
Decl(r),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Stage(ast::PipelineStage::kFragment),
|
|
||||||
});
|
|
||||||
|
|
||||||
GeneratorImpl& gen = SanitizeAndBuild();
|
|
||||||
|
|
||||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
|
||||||
|
|
||||||
auto got = gen.result();
|
|
||||||
auto* expect = R"(int x(inout int p) {
|
|
||||||
return p;
|
|
||||||
}
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
int v = 0;
|
|
||||||
int r = x(v);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
)";
|
|
||||||
EXPECT_EQ(expect, got);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace hlsl
|
} // namespace hlsl
|
||||||
} // namespace writer
|
} // namespace writer
|
||||||
|
|
Loading…
Reference in New Issue