diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index a3e634deab..37cdbdaef3 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -559,31 +559,25 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { auto result = result_op(); auto result_id = result.to_i(); - auto lhs_type = expr->lhs()->result_type(); - - auto expr_type = expr->result_type(); - auto type_id = GenerateTypeIfNeeded(expr_type); + auto type_id = GenerateTypeIfNeeded(expr->result_type()); if (type_id == 0) { return 0; } + // Handle int and float and the vectors of those types. Other types + // should have been rejected by validation. + auto lhs_type = expr->lhs()->result_type(); + bool lhs_is_float_or_vec = + lhs_type->IsF32() || + (lhs_type->IsVector() && lhs_type->AsVector()->type()->IsF32()); + spv::Op op = spv::Op::OpNop; if (expr->IsAdd()) { - // This handles int and float and the vectors of those types. Other types - // should have been rejected by validation. - op = spv::Op::OpIAdd; - if (expr_type->IsF32() || - (expr_type->IsVector() && expr_type->AsVector()->type()->IsF32())) { - op = spv::Op::OpFAdd; - } + op = lhs_is_float_or_vec ? spv::Op::OpFAdd : spv::Op::OpIAdd; } else if (expr->IsEqual()) { - // This handles int and float and the vectors of those types. Other types - // should have been rejected by validation. - op = spv::Op::OpIEqual; - if (lhs_type->IsF32() || - (lhs_type->IsVector() && lhs_type->AsVector()->type()->IsF32())) { - op = spv::Op::OpFOrdEqual; - } + op = lhs_is_float_or_vec ? spv::Op::OpFOrdEqual : spv::Op::OpIEqual; + } else if (expr->IsNotEqual()) { + op = lhs_is_float_or_vec ? spv::Op::OpFOrdNotEqual : spv::Op::OpINotEqual; } else { return 0; } diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc index 6b0848f3be..41fa16ca56 100644 --- a/src/writer/spirv/builder_binary_expression_test.cc +++ b/src/writer/spirv/builder_binary_expression_test.cc @@ -276,10 +276,11 @@ TEST_P(BinaryCompareIntegerTest, Vector) { EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), "%5 = " + param.name + " %6 %4 %4\n"); } -INSTANTIATE_TEST_SUITE_P(BuilderTest, - BinaryCompareIntegerTest, - testing::Values(BinaryData{ast::BinaryOp::kEqual, - "OpIEqual"})); +INSTANTIATE_TEST_SUITE_P( + BuilderTest, + BinaryCompareIntegerTest, + testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpIEqual"}, + BinaryData{ast::BinaryOp::kNotEqual, "OpINotEqual"})); using BinaryCompareFloatTest = testing::TestWithParam; TEST_P(BinaryCompareFloatTest, Scalar) { @@ -357,10 +358,11 @@ TEST_P(BinaryCompareFloatTest, Vector) { EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), "%5 = " + param.name + " %6 %4 %4\n"); } -INSTANTIATE_TEST_SUITE_P(BuilderTest, - BinaryCompareFloatTest, - testing::Values(BinaryData{ast::BinaryOp::kEqual, - "OpFOrdEqual"})); +INSTANTIATE_TEST_SUITE_P( + BuilderTest, + BinaryCompareFloatTest, + testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpFOrdEqual"}, + BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"})); } // namespace } // namespace spirv