Add type determination for the select intrinsic.

This CL adds type determination for a `select` intrinsic.

Bug: tint:106
Change-Id: Ie5c051cb42c72ae732579e3064561a4544a90473
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25380
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-07-21 17:44:44 +00:00 committed by David Neto
parent d71e80b710
commit 16a2ea11d3
3 changed files with 110 additions and 4 deletions

View File

@ -40,7 +40,7 @@ bool IsFloatClassificationIntrinsic(const std::string& name) {
bool IsIntrinsic(const std::string& name) { bool IsIntrinsic(const std::string& name) {
return IsDerivative(name) || name == "all" || name == "any" || return IsDerivative(name) || name == "all" || name == "any" ||
IsFloatClassificationIntrinsic(name) || name == "dot" || IsFloatClassificationIntrinsic(name) || name == "dot" ||
name == "outer_product"; name == "outer_product" || name == "select";
} }
} // namespace intrinsic } // namespace intrinsic

View File

@ -540,7 +540,11 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name,
auto* bool_type = auto* bool_type =
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>()); ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); auto& param = expr->params()[0];
if (!DetermineResultType(param.get())) {
return false;
}
auto* param_type = param->result_type()->UnwrapPtrIfNeeded();
if (param_type->IsVector()) { if (param_type->IsVector()) {
expr->func()->set_result_type( expr->func()->set_result_type(
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>( ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
@ -562,8 +566,15 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name,
return false; return false;
} }
auto* param0_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); auto& param0 = expr->params()[0];
auto* param1_type = expr->params()[1]->result_type()->UnwrapPtrIfNeeded(); auto& param1 = expr->params()[1];
if (!DetermineResultType(param0.get()) ||
!DetermineResultType(param1.get())) {
return false;
}
auto* param0_type = param0->result_type()->UnwrapPtrIfNeeded();
auto* param1_type = param1->result_type()->UnwrapPtrIfNeeded();
if (!param0_type->IsVector() || !param1_type->IsVector()) { if (!param0_type->IsVector() || !param1_type->IsVector()) {
set_error(expr->source(), "invalid parameter type for outer_product"); set_error(expr->source(), "invalid parameter type for outer_product");
return false; return false;
@ -575,6 +586,22 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name,
param0_type->AsVector()->size(), param1_type->AsVector()->size()))); param0_type->AsVector()->size(), param1_type->AsVector()->size())));
return true; return true;
} }
if (name == "select") {
if (expr->params().size() != 3) {
set_error(expr->source(),
"incorrect number of parameters for select expected 3 got " +
std::to_string(expr->params().size()));
return false;
}
// The result type must be the same as the type of the parameter.
auto& param = expr->params()[0];
if (!DetermineResultType(param.get())) {
return false;
}
expr->func()->set_result_type(param->result_type()->UnwrapPtrIfNeeded());
return true;
}
return false; return false;
} }

View File

@ -1769,6 +1769,85 @@ TEST_F(TypeDeterminerTest, Intrinsic_Dot) {
EXPECT_TRUE(expr.result_type()->IsF32()); EXPECT_TRUE(expr.result_type()->IsF32());
} }
TEST_F(TypeDeterminerTest, Intrinsic_Select) {
ast::type::F32Type f32;
ast::type::BoolType bool_type;
ast::type::VectorType vec3(&f32, 3);
ast::type::VectorType bool_vec3(&bool_type, 3);
auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
&vec3);
auto bool_var = std::make_unique<ast::Variable>(
"bool_var", ast::StorageClass::kNone, &bool_vec3);
mod()->AddGlobalVariable(std::move(var));
mod()->AddGlobalVariable(std::move(bool_var));
ast::ExpressionList call_params;
call_params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
call_params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
call_params.push_back(
std::make_unique<ast::IdentifierExpression>("bool_var"));
ast::CallExpression expr(
std::make_unique<ast::IdentifierExpression>("select"),
std::move(call_params));
// Register the variable
EXPECT_TRUE(td()->Determine());
EXPECT_TRUE(td()->DetermineResultType(&expr)) << td()->error();
ASSERT_NE(expr.result_type(), nullptr);
EXPECT_TRUE(expr.result_type()->IsVector());
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3u);
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
}
TEST_F(TypeDeterminerTest, Intrinsic_Select_TooFewParams) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto var =
std::make_unique<ast::Variable>("v", ast::StorageClass::kNone, &vec3);
mod()->AddGlobalVariable(std::move(var));
ast::ExpressionList call_params;
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
ast::CallExpression expr(
std::make_unique<ast::IdentifierExpression>("select"),
std::move(call_params));
// Register the variable
EXPECT_TRUE(td()->Determine());
EXPECT_FALSE(td()->DetermineResultType(&expr));
EXPECT_EQ(td()->error(),
"incorrect number of parameters for select expected 3 got 1");
}
TEST_F(TypeDeterminerTest, Intrinsic_Select_TooManyParams) {
ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3);
auto var =
std::make_unique<ast::Variable>("v", ast::StorageClass::kNone, &vec3);
mod()->AddGlobalVariable(std::move(var));
ast::ExpressionList call_params;
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
ast::CallExpression expr(
std::make_unique<ast::IdentifierExpression>("select"),
std::move(call_params));
// Register the variable
EXPECT_TRUE(td()->Determine());
EXPECT_FALSE(td()->DetermineResultType(&expr));
EXPECT_EQ(td()->error(),
"incorrect number of parameters for select expected 3 got 4");
}
TEST_F(TypeDeterminerTest, Intrinsic_OuterProduct) { TEST_F(TypeDeterminerTest, Intrinsic_OuterProduct) {
ast::type::F32Type f32; ast::type::F32Type f32;
ast::type::VectorType vec3(&f32, 3); ast::type::VectorType vec3(&f32, 3);