Add relational expression type determination.

This CL adds the type determination for each of the relation types in
the relational expression.

Bug: tint:5
Change-Id: I15e8dae2f90cc4a0f720692f5addb944b26811ec
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18847
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-04-07 19:26:39 +00:00
parent 8ee1d22882
commit 9b97802d99
3 changed files with 559 additions and 1 deletions

View File

@ -30,10 +30,12 @@
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/regardless_statement.h"
#include "src/ast/relational_expression.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/bool_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h"
@ -206,6 +208,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
if (expr->IsMemberAccessor()) {
return DetermineMemberAccessor(expr->AsMemberAccessor());
}
if (expr->IsRelational()) {
return DetermineRelational(expr->AsRelational());
}
error_ = "unknown expression for type determination";
return false;
@ -321,4 +326,79 @@ bool TypeDeterminer::DetermineMemberAccessor(
return false;
}
bool TypeDeterminer::DetermineRelational(ast::RelationalExpression* expr) {
if (!DetermineResultType(expr->lhs()) || !DetermineResultType(expr->rhs())) {
return false;
}
// Result type matches first parameter type
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
expr->IsShiftRight() || expr->IsShiftRightArith() || expr->IsAdd() ||
expr->IsSubtract() || expr->IsDivide() || expr->IsModulo()) {
expr->set_result_type(expr->lhs()->result_type());
return true;
}
// Result type is a scalar or vector of boolean type
if (expr->IsLogicalAnd() || expr->IsLogicalOr() || expr->IsEqual() ||
expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
auto bool_type =
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
auto param_type = expr->lhs()->result_type();
if (param_type->IsVector()) {
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
bool_type, param_type->AsVector()->size())));
} else {
expr->set_result_type(bool_type);
}
return true;
}
if (expr->IsMultiply()) {
auto lhs_type = expr->lhs()->result_type();
auto rhs_type = expr->rhs()->result_type();
// Note, the ordering here matters. The later checks depend on the prior
// checks having been done.
if (lhs_type->IsMatrix() && rhs_type->IsMatrix()) {
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
lhs_type->AsMatrix()->type(), lhs_type->AsMatrix()->rows(),
rhs_type->AsMatrix()->columns())));
} else if (lhs_type->IsMatrix() && rhs_type->IsVector()) {
auto mat = lhs_type->AsMatrix();
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
mat->type(), mat->rows())));
} else if (lhs_type->IsVector() && rhs_type->IsMatrix()) {
auto mat = rhs_type->AsMatrix();
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
mat->type(), mat->columns())));
} else if (lhs_type->IsMatrix()) {
// matrix * scalar
expr->set_result_type(lhs_type);
} else if (rhs_type->IsMatrix()) {
// scalar * matrix
expr->set_result_type(rhs_type);
} else if (lhs_type->IsVector() && rhs_type->IsVector()) {
expr->set_result_type(lhs_type);
} else if (lhs_type->IsVector()) {
// Vector * scalar
expr->set_result_type(lhs_type);
} else if (rhs_type->IsVector()) {
// Scalar * vector
expr->set_result_type(rhs_type);
} else {
// Scalar * Scalar
expr->set_result_type(lhs_type);
}
return true;
}
return false;
}
} // namespace tint

View File

@ -30,9 +30,10 @@ class AsExpression;
class CallExpression;
class CastExpression;
class ConstructorExpression;
class Function;
class IdentifierExpression;
class MemberAccessorExpression;
class Function;
class RelationalExpression;
class Variable;
} // namespace ast
@ -81,6 +82,7 @@ class TypeDeterminer {
bool DetermineConstructor(ast::ConstructorExpression* expr);
bool DetermineIdentifier(ast::IdentifierExpression* expr);
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
bool DetermineRelational(ast::RelationalExpression* expr);
Context& ctx_;
std::string error_;

View File

@ -35,12 +35,14 @@
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/regardless_statement.h"
#include "src/ast/relational_expression.h"
#include "src/ast/return_statement.h"
#include "src/ast/scalar_constructor_expression.h"
#include "src/ast/struct.h"
#include "src/ast/struct_member.h"
#include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h"
#include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h"
@ -755,5 +757,479 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
}
using Expr_Relational_BitwiseTest = testing::TestWithParam<ast::Relation>;
TEST_P(Expr_Relational_BitwiseTest, Scalar) {
auto op = GetParam();
ast::type::I32Type i32;
auto var =
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(var));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
op, std::make_unique<ast::IdentifierExpression>("val"),
std::make_unique<ast::IdentifierExpression>("val"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
EXPECT_TRUE(expr.result_type()->IsI32());
}
TEST_P(Expr_Relational_BitwiseTest, Vector) {
auto op = GetParam();
ast::type::I32Type i32;
ast::type::VectorType vec3(&i32, 3);
auto var =
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(var));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
op, std::make_unique<ast::IdentifierExpression>("val"),
std::make_unique<ast::IdentifierExpression>("val"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsVector());
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsI32());
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
}
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
Expr_Relational_BitwiseTest,
testing::Values(ast::Relation::kAnd,
ast::Relation::kOr,
ast::Relation::kXor,
ast::Relation::kShiftLeft,
ast::Relation::kShiftRight,
ast::Relation::kShiftRightArith,
ast::Relation::kAdd,
ast::Relation::kSubtract,
ast::Relation::kDivide,
ast::Relation::kModulo));
using Expr_Relational_LogicalTest = testing::TestWithParam<ast::Relation>;
TEST_P(Expr_Relational_LogicalTest, Scalar) {
auto op = GetParam();
ast::type::BoolType bool_type;
auto var = std::make_unique<ast::Variable>("val", ast::StorageClass::kNone,
&bool_type);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(var));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
op, std::make_unique<ast::IdentifierExpression>("val"),
std::make_unique<ast::IdentifierExpression>("val"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
EXPECT_TRUE(expr.result_type()->IsBool());
}
TEST_P(Expr_Relational_LogicalTest, Vector) {
auto op = GetParam();
ast::type::BoolType bool_type;
ast::type::VectorType vec3(&bool_type, 3);
auto var =
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(var));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
op, std::make_unique<ast::IdentifierExpression>("val"),
std::make_unique<ast::IdentifierExpression>("val"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsVector());
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsBool());
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
}
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
Expr_Relational_LogicalTest,
testing::Values(ast::Relation::kLogicalAnd,
ast::Relation::kLogicalOr));
using Expr_Relational_CompareTest = testing::TestWithParam<ast::Relation>;
TEST_P(Expr_Relational_CompareTest, Scalar) {
auto op = GetParam();
ast::type::I32Type i32;
auto var =
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(var));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
op, std::make_unique<ast::IdentifierExpression>("val"),
std::make_unique<ast::IdentifierExpression>("val"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
EXPECT_TRUE(expr.result_type()->IsBool());
}
TEST_P(Expr_Relational_CompareTest, Vector) {
auto op = GetParam();
ast::type::I32Type i32;
ast::type::VectorType vec3(&i32, 3);
auto var =
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(var));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
op, std::make_unique<ast::IdentifierExpression>("val"),
std::make_unique<ast::IdentifierExpression>("val"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsVector());
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsBool());
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
}
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
Expr_Relational_CompareTest,
testing::Values(ast::Relation::kEqual,
ast::Relation::kNotEqual,
ast::Relation::kLessThan,
ast::Relation::kGreaterThan,
ast::Relation::kLessThanEqual,
ast::Relation::kGreaterThanEqual));
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Scalar) {
ast::type::I32Type i32;
auto var =
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(var));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
ast::Relation::kMultiply,
std::make_unique<ast::IdentifierExpression>("val"),
std::make_unique<ast::IdentifierExpression>("val"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
EXPECT_TRUE(expr.result_type()->IsI32());
}
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Scalar) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto scalar =
std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
auto vector = std::make_unique<ast::Variable>(
"vector", ast::StorageClass::kNone, &vec3);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(scalar));
m.AddGlobalVariable(std::move(vector));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
ast::Relation::kMultiply,
std::make_unique<ast::IdentifierExpression>("vector"),
std::make_unique<ast::IdentifierExpression>("scalar"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsVector());
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
}
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Vector) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto scalar =
std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
auto vector = std::make_unique<ast::Variable>(
"vector", ast::StorageClass::kNone, &vec3);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(scalar));
m.AddGlobalVariable(std::move(vector));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
ast::Relation::kMultiply,
std::make_unique<ast::IdentifierExpression>("scalar"),
std::make_unique<ast::IdentifierExpression>("vector"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsVector());
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
}
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Vector) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto vector = std::make_unique<ast::Variable>(
"vector", ast::StorageClass::kNone, &vec3);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(vector));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
ast::Relation::kMultiply,
std::make_unique<ast::IdentifierExpression>("vector"),
std::make_unique<ast::IdentifierExpression>("vector"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsVector());
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
}
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Scalar) {
ast::type::F32Type f32;
ast::type::MatrixType mat3x2(&f32, 3, 2);
auto scalar =
std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
auto matrix = std::make_unique<ast::Variable>(
"matrix", ast::StorageClass::kNone, &mat3x2);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(scalar));
m.AddGlobalVariable(std::move(matrix));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
ast::Relation::kMultiply,
std::make_unique<ast::IdentifierExpression>("matrix"),
std::make_unique<ast::IdentifierExpression>("scalar"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsMatrix());
auto mat = expr.result_type()->AsMatrix();
EXPECT_TRUE(mat->type()->IsF32());
EXPECT_EQ(mat->rows(), 3);
EXPECT_EQ(mat->columns(), 2);
}
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Matrix) {
ast::type::F32Type f32;
ast::type::MatrixType mat3x2(&f32, 3, 2);
auto scalar =
std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
auto matrix = std::make_unique<ast::Variable>(
"matrix", ast::StorageClass::kNone, &mat3x2);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(scalar));
m.AddGlobalVariable(std::move(matrix));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
ast::Relation::kMultiply,
std::make_unique<ast::IdentifierExpression>("scalar"),
std::make_unique<ast::IdentifierExpression>("matrix"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsMatrix());
auto mat = expr.result_type()->AsMatrix();
EXPECT_TRUE(mat->type()->IsF32());
EXPECT_EQ(mat->rows(), 3);
EXPECT_EQ(mat->columns(), 2);
}
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Vector) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 2);
ast::type::MatrixType mat3x2(&f32, 3, 2);
auto vector = std::make_unique<ast::Variable>(
"vector", ast::StorageClass::kNone, &vec3);
auto matrix = std::make_unique<ast::Variable>(
"matrix", ast::StorageClass::kNone, &mat3x2);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(vector));
m.AddGlobalVariable(std::move(matrix));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
ast::Relation::kMultiply,
std::make_unique<ast::IdentifierExpression>("matrix"),
std::make_unique<ast::IdentifierExpression>("vector"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsVector());
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
}
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Matrix) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::MatrixType mat3x2(&f32, 3, 2);
auto vector = std::make_unique<ast::Variable>(
"vector", ast::StorageClass::kNone, &vec3);
auto matrix = std::make_unique<ast::Variable>(
"matrix", ast::StorageClass::kNone, &mat3x2);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(vector));
m.AddGlobalVariable(std::move(matrix));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
ast::Relation::kMultiply,
std::make_unique<ast::IdentifierExpression>("vector"),
std::make_unique<ast::IdentifierExpression>("matrix"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsVector());
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
EXPECT_EQ(expr.result_type()->AsVector()->size(), 2);
}
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Matrix) {
ast::type::F32Type f32;
ast::type::MatrixType mat4x3(&f32, 4, 3);
ast::type::MatrixType mat3x4(&f32, 3, 4);
auto matrix1 = std::make_unique<ast::Variable>(
"mat4x3", ast::StorageClass::kNone, &mat4x3);
auto matrix2 = std::make_unique<ast::Variable>(
"mat3x4", ast::StorageClass::kNone, &mat3x4);
Context ctx;
TypeDeterminer td(&ctx);
ast::Module m;
m.AddGlobalVariable(std::move(matrix1));
m.AddGlobalVariable(std::move(matrix2));
// Register the global
ASSERT_TRUE(td.Determine(&m)) << td.error();
ast::RelationalExpression expr(
ast::Relation::kMultiply,
std::make_unique<ast::IdentifierExpression>("mat4x3"),
std::make_unique<ast::IdentifierExpression>("mat3x4"));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ASSERT_NE(expr.result_type(), nullptr);
ASSERT_TRUE(expr.result_type()->IsMatrix());
auto mat = expr.result_type()->AsMatrix();
EXPECT_TRUE(mat->type()->IsF32());
EXPECT_EQ(mat->rows(), 4);
EXPECT_EQ(mat->columns(), 4);
}
} // namespace
} // namespace tint