spirv-writer: fix bool equality, inequality

Fixed: tint:743
Change-Id: I03b5d50d2bf3cd17b672401f1922bde35cbf2640
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52740
Auto-Submit: David Neto <dneto@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: James Price <jrprice@google.com>
Reviewed-by: James Price <jrprice@google.com>
This commit is contained in:
David Neto 2021-05-31 22:43:00 +00:00 committed by Tint LUCI CQ
parent 4e6744f954
commit f534e9e692
2 changed files with 77 additions and 2 deletions

View File

@ -1885,6 +1885,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
} }
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector(); bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
bool lhs_is_bool_or_vec = lhs_type->is_bool_scalar_or_vector();
bool lhs_is_integer_or_vec = lhs_type->is_integer_scalar_or_vector();
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector(); bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
spv::Op op = spv::Op::OpNop; spv::Op op = spv::Op::OpNop;
@ -1901,7 +1903,16 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
op = spv::Op::OpSDiv; op = spv::Op::OpSDiv;
} }
} else if (expr->IsEqual()) { } else if (expr->IsEqual()) {
op = lhs_is_float_or_vec ? spv::Op::OpFOrdEqual : spv::Op::OpIEqual; if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdEqual;
} else if (lhs_is_bool_or_vec) {
op = spv::Op::OpLogicalEqual;
} else if (lhs_is_integer_or_vec) {
op = spv::Op::OpIEqual;
} else {
error_ = "invalid equal expression";
return 0;
}
} else if (expr->IsGreaterThan()) { } else if (expr->IsGreaterThan()) {
if (lhs_is_float_or_vec) { if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdGreaterThan; op = spv::Op::OpFOrdGreaterThan;
@ -1983,7 +1994,16 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
return 0; return 0;
} }
} else if (expr->IsNotEqual()) { } else if (expr->IsNotEqual()) {
op = lhs_is_float_or_vec ? spv::Op::OpFOrdNotEqual : spv::Op::OpINotEqual; if (lhs_is_float_or_vec) {
op = spv::Op::OpFOrdNotEqual;
} else if (lhs_is_bool_or_vec) {
op = spv::Op::OpLogicalNotEqual;
} else if (lhs_is_integer_or_vec) {
op = spv::Op::OpINotEqual;
} else {
error_ = "invalid not-equal expression";
return 0;
}
} else if (expr->IsOr()) { } else if (expr->IsOr()) {
op = spv::Op::OpBitwiseOr; op = spv::Op::OpBitwiseOr;
} else if (expr->IsShiftLeft()) { } else if (expr->IsShiftLeft()) {

View File

@ -250,6 +250,61 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{ast::BinaryOp::kMultiply, "OpFMul"}, BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
BinaryData{ast::BinaryOp::kSubtract, "OpFSub"})); BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
using BinaryCompareBoolTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareBoolTest, Scalar) {
auto param = GetParam();
auto* lhs = Expr(true);
auto* rhs = Expr(false);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeBool
%2 = OpConstantTrue %1
%3 = OpConstantFalse %1
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%4 = " + param.name + " %1 %2 %3\n");
}
TEST_P(BinaryCompareBoolTest, Vector) {
auto param = GetParam();
auto* lhs = vec3<bool>(false, true, false);
auto* rhs = vec3<bool>(true, false, true);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(expr), 7u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeBool
%1 = OpTypeVector %2 3
%3 = OpConstantFalse %2
%4 = OpConstantTrue %2
%5 = OpConstantComposite %1 %3 %4 %3
%6 = OpConstantComposite %1 %4 %3 %4
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%7 = " + param.name + " %1 %5 %6\n");
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
BinaryCompareBoolTest,
testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpLogicalEqual"},
BinaryData{ast::BinaryOp::kNotEqual, "OpLogicalNotEqual"}));
using BinaryCompareUnsignedIntegerTest = TestParamHelper<BinaryData>; using BinaryCompareUnsignedIntegerTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareUnsignedIntegerTest, Scalar) { TEST_P(BinaryCompareUnsignedIntegerTest, Scalar) {
auto param = GetParam(); auto param = GetParam();