tint: Add f16 support for operator
This patch add support for using f16 types in unary operator `-` and binary operator `+`, `-`, `*`, `/`, `%`, `<`, `>`, `<=`, and `>=`. `==` is already supported. Unittests are also implemented. Bug: tint:1473, tint:1502 Change-Id: I1123fa5e9e586ec0d8522b0f6bacafb4ad53ffcf Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/96380 Auto-Submit: Zhaoming Jiang <zhaoming.jiang@intel.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
cd74244614
commit
59e640b208
|
@ -129,7 +129,9 @@ type __atomic_compare_exchange_result<T>
|
|||
|
||||
match f32f16: f32 | f16
|
||||
match fiu32: f32 | i32 | u32
|
||||
match fiu32f16: f32 | f16 | i32 | u32
|
||||
match fi32: f32 | i32
|
||||
match fi32f16: f32 | f16 | i32
|
||||
match iu32: i32 | u32
|
||||
match aiu32: ai | i32 | u32
|
||||
match scalar: f32 | f16 | i32 | u32 | bool
|
||||
|
@ -820,43 +822,43 @@ op ! <N: num> (vec<N, bool>) -> vec<N, bool>
|
|||
@const op ~ <T: aiu32>(T) -> T
|
||||
@const op ~ <T: aiu32, N: num> (vec<N, T>) -> vec<N, T>
|
||||
|
||||
op - <T: fi32>(T) -> T
|
||||
op - <T: fi32, N: num> (vec<N, T>) -> vec<N, T>
|
||||
op - <T: fi32f16>(T) -> T
|
||||
op - <T: fi32f16, N: num> (vec<N, T>) -> vec<N, T>
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
// Binary Operators //
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
op + <T: fiu32>(T, T) -> T
|
||||
op + <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op + <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op + <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op + <N: num, M: num> (mat<N, M, f32>, mat<N, M, f32>) -> mat<N, M, f32>
|
||||
op + <T: fiu32f16>(T, T) -> T
|
||||
op + <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op + <T: fiu32f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op + <T: fiu32f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op + <T: f32f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
|
||||
|
||||
op - <T: fiu32>(T, T) -> T
|
||||
op - <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op - <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op - <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op - <N: num, M: num> (mat<N, M, f32>, mat<N, M, f32>) -> mat<N, M, f32>
|
||||
op - <T: fiu32f16>(T, T) -> T
|
||||
op - <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op - <T: fiu32f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op - <T: fiu32f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op - <T: f32f16, N: num, M: num> (mat<N, M, T>, mat<N, M, T>) -> mat<N, M, T>
|
||||
|
||||
op * <T: fiu32>(T, T) -> T
|
||||
op * <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op * <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op * <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op * <N: num, M: num> (f32, mat<N, M, f32>) -> mat<N, M, f32>
|
||||
op * <N: num, M: num> (mat<N, M, f32>, f32) -> mat<N, M, f32>
|
||||
op * <C: num, R: num> (mat<C, R, f32>, vec<C, f32>) -> vec<R, f32>
|
||||
op * <C: num, R: num> (vec<R, f32>, mat<C, R, f32>) -> vec<C, f32>
|
||||
op * <K: num, C: num, R: num> (mat<K, R, f32>, mat<C, K, f32>) -> mat<C, R, f32>
|
||||
op * <T: fiu32f16>(T, T) -> T
|
||||
op * <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op * <T: fiu32f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op * <T: fiu32f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op * <T: f32f16, N: num, M: num> (T, mat<N, M, T>) -> mat<N, M, T>
|
||||
op * <T: f32f16, N: num, M: num> (mat<N, M, T>, T) -> mat<N, M, T>
|
||||
op * <T: f32f16, C: num, R: num> (mat<C, R, T>, vec<C, T>) -> vec<R, T>
|
||||
op * <T: f32f16, C: num, R: num> (vec<R, T>, mat<C, R, T>) -> vec<C, T>
|
||||
op * <T: f32f16, K: num, C: num, R: num> (mat<K, R, T>, mat<C, K, T>) -> mat<C, R, T>
|
||||
|
||||
op / <T: fiu32>(T, T) -> T
|
||||
op / <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op / <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op / <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op / <T: fiu32f16>(T, T) -> T
|
||||
op / <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op / <T: fiu32f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op / <T: fiu32f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
|
||||
op % <T: fiu32>(T, T) -> T
|
||||
op % <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op % <T: fiu32, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op % <T: fiu32, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
op % <T: fiu32f16>(T, T) -> T
|
||||
op % <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
op % <T: fiu32f16, N: num> (vec<N, T>, T) -> vec<N, T>
|
||||
op % <T: fiu32f16, N: num> (T, vec<N, T>) -> vec<N, T>
|
||||
|
||||
op ^ <T: iu32>(T, T) -> T
|
||||
op ^ <T: iu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, T>
|
||||
|
@ -880,17 +882,17 @@ op == <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
|||
op != <T: scalar>(T, T) -> bool
|
||||
op != <T: scalar, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||
|
||||
op < <T: fiu32>(T, T) -> bool
|
||||
op < <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||
op < <T: fiu32f16>(T, T) -> bool
|
||||
op < <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||
|
||||
op > <T: fiu32>(T, T) -> bool
|
||||
op > <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||
op > <T: fiu32f16>(T, T) -> bool
|
||||
op > <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||
|
||||
op <= <T: fiu32>(T, T) -> bool
|
||||
op <= <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||
op <= <T: fiu32f16>(T, T) -> bool
|
||||
op <= <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||
|
||||
op >= <T: fiu32>(T, T) -> bool
|
||||
op >= <T: fiu32, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||
op >= <T: fiu32f16>(T, T) -> bool
|
||||
op >= <T: fiu32f16, N: num> (vec<N, T>, vec<N, T>) -> vec<N, bool>
|
||||
|
||||
op << <T: iu32>(T, u32) -> T
|
||||
op << <T: iu32, N: num> (vec<N, T>, vec<N, u32>) -> vec<N, T>
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -604,8 +604,8 @@ TEST_F(IntrinsicTableTest, MismatchUnaryOp) {
|
|||
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator - (bool)
|
||||
|
||||
2 candidate operators:
|
||||
operator - (T) -> T where: T is f32 or i32
|
||||
operator - (vecN<T>) -> vecN<T> where: T is f32 or i32
|
||||
operator - (T) -> T where: T is f32, f16 or i32
|
||||
operator - (vecN<T>) -> vecN<T> where: T is f32, f16 or i32
|
||||
)");
|
||||
}
|
||||
|
||||
|
@ -629,15 +629,15 @@ TEST_F(IntrinsicTableTest, MismatchBinaryOp) {
|
|||
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator * (f32, bool)
|
||||
|
||||
9 candidate operators:
|
||||
operator * (T, T) -> T where: T is f32, i32 or u32
|
||||
operator * (vecN<T>, T) -> vecN<T> where: T is f32, i32 or u32
|
||||
operator * (T, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
|
||||
operator * (f32, matNxM<f32>) -> matNxM<f32>
|
||||
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
|
||||
operator * (matNxM<f32>, f32) -> matNxM<f32>
|
||||
operator * (matCxR<f32>, vecC<f32>) -> vecR<f32>
|
||||
operator * (vecR<f32>, matCxR<f32>) -> vecC<f32>
|
||||
operator * (matKxR<f32>, matCxK<f32>) -> matCxR<f32>
|
||||
operator * (T, T) -> T where: T is f32, f16, i32 or u32
|
||||
operator * (vecN<T>, T) -> vecN<T> where: T is f32, f16, i32 or u32
|
||||
operator * (T, vecN<T>) -> vecN<T> where: T is f32, f16, i32 or u32
|
||||
operator * (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
|
||||
operator * (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
|
||||
operator * (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, f16, i32 or u32
|
||||
operator * (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
|
||||
operator * (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
|
||||
operator * (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
|
@ -661,15 +661,15 @@ TEST_F(IntrinsicTableTest, MismatchCompoundOp) {
|
|||
EXPECT_EQ(Diagnostics().str(), R"(12:34 error: no matching overload for operator *= (f32, bool)
|
||||
|
||||
9 candidate operators:
|
||||
operator *= (T, T) -> T where: T is f32, i32 or u32
|
||||
operator *= (vecN<T>, T) -> vecN<T> where: T is f32, i32 or u32
|
||||
operator *= (T, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
|
||||
operator *= (f32, matNxM<f32>) -> matNxM<f32>
|
||||
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, i32 or u32
|
||||
operator *= (matNxM<f32>, f32) -> matNxM<f32>
|
||||
operator *= (matCxR<f32>, vecC<f32>) -> vecR<f32>
|
||||
operator *= (vecR<f32>, matCxR<f32>) -> vecC<f32>
|
||||
operator *= (matKxR<f32>, matCxK<f32>) -> matCxR<f32>
|
||||
operator *= (T, T) -> T where: T is f32, f16, i32 or u32
|
||||
operator *= (vecN<T>, T) -> vecN<T> where: T is f32, f16, i32 or u32
|
||||
operator *= (T, vecN<T>) -> vecN<T> where: T is f32, f16, i32 or u32
|
||||
operator *= (T, matNxM<T>) -> matNxM<T> where: T is f32 or f16
|
||||
operator *= (matNxM<T>, T) -> matNxM<T> where: T is f32 or f16
|
||||
operator *= (vecN<T>, vecN<T>) -> vecN<T> where: T is f32, f16, i32 or u32
|
||||
operator *= (matCxR<T>, vecC<T>) -> vecR<T> where: T is f32 or f16
|
||||
operator *= (vecR<T>, matCxR<T>) -> vecC<T> where: T is f32 or f16
|
||||
operator *= (matKxR<T>, matCxK<T>) -> matCxR<T> where: T is f32 or f16
|
||||
)");
|
||||
}
|
||||
|
||||
|
|
|
@ -59,6 +59,34 @@ TEST_P(GlslBinaryTest, Emit_f32) {
|
|||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), params.result);
|
||||
}
|
||||
TEST_P(GlslBinaryTest, Emit_f16) {
|
||||
auto params = GetParam();
|
||||
|
||||
// Skip ops that are illegal for this type
|
||||
if (params.op == ast::BinaryOp::kAnd || params.op == ast::BinaryOp::kOr ||
|
||||
params.op == ast::BinaryOp::kXor || params.op == ast::BinaryOp::kShiftLeft ||
|
||||
params.op == ast::BinaryOp::kShiftRight || params.op == ast::BinaryOp::kModulo) {
|
||||
return;
|
||||
}
|
||||
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("left", ty.f16(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("right", ty.f16(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* left = Expr("left");
|
||||
auto* right = Expr("right");
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(params.op, left, right);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), params.result);
|
||||
}
|
||||
TEST_P(GlslBinaryTest, Emit_u32) {
|
||||
auto params = GetParam();
|
||||
|
||||
|
@ -122,7 +150,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
BinaryData{"(left / right)", ast::BinaryOp::kDivide},
|
||||
BinaryData{"(left % right)", ast::BinaryOp::kModulo}));
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
|
||||
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
auto* rhs = Expr(1_f);
|
||||
|
||||
|
@ -137,7 +165,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
|
|||
EXPECT_EQ(out.str(), "(vec3(1.0f) * 1.0f)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
auto* rhs = Expr(1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(f16vec3(1.0hf) * 1.0hf)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
|
||||
auto* lhs = Expr(1_f);
|
||||
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
|
||||
|
@ -152,7 +197,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
|
|||
EXPECT_EQ(out.str(), "(1.0f * vec3(1.0f))");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* lhs = Expr(1_h);
|
||||
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(1.0hf * f16vec3(1.0hf))");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {
|
||||
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("mat");
|
||||
auto* rhs = Expr(1_f);
|
||||
|
@ -167,7 +229,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
|
|||
EXPECT_EQ(out.str(), "(mat * 1.0f)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("mat");
|
||||
auto* rhs = Expr(1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(mat * 1.0hf)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarMatrix_f32) {
|
||||
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr(1_f);
|
||||
auto* rhs = Expr("mat");
|
||||
|
@ -182,7 +261,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
|
|||
EXPECT_EQ(out.str(), "(1.0f * mat)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_ScalarMatrix_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr(1_h);
|
||||
auto* rhs = Expr("mat");
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(1.0hf * mat)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixVector_f32) {
|
||||
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("mat");
|
||||
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
|
@ -197,7 +293,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
|
|||
EXPECT_EQ(out.str(), "(mat * vec3(1.0f))");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixVector_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("mat");
|
||||
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(mat * f16vec3(1.0hf))");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorMatrix_f32) {
|
||||
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
auto* rhs = Expr("mat");
|
||||
|
@ -212,7 +325,24 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
|
|||
EXPECT_EQ(out.str(), "(vec3(1.0f) * mat)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_VectorMatrix_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
auto* rhs = Expr("mat");
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(f16vec3(1.0hf) * mat)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixMatrix_f32) {
|
||||
GlobalVar("lhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("rhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
|
||||
|
@ -226,28 +356,41 @@ TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
|
|||
EXPECT_EQ(out.str(), "(lhs * rhs)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Logical_And) {
|
||||
GlobalVar("a", ty.bool_(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.bool_(), ast::StorageClass::kPrivate);
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Multiply_MatrixMatrix_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr("a"), Expr("b"));
|
||||
GlobalVar("lhs", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("rhs", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("lhs"), Expr("rhs"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(lhs * rhs)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, ModF32) {
|
||||
GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.f32(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(tint_tmp)");
|
||||
EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
|
||||
if (tint_tmp) {
|
||||
tint_tmp = b;
|
||||
}
|
||||
)");
|
||||
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, ModF32) {
|
||||
GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.f32(), ast::StorageClass::kPrivate);
|
||||
TEST_F(GlslGeneratorImplTest_Binary, ModF16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("a", ty.f16(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.f16(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
|
||||
WrapInFunction(expr);
|
||||
|
@ -273,6 +416,22 @@ TEST_F(GlslGeneratorImplTest_Binary, ModVec3F32) {
|
|||
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, ModVec3F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("a", ty.vec3<f16>(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.vec3<f16>(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, ModVec3F32ScalarF32) {
|
||||
GlobalVar("a", ty.vec3<f32>(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.f32(), ast::StorageClass::kPrivate);
|
||||
|
@ -287,6 +446,22 @@ TEST_F(GlslGeneratorImplTest_Binary, ModVec3F32ScalarF32) {
|
|||
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, ModVec3F16ScalarF16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("a", ty.vec3<f16>(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.f16(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, ModScalarF32Vec3F32) {
|
||||
GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.vec3<f32>(), ast::StorageClass::kPrivate);
|
||||
|
@ -301,6 +476,22 @@ TEST_F(GlslGeneratorImplTest_Binary, ModScalarF32Vec3F32) {
|
|||
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, ModScalarF16Vec3F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("a", ty.f16(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.vec3<f16>(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "tint_float_modulo(a, b)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, ModMixedVec3ScalarF32) {
|
||||
GlobalVar("a", ty.vec3<f32>(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.f32(), ast::StorageClass::kPrivate);
|
||||
|
@ -343,6 +534,70 @@ void test_function() {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, ModMixedVec3ScalarF16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("a", ty.vec3<f16>(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.f16(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* expr_vec_mod_vec =
|
||||
create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("a"));
|
||||
auto* expr_vec_mod_scalar =
|
||||
create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("a"), Expr("b"));
|
||||
auto* expr_scalar_mod_vec =
|
||||
create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr("b"), Expr("a"));
|
||||
WrapInFunction(expr_vec_mod_vec, expr_vec_mod_scalar, expr_scalar_mod_vec);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_EQ(gen.result(), R"(#version 310 es
|
||||
#extension GL_AMD_gpu_shader_half_float : require
|
||||
|
||||
f16vec3 tint_float_modulo(f16vec3 lhs, f16vec3 rhs) {
|
||||
return (lhs - rhs * trunc(lhs / rhs));
|
||||
}
|
||||
|
||||
f16vec3 tint_float_modulo_1(f16vec3 lhs, float16_t rhs) {
|
||||
return (lhs - rhs * trunc(lhs / rhs));
|
||||
}
|
||||
|
||||
f16vec3 tint_float_modulo_2(float16_t lhs, f16vec3 rhs) {
|
||||
return (lhs - rhs * trunc(lhs / rhs));
|
||||
}
|
||||
|
||||
|
||||
f16vec3 a = f16vec3(0.0hf, 0.0hf, 0.0hf);
|
||||
float16_t b = 0.0hf;
|
||||
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
|
||||
void test_function() {
|
||||
f16vec3 tint_symbol = tint_float_modulo(a, a);
|
||||
f16vec3 tint_symbol_1 = tint_float_modulo_1(a, b);
|
||||
f16vec3 tint_symbol_2 = tint_float_modulo_2(b, a);
|
||||
return;
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Logical_And) {
|
||||
GlobalVar("a", ty.bool_(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.bool_(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kLogicalAnd, Expr("a"), Expr("b"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(tint_tmp)");
|
||||
EXPECT_EQ(gen.result(), R"(bool tint_tmp = a;
|
||||
if (tint_tmp) {
|
||||
tint_tmp = b;
|
||||
}
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(GlslGeneratorImplTest_Binary, Logical_Multi) {
|
||||
// (a && b) || (c || d)
|
||||
GlobalVar("a", ty.bool_(), ast::StorageClass::kPrivate);
|
||||
|
|
|
@ -66,6 +66,38 @@ TEST_P(HlslBinaryTest, Emit_f32) {
|
|||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), params.result);
|
||||
}
|
||||
TEST_P(HlslBinaryTest, Emit_f16) {
|
||||
auto params = GetParam();
|
||||
|
||||
if ((params.valid_for & BinaryData::Types::Float) == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Skip ops that are illegal for this type
|
||||
if (params.op == ast::BinaryOp::kAnd || params.op == ast::BinaryOp::kOr ||
|
||||
params.op == ast::BinaryOp::kXor || params.op == ast::BinaryOp::kShiftLeft ||
|
||||
params.op == ast::BinaryOp::kShiftRight) {
|
||||
return;
|
||||
}
|
||||
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("left", ty.f16(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("right", ty.f16(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* left = Expr("left");
|
||||
auto* right = Expr("right");
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(params.op, left, right);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), params.result);
|
||||
}
|
||||
TEST_P(HlslBinaryTest, Emit_u32) {
|
||||
auto params = GetParam();
|
||||
|
||||
|
@ -140,7 +172,7 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
BinaryData{"(left % right)", ast::BinaryOp::kModulo,
|
||||
BinaryData::Types::Float}));
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f32) {
|
||||
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
auto* rhs = Expr(1_f);
|
||||
|
||||
|
@ -155,7 +187,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar) {
|
|||
EXPECT_EQ(out.str(), "((1.0f).xxx * 1.0f)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorScalar_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
auto* rhs = Expr(1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "((float16_t(1.0h)).xxx * float16_t(1.0h))");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f32) {
|
||||
auto* lhs = Expr(1_f);
|
||||
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
|
||||
|
@ -170,7 +219,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector) {
|
|||
EXPECT_EQ(out.str(), "(1.0f * (1.0f).xxx)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarVector_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* lhs = Expr(1_h);
|
||||
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(float16_t(1.0h) * (float16_t(1.0h)).xxx)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f32) {
|
||||
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("mat");
|
||||
auto* rhs = Expr(1_f);
|
||||
|
@ -185,7 +251,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar) {
|
|||
EXPECT_EQ(out.str(), "(mat * 1.0f)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixScalar_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("mat");
|
||||
auto* rhs = Expr(1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(mat * float16_t(1.0h))");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix_f32) {
|
||||
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr(1_f);
|
||||
auto* rhs = Expr("mat");
|
||||
|
@ -200,7 +283,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix) {
|
|||
EXPECT_EQ(out.str(), "(1.0f * mat)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_ScalarMatrix_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr(1_h);
|
||||
auto* rhs = Expr("mat");
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "(float16_t(1.0h) * mat)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector_f32) {
|
||||
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("mat");
|
||||
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
|
@ -215,7 +315,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector) {
|
|||
EXPECT_EQ(out.str(), "mul((1.0f).xxx, mat)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixVector_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = Expr("mat");
|
||||
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "mul((float16_t(1.0h)).xxx, mat)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix_f32) {
|
||||
GlobalVar("mat", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
auto* rhs = Expr("mat");
|
||||
|
@ -230,7 +347,24 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix) {
|
|||
EXPECT_EQ(out.str(), "mul(mat, (1.0f).xxx)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_VectorMatrix_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("mat", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
auto* rhs = Expr("mat");
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "mul(mat, (float16_t(1.0h)).xxx)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix_f32) {
|
||||
GlobalVar("lhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("rhs", ty.mat3x3<f32>(), ast::StorageClass::kPrivate);
|
||||
|
||||
|
@ -244,6 +378,22 @@ TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix) {
|
|||
EXPECT_EQ(out.str(), "mul(rhs, lhs)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Multiply_MatrixMatrix_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
GlobalVar("lhs", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("rhs", ty.mat3x3<f16>(), ast::StorageClass::kPrivate);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("lhs"), Expr("rhs"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
EXPECT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "mul(rhs, lhs)");
|
||||
}
|
||||
|
||||
TEST_F(HlslGeneratorImplTest_Binary, Logical_And) {
|
||||
GlobalVar("a", ty.bool_(), ast::StorageClass::kPrivate);
|
||||
GlobalVar("b", ty.bool_(), ast::StorageClass::kPrivate);
|
||||
|
@ -559,7 +709,7 @@ TEST_P(HlslGeneratorDivModTest, DivOrModByLiteralZero_u32) {
|
|||
R"( 1u);
|
||||
}
|
||||
)");
|
||||
} // namespace HlslGeneratorDivMod
|
||||
}
|
||||
|
||||
TEST_P(HlslGeneratorDivModTest, DivOrModByLiteralZero_vec_by_vec_i32) {
|
||||
Func("fn", {}, ty.void_(),
|
||||
|
@ -577,7 +727,7 @@ TEST_P(HlslGeneratorDivModTest, DivOrModByLiteralZero_vec_by_vec_i32) {
|
|||
R"( int4(50, 1, 25, 1));
|
||||
}
|
||||
)");
|
||||
} // namespace
|
||||
}
|
||||
|
||||
TEST_P(HlslGeneratorDivModTest, DivOrModByLiteralZero_vec_by_scalar_i32) {
|
||||
Func("fn", {}, ty.void_(),
|
||||
|
@ -595,7 +745,7 @@ TEST_P(HlslGeneratorDivModTest, DivOrModByLiteralZero_vec_by_scalar_i32) {
|
|||
R"( 1);
|
||||
}
|
||||
)");
|
||||
} // namespace hlsl
|
||||
}
|
||||
|
||||
TEST_P(HlslGeneratorDivModTest, DivOrModByIdentifier_i32) {
|
||||
Func("fn", {Param("b", ty.i32())}, ty.void_(),
|
||||
|
@ -613,7 +763,7 @@ TEST_P(HlslGeneratorDivModTest, DivOrModByIdentifier_i32) {
|
|||
R"( (b == 0 ? 1 : b));
|
||||
}
|
||||
)");
|
||||
} // namespace writer
|
||||
}
|
||||
|
||||
TEST_P(HlslGeneratorDivModTest, DivOrModByIdentifier_u32) {
|
||||
Func("fn", {Param("b", ty.u32())}, ty.void_(),
|
||||
|
@ -631,7 +781,7 @@ TEST_P(HlslGeneratorDivModTest, DivOrModByIdentifier_u32) {
|
|||
R"( (b == 0u ? 1u : b));
|
||||
}
|
||||
)");
|
||||
} // namespace tint
|
||||
}
|
||||
|
||||
TEST_P(HlslGeneratorDivModTest, DivOrModByIdentifier_vec_by_vec_i32) {
|
||||
Func("fn", {Param("b", ty.vec3<i32>())}, ty.void_(),
|
||||
|
|
|
@ -158,6 +158,21 @@ TEST_F(MslBinaryTest, ModF32) {
|
|||
EXPECT_EQ(out.str(), "fmod(left, right)");
|
||||
}
|
||||
|
||||
TEST_F(MslBinaryTest, ModF16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* left = Var("left", ty.f16());
|
||||
auto* right = Var("right", ty.f16());
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr(left), Expr(right));
|
||||
WrapInFunction(left, right, expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "fmod(left, right)");
|
||||
}
|
||||
|
||||
TEST_F(MslBinaryTest, ModVec3F32) {
|
||||
auto* left = Var("left", ty.vec3<f32>());
|
||||
auto* right = Var("right", ty.vec3<f32>());
|
||||
|
@ -171,6 +186,21 @@ TEST_F(MslBinaryTest, ModVec3F32) {
|
|||
EXPECT_EQ(out.str(), "fmod(left, right)");
|
||||
}
|
||||
|
||||
TEST_F(MslBinaryTest, ModVec3F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* left = Var("left", ty.vec3<f16>());
|
||||
auto* right = Var("right", ty.vec3<f16>());
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kModulo, Expr(left), Expr(right));
|
||||
WrapInFunction(left, right, expr);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitExpression(out, expr)) << gen.error();
|
||||
EXPECT_EQ(out.str(), "fmod(left, right)");
|
||||
}
|
||||
|
||||
TEST_F(MslBinaryTest, BoolAnd) {
|
||||
auto* left = Var("left", nullptr, Expr(true));
|
||||
auto* right = Var("right", nullptr, Expr(false));
|
||||
|
|
|
@ -191,8 +191,8 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
BinaryData{ast::BinaryOp::kSubtract, "OpISub"},
|
||||
BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"}));
|
||||
|
||||
using BinaryArithFloatTest = TestParamHelper<BinaryData>;
|
||||
TEST_P(BinaryArithFloatTest, Scalar) {
|
||||
using BinaryArithF32Test = TestParamHelper<BinaryData>;
|
||||
TEST_P(BinaryArithF32Test, Scalar) {
|
||||
auto param = GetParam();
|
||||
|
||||
auto* lhs = Expr(3.2_f);
|
||||
|
@ -215,7 +215,7 @@ TEST_P(BinaryArithFloatTest, Scalar) {
|
|||
"%4 = " + param.name + " %1 %2 %3\n");
|
||||
}
|
||||
|
||||
TEST_P(BinaryArithFloatTest, Vector) {
|
||||
TEST_P(BinaryArithF32Test, Vector) {
|
||||
auto param = GetParam();
|
||||
|
||||
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
|
@ -239,7 +239,66 @@ TEST_P(BinaryArithFloatTest, Vector) {
|
|||
"%5 = " + param.name + " %1 %4 %4\n");
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(BuilderTest,
|
||||
BinaryArithFloatTest,
|
||||
BinaryArithF32Test,
|
||||
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
|
||||
BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
|
||||
BinaryData{ast::BinaryOp::kModulo, "OpFRem"},
|
||||
BinaryData{ast::BinaryOp::kMultiply, "OpFMul"},
|
||||
BinaryData{ast::BinaryOp::kSubtract, "OpFSub"}));
|
||||
|
||||
using BinaryArithF16Test = TestParamHelper<BinaryData>;
|
||||
TEST_P(BinaryArithF16Test, Scalar) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto param = GetParam();
|
||||
|
||||
auto* lhs = Expr(3.2_h);
|
||||
auto* rhs = Expr(4.5_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 16
|
||||
%2 = OpConstant %1 0x1.998p+1
|
||||
%3 = OpConstant %1 0x1.2p+2
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
"%4 = " + param.name + " %1 %2 %3\n");
|
||||
}
|
||||
|
||||
TEST_P(BinaryArithF16Test, Vector) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto param = GetParam();
|
||||
|
||||
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 16
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 0x1p+0
|
||||
%4 = OpConstantComposite %1 %3 %3 %3
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
"%5 = " + param.name + " %1 %4 %4\n");
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(BuilderTest,
|
||||
BinaryArithF16Test,
|
||||
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpFAdd"},
|
||||
BinaryData{ast::BinaryOp::kDivide, "OpFDiv"},
|
||||
BinaryData{ast::BinaryOp::kModulo, "OpFRem"},
|
||||
|
@ -422,8 +481,8 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
BinaryData{ast::BinaryOp::kLessThanEqual, "OpSLessThanEqual"},
|
||||
BinaryData{ast::BinaryOp::kNotEqual, "OpINotEqual"}));
|
||||
|
||||
using BinaryCompareFloatTest = TestParamHelper<BinaryData>;
|
||||
TEST_P(BinaryCompareFloatTest, Scalar) {
|
||||
using BinaryCompareF32Test = TestParamHelper<BinaryData>;
|
||||
TEST_P(BinaryCompareF32Test, Scalar) {
|
||||
auto param = GetParam();
|
||||
|
||||
auto* lhs = Expr(3.2_f);
|
||||
|
@ -447,7 +506,7 @@ TEST_P(BinaryCompareFloatTest, Scalar) {
|
|||
"%4 = " + param.name + " %5 %2 %3\n");
|
||||
}
|
||||
|
||||
TEST_P(BinaryCompareFloatTest, Vector) {
|
||||
TEST_P(BinaryCompareF32Test, Vector) {
|
||||
auto param = GetParam();
|
||||
|
||||
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
|
@ -474,7 +533,7 @@ TEST_P(BinaryCompareFloatTest, Vector) {
|
|||
}
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
BuilderTest,
|
||||
BinaryCompareFloatTest,
|
||||
BinaryCompareF32Test,
|
||||
testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpFOrdEqual"},
|
||||
BinaryData{ast::BinaryOp::kGreaterThan, "OpFOrdGreaterThan"},
|
||||
BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual"},
|
||||
|
@ -482,7 +541,71 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
|
||||
BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
|
||||
using BinaryCompareF16Test = TestParamHelper<BinaryData>;
|
||||
TEST_P(BinaryCompareF16Test, Scalar) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto param = GetParam();
|
||||
|
||||
auto* lhs = Expr(3.2_h);
|
||||
auto* rhs = Expr(4.5_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 4u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%1 = OpTypeFloat 16
|
||||
%2 = OpConstant %1 0x1.998p+1
|
||||
%3 = OpConstant %1 0x1.2p+2
|
||||
%5 = OpTypeBool
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
"%4 = " + param.name + " %5 %2 %3\n");
|
||||
}
|
||||
|
||||
TEST_P(BinaryCompareF16Test, Vector) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto param = GetParam();
|
||||
|
||||
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()), R"(%2 = OpTypeFloat 16
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 0x1p+0
|
||||
%4 = OpConstantComposite %1 %3 %3 %3
|
||||
%7 = OpTypeBool
|
||||
%6 = OpTypeVector %7 3
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
"%5 = " + param.name + " %6 %4 %4\n");
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
BuilderTest,
|
||||
BinaryCompareF16Test,
|
||||
testing::Values(BinaryData{ast::BinaryOp::kEqual, "OpFOrdEqual"},
|
||||
BinaryData{ast::BinaryOp::kGreaterThan, "OpFOrdGreaterThan"},
|
||||
BinaryData{ast::BinaryOp::kGreaterThanEqual, "OpFOrdGreaterThanEqual"},
|
||||
BinaryData{ast::BinaryOp::kLessThan, "OpFOrdLessThan"},
|
||||
BinaryData{ast::BinaryOp::kLessThanEqual, "OpFOrdLessThanEqual"},
|
||||
BinaryData{ast::BinaryOp::kNotEqual, "OpFOrdNotEqual"}));
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_VectorScalar_F32) {
|
||||
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
auto* rhs = Expr(1_f);
|
||||
|
||||
|
@ -505,7 +628,32 @@ TEST_F(BuilderTest, Binary_Multiply_VectorScalar) {
|
|||
"%5 = OpVectorTimesScalar %1 %4 %3\n");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
|
||||
TEST_F(BuilderTest, Binary_Multiply_VectorScalar_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
auto* rhs = Expr(1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()),
|
||||
R"(%2 = OpTypeFloat 16
|
||||
%1 = OpTypeVector %2 3
|
||||
%3 = OpConstant %2 0x1p+0
|
||||
%4 = OpConstantComposite %1 %3 %3 %3
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
"%5 = OpVectorTimesScalar %1 %4 %3\n");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_ScalarVector_F32) {
|
||||
auto* lhs = Expr(1_f);
|
||||
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
|
||||
|
@ -528,7 +676,32 @@ TEST_F(BuilderTest, Binary_Multiply_ScalarVector) {
|
|||
"%5 = OpVectorTimesScalar %3 %4 %2\n");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
|
||||
TEST_F(BuilderTest, Binary_Multiply_ScalarVector_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* lhs = Expr(1_h);
|
||||
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 5u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()),
|
||||
R"(%1 = OpTypeFloat 16
|
||||
%2 = OpConstant %1 0x1p+0
|
||||
%3 = OpTypeVector %1 3
|
||||
%4 = OpConstantComposite %3 %2 %2 %2
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
"%5 = OpVectorTimesScalar %3 %4 %2\n");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixScalar_F32) {
|
||||
auto* var = Var("mat", ty.mat3x3<f32>());
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), Expr(1_f));
|
||||
|
@ -555,7 +728,36 @@ TEST_F(BuilderTest, Binary_Multiply_MatrixScalar) {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixScalar_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = Var("mat", ty.mat3x3<f16>());
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), Expr(1_h));
|
||||
|
||||
WrapInFunction(var, expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()),
|
||||
R"(%5 = OpTypeFloat 16
|
||||
%4 = OpTypeVector %5 3
|
||||
%3 = OpTypeMatrix %4 3
|
||||
%2 = OpTypePointer Function %3
|
||||
%1 = OpVariable %2 Function
|
||||
%7 = OpConstant %5 0x1p+0
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%6 = OpLoad %3 %1
|
||||
%8 = OpMatrixTimesScalar %3 %6 %7
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix_F32) {
|
||||
auto* var = Var("mat", ty.mat3x3<f32>());
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr(1_f), Expr("mat"));
|
||||
|
@ -582,7 +784,36 @@ TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix) {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
|
||||
TEST_F(BuilderTest, Binary_Multiply_ScalarMatrix_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = Var("mat", ty.mat3x3<f16>());
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr(1_h), Expr("mat"));
|
||||
|
||||
WrapInFunction(var, expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()),
|
||||
R"(%5 = OpTypeFloat 16
|
||||
%4 = OpTypeVector %5 3
|
||||
%3 = OpTypeMatrix %4 3
|
||||
%2 = OpTypePointer Function %3
|
||||
%1 = OpVariable %2 Function
|
||||
%6 = OpConstant %5 0x1p+0
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%7 = OpLoad %3 %1
|
||||
%8 = OpMatrixTimesScalar %3 %7 %6
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixVector_F32) {
|
||||
auto* var = Var("mat", ty.mat3x3<f32>());
|
||||
auto* rhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
|
||||
|
@ -611,7 +842,38 @@ TEST_F(BuilderTest, Binary_Multiply_MatrixVector) {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixVector_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = Var("mat", ty.mat3x3<f16>());
|
||||
auto* rhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), rhs);
|
||||
|
||||
WrapInFunction(var, expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 9u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()),
|
||||
R"(%5 = OpTypeFloat 16
|
||||
%4 = OpTypeVector %5 3
|
||||
%3 = OpTypeMatrix %4 3
|
||||
%2 = OpTypePointer Function %3
|
||||
%1 = OpVariable %2 Function
|
||||
%7 = OpConstant %5 0x1p+0
|
||||
%8 = OpConstantComposite %4 %7 %7 %7
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%6 = OpLoad %3 %1
|
||||
%9 = OpMatrixTimesVector %4 %6 %8
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_VectorMatrix_F32) {
|
||||
auto* var = Var("mat", ty.mat3x3<f32>());
|
||||
auto* lhs = vec3<f32>(1_f, 1_f, 1_f);
|
||||
|
||||
|
@ -640,7 +902,38 @@ TEST_F(BuilderTest, Binary_Multiply_VectorMatrix) {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
|
||||
TEST_F(BuilderTest, Binary_Multiply_VectorMatrix_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = Var("mat", ty.mat3x3<f16>());
|
||||
auto* lhs = vec3<f16>(1_h, 1_h, 1_h);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, lhs, Expr("mat"));
|
||||
|
||||
WrapInFunction(var, expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 9u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()),
|
||||
R"(%5 = OpTypeFloat 16
|
||||
%4 = OpTypeVector %5 3
|
||||
%3 = OpTypeMatrix %4 3
|
||||
%2 = OpTypePointer Function %3
|
||||
%1 = OpVariable %2 Function
|
||||
%6 = OpConstant %5 0x1p+0
|
||||
%7 = OpConstantComposite %4 %6 %6 %6
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%8 = OpLoad %3 %1
|
||||
%9 = OpVectorTimesMatrix %4 %7 %8
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix_F32) {
|
||||
auto* var = Var("mat", ty.mat3x3<f32>());
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), Expr("mat"));
|
||||
|
@ -667,6 +960,35 @@ TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix) {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_Multiply_MatrixMatrix_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = Var("mat", ty.mat3x3<f16>());
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, Expr("mat"), Expr("mat"));
|
||||
|
||||
WrapInFunction(var, expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
|
||||
b.push_function(Function{});
|
||||
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
|
||||
|
||||
EXPECT_EQ(b.GenerateBinaryExpression(expr), 8u) << b.error();
|
||||
EXPECT_EQ(DumpInstructions(b.types()),
|
||||
R"(%5 = OpTypeFloat 16
|
||||
%4 = OpTypeVector %5 3
|
||||
%3 = OpTypeMatrix %4 3
|
||||
%2 = OpTypePointer Function %3
|
||||
%1 = OpVariable %2 Function
|
||||
)");
|
||||
EXPECT_EQ(DumpInstructions(b.functions()[0].instructions()),
|
||||
R"(%6 = OpLoad %3 %1
|
||||
%7 = OpLoad %3 %1
|
||||
%8 = OpMatrixTimesMatrix %3 %6 %7
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(BuilderTest, Binary_LogicalAnd) {
|
||||
auto* lhs = create<ast::BinaryExpression>(ast::BinaryOp::kEqual, Expr(1_i), Expr(2_i));
|
||||
|
||||
|
@ -895,11 +1217,13 @@ OpBranch %9
|
|||
|
||||
namespace BinaryArithVectorScalar {
|
||||
|
||||
enum class Type { f32, i32, u32 };
|
||||
enum class Type { f32, f16, i32, u32 };
|
||||
static const ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return builder->vec3<f32>(1_f, 1_f, 1_f);
|
||||
case Type::f16:
|
||||
return builder->vec3<f16>(1_h, 1_h, 1_h);
|
||||
case Type::i32:
|
||||
return builder->vec3<i32>(1_i, 1_i, 1_i);
|
||||
case Type::u32:
|
||||
|
@ -911,6 +1235,8 @@ static const ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type)
|
|||
switch (type) {
|
||||
case Type::f32:
|
||||
return builder->Expr(1_f);
|
||||
case Type::f16:
|
||||
return builder->Expr(1_h);
|
||||
case Type::i32:
|
||||
return builder->Expr(1_i);
|
||||
case Type::u32:
|
||||
|
@ -922,6 +1248,8 @@ static std::string OpTypeDecl(Type type) {
|
|||
switch (type) {
|
||||
case Type::f32:
|
||||
return "OpTypeFloat 32";
|
||||
case Type::f16:
|
||||
return "OpTypeFloat 16";
|
||||
case Type::i32:
|
||||
return "OpTypeInt 32 1";
|
||||
case Type::u32:
|
||||
|
@ -929,6 +1257,32 @@ static std::string OpTypeDecl(Type type) {
|
|||
}
|
||||
return {};
|
||||
}
|
||||
static std::string ConstantValue(Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
case Type::i32:
|
||||
case Type::u32:
|
||||
return "1";
|
||||
case Type::f16:
|
||||
return "0x1p+0";
|
||||
}
|
||||
return {};
|
||||
}
|
||||
static std::string CapabilityDecl(Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
case Type::i32:
|
||||
case Type::u32:
|
||||
return "OpCapability Shader";
|
||||
case Type::f16:
|
||||
return R"(OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability UniformAndStorageBuffer16BitAccess
|
||||
OpCapability StorageBuffer16BitAccess
|
||||
OpCapability StorageInputOutput16)";
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
struct Param {
|
||||
Type type;
|
||||
|
@ -940,9 +1294,15 @@ using BinaryArithVectorScalarTest = TestParamHelper<Param>;
|
|||
TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
|
||||
auto& param = GetParam();
|
||||
|
||||
if (param.type == Type::f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
}
|
||||
|
||||
const ast::Expression* lhs = MakeVectorExpr(this, param.type);
|
||||
const ast::Expression* rhs = MakeScalarExpr(this, param.type);
|
||||
std::string op_type_decl = OpTypeDecl(param.type);
|
||||
std::string constant_value = ConstantValue(param.type);
|
||||
std::string capability_decl = CapabilityDecl(param.type);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
|
@ -951,7 +1311,7 @@ TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
|
|||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
|
@ -960,7 +1320,8 @@ OpName %3 "test_function"
|
|||
%1 = OpTypeFunction %2
|
||||
%6 = )" + op_type_decl + R"(
|
||||
%5 = OpTypeVector %6 3
|
||||
%7 = OpConstant %6 1
|
||||
%7 = OpConstant %6 )" + constant_value +
|
||||
R"(
|
||||
%8 = OpConstantComposite %5 %7 %7 %7
|
||||
%11 = OpTypePointer Function %5
|
||||
%12 = OpConstantNull %5
|
||||
|
@ -978,9 +1339,15 @@ OpFunctionEnd
|
|||
TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
|
||||
auto& param = GetParam();
|
||||
|
||||
if (param.type == Type::f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
}
|
||||
|
||||
const ast::Expression* lhs = MakeScalarExpr(this, param.type);
|
||||
const ast::Expression* rhs = MakeVectorExpr(this, param.type);
|
||||
std::string op_type_decl = OpTypeDecl(param.type);
|
||||
std::string constant_value = ConstantValue(param.type);
|
||||
std::string capability_decl = CapabilityDecl(param.type);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
|
@ -989,7 +1356,7 @@ TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
|
|||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
|
@ -997,7 +1364,8 @@ OpName %3 "test_function"
|
|||
%2 = OpTypeVoid
|
||||
%1 = OpTypeFunction %2
|
||||
%5 = )" + op_type_decl + R"(
|
||||
%6 = OpConstant %5 1
|
||||
%6 = OpConstant %5 )" + constant_value +
|
||||
R"(
|
||||
%7 = OpTypeVector %5 3
|
||||
%8 = OpConstantComposite %7 %6 %6 %6
|
||||
%11 = OpTypePointer Function %7
|
||||
|
@ -1024,6 +1392,10 @@ INSTANTIATE_TEST_SUITE_P(BuilderTest,
|
|||
// Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
|
||||
Param{Type::f32, ast::BinaryOp::kSubtract, "OpFSub"},
|
||||
|
||||
Param{Type::f16, ast::BinaryOp::kAdd, "OpFAdd"},
|
||||
Param{Type::f16, ast::BinaryOp::kDivide, "OpFDiv"},
|
||||
Param{Type::f16, ast::BinaryOp::kSubtract, "OpFSub"},
|
||||
|
||||
Param{Type::i32, ast::BinaryOp::kAdd, "OpIAdd"},
|
||||
Param{Type::i32, ast::BinaryOp::kDivide, "OpSDiv"},
|
||||
Param{Type::i32, ast::BinaryOp::kModulo, "OpSMod"},
|
||||
|
@ -1040,9 +1412,15 @@ using BinaryArithVectorScalarMultiplyTest = TestParamHelper<Param>;
|
|||
TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
|
||||
auto& param = GetParam();
|
||||
|
||||
if (param.type == Type::f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
}
|
||||
|
||||
const ast::Expression* lhs = MakeVectorExpr(this, param.type);
|
||||
const ast::Expression* rhs = MakeScalarExpr(this, param.type);
|
||||
std::string op_type_decl = OpTypeDecl(param.type);
|
||||
std::string constant_value = ConstantValue(param.type);
|
||||
std::string capability_decl = CapabilityDecl(param.type);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
|
@ -1051,7 +1429,7 @@ TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
|
|||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
|
@ -1060,7 +1438,8 @@ OpName %3 "test_function"
|
|||
%1 = OpTypeFunction %2
|
||||
%6 = )" + op_type_decl + R"(
|
||||
%5 = OpTypeVector %6 3
|
||||
%7 = OpConstant %6 1
|
||||
%7 = OpConstant %6 )" + constant_value +
|
||||
R"(
|
||||
%8 = OpConstantComposite %5 %7 %7 %7
|
||||
%3 = OpFunction %2 None %1
|
||||
%4 = OpLabel
|
||||
|
@ -1074,9 +1453,15 @@ OpFunctionEnd
|
|||
TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
|
||||
auto& param = GetParam();
|
||||
|
||||
if (param.type == Type::f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
}
|
||||
|
||||
const ast::Expression* lhs = MakeScalarExpr(this, param.type);
|
||||
const ast::Expression* rhs = MakeVectorExpr(this, param.type);
|
||||
std::string op_type_decl = OpTypeDecl(param.type);
|
||||
std::string constant_value = ConstantValue(param.type);
|
||||
std::string capability_decl = CapabilityDecl(param.type);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
|
@ -1085,7 +1470,7 @@ TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
|
|||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
|
@ -1093,7 +1478,8 @@ OpName %3 "test_function"
|
|||
%2 = OpTypeVoid
|
||||
%1 = OpTypeFunction %2
|
||||
%5 = )" + op_type_decl + R"(
|
||||
%6 = OpConstant %5 1
|
||||
%6 = OpConstant %5 )" + constant_value +
|
||||
R"(
|
||||
%7 = OpTypeVector %5 3
|
||||
%8 = OpConstantComposite %7 %6 %6 %6
|
||||
%3 = OpFunction %2 None %1
|
||||
|
@ -1107,13 +1493,57 @@ OpFunctionEnd
|
|||
}
|
||||
INSTANTIATE_TEST_SUITE_P(BuilderTest,
|
||||
BinaryArithVectorScalarMultiplyTest,
|
||||
testing::Values(Param{Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
|
||||
testing::Values(Param{Type::f32, ast::BinaryOp::kMultiply, "OpFMul"},
|
||||
Param{Type::f16, ast::BinaryOp::kMultiply, "OpFMul"}));
|
||||
|
||||
} // namespace BinaryArithVectorScalar
|
||||
|
||||
namespace BinaryArithMatrixMatrix {
|
||||
|
||||
enum class Type { f32, f16 };
|
||||
static const ast::Expression* MakeMat3x4Expr(ProgramBuilder* builder, Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return builder->mat3x4<f32>();
|
||||
case Type::f16:
|
||||
return builder->mat3x4<f16>();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
static const ast::Expression* MakeMat4x3Expr(ProgramBuilder* builder, Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return builder->mat4x3<f32>();
|
||||
case Type::f16:
|
||||
return builder->mat4x3<f16>();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
static std::string OpTypeDecl(Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return "OpTypeFloat 32";
|
||||
case Type::f16:
|
||||
return "OpTypeFloat 16";
|
||||
}
|
||||
return {};
|
||||
}
|
||||
static std::string CapabilityDecl(Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return "OpCapability Shader";
|
||||
case Type::f16:
|
||||
return R"(OpCapability Shader
|
||||
OpCapability Float16
|
||||
OpCapability UniformAndStorageBuffer16BitAccess
|
||||
OpCapability StorageBuffer16BitAccess
|
||||
OpCapability StorageInputOutput16)";
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
struct Param {
|
||||
Type type;
|
||||
ast::BinaryOp op;
|
||||
std::string name;
|
||||
};
|
||||
|
@ -1122,8 +1552,14 @@ using BinaryArithMatrixMatrix = TestParamHelper<Param>;
|
|||
TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
|
||||
auto& param = GetParam();
|
||||
|
||||
const ast::Expression* lhs = mat3x4<f32>();
|
||||
const ast::Expression* rhs = mat3x4<f32>();
|
||||
if (param.type == Type::f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
}
|
||||
|
||||
const ast::Expression* lhs = MakeMat3x4Expr(this, param.type);
|
||||
const ast::Expression* rhs = MakeMat3x4Expr(this, param.type);
|
||||
std::string op_type_decl = OpTypeDecl(param.type);
|
||||
std::string capability_decl = CapabilityDecl(param.type);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
|
@ -1132,14 +1568,14 @@ TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
|
|||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
OpName %3 "test_function"
|
||||
%2 = OpTypeVoid
|
||||
%1 = OpTypeFunction %2
|
||||
%7 = OpTypeFloat 32
|
||||
%7 = )" + op_type_decl + R"(
|
||||
%6 = OpTypeVector %7 4
|
||||
%5 = OpTypeMatrix %6 3
|
||||
%8 = OpConstantNull %5
|
||||
|
@ -1164,15 +1600,23 @@ OpFunctionEnd
|
|||
INSTANTIATE_TEST_SUITE_P( //
|
||||
BuilderTest,
|
||||
BinaryArithMatrixMatrix,
|
||||
testing::Values(Param{ast::BinaryOp::kAdd, "OpFAdd"},
|
||||
Param{ast::BinaryOp::kSubtract, "OpFSub"}));
|
||||
testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
|
||||
Param{Type::f32, ast::BinaryOp::kSubtract, "OpFSub"},
|
||||
Param{Type::f16, ast::BinaryOp::kAdd, "OpFAdd"},
|
||||
Param{Type::f16, ast::BinaryOp::kSubtract, "OpFSub"}));
|
||||
|
||||
using BinaryArithMatrixMatrixMultiply = TestParamHelper<Param>;
|
||||
TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
|
||||
auto& param = GetParam();
|
||||
|
||||
const ast::Expression* lhs = mat3x4<f32>();
|
||||
const ast::Expression* rhs = mat4x3<f32>();
|
||||
if (param.type == Type::f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
}
|
||||
|
||||
const ast::Expression* lhs = MakeMat3x4Expr(this, param.type);
|
||||
const ast::Expression* rhs = MakeMat4x3Expr(this, param.type);
|
||||
std::string op_type_decl = OpTypeDecl(param.type);
|
||||
std::string capability_decl = CapabilityDecl(param.type);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
|
@ -1181,14 +1625,14 @@ TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
|
|||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
EXPECT_EQ(DumpBuilder(b), capability_decl + R"(
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
OpName %3 "test_function"
|
||||
%2 = OpTypeVoid
|
||||
%1 = OpTypeFunction %2
|
||||
%7 = OpTypeFloat 32
|
||||
%7 = )" + op_type_decl + R"(
|
||||
%6 = OpTypeVector %7 4
|
||||
%5 = OpTypeMatrix %6 3
|
||||
%8 = OpConstantNull %5
|
||||
|
@ -1208,7 +1652,8 @@ OpFunctionEnd
|
|||
INSTANTIATE_TEST_SUITE_P( //
|
||||
BuilderTest,
|
||||
BinaryArithMatrixMatrixMultiply,
|
||||
testing::Values(Param{ast::BinaryOp::kMultiply, "OpFMul"}));
|
||||
testing::Values(Param{Type::f32, ast::BinaryOp::kMultiply, ""},
|
||||
Param{Type::f16, ast::BinaryOp::kMultiply, ""}));
|
||||
|
||||
} // namespace BinaryArithMatrixMatrix
|
||||
|
||||
|
|
Loading…
Reference in New Issue