diff --git a/src/tint/writer/spirv/ir/generator_impl_ir.cc b/src/tint/writer/spirv/ir/generator_impl_ir.cc index f1e3410d2f..b9ce81793b 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir.cc @@ -412,6 +412,7 @@ void GeneratorImplIr::EmitIf(const ir::If* i) { uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) { auto id = module_.NextId(); + auto* lhs_ty = binary->LHS()->Type(); // Determine the opcode. spv::Op op = spv::Op::Max; @@ -424,6 +425,68 @@ uint32_t GeneratorImplIr::EmitBinary(const ir::Binary* binary) { op = binary->Type()->is_integer_scalar_or_vector() ? spv::Op::OpISub : spv::Op::OpFSub; break; } + + case ir::Binary::Kind::kEqual: { + if (lhs_ty->is_bool_scalar_or_vector()) { + op = spv::Op::OpLogicalEqual; + } else if (lhs_ty->is_float_scalar_or_vector()) { + op = spv::Op::OpFOrdEqual; + } else if (lhs_ty->is_integer_scalar_or_vector()) { + op = spv::Op::OpIEqual; + } + break; + } + case ir::Binary::Kind::kNotEqual: { + if (lhs_ty->is_bool_scalar_or_vector()) { + op = spv::Op::OpLogicalNotEqual; + } else if (lhs_ty->is_float_scalar_or_vector()) { + op = spv::Op::OpFOrdNotEqual; + } else if (lhs_ty->is_integer_scalar_or_vector()) { + op = spv::Op::OpINotEqual; + } + break; + } + case ir::Binary::Kind::kGreaterThan: { + if (lhs_ty->is_float_scalar_or_vector()) { + op = spv::Op::OpFOrdGreaterThan; + } else if (lhs_ty->is_signed_integer_scalar_or_vector()) { + op = spv::Op::OpSGreaterThan; + } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) { + op = spv::Op::OpUGreaterThan; + } + break; + } + case ir::Binary::Kind::kGreaterThanEqual: { + if (lhs_ty->is_float_scalar_or_vector()) { + op = spv::Op::OpFOrdGreaterThanEqual; + } else if (lhs_ty->is_signed_integer_scalar_or_vector()) { + op = spv::Op::OpSGreaterThanEqual; + } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) { + op = spv::Op::OpUGreaterThanEqual; + } + break; + } + case ir::Binary::Kind::kLessThan: { + if (lhs_ty->is_float_scalar_or_vector()) { + op = spv::Op::OpFOrdLessThan; + } else if (lhs_ty->is_signed_integer_scalar_or_vector()) { + op = spv::Op::OpSLessThan; + } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) { + op = spv::Op::OpULessThan; + } + break; + } + case ir::Binary::Kind::kLessThanEqual: { + if (lhs_ty->is_float_scalar_or_vector()) { + op = spv::Op::OpFOrdLessThanEqual; + } else if (lhs_ty->is_signed_integer_scalar_or_vector()) { + op = spv::Op::OpSLessThanEqual; + } else if (lhs_ty->is_unsigned_integer_scalar_or_vector()) { + op = spv::Op::OpULessThanEqual; + } + break; + } + default: { TINT_ICE(Writer, diagnostics_) << "unimplemented binary instruction: " << static_cast(binary->Kind()); diff --git a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc index 264f619ad9..a139ad1a8b 100644 --- a/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc +++ b/src/tint/writer/spirv/ir/generator_impl_ir_binary_test.cc @@ -14,11 +14,109 @@ #include "src/tint/writer/spirv/ir/test_helper_ir.h" +#include "gmock/gmock.h" +#include "src/tint/ir/binary.h" + using namespace tint::number_suffixes; // NOLINT namespace tint::writer::spirv { namespace { +/// The element type of a test. +enum Type { + kBool, + kI32, + kU32, + kF32, + kF16, +}; + +/// A parameterized test case. +struct BinaryTestCase { + /// The element type to test. + Type type; + /// The binary operation. + enum ir::Binary::Kind kind; + /// The expected SPIR-V instruction. + std::string spirv_inst; +}; + +/// A helper class for parameterized binary instruction tests. +class BinaryInstructionTest : public SpvGeneratorImplTestWithParam { + protected: + /// Helper to make a scalar type corresponding to the element type `ty`. + /// @param ty the element type + /// @returns the scalar type + const type::Type* MakeScalarType(Type ty) { + switch (ty) { + case kBool: + return mod.Types().bool_(); + case kI32: + return mod.Types().i32(); + case kU32: + return mod.Types().u32(); + case kF32: + return mod.Types().f32(); + case kF16: + return mod.Types().f16(); + } + return nullptr; + } + + /// Helper to make a vector type corresponding to the element type `ty`. + /// @param ty the element type + /// @returns the vector type + const type::Type* MakeVectorType(Type ty) { return mod.Types().vec2(MakeScalarType(ty)); } + + /// Helper to make a scalar value with the scalar type `ty`. + /// @param ty the element type + /// @returns the scalar value + ir::Value* MakeScalarValue(Type ty) { + switch (ty) { + case kBool: + return b.Constant(true); + case kI32: + return b.Constant(1_i); + case kU32: + return b.Constant(1_u); + case kF32: + return b.Constant(1_f); + case kF16: + return b.Constant(1_h); + } + return nullptr; + } + + /// Helper to make a vector value with an element type of `ty`. + /// @param ty the element type + /// @returns the vector value + ir::Value* MakeVectorValue(Type ty) { + switch (ty) { + case kBool: + return b.Constant(b.ir.constant_values.Composite( + MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(true), + b.ir.constant_values.Get(false)})); + case kI32: + return b.Constant(b.ir.constant_values.Composite( + MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_i), + b.ir.constant_values.Get(-10_i)})); + case kU32: + return b.Constant(b.ir.constant_values.Composite( + MakeVectorType(ty), + utils::Vector{b.ir.constant_values.Get(42_u), b.ir.constant_values.Get(10_u)})); + case kF32: + return b.Constant(b.ir.constant_values.Composite( + MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_f), + b.ir.constant_values.Get(-0.5_f)})); + case kF16: + return b.Constant(b.ir.constant_values.Composite( + MakeVectorType(ty), utils::Vector{b.ir.constant_values.Get(42_h), + b.ir.constant_values.Get(-0.5_h)})); + } + return nullptr; + } +}; + TEST_F(SpvGeneratorImplTest, Binary_Add_I32) { auto* func = b.CreateFunction("foo", mod.Types().void_()); func->StartTarget()->SetInstructions(utils::Vector{ @@ -210,6 +308,78 @@ OpFunctionEnd )"); } +using Comparison = BinaryInstructionTest; +TEST_P(Comparison, Scalar) { + auto params = GetParam(); + + auto* func = b.CreateFunction("foo", mod.Types().void_()); + func->StartTarget()->SetInstructions( + utils::Vector{b.CreateBinary(params.kind, mod.Types().bool_(), MakeScalarValue(params.type), + MakeScalarValue(params.type)), + b.Branch(func->EndTarget())}); + + generator_.EmitFunction(func); + EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst)); +} +TEST_P(Comparison, Vector) { + auto params = GetParam(); + + auto* func = b.CreateFunction("foo", mod.Types().void_()); + func->StartTarget()->SetInstructions( + utils::Vector{b.CreateBinary(params.kind, mod.Types().vec2(mod.Types().bool_()), + MakeVectorValue(params.type), MakeVectorValue(params.type)), + + b.Branch(func->EndTarget())}); + + generator_.EmitFunction(func); + EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst)); +} +INSTANTIATE_TEST_SUITE_P( + SpvGeneratorImplTest_Binary_I32, + Comparison, + testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kEqual, "OpIEqual"}, + BinaryTestCase{kI32, ir::Binary::Kind::kNotEqual, "OpINotEqual"}, + BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThan, "OpSGreaterThan"}, + BinaryTestCase{kI32, ir::Binary::Kind::kGreaterThanEqual, + "OpSGreaterThanEqual"}, + BinaryTestCase{kI32, ir::Binary::Kind::kLessThan, "OpSLessThan"}, + BinaryTestCase{kI32, ir::Binary::Kind::kLessThanEqual, "OpSLessThanEqual"})); +INSTANTIATE_TEST_SUITE_P( + SpvGeneratorImplTest_Binary_U32, + Comparison, + testing::Values(BinaryTestCase{kU32, ir::Binary::Kind::kEqual, "OpIEqual"}, + BinaryTestCase{kU32, ir::Binary::Kind::kNotEqual, "OpINotEqual"}, + BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThan, "OpUGreaterThan"}, + BinaryTestCase{kU32, ir::Binary::Kind::kGreaterThanEqual, + "OpUGreaterThanEqual"}, + BinaryTestCase{kU32, ir::Binary::Kind::kLessThan, "OpULessThan"}, + BinaryTestCase{kU32, ir::Binary::Kind::kLessThanEqual, "OpULessThanEqual"})); +INSTANTIATE_TEST_SUITE_P( + SpvGeneratorImplTest_Binary_F32, + Comparison, + testing::Values(BinaryTestCase{kF32, ir::Binary::Kind::kEqual, "OpFOrdEqual"}, + BinaryTestCase{kF32, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual"}, + BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan"}, + BinaryTestCase{kF32, ir::Binary::Kind::kGreaterThanEqual, + "OpFOrdGreaterThanEqual"}, + BinaryTestCase{kF32, ir::Binary::Kind::kLessThan, "OpFOrdLessThan"}, + BinaryTestCase{kF32, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual"})); +INSTANTIATE_TEST_SUITE_P( + SpvGeneratorImplTest_Binary_F16, + Comparison, + testing::Values(BinaryTestCase{kF16, ir::Binary::Kind::kEqual, "OpFOrdEqual"}, + BinaryTestCase{kF16, ir::Binary::Kind::kNotEqual, "OpFOrdNotEqual"}, + BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThan, "OpFOrdGreaterThan"}, + BinaryTestCase{kF16, ir::Binary::Kind::kGreaterThanEqual, + "OpFOrdGreaterThanEqual"}, + BinaryTestCase{kF16, ir::Binary::Kind::kLessThan, "OpFOrdLessThan"}, + BinaryTestCase{kF16, ir::Binary::Kind::kLessThanEqual, "OpFOrdLessThanEqual"})); +INSTANTIATE_TEST_SUITE_P( + SpvGeneratorImplTest_Binary_Bool, + Comparison, + testing::Values(BinaryTestCase{kBool, ir::Binary::Kind::kEqual, "OpLogicalEqual"}, + BinaryTestCase{kBool, ir::Binary::Kind::kNotEqual, "OpLogicalNotEqual"})); + TEST_F(SpvGeneratorImplTest, Binary_Chain) { auto* func = b.CreateFunction("foo", mod.Types().void_()); auto* a = b.Subtract(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i));