diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 1ad79b90eb..e224ebd317 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -1250,6 +1250,8 @@ uint32_t Builder::GenerateIntrinsic(const std::string& name, op = spv::Op::OpAny; } else if (name == "all") { op = spv::Op::OpAll; + } else if (name == "is_nan") { + op = spv::Op::OpIsNan; } if (op == spv::Op::OpNop) { error_ = "unable to determine operator for: " + name; diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc index d819ad554f..ba40b10096 100644 --- a/src/writer/spirv/builder_intrinsic_test.cc +++ b/src/writer/spirv/builder_intrinsic_test.cc @@ -18,6 +18,7 @@ #include "src/ast/call_expression.h" #include "src/ast/identifier_expression.h" #include "src/ast/type/bool_type.h" +#include "src/ast/type/f32_type.h" #include "src/ast/type/vector_type.h" #include "src/ast/variable.h" #include "src/context.h" @@ -32,7 +33,19 @@ namespace { using BuilderTest = testing::Test; -TEST_F(BuilderTest, Call_Any) { +struct IntrinsicData { + std::string name; + std::string op; +}; +inline std::ostream& operator<<(std::ostream& out, IntrinsicData data) { + out << data.name; + return out; +} + +using IntrinsicBoolTest = testing::TestWithParam; +TEST_P(IntrinsicBoolTest, Call_Bool) { + auto param = GetParam(); + ast::type::BoolType bool_type; ast::type::VectorType vec3(&bool_type, 3); @@ -41,8 +54,9 @@ TEST_F(BuilderTest, Call_Any) { ast::ExpressionList params; params.push_back(std::make_unique("v")); - ast::CallExpression expr(std::make_unique("any"), - std::move(params)); + ast::CallExpression expr( + std::make_unique(param.name), + std::move(params)); Context ctx; ast::Module mod; @@ -64,21 +78,67 @@ TEST_F(BuilderTest, Call_Any) { )"); EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), R"(%7 = OpLoad %3 %1 -%6 = OpAny %4 %7 +%6 = )" + param.op + + " %4 %7\n"); +} +INSTANTIATE_TEST_SUITE_P(BuilderTest, + IntrinsicBoolTest, + testing::Values(IntrinsicData{"any", "OpAny"}, + IntrinsicData{"all", "OpAll"})); + +using IntrinsicFloatTest = testing::TestWithParam; +TEST_P(IntrinsicFloatTest, Call_Float_Scalar) { + auto param = GetParam(); + + ast::type::F32Type f32; + + auto var = + std::make_unique("v", ast::StorageClass::kPrivate, &f32); + + ast::ExpressionList params; + params.push_back(std::make_unique("v")); + ast::CallExpression expr( + std::make_unique(param.name), + std::move(params)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(var.get()); + + ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); + + Builder b(&mod); + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); + + EXPECT_EQ(b.GenerateCallExpression(&expr), 5u) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 +%2 = OpTypePointer Private %3 +%4 = OpConstantNull %3 +%1 = OpVariable %2 Private %4 +%6 = OpTypeBool )"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%7 = OpLoad %3 %1 +%5 = )" + param.op + + " %6 %7\n"); } -TEST_F(BuilderTest, Call_All) { - ast::type::BoolType bool_type; - ast::type::VectorType vec3(&bool_type, 3); +TEST_P(IntrinsicFloatTest, Call_Float_Vector) { + auto param = GetParam(); + + ast::type::F32Type f32; + ast::type::VectorType vec3(&f32, 3); auto var = std::make_unique("v", ast::StorageClass::kPrivate, &vec3); ast::ExpressionList params; params.push_back(std::make_unique("v")); - ast::CallExpression expr(std::make_unique("all"), - std::move(params)); + ast::CallExpression expr( + std::make_unique(param.name), + std::move(params)); Context ctx; ast::Module mod; @@ -92,17 +152,23 @@ TEST_F(BuilderTest, Call_All) { ASSERT_TRUE(b.GenerateGlobalVariable(var.get())) << b.error(); EXPECT_EQ(b.GenerateCallExpression(&expr), 6u) << b.error(); - EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeBool + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 %3 = OpTypeVector %4 3 %2 = OpTypePointer Private %3 %5 = OpConstantNull %3 %1 = OpVariable %2 Private %5 +%8 = OpTypeBool +%7 = OpTypeVector %8 3 )"); EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), - R"(%7 = OpLoad %3 %1 -%6 = OpAll %4 %7 -)"); + R"(%9 = OpLoad %3 %1 +%6 = )" + param.op + + " %7 %9\n"); } +INSTANTIATE_TEST_SUITE_P(BuilderTest, + IntrinsicFloatTest, + testing::Values(IntrinsicData{"is_nan", "OpIsNan"})); + } // namespace } // namespace spirv } // namespace writer