From f534e9e692f38332f7b0661de0dca96dab1da04e Mon Sep 17 00:00:00 2001 From: David Neto Date: Mon, 31 May 2021 22:43:00 +0000 Subject: [PATCH] spirv-writer: fix bool equality, inequality Fixed: tint:743 Change-Id: I03b5d50d2bf3cd17b672401f1922bde35cbf2640 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52740 Auto-Submit: David Neto Kokoro: Kokoro Commit-Queue: James Price Reviewed-by: James Price --- src/writer/spirv/builder.cc | 24 +++++++- .../spirv/builder_binary_expression_test.cc | 55 +++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 7f51152a53..bb45f2d084 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -1885,6 +1885,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { } bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector(); + bool lhs_is_bool_or_vec = lhs_type->is_bool_scalar_or_vector(); + bool lhs_is_integer_or_vec = lhs_type->is_integer_scalar_or_vector(); bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector(); spv::Op op = spv::Op::OpNop; @@ -1901,7 +1903,16 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { op = spv::Op::OpSDiv; } } else if (expr->IsEqual()) { - op = lhs_is_float_or_vec ? spv::Op::OpFOrdEqual : spv::Op::OpIEqual; + if (lhs_is_float_or_vec) { + op = spv::Op::OpFOrdEqual; + } else if (lhs_is_bool_or_vec) { + op = spv::Op::OpLogicalEqual; + } else if (lhs_is_integer_or_vec) { + op = spv::Op::OpIEqual; + } else { + error_ = "invalid equal expression"; + return 0; + } } else if (expr->IsGreaterThan()) { if (lhs_is_float_or_vec) { op = spv::Op::OpFOrdGreaterThan; @@ -1983,7 +1994,16 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { return 0; } } else if (expr->IsNotEqual()) { - op = lhs_is_float_or_vec ? spv::Op::OpFOrdNotEqual : spv::Op::OpINotEqual; + if (lhs_is_float_or_vec) { + op = spv::Op::OpFOrdNotEqual; + } else if (lhs_is_bool_or_vec) { + op = spv::Op::OpLogicalNotEqual; + } else if (lhs_is_integer_or_vec) { + op = spv::Op::OpINotEqual; + } else { + error_ = "invalid not-equal expression"; + return 0; + } } else if (expr->IsOr()) { op = spv::Op::OpBitwiseOr; } else if (expr->IsShiftLeft()) { diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc index 3b8391b491..f18f460361 100644 --- a/src/writer/spirv/builder_binary_expression_test.cc +++ b/src/writer/spirv/builder_binary_expression_test.cc @@ -250,6 +250,61 @@ INSTANTIATE_TEST_SUITE_P( BinaryData{ast::BinaryOp::kMultiply, "OpFMul"}, BinaryData{ast::BinaryOp::kSubtract, "OpFSub"})); +using BinaryCompareBoolTest = TestParamHelper; +TEST_P(BinaryCompareBoolTest, Scalar) { + auto param = GetParam(); + + auto* lhs = Expr(true); + auto* rhs = Expr(false); + + auto* expr = create(param.op, lhs, rhs); + + WrapInFunction(expr); + + spirv::Builder& b = Build(); + + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool +%2 = OpConstantTrue %1 +%3 = OpConstantFalse %1 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + "%4 = " + param.name + " %1 %2 %3\n"); +} + +TEST_P(BinaryCompareBoolTest, Vector) { + auto param = GetParam(); + + auto* lhs = vec3(false, true, false); + auto* rhs = vec3(true, false, true); + + auto* expr = create(param.op, lhs, rhs); + + WrapInFunction(expr); + + spirv::Builder& b = Build(); + + b.push_function(Function{}); + + EXPECT_EQ(b.GenerateBinaryExpression(expr), 7u) << b.error(); + EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool +%1 = OpTypeVector %2 3 +%3 = OpConstantFalse %2 +%4 = OpConstantTrue %2 +%5 = OpConstantComposite %1 %3 %4 %3 +%6 = OpConstantComposite %1 %4 %3 %4 +)"); + EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()), + "%7 = " + param.name + " %1 %5 %6\n"); +} +INSTANTIATE_TEST_SUITE_P( + BuilderTest, + BinaryCompareBoolTest, + testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpLogicalEqual"}, + BinaryData{ast::BinaryOp::kNotEqual, "OpLogicalNotEqual"})); + using BinaryCompareUnsignedIntegerTest = TestParamHelper; TEST_P(BinaryCompareUnsignedIntegerTest, Scalar) { auto param = GetParam();