diff --git a/src/ast/intrinsic.cc b/src/ast/intrinsic.cc index 73cefac2ee..226581b596 100644 --- a/src/ast/intrinsic.cc +++ b/src/ast/intrinsic.cc @@ -40,7 +40,7 @@ bool IsFloatClassificationIntrinsic(const std::string& name) { bool IsIntrinsic(const std::string& name) { return IsDerivative(name) || name == "all" || name == "any" || IsFloatClassificationIntrinsic(name) || name == "dot" || - name == "outer_product"; + name == "outer_product" || name == "select"; } } // namespace intrinsic diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 5d2c983b33..d4a597e45d 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -540,7 +540,11 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name, auto* bool_type = ctx_.type_mgr().Get(std::make_unique()); - auto* param_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); + auto& param = expr->params()[0]; + if (!DetermineResultType(param.get())) { + return false; + } + auto* param_type = param->result_type()->UnwrapPtrIfNeeded(); if (param_type->IsVector()) { expr->func()->set_result_type( ctx_.type_mgr().Get(std::make_unique( @@ -562,8 +566,15 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name, return false; } - auto* param0_type = expr->params()[0]->result_type()->UnwrapPtrIfNeeded(); - auto* param1_type = expr->params()[1]->result_type()->UnwrapPtrIfNeeded(); + auto& param0 = expr->params()[0]; + auto& param1 = expr->params()[1]; + if (!DetermineResultType(param0.get()) || + !DetermineResultType(param1.get())) { + return false; + } + + auto* param0_type = param0->result_type()->UnwrapPtrIfNeeded(); + auto* param1_type = param1->result_type()->UnwrapPtrIfNeeded(); if (!param0_type->IsVector() || !param1_type->IsVector()) { set_error(expr->source(), "invalid parameter type for outer_product"); return false; @@ -575,6 +586,22 @@ bool TypeDeterminer::DetermineIntrinsic(const std::string& name, param0_type->AsVector()->size(), param1_type->AsVector()->size()))); return true; } + if (name == "select") { + if (expr->params().size() != 3) { + set_error(expr->source(), + "incorrect number of parameters for select expected 3 got " + + std::to_string(expr->params().size())); + return false; + } + + // The result type must be the same as the type of the parameter. + auto& param = expr->params()[0]; + if (!DetermineResultType(param.get())) { + return false; + } + expr->func()->set_result_type(param->result_type()->UnwrapPtrIfNeeded()); + return true; + } return false; } diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index b46cfb7d72..d238ab6f74 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -1769,6 +1769,85 @@ TEST_F(TypeDeterminerTest, Intrinsic_Dot) { EXPECT_TRUE(expr.result_type()->IsF32()); } +TEST_F(TypeDeterminerTest, Intrinsic_Select) { + ast::type::F32Type f32; + ast::type::BoolType bool_type; + ast::type::VectorType vec3(&f32, 3); + ast::type::VectorType bool_vec3(&bool_type, 3); + + auto var = std::make_unique("my_var", ast::StorageClass::kNone, + &vec3); + auto bool_var = std::make_unique( + "bool_var", ast::StorageClass::kNone, &bool_vec3); + mod()->AddGlobalVariable(std::move(var)); + mod()->AddGlobalVariable(std::move(bool_var)); + + ast::ExpressionList call_params; + call_params.push_back(std::make_unique("my_var")); + call_params.push_back(std::make_unique("my_var")); + call_params.push_back( + std::make_unique("bool_var")); + + ast::CallExpression expr( + std::make_unique("select"), + std::move(call_params)); + + // Register the variable + EXPECT_TRUE(td()->Determine()); + EXPECT_TRUE(td()->DetermineResultType(&expr)) << td()->error(); + ASSERT_NE(expr.result_type(), nullptr); + EXPECT_TRUE(expr.result_type()->IsVector()); + EXPECT_EQ(expr.result_type()->AsVector()->size(), 3u); + EXPECT_TRUE(expr.result_type()->AsVector()->type()->IsF32()); +} + +TEST_F(TypeDeterminerTest, Intrinsic_Select_TooFewParams) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + + auto var = + std::make_unique("v", ast::StorageClass::kNone, &vec3); + mod()->AddGlobalVariable(std::move(var)); + + ast::ExpressionList call_params; + call_params.push_back(std::make_unique("v")); + + ast::CallExpression expr( + std::make_unique("select"), + std::move(call_params)); + + // Register the variable + EXPECT_TRUE(td()->Determine()); + EXPECT_FALSE(td()->DetermineResultType(&expr)); + EXPECT_EQ(td()->error(), + "incorrect number of parameters for select expected 3 got 1"); +} + +TEST_F(TypeDeterminerTest, Intrinsic_Select_TooManyParams) { + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); + + auto var = + std::make_unique("v", ast::StorageClass::kNone, &vec3); + mod()->AddGlobalVariable(std::move(var)); + + ast::ExpressionList call_params; + call_params.push_back(std::make_unique("v")); + call_params.push_back(std::make_unique("v")); + call_params.push_back(std::make_unique("v")); + call_params.push_back(std::make_unique("v")); + + ast::CallExpression expr( + std::make_unique("select"), + std::move(call_params)); + + // Register the variable + EXPECT_TRUE(td()->Determine()); + EXPECT_FALSE(td()->DetermineResultType(&expr)); + EXPECT_EQ(td()->error(), + "incorrect number of parameters for select expected 3 got 4"); +} + TEST_F(TypeDeterminerTest, Intrinsic_OuterProduct) { ast::type::F32Type f32; ast::type::VectorType vec3(&f32, 3);