Add relational expression type determination.
This CL adds the type determination for each of the relation types in the relational expression. Bug: tint:5 Change-Id: I15e8dae2f90cc4a0f720692f5addb944b26811ec Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18847 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
8ee1d22882
commit
9b97802d99
|
@ -30,10 +30,12 @@
|
|||
#include "src/ast/loop_statement.h"
|
||||
#include "src/ast/member_accessor_expression.h"
|
||||
#include "src/ast/regardless_statement.h"
|
||||
#include "src/ast/relational_expression.h"
|
||||
#include "src/ast/return_statement.h"
|
||||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/switch_statement.h"
|
||||
#include "src/ast/type/array_type.h"
|
||||
#include "src/ast/type/bool_type.h"
|
||||
#include "src/ast/type/matrix_type.h"
|
||||
#include "src/ast/type/struct_type.h"
|
||||
#include "src/ast/type/vector_type.h"
|
||||
|
@ -206,6 +208,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
|
|||
if (expr->IsMemberAccessor()) {
|
||||
return DetermineMemberAccessor(expr->AsMemberAccessor());
|
||||
}
|
||||
if (expr->IsRelational()) {
|
||||
return DetermineRelational(expr->AsRelational());
|
||||
}
|
||||
|
||||
error_ = "unknown expression for type determination";
|
||||
return false;
|
||||
|
@ -321,4 +326,79 @@ bool TypeDeterminer::DetermineMemberAccessor(
|
|||
return false;
|
||||
}
|
||||
|
||||
bool TypeDeterminer::DetermineRelational(ast::RelationalExpression* expr) {
|
||||
if (!DetermineResultType(expr->lhs()) || !DetermineResultType(expr->rhs())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Result type matches first parameter type
|
||||
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
|
||||
expr->IsShiftRight() || expr->IsShiftRightArith() || expr->IsAdd() ||
|
||||
expr->IsSubtract() || expr->IsDivide() || expr->IsModulo()) {
|
||||
expr->set_result_type(expr->lhs()->result_type());
|
||||
return true;
|
||||
}
|
||||
// Result type is a scalar or vector of boolean type
|
||||
if (expr->IsLogicalAnd() || expr->IsLogicalOr() || expr->IsEqual() ||
|
||||
expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
|
||||
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
|
||||
auto bool_type =
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
|
||||
auto param_type = expr->lhs()->result_type();
|
||||
if (param_type->IsVector()) {
|
||||
expr->set_result_type(
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
||||
bool_type, param_type->AsVector()->size())));
|
||||
} else {
|
||||
expr->set_result_type(bool_type);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
if (expr->IsMultiply()) {
|
||||
auto lhs_type = expr->lhs()->result_type();
|
||||
auto rhs_type = expr->rhs()->result_type();
|
||||
|
||||
// Note, the ordering here matters. The later checks depend on the prior
|
||||
// checks having been done.
|
||||
if (lhs_type->IsMatrix() && rhs_type->IsMatrix()) {
|
||||
expr->set_result_type(
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
|
||||
lhs_type->AsMatrix()->type(), lhs_type->AsMatrix()->rows(),
|
||||
rhs_type->AsMatrix()->columns())));
|
||||
|
||||
} else if (lhs_type->IsMatrix() && rhs_type->IsVector()) {
|
||||
auto mat = lhs_type->AsMatrix();
|
||||
expr->set_result_type(
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
||||
mat->type(), mat->rows())));
|
||||
} else if (lhs_type->IsVector() && rhs_type->IsMatrix()) {
|
||||
auto mat = rhs_type->AsMatrix();
|
||||
expr->set_result_type(
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
||||
mat->type(), mat->columns())));
|
||||
} else if (lhs_type->IsMatrix()) {
|
||||
// matrix * scalar
|
||||
expr->set_result_type(lhs_type);
|
||||
} else if (rhs_type->IsMatrix()) {
|
||||
// scalar * matrix
|
||||
expr->set_result_type(rhs_type);
|
||||
} else if (lhs_type->IsVector() && rhs_type->IsVector()) {
|
||||
expr->set_result_type(lhs_type);
|
||||
} else if (lhs_type->IsVector()) {
|
||||
// Vector * scalar
|
||||
expr->set_result_type(lhs_type);
|
||||
} else if (rhs_type->IsVector()) {
|
||||
// Scalar * vector
|
||||
expr->set_result_type(rhs_type);
|
||||
} else {
|
||||
// Scalar * Scalar
|
||||
expr->set_result_type(lhs_type);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace tint
|
||||
|
|
|
@ -30,9 +30,10 @@ class AsExpression;
|
|||
class CallExpression;
|
||||
class CastExpression;
|
||||
class ConstructorExpression;
|
||||
class Function;
|
||||
class IdentifierExpression;
|
||||
class MemberAccessorExpression;
|
||||
class Function;
|
||||
class RelationalExpression;
|
||||
class Variable;
|
||||
|
||||
} // namespace ast
|
||||
|
@ -81,6 +82,7 @@ class TypeDeterminer {
|
|||
bool DetermineConstructor(ast::ConstructorExpression* expr);
|
||||
bool DetermineIdentifier(ast::IdentifierExpression* expr);
|
||||
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
|
||||
bool DetermineRelational(ast::RelationalExpression* expr);
|
||||
|
||||
Context& ctx_;
|
||||
std::string error_;
|
||||
|
|
|
@ -35,12 +35,14 @@
|
|||
#include "src/ast/loop_statement.h"
|
||||
#include "src/ast/member_accessor_expression.h"
|
||||
#include "src/ast/regardless_statement.h"
|
||||
#include "src/ast/relational_expression.h"
|
||||
#include "src/ast/return_statement.h"
|
||||
#include "src/ast/scalar_constructor_expression.h"
|
||||
#include "src/ast/struct.h"
|
||||
#include "src/ast/struct_member.h"
|
||||
#include "src/ast/switch_statement.h"
|
||||
#include "src/ast/type/array_type.h"
|
||||
#include "src/ast/type/bool_type.h"
|
||||
#include "src/ast/type/f32_type.h"
|
||||
#include "src/ast/type/i32_type.h"
|
||||
#include "src/ast/type/matrix_type.h"
|
||||
|
@ -755,5 +757,479 @@ TEST_F(TypeDeterminerTest, Expr_MultiLevel) {
|
|||
EXPECT_EQ(mem.result_type()->AsVector()->size(), 2);
|
||||
}
|
||||
|
||||
using Expr_Relational_BitwiseTest = testing::TestWithParam<ast::Relation>;
|
||||
TEST_P(Expr_Relational_BitwiseTest, Scalar) {
|
||||
auto op = GetParam();
|
||||
|
||||
ast::type::I32Type i32;
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
op, std::make_unique<ast::IdentifierExpression>("val"),
|
||||
std::make_unique<ast::IdentifierExpression>("val"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
EXPECT_TRUE(expr.result_type()->IsI32());
|
||||
}
|
||||
|
||||
TEST_P(Expr_Relational_BitwiseTest, Vector) {
|
||||
auto op = GetParam();
|
||||
|
||||
ast::type::I32Type i32;
|
||||
ast::type::VectorType vec3(&i32, 3);
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
op, std::make_unique<ast::IdentifierExpression>("val"),
|
||||
std::make_unique<ast::IdentifierExpression>("val"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsVector());
|
||||
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsI32());
|
||||
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
|
||||
Expr_Relational_BitwiseTest,
|
||||
testing::Values(ast::Relation::kAnd,
|
||||
ast::Relation::kOr,
|
||||
ast::Relation::kXor,
|
||||
ast::Relation::kShiftLeft,
|
||||
ast::Relation::kShiftRight,
|
||||
ast::Relation::kShiftRightArith,
|
||||
ast::Relation::kAdd,
|
||||
ast::Relation::kSubtract,
|
||||
ast::Relation::kDivide,
|
||||
ast::Relation::kModulo));
|
||||
|
||||
using Expr_Relational_LogicalTest = testing::TestWithParam<ast::Relation>;
|
||||
TEST_P(Expr_Relational_LogicalTest, Scalar) {
|
||||
auto op = GetParam();
|
||||
|
||||
ast::type::BoolType bool_type;
|
||||
|
||||
auto var = std::make_unique<ast::Variable>("val", ast::StorageClass::kNone,
|
||||
&bool_type);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
op, std::make_unique<ast::IdentifierExpression>("val"),
|
||||
std::make_unique<ast::IdentifierExpression>("val"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
EXPECT_TRUE(expr.result_type()->IsBool());
|
||||
}
|
||||
|
||||
TEST_P(Expr_Relational_LogicalTest, Vector) {
|
||||
auto op = GetParam();
|
||||
|
||||
ast::type::BoolType bool_type;
|
||||
ast::type::VectorType vec3(&bool_type, 3);
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
op, std::make_unique<ast::IdentifierExpression>("val"),
|
||||
std::make_unique<ast::IdentifierExpression>("val"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsVector());
|
||||
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsBool());
|
||||
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
|
||||
Expr_Relational_LogicalTest,
|
||||
testing::Values(ast::Relation::kLogicalAnd,
|
||||
ast::Relation::kLogicalOr));
|
||||
|
||||
using Expr_Relational_CompareTest = testing::TestWithParam<ast::Relation>;
|
||||
TEST_P(Expr_Relational_CompareTest, Scalar) {
|
||||
auto op = GetParam();
|
||||
|
||||
ast::type::I32Type i32;
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
op, std::make_unique<ast::IdentifierExpression>("val"),
|
||||
std::make_unique<ast::IdentifierExpression>("val"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
EXPECT_TRUE(expr.result_type()->IsBool());
|
||||
}
|
||||
|
||||
TEST_P(Expr_Relational_CompareTest, Vector) {
|
||||
auto op = GetParam();
|
||||
|
||||
ast::type::I32Type i32;
|
||||
ast::type::VectorType vec3(&i32, 3);
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &vec3);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
op, std::make_unique<ast::IdentifierExpression>("val"),
|
||||
std::make_unique<ast::IdentifierExpression>("val"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsVector());
|
||||
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsBool());
|
||||
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
|
||||
Expr_Relational_CompareTest,
|
||||
testing::Values(ast::Relation::kEqual,
|
||||
ast::Relation::kNotEqual,
|
||||
ast::Relation::kLessThan,
|
||||
ast::Relation::kGreaterThan,
|
||||
ast::Relation::kLessThanEqual,
|
||||
ast::Relation::kGreaterThanEqual));
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Scalar) {
|
||||
ast::type::I32Type i32;
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("val", ast::StorageClass::kNone, &i32);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(var));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
ast::Relation::kMultiply,
|
||||
std::make_unique<ast::IdentifierExpression>("val"),
|
||||
std::make_unique<ast::IdentifierExpression>("val"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
EXPECT_TRUE(expr.result_type()->IsI32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Scalar) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
||||
auto scalar =
|
||||
std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
|
||||
auto vector = std::make_unique<ast::Variable>(
|
||||
"vector", ast::StorageClass::kNone, &vec3);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(scalar));
|
||||
m.AddGlobalVariable(std::move(vector));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
ast::Relation::kMultiply,
|
||||
std::make_unique<ast::IdentifierExpression>("vector"),
|
||||
std::make_unique<ast::IdentifierExpression>("scalar"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsVector());
|
||||
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Vector) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
||||
auto scalar =
|
||||
std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
|
||||
auto vector = std::make_unique<ast::Variable>(
|
||||
"vector", ast::StorageClass::kNone, &vec3);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(scalar));
|
||||
m.AddGlobalVariable(std::move(vector));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
ast::Relation::kMultiply,
|
||||
std::make_unique<ast::IdentifierExpression>("scalar"),
|
||||
std::make_unique<ast::IdentifierExpression>("vector"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsVector());
|
||||
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Vector) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
||||
auto vector = std::make_unique<ast::Variable>(
|
||||
"vector", ast::StorageClass::kNone, &vec3);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(vector));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
ast::Relation::kMultiply,
|
||||
std::make_unique<ast::IdentifierExpression>("vector"),
|
||||
std::make_unique<ast::IdentifierExpression>("vector"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsVector());
|
||||
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Scalar) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::MatrixType mat3x2(&f32, 3, 2);
|
||||
|
||||
auto scalar =
|
||||
std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
|
||||
auto matrix = std::make_unique<ast::Variable>(
|
||||
"matrix", ast::StorageClass::kNone, &mat3x2);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(scalar));
|
||||
m.AddGlobalVariable(std::move(matrix));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
ast::Relation::kMultiply,
|
||||
std::make_unique<ast::IdentifierExpression>("matrix"),
|
||||
std::make_unique<ast::IdentifierExpression>("scalar"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsMatrix());
|
||||
|
||||
auto mat = expr.result_type()->AsMatrix();
|
||||
EXPECT_TRUE(mat->type()->IsF32());
|
||||
EXPECT_EQ(mat->rows(), 3);
|
||||
EXPECT_EQ(mat->columns(), 2);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Scalar_Matrix) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::MatrixType mat3x2(&f32, 3, 2);
|
||||
|
||||
auto scalar =
|
||||
std::make_unique<ast::Variable>("scalar", ast::StorageClass::kNone, &f32);
|
||||
auto matrix = std::make_unique<ast::Variable>(
|
||||
"matrix", ast::StorageClass::kNone, &mat3x2);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(scalar));
|
||||
m.AddGlobalVariable(std::move(matrix));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
ast::Relation::kMultiply,
|
||||
std::make_unique<ast::IdentifierExpression>("scalar"),
|
||||
std::make_unique<ast::IdentifierExpression>("matrix"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsMatrix());
|
||||
|
||||
auto mat = expr.result_type()->AsMatrix();
|
||||
EXPECT_TRUE(mat->type()->IsF32());
|
||||
EXPECT_EQ(mat->rows(), 3);
|
||||
EXPECT_EQ(mat->columns(), 2);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Vector) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 2);
|
||||
ast::type::MatrixType mat3x2(&f32, 3, 2);
|
||||
|
||||
auto vector = std::make_unique<ast::Variable>(
|
||||
"vector", ast::StorageClass::kNone, &vec3);
|
||||
auto matrix = std::make_unique<ast::Variable>(
|
||||
"matrix", ast::StorageClass::kNone, &mat3x2);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(vector));
|
||||
m.AddGlobalVariable(std::move(matrix));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
ast::Relation::kMultiply,
|
||||
std::make_unique<ast::IdentifierExpression>("matrix"),
|
||||
std::make_unique<ast::IdentifierExpression>("vector"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsVector());
|
||||
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Vector_Matrix) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
ast::type::MatrixType mat3x2(&f32, 3, 2);
|
||||
|
||||
auto vector = std::make_unique<ast::Variable>(
|
||||
"vector", ast::StorageClass::kNone, &vec3);
|
||||
auto matrix = std::make_unique<ast::Variable>(
|
||||
"matrix", ast::StorageClass::kNone, &mat3x2);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(vector));
|
||||
m.AddGlobalVariable(std::move(matrix));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
ast::Relation::kMultiply,
|
||||
std::make_unique<ast::IdentifierExpression>("vector"),
|
||||
std::make_unique<ast::IdentifierExpression>("matrix"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsVector());
|
||||
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
|
||||
EXPECT_EQ(expr.result_type()->AsVector()->size(), 2);
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Expr_Relational_Multiply_Matrix_Matrix) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::MatrixType mat4x3(&f32, 4, 3);
|
||||
ast::type::MatrixType mat3x4(&f32, 3, 4);
|
||||
|
||||
auto matrix1 = std::make_unique<ast::Variable>(
|
||||
"mat4x3", ast::StorageClass::kNone, &mat4x3);
|
||||
auto matrix2 = std::make_unique<ast::Variable>(
|
||||
"mat3x4", ast::StorageClass::kNone, &mat3x4);
|
||||
|
||||
Context ctx;
|
||||
TypeDeterminer td(&ctx);
|
||||
|
||||
ast::Module m;
|
||||
m.AddGlobalVariable(std::move(matrix1));
|
||||
m.AddGlobalVariable(std::move(matrix2));
|
||||
|
||||
// Register the global
|
||||
ASSERT_TRUE(td.Determine(&m)) << td.error();
|
||||
|
||||
ast::RelationalExpression expr(
|
||||
ast::Relation::kMultiply,
|
||||
std::make_unique<ast::IdentifierExpression>("mat4x3"),
|
||||
std::make_unique<ast::IdentifierExpression>("mat3x4"));
|
||||
|
||||
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
ASSERT_TRUE(expr.result_type()->IsMatrix());
|
||||
|
||||
auto mat = expr.result_type()->AsMatrix();
|
||||
EXPECT_TRUE(mat->type()->IsF32());
|
||||
EXPECT_EQ(mat->rows(), 4);
|
||||
EXPECT_EQ(mat->columns(), 4);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint
|
||||
|
|
Loading…
Reference in New Issue