tint: Add f16 support for operator

This patch add support for using f16 types in unary operator `-` and
binary operator `+`, `-`, `*`, `/`, `%`, `<`, `>`, `<=`, and `>=`.
`==` is already supported. Unittests are also implemented.

Bug: tint:1473, tint:1502
Change-Id: I1123fa5e9e586ec0d8522b0f6bacafb4ad53ffcf
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/96380
Auto-Submit: Zhaoming Jiang <zhaoming.jiang@intel.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Zhaoming Jiang 2022-07-18 14:38:32 +00:00 committed by Dawn LUCI CQ
parent cd74244614
commit 59e640b208
7 changed files with 2819 additions and 1841 deletions

View File

@ -129,7 +129,9 @@ type __atomic_compare_exchange_result<T>
match f32f16: f32 | f16
match fiu32: f32 | i32 | u32
match fiu32f16: f32 | f16 | i32 | u32
match fi32: f32 | i32
match fi32f16: f32 | f16 | i32
match iu32: i32 | u32
match aiu32: ai | i32 | u32
match scalar: f32 | f16 | i32 | u32 | bool
@ -820,43 +822,43 @@ op ! <N: num> (vec<N, bool>) -> vec<N, bool>
@const op ~ <T: aiu32>(T) -> T
@const op ~ <T: aiu32, N: num> (vec<N, T>) -> vec<N, T>
op - <T: fi32>(T) -> T
op - <T: fi32, N: num> (vec<N, T>) -> vec<N, T>
op - <T: fi32f16>(T) -> T
op - <T: fi32f16, N: num> (vec<N, T>) -> vec<N, T>
////////////////////////////////////////////////////////////////////////////////
// Binary Operators //
////////////////////////////////////////////////////////////////////////////////
op + <T: fiu32>(T, T) -> T
op + <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op + <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
op + <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
op + <N: num, M: num> (mat<N, M, f32>, mat<N, M, f32>) -> mat<N, M, f32>
op + <T: fiu32f16>(T, T) -> T
op + <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op + <T: fiu32f16, N: num> (vec<N, T>, T) -> vec<N, T>
op + <T: fiu32f16, N: num> (T, vec<N, T>) -> vec<N, T>
op + <T: f32f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
op - <T: fiu32>(T, T) -> T
op - <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op - <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
op - <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
op - <N: num, M: num> (mat<N, M, f32>, mat<N, M, f32>) -> mat<N, M, f32>
op - <T: fiu32f16>(T, T) -> T
op - <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op - <T: fiu32f16, N: num> (vec<N, T>, T) -> vec<N, T>
op - <T: fiu32f16, N: num> (T, vec<N, T>) -> vec<N, T>
op - <T: f32f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
op * <T: fiu32>(T, T) -> T
op * <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op * <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
op * <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
op * <N: num, M: num> (f32, mat<N, M, f32>) -> mat<N, M, f32>
op * <N: num, M: num> (mat<N, M, f32>, f32) -> mat<N, M, f32>
op * <C: num, R: num> (mat<C, R, f32>, vec<C, f32>) -> vec<R, f32>
op * <C: num, R: num> (vec<R, f32>, mat<C, R, f32>) -> vec<C, f32>
op * <K: num, C: num, R: num> (mat<K, R, f32>, mat<C, K, f32>) -> mat<C, R, f32>
op * <T: fiu32f16>(T, T) -> T
op * <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op * <T: fiu32f16, N: num> (vec<N, T>, T) -> vec<N, T>
op * <T: fiu32f16, N: num> (T, vec<N, T>) -> vec<N, T>
op * <T: f32f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
op * <T: f32f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
op * <T: f32f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
op * <T: f32f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
op * <T: f32f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
op / <T: fiu32>(T, T) -> T
op / <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op / <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
op / <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
op / <T: fiu32f16>(T, T) -> T
op / <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op / <T: fiu32f16, N: num> (vec<N, T>, T) -> vec<N, T>
op / <T: fiu32f16, N: num> (T, vec<N, T>) -> vec<N, T>
op % <T: fiu32>(T, T) -> T
op % <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op % <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
op % <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
op % <T: fiu32f16>(T, T) -> T
op % <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
op % <T: fiu32f16, N: num> (vec<N, T>, T) -> vec<N, T>
op % <T: fiu32f16, N: num> (T, vec<N, T>) -> vec<N, T>
op ^ <T: iu32>(T, T) -> T
op ^ <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
@ -880,17 +882,17 @@ op == <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op != <T: scalar>(T, T) -> bool
op != <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op < <T: fiu32>(T, T) -> bool
op < <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op < <T: fiu32f16>(T, T) -> bool
op < <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op > <T: fiu32>(T, T) -> bool
op > <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op > <T: fiu32f16>(T, T) -> bool
op > <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op <= <T: fiu32>(T, T) -> bool
op <= <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op <= <T: fiu32f16>(T, T) -> bool
op <= <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op >= <T: fiu32>(T, T) -> bool
op >= <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op >= <T: fiu32f16>(T, T) -> bool
op >= <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
op << <T: iu32>(T, u32) -> T
op << <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>

File diff suppressed because it is too large Load Diff

View File

@ -604,8 +604,8 @@ TEST_F(IntrinsicTableTest, MismatchUnaryOp) {
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator - (bool)
2 candidate operators:
operator - (T) -> T where: T is f32 or i32
operator - (vecN<T>) -> vecN<T> where: T is f32 or i32
operator - (T) -> T where: T is f32, f16 or i32
operator - (vecN<T>) -> vecN<T> where: T is f32, f16 or i32
)");
}
@ -629,15 +629,15 @@ TEST_F(IntrinsicTableTest, MismatchBinaryOp) {
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator * (f32, bool)
9 candidate operators:
operator * (T, T) -> T where: T is f32, i32 or u32
operator * (vecN<T>, T) -> vecN<T> where: T is f32, i32 or u32
operator * (T, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
operator * (f32, matNxM<f32>) -> matNxM<f32>
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
operator * (matNxM<f32>, f32) -> matNxM<f32>
operator * (matCxR<f32>, vecC<f32>) -> vecR<f32>
operator * (vecR<f32>, matCxR<f32>) -> vecC<f32>
operator * (matKxR<f32>, matCxK<f32>) -> matCxR<f32>
operator * (T, T) -> T where: T is f32, f16, i32 or u32
operator * (vecN<T>, T) -> vecN<T> where: T is f32, f16, i32 or u32
operator * (T, vecN<T>) -> vecN<T> where: T is f32, f16, i32 or u32
operator * (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
operator * (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, f16, i32 or u32
operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
)");
}
@ -661,15 +661,15 @@ TEST_F(IntrinsicTableTest, MismatchCompoundOp) {
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator *= (f32, bool)
9 candidate operators:
operator *= (T, T) -> T where: T is f32, i32 or u32
operator *= (vecN<T>, T) -> vecN<T> where: T is f32, i32 or u32
operator *= (T, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
operator *= (f32, matNxM<f32>) -> matNxM<f32>
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
operator *= (matNxM<f32>, f32) -> matNxM<f32>
operator *= (matCxR<f32>, vecC<f32>) -> vecR<f32>
operator *= (vecR<f32>, matCxR<f32>) -> vecC<f32>
operator *= (matKxR<f32>, matCxK<f32>) -> matCxR<f32>
operator *= (T, T) -> T where: T is f32, f16, i32 or u32
operator *= (vecN<T>, T) -> vecN<T> where: T is f32, f16, i32 or u32
operator *= (T, vecN<T>) -> vecN<T> where: T is f32, f16, i32 or u32
operator *= (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
operator *= (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, f16, i32 or u32
operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
)");
}

View File

@ -59,6 +59,34 @@ TEST_P(GlslBinaryTest, Emit_f32) {
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), params.result);
}
TEST_P(GlslBinaryTest, Emit_f16) {
auto params = GetParam();
// Skip ops that are illegal for this type
if (params.op == ast::BinaryOp::kAnd || params.op == ast::BinaryOp::kOr ||
params.op == ast::BinaryOp::kXor || params.op == ast::BinaryOp::kShiftLeft ||
params.op == ast::BinaryOp::kShiftRight || params.op == ast::BinaryOp::kModulo) {
return;
}
Enable(ast::Extension::kF16);
GlobalVar("left", ty.f16(), ast::StorageClass::kPrivate);
GlobalVar("right", ty.f16(), ast::StorageClass::kPrivate);
auto* left = Expr("left");
auto* right = Expr("right");
auto* expr = create<ast::BinaryExpression>(params.op, left, right);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), params.result);
}
TEST_P(GlslBinaryTest, Emit_u32) {
auto params = GetParam();
@ -122,7 +150,7 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{"(left / right)", ast::BinaryOp::kDivide},
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
auto* rhs = Expr(1_f);
@ -137,7 +165,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
EXPECT_EQ(out.str(), "(vec3(1.0f) * 1.0f)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
Enable(ast::Extension::kF16);
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
auto* rhs = Expr(1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(f16vec3(1.0hf) * 1.0hf)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
auto* lhs = Expr(1_f);
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
@ -152,7 +197,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
EXPECT_EQ(out.str(), "(1.0f * vec3(1.0f))");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
Enable(ast::Extension::kF16);
auto* lhs = Expr(1_h);
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(1.0hf * f16vec3(1.0hf))");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
auto* lhs = Expr("mat");
auto* rhs = Expr(1_f);
@ -167,7 +229,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
EXPECT_EQ(out.str(), "(mat * 1.0f)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f16) {
Enable(ast::Extension::kF16);
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* lhs = Expr("mat");
auto* rhs = Expr(1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(mat * 1.0hf)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarMatrix_f32) {
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
auto* lhs = Expr(1_f);
auto* rhs = Expr("mat");
@ -182,7 +261,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
EXPECT_EQ(out.str(), "(1.0f * mat)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarMatrix_f16) {
Enable(ast::Extension::kF16);
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* lhs = Expr(1_h);
auto* rhs = Expr("mat");
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(1.0hf * mat)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixVector_f32) {
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
auto* lhs = Expr("mat");
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
@ -197,7 +293,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
EXPECT_EQ(out.str(), "(mat * vec3(1.0f))");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixVector_f16) {
Enable(ast::Extension::kF16);
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* lhs = Expr("mat");
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(mat * f16vec3(1.0hf))");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorMatrix_f32) {
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
auto* rhs = Expr("mat");
@ -212,7 +325,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
EXPECT_EQ(out.str(), "(vec3(1.0f) * mat)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorMatrix_f16) {
Enable(ast::Extension::kF16);
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
auto* rhs = Expr("mat");
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(f16vec3(1.0hf) * mat)");
}
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixMatrix_f32) {
GlobalVar("lhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("rhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
@ -226,28 +356,41 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
EXPECT_EQ(out.str(), "(lhs * rhs)");
}
TEST_F(GlslGeneratorImplTest_Binary, Logical_And) {
GlobalVar("a", ty.bool_(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.bool_(), ast::StorageClass::kPrivate);
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixMatrix_f16) {
Enable(ast::Extension::kF16);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr("a"), Expr("b"));
GlobalVar("lhs", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
GlobalVar("rhs", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("lhs"), Expr("rhs"));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(lhs * rhs)");
}
TEST_F(GlslGeneratorImplTest_Binary, ModF32) {
GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.f32(), ast::StorageClass::kPrivate);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(tint_tmp)");
EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
if (tint_tmp) {
tint_tmp = b;
}
)");
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
}
TEST_F(GlslGeneratorImplTest_Binary, ModF32) {
GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.f32(), ast::StorageClass::kPrivate);
TEST_F(GlslGeneratorImplTest_Binary, ModF16) {
Enable(ast::Extension::kF16);
GlobalVar("a", ty.f16(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.f16(), ast::StorageClass::kPrivate);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
WrapInFunction(expr);
@ -273,6 +416,22 @@ TEST_F(GlslGeneratorImplTest_Binary, ModVec3F32) {
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
}
TEST_F(GlslGeneratorImplTest_Binary, ModVec3F16) {
Enable(ast::Extension::kF16);
GlobalVar("a", ty.vec3<f16>(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.vec3<f16>(), ast::StorageClass::kPrivate);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
}
TEST_F(GlslGeneratorImplTest_Binary, ModVec3F32ScalarF32) {
GlobalVar("a", ty.vec3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.f32(), ast::StorageClass::kPrivate);
@ -287,6 +446,22 @@ TEST_F(GlslGeneratorImplTest_Binary, ModVec3F32ScalarF32) {
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
}
TEST_F(GlslGeneratorImplTest_Binary, ModVec3F16ScalarF16) {
Enable(ast::Extension::kF16);
GlobalVar("a", ty.vec3<f16>(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.f16(), ast::StorageClass::kPrivate);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
}
TEST_F(GlslGeneratorImplTest_Binary, ModScalarF32Vec3F32) {
GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.vec3<f32>(), ast::StorageClass::kPrivate);
@ -301,6 +476,22 @@ TEST_F(GlslGeneratorImplTest_Binary, ModScalarF32Vec3F32) {
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
}
TEST_F(GlslGeneratorImplTest_Binary, ModScalarF16Vec3F16) {
Enable(ast::Extension::kF16);
GlobalVar("a", ty.f16(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.vec3<f16>(), ast::StorageClass::kPrivate);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
}
TEST_F(GlslGeneratorImplTest_Binary, ModMixedVec3ScalarF32) {
GlobalVar("a", ty.vec3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.f32(), ast::StorageClass::kPrivate);
@ -343,6 +534,70 @@ void test_function() {
)");
}
TEST_F(GlslGeneratorImplTest_Binary, ModMixedVec3ScalarF16) {
Enable(ast::Extension::kF16);
GlobalVar("a", ty.vec3<f16>(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.f16(), ast::StorageClass::kPrivate);
auto* expr_vec_mod_vec =
create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("a"));
auto* expr_vec_mod_scalar =
create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
auto* expr_scalar_mod_vec =
create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("b"), Expr("a"));
WrapInFunction(expr_vec_mod_vec, expr_vec_mod_scalar, expr_scalar_mod_vec);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#version 310 es
#extension GL_AMD_gpu_shader_half_float : require
f16vec3 tint_float_modulo(f16vec3 lhs, f16vec3 rhs) {
return (lhs - rhs * trunc(lhs / rhs));
}
f16vec3 tint_float_modulo_1(f16vec3 lhs, float16_t rhs) {
return (lhs - rhs * trunc(lhs / rhs));
}
f16vec3 tint_float_modulo_2(float16_t lhs, f16vec3 rhs) {
return (lhs - rhs * trunc(lhs / rhs));
}
f16vec3 a = f16vec3(0.0hf, 0.0hf, 0.0hf);
float16_t b = 0.0hf;
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
void test_function() {
f16vec3 tint_symbol = tint_float_modulo(a, a);
f16vec3 tint_symbol_1 = tint_float_modulo_1(a, b);
f16vec3 tint_symbol_2 = tint_float_modulo_2(b, a);
return;
}
)");
}
TEST_F(GlslGeneratorImplTest_Binary, Logical_And) {
GlobalVar("a", ty.bool_(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.bool_(), ast::StorageClass::kPrivate);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr("a"), Expr("b"));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(tint_tmp)");
EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
if (tint_tmp) {
tint_tmp = b;
}
)");
}
TEST_F(GlslGeneratorImplTest_Binary, Logical_Multi) {
// (a && b) || (c || d)
GlobalVar("a", ty.bool_(), ast::StorageClass::kPrivate);

View File

@ -66,6 +66,38 @@ TEST_P(HlslBinaryTest, Emit_f32) {
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), params.result);
}
TEST_P(HlslBinaryTest, Emit_f16) {
auto params = GetParam();
if ((params.valid_for & BinaryData::Types::Float) == 0) {
return;
}
// Skip ops that are illegal for this type
if (params.op == ast::BinaryOp::kAnd || params.op == ast::BinaryOp::kOr ||
params.op == ast::BinaryOp::kXor || params.op == ast::BinaryOp::kShiftLeft ||
params.op == ast::BinaryOp::kShiftRight) {
return;
}
Enable(ast::Extension::kF16);
GlobalVar("left", ty.f16(), ast::StorageClass::kPrivate);
GlobalVar("right", ty.f16(), ast::StorageClass::kPrivate);
auto* left = Expr("left");
auto* right = Expr("right");
auto* expr = create<ast::BinaryExpression>(params.op, left, right);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), params.result);
}
TEST_P(HlslBinaryTest, Emit_u32) {
auto params = GetParam();
@ -140,7 +172,7 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{"(left % right)", ast::BinaryOp::kModulo,
BinaryData::Types::Float}));
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
auto* rhs = Expr(1_f);
@ -155,7 +187,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
EXPECT_EQ(out.str(), "((1.0f).xxx * 1.0f)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
Enable(ast::Extension::kF16);
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
auto* rhs = Expr(1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "((float16_t(1.0h)).xxx * float16_t(1.0h))");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
auto* lhs = Expr(1_f);
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
@ -170,7 +219,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
EXPECT_EQ(out.str(), "(1.0f * (1.0f).xxx)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
Enable(ast::Extension::kF16);
auto* lhs = Expr(1_h);
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(float16_t(1.0h) * (float16_t(1.0h)).xxx)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
auto* lhs = Expr("mat");
auto* rhs = Expr(1_f);
@ -185,7 +251,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
EXPECT_EQ(out.str(), "(mat * 1.0f)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f16) {
Enable(ast::Extension::kF16);
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* lhs = Expr("mat");
auto* rhs = Expr(1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(mat * float16_t(1.0h))");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix_f32) {
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
auto* lhs = Expr(1_f);
auto* rhs = Expr("mat");
@ -200,7 +283,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
EXPECT_EQ(out.str(), "(1.0f * mat)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix_f16) {
Enable(ast::Extension::kF16);
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* lhs = Expr(1_h);
auto* rhs = Expr("mat");
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "(float16_t(1.0h) * mat)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector_f32) {
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
auto* lhs = Expr("mat");
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
@ -215,7 +315,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
EXPECT_EQ(out.str(), "mul((1.0f).xxx, mat)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector_f16) {
Enable(ast::Extension::kF16);
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* lhs = Expr("mat");
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "mul((float16_t(1.0h)).xxx, mat)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix_f32) {
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
auto* rhs = Expr("mat");
@ -230,7 +347,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
EXPECT_EQ(out.str(), "mul(mat, (1.0f).xxx)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix_f16) {
Enable(ast::Extension::kF16);
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
auto* rhs = Expr("mat");
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "mul(mat, (float16_t(1.0h)).xxx)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix_f32) {
GlobalVar("lhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
GlobalVar("rhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
@ -244,6 +378,22 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
EXPECT_EQ(out.str(), "mul(rhs, lhs)");
}
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix_f16) {
Enable(ast::Extension::kF16);
GlobalVar("lhs", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
GlobalVar("rhs", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("lhs"), Expr("rhs"));
WrapInFunction(expr);
GeneratorImpl& gen = Build();
std::stringstream out;
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "mul(rhs, lhs)");
}
TEST_F(HlslGeneratorImplTest_Binary, Logical_And) {
GlobalVar("a", ty.bool_(), ast::StorageClass::kPrivate);
GlobalVar("b", ty.bool_(), ast::StorageClass::kPrivate);
@ -559,7 +709,7 @@ TEST_P(HlslGeneratorDivModTest, DivOrModByLiteralZero_u32) {
R"( 1u);
}
)");
} // namespace HlslGeneratorDivMod
}
TEST_P(HlslGeneratorDivModTest, DivOrModByLiteralZero_vec_by_vec_i32) {
Func("fn", {}, ty.void_(),
@ -577,7 +727,7 @@ TEST_P(HlslGeneratorDivModTest, DivOrModByLiteralZero_vec_by_vec_i32) {
R"( int4(50, 1, 25, 1));
}
)");
} // namespace
}
TEST_P(HlslGeneratorDivModTest, DivOrModByLiteralZero_vec_by_scalar_i32) {
Func("fn", {}, ty.void_(),
@ -595,7 +745,7 @@ TEST_P(HlslGeneratorDivModTest, DivOrModByLiteralZero_vec_by_scalar_i32) {
R"( 1);
}
)");
} // namespace hlsl
}
TEST_P(HlslGeneratorDivModTest, DivOrModByIdentifier_i32) {
Func("fn", {Param("b", ty.i32())}, ty.void_(),
@ -613,7 +763,7 @@ TEST_P(HlslGeneratorDivModTest, DivOrModByIdentifier_i32) {
R"( (b == 0 ? 1 : b));
}
)");
} // namespace writer
}
TEST_P(HlslGeneratorDivModTest, DivOrModByIdentifier_u32) {
Func("fn", {Param("b", ty.u32())}, ty.void_(),
@ -631,7 +781,7 @@ TEST_P(HlslGeneratorDivModTest, DivOrModByIdentifier_u32) {
R"( (b == 0u ? 1u : b));
}
)");
} // namespace tint
}
TEST_P(HlslGeneratorDivModTest, DivOrModByIdentifier_vec_by_vec_i32) {
Func("fn", {Param("b", ty.vec3<i32>())}, ty.void_(),

View File

@ -158,6 +158,21 @@ TEST_F(MslBinaryTest, ModF32) {
EXPECT_EQ(out.str(), "fmod(left, right)");
}
TEST_F(MslBinaryTest, ModF16) {
Enable(ast::Extension::kF16);
auto* left = Var("left", ty.f16());
auto* right = Var("right", ty.f16());
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr(left), Expr(right));
WrapInFunction(left, right, expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "fmod(left, right)");
}
TEST_F(MslBinaryTest, ModVec3F32) {
auto* left = Var("left", ty.vec3<f32>());
auto* right = Var("right", ty.vec3<f32>());
@ -171,6 +186,21 @@ TEST_F(MslBinaryTest, ModVec3F32) {
EXPECT_EQ(out.str(), "fmod(left, right)");
}
TEST_F(MslBinaryTest, ModVec3F16) {
Enable(ast::Extension::kF16);
auto* left = Var("left", ty.vec3<f16>());
auto* right = Var("right", ty.vec3<f16>());
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr(left), Expr(right));
WrapInFunction(left, right, expr);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
EXPECT_EQ(out.str(), "fmod(left, right)");
}
TEST_F(MslBinaryTest, BoolAnd) {
auto* left = Var("left", nullptr, Expr(true));
auto* right = Var("right", nullptr, Expr(false));

View File

@ -191,8 +191,8 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{ast::BinaryOp::kSubtract, "OpISub"},
BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"}));
using BinaryArithFloatTest = TestParamHelper<BinaryData>;
TEST_P(BinaryArithFloatTest, Scalar) {
using BinaryArithF32Test = TestParamHelper<BinaryData>;
TEST_P(BinaryArithF32Test, Scalar) {
auto param = GetParam();
auto* lhs = Expr(3.2_f);
@ -215,7 +215,7 @@ TEST_P(BinaryArithFloatTest, Scalar) {
"%4 = " + param.name + " %1 %2 %3\n");
}
TEST_P(BinaryArithFloatTest, Vector) {
TEST_P(BinaryArithF32Test, Vector) {
auto param = GetParam();
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
@ -239,7 +239,66 @@ TEST_P(BinaryArithFloatTest, Vector) {
"%5 = " + param.name + " %1 %4 %4\n");
}
INSTANTIATE_TEST_SUITE_P(BuilderTest,
BinaryArithFloatTest,
BinaryArithF32Test,
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpFRem"},
BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
using BinaryArithF16Test = TestParamHelper<BinaryData>;
TEST_P(BinaryArithF16Test, Scalar) {
Enable(ast::Extension::kF16);
auto param = GetParam();
auto* lhs = Expr(3.2_h);
auto* rhs = Expr(4.5_h);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 16
%2 = OpConstant %1 0x1.998p+1
%3 = OpConstant %1 0x1.2p+2
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%4 = " + param.name + " %1 %2 %3\n");
}
TEST_P(BinaryArithF16Test, Vector) {
Enable(ast::Extension::kF16);
auto param = GetParam();
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 16
%1 = OpTypeVector %2 3
%3 = OpConstant %2 0x1p+0
%4 = OpConstantComposite %1 %3 %3 %3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = " + param.name + " %1 %4 %4\n");
}
INSTANTIATE_TEST_SUITE_P(BuilderTest,
BinaryArithF16Test,
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
BinaryData{ast::BinaryOp::kModulo, "OpFRem"},
@ -422,8 +481,8 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{ast::BinaryOp::kLessThanEqual, "OpSLessThanEqual"},
BinaryData{ast::BinaryOp::kNotEqual, "OpINotEqual"}));
using BinaryCompareFloatTest = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareFloatTest, Scalar) {
using BinaryCompareF32Test = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareF32Test, Scalar) {
auto param = GetParam();
auto* lhs = Expr(3.2_f);
@ -447,7 +506,7 @@ TEST_P(BinaryCompareFloatTest, Scalar) {
"%4 = " + param.name + " %5 %2 %3\n");
}
TEST_P(BinaryCompareFloatTest, Vector) {
TEST_P(BinaryCompareF32Test, Vector) {
auto param = GetParam();
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
@ -474,7 +533,7 @@ TEST_P(BinaryCompareFloatTest, Vector) {
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
BinaryCompareFloatTest,
BinaryCompareF32Test,
testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpFOrdEqual"},
BinaryData{ast::BinaryOp::kGreaterThan, "OpFOrdGreaterThan"},
BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual"},
@ -482,7 +541,71 @@ INSTANTIATE_TEST_SUITE_P(
BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
using BinaryCompareF16Test = TestParamHelper<BinaryData>;
TEST_P(BinaryCompareF16Test, Scalar) {
Enable(ast::Extension::kF16);
auto param = GetParam();
auto* lhs = Expr(3.2_h);
auto* rhs = Expr(4.5_h);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 16
%2 = OpConstant %1 0x1.998p+1
%3 = OpConstant %1 0x1.2p+2
%5 = OpTypeBool
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%4 = " + param.name + " %5 %2 %3\n");
}
TEST_P(BinaryCompareF16Test, Vector) {
Enable(ast::Extension::kF16);
auto param = GetParam();
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 16
%1 = OpTypeVector %2 3
%3 = OpConstant %2 0x1p+0
%4 = OpConstantComposite %1 %3 %3 %3
%7 = OpTypeBool
%6 = OpTypeVector %7 3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = " + param.name + " %6 %4 %4\n");
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
BinaryCompareF16Test,
testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpFOrdEqual"},
BinaryData{ast::BinaryOp::kGreaterThan, "OpFOrdGreaterThan"},
BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual"},
BinaryData{ast::BinaryOp::kLessThan, "OpFOrdLessThan"},
BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
TEST_F(BuilderTest, Binary_Multiply_VectorScalar_F32) {
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
auto* rhs = Expr(1_f);
@ -505,7 +628,32 @@ TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
"%5 = OpVectorTimesScalar %1 %4 %3\n");
}
TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
TEST_F(BuilderTest, Binary_Multiply_VectorScalar_F16) {
Enable(ast::Extension::kF16);
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
auto* rhs = Expr(1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()),
R"(%2 = OpTypeFloat 16
%1 = OpTypeVector %2 3
%3 = OpConstant %2 0x1p+0
%4 = OpConstantComposite %1 %3 %3 %3
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = OpVectorTimesScalar %1 %4 %3\n");
}
TEST_F(BuilderTest, Binary_Multiply_ScalarVector_F32) {
auto* lhs = Expr(1_f);
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
@ -528,7 +676,32 @@ TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
"%5 = OpVectorTimesScalar %3 %4 %2\n");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
TEST_F(BuilderTest, Binary_Multiply_ScalarVector_F16) {
Enable(ast::Extension::kF16);
auto* lhs = Expr(1_h);
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
b.push_function(Function{});
EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()),
R"(%1 = OpTypeFloat 16
%2 = OpConstant %1 0x1p+0
%3 = OpTypeVector %1 3
%4 = OpConstantComposite %3 %2 %2 %2
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
"%5 = OpVectorTimesScalar %3 %4 %2\n");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixScalar_F32) {
auto* var = Var("mat", ty.mat3x3<f32>());
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), Expr(1_f));
@ -555,7 +728,36 @@ TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
)");
}
TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
TEST_F(BuilderTest, Binary_Multiply_MatrixScalar_F16) {
Enable(ast::Extension::kF16);
auto* var = Var("mat", ty.mat3x3<f16>());
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), Expr(1_h));
WrapInFunction(var, expr);
spirv::Builder& b = Build();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()),
R"(%5 = OpTypeFloat 16
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%7 = OpConstant %5 0x1p+0
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpLoad %3 %1
%8 = OpMatrixTimesScalar %3 %6 %7
)");
}
TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix_F32) {
auto* var = Var("mat", ty.mat3x3<f32>());
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr(1_f), Expr("mat"));
@ -582,7 +784,36 @@ TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
)");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix_F16) {
Enable(ast::Extension::kF16);
auto* var = Var("mat", ty.mat3x3<f16>());
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr(1_h), Expr("mat"));
WrapInFunction(var, expr);
spirv::Builder& b = Build();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()),
R"(%5 = OpTypeFloat 16
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%6 = OpConstant %5 0x1p+0
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%7 = OpLoad %3 %1
%8 = OpMatrixTimesScalar %3 %7 %6
)");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixVector_F32) {
auto* var = Var("mat", ty.mat3x3<f32>());
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
@ -611,7 +842,38 @@ TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
)");
}
TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
TEST_F(BuilderTest, Binary_Multiply_MatrixVector_F16) {
Enable(ast::Extension::kF16);
auto* var = Var("mat", ty.mat3x3<f16>());
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), rhs);
WrapInFunction(var, expr);
spirv::Builder& b = Build();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(expr), 9u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()),
R"(%5 = OpTypeFloat 16
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%7 = OpConstant %5 0x1p+0
%8 = OpConstantComposite %4 %7 %7 %7
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpLoad %3 %1
%9 = OpMatrixTimesVector %4 %6 %8
)");
}
TEST_F(BuilderTest, Binary_Multiply_VectorMatrix_F32) {
auto* var = Var("mat", ty.mat3x3<f32>());
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
@ -640,7 +902,38 @@ TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
)");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
TEST_F(BuilderTest, Binary_Multiply_VectorMatrix_F16) {
Enable(ast::Extension::kF16);
auto* var = Var("mat", ty.mat3x3<f16>());
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, Expr("mat"));
WrapInFunction(var, expr);
spirv::Builder& b = Build();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(expr), 9u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()),
R"(%5 = OpTypeFloat 16
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
%6 = OpConstant %5 0x1p+0
%7 = OpConstantComposite %4 %6 %6 %6
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%8 = OpLoad %3 %1
%9 = OpVectorTimesMatrix %4 %7 %8
)");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix_F32) {
auto* var = Var("mat", ty.mat3x3<f32>());
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), Expr("mat"));
@ -667,6 +960,35 @@ TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
)");
}
TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix_F16) {
Enable(ast::Extension::kF16);
auto* var = Var("mat", ty.mat3x3<f16>());
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), Expr("mat"));
WrapInFunction(var, expr);
spirv::Builder& b = Build();
b.push_function(Function{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
EXPECT_EQ(DumpInstructions(b.types()),
R"(%5 = OpTypeFloat 16
%4 = OpTypeVector %5 3
%3 = OpTypeMatrix %4 3
%2 = OpTypePointer Function %3
%1 = OpVariable %2 Function
)");
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
R"(%6 = OpLoad %3 %1
%7 = OpLoad %3 %1
%8 = OpMatrixTimesMatrix %3 %6 %7
)");
}
TEST_F(BuilderTest, Binary_LogicalAnd) {
auto* lhs = create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(1_i), Expr(2_i));
@ -895,11 +1217,13 @@ OpBranch %9
namespace BinaryArithVectorScalar {
enum class Type { f32, i32, u32 };
enum class Type { f32, f16, i32, u32 };
static const ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
switch (type) {
case Type::f32:
return builder->vec3<f32>(1_f, 1_f, 1_f);
case Type::f16:
return builder->vec3<f16>(1_h, 1_h, 1_h);
case Type::i32:
return builder->vec3<i32>(1_i, 1_i, 1_i);
case Type::u32:
@ -911,6 +1235,8 @@ static const ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type)
switch (type) {
case Type::f32:
return builder->Expr(1_f);
case Type::f16:
return builder->Expr(1_h);
case Type::i32:
return builder->Expr(1_i);
case Type::u32:
@ -922,6 +1248,8 @@ static std::string OpTypeDecl(Type type) {
switch (type) {
case Type::f32:
return "OpTypeFloat 32";
case Type::f16:
return "OpTypeFloat 16";
case Type::i32:
return "OpTypeInt 32 1";
case Type::u32:
@ -929,6 +1257,32 @@ static std::string OpTypeDecl(Type type) {
}
return {};
}
static std::string ConstantValue(Type type) {
switch (type) {
case Type::f32:
case Type::i32:
case Type::u32:
return "1";
case Type::f16:
return "0x1p+0";
}
return {};
}
static std::string CapabilityDecl(Type type) {
switch (type) {
case Type::f32:
case Type::i32:
case Type::u32:
return "OpCapability Shader";
case Type::f16:
return R"(OpCapability Shader
OpCapability Float16
OpCapability UniformAndStorageBuffer16BitAccess
OpCapability StorageBuffer16BitAccess
OpCapability StorageInputOutput16)";
}
return {};
}
struct Param {
Type type;
@ -940,9 +1294,15 @@ using BinaryArithVectorScalarTest = TestParamHelper<Param>;
TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
auto& param = GetParam();
if (param.type == Type::f16) {
Enable(ast::Extension::kF16);
}
const ast::Expression* lhs = MakeVectorExpr(this, param.type);
const ast::Expression* rhs = MakeScalarExpr(this, param.type);
std::string op_type_decl = OpTypeDecl(param.type);
std::string constant_value = ConstantValue(param.type);
std::string capability_decl = CapabilityDecl(param.type);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
@ -951,7 +1311,7 @@ TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
@ -960,7 +1320,8 @@ OpName %3 "test_function"
%1 = OpTypeFunction %2
%6 = )" + op_type_decl + R"(
%5 = OpTypeVector %6 3
%7 = OpConstant %6 1
%7 = OpConstant %6 )" + constant_value +
R"(
%8 = OpConstantComposite %5 %7 %7 %7
%11 = OpTypePointer Function %5
%12 = OpConstantNull %5
@ -978,9 +1339,15 @@ OpFunctionEnd
TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
auto& param = GetParam();
if (param.type == Type::f16) {
Enable(ast::Extension::kF16);
}
const ast::Expression* lhs = MakeScalarExpr(this, param.type);
const ast::Expression* rhs = MakeVectorExpr(this, param.type);
std::string op_type_decl = OpTypeDecl(param.type);
std::string constant_value = ConstantValue(param.type);
std::string capability_decl = CapabilityDecl(param.type);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
@ -989,7 +1356,7 @@ TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
@ -997,7 +1364,8 @@ OpName %3 "test_function"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%5 = )" + op_type_decl + R"(
%6 = OpConstant %5 1
%6 = OpConstant %5 )" + constant_value +
R"(
%7 = OpTypeVector %5 3
%8 = OpConstantComposite %7 %6 %6 %6
%11 = OpTypePointer Function %7
@ -1024,6 +1392,10 @@ INSTANTIATE_TEST_SUITE_P(BuilderTest,
// Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
Param{Type::f32, ast::BinaryOp::kSubtract, "OpFSub"},
Param{Type::f16, ast::BinaryOp::kAdd, "OpFAdd"},
Param{Type::f16, ast::BinaryOp::kDivide, "OpFDiv"},
Param{Type::f16, ast::BinaryOp::kSubtract, "OpFSub"},
Param{Type::i32, ast::BinaryOp::kAdd, "OpIAdd"},
Param{Type::i32, ast::BinaryOp::kDivide, "OpSDiv"},
Param{Type::i32, ast::BinaryOp::kModulo, "OpSMod"},
@ -1040,9 +1412,15 @@ using BinaryArithVectorScalarMultiplyTest = TestParamHelper<Param>;
TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
auto& param = GetParam();
if (param.type == Type::f16) {
Enable(ast::Extension::kF16);
}
const ast::Expression* lhs = MakeVectorExpr(this, param.type);
const ast::Expression* rhs = MakeScalarExpr(this, param.type);
std::string op_type_decl = OpTypeDecl(param.type);
std::string constant_value = ConstantValue(param.type);
std::string capability_decl = CapabilityDecl(param.type);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
@ -1051,7 +1429,7 @@ TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
@ -1060,7 +1438,8 @@ OpName %3 "test_function"
%1 = OpTypeFunction %2
%6 = )" + op_type_decl + R"(
%5 = OpTypeVector %6 3
%7 = OpConstant %6 1
%7 = OpConstant %6 )" + constant_value +
R"(
%8 = OpConstantComposite %5 %7 %7 %7
%3 = OpFunction %2 None %1
%4 = OpLabel
@ -1074,9 +1453,15 @@ OpFunctionEnd
TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
auto& param = GetParam();
if (param.type == Type::f16) {
Enable(ast::Extension::kF16);
}
const ast::Expression* lhs = MakeScalarExpr(this, param.type);
const ast::Expression* rhs = MakeVectorExpr(this, param.type);
std::string op_type_decl = OpTypeDecl(param.type);
std::string constant_value = ConstantValue(param.type);
std::string capability_decl = CapabilityDecl(param.type);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
@ -1085,7 +1470,7 @@ TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
@ -1093,7 +1478,8 @@ OpName %3 "test_function"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%5 = )" + op_type_decl + R"(
%6 = OpConstant %5 1
%6 = OpConstant %5 )" + constant_value +
R"(
%7 = OpTypeVector %5 3
%8 = OpConstantComposite %7 %6 %6 %6
%3 = OpFunction %2 None %1
@ -1107,13 +1493,57 @@ OpFunctionEnd
}
INSTANTIATE_TEST_SUITE_P(BuilderTest,
BinaryArithVectorScalarMultiplyTest,
testing::Values(Param{Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
testing::Values(Param{Type::f32, ast::BinaryOp::kMultiply, "OpFMul"},
Param{Type::f16, ast::BinaryOp::kMultiply, "OpFMul"}));
} // namespace BinaryArithVectorScalar
namespace BinaryArithMatrixMatrix {
enum class Type { f32, f16 };
static const ast::Expression* MakeMat3x4Expr(ProgramBuilder* builder, Type type) {
switch (type) {
case Type::f32:
return builder->mat3x4<f32>();
case Type::f16:
return builder->mat3x4<f16>();
}
return nullptr;
}
static const ast::Expression* MakeMat4x3Expr(ProgramBuilder* builder, Type type) {
switch (type) {
case Type::f32:
return builder->mat4x3<f32>();
case Type::f16:
return builder->mat4x3<f16>();
}
return nullptr;
}
static std::string OpTypeDecl(Type type) {
switch (type) {
case Type::f32:
return "OpTypeFloat 32";
case Type::f16:
return "OpTypeFloat 16";
}
return {};
}
static std::string CapabilityDecl(Type type) {
switch (type) {
case Type::f32:
return "OpCapability Shader";
case Type::f16:
return R"(OpCapability Shader
OpCapability Float16
OpCapability UniformAndStorageBuffer16BitAccess
OpCapability StorageBuffer16BitAccess
OpCapability StorageInputOutput16)";
}
return {};
}
struct Param {
Type type;
ast::BinaryOp op;
std::string name;
};
@ -1122,8 +1552,14 @@ using BinaryArithMatrixMatrix = TestParamHelper<Param>;
TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
auto& param = GetParam();
const ast::Expression* lhs = mat3x4<f32>();
const ast::Expression* rhs = mat3x4<f32>();
if (param.type == Type::f16) {
Enable(ast::Extension::kF16);
}
const ast::Expression* lhs = MakeMat3x4Expr(this, param.type);
const ast::Expression* rhs = MakeMat3x4Expr(this, param.type);
std::string op_type_decl = OpTypeDecl(param.type);
std::string capability_decl = CapabilityDecl(param.type);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
@ -1132,14 +1568,14 @@ TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
OpName %3 "test_function"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%7 = OpTypeFloat 32
%7 = )" + op_type_decl + R"(
%6 = OpTypeVector %7 4
%5 = OpTypeMatrix %6 3
%8 = OpConstantNull %5
@ -1164,15 +1600,23 @@ OpFunctionEnd
INSTANTIATE_TEST_SUITE_P( //
BuilderTest,
BinaryArithMatrixMatrix,
testing::Values(Param{ast::BinaryOp::kAdd, "OpFAdd"},
Param{ast::BinaryOp::kSubtract, "OpFSub"}));
testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
Param{Type::f32, ast::BinaryOp::kSubtract, "OpFSub"},
Param{Type::f16, ast::BinaryOp::kAdd, "OpFAdd"},
Param{Type::f16, ast::BinaryOp::kSubtract, "OpFSub"}));
using BinaryArithMatrixMatrixMultiply = TestParamHelper<Param>;
TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
auto& param = GetParam();
const ast::Expression* lhs = mat3x4<f32>();
const ast::Expression* rhs = mat4x3<f32>();
if (param.type == Type::f16) {
Enable(ast::Extension::kF16);
}
const ast::Expression* lhs = MakeMat3x4Expr(this, param.type);
const ast::Expression* rhs = MakeMat4x3Expr(this, param.type);
std::string op_type_decl = OpTypeDecl(param.type);
std::string capability_decl = CapabilityDecl(param.type);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
@ -1181,14 +1625,14 @@ TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
OpName %3 "test_function"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%7 = OpTypeFloat 32
%7 = )" + op_type_decl + R"(
%6 = OpTypeVector %7 4
%5 = OpTypeMatrix %6 3
%8 = OpConstantNull %5
@ -1208,7 +1652,8 @@ OpFunctionEnd
INSTANTIATE_TEST_SUITE_P( //
BuilderTest,
BinaryArithMatrixMatrixMultiply,
testing::Values(Param{ast::BinaryOp::kMultiply, "OpFMul"}));
testing::Values(Param{Type::f32, ast::BinaryOp::kMultiply, ""},
Param{Type::f16, ast::BinaryOp::kMultiply, ""}));
} // namespace BinaryArithMatrixMatrix