From e9c5070348abd0f33529b225dfcc2cc4703bc3ea Mon Sep 17 00:00:00 2001 From: Zhaoming Jiang Date: Mon, 11 Jul 2022 03:03:11 +0000 Subject: [PATCH] tint/writer/msl: Support for F16 type, constructor, and convertor This patch make MSL writer support emitting f16 types, f16 literals, f16 constructor and convertor. Unittests are also implemented. The MSL writer will emit f16 literal as `1.23h`, and map f16 types as follow: WGSL type -> MSL type f16 -> half vec2 -> half2 vec3 -> half3 vec4 -> half4 mat2x2 -> half2x2 mat2x3 -> half2x3 mat2x4 -> half2x4 mat3x2 -> half3x2 mat3x3 -> half3x3 mat3x4 -> half3x4 mat4x2 -> half4x2 mat4x3 -> half4x3 mat4x4 -> half4x4 Bug: tint:1473, tint:1502 Change-Id: Id91821e1a32d48c80bad9a0753faa5247835b0f7 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95686 Commit-Queue: Zhaoming Jiang Reviewed-by: Ben Clayton --- src/tint/writer/msl/generator_impl.cc | 69 ++++- .../msl/generator_impl_constructor_test.cc | 256 +++++++++++++++++- .../generator_impl_module_constant_test.cc | 60 ++++ .../writer/msl/generator_impl_type_test.cc | 36 ++- ...rator_impl_variable_decl_statement_test.cc | 167 ++++++++++-- 5 files changed, 548 insertions(+), 40 deletions(-) diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index 43324bb7a3..e2b14a6563 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -39,6 +39,7 @@ #include "src/tint/sem/constant.h" #include "src/tint/sem/depth_multisampled_texture.h" #include "src/tint/sem/depth_texture.h" +#include "src/tint/sem/f16.h" #include "src/tint/sem/f32.h" #include "src/tint/sem/function.h" #include "src/tint/sem/i32.h" @@ -97,6 +98,20 @@ void PrintF32(std::ostream& out, float value) { } } +void PrintF16(std::ostream& out, float value) { + // Note: Currently inf and nan should not be constructable, but this is implemented for the day + // we support them. + if (std::isinf(value)) { + // HUGE_VALH evaluates to +infinity. + out << (value >= 0 ? "HUGE_VALH" : "-HUGE_VALH"); + } else if (std::isnan(value)) { + // There is no NaN expr for half in MSL, "NAN" is of float type. + out << "NAN"; + } else { + out << FloatToString(value) << "h"; + } +} + void PrintI32(std::ostream& out, int32_t value) { // MSL (and C++) parse `-2147483648` as a `long` because it parses unary minus and `2147483648` // as separate tokens, and the latter doesn't fit into an (32-bit) `int`. @@ -1551,10 +1566,8 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) { return true; }, [&](const sem::F16*) { - // Placeholder for emitting f16 zero value - diagnostics_.add_error(diag::System::Writer, - "Type f16 is not completely implemented yet"); - return false; + out << "0.0h"; + return true; }, [&](const sem::F32*) { out << "0.0f"; @@ -1605,6 +1618,10 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan PrintF32(out, constant->As()); return true; }, + [&](const sem::F16*) { + PrintF16(out, constant->As()); + return true; + }, [&](const sem::I32*) { PrintI32(out, constant->As()); return true; @@ -1694,7 +1711,11 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression* return true; }, [&](const ast::FloatLiteralExpression* l) { - PrintF32(out, static_cast(l->value)); + if (l->suffix == ast::FloatLiteralExpression::Suffix::kH) { + PrintF16(out, static_cast(l->value)); + } else { + PrintF32(out, static_cast(l->value)); + } return true; }, [&](const ast::IntLiteralExpression* i) { @@ -2459,9 +2480,8 @@ bool GeneratorImpl::EmitType(std::ostream& out, return true; }, [&](const sem::F16*) { - diagnostics_.add_error(diag::System::Writer, - "Type f16 is not completely implemented yet"); - return false; + out << "half"; + return true; }, [&](const sem::F32*) { out << "float"; @@ -3043,20 +3063,25 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem:: [&](const sem::F32*) { return SizeAndAlign{4, 4}; }, + [&](const sem::F16*) { + return SizeAndAlign{2, 2}; + }, [&](const sem::Vector* vec) { auto num_els = vec->Width(); auto* el_ty = vec->type(); - if (el_ty->IsAnyOf()) { + SizeAndAlign el_size_align = MslPackedTypeSizeAndAlign(el_ty); + if (el_ty->IsAnyOf()) { // Use a packed_vec type for 3-element vectors only. if (num_els == 3) { // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // 2.2.3 Packed Vector Types - return SizeAndAlign{num_els * 4, 4}; + return SizeAndAlign{num_els * el_size_align.size, el_size_align.align}; } else { // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // 2.2 Vector Data Types - return SizeAndAlign{num_els * 4, num_els * 4}; + // Vector data types are aligned to their size. + return SizeAndAlign{num_els * el_size_align.size, num_els * el_size_align.size}; } } TINT_UNREACHABLE(Writer, diagnostics_) @@ -3070,8 +3095,9 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem:: auto cols = mat->columns(); auto rows = mat->rows(); auto* el_ty = mat->type(); - if (el_ty->IsAnyOf()) { - static constexpr SizeAndAlign table[] = { + // Metal only support half and float matrix. + if (el_ty->IsAnyOf()) { + static constexpr SizeAndAlign table_f32[] = { /* float2x2 */ {16, 8}, /* float2x3 */ {32, 16}, /* float2x4 */ {32, 16}, @@ -3082,8 +3108,23 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem:: /* float4x3 */ {64, 16}, /* float4x4 */ {64, 16}, }; + static constexpr SizeAndAlign table_f16[] = { + /* half2x2 */ {8, 4}, + /* half2x3 */ {16, 8}, + /* half2x4 */ {16, 8}, + /* half3x2 */ {12, 4}, + /* half3x3 */ {24, 8}, + /* half3x4 */ {24, 8}, + /* half4x2 */ {16, 4}, + /* half4x3 */ {32, 8}, + /* half4x4 */ {32, 8}, + }; if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) { - return table[(3 * (cols - 2)) + (rows - 2)]; + if (el_ty->Is()) { + return table_f32[(3 * (cols - 2)) + (rows - 2)]; + } else { + return table_f16[(3 * (cols - 2)) + (rows - 2)]; + } } } diff --git a/src/tint/writer/msl/generator_impl_constructor_test.cc b/src/tint/writer/msl/generator_impl_constructor_test.cc index 2fa85f01ce..14fc4a5a15 100644 --- a/src/tint/writer/msl/generator_impl_constructor_test.cc +++ b/src/tint/writer/msl/generator_impl_constructor_test.cc @@ -61,6 +61,18 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Float) { EXPECT_THAT(gen.result(), HasSubstr("1073741824.0f")); } +TEST_F(MslGeneratorImplTest, EmitConstructor_F16) { + Enable(ast::Extension::kF16); + + // Use a number close to 1<<16 but whose decimal representation ends in 0. + WrapInFunction(Expr(f16((1 << 15) - 8))); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("32752.0h")); +} + TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) { WrapInFunction(Construct(-1.2e-5_f)); @@ -70,6 +82,17 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) { EXPECT_THAT(gen.result(), HasSubstr("-0.000012f")); } +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(Construct(-1.2e-3_h)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("-0.00119972229h")); +} + TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Bool) { WrapInFunction(Construct(true)); @@ -97,7 +120,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Uint) { EXPECT_THAT(gen.result(), HasSubstr("12345u")); } -TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) { +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_F32) { WrapInFunction(vec3(1_f, 2_f, 3_f)); GeneratorImpl& gen = Build(); @@ -106,7 +129,18 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) { EXPECT_THAT(gen.result(), HasSubstr("float3(1.0f, 2.0f, 3.0f)")); } -TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) { +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(vec3(1_h, 2_h, 3_h)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("half3(1.0h, 2.0h, 3.0h)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty_F32) { WrapInFunction(vec3()); GeneratorImpl& gen = Build(); @@ -115,26 +149,230 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) { EXPECT_THAT(gen.result(), HasSubstr("float3(0.0f)")); } -TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat) { - WrapInFunction(Construct(ty.mat2x3(), vec3(1_f, 2_f, 3_f), vec3(3_f, 4_f, 5_f))); +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(vec3()); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("half3(0.0h)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F32_Literal) { + WrapInFunction(vec3(2_f)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("float3(2.0f)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F16_Literal) { + Enable(ast::Extension::kF16); + + WrapInFunction(vec3(2_h)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("half3(2.0h)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F32_Var) { + auto* var = Var("v", nullptr, Expr(2_f)); + auto* cast = vec3(var); + WrapInFunction(var, cast); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr(R"(float v = 2.0f; + float3 const tint_symbol = float3(v);)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F16_Var) { + Enable(ast::Extension::kF16); + + auto* var = Var("v", nullptr, Expr(2_h)); + auto* cast = vec3(var); + WrapInFunction(var, cast); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr(R"(half v = 2.0h; + half3 const tint_symbol = half3(v);)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_Bool) { + WrapInFunction(vec3(true)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("bool3(true)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_Int) { + WrapInFunction(vec3(2_i)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("int3(2)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_UInt) { + WrapInFunction(vec3(2_u)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("uint3(2u)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_F32) { + WrapInFunction(mat2x3(vec3(1_f, 2_f, 3_f), vec3(3_f, 4_f, 5_f))); GeneratorImpl& gen = Build(); ASSERT_TRUE(gen.Generate()) << gen.error(); - // A matrix of type T with n columns and m rows can also be constructed from - // n vectors of type T with m components. EXPECT_THAT(gen.result(), HasSubstr("float2x3(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))")); } -TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty) { - WrapInFunction(mat4x4()); +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(mat2x3(vec3(1_h, 2_h, 3_h), vec3(3_h, 4_h, 5_h))); GeneratorImpl& gen = Build(); ASSERT_TRUE(gen.Generate()) << gen.error(); - EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(0.0f), float4(0.0f)")); + + EXPECT_THAT(gen.result(), + HasSubstr("half2x3(half3(1.0h, 2.0h, 3.0h), half3(3.0h, 4.0h, 5.0h))")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Complex_F32) { + // mat4x4( + // vec4(2.0f, 3.0f, 4.0f, 8.0f), + // vec4(), + // vec4(7.0f), + // vec4(vec4(42.0f, 21.0f, 6.0f, -5.0f)), + // ); + auto* vector_literal = + vec4(Expr(f32(2.0)), Expr(f32(3.0)), Expr(f32(4.0)), Expr(f32(8.0))); + auto* vector_zero_ctor = vec4(); + auto* vector_single_scalar_ctor = vec4(Expr(f32(7.0))); + auto* vector_identical_ctor = + vec4(vec4(Expr(f32(42.0)), Expr(f32(21.0)), Expr(f32(6.0)), Expr(f32(-5.0)))); + + auto* constructor = mat4x4(vector_literal, vector_zero_ctor, vector_single_scalar_ctor, + vector_identical_ctor); + + WrapInFunction(constructor); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(2.0f, 3.0f, 4.0f, 8.0f), float4(0.0f), " + "float4(7.0f), float4(42.0f, 21.0f, 6.0f, -5.0f))")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Complex_F16) { + // mat4x4( + // vec4(2.0h, 3.0h, 4.0h, 8.0h), + // vec4(), + // vec4(7.0h), + // vec4(vec4(42.0h, 21.0h, 6.0h, -5.0h)), + // ); + Enable(ast::Extension::kF16); + + auto* vector_literal = + vec4(Expr(f16(2.0)), Expr(f16(3.0)), Expr(f16(4.0)), Expr(f16(8.0))); + auto* vector_zero_ctor = vec4(); + auto* vector_single_scalar_ctor = vec4(Expr(f16(7.0))); + auto* vector_identical_ctor = + vec4(vec4(Expr(f16(42.0)), Expr(f16(21.0)), Expr(f16(6.0)), Expr(f16(-5.0)))); + + auto* constructor = mat4x4(vector_literal, vector_zero_ctor, vector_single_scalar_ctor, + vector_identical_ctor); + + WrapInFunction(constructor); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), HasSubstr("half4x4(half4(2.0h, 3.0h, 4.0h, 8.0h), half4(0.0h), " + "half4(7.0h), half4(42.0h, 21.0h, 6.0h, -5.0h))")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty_F32) { + WrapInFunction(mat2x3()); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), + HasSubstr("float2x3 const tint_symbol = float2x3(float3(0.0f), float3(0.0f))")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(mat2x3()); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), + HasSubstr("half2x3 const tint_symbol = half2x3(half3(0.0h), half3(0.0h))")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Identity_F32) { + // fn f() { + // var m_1: mat4x4 = mat4x4(); + // var m_2: mat4x4 = mat4x4(m_1); + // } + + auto* m_1 = Var("m_1", ty.mat4x4(ty.f32()), mat4x4()); + auto* m_2 = Var("m_2", ty.mat4x4(ty.f32()), mat4x4(m_1)); + + WrapInFunction(m_1, m_2); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), HasSubstr("float4x4 m_2 = float4x4(m_1);")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Identity_F16) { + // fn f() { + // var m_1: mat4x4 = mat4x4(); + // var m_2: mat4x4 = mat4x4(m_1); + // } + + Enable(ast::Extension::kF16); + + auto* m_1 = Var("m_1", ty.mat4x4(ty.f16()), mat4x4()); + auto* m_2 = Var("m_2", ty.mat4x4(ty.f16()), mat4x4(m_1)); + + WrapInFunction(m_1, m_2); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), HasSubstr("half4x4 m_2 = half4x4(m_1);")); } TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Array) { diff --git a/src/tint/writer/msl/generator_impl_module_constant_test.cc b/src/tint/writer/msl/generator_impl_module_constant_test.cc index c2046861d1..4834b3dcf1 100644 --- a/src/tint/writer/msl/generator_impl_module_constant_test.cc +++ b/src/tint/writer/msl/generator_impl_module_constant_test.cc @@ -112,6 +112,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_GlobalConst_f16) { + Enable(ast::Extension::kF16); + + auto* var = GlobalConst("G", nullptr, Expr(1_h)); + Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half const l = 1.0h; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_GlobalConst_vec3_AInt) { auto* var = GlobalConst("G", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a)); Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); @@ -166,6 +186,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_GlobalConst_vec3_f16) { + Enable(ast::Extension::kF16); + + auto* var = GlobalConst("G", nullptr, vec3(1_h, 2_h, 3_h)); + Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half3 const l = half3(1.0h, 2.0h, 3.0h); +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_GlobalConst_mat2x3_AFloat) { auto* var = GlobalConst("G", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a)); @@ -203,6 +243,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_GlobalConst_mat2x3_f16) { + Enable(ast::Extension::kF16); + + auto* var = GlobalConst("G", nullptr, mat2x3(1_h, 2_h, 3_h, 4_h, 5_h, 6_h)); + Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h)); +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_GlobalConst_arr_f32) { auto* var = GlobalConst("G", nullptr, Construct(ty.array(), 1_f, 2_f, 3_f)); Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); diff --git a/src/tint/writer/msl/generator_impl_type_test.cc b/src/tint/writer/msl/generator_impl_type_test.cc index 070b1bb51a..4ed5a76017 100644 --- a/src/tint/writer/msl/generator_impl_type_test.cc +++ b/src/tint/writer/msl/generator_impl_type_test.cc @@ -71,6 +71,18 @@ DECLARE_TYPE(float3x4, 48, 16); DECLARE_TYPE(float4x2, 32, 8); DECLARE_TYPE(float4x3, 64, 16); DECLARE_TYPE(float4x4, 64, 16); +DECLARE_TYPE(half2, 4, 4); +DECLARE_TYPE(packed_half3, 6, 2); +DECLARE_TYPE(half4, 8, 8); +DECLARE_TYPE(half2x2, 8, 4); +DECLARE_TYPE(half2x3, 16, 8); +DECLARE_TYPE(half2x4, 16, 8); +DECLARE_TYPE(half3x2, 12, 4); +DECLARE_TYPE(half3x3, 24, 8); +DECLARE_TYPE(half3x4, 24, 8); +DECLARE_TYPE(half4x2, 16, 4); +DECLARE_TYPE(half4x3, 32, 8); +DECLARE_TYPE(half4x4, 32, 8); using uint = unsigned int; using MslGeneratorImplTest = TestHelper; @@ -153,6 +165,16 @@ TEST_F(MslGeneratorImplTest, EmitType_F32) { EXPECT_EQ(out.str(), "float"); } +TEST_F(MslGeneratorImplTest, EmitType_F16) { + auto* f16 = create(); + + GeneratorImpl& gen = Build(); + + std::stringstream out; + ASSERT_TRUE(gen.EmitType(out, f16, "")) << gen.error(); + EXPECT_EQ(out.str(), "half"); +} + TEST_F(MslGeneratorImplTest, EmitType_I32) { auto* i32 = create(); @@ -163,7 +185,7 @@ TEST_F(MslGeneratorImplTest, EmitType_I32) { EXPECT_EQ(out.str(), "int"); } -TEST_F(MslGeneratorImplTest, EmitType_Matrix) { +TEST_F(MslGeneratorImplTest, EmitType_Matrix_F32) { auto* f32 = create(); auto* vec3 = create(f32, 3u); auto* mat2x3 = create(vec3, 2u); @@ -175,6 +197,18 @@ TEST_F(MslGeneratorImplTest, EmitType_Matrix) { EXPECT_EQ(out.str(), "float2x3"); } +TEST_F(MslGeneratorImplTest, EmitType_Matrix_F16) { + auto* f16 = create(); + auto* vec3 = create(f16, 3u); + auto* mat2x3 = create(vec3, 2u); + + GeneratorImpl& gen = Build(); + + std::stringstream out; + ASSERT_TRUE(gen.EmitType(out, mat2x3, "")) << gen.error(); + EXPECT_EQ(out.str(), "half2x3"); +} + TEST_F(MslGeneratorImplTest, EmitType_Pointer) { auto* f32 = create(); auto* p = create(f32, ast::StorageClass::kWorkgroup, ast::Access::kReadWrite); diff --git a/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc b/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc index 46becb0dc0..10c5b3d599 100644 --- a/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc +++ b/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc @@ -154,6 +154,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_f16) { + Enable(ast::Extension::kF16); + + auto* C = Const("C", nullptr, Expr(1_h)); + Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half const l = 1.0h; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_AInt) { auto* C = Const("C", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a)); Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); @@ -208,6 +228,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_f16) { + Enable(ast::Extension::kF16); + + auto* C = Const("C", nullptr, vec3(1_h, 2_h, 3_h)); + Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half3 const l = half3(1.0h, 2.0h, 3.0h); +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_AFloat) { auto* C = Const("C", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a)); @@ -245,6 +285,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_f16) { + Enable(ast::Extension::kF16); + + auto* C = Const("C", nullptr, mat2x3(1_h, 2_h, 3_h, 4_h, 5_h, 6_h)); + Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h)); +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_arr_f32) { auto* C = Const("C", nullptr, Construct(ty.array(), 1_f, 2_f, 3_f)); Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); @@ -343,7 +403,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Struct) { )"); } -TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector) { +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector_f32) { auto* var = Var("a", ty.vec2()); auto* stmt = Decl(var); WrapInFunction(stmt); @@ -356,7 +416,22 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector) { EXPECT_EQ(gen.result(), " float2 a = 0.0f;\n"); } -TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix) { +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector_f16) { + Enable(ast::Extension::kF16); + + auto* var = Var("a", ty.vec2()); + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), " half2 a = 0.0h;\n"); +} + +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix_f32) { auto* var = Var("a", ty.mat3x2()); auto* stmt = Decl(var); @@ -370,6 +445,80 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix) { EXPECT_EQ(gen.result(), " float3x2 a = float3x2(0.0f);\n"); } +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix_f16) { + Enable(ast::Extension::kF16); + + auto* var = Var("a", ty.mat3x2()); + + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), " half3x2 a = half3x2(0.0h);\n"); +} + +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec_f32) { + auto* var = Var("a", ty.vec3(), ast::StorageClass::kNone, vec3()); + + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f); +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec_f16) { + Enable(ast::Extension::kF16); + + auto* var = Var("a", ty.vec3(), ast::StorageClass::kNone, vec3()); + + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), R"(half3 a = half3(0.0h); +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroMat_f32) { + auto* var = Var("a", ty.mat2x3(), ast::StorageClass::kNone, mat2x3()); + + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), + R"(float2x3 a = float2x3(float3(0.0f), float3(0.0f)); +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroMat_f16) { + Enable(ast::Extension::kF16); + + auto* var = Var("a", ty.mat2x3(), ast::StorageClass::kNone, mat2x3()); + + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), + R"(half2x3 a = half2x3(half3(0.0h), half3(0.0h)); +)"); +} + TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Private) { GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate); @@ -396,19 +545,5 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Workgroup) { EXPECT_THAT(gen.result(), HasSubstr("threadgroup float tint_symbol_2;\n")); } -TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec) { - auto* zero_vec = vec3(); - - auto* var = Var("a", ty.vec3(), ast::StorageClass::kNone, zero_vec); - auto* stmt = Decl(var); - WrapInFunction(stmt); - - GeneratorImpl& gen = Build(); - - ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); - EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f); -)"); -} - } // namespace } // namespace tint::writer::msl