Add type determination for the select intrinsic.
This CL adds type determination for a `select` intrinsic. Bug: tint:106 Change-Id: Ie5c051cb42c72ae732579e3064561a4544a90473 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25380 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
d71e80b710
commit
16a2ea11d3
|
@ -40,7 +40,7 @@ bool IsFloatClassificationIntrinsic(const std::string& name) {
|
||||||
bool IsIntrinsic(const std::string& name) {
|
bool IsIntrinsic(const std::string& name) {
|
||||||
return IsDerivative(name) || name == "all" || name == "any" ||
|
return IsDerivative(name) || name == "all" || name == "any" ||
|
||||||
IsFloatClassificationIntrinsic(name) || name == "dot" ||
|
IsFloatClassificationIntrinsic(name) || name == "dot" ||
|
||||||
name == "outer_product";
|
name == "outer_product" || name == "select";
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace intrinsic
|
} // namespace intrinsic
|
||||||
|
|
|
@ -540,7 +540,11 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name,
|
||||||
auto* bool_type =
|
auto* bool_type =
|
||||||
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
|
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
|
||||||
|
|
||||||
auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
|
auto& param = expr->params()[0];
|
||||||
|
if (!DetermineResultType(param.get())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto* param_type = param->result_type()->UnwrapPtrIfNeeded();
|
||||||
if (param_type->IsVector()) {
|
if (param_type->IsVector()) {
|
||||||
expr->func()->set_result_type(
|
expr->func()->set_result_type(
|
||||||
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
||||||
|
@ -562,8 +566,15 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* param0_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
|
auto& param0 = expr->params()[0];
|
||||||
auto* param1_type = expr->params()[1]->result_type()->UnwrapPtrIfNeeded();
|
auto& param1 = expr->params()[1];
|
||||||
|
if (!DetermineResultType(param0.get()) ||
|
||||||
|
!DetermineResultType(param1.get())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto* param0_type = param0->result_type()->UnwrapPtrIfNeeded();
|
||||||
|
auto* param1_type = param1->result_type()->UnwrapPtrIfNeeded();
|
||||||
if (!param0_type->IsVector() || !param1_type->IsVector()) {
|
if (!param0_type->IsVector() || !param1_type->IsVector()) {
|
||||||
set_error(expr->source(), "invalid parameter type for outer_product");
|
set_error(expr->source(), "invalid parameter type for outer_product");
|
||||||
return false;
|
return false;
|
||||||
|
@ -575,6 +586,22 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name,
|
||||||
param0_type->AsVector()->size(), param1_type->AsVector()->size())));
|
param0_type->AsVector()->size(), param1_type->AsVector()->size())));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (name == "select") {
|
||||||
|
if (expr->params().size() != 3) {
|
||||||
|
set_error(expr->source(),
|
||||||
|
"incorrect number of parameters for select expected 3 got " +
|
||||||
|
std::to_string(expr->params().size()));
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// The result type must be the same as the type of the parameter.
|
||||||
|
auto& param = expr->params()[0];
|
||||||
|
if (!DetermineResultType(param.get())) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
expr->func()->set_result_type(param->result_type()->UnwrapPtrIfNeeded());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1769,6 +1769,85 @@ TEST_F(TypeDeterminerTest, Intrinsic_Dot) {
|
||||||
EXPECT_TRUE(expr.result_type()->IsF32());
|
EXPECT_TRUE(expr.result_type()->IsF32());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Intrinsic_Select) {
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
ast::type::BoolType bool_type;
|
||||||
|
ast::type::VectorType vec3(&f32, 3);
|
||||||
|
ast::type::VectorType bool_vec3(&bool_type, 3);
|
||||||
|
|
||||||
|
auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
|
||||||
|
&vec3);
|
||||||
|
auto bool_var = std::make_unique<ast::Variable>(
|
||||||
|
"bool_var", ast::StorageClass::kNone, &bool_vec3);
|
||||||
|
mod()->AddGlobalVariable(std::move(var));
|
||||||
|
mod()->AddGlobalVariable(std::move(bool_var));
|
||||||
|
|
||||||
|
ast::ExpressionList call_params;
|
||||||
|
call_params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
|
||||||
|
call_params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
|
||||||
|
call_params.push_back(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("bool_var"));
|
||||||
|
|
||||||
|
ast::CallExpression expr(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("select"),
|
||||||
|
std::move(call_params));
|
||||||
|
|
||||||
|
// Register the variable
|
||||||
|
EXPECT_TRUE(td()->Determine());
|
||||||
|
EXPECT_TRUE(td()->DetermineResultType(&expr)) << td()->error();
|
||||||
|
ASSERT_NE(expr.result_type(), nullptr);
|
||||||
|
EXPECT_TRUE(expr.result_type()->IsVector());
|
||||||
|
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3u);
|
||||||
|
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Intrinsic_Select_TooFewParams) {
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
ast::type::VectorType vec3(&f32, 3);
|
||||||
|
|
||||||
|
auto var =
|
||||||
|
std::make_unique<ast::Variable>("v", ast::StorageClass::kNone, &vec3);
|
||||||
|
mod()->AddGlobalVariable(std::move(var));
|
||||||
|
|
||||||
|
ast::ExpressionList call_params;
|
||||||
|
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
|
||||||
|
|
||||||
|
ast::CallExpression expr(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("select"),
|
||||||
|
std::move(call_params));
|
||||||
|
|
||||||
|
// Register the variable
|
||||||
|
EXPECT_TRUE(td()->Determine());
|
||||||
|
EXPECT_FALSE(td()->DetermineResultType(&expr));
|
||||||
|
EXPECT_EQ(td()->error(),
|
||||||
|
"incorrect number of parameters for select expected 3 got 1");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(TypeDeterminerTest, Intrinsic_Select_TooManyParams) {
|
||||||
|
ast::type::F32Type f32;
|
||||||
|
ast::type::VectorType vec3(&f32, 3);
|
||||||
|
|
||||||
|
auto var =
|
||||||
|
std::make_unique<ast::Variable>("v", ast::StorageClass::kNone, &vec3);
|
||||||
|
mod()->AddGlobalVariable(std::move(var));
|
||||||
|
|
||||||
|
ast::ExpressionList call_params;
|
||||||
|
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
|
||||||
|
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
|
||||||
|
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
|
||||||
|
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
|
||||||
|
|
||||||
|
ast::CallExpression expr(
|
||||||
|
std::make_unique<ast::IdentifierExpression>("select"),
|
||||||
|
std::move(call_params));
|
||||||
|
|
||||||
|
// Register the variable
|
||||||
|
EXPECT_TRUE(td()->Determine());
|
||||||
|
EXPECT_FALSE(td()->DetermineResultType(&expr));
|
||||||
|
EXPECT_EQ(td()->error(),
|
||||||
|
"incorrect number of parameters for select expected 3 got 4");
|
||||||
|
}
|
||||||
|
|
||||||
TEST_F(TypeDeterminerTest, Intrinsic_OuterProduct) {
|
TEST_F(TypeDeterminerTest, Intrinsic_OuterProduct) {
|
||||||
ast::type::F32Type f32;
|
ast::type::F32Type f32;
|
||||||
ast::type::VectorType vec3(&f32, 3);
|
ast::type::VectorType vec3(&f32, 3);
|
||||||
|
|
Loading…
Reference in New Issue