diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index a7af47f138..7b51c944d8 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -1971,6 +1971,22 @@ uint32_t Builder::GenerateIntrinsic(ast::CallExpression* call, return result_id; } + // Generates the SPIR-V ID for the expression for the indexed call parameter, + // and loads it if necessary. Returns 0 on error. + auto get_param_as_value_id = [&](size_t i) -> uint32_t { + auto* arg = call->params()[i]; + auto& param = intrinsic->Parameters()[i]; + auto val_id = GenerateExpression(arg); + if (val_id == 0) { + return 0; + } + + if (!param.type->Is()) { + val_id = GenerateLoadIfNeeded(TypeOf(arg), val_id); + } + return val_id; + }; + OperandList params = {Operand::Int(result_type_id), result}; spv::Op op = spv::Op::OpNop; @@ -2054,6 +2070,32 @@ uint32_t Builder::GenerateIntrinsic(ast::CallExpression* call, case IntrinsicType::kIsNan: op = spv::Op::OpIsNan; break; + case IntrinsicType::kIsFinite: { + // Implemented as: not(IsInf or IsNan) + auto val_id = get_param_as_value_id(0); + if (!val_id) { + return 0; + } + auto inf_result = result_op(); + auto nan_result = result_op(); + auto or_result = result_op(); + if (push_function_inst(spv::Op::OpIsInf, + {Operand::Int(result_type_id), inf_result, + Operand::Int(val_id)}) && + push_function_inst(spv::Op::OpIsNan, + {Operand::Int(result_type_id), nan_result, + Operand::Int(val_id)}) && + push_function_inst(spv::Op::OpLogicalOr, + {Operand::Int(result_type_id), or_result, + Operand::Int(inf_result.to_i()), + Operand::Int(nan_result.to_i())}) && + push_function_inst(spv::Op::OpLogicalNot, + {Operand::Int(result_type_id), result, + Operand::Int(or_result.to_i())})) { + return result_id; + } + return 0; + } case IntrinsicType::kReverseBits: op = spv::Op::OpBitReverse; break; @@ -2090,18 +2132,11 @@ uint32_t Builder::GenerateIntrinsic(ast::CallExpression* call, } for (size_t i = 0; i < call->params().size(); i++) { - auto* arg = call->params()[i]; - auto& param = intrinsic->Parameters()[i]; - auto val_id = GenerateExpression(arg); - if (val_id == 0) { - return false; + if (auto val_id = get_param_as_value_id(i)) { + params.emplace_back(Operand::Int(val_id)); + } else { + return 0; } - - if (!param.type->Is()) { - val_id = GenerateLoadIfNeeded(TypeOf(arg), val_id); - } - - params.emplace_back(Operand::Int(val_id)); } if (!push_function_inst(op, params)) { diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc index 3c40829063..8599f1b470 100644 --- a/src/writer/spirv/builder_intrinsic_test.cc +++ b/src/writer/spirv/builder_intrinsic_test.cc @@ -128,6 +128,62 @@ INSTANTIATE_TEST_SUITE_P(IntrinsicBuilderTest, testing::Values(IntrinsicData{"isNan", "OpIsNan"}, IntrinsicData{"isInf", "OpIsInf"})); +TEST_F(IntrinsicBuilderTest, IsFinite_Scalar) { + auto* var = Global("v", ty.f32(), ast::StorageClass::kPrivate); + + auto* expr = Call("isFinite", "v"); + WrapInFunction(expr); + + spirv::Builder& b = Build(); + + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error(); + + EXPECT_EQ(b.GenerateCallExpression(expr), 5u) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%3 = OpTypeFloat 32 +%2 = OpTypePointer Private %3 +%4 = OpConstantNull %3 +%1 = OpVariable %2 Private %4 +%6 = OpTypeBool +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%7 = OpLoad %3 %1 +%8 = OpIsInf %6 %7 +%9 = OpIsNan %6 %7 +%10 = OpLogicalOr %6 %8 %9 +%5 = OpLogicalNot %6 %10 +)"); +} + +TEST_F(IntrinsicBuilderTest, IsFinite_Vector) { + auto* var = Global("v", ty.vec3(), ast::StorageClass::kPrivate); + + auto* expr = Call("isFinite", "v"); + WrapInFunction(expr); + + spirv::Builder& b = Build(); + + b.push_function(Function{}); + ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error(); + + EXPECT_EQ(b.GenerateCallExpression(expr), 6u) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%4 = OpTypeFloat 32 +%3 = OpTypeVector %4 3 +%2 = OpTypePointer Private %3 +%5 = OpConstantNull %3 +%1 = OpVariable %2 Private %5 +%8 = OpTypeBool +%7 = OpTypeVector %8 3 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + R"(%9 = OpLoad %3 %1 +%10 = OpIsInf %7 %9 +%11 = OpIsNan %7 %9 +%12 = OpLogicalOr %7 %10 %11 +%6 = OpLogicalNot %7 %12 +)"); +} + using IntrinsicIntTest = IntrinsicBuilderTestWithParam; TEST_P(IntrinsicIntTest, Call_SInt_Scalar) { auto param = GetParam();