Add type determination for unary method.
This CL adds the type determination for the unary method expression. Bug: tint:5 Change-Id: I9f94a79b9715cf74e37c74eb1a612ca84b3c241f Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/18849 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
b173056fca
commit
8dcfd108a9
|
@ -36,11 +36,13 @@
|
||||||
#include "src/ast/switch_statement.h"
|
#include "src/ast/switch_statement.h"
|
||||||
#include "src/ast/type/array_type.h"
|
#include "src/ast/type/array_type.h"
|
||||||
#include "src/ast/type/bool_type.h"
|
#include "src/ast/type/bool_type.h"
|
||||||
|
#include "src/ast/type/f32_type.h"
|
||||||
#include "src/ast/type/matrix_type.h"
|
#include "src/ast/type/matrix_type.h"
|
||||||
#include "src/ast/type/struct_type.h"
|
#include "src/ast/type/struct_type.h"
|
||||||
#include "src/ast/type/vector_type.h"
|
#include "src/ast/type/vector_type.h"
|
||||||
#include "src/ast/type_constructor_expression.h"
|
#include "src/ast/type_constructor_expression.h"
|
||||||
#include "src/ast/unary_derivative_expression.h"
|
#include "src/ast/unary_derivative_expression.h"
|
||||||
|
#include "src/ast/unary_method_expression.h"
|
||||||
#include "src/ast/unless_statement.h"
|
#include "src/ast/unless_statement.h"
|
||||||
#include "src/ast/variable_decl_statement.h"
|
#include "src/ast/variable_decl_statement.h"
|
||||||
|
|
||||||
|
@ -215,6 +217,9 @@ bool TypeDeterminer::DetermineResultType(ast::Expression* expr) {
|
||||||
if (expr->IsUnaryDerivative()) {
|
if (expr->IsUnaryDerivative()) {
|
||||||
return DetermineUnaryDerivative(expr->AsUnaryDerivative());
|
return DetermineUnaryDerivative(expr->AsUnaryDerivative());
|
||||||
}
|
}
|
||||||
|
if (expr->IsUnaryMethod()) {
|
||||||
|
return DetermineUnaryMethod(expr->AsUnaryMethod());
|
||||||
|
}
|
||||||
|
|
||||||
error_ = "unknown expression for type determination";
|
error_ = "unknown expression for type determination";
|
||||||
return false;
|
return false;
|
||||||
|
@ -415,4 +420,66 @@ bool TypeDeterminer::DetermineUnaryDerivative(
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool TypeDeterminer::DetermineUnaryMethod(ast::UnaryMethodExpression* expr) {
|
||||||
|
for (const auto& param : expr->params()) {
|
||||||
|
if (!DetermineResultType(param.get())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (expr->op()) {
|
||||||
|
case ast::UnaryMethod::kAny:
|
||||||
|
case ast::UnaryMethod::kAll: {
|
||||||
|
expr->set_result_type(
|
||||||
|
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ast::UnaryMethod::kIsNan:
|
||||||
|
case ast::UnaryMethod::kIsInf:
|
||||||
|
case ast::UnaryMethod::kIsFinite:
|
||||||
|
case ast::UnaryMethod::kIsNormal: {
|
||||||
|
if (expr->params().empty()) {
|
||||||
|
error_ = "incorrect number of parameters";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto bool_type =
|
||||||
|
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
|
||||||
|
auto param_type = expr->params()[0]->result_type();
|
||||||
|
if (param_type->IsVector()) {
|
||||||
|
expr->set_result_type(
|
||||||
|
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
||||||
|
bool_type, param_type->AsVector()->size())));
|
||||||
|
} else {
|
||||||
|
expr->set_result_type(bool_type);
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ast::UnaryMethod::kDot: {
|
||||||
|
expr->set_result_type(
|
||||||
|
ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case ast::UnaryMethod::kOuterProduct: {
|
||||||
|
if (expr->params().size() != 2) {
|
||||||
|
error_ = "incorrect number of parameters for outer product";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto param0_type = expr->params()[0]->result_type();
|
||||||
|
auto param1_type = expr->params()[1]->result_type();
|
||||||
|
if (!param0_type->IsVector() || !param1_type->IsVector()) {
|
||||||
|
error_ = "invalid parameter type for outer product";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
expr->set_result_type(
|
||||||
|
ctx_.type_mgr().Get(std::make_unique<ast::type::MatrixType>(
|
||||||
|
ctx_.type_mgr().Get(std::make_unique<ast::type::F32Type>()),
|
||||||
|
param0_type->AsVector()->size(),
|
||||||
|
param1_type->AsVector()->size())));
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
|
@ -35,6 +35,7 @@ class IdentifierExpression;
|
||||||
class MemberAccessorExpression;
|
class MemberAccessorExpression;
|
||||||
class RelationalExpression;
|
class RelationalExpression;
|
||||||
class UnaryDerivativeExpression;
|
class UnaryDerivativeExpression;
|
||||||
|
class UnaryMethodExpression;
|
||||||
class Variable;
|
class Variable;
|
||||||
|
|
||||||
} // namespace ast
|
} // namespace ast
|
||||||
|
@ -85,6 +86,7 @@ class TypeDeterminer {
|
||||||
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
|
bool DetermineMemberAccessor(ast::MemberAccessorExpression* expr);
|
||||||
bool DetermineRelational(ast::RelationalExpression* expr);
|
bool DetermineRelational(ast::RelationalExpression* expr);
|
||||||
bool DetermineUnaryDerivative(ast::UnaryDerivativeExpression* expr);
|
bool DetermineUnaryDerivative(ast::UnaryDerivativeExpression* expr);
|
||||||
|
bool DetermineUnaryMethod(ast::UnaryMethodExpression* expr);
|
||||||
|
|
||||||
Context& ctx_;
|
Context& ctx_;
|
||||||
std::string error_;
|
std::string error_;
|
||||||
|
|
|
@ -50,6 +50,7 @@
|
||||||
#include "src/ast/type/vector_type.h"
|
#include "src/ast/type/vector_type.h"
|
||||||
#include "src/ast/type_constructor_expression.h"
|
#include "src/ast/type_constructor_expression.h"
|
||||||
#include "src/ast/unary_derivative_expression.h"
|
#include "src/ast/unary_derivative_expression.h"
|
||||||
|
#include "src/ast/unary_method_expression.h"
|
||||||
#include "src/ast/unless_statement.h"
|
#include "src/ast/unless_statement.h"
|
||||||
#include "src/ast/variable_decl_statement.h"
|
#include "src/ast/variable_decl_statement.h"
|
||||||
|
|
||||||
|
@ -1268,5 +1269,164 @@ INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
|
||||||
ast::UnaryDerivative::kDpdy,
|
ast::UnaryDerivative::kDpdy,
|
||||||
ast::UnaryDerivative::kFwidth));
|
ast::UnaryDerivative::kFwidth));
|
||||||
|
|
||||||
|
using UnaryMethodExpressionBoolTest = testing::TestWithParam<ast::UnaryMethod>;
|
||||||
|
TEST_P(UnaryMethodExpressionBoolTest, Expr_UnaryMethod_Any) {
|
||||||
|
auto op = GetParam();
|
||||||
|
|
||||||
|
ast::type::BoolType bool_type;
|
||||||
|
ast::type::VectorType vec3(&bool_type, 3);
|
||||||
|
|
||||||
|
auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
|
||||||
|
&vec3);
|
||||||
|
|
||||||
|
ast::Module m;
|
||||||
|
m.AddGlobalVariable(std::move(var));
|
||||||
|
|
||||||
|
ast::ExpressionList params;
|
||||||
|
params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
|
||||||
|
|
||||||
|
ast::UnaryMethodExpression exp(op, std::move(params));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
TypeDeterminer td(&ctx);
|
||||||
|
|
||||||
|
// Register the variable
|
||||||
|
EXPECT_TRUE(td.Determine(&m));
|
||||||
|
|
||||||
|
EXPECT_TRUE(td.DetermineResultType(&exp));
|
||||||
|
ASSERT_NE(exp.result_type(), nullptr);
|
||||||
|
EXPECT_TRUE(exp.result_type()->IsBool());
|
||||||
|
}
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
|
||||||
|
UnaryMethodExpressionBoolTest,
|
||||||
|
testing::Values(ast::UnaryMethod::kAny,
|
||||||
|
ast::UnaryMethod::kAll));
|
||||||
|
|
||||||
|
using UnaryMethodExpressionVecTest = testing::TestWithParam<ast::UnaryMethod>;
|
||||||
|
TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Bool) {
|
||||||
|
auto op = GetParam();
|
||||||
|
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
ast::type::VectorType vec3(&f32, 3);
|
||||||
|
|
||||||
|
auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
|
||||||
|
&vec3);
|
||||||
|
|
||||||
|
ast::Module m;
|
||||||
|
m.AddGlobalVariable(std::move(var));
|
||||||
|
|
||||||
|
ast::ExpressionList params;
|
||||||
|
params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
|
||||||
|
|
||||||
|
ast::UnaryMethodExpression exp(op, std::move(params));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
TypeDeterminer td(&ctx);
|
||||||
|
|
||||||
|
// Register the variable
|
||||||
|
EXPECT_TRUE(td.Determine(&m));
|
||||||
|
|
||||||
|
EXPECT_TRUE(td.DetermineResultType(&exp));
|
||||||
|
ASSERT_NE(exp.result_type(), nullptr);
|
||||||
|
ASSERT_TRUE(exp.result_type()->IsVector());
|
||||||
|
EXPECT_TRUE(exp.result_type()->AsVector()->type()->IsBool());
|
||||||
|
EXPECT_EQ(exp.result_type()->AsVector()->size(), 3);
|
||||||
|
}
|
||||||
|
TEST_P(UnaryMethodExpressionVecTest, Expr_UnaryMethod_Vec) {
|
||||||
|
auto op = GetParam();
|
||||||
|
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
|
||||||
|
auto var =
|
||||||
|
std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone, &f32);
|
||||||
|
|
||||||
|
ast::Module m;
|
||||||
|
m.AddGlobalVariable(std::move(var));
|
||||||
|
|
||||||
|
ast::ExpressionList params;
|
||||||
|
params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
|
||||||
|
|
||||||
|
ast::UnaryMethodExpression exp(op, std::move(params));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
TypeDeterminer td(&ctx);
|
||||||
|
|
||||||
|
// Register the variable
|
||||||
|
EXPECT_TRUE(td.Determine(&m));
|
||||||
|
|
||||||
|
EXPECT_TRUE(td.DetermineResultType(&exp));
|
||||||
|
ASSERT_NE(exp.result_type(), nullptr);
|
||||||
|
EXPECT_TRUE(exp.result_type()->IsBool());
|
||||||
|
}
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest,
|
||||||
|
UnaryMethodExpressionVecTest,
|
||||||
|
testing::Values(ast::UnaryMethod::kIsInf,
|
||||||
|
ast::UnaryMethod::kIsNan,
|
||||||
|
ast::UnaryMethod::kIsFinite,
|
||||||
|
ast::UnaryMethod::kIsNormal));
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Expr_UnaryMethod_Dot) {
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
ast::type::VectorType vec3(&f32, 3);
|
||||||
|
|
||||||
|
auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
|
||||||
|
&vec3);
|
||||||
|
|
||||||
|
ast::Module m;
|
||||||
|
m.AddGlobalVariable(std::move(var));
|
||||||
|
|
||||||
|
ast::ExpressionList params;
|
||||||
|
params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
|
||||||
|
params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
|
||||||
|
|
||||||
|
ast::UnaryMethodExpression exp(ast::UnaryMethod::kDot, std::move(params));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
TypeDeterminer td(&ctx);
|
||||||
|
|
||||||
|
// Register the variable
|
||||||
|
EXPECT_TRUE(td.Determine(&m));
|
||||||
|
|
||||||
|
EXPECT_TRUE(td.DetermineResultType(&exp));
|
||||||
|
ASSERT_NE(exp.result_type(), nullptr);
|
||||||
|
EXPECT_TRUE(exp.result_type()->IsF32());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Expr_UnaryMethod_OuterProduct) {
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
ast::type::VectorType vec3(&f32, 3);
|
||||||
|
ast::type::VectorType vec2(&f32, 2);
|
||||||
|
|
||||||
|
auto var1 =
|
||||||
|
std::make_unique<ast::Variable>("v3", ast::StorageClass::kNone, &vec3);
|
||||||
|
auto var2 =
|
||||||
|
std::make_unique<ast::Variable>("v2", ast::StorageClass::kNone, &vec2);
|
||||||
|
|
||||||
|
ast::Module m;
|
||||||
|
m.AddGlobalVariable(std::move(var1));
|
||||||
|
m.AddGlobalVariable(std::move(var2));
|
||||||
|
|
||||||
|
ast::ExpressionList params;
|
||||||
|
params.push_back(std::make_unique<ast::IdentifierExpression>("v3"));
|
||||||
|
params.push_back(std::make_unique<ast::IdentifierExpression>("v2"));
|
||||||
|
|
||||||
|
ast::UnaryMethodExpression exp(ast::UnaryMethod::kOuterProduct,
|
||||||
|
std::move(params));
|
||||||
|
|
||||||
|
Context ctx;
|
||||||
|
TypeDeterminer td(&ctx);
|
||||||
|
|
||||||
|
// Register the variable
|
||||||
|
EXPECT_TRUE(td.Determine(&m));
|
||||||
|
|
||||||
|
EXPECT_TRUE(td.DetermineResultType(&exp));
|
||||||
|
ASSERT_NE(exp.result_type(), nullptr);
|
||||||
|
ASSERT_TRUE(exp.result_type()->IsMatrix());
|
||||||
|
auto mat = exp.result_type()->AsMatrix();
|
||||||
|
EXPECT_TRUE(mat->type()->IsF32());
|
||||||
|
EXPECT_EQ(mat->rows(), 3);
|
||||||
|
EXPECT_EQ(mat->columns(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tint
|
} // namespace tint
|
||||||
|
|
Loading…
Reference in New Issue