Add type determination for call expressions.
This CL adds the type determination for call expressions. Bug: tint:5 Change-Id: Ibe08f90ec3905dd1e2169f6e69d1d74943720819 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18844 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
a01777c2d9
commit
3ca8746ebd
|
@ -20,6 +20,7 @@
|
||||||
#include "src/ast/as_expression.h"
|
#include "src/ast/as_expression.h"
|
||||||
#include "src/ast/assignment_statement.h"
|
#include "src/ast/assignment_statement.h"
|
||||||
#include "src/ast/break_statement.h"
|
#include "src/ast/break_statement.h"
|
||||||
|
#include "src/ast/call_expression.h"
|
||||||
#include "src/ast/case_statement.h"
|
#include "src/ast/case_statement.h"
|
||||||
#include "src/ast/continue_statement.h"
|
#include "src/ast/continue_statement.h"
|
||||||
#include "src/ast/else_statement.h"
|
#include "src/ast/else_statement.h"
|
||||||
|
@ -39,10 +40,7 @@
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
|
|
||||||
TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {
|
TypeDeterminer::TypeDeterminer(Context* ctx) : ctx_(*ctx) {}
|
||||||
// TODO(dsinclair): Temporary usage to avoid compiler warning
|
|
||||||
static_cast<void>(ctx_.type_mgr());
|
|
||||||
}
|
|
||||||
|
|
||||||
TypeDeterminer::~TypeDeterminer() = default;
|
TypeDeterminer::~TypeDeterminer() = default;
|
||||||
|
|
||||||
|
@ -175,6 +173,15 @@ bool TypeDeterminer::DetermineResultType(ast::Statement* stmt) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool TypeDeterminer::DetermineResultType(const ast::ExpressionList& exprs) {
|
||||||
|
for (const auto& expr : exprs) {
|
||||||
|
if (!DetermineResultType(expr.get())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
|
bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
|
||||||
// This is blindly called above, so in some cases the expression won't exist.
|
// This is blindly called above, so in some cases the expression won't exist.
|
||||||
if (!expr) {
|
if (!expr) {
|
||||||
|
@ -187,6 +194,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
|
||||||
if (expr->IsAs()) {
|
if (expr->IsAs()) {
|
||||||
return DetermineAs(expr->AsAs());
|
return DetermineAs(expr->AsAs());
|
||||||
}
|
}
|
||||||
|
if (expr->IsCall()) {
|
||||||
|
return DetermineCall(expr->AsCall());
|
||||||
|
}
|
||||||
if (expr->IsConstructor()) {
|
if (expr->IsConstructor()) {
|
||||||
return DetermineConstructor(expr->AsConstructor());
|
return DetermineConstructor(expr->AsConstructor());
|
||||||
}
|
}
|
||||||
|
@ -224,6 +234,17 @@ bool TypeDeterminer::DetermineAs(ast::AsExpression* expr) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
|
||||||
|
if (!DetermineResultType(expr->func())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!DetermineResultType(expr->params())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
expr->set_result_type(expr->func()->result_type());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool TypeDeterminer::DetermineConstructor(ast::ConstructorExpression* expr) {
|
bool TypeDeterminer::DetermineConstructor(ast::ConstructorExpression* expr) {
|
||||||
if (expr->IsTypeConstructor()) {
|
if (expr->IsTypeConstructor()) {
|
||||||
expr->set_result_type(expr->AsTypeConstructor()->type());
|
expr->set_result_type(expr->AsTypeConstructor()->type());
|
||||||
|
|
|
@ -27,6 +27,7 @@ namespace ast {
|
||||||
|
|
||||||
class ArrayAccessorExpression;
|
class ArrayAccessorExpression;
|
||||||
class AsExpression;
|
class AsExpression;
|
||||||
|
class CallExpression;
|
||||||
class ConstructorExpression;
|
class ConstructorExpression;
|
||||||
class IdentifierExpression;
|
class IdentifierExpression;
|
||||||
class Function;
|
class Function;
|
||||||
|
@ -65,6 +66,10 @@ class TypeDeterminer {
|
||||||
/// @param stmt the statement to check
|
/// @param stmt the statement to check
|
||||||
/// @returns true if the determination was successful
|
/// @returns true if the determination was successful
|
||||||
bool DetermineResultType(ast::Statement* stmt);
|
bool DetermineResultType(ast::Statement* stmt);
|
||||||
|
/// Determines type information for a list of expressions
|
||||||
|
/// @param exprs the expressions to check
|
||||||
|
/// @returns true if the determination was successful
|
||||||
|
bool DetermineResultType(const ast::ExpressionList& exprs);
|
||||||
/// Determines type information for an expression
|
/// Determines type information for an expression
|
||||||
/// @param expr the expression to check
|
/// @param expr the expression to check
|
||||||
/// @returns true if the determination was successful
|
/// @returns true if the determination was successful
|
||||||
|
@ -73,6 +78,7 @@ class TypeDeterminer {
|
||||||
private:
|
private:
|
||||||
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
|
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
|
||||||
bool DetermineAs(ast::AsExpression* expr);
|
bool DetermineAs(ast::AsExpression* expr);
|
||||||
|
bool DetermineCall(ast::CallExpression* expr);
|
||||||
bool DetermineConstructor(ast::ConstructorExpression* expr);
|
bool DetermineConstructor(ast::ConstructorExpression* expr);
|
||||||
bool DetermineIdentifier(ast::IdentifierExpression* expr);
|
bool DetermineIdentifier(ast::IdentifierExpression* expr);
|
||||||
Context& ctx_;
|
Context& ctx_;
|
||||||
|
|
|
@ -22,6 +22,7 @@
|
||||||
#include "src/ast/as_expression.h"
|
#include "src/ast/as_expression.h"
|
||||||
#include "src/ast/assignment_statement.h"
|
#include "src/ast/assignment_statement.h"
|
||||||
#include "src/ast/break_statement.h"
|
#include "src/ast/break_statement.h"
|
||||||
|
#include "src/ast/call_expression.h"
|
||||||
#include "src/ast/case_statement.h"
|
#include "src/ast/case_statement.h"
|
||||||
#include "src/ast/continue_statement.h"
|
#include "src/ast/continue_statement.h"
|
||||||
#include "src/ast/else_statement.h"
|
#include "src/ast/else_statement.h"
|
||||||
|
@ -489,6 +490,66 @@ TEST_F(TypeDeterminerTest, Expr_As) {
|
||||||
EXPECT_TRUE(as.result_type()->IsF32());
|
EXPECT_TRUE(as.result_type()->IsF32());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Expr_Call) {
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
|
||||||
|
ast::VariableList params;
|
||||||
|
auto func =
|
||||||
|
std::make_unique<ast::Function>("my_func", std::move(params), &f32);
|
||||||
|
ast::Module m;
|
||||||
|
m.AddFunction(std::move(func));
|
||||||
|
|
||||||
|
// Register the function
|
||||||
|
EXPECT_TRUE(td()->Determine(&m));
|
||||||
|
|
||||||
|
ast::ExpressionList call_params;
|
||||||
|
ast::CallExpression call(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("my_func"),
|
||||||
|
std::move(call_params));
|
||||||
|
EXPECT_TRUE(td()->DetermineResultType(&call));
|
||||||
|
ASSERT_NE(call.result_type(), nullptr);
|
||||||
|
EXPECT_TRUE(call.result_type()->IsF32());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Expr_Call_WithParams) {
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
ast::type::I32Type i32;
|
||||||
|
|
||||||
|
ast::VariableList params;
|
||||||
|
params.push_back(
|
||||||
|
std::make_unique<ast::Variable>("a", ast::StorageClass::kNone, &f32));
|
||||||
|
params.push_back(
|
||||||
|
std::make_unique<ast::Variable>("b", ast::StorageClass::kNone, &i32));
|
||||||
|
|
||||||
|
auto func =
|
||||||
|
std::make_unique<ast::Function>("my_func", std::move(params), &f32);
|
||||||
|
ast::Module m;
|
||||||
|
m.AddFunction(std::move(func));
|
||||||
|
|
||||||
|
// Register the function
|
||||||
|
EXPECT_TRUE(td()->Determine(&m));
|
||||||
|
|
||||||
|
ast::ExpressionList call_params;
|
||||||
|
call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::FloatLiteral>(&f32, 2.5f)));
|
||||||
|
auto a_ptr = call_params.back().get();
|
||||||
|
call_params.push_back(std::make_unique<ast::ScalarConstructorExpression>(
|
||||||
|
std::make_unique<ast::IntLiteral>(&i32, 1)));
|
||||||
|
auto b_ptr = call_params.back().get();
|
||||||
|
|
||||||
|
ast::CallExpression call(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("my_func"),
|
||||||
|
std::move(call_params));
|
||||||
|
EXPECT_TRUE(td()->DetermineResultType(&call));
|
||||||
|
ASSERT_NE(call.result_type(), nullptr);
|
||||||
|
EXPECT_TRUE(call.result_type()->IsF32());
|
||||||
|
|
||||||
|
ASSERT_NE(a_ptr->result_type(), nullptr);
|
||||||
|
EXPECT_TRUE(a_ptr->result_type()->IsF32());
|
||||||
|
ASSERT_NE(b_ptr->result_type(), nullptr);
|
||||||
|
EXPECT_TRUE(b_ptr->result_type()->IsI32());
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TypeDeterminerTest, Expr_Constructor_Scalar) {
|
TEST_F(TypeDeterminerTest, Expr_Constructor_Scalar) {
|
||||||
ast::type::F32Type f32;
|
ast::type::F32Type f32;
|
||||||
ast::ScalarConstructorExpression s(
|
ast::ScalarConstructorExpression s(
|
||||||
|
|
Loading…
Reference in New Issue