diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 8f07dc4f6d..c8d495217b 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -36,11 +36,13 @@ #include "src/ast/switch_statement.h" #include "src/ast/type/array_type.h" #include "src/ast/type/bool_type.h" +#include "src/ast/type/f32_type.h" #include "src/ast/type/matrix_type.h" #include "src/ast/type/struct_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" #include "src/ast/unary_derivative_expression.h" +#include "src/ast/unary_method_expression.h" #include "src/ast/unless_statement.h" #include "src/ast/variable_decl_statement.h" @@ -215,6 +217,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) { if (expr->IsUnaryDerivative()) { return DetermineUnaryDerivative(expr->AsUnaryDerivative()); } + if (expr->IsUnaryMethod()) { + return DetermineUnaryMethod(expr->AsUnaryMethod()); + } error_ = "unknown expression for type determination"; return false; @@ -415,4 +420,66 @@ bool TypeDeterminer::DetermineUnaryDerivative( return true; } +bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) { + for (const auto& param : expr->params()) { + if (!DetermineResultType(param.get())) { + return false; + } + } + + switch (expr->op()) { + case ast::UnaryMethod::kAny: + case ast::UnaryMethod::kAll: { + expr->set_result_type( + ctx_.type_mgr().Get(std::make_unique())); + break; + } + case ast::UnaryMethod::kIsNan: + case ast::UnaryMethod::kIsInf: + case ast::UnaryMethod::kIsFinite: + case ast::UnaryMethod::kIsNormal: { + if (expr->params().empty()) { + error_ = "incorrect number of parameters"; + return false; + } + + auto bool_type = + ctx_.type_mgr().Get(std::make_unique()); + auto param_type = expr->params()[0]->result_type(); + if (param_type->IsVector()) { + expr->set_result_type( + ctx_.type_mgr().Get(std::make_unique( + bool_type, param_type->AsVector()->size()))); + } else { + expr->set_result_type(bool_type); + } + break; + } + case ast::UnaryMethod::kDot: { + expr->set_result_type( + ctx_.type_mgr().Get(std::make_unique())); + break; + } + case ast::UnaryMethod::kOuterProduct: { + if (expr->params().size() != 2) { + error_ = "incorrect number of parameters for outer product"; + return false; + } + auto param0_type = expr->params()[0]->result_type(); + auto param1_type = expr->params()[1]->result_type(); + if (!param0_type->IsVector() || !param1_type->IsVector()) { + error_ = "invalid parameter type for outer product"; + return false; + } + expr->set_result_type( + ctx_.type_mgr().Get(std::make_unique( + ctx_.type_mgr().Get(std::make_unique()), + param0_type->AsVector()->size(), + param1_type->AsVector()->size()))); + break; + } + } + return true; +} + } // namespace tint diff --git a/src/type_determiner.h b/src/type_determiner.h index de6a18889e..126b755491 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -35,6 +35,7 @@ class IdentifierExpression; class MemberAccessorExpression; class RelationalExpression; class UnaryDerivativeExpression; +class UnaryMethodExpression; class Variable; } // namespace ast @@ -85,6 +86,7 @@ class TypeDeterminer { bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr); bool DetermineRelational(ast::RelationalExpression* expr); bool DetermineUnaryDerivative(ast::UnaryDerivativeExpression* expr); + bool DetermineUnaryMethod(ast::UnaryMethodExpression* expr); Context& ctx_; std::string error_; diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index a41691672d..1d63956b1e 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -50,6 +50,7 @@ #include "src/ast/type/vector_type.h" #include "src/ast/type_constructor_expression.h" #include "src/ast/unary_derivative_expression.h" +#include "src/ast/unary_method_expression.h" #include "src/ast/unless_statement.h" #include "src/ast/variable_decl_statement.h" @@ -1268,5 +1269,164 @@ INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, ast::UnaryDerivative::kDpdy, ast::UnaryDerivative::kFwidth)); +using UnaryMethodExpressionBoolTest = testing::TestWithParam; +TEST_P(UnaryMethodExpressionBoolTest, Expr_UnaryMethod_Any) { + auto op = GetParam(); + + ast::type::BoolType bool_type; + ast::type::VectorType vec3(&bool_type, 3); + + auto var = std::make_unique("my_var", ast::StorageClass::kNone, + &vec3); + + ast::Module m; + m.AddGlobalVariable(std::move(var)); + + ast::ExpressionList params; + params.push_back(std::make_unique("my_var")); + + ast::UnaryMethodExpression exp(op, std::move(params)); + + Context ctx; + TypeDeterminer td(&ctx); + + // Register the variable + EXPECT_TRUE(td.Determine(&m)); + + EXPECT_TRUE(td.DetermineResultType(&exp)); + ASSERT_NE(exp.result_type(), nullptr); + EXPECT_TRUE(exp.result_type()->IsBool()); +} +INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, + UnaryMethodExpressionBoolTest, + testing::Values(ast::UnaryMethod::kAny, + ast::UnaryMethod::kAll)); + +using UnaryMethodExpressionVecTest = testing::TestWithParam; +TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Bool) { + auto op = GetParam(); + + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + + auto var = std::make_unique("my_var", ast::StorageClass::kNone, + &vec3); + + ast::Module m; + m.AddGlobalVariable(std::move(var)); + + ast::ExpressionList params; + params.push_back(std::make_unique("my_var")); + + ast::UnaryMethodExpression exp(op, std::move(params)); + + Context ctx; + TypeDeterminer td(&ctx); + + // Register the variable + EXPECT_TRUE(td.Determine(&m)); + + EXPECT_TRUE(td.DetermineResultType(&exp)); + ASSERT_NE(exp.result_type(), nullptr); + ASSERT_TRUE(exp.result_type()->IsVector()); + EXPECT_TRUE(exp.result_type()->AsVector()->type()->IsBool()); + EXPECT_EQ(exp.result_type()->AsVector()->size(), 3); +} +TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Vec) { + auto op = GetParam(); + + ast::type::F32Type f32; + + auto var = + std::make_unique("my_var", ast::StorageClass::kNone, &f32); + + ast::Module m; + m.AddGlobalVariable(std::move(var)); + + ast::ExpressionList params; + params.push_back(std::make_unique("my_var")); + + ast::UnaryMethodExpression exp(op, std::move(params)); + + Context ctx; + TypeDeterminer td(&ctx); + + // Register the variable + EXPECT_TRUE(td.Determine(&m)); + + EXPECT_TRUE(td.DetermineResultType(&exp)); + ASSERT_NE(exp.result_type(), nullptr); + EXPECT_TRUE(exp.result_type()->IsBool()); +} +INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, + UnaryMethodExpressionVecTest, + testing::Values(ast::UnaryMethod::kIsInf, + ast::UnaryMethod::kIsNan, + ast::UnaryMethod::kIsFinite, + ast::UnaryMethod::kIsNormal)); + +TEST_F(TypeDeterminerTest, Expr_UnaryMethod_Dot) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + + auto var = std::make_unique("my_var", ast::StorageClass::kNone, + &vec3); + + ast::Module m; + m.AddGlobalVariable(std::move(var)); + + ast::ExpressionList params; + params.push_back(std::make_unique("my_var")); + params.push_back(std::make_unique("my_var")); + + ast::UnaryMethodExpression exp(ast::UnaryMethod::kDot, std::move(params)); + + Context ctx; + TypeDeterminer td(&ctx); + + // Register the variable + EXPECT_TRUE(td.Determine(&m)); + + EXPECT_TRUE(td.DetermineResultType(&exp)); + ASSERT_NE(exp.result_type(), nullptr); + EXPECT_TRUE(exp.result_type()->IsF32()); +} + +TEST_F(TypeDeterminerTest, Expr_UnaryMethod_OuterProduct) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + ast::type::VectorType vec2(&f32, 2); + + auto var1 = + std::make_unique("v3", ast::StorageClass::kNone, &vec3); + auto var2 = + std::make_unique("v2", ast::StorageClass::kNone, &vec2); + + ast::Module m; + m.AddGlobalVariable(std::move(var1)); + m.AddGlobalVariable(std::move(var2)); + + ast::ExpressionList params; + params.push_back(std::make_unique("v3")); + params.push_back(std::make_unique("v2")); + + ast::UnaryMethodExpression exp(ast::UnaryMethod::kOuterProduct, + std::move(params)); + + Context ctx; + TypeDeterminer td(&ctx); + + // Register the variable + EXPECT_TRUE(td.Determine(&m)); + + EXPECT_TRUE(td.DetermineResultType(&exp)); + ASSERT_NE(exp.result_type(), nullptr); + ASSERT_TRUE(exp.result_type()->IsMatrix()); + auto mat = exp.result_type()->AsMatrix(); + EXPECT_TRUE(mat->type()->IsF32()); + EXPECT_EQ(mat->rows(), 3); + EXPECT_EQ(mat->columns(), 2); +} + } // namespace } // namespace tint