[ir][spirv-writer] Expand binary arithmetic tests

Use the parameterized test helper for binary expressions to more
comprehensively test scalar and vector values across different types.

Bug: tint:1906
Change-Id: I2be087d7889d0993125eb0e3f897acbdf56575b2
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134323
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: James Price <jrprice@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price 2023-05-25 03:24:19 +00:00 committed by Dawn LUCI CQ
parent 6663a97b74
commit 0eb4d04d83
1 changed files with 32 additions and 176 deletions

View File

@ -117,196 +117,52 @@ class BinaryInstructionTest : public SpvGeneratorImplTestWithParam<BinaryTestCas
}
};
TEST_F(SpvGeneratorImplTest, Binary_Add_I32) {
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(utils::Vector{
b.Add(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i)), b.Branch(func->EndTarget())});
using Arithmetic = BinaryInstructionTest;
TEST_P(Arithmetic, Scalar) {
auto params = GetParam();
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeInt 32 1
%7 = OpConstant %6 1
%8 = OpConstant %6 2
%1 = OpFunction %2 None %3
%4 = OpLabel
%5 = OpIAdd %6 %7 %8
OpReturn
OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Binary_Add_U32) {
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(utils::Vector{
b.Add(mod.Types().u32(), b.Constant(1_u), b.Constant(2_u)), b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeInt 32 0
%7 = OpConstant %6 1
%8 = OpConstant %6 2
%1 = OpFunction %2 None %3
%4 = OpLabel
%5 = OpIAdd %6 %7 %8
OpReturn
OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Binary_Add_F32) {
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(utils::Vector{
b.Add(mod.Types().f32(), b.Constant(1_f), b.Constant(2_f)), b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeFloat 32
%7 = OpConstant %6 1
%8 = OpConstant %6 2
%1 = OpFunction %2 None %3
%4 = OpLabel
%5 = OpFAdd %6 %7 %8
OpReturn
OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Binary_Sub_I32) {
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.Subtract(mod.Types().i32(), b.Constant(1_i), b.Constant(2_i)),
utils::Vector{b.CreateBinary(params.kind, MakeScalarType(params.type),
MakeScalarValue(params.type), MakeScalarValue(params.type)),
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeInt 32 1
%7 = OpConstant %6 1
%8 = OpConstant %6 2
%1 = OpFunction %2 None %3
%4 = OpLabel
%5 = OpISub %6 %7 %8
OpReturn
OpFunctionEnd
)");
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
TEST_P(Arithmetic, Vector) {
auto params = GetParam();
TEST_F(SpvGeneratorImplTest, Binary_Sub_U32) {
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.Subtract(mod.Types().u32(), b.Constant(1_u), b.Constant(2_u)),
utils::Vector{b.CreateBinary(params.kind, MakeVectorType(params.type),
MakeVectorValue(params.type), MakeVectorValue(params.type)),
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeInt 32 0
%7 = OpConstant %6 1
%8 = OpConstant %6 2
%1 = OpFunction %2 None %3
%4 = OpLabel
%5 = OpISub %6 %7 %8
OpReturn
OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Binary_Sub_F32) {
auto* func = b.CreateFunction("foo", mod.Types().void_());
func->StartTarget()->SetInstructions(
utils::Vector{b.Subtract(mod.Types().f32(), b.Constant(1_f), b.Constant(2_f)),
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%6 = OpTypeFloat 32
%7 = OpConstant %6 1
%8 = OpConstant %6 2
%1 = OpFunction %2 None %3
%4 = OpLabel
%5 = OpFSub %6 %7 %8
OpReturn
OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec2i) {
auto const_i32 = [&](int val) { return b.ir.constant_values.Get(i32(val)); };
auto* func = b.CreateFunction("foo", mod.Types().void_());
auto* lhs = b.ir.constant_values.Composite(mod.Types().vec2(mod.Types().i32()),
utils::Vector{const_i32(42), const_i32(-1)});
auto* rhs = b.ir.constant_values.Composite(mod.Types().vec2(mod.Types().i32()),
utils::Vector{const_i32(0), const_i32(-43)});
func->StartTarget()->SetInstructions(
utils::Vector{b.Subtract(mod.Types().Get<type::Vector>(mod.Types().i32(), 2u),
b.Constant(lhs), b.Constant(rhs)),
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%7 = OpTypeInt 32 1
%6 = OpTypeVector %7 2
%9 = OpConstant %7 42
%10 = OpConstant %7 -1
%8 = OpConstantComposite %6 %9 %10
%12 = OpConstant %7 0
%13 = OpConstant %7 -43
%11 = OpConstantComposite %6 %12 %13
%1 = OpFunction %2 None %3
%4 = OpLabel
%5 = OpISub %6 %8 %11
OpReturn
OpFunctionEnd
)");
}
TEST_F(SpvGeneratorImplTest, Binary_Sub_Vec4f) {
auto const_f32 = [&](float val) { return b.ir.constant_values.Get(f32(val)); };
auto* func = b.CreateFunction("foo", mod.Types().void_());
auto* lhs = b.ir.constant_values.Composite(
mod.Types().vec4(mod.Types().f32()),
utils::Vector{const_f32(42), const_f32(-1), const_f32(0), const_f32(1.25)});
auto* rhs = b.ir.constant_values.Composite(
mod.Types().vec4(mod.Types().f32()),
utils::Vector{const_f32(0), const_f32(1.25), const_f32(-42), const_f32(1)});
func->StartTarget()->SetInstructions(
utils::Vector{b.Subtract(mod.Types().Get<type::Vector>(mod.Types().f32(), 4u),
b.Constant(lhs), b.Constant(rhs)),
b.Branch(func->EndTarget())});
generator_.EmitFunction(func);
EXPECT_EQ(DumpModule(generator_.Module()), R"(OpName %1 "foo"
%2 = OpTypeVoid
%3 = OpTypeFunction %2
%7 = OpTypeFloat 32
%6 = OpTypeVector %7 4
%9 = OpConstant %7 42
%10 = OpConstant %7 -1
%11 = OpConstant %7 0
%12 = OpConstant %7 1.25
%8 = OpConstantComposite %6 %9 %10 %11 %12
%14 = OpConstant %7 -42
%15 = OpConstant %7 1
%13 = OpConstantComposite %6 %11 %12 %14 %15
%1 = OpFunction %2 None %3
%4 = OpLabel
%5 = OpFSub %6 %8 %13
OpReturn
OpFunctionEnd
)");
EXPECT_THAT(DumpModule(generator_.Module()), ::testing::HasSubstr(params.spirv_inst));
}
INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_I32,
Arithmetic,
testing::Values(BinaryTestCase{kI32, ir::Binary::Kind::kAdd, "OpIAdd"},
BinaryTestCase{kI32, ir::Binary::Kind::kSubtract,
"OpISub"}));
INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_U32,
Arithmetic,
testing::Values(BinaryTestCase{kU32, ir::Binary::Kind::kAdd, "OpIAdd"},
BinaryTestCase{kU32, ir::Binary::Kind::kSubtract,
"OpISub"}));
INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_F32,
Arithmetic,
testing::Values(BinaryTestCase{kF32, ir::Binary::Kind::kAdd, "OpFAdd"},
BinaryTestCase{kF32, ir::Binary::Kind::kSubtract,
"OpFSub"}));
INSTANTIATE_TEST_SUITE_P(SpvGeneratorImplTest_Binary_F16,
Arithmetic,
testing::Values(BinaryTestCase{kF16, ir::Binary::Kind::kAdd, "OpFAdd"},
BinaryTestCase{kF16, ir::Binary::Kind::kSubtract,
"OpFSub"}));
using Bitwise = BinaryInstructionTest;
TEST_P(Bitwise, Scalar) {