Add type determination for unary method.

This CL adds the type determination for the unary method expression.

Bug: tint:5
Change-Id: I9f94a79b9715cf74e37c74eb1a612ca84b3c241f
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18849
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-04-07 19:27:00 +00:00
parent b173056fca
commit 8dcfd108a9
3 changed files with 229 additions and 0 deletions

View File

@ -36,11 +36,13 @@
#include "src/ast/switch_statement.h" #include "src/ast/switch_statement.h"
#include "src/ast/type/array_type.h" #include "src/ast/type/array_type.h"
#include "src/ast/type/bool_type.h" #include "src/ast/type/bool_type.h"
#include "src/ast/type/f32_type.h"
#include "src/ast/type/matrix_type.h" #include "src/ast/type/matrix_type.h"
#include "src/ast/type/struct_type.h" #include "src/ast/type/struct_type.h"
#include "src/ast/type/vector_type.h" #include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h" #include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_derivative_expression.h" #include "src/ast/unary_derivative_expression.h"
#include "src/ast/unary_method_expression.h"
#include "src/ast/unless_statement.h" #include "src/ast/unless_statement.h"
#include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decl_statement.h"
@ -215,6 +217,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
if (expr->IsUnaryDerivative()) { if (expr->IsUnaryDerivative()) {
return DetermineUnaryDerivative(expr->AsUnaryDerivative()); return DetermineUnaryDerivative(expr->AsUnaryDerivative());
} }
if (expr->IsUnaryMethod()) {
return DetermineUnaryMethod(expr->AsUnaryMethod());
}
error_ = "unknown expression for type determination"; error_ = "unknown expression for type determination";
return false; return false;
@ -415,4 +420,66 @@ bool TypeDeterminer::DetermineUnaryDerivative(
return true; return true;
} }
bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) {
for (const auto& param : expr->params()) {
if (!DetermineResultType(param.get())) {
return false;
}
}
switch (expr->op()) {
case ast::UnaryMethod::kAny:
case ast::UnaryMethod::kAll: {
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()));
break;
}
case ast::UnaryMethod::kIsNan:
case ast::UnaryMethod::kIsInf:
case ast::UnaryMethod::kIsFinite:
case ast::UnaryMethod::kIsNormal: {
if (expr->params().empty()) {
error_ = "incorrect number of parameters";
return false;
}
auto bool_type =
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
auto param_type = expr->params()[0]->result_type();
if (param_type->IsVector()) {
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
bool_type, param_type->AsVector()->size())));
} else {
expr->set_result_type(bool_type);
}
break;
}
case ast::UnaryMethod::kDot: {
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()));
break;
}
case ast::UnaryMethod::kOuterProduct: {
if (expr->params().size() != 2) {
error_ = "incorrect number of parameters for outer product";
return false;
}
auto param0_type = expr->params()[0]->result_type();
auto param1_type = expr->params()[1]->result_type();
if (!param0_type->IsVector() || !param1_type->IsVector()) {
error_ = "invalid parameter type for outer product";
return false;
}
expr->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()),
param0_type->AsVector()->size(),
param1_type->AsVector()->size())));
break;
}
}
return true;
}
} // namespace tint } // namespace tint

View File

@ -35,6 +35,7 @@ class IdentifierExpression;
class MemberAccessorExpression; class MemberAccessorExpression;
class RelationalExpression; class RelationalExpression;
class UnaryDerivativeExpression; class UnaryDerivativeExpression;
class UnaryMethodExpression;
class Variable; class Variable;
} // namespace ast } // namespace ast
@ -85,6 +86,7 @@ class TypeDeterminer {
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr); bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
bool DetermineRelational(ast::RelationalExpression* expr); bool DetermineRelational(ast::RelationalExpression* expr);
bool DetermineUnaryDerivative(ast::UnaryDerivativeExpression* expr); bool DetermineUnaryDerivative(ast::UnaryDerivativeExpression* expr);
bool DetermineUnaryMethod(ast::UnaryMethodExpression* expr);
Context& ctx_; Context& ctx_;
std::string error_; std::string error_;

View File

@ -50,6 +50,7 @@
#include "src/ast/type/vector_type.h" #include "src/ast/type/vector_type.h"
#include "src/ast/type_constructor_expression.h" #include "src/ast/type_constructor_expression.h"
#include "src/ast/unary_derivative_expression.h" #include "src/ast/unary_derivative_expression.h"
#include "src/ast/unary_method_expression.h"
#include "src/ast/unless_statement.h" #include "src/ast/unless_statement.h"
#include "src/ast/variable_decl_statement.h" #include "src/ast/variable_decl_statement.h"
@ -1268,5 +1269,164 @@ INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
ast::UnaryDerivative::kDpdy, ast::UnaryDerivative::kDpdy,
ast::UnaryDerivative::kFwidth)); ast::UnaryDerivative::kFwidth));
using UnaryMethodExpressionBoolTest = testing::TestWithParam<ast::UnaryMethod>;
TEST_P(UnaryMethodExpressionBoolTest, Expr_UnaryMethod_Any) {
auto op = GetParam();
ast::type::BoolType bool_type;
ast::type::VectorType vec3(&bool_type, 3);
auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
&vec3);
ast::Module m;
m.AddGlobalVariable(std::move(var));
ast::ExpressionList params;
params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
ast::UnaryMethodExpression exp(op, std::move(params));
Context ctx;
TypeDeterminer td(&ctx);
// Register the variable
EXPECT_TRUE(td.Determine(&m));
EXPECT_TRUE(td.DetermineResultType(&exp));
ASSERT_NE(exp.result_type(), nullptr);
EXPECT_TRUE(exp.result_type()->IsBool());
}
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
UnaryMethodExpressionBoolTest,
testing::Values(ast::UnaryMethod::kAny,
ast::UnaryMethod::kAll));
using UnaryMethodExpressionVecTest = testing::TestWithParam<ast::UnaryMethod>;
TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Bool) {
auto op = GetParam();
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
&vec3);
ast::Module m;
m.AddGlobalVariable(std::move(var));
ast::ExpressionList params;
params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
ast::UnaryMethodExpression exp(op, std::move(params));
Context ctx;
TypeDeterminer td(&ctx);
// Register the variable
EXPECT_TRUE(td.Determine(&m));
EXPECT_TRUE(td.DetermineResultType(&exp));
ASSERT_NE(exp.result_type(), nullptr);
ASSERT_TRUE(exp.result_type()->IsVector());
EXPECT_TRUE(exp.result_type()->AsVector()->type()->IsBool());
EXPECT_EQ(exp.result_type()->AsVector()->size(), 3);
}
TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Vec) {
auto op = GetParam();
ast::type::F32Type f32;
auto var =
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
ast::Module m;
m.AddGlobalVariable(std::move(var));
ast::ExpressionList params;
params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
ast::UnaryMethodExpression exp(op, std::move(params));
Context ctx;
TypeDeterminer td(&ctx);
// Register the variable
EXPECT_TRUE(td.Determine(&m));
EXPECT_TRUE(td.DetermineResultType(&exp));
ASSERT_NE(exp.result_type(), nullptr);
EXPECT_TRUE(exp.result_type()->IsBool());
}
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
UnaryMethodExpressionVecTest,
testing::Values(ast::UnaryMethod::kIsInf,
ast::UnaryMethod::kIsNan,
ast::UnaryMethod::kIsFinite,
ast::UnaryMethod::kIsNormal));
TEST_F(TypeDeterminerTest, Expr_UnaryMethod_Dot) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
&vec3);
ast::Module m;
m.AddGlobalVariable(std::move(var));
ast::ExpressionList params;
params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
ast::UnaryMethodExpression exp(ast::UnaryMethod::kDot, std::move(params));
Context ctx;
TypeDeterminer td(&ctx);
// Register the variable
EXPECT_TRUE(td.Determine(&m));
EXPECT_TRUE(td.DetermineResultType(&exp));
ASSERT_NE(exp.result_type(), nullptr);
EXPECT_TRUE(exp.result_type()->IsF32());
}
TEST_F(TypeDeterminerTest, Expr_UnaryMethod_OuterProduct) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
ast::type::VectorType vec2(&f32, 2);
auto var1 =
std::make_unique<ast::Variable>("v3", ast::StorageClass::kNone, &vec3);
auto var2 =
std::make_unique<ast::Variable>("v2", ast::StorageClass::kNone, &vec2);
ast::Module m;
m.AddGlobalVariable(std::move(var1));
m.AddGlobalVariable(std::move(var2));
ast::ExpressionList params;
params.push_back(std::make_unique<ast::IdentifierExpression>("v3"));
params.push_back(std::make_unique<ast::IdentifierExpression>("v2"));
ast::UnaryMethodExpression exp(ast::UnaryMethod::kOuterProduct,
std::move(params));
Context ctx;
TypeDeterminer td(&ctx);
// Register the variable
EXPECT_TRUE(td.Determine(&m));
EXPECT_TRUE(td.DetermineResultType(&exp));
ASSERT_NE(exp.result_type(), nullptr);
ASSERT_TRUE(exp.result_type()->IsMatrix());
auto mat = exp.result_type()->AsMatrix();
EXPECT_TRUE(mat->type()->IsF32());
EXPECT_EQ(mat->rows(), 3);
EXPECT_EQ(mat->columns(), 2);
}
} // namespace } // namespace
} // namespace tint } // namespace tint