Add type determination for the select intrinsic.
This CL adds type determination for a `select` intrinsic. Bug: tint:106 Change-Id: Ie5c051cb42c72ae732579e3064561a4544a90473 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/25380 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
d71e80b710
commit
16a2ea11d3
|
@ -40,7 +40,7 @@ bool IsFloatClassificationIntrinsic(const std::string& name) {
|
|||
bool IsIntrinsic(const std::string& name) {
|
||||
return IsDerivative(name) || name == "all" || name == "any" ||
|
||||
IsFloatClassificationIntrinsic(name) || name == "dot" ||
|
||||
name == "outer_product";
|
||||
name == "outer_product" || name == "select";
|
||||
}
|
||||
|
||||
} // namespace intrinsic
|
||||
|
|
|
@ -540,7 +540,11 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name,
|
|||
auto* bool_type =
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::BoolType>());
|
||||
|
||||
auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
|
||||
auto& param = expr->params()[0];
|
||||
if (!DetermineResultType(param.get())) {
|
||||
return false;
|
||||
}
|
||||
auto* param_type = param->result_type()->UnwrapPtrIfNeeded();
|
||||
if (param_type->IsVector()) {
|
||||
expr->func()->set_result_type(
|
||||
ctx_.type_mgr().Get(std::make_unique<ast::type::VectorType>(
|
||||
|
@ -562,8 +566,15 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name,
|
|||
return false;
|
||||
}
|
||||
|
||||
auto* param0_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded();
|
||||
auto* param1_type = expr->params()[1]->result_type()->UnwrapPtrIfNeeded();
|
||||
auto& param0 = expr->params()[0];
|
||||
auto& param1 = expr->params()[1];
|
||||
if (!DetermineResultType(param0.get()) ||
|
||||
!DetermineResultType(param1.get())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto* param0_type = param0->result_type()->UnwrapPtrIfNeeded();
|
||||
auto* param1_type = param1->result_type()->UnwrapPtrIfNeeded();
|
||||
if (!param0_type->IsVector() || !param1_type->IsVector()) {
|
||||
set_error(expr->source(), "invalid parameter type for outer_product");
|
||||
return false;
|
||||
|
@ -575,6 +586,22 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name,
|
|||
param0_type->AsVector()->size(), param1_type->AsVector()->size())));
|
||||
return true;
|
||||
}
|
||||
if (name == "select") {
|
||||
if (expr->params().size() != 3) {
|
||||
set_error(expr->source(),
|
||||
"incorrect number of parameters for select expected 3 got " +
|
||||
std::to_string(expr->params().size()));
|
||||
return false;
|
||||
}
|
||||
|
||||
// The result type must be the same as the type of the parameter.
|
||||
auto& param = expr->params()[0];
|
||||
if (!DetermineResultType(param.get())) {
|
||||
return false;
|
||||
}
|
||||
expr->func()->set_result_type(param->result_type()->UnwrapPtrIfNeeded());
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -1769,6 +1769,85 @@ TEST_F(TypeDeterminerTest, Intrinsic_Dot) {
|
|||
EXPECT_TRUE(expr.result_type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Intrinsic_Select) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::BoolType bool_type;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
ast::type::VectorType bool_vec3(&bool_type, 3);
|
||||
|
||||
auto var = std::make_unique<ast::Variable>("my_var", ast::StorageClass::kNone,
|
||||
&vec3);
|
||||
auto bool_var = std::make_unique<ast::Variable>(
|
||||
"bool_var", ast::StorageClass::kNone, &bool_vec3);
|
||||
mod()->AddGlobalVariable(std::move(var));
|
||||
mod()->AddGlobalVariable(std::move(bool_var));
|
||||
|
||||
ast::ExpressionList call_params;
|
||||
call_params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
|
||||
call_params.push_back(std::make_unique<ast::IdentifierExpression>("my_var"));
|
||||
call_params.push_back(
|
||||
std::make_unique<ast::IdentifierExpression>("bool_var"));
|
||||
|
||||
ast::CallExpression expr(
|
||||
std::make_unique<ast::IdentifierExpression>("select"),
|
||||
std::move(call_params));
|
||||
|
||||
// Register the variable
|
||||
EXPECT_TRUE(td()->Determine());
|
||||
EXPECT_TRUE(td()->DetermineResultType(&expr)) << td()->error();
|
||||
ASSERT_NE(expr.result_type(), nullptr);
|
||||
EXPECT_TRUE(expr.result_type()->IsVector());
|
||||
EXPECT_EQ(expr.result_type()->AsVector()->size(), 3u);
|
||||
EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32());
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Intrinsic_Select_TooFewParams) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("v", ast::StorageClass::kNone, &vec3);
|
||||
mod()->AddGlobalVariable(std::move(var));
|
||||
|
||||
ast::ExpressionList call_params;
|
||||
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
|
||||
|
||||
ast::CallExpression expr(
|
||||
std::make_unique<ast::IdentifierExpression>("select"),
|
||||
std::move(call_params));
|
||||
|
||||
// Register the variable
|
||||
EXPECT_TRUE(td()->Determine());
|
||||
EXPECT_FALSE(td()->DetermineResultType(&expr));
|
||||
EXPECT_EQ(td()->error(),
|
||||
"incorrect number of parameters for select expected 3 got 1");
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Intrinsic_Select_TooManyParams) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
||||
auto var =
|
||||
std::make_unique<ast::Variable>("v", ast::StorageClass::kNone, &vec3);
|
||||
mod()->AddGlobalVariable(std::move(var));
|
||||
|
||||
ast::ExpressionList call_params;
|
||||
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
|
||||
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
|
||||
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
|
||||
call_params.push_back(std::make_unique<ast::IdentifierExpression>("v"));
|
||||
|
||||
ast::CallExpression expr(
|
||||
std::make_unique<ast::IdentifierExpression>("select"),
|
||||
std::move(call_params));
|
||||
|
||||
// Register the variable
|
||||
EXPECT_TRUE(td()->Determine());
|
||||
EXPECT_FALSE(td()->DetermineResultType(&expr));
|
||||
EXPECT_EQ(td()->error(),
|
||||
"incorrect number of parameters for select expected 3 got 4");
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, Intrinsic_OuterProduct) {
|
||||
ast::type::F32Type f32;
|
||||
ast::type::VectorType vec3(&f32, 3);
|
||||
|
|
Loading…
Reference in New Issue