validation: validate function call pointer parameter

Each argument of a function call of pointer type must be one of:
- An address-of expression of a variable identifier expression
- A function parameter
Also added source location to duplicate struct member name unittest

Bug: tint:983
Change-Id: Ic5ab010b2ed76207a1d8d3ef9f66140ea95f7e72
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/58480
Auto-Submit: Sarah Mashayekhi <sarahmashay@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@chromium.org>
Commit-Queue: Sarah Mashayekhi <sarahmashay@google.com>
This commit is contained in:
Sarah 2021-07-20 18:14:02 +00:00 committed by Tint LUCI CQ
parent 6d27f23451
commit 3d441d48bb
6 changed files with 249 additions and 93 deletions

View File

@ -128,6 +128,188 @@ TEST_F(ResolverCallValidationTest, UnusedRetval) {
"intentional wrap the function call in ignore()"); "intentional wrap the function call in ignore()");
} }
TEST_F(ResolverCallValidationTest, PointerArgument_VariableIdentExpr) {
// fn foo(p: ptr<function, i32>) {}
// fn main() {
// var z: i32 = 1;
// foo(&z);
// }
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
Func("foo", {param}, ty.void_(), {});
Func("main", {}, ty.void_(),
ast::StatementList{
Decl(Var("z", ty.i32(), Expr(1))),
create<ast::CallStatement>(
Call("foo", AddressOf(Source{{12, 34}}, Expr("z")))),
});
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, PointerArgument_ConstIdentExpr) {
// fn foo(p: ptr<function, i32>) {}
// fn main() {
// let z: i32 = 1;
// foo(&z);
// }
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
Func("foo", {param}, ty.void_(), {});
Func("main", {}, ty.void_(),
ast::StatementList{
Decl(Const("z", ty.i32(), Expr(1))),
create<ast::CallStatement>(
Call("foo", AddressOf(Expr(Source{{12, 34}}, "z")))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
}
TEST_F(ResolverCallValidationTest, PointerArgument_NotIdentExprVar) {
// struct S { m: i32; };
// fn foo(p: ptr<function, i32>) {}
// fn main() {
// var v: S;
// foo(&v.m);
// }
auto* S = Structure("S", {Member("m", ty.i32())});
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
Func("foo", {param}, ty.void_(), {});
Func("main", {}, ty.void_(),
ast::StatementList{
Decl(Var("v", ty.Of(S))),
create<ast::CallStatement>(Call(
"foo", AddressOf(Source{{12, 34}}, MemberAccessor("v", "m")))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: expected an address-of expression of a variable "
"identifier expression or a function parameter");
}
TEST_F(ResolverCallValidationTest, PointerArgument_AddressOfMemberAccessor) {
// struct S { m: i32; };
// fn foo(p: ptr<function, i32>) {}
// fn main() {
// let v: S = S();
// foo(&v.m);
// }
auto* S = Structure("S", {Member("m", ty.i32())});
auto* param = Param("p", ty.pointer<i32>(ast::StorageClass::kFunction));
Func("foo", {param}, ty.void_(), {});
Func("main", {}, ty.void_(),
ast::StatementList{
Decl(Const("v", ty.Of(S), Construct(ty.Of(S)))),
create<ast::CallStatement>(Call(
"foo",
AddressOf(Expr(Source{{12, 34}}, MemberAccessor("v", "m"))))),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "12:34 error: cannot take the address of expression");
}
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParam) {
// fn foo(p: ptr<function, i32>) {}
// fn bar(p: ptr<function, i32>) {
// foo(p);
// }
Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
ty.void_(), {});
Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
ty.void_(),
ast::StatementList{create<ast::CallStatement>(Call("foo", Expr("p")))});
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, PointerArgument_FunctionParamWithMain) {
// fn foo(p: ptr<function, i32>) {}
// fn bar(p: ptr<function, i32>) {
// foo(p);
// }
// [[stage(fragment)]]
// fn main() {
// var v: i32;
// bar(&v);
// }
Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
ty.void_(), {});
Func("bar", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
ty.void_(),
ast::StatementList{create<ast::CallStatement>(Call("foo", Expr("p")))});
Func("main", ast::VariableList{}, ty.void_(),
{
Decl(Var("v", ty.i32(), Expr(1))),
create<ast::CallStatement>(Call("foo", AddressOf(Expr("v")))),
},
{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallValidationTest, LetPointer) {
// fn x(p : ptr<function, i32>) -> i32 {}
// [[stage(fragment)]]
// fn main() {
// var v: i32;
// let p: ptr<function, i32> = &v;
// var c: i32 = x(p);
// }
Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
ty.void_(), {});
auto* v = Var("v", ty.i32());
auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
AddressOf(v));
auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
Call("x", Expr(Source{{12, 34}}, p)));
Func("main", ast::VariableList{}, ty.void_(),
{
Decl(v),
Decl(p),
Decl(c),
},
{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: expected an address-of expression of a variable "
"identifier expression or a function parameter");
}
TEST_F(ResolverCallValidationTest, LetPointerPrivate) {
// let p: ptr<private, i32> = &v;
// fn foo(p : ptr<private, i32>) -> i32 {}
// var v: i32;
// [[stage(fragment)]]
// fn main() {
// var c: i32 = foo(p);
// }
Func("foo", {Param("p", ty.pointer<i32>(ast::StorageClass::kPrivate))},
ty.void_(), {});
auto* v = Global("v", ty.i32(), ast::StorageClass::kPrivate);
auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kPrivate),
AddressOf(v));
auto* c = Var("c", ty.i32(), ast::StorageClass::kNone,
Call("foo", Expr(Source{{12, 34}}, p)));
Func("main", ast::VariableList{}, ty.void_(),
{
Decl(p),
Decl(c),
},
{
Stage(ast::PipelineStage::kFragment),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: expected an address-of expression of a variable "
"identifier expression or a function parameter");
}
} // namespace } // namespace
} // namespace resolver } // namespace resolver
} // namespace tint } // namespace tint

View File

@ -2578,7 +2578,21 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) {
} }
} }
// Validate number of arguments match number of parameters function_calls_.emplace(call,
FunctionCallInfo{callee_func, current_statement_});
SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
if (!ValidateFunctionCall(call, callee_func)) {
return false;
}
return true;
}
bool Resolver::ValidateFunctionCall(const ast::CallExpression* call,
const FunctionInfo* callee_func) {
auto* ident = call->func();
auto name = builder_->Symbols().NameFor(ident->symbol());
if (call->params().size() != callee_func->parameters.size()) { if (call->params().size() != callee_func->parameters.size()) {
bool more = call->params().size() > callee_func->parameters.size(); bool more = call->params().size() > callee_func->parameters.size();
AddError("too " + (more ? std::string("many") : std::string("few")) + AddError("too " + (more ? std::string("many") : std::string("few")) +
@ -2589,7 +2603,6 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) {
return false; return false;
} }
// Validate arguments match parameter types
for (size_t i = 0; i < call->params().size(); ++i) { for (size_t i = 0; i < call->params().size(); ++i) {
const VariableInfo* param = callee_func->parameters[i]; const VariableInfo* param = callee_func->parameters[i];
const ast::Expression* arg_expr = call->params()[i]; const ast::Expression* arg_expr = call->params()[i];
@ -2603,12 +2616,48 @@ bool Resolver::FunctionCall(const ast::CallExpression* call) {
arg_expr->source()); arg_expr->source());
return false; return false;
} }
if (param->declaration->type()->Is<ast::Pointer>()) {
auto is_valid = false;
if (auto* ident_expr = arg_expr->As<ast::IdentifierExpression>()) {
VariableInfo* var;
if (!variable_stack_.get(ident_expr->symbol(), &var)) {
TINT_ICE(Resolver, diagnostics_) << "failed to resolve identifier";
return false;
}
if (var->kind == VariableKind::kParameter) {
is_valid = true;
}
} else if (auto* unary = arg_expr->As<ast::UnaryOpExpression>()) {
if (unary->op() == ast::UnaryOp::kAddressOf) {
if (auto* ident_unary =
unary->expr()->As<ast::IdentifierExpression>()) {
VariableInfo* var;
if (!variable_stack_.get(ident_unary->symbol(), &var)) {
TINT_ICE(Resolver, diagnostics_)
<< "failed to resolve identifier";
return false;
}
if (var->declaration->is_const()) {
TINT_ICE(Resolver, diagnostics_)
<< "Resolver::FunctionCall() encountered an address-of "
"expression of a constant identifier expression";
return false;
}
is_valid = true;
}
}
}
if (!is_valid) {
AddError(
"expected an address-of expression of a variable identifier "
"expression or a function parameter",
arg_expr->source());
return false;
}
}
} }
function_calls_.emplace(call,
FunctionCallInfo{callee_func, current_statement_});
SetExprInfo(call, callee_func->return_type, callee_func->return_type_name);
return true; return true;
} }

View File

@ -287,6 +287,8 @@ class Resolver {
bool ValidateCallStatement(ast::CallStatement* stmt); bool ValidateCallStatement(ast::CallStatement* stmt);
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info); bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info); bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunctionCall(const ast::CallExpression* call,
const FunctionInfo* info);
bool ValidateGlobalVariable(const VariableInfo* var); bool ValidateGlobalVariable(const VariableInfo* var);
bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco, bool ValidateInterpolateDecoration(const ast::InterpolateDecoration* deco,
const sem::Type* storage_type); const sem::Type* storage_type);

View File

@ -811,20 +811,20 @@ TEST_F(ResolverValidationTest, Stmt_BreakNotInLoopOrSwitch) {
} }
TEST_F(ResolverValidationTest, StructMemberDuplicateName) { TEST_F(ResolverValidationTest, StructMemberDuplicateName) {
Structure("S", Structure("S", {Member(Source{{12, 34}}, "a", ty.i32()),
{Member("a", ty.i32()), Member(Source{{12, 34}}, "a", ty.i32())}); Member(Source{{56, 78}}, "a", ty.i32())});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ( EXPECT_EQ(r()->error(),
r()->error(), "56:78 error: redefinition of 'a'\n12:34 note: previous definition "
"12:34 error: redefinition of 'a'\nnote: previous definition is here"); "is here");
} }
TEST_F(ResolverValidationTest, StructMemberDuplicateNameDifferentTypes) { TEST_F(ResolverValidationTest, StructMemberDuplicateNameDifferentTypes) {
Structure("S", {Member("a", ty.bool_()), Structure("S", {Member(Source{{12, 34}}, "a", ty.bool_()),
Member(Source{{12, 34}}, "a", ty.vec3<f32>())}); Member(Source{{12, 34}}, "a", ty.vec3<f32>())});
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ( EXPECT_EQ(r()->error(),
r()->error(), "12:34 error: redefinition of 'a'\n12:34 note: previous definition "
"12:34 error: redefinition of 'a'\nnote: previous definition is here"); "is here");
} }
TEST_F(ResolverValidationTest, StructMemberDuplicateNamePass) { TEST_F(ResolverValidationTest, StructMemberDuplicateNamePass) {
Structure("S", {Member("a", ty.i32()), Member("b", ty.f32())}); Structure("S", {Member("a", ty.i32()), Member("b", ty.f32())});

View File

@ -79,35 +79,6 @@ fn f() {
EXPECT_EQ(expect, str(got)); EXPECT_EQ(expect, str(got));
} }
TEST_F(InlinePointerLetsTest, Param) {
auto* src = R"(
fn x(p : ptr<function, i32>) -> i32 {
return *p;
}
fn f() {
var v : i32;
let p : ptr<function, i32> = &v;
var r : i32 = x(p);
}
)";
auto* expect = R"(
fn x(p : ptr<function, i32>) -> i32 {
return *(p);
}
fn f() {
var v : i32;
var r : i32 = x(&(v));
}
)";
auto got = Run<InlinePointerLets>(src);
EXPECT_EQ(expect, str(got));
}
TEST_F(InlinePointerLetsTest, SavedVars) { TEST_F(InlinePointerLetsTest, SavedVars) {
auto* src = R"( auto* src = R"(
struct S { struct S {

View File

@ -286,54 +286,6 @@ TEST_F(HlslSanitizerTest, InlinePtrLetsComplexChain) {
EXPECT_EQ(expect, got); EXPECT_EQ(expect, got);
} }
TEST_F(HlslSanitizerTest, InlineParam) {
// fn x(p : ptr<function, i32>) -> i32 {
// return *p;
// }
//
// [[stage(fragment)]]
// fn main() {
// var v : i32;
// let p : ptr<function, i32> = &v;
// var r : i32 = x(p);
// }
Func("x", {Param("p", ty.pointer<i32>(ast::StorageClass::kFunction))},
ty.i32(), {Return(Deref("p"))});
auto* v = Var("v", ty.i32());
auto* p = Const("p", ty.pointer(ty.i32(), ast::StorageClass::kFunction),
AddressOf(v));
auto* r = Var("r", ty.i32(), ast::StorageClass::kNone, Call("x", p));
Func("main", ast::VariableList{}, ty.void_(),
{
Decl(v),
Decl(p),
Decl(r),
},
{
Stage(ast::PipelineStage::kFragment),
});
GeneratorImpl& gen = SanitizeAndBuild();
ASSERT_TRUE(gen.Generate()) << gen.error();
auto got = gen.result();
auto* expect = R"(int x(inout int p) {
return p;
}
void main() {
int v = 0;
int r = x(v);
return;
}
)";
EXPECT_EQ(expect, got);
}
} // namespace } // namespace
} // namespace hlsl } // namespace hlsl
} // namespace writer } // namespace writer