Validate function call arguments

- Add resolver/call_test.cc for new unit tests, and move a couple that
were in resolver/validation_test.cc to it

- Fix CalculateArrayLength transform so that it passes the address of
the u32 it creates to the internal function

- Fix tests broken as a result of this change

Bug: tint:664
Change-Id: If713f9828790cd51224d2392d42c01c0057cb652
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/53920
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano
2021-06-09 20:17:59 +00:00
committed by Tint LUCI CQ
parent 1987fd80f4
commit 14b3403148
19 changed files with 371 additions and 139 deletions

194
src/resolver/call_test.cc Normal file
View File

@@ -0,0 +1,194 @@
// Copyright 2021 The Tint Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "src/resolver/resolver.h"
#include "gmock/gmock.h"
#include "src/ast/call_statement.h"
#include "src/resolver/resolver_test_helper.h"
namespace tint {
namespace resolver {
namespace {
// Helpers and typedefs
template <typename T>
using DataType = builder::DataType<T>;
template <int N, typename T>
using vec = builder::vec<N, T>;
template <typename T>
using vec2 = builder::vec2<T>;
template <typename T>
using vec3 = builder::vec3<T>;
template <typename T>
using vec4 = builder::vec4<T>;
template <int N, int M, typename T>
using mat = builder::mat<N, M, T>;
template <typename T>
using mat2x2 = builder::mat2x2<T>;
template <typename T>
using mat2x3 = builder::mat2x3<T>;
template <typename T>
using mat3x2 = builder::mat3x2<T>;
template <typename T>
using mat3x3 = builder::mat3x3<T>;
template <typename T>
using mat4x4 = builder::mat4x4<T>;
template <typename T, int ID = 0>
using alias = builder::alias<T, ID>;
template <typename T>
using alias1 = builder::alias1<T>;
template <typename T>
using alias2 = builder::alias2<T>;
template <typename T>
using alias3 = builder::alias3<T>;
using f32 = builder::f32;
using i32 = builder::i32;
using u32 = builder::u32;
using ResolverCallTest = ResolverTest;
TEST_F(ResolverCallTest, Recursive_Invalid) {
// fn main() {main(); }
SetSource(Source::Location{12, 34});
auto* call_expr = Call("main");
ast::VariableList params0;
Func("main", params0, ty.void_(),
ast::StatementList{
create<ast::CallStatement>(call_expr),
},
ast::DecorationList{
Stage(ast::PipelineStage::kVertex),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error v-0004: recursion is not permitted. 'main' attempted "
"to call "
"itself.");
}
TEST_F(ResolverCallTest, Undeclared_Invalid) {
// fn main() {func(); return; }
// fn func() { return; }
SetSource(Source::Location{12, 34});
auto* call_expr = Call("func");
ast::VariableList params0;
Func("main", params0, ty.f32(),
ast::StatementList{
create<ast::CallStatement>(call_expr),
Return(),
},
ast::DecorationList{});
Func("func", params0, ty.f32(),
ast::StatementList{
Return(),
},
ast::DecorationList{});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: v-0006: unable to find called function: func");
}
struct Params {
builder::ast_expr_func_ptr create_value;
builder::ast_type_func_ptr create_type;
};
template <typename T>
constexpr Params ParamsFor() {
return Params{DataType<T>::Expr, DataType<T>::AST};
}
static constexpr Params all_param_types[] = {
ParamsFor<bool>(), //
ParamsFor<u32>(), //
ParamsFor<i32>(), //
ParamsFor<f32>(), //
ParamsFor<vec3<bool>>(), //
ParamsFor<vec3<i32>>(), //
ParamsFor<vec3<u32>>(), //
ParamsFor<vec3<f32>>(), //
ParamsFor<mat3x3<i32>>(), //
ParamsFor<mat3x3<u32>>(), //
ParamsFor<mat3x3<f32>>(), //
ParamsFor<mat2x3<i32>>(), //
ParamsFor<mat2x3<u32>>(), //
ParamsFor<mat2x3<f32>>(), //
ParamsFor<mat3x2<i32>>(), //
ParamsFor<mat3x2<u32>>(), //
ParamsFor<mat3x2<f32>>() //
};
TEST_F(ResolverCallTest, Valid) {
ast::VariableList params;
ast::ExpressionList args;
for (auto& p : all_param_types) {
params.push_back(Param(Sym(), p.create_type(*this)));
args.push_back(p.create_value(*this, 0));
}
Func("foo", std::move(params), ty.void_(), {Return()});
auto* call = Call("foo", std::move(args));
WrapInFunction(call);
EXPECT_TRUE(r()->Resolve()) << r()->error();
}
TEST_F(ResolverCallTest, TooFewArgs) {
Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(),
{Return()});
auto* call = Call(Source{{12, 34}}, "foo", 1);
WrapInFunction(call);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: too few arguments in call to 'foo', expected 2, got 1");
}
TEST_F(ResolverCallTest, TooManyArgs) {
Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(),
{Return()});
auto* call = Call(Source{{12, 34}}, "foo", 1, 1.0f, 1.0f);
WrapInFunction(call);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(
r()->error(),
"12:34 error: too many arguments in call to 'foo', expected 2, got 3");
}
TEST_F(ResolverCallTest, MismatchedArgs) {
Func("foo", {Param(Sym(), ty.i32()), Param(Sym(), ty.f32())}, ty.void_(),
{Return()});
auto* call = Call("foo", Expr(Source{{12, 34}}, true), 1.0f);
WrapInFunction(call);
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: type mismatch for argument 1 in call to 'foo', "
"expected 'i32', got 'bool'");
}
} // namespace
} // namespace resolver
} // namespace tint

View File

@@ -1701,51 +1701,9 @@ bool Resolver::Call(ast::CallExpression* call) {
return false;
}
} else {
if (current_function_) {
auto callee_func_it = symbol_to_function_.find(ident->symbol());
if (callee_func_it == symbol_to_function_.end()) {
if (current_function_->declaration->symbol() == ident->symbol()) {
diagnostics_.add_error("v-0004",
"recursion is not permitted. '" + name +
"' attempted to call itself.",
call->source());
} else {
diagnostics_.add_error(
"v-0006: unable to find called function: " + name,
call->source());
}
return false;
}
auto* callee_func = callee_func_it->second;
callee_func->callsites.push_back(call);
// Note: Requires called functions to be resolved first.
// This is currently guaranteed as functions must be declared before
// use.
current_function_->transitive_calls.add(callee_func);
for (auto* transitive_call : callee_func->transitive_calls) {
current_function_->transitive_calls.add(transitive_call);
}
// We inherit any referenced variables from the callee.
for (auto* var : callee_func->referenced_module_vars) {
set_referenced_from_function_if_needed(var, false);
}
}
auto iter = symbol_to_function_.find(ident->symbol());
if (iter == symbol_to_function_.end()) {
diagnostics_.add_error(
"v-0005: function must be declared before use: '" + name + "'",
call->source());
if (!FunctionCall(call)) {
return false;
}
auto* function = iter->second;
function_calls_.emplace(call,
FunctionCallInfo{function, current_statement_});
SetType(call, function->return_type, function->return_type_name);
}
return true;
@@ -1775,6 +1733,79 @@ bool Resolver::IntrinsicCall(ast::CallExpression* call,
return true;
}
bool Resolver::FunctionCall(const ast::CallExpression* call) {
auto* ident = call->func();
auto name = builder_->Symbols().NameFor(ident->symbol());
auto callee_func_it = symbol_to_function_.find(ident->symbol());
if (callee_func_it == symbol_to_function_.end()) {
if (current_function_ &&
current_function_->declaration->symbol() == ident->symbol()) {
diagnostics_.add_error("v-0004",
"recursion is not permitted. '" + name +
"' attempted to call itself.",
call->source());
} else {
diagnostics_.add_error("v-0006: unable to find called function: " + name,
call->source());
}
return false;
}
auto* callee_func = callee_func_it->second;
if (current_function_) {
callee_func->callsites.push_back(call);
// Note: Requires called functions to be resolved first.
// This is currently guaranteed as functions must be declared before
// use.
current_function_->transitive_calls.add(callee_func);
for (auto* transitive_call : callee_func->transitive_calls) {
current_function_->transitive_calls.add(transitive_call);
}
// We inherit any referenced variables from the callee.
for (auto* var : callee_func->referenced_module_vars) {
set_referenced_from_function_if_needed(var, false);
}
}
// Validate number of arguments match number of parameters
if (call->params().size() != callee_func->parameters.size()) {
bool more = call->params().size() > callee_func->parameters.size();
diagnostics_.add_error(
"too " + (more ? std::string("many") : std::string("few")) +
" arguments in call to '" + name + "', expected " +
std::to_string(callee_func->parameters.size()) + ", got " +
std::to_string(call->params().size()),
call->source());
return false;
}
// Validate arguments match parameter types
for (size_t i = 0; i < call->params().size(); ++i) {
const VariableInfo* param = callee_func->parameters[i];
const ast::Expression* arg_expr = call->params()[i];
auto* arg_type = TypeOf(arg_expr)->UnwrapRef();
if (param->type != arg_type) {
diagnostics_.add_error(
"type mismatch for argument " + std::to_string(i + 1) +
" in call to '" + name + "', expected '" +
param->type->FriendlyName(builder_->Symbols()) + "', got '" +
arg_type->FriendlyName(builder_->Symbols()) + "'",
arg_expr->source());
return false;
}
}
function_calls_.emplace(call,
FunctionCallInfo{callee_func, current_statement_});
SetType(call, callee_func->return_type, callee_func->return_type_name);
return true;
}
bool Resolver::Constructor(ast::ConstructorExpression* expr) {
if (auto* type_ctor = expr->As<ast::TypeConstructorExpression>()) {
for (auto* value : type_ctor->values()) {
@@ -2514,11 +2545,11 @@ sem::Type* Resolver::TypeOf(const ast::Literal* lit) {
return nullptr;
}
void Resolver::SetType(ast::Expression* expr, const sem::Type* type) {
void Resolver::SetType(const ast::Expression* expr, const sem::Type* type) {
SetType(expr, type, type->FriendlyName(builder_->Symbols()));
}
void Resolver::SetType(ast::Expression* expr,
void Resolver::SetType(const ast::Expression* expr,
const sem::Type* type,
const std::string& type_name) {
if (expr_info_.count(expr)) {

View File

@@ -234,6 +234,7 @@ class Resolver {
bool Identifier(ast::IdentifierExpression*);
bool IfStatement(ast::IfStatement*);
bool IntrinsicCall(ast::CallExpression*, sem::IntrinsicType);
bool FunctionCall(const ast::CallExpression* call);
bool LoopStatement(ast::LoopStatement*);
bool MemberAccessor(ast::MemberAccessorExpression*);
bool Parameter(ast::Variable* param);
@@ -345,7 +346,7 @@ class Resolver {
/// assigns this semantic node to the expression `expr`.
/// @param expr the expression
/// @param type the resolved type
void SetType(ast::Expression* expr, const sem::Type* type);
void SetType(const ast::Expression* expr, const sem::Type* type);
/// Creates a sem::Expression node with the resolved type `type`, the declared
/// type name `type_name` and assigns this semantic node to the expression
@@ -353,7 +354,7 @@ class Resolver {
/// @param expr the expression
/// @param type the resolved type
/// @param type_name the declared type name
void SetType(ast::Expression* expr,
void SetType(const ast::Expression* expr,
const sem::Type* type,
const std::string& type_name);
@@ -396,7 +397,8 @@ class Resolver {
std::vector<FunctionInfo*> entry_points_;
std::unordered_map<const ast::Function*, FunctionInfo*> function_to_info_;
std::unordered_map<const ast::Variable*, VariableInfo*> variable_to_info_;
std::unordered_map<ast::CallExpression*, FunctionCallInfo> function_calls_;
std::unordered_map<const ast::CallExpression*, FunctionCallInfo>
function_calls_;
std::unordered_map<const ast::Expression*, ExpressionInfo> expr_info_;
std::unordered_map<Symbol, TypeDeclInfo> named_type_info_;

View File

@@ -560,8 +560,7 @@ TEST_F(ResolverTest, Expr_Call_InBinaryOp) {
}
TEST_F(ResolverTest, Expr_Call_WithParams) {
ast::VariableList params;
Func("my_func", params, ty.f32(),
Func("my_func", {Param(Sym(), ty.f32())}, ty.f32(),
{
Return(1.2f),
});

View File

@@ -83,56 +83,6 @@ TEST_F(ResolverValidationTest, Stmt_Error_Unknown) {
"2:30 error: unknown statement type for type determination: Fake");
}
TEST_F(ResolverValidationTest, Stmt_Call_undeclared) {
// fn main() {func(); return; }
// fn func() { return; }
SetSource(Source::Location{12, 34});
auto* call_expr = Call("func");
ast::VariableList params0;
Func("main", params0, ty.f32(),
ast::StatementList{
create<ast::CallStatement>(call_expr),
Return(),
},
ast::DecorationList{});
Func("func", params0, ty.f32(),
ast::StatementList{
Return(),
},
ast::DecorationList{});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error: v-0006: unable to find called function: func");
}
TEST_F(ResolverValidationTest, Stmt_Call_recursive) {
// fn main() {main(); }
SetSource(Source::Location{12, 34});
auto* call_expr = Call("main");
ast::VariableList params0;
Func("main", params0, ty.void_(),
ast::StatementList{
create<ast::CallStatement>(call_expr),
},
ast::DecorationList{
Stage(ast::PipelineStage::kVertex),
});
EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(),
"12:34 error v-0004: recursion is not permitted. 'main' attempted "
"to call "
"itself.");
}
TEST_F(ResolverValidationTest, Stmt_If_NonBool) {
// if (1.23f) {}