tint/writer/msl: Support for F16 type, constructor, and convertor
This patch make MSL writer support emitting f16 types, f16 literals, f16 constructor and convertor. Unittests are also implemented. The MSL writer will emit f16 literal as `1.23h`, and map f16 types as follow: WGSL type -> MSL type f16 -> half vec2<f16> -> half2 vec3<f16> -> half3 vec4<f16> -> half4 mat2x2<f16> -> half2x2 mat2x3<f16> -> half2x3 mat2x4<f16> -> half2x4 mat3x2<f16> -> half3x2 mat3x3<f16> -> half3x3 mat3x4<f16> -> half3x4 mat4x2<f16> -> half4x2 mat4x3<f16> -> half4x3 mat4x4<f16> -> half4x4 Bug: tint:1473, tint:1502 Change-Id: Id91821e1a32d48c80bad9a0753faa5247835b0f7 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95686 Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
parent
0d3d3210d4
commit
e9c5070348
|
@ -39,6 +39,7 @@
|
|||
#include "src/tint/sem/constant.h"
|
||||
#include "src/tint/sem/depth_multisampled_texture.h"
|
||||
#include "src/tint/sem/depth_texture.h"
|
||||
#include "src/tint/sem/f16.h"
|
||||
#include "src/tint/sem/f32.h"
|
||||
#include "src/tint/sem/function.h"
|
||||
#include "src/tint/sem/i32.h"
|
||||
|
@ -97,6 +98,20 @@ void PrintF32(std::ostream& out, float value) {
|
|||
}
|
||||
}
|
||||
|
||||
void PrintF16(std::ostream& out, float value) {
|
||||
// Note: Currently inf and nan should not be constructable, but this is implemented for the day
|
||||
// we support them.
|
||||
if (std::isinf(value)) {
|
||||
// HUGE_VALH evaluates to +infinity.
|
||||
out << (value >= 0 ? "HUGE_VALH" : "-HUGE_VALH");
|
||||
} else if (std::isnan(value)) {
|
||||
// There is no NaN expr for half in MSL, "NAN" is of float type.
|
||||
out << "NAN";
|
||||
} else {
|
||||
out << FloatToString(value) << "h";
|
||||
}
|
||||
}
|
||||
|
||||
void PrintI32(std::ostream& out, int32_t value) {
|
||||
// MSL (and C++) parse `-2147483648` as a `long` because it parses unary minus and `2147483648`
|
||||
// as separate tokens, and the latter doesn't fit into an (32-bit) `int`.
|
||||
|
@ -1551,10 +1566,8 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
|
|||
return true;
|
||||
},
|
||||
[&](const sem::F16*) {
|
||||
// Placeholder for emitting f16 zero value
|
||||
diagnostics_.add_error(diag::System::Writer,
|
||||
"Type f16 is not completely implemented yet");
|
||||
return false;
|
||||
out << "0.0h";
|
||||
return true;
|
||||
},
|
||||
[&](const sem::F32*) {
|
||||
out << "0.0f";
|
||||
|
@ -1605,6 +1618,10 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan
|
|||
PrintF32(out, constant->As<float>());
|
||||
return true;
|
||||
},
|
||||
[&](const sem::F16*) {
|
||||
PrintF16(out, constant->As<float>());
|
||||
return true;
|
||||
},
|
||||
[&](const sem::I32*) {
|
||||
PrintI32(out, constant->As<int32_t>());
|
||||
return true;
|
||||
|
@ -1694,7 +1711,11 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression*
|
|||
return true;
|
||||
},
|
||||
[&](const ast::FloatLiteralExpression* l) {
|
||||
PrintF32(out, static_cast<float>(l->value));
|
||||
if (l->suffix == ast::FloatLiteralExpression::Suffix::kH) {
|
||||
PrintF16(out, static_cast<float>(l->value));
|
||||
} else {
|
||||
PrintF32(out, static_cast<float>(l->value));
|
||||
}
|
||||
return true;
|
||||
},
|
||||
[&](const ast::IntLiteralExpression* i) {
|
||||
|
@ -2459,9 +2480,8 @@ bool GeneratorImpl::EmitType(std::ostream& out,
|
|||
return true;
|
||||
},
|
||||
[&](const sem::F16*) {
|
||||
diagnostics_.add_error(diag::System::Writer,
|
||||
"Type f16 is not completely implemented yet");
|
||||
return false;
|
||||
out << "half";
|
||||
return true;
|
||||
},
|
||||
[&](const sem::F32*) {
|
||||
out << "float";
|
||||
|
@ -3043,20 +3063,25 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::
|
|||
[&](const sem::F32*) {
|
||||
return SizeAndAlign{4, 4};
|
||||
},
|
||||
[&](const sem::F16*) {
|
||||
return SizeAndAlign{2, 2};
|
||||
},
|
||||
|
||||
[&](const sem::Vector* vec) {
|
||||
auto num_els = vec->Width();
|
||||
auto* el_ty = vec->type();
|
||||
if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
|
||||
SizeAndAlign el_size_align = MslPackedTypeSizeAndAlign(el_ty);
|
||||
if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32, sem::F16>()) {
|
||||
// Use a packed_vec type for 3-element vectors only.
|
||||
if (num_els == 3) {
|
||||
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
// 2.2.3 Packed Vector Types
|
||||
return SizeAndAlign{num_els * 4, 4};
|
||||
return SizeAndAlign{num_els * el_size_align.size, el_size_align.align};
|
||||
} else {
|
||||
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
|
||||
// 2.2 Vector Data Types
|
||||
return SizeAndAlign{num_els * 4, num_els * 4};
|
||||
// Vector data types are aligned to their size.
|
||||
return SizeAndAlign{num_els * el_size_align.size, num_els * el_size_align.size};
|
||||
}
|
||||
}
|
||||
TINT_UNREACHABLE(Writer, diagnostics_)
|
||||
|
@ -3070,8 +3095,9 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::
|
|||
auto cols = mat->columns();
|
||||
auto rows = mat->rows();
|
||||
auto* el_ty = mat->type();
|
||||
if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
|
||||
static constexpr SizeAndAlign table[] = {
|
||||
// Metal only support half and float matrix.
|
||||
if (el_ty->IsAnyOf<sem::F32, sem::F16>()) {
|
||||
static constexpr SizeAndAlign table_f32[] = {
|
||||
/* float2x2 */ {16, 8},
|
||||
/* float2x3 */ {32, 16},
|
||||
/* float2x4 */ {32, 16},
|
||||
|
@ -3082,8 +3108,23 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::
|
|||
/* float4x3 */ {64, 16},
|
||||
/* float4x4 */ {64, 16},
|
||||
};
|
||||
static constexpr SizeAndAlign table_f16[] = {
|
||||
/* half2x2 */ {8, 4},
|
||||
/* half2x3 */ {16, 8},
|
||||
/* half2x4 */ {16, 8},
|
||||
/* half3x2 */ {12, 4},
|
||||
/* half3x3 */ {24, 8},
|
||||
/* half3x4 */ {24, 8},
|
||||
/* half4x2 */ {16, 4},
|
||||
/* half4x3 */ {32, 8},
|
||||
/* half4x4 */ {32, 8},
|
||||
};
|
||||
if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) {
|
||||
return table[(3 * (cols - 2)) + (rows - 2)];
|
||||
if (el_ty->Is<sem::F32>()) {
|
||||
return table_f32[(3 * (cols - 2)) + (rows - 2)];
|
||||
} else {
|
||||
return table_f16[(3 * (cols - 2)) + (rows - 2)];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -61,6 +61,18 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Float) {
|
|||
EXPECT_THAT(gen.result(), HasSubstr("1073741824.0f"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
// Use a number close to 1<<16 but whose decimal representation ends in 0.
|
||||
WrapInFunction(Expr(f16((1 << 15) - 8)));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("32752.0h"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) {
|
||||
WrapInFunction(Construct<f32>(-1.2e-5_f));
|
||||
|
||||
|
@ -70,6 +82,17 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) {
|
|||
EXPECT_THAT(gen.result(), HasSubstr("-0.000012f"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
WrapInFunction(Construct<f16>(-1.2e-3_h));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("-0.00119972229h"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Bool) {
|
||||
WrapInFunction(Construct<bool>(true));
|
||||
|
||||
|
@ -97,7 +120,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Uint) {
|
|||
EXPECT_THAT(gen.result(), HasSubstr("12345u"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) {
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_F32) {
|
||||
WrapInFunction(vec3<f32>(1_f, 2_f, 3_f));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
@ -106,7 +129,18 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) {
|
|||
EXPECT_THAT(gen.result(), HasSubstr("float3(1.0f, 2.0f, 3.0f)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) {
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
WrapInFunction(vec3<f16>(1_h, 2_h, 3_h));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("half3(1.0h, 2.0h, 3.0h)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty_F32) {
|
||||
WrapInFunction(vec3<f32>());
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
@ -115,26 +149,230 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) {
|
|||
EXPECT_THAT(gen.result(), HasSubstr("float3(0.0f)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat) {
|
||||
WrapInFunction(Construct(ty.mat2x3<f32>(), vec3<f32>(1_f, 2_f, 3_f), vec3<f32>(3_f, 4_f, 5_f)));
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
WrapInFunction(vec3<f16>());
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("half3(0.0h)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F32_Literal) {
|
||||
WrapInFunction(vec3<f32>(2_f));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("float3(2.0f)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F16_Literal) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
WrapInFunction(vec3<f16>(2_h));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("half3(2.0h)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F32_Var) {
|
||||
auto* var = Var("v", nullptr, Expr(2_f));
|
||||
auto* cast = vec3<f32>(var);
|
||||
WrapInFunction(var, cast);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr(R"(float v = 2.0f;
|
||||
float3 const tint_symbol = float3(v);)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F16_Var) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = Var("v", nullptr, Expr(2_h));
|
||||
auto* cast = vec3<f16>(var);
|
||||
WrapInFunction(var, cast);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr(R"(half v = 2.0h;
|
||||
half3 const tint_symbol = half3(v);)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_Bool) {
|
||||
WrapInFunction(vec3<bool>(true));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("bool3(true)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_Int) {
|
||||
WrapInFunction(vec3<i32>(2_i));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("int3(2)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_UInt) {
|
||||
WrapInFunction(vec3<u32>(2_u));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("uint3(2u)"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_F32) {
|
||||
WrapInFunction(mat2x3<f32>(vec3<f32>(1_f, 2_f, 3_f), vec3<f32>(3_f, 4_f, 5_f)));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
// A matrix of type T with n columns and m rows can also be constructed from
|
||||
// n vectors of type T with m components.
|
||||
EXPECT_THAT(gen.result(),
|
||||
HasSubstr("float2x3(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty) {
|
||||
WrapInFunction(mat4x4<f32>());
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
WrapInFunction(mat2x3<f16>(vec3<f16>(1_h, 2_h, 3_h), vec3<f16>(3_h, 4_h, 5_h)));
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(0.0f), float4(0.0f)"));
|
||||
|
||||
EXPECT_THAT(gen.result(),
|
||||
HasSubstr("half2x3(half3(1.0h, 2.0h, 3.0h), half3(3.0h, 4.0h, 5.0h))"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Complex_F32) {
|
||||
// mat4x4<f32>(
|
||||
// vec4<f32>(2.0f, 3.0f, 4.0f, 8.0f),
|
||||
// vec4<f32>(),
|
||||
// vec4<f32>(7.0f),
|
||||
// vec4<f32>(vec4<f32>(42.0f, 21.0f, 6.0f, -5.0f)),
|
||||
// );
|
||||
auto* vector_literal =
|
||||
vec4<f32>(Expr(f32(2.0)), Expr(f32(3.0)), Expr(f32(4.0)), Expr(f32(8.0)));
|
||||
auto* vector_zero_ctor = vec4<f32>();
|
||||
auto* vector_single_scalar_ctor = vec4<f32>(Expr(f32(7.0)));
|
||||
auto* vector_identical_ctor =
|
||||
vec4<f32>(vec4<f32>(Expr(f32(42.0)), Expr(f32(21.0)), Expr(f32(6.0)), Expr(f32(-5.0))));
|
||||
|
||||
auto* constructor = mat4x4<f32>(vector_literal, vector_zero_ctor, vector_single_scalar_ctor,
|
||||
vector_identical_ctor);
|
||||
|
||||
WrapInFunction(constructor);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(2.0f, 3.0f, 4.0f, 8.0f), float4(0.0f), "
|
||||
"float4(7.0f), float4(42.0f, 21.0f, 6.0f, -5.0f))"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Complex_F16) {
|
||||
// mat4x4<f16>(
|
||||
// vec4<f16>(2.0h, 3.0h, 4.0h, 8.0h),
|
||||
// vec4<f16>(),
|
||||
// vec4<f16>(7.0h),
|
||||
// vec4<f16>(vec4<f16>(42.0h, 21.0h, 6.0h, -5.0h)),
|
||||
// );
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* vector_literal =
|
||||
vec4<f16>(Expr(f16(2.0)), Expr(f16(3.0)), Expr(f16(4.0)), Expr(f16(8.0)));
|
||||
auto* vector_zero_ctor = vec4<f16>();
|
||||
auto* vector_single_scalar_ctor = vec4<f16>(Expr(f16(7.0)));
|
||||
auto* vector_identical_ctor =
|
||||
vec4<f16>(vec4<f16>(Expr(f16(42.0)), Expr(f16(21.0)), Expr(f16(6.0)), Expr(f16(-5.0))));
|
||||
|
||||
auto* constructor = mat4x4<f16>(vector_literal, vector_zero_ctor, vector_single_scalar_ctor,
|
||||
vector_identical_ctor);
|
||||
|
||||
WrapInFunction(constructor);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_THAT(gen.result(), HasSubstr("half4x4(half4(2.0h, 3.0h, 4.0h, 8.0h), half4(0.0h), "
|
||||
"half4(7.0h), half4(42.0h, 21.0h, 6.0h, -5.0h))"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty_F32) {
|
||||
WrapInFunction(mat2x3<f32>());
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_THAT(gen.result(),
|
||||
HasSubstr("float2x3 const tint_symbol = float2x3(float3(0.0f), float3(0.0f))"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty_F16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
WrapInFunction(mat2x3<f16>());
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_THAT(gen.result(),
|
||||
HasSubstr("half2x3 const tint_symbol = half2x3(half3(0.0h), half3(0.0h))"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Identity_F32) {
|
||||
// fn f() {
|
||||
// var m_1: mat4x4<f32> = mat4x4<f32>();
|
||||
// var m_2: mat4x4<f32> = mat4x4<f32>(m_1);
|
||||
// }
|
||||
|
||||
auto* m_1 = Var("m_1", ty.mat4x4(ty.f32()), mat4x4<f32>());
|
||||
auto* m_2 = Var("m_2", ty.mat4x4(ty.f32()), mat4x4<f32>(m_1));
|
||||
|
||||
WrapInFunction(m_1, m_2);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_THAT(gen.result(), HasSubstr("float4x4 m_2 = float4x4(m_1);"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Identity_F16) {
|
||||
// fn f() {
|
||||
// var m_1: mat4x4<f16> = mat4x4<f16>();
|
||||
// var m_2: mat4x4<f16> = mat4x4<f16>(m_1);
|
||||
// }
|
||||
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* m_1 = Var("m_1", ty.mat4x4(ty.f16()), mat4x4<f16>());
|
||||
auto* m_2 = Var("m_2", ty.mat4x4(ty.f16()), mat4x4<f16>(m_1));
|
||||
|
||||
WrapInFunction(m_1, m_2);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_THAT(gen.result(), HasSubstr("half4x4 m_2 = half4x4(m_1);"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Array) {
|
||||
|
|
|
@ -112,6 +112,26 @@ void f() {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = GlobalConst("G", nullptr, Expr(1_h));
|
||||
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
void f() {
|
||||
half const l = 1.0h;
|
||||
}
|
||||
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_vec3_AInt) {
|
||||
auto* var = GlobalConst("G", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a));
|
||||
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
|
||||
|
@ -166,6 +186,26 @@ void f() {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_vec3_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = GlobalConst("G", nullptr, vec3<f16>(1_h, 2_h, 3_h));
|
||||
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
void f() {
|
||||
half3 const l = half3(1.0h, 2.0h, 3.0h);
|
||||
}
|
||||
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_mat2x3_AFloat) {
|
||||
auto* var = GlobalConst("G", nullptr,
|
||||
Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a));
|
||||
|
@ -203,6 +243,26 @@ void f() {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_mat2x3_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = GlobalConst("G", nullptr, mat2x3<f16>(1_h, 2_h, 3_h, 4_h, 5_h, 6_h));
|
||||
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
void f() {
|
||||
half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h));
|
||||
}
|
||||
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_arr_f32) {
|
||||
auto* var = GlobalConst("G", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
|
||||
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
|
||||
|
|
|
@ -71,6 +71,18 @@ DECLARE_TYPE(float3x4, 48, 16);
|
|||
DECLARE_TYPE(float4x2, 32, 8);
|
||||
DECLARE_TYPE(float4x3, 64, 16);
|
||||
DECLARE_TYPE(float4x4, 64, 16);
|
||||
DECLARE_TYPE(half2, 4, 4);
|
||||
DECLARE_TYPE(packed_half3, 6, 2);
|
||||
DECLARE_TYPE(half4, 8, 8);
|
||||
DECLARE_TYPE(half2x2, 8, 4);
|
||||
DECLARE_TYPE(half2x3, 16, 8);
|
||||
DECLARE_TYPE(half2x4, 16, 8);
|
||||
DECLARE_TYPE(half3x2, 12, 4);
|
||||
DECLARE_TYPE(half3x3, 24, 8);
|
||||
DECLARE_TYPE(half3x4, 24, 8);
|
||||
DECLARE_TYPE(half4x2, 16, 4);
|
||||
DECLARE_TYPE(half4x3, 32, 8);
|
||||
DECLARE_TYPE(half4x4, 32, 8);
|
||||
using uint = unsigned int;
|
||||
|
||||
using MslGeneratorImplTest = TestHelper;
|
||||
|
@ -153,6 +165,16 @@ TEST_F(MslGeneratorImplTest, EmitType_F32) {
|
|||
EXPECT_EQ(out.str(), "float");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitType_F16) {
|
||||
auto* f16 = create<sem::F16>();
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitType(out, f16, "")) << gen.error();
|
||||
EXPECT_EQ(out.str(), "half");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitType_I32) {
|
||||
auto* i32 = create<sem::I32>();
|
||||
|
||||
|
@ -163,7 +185,7 @@ TEST_F(MslGeneratorImplTest, EmitType_I32) {
|
|||
EXPECT_EQ(out.str(), "int");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitType_Matrix) {
|
||||
TEST_F(MslGeneratorImplTest, EmitType_Matrix_F32) {
|
||||
auto* f32 = create<sem::F32>();
|
||||
auto* vec3 = create<sem::Vector>(f32, 3u);
|
||||
auto* mat2x3 = create<sem::Matrix>(vec3, 2u);
|
||||
|
@ -175,6 +197,18 @@ TEST_F(MslGeneratorImplTest, EmitType_Matrix) {
|
|||
EXPECT_EQ(out.str(), "float2x3");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitType_Matrix_F16) {
|
||||
auto* f16 = create<sem::F16>();
|
||||
auto* vec3 = create<sem::Vector>(f16, 3u);
|
||||
auto* mat2x3 = create<sem::Matrix>(vec3, 2u);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
std::stringstream out;
|
||||
ASSERT_TRUE(gen.EmitType(out, mat2x3, "")) << gen.error();
|
||||
EXPECT_EQ(out.str(), "half2x3");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, EmitType_Pointer) {
|
||||
auto* f32 = create<sem::F32>();
|
||||
auto* p = create<sem::Pointer>(f32, ast::StorageClass::kWorkgroup, ast::Access::kReadWrite);
|
||||
|
|
|
@ -154,6 +154,26 @@ void f() {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* C = Const("C", nullptr, Expr(1_h));
|
||||
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
void f() {
|
||||
half const l = 1.0h;
|
||||
}
|
||||
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_AInt) {
|
||||
auto* C = Const("C", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a));
|
||||
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
|
||||
|
@ -208,6 +228,26 @@ void f() {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* C = Const("C", nullptr, vec3<f16>(1_h, 2_h, 3_h));
|
||||
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
void f() {
|
||||
half3 const l = half3(1.0h, 2.0h, 3.0h);
|
||||
}
|
||||
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_AFloat) {
|
||||
auto* C =
|
||||
Const("C", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a));
|
||||
|
@ -245,6 +285,26 @@ void f() {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* C = Const("C", nullptr, mat2x3<f16>(1_h, 2_h, 3_h, 4_h, 5_h, 6_h));
|
||||
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.Generate()) << gen.error();
|
||||
|
||||
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
|
||||
|
||||
using namespace metal;
|
||||
void f() {
|
||||
half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h));
|
||||
}
|
||||
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_arr_f32) {
|
||||
auto* C = Const("C", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
|
||||
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
|
||||
|
@ -343,7 +403,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Struct) {
|
|||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector) {
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector_f32) {
|
||||
auto* var = Var("a", ty.vec2<f32>());
|
||||
auto* stmt = Decl(var);
|
||||
WrapInFunction(stmt);
|
||||
|
@ -356,7 +416,22 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector) {
|
|||
EXPECT_EQ(gen.result(), " float2 a = 0.0f;\n");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix) {
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = Var("a", ty.vec2<f16>());
|
||||
auto* stmt = Decl(var);
|
||||
WrapInFunction(stmt);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
gen.increment_indent();
|
||||
|
||||
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
|
||||
EXPECT_EQ(gen.result(), " half2 a = 0.0h;\n");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix_f32) {
|
||||
auto* var = Var("a", ty.mat3x2<f32>());
|
||||
|
||||
auto* stmt = Decl(var);
|
||||
|
@ -370,6 +445,80 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix) {
|
|||
EXPECT_EQ(gen.result(), " float3x2 a = float3x2(0.0f);\n");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = Var("a", ty.mat3x2<f16>());
|
||||
|
||||
auto* stmt = Decl(var);
|
||||
WrapInFunction(stmt);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
gen.increment_indent();
|
||||
|
||||
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
|
||||
EXPECT_EQ(gen.result(), " half3x2 a = half3x2(0.0h);\n");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec_f32) {
|
||||
auto* var = Var("a", ty.vec3<f32>(), ast::StorageClass::kNone, vec3<f32>());
|
||||
|
||||
auto* stmt = Decl(var);
|
||||
WrapInFunction(stmt);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
|
||||
EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f);
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = Var("a", ty.vec3<f16>(), ast::StorageClass::kNone, vec3<f16>());
|
||||
|
||||
auto* stmt = Decl(var);
|
||||
WrapInFunction(stmt);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
|
||||
EXPECT_EQ(gen.result(), R"(half3 a = half3(0.0h);
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroMat_f32) {
|
||||
auto* var = Var("a", ty.mat2x3<f32>(), ast::StorageClass::kNone, mat2x3<f32>());
|
||||
|
||||
auto* stmt = Decl(var);
|
||||
WrapInFunction(stmt);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
|
||||
EXPECT_EQ(gen.result(),
|
||||
R"(float2x3 a = float2x3(float3(0.0f), float3(0.0f));
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroMat_f16) {
|
||||
Enable(ast::Extension::kF16);
|
||||
|
||||
auto* var = Var("a", ty.mat2x3<f16>(), ast::StorageClass::kNone, mat2x3<f16>());
|
||||
|
||||
auto* stmt = Decl(var);
|
||||
WrapInFunction(stmt);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
|
||||
EXPECT_EQ(gen.result(),
|
||||
R"(half2x3 a = half2x3(half3(0.0h), half3(0.0h));
|
||||
)");
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Private) {
|
||||
GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate);
|
||||
|
||||
|
@ -396,19 +545,5 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Workgroup) {
|
|||
EXPECT_THAT(gen.result(), HasSubstr("threadgroup float tint_symbol_2;\n"));
|
||||
}
|
||||
|
||||
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec) {
|
||||
auto* zero_vec = vec3<f32>();
|
||||
|
||||
auto* var = Var("a", ty.vec3<f32>(), ast::StorageClass::kNone, zero_vec);
|
||||
auto* stmt = Decl(var);
|
||||
WrapInFunction(stmt);
|
||||
|
||||
GeneratorImpl& gen = Build();
|
||||
|
||||
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
|
||||
EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f);
|
||||
)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tint::writer::msl
|
||||
|
|
Loading…
Reference in New Issue