diff --git a/src/tint/writer/msl/generator_impl.cc b/src/tint/writer/msl/generator_impl.cc index 43324bb7a3..e2b14a6563 100644 --- a/src/tint/writer/msl/generator_impl.cc +++ b/src/tint/writer/msl/generator_impl.cc @@ -39,6 +39,7 @@ #include "src/tint/sem/constant.h" #include "src/tint/sem/depth_multisampled_texture.h" #include "src/tint/sem/depth_texture.h" +#include "src/tint/sem/f16.h" #include "src/tint/sem/f32.h" #include "src/tint/sem/function.h" #include "src/tint/sem/i32.h" @@ -97,6 +98,20 @@ void PrintF32(std::ostream& out, float value) { } } +void PrintF16(std::ostream& out, float value) { + // Note: Currently inf and nan should not be constructable, but this is implemented for the day + // we support them. + if (std::isinf(value)) { + // HUGE_VALH evaluates to +infinity. + out << (value >= 0 ? "HUGE_VALH" : "-HUGE_VALH"); + } else if (std::isnan(value)) { + // There is no NaN expr for half in MSL, "NAN" is of float type. + out << "NAN"; + } else { + out << FloatToString(value) << "h"; + } +} + void PrintI32(std::ostream& out, int32_t value) { // MSL (and C++) parse `-2147483648` as a `long` because it parses unary minus and `2147483648` // as separate tokens, and the latter doesn't fit into an (32-bit) `int`. @@ -1551,10 +1566,8 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) { return true; }, [&](const sem::F16*) { - // Placeholder for emitting f16 zero value - diagnostics_.add_error(diag::System::Writer, - "Type f16 is not completely implemented yet"); - return false; + out << "0.0h"; + return true; }, [&](const sem::F32*) { out << "0.0f"; @@ -1605,6 +1618,10 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan PrintF32(out, constant->As()); return true; }, + [&](const sem::F16*) { + PrintF16(out, constant->As()); + return true; + }, [&](const sem::I32*) { PrintI32(out, constant->As()); return true; @@ -1694,7 +1711,11 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression* return true; }, [&](const ast::FloatLiteralExpression* l) { - PrintF32(out, static_cast(l->value)); + if (l->suffix == ast::FloatLiteralExpression::Suffix::kH) { + PrintF16(out, static_cast(l->value)); + } else { + PrintF32(out, static_cast(l->value)); + } return true; }, [&](const ast::IntLiteralExpression* i) { @@ -2459,9 +2480,8 @@ bool GeneratorImpl::EmitType(std::ostream& out, return true; }, [&](const sem::F16*) { - diagnostics_.add_error(diag::System::Writer, - "Type f16 is not completely implemented yet"); - return false; + out << "half"; + return true; }, [&](const sem::F32*) { out << "float"; @@ -3043,20 +3063,25 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem:: [&](const sem::F32*) { return SizeAndAlign{4, 4}; }, + [&](const sem::F16*) { + return SizeAndAlign{2, 2}; + }, [&](const sem::Vector* vec) { auto num_els = vec->Width(); auto* el_ty = vec->type(); - if (el_ty->IsAnyOf()) { + SizeAndAlign el_size_align = MslPackedTypeSizeAndAlign(el_ty); + if (el_ty->IsAnyOf()) { // Use a packed_vec type for 3-element vectors only. if (num_els == 3) { // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // 2.2.3 Packed Vector Types - return SizeAndAlign{num_els * 4, 4}; + return SizeAndAlign{num_els * el_size_align.size, el_size_align.align}; } else { // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // 2.2 Vector Data Types - return SizeAndAlign{num_els * 4, num_els * 4}; + // Vector data types are aligned to their size. + return SizeAndAlign{num_els * el_size_align.size, num_els * el_size_align.size}; } } TINT_UNREACHABLE(Writer, diagnostics_) @@ -3070,8 +3095,9 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem:: auto cols = mat->columns(); auto rows = mat->rows(); auto* el_ty = mat->type(); - if (el_ty->IsAnyOf()) { - static constexpr SizeAndAlign table[] = { + // Metal only support half and float matrix. + if (el_ty->IsAnyOf()) { + static constexpr SizeAndAlign table_f32[] = { /* float2x2 */ {16, 8}, /* float2x3 */ {32, 16}, /* float2x4 */ {32, 16}, @@ -3082,8 +3108,23 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem:: /* float4x3 */ {64, 16}, /* float4x4 */ {64, 16}, }; + static constexpr SizeAndAlign table_f16[] = { + /* half2x2 */ {8, 4}, + /* half2x3 */ {16, 8}, + /* half2x4 */ {16, 8}, + /* half3x2 */ {12, 4}, + /* half3x3 */ {24, 8}, + /* half3x4 */ {24, 8}, + /* half4x2 */ {16, 4}, + /* half4x3 */ {32, 8}, + /* half4x4 */ {32, 8}, + }; if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) { - return table[(3 * (cols - 2)) + (rows - 2)]; + if (el_ty->Is()) { + return table_f32[(3 * (cols - 2)) + (rows - 2)]; + } else { + return table_f16[(3 * (cols - 2)) + (rows - 2)]; + } } } diff --git a/src/tint/writer/msl/generator_impl_constructor_test.cc b/src/tint/writer/msl/generator_impl_constructor_test.cc index 2fa85f01ce..14fc4a5a15 100644 --- a/src/tint/writer/msl/generator_impl_constructor_test.cc +++ b/src/tint/writer/msl/generator_impl_constructor_test.cc @@ -61,6 +61,18 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Float) { EXPECT_THAT(gen.result(), HasSubstr("1073741824.0f")); } +TEST_F(MslGeneratorImplTest, EmitConstructor_F16) { + Enable(ast::Extension::kF16); + + // Use a number close to 1<<16 but whose decimal representation ends in 0. + WrapInFunction(Expr(f16((1 << 15) - 8))); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("32752.0h")); +} + TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) { WrapInFunction(Construct(-1.2e-5_f)); @@ -70,6 +82,17 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) { EXPECT_THAT(gen.result(), HasSubstr("-0.000012f")); } +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(Construct(-1.2e-3_h)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("-0.00119972229h")); +} + TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Bool) { WrapInFunction(Construct(true)); @@ -97,7 +120,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Uint) { EXPECT_THAT(gen.result(), HasSubstr("12345u")); } -TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) { +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_F32) { WrapInFunction(vec3(1_f, 2_f, 3_f)); GeneratorImpl& gen = Build(); @@ -106,7 +129,18 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) { EXPECT_THAT(gen.result(), HasSubstr("float3(1.0f, 2.0f, 3.0f)")); } -TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) { +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(vec3(1_h, 2_h, 3_h)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("half3(1.0h, 2.0h, 3.0h)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty_F32) { WrapInFunction(vec3()); GeneratorImpl& gen = Build(); @@ -115,26 +149,230 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) { EXPECT_THAT(gen.result(), HasSubstr("float3(0.0f)")); } -TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat) { - WrapInFunction(Construct(ty.mat2x3(), vec3(1_f, 2_f, 3_f), vec3(3_f, 4_f, 5_f))); +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(vec3()); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("half3(0.0h)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F32_Literal) { + WrapInFunction(vec3(2_f)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("float3(2.0f)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F16_Literal) { + Enable(ast::Extension::kF16); + + WrapInFunction(vec3(2_h)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("half3(2.0h)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F32_Var) { + auto* var = Var("v", nullptr, Expr(2_f)); + auto* cast = vec3(var); + WrapInFunction(var, cast); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr(R"(float v = 2.0f; + float3 const tint_symbol = float3(v);)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F16_Var) { + Enable(ast::Extension::kF16); + + auto* var = Var("v", nullptr, Expr(2_h)); + auto* cast = vec3(var); + WrapInFunction(var, cast); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr(R"(half v = 2.0h; + half3 const tint_symbol = half3(v);)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_Bool) { + WrapInFunction(vec3(true)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("bool3(true)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_Int) { + WrapInFunction(vec3(2_i)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("int3(2)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_UInt) { + WrapInFunction(vec3(2_u)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("uint3(2u)")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_F32) { + WrapInFunction(mat2x3(vec3(1_f, 2_f, 3_f), vec3(3_f, 4_f, 5_f))); GeneratorImpl& gen = Build(); ASSERT_TRUE(gen.Generate()) << gen.error(); - // A matrix of type T with n columns and m rows can also be constructed from - // n vectors of type T with m components. EXPECT_THAT(gen.result(), HasSubstr("float2x3(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))")); } -TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty) { - WrapInFunction(mat4x4()); +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(mat2x3(vec3(1_h, 2_h, 3_h), vec3(3_h, 4_h, 5_h))); GeneratorImpl& gen = Build(); ASSERT_TRUE(gen.Generate()) << gen.error(); - EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(0.0f), float4(0.0f)")); + + EXPECT_THAT(gen.result(), + HasSubstr("half2x3(half3(1.0h, 2.0h, 3.0h), half3(3.0h, 4.0h, 5.0h))")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Complex_F32) { + // mat4x4( + // vec4(2.0f, 3.0f, 4.0f, 8.0f), + // vec4(), + // vec4(7.0f), + // vec4(vec4(42.0f, 21.0f, 6.0f, -5.0f)), + // ); + auto* vector_literal = + vec4(Expr(f32(2.0)), Expr(f32(3.0)), Expr(f32(4.0)), Expr(f32(8.0))); + auto* vector_zero_ctor = vec4(); + auto* vector_single_scalar_ctor = vec4(Expr(f32(7.0))); + auto* vector_identical_ctor = + vec4(vec4(Expr(f32(42.0)), Expr(f32(21.0)), Expr(f32(6.0)), Expr(f32(-5.0)))); + + auto* constructor = mat4x4(vector_literal, vector_zero_ctor, vector_single_scalar_ctor, + vector_identical_ctor); + + WrapInFunction(constructor); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(2.0f, 3.0f, 4.0f, 8.0f), float4(0.0f), " + "float4(7.0f), float4(42.0f, 21.0f, 6.0f, -5.0f))")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Complex_F16) { + // mat4x4( + // vec4(2.0h, 3.0h, 4.0h, 8.0h), + // vec4(), + // vec4(7.0h), + // vec4(vec4(42.0h, 21.0h, 6.0h, -5.0h)), + // ); + Enable(ast::Extension::kF16); + + auto* vector_literal = + vec4(Expr(f16(2.0)), Expr(f16(3.0)), Expr(f16(4.0)), Expr(f16(8.0))); + auto* vector_zero_ctor = vec4(); + auto* vector_single_scalar_ctor = vec4(Expr(f16(7.0))); + auto* vector_identical_ctor = + vec4(vec4(Expr(f16(42.0)), Expr(f16(21.0)), Expr(f16(6.0)), Expr(f16(-5.0)))); + + auto* constructor = mat4x4(vector_literal, vector_zero_ctor, vector_single_scalar_ctor, + vector_identical_ctor); + + WrapInFunction(constructor); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), HasSubstr("half4x4(half4(2.0h, 3.0h, 4.0h, 8.0h), half4(0.0h), " + "half4(7.0h), half4(42.0h, 21.0h, 6.0h, -5.0h))")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty_F32) { + WrapInFunction(mat2x3()); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), + HasSubstr("float2x3 const tint_symbol = float2x3(float3(0.0f), float3(0.0f))")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(mat2x3()); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), + HasSubstr("half2x3 const tint_symbol = half2x3(half3(0.0h), half3(0.0h))")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Identity_F32) { + // fn f() { + // var m_1: mat4x4 = mat4x4(); + // var m_2: mat4x4 = mat4x4(m_1); + // } + + auto* m_1 = Var("m_1", ty.mat4x4(ty.f32()), mat4x4()); + auto* m_2 = Var("m_2", ty.mat4x4(ty.f32()), mat4x4(m_1)); + + WrapInFunction(m_1, m_2); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), HasSubstr("float4x4 m_2 = float4x4(m_1);")); +} + +TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Identity_F16) { + // fn f() { + // var m_1: mat4x4 = mat4x4(); + // var m_2: mat4x4 = mat4x4(m_1); + // } + + Enable(ast::Extension::kF16); + + auto* m_1 = Var("m_1", ty.mat4x4(ty.f16()), mat4x4()); + auto* m_2 = Var("m_2", ty.mat4x4(ty.f16()), mat4x4(m_1)); + + WrapInFunction(m_1, m_2); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_THAT(gen.result(), HasSubstr("half4x4 m_2 = half4x4(m_1);")); } TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Array) { diff --git a/src/tint/writer/msl/generator_impl_module_constant_test.cc b/src/tint/writer/msl/generator_impl_module_constant_test.cc index c2046861d1..4834b3dcf1 100644 --- a/src/tint/writer/msl/generator_impl_module_constant_test.cc +++ b/src/tint/writer/msl/generator_impl_module_constant_test.cc @@ -112,6 +112,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_GlobalConst_f16) { + Enable(ast::Extension::kF16); + + auto* var = GlobalConst("G", nullptr, Expr(1_h)); + Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half const l = 1.0h; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_GlobalConst_vec3_AInt) { auto* var = GlobalConst("G", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a)); Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); @@ -166,6 +186,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_GlobalConst_vec3_f16) { + Enable(ast::Extension::kF16); + + auto* var = GlobalConst("G", nullptr, vec3(1_h, 2_h, 3_h)); + Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half3 const l = half3(1.0h, 2.0h, 3.0h); +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_GlobalConst_mat2x3_AFloat) { auto* var = GlobalConst("G", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a)); @@ -203,6 +243,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_GlobalConst_mat2x3_f16) { + Enable(ast::Extension::kF16); + + auto* var = GlobalConst("G", nullptr, mat2x3(1_h, 2_h, 3_h, 4_h, 5_h, 6_h)); + Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h)); +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_GlobalConst_arr_f32) { auto* var = GlobalConst("G", nullptr, Construct(ty.array(), 1_f, 2_f, 3_f)); Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); diff --git a/src/tint/writer/msl/generator_impl_type_test.cc b/src/tint/writer/msl/generator_impl_type_test.cc index 070b1bb51a..4ed5a76017 100644 --- a/src/tint/writer/msl/generator_impl_type_test.cc +++ b/src/tint/writer/msl/generator_impl_type_test.cc @@ -71,6 +71,18 @@ DECLARE_TYPE(float3x4, 48, 16); DECLARE_TYPE(float4x2, 32, 8); DECLARE_TYPE(float4x3, 64, 16); DECLARE_TYPE(float4x4, 64, 16); +DECLARE_TYPE(half2, 4, 4); +DECLARE_TYPE(packed_half3, 6, 2); +DECLARE_TYPE(half4, 8, 8); +DECLARE_TYPE(half2x2, 8, 4); +DECLARE_TYPE(half2x3, 16, 8); +DECLARE_TYPE(half2x4, 16, 8); +DECLARE_TYPE(half3x2, 12, 4); +DECLARE_TYPE(half3x3, 24, 8); +DECLARE_TYPE(half3x4, 24, 8); +DECLARE_TYPE(half4x2, 16, 4); +DECLARE_TYPE(half4x3, 32, 8); +DECLARE_TYPE(half4x4, 32, 8); using uint = unsigned int; using MslGeneratorImplTest = TestHelper; @@ -153,6 +165,16 @@ TEST_F(MslGeneratorImplTest, EmitType_F32) { EXPECT_EQ(out.str(), "float"); } +TEST_F(MslGeneratorImplTest, EmitType_F16) { + auto* f16 = create(); + + GeneratorImpl& gen = Build(); + + std::stringstream out; + ASSERT_TRUE(gen.EmitType(out, f16, "")) << gen.error(); + EXPECT_EQ(out.str(), "half"); +} + TEST_F(MslGeneratorImplTest, EmitType_I32) { auto* i32 = create(); @@ -163,7 +185,7 @@ TEST_F(MslGeneratorImplTest, EmitType_I32) { EXPECT_EQ(out.str(), "int"); } -TEST_F(MslGeneratorImplTest, EmitType_Matrix) { +TEST_F(MslGeneratorImplTest, EmitType_Matrix_F32) { auto* f32 = create(); auto* vec3 = create(f32, 3u); auto* mat2x3 = create(vec3, 2u); @@ -175,6 +197,18 @@ TEST_F(MslGeneratorImplTest, EmitType_Matrix) { EXPECT_EQ(out.str(), "float2x3"); } +TEST_F(MslGeneratorImplTest, EmitType_Matrix_F16) { + auto* f16 = create(); + auto* vec3 = create(f16, 3u); + auto* mat2x3 = create(vec3, 2u); + + GeneratorImpl& gen = Build(); + + std::stringstream out; + ASSERT_TRUE(gen.EmitType(out, mat2x3, "")) << gen.error(); + EXPECT_EQ(out.str(), "half2x3"); +} + TEST_F(MslGeneratorImplTest, EmitType_Pointer) { auto* f32 = create(); auto* p = create(f32, ast::StorageClass::kWorkgroup, ast::Access::kReadWrite); diff --git a/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc b/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc index 46becb0dc0..10c5b3d599 100644 --- a/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc +++ b/src/tint/writer/msl/generator_impl_variable_decl_statement_test.cc @@ -154,6 +154,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_f16) { + Enable(ast::Extension::kF16); + + auto* C = Const("C", nullptr, Expr(1_h)); + Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half const l = 1.0h; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_AInt) { auto* C = Const("C", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a)); Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); @@ -208,6 +228,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_f16) { + Enable(ast::Extension::kF16); + + auto* C = Const("C", nullptr, vec3(1_h, 2_h, 3_h)); + Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half3 const l = half3(1.0h, 2.0h, 3.0h); +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_AFloat) { auto* C = Const("C", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a)); @@ -245,6 +285,26 @@ void f() { )"); } +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_f16) { + Enable(ast::Extension::kF16); + + auto* C = Const("C", nullptr, mat2x3(1_h, 2_h, 3_h, 4_h, 5_h, 6_h)); + Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +void f() { + half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h)); +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_arr_f32) { auto* C = Const("C", nullptr, Construct(ty.array(), 1_f, 2_f, 3_f)); Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); @@ -343,7 +403,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Struct) { )"); } -TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector) { +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector_f32) { auto* var = Var("a", ty.vec2()); auto* stmt = Decl(var); WrapInFunction(stmt); @@ -356,7 +416,22 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector) { EXPECT_EQ(gen.result(), " float2 a = 0.0f;\n"); } -TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix) { +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector_f16) { + Enable(ast::Extension::kF16); + + auto* var = Var("a", ty.vec2()); + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), " half2 a = 0.0h;\n"); +} + +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix_f32) { auto* var = Var("a", ty.mat3x2()); auto* stmt = Decl(var); @@ -370,6 +445,80 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix) { EXPECT_EQ(gen.result(), " float3x2 a = float3x2(0.0f);\n"); } +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix_f16) { + Enable(ast::Extension::kF16); + + auto* var = Var("a", ty.mat3x2()); + + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + gen.increment_indent(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), " half3x2 a = half3x2(0.0h);\n"); +} + +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec_f32) { + auto* var = Var("a", ty.vec3(), ast::StorageClass::kNone, vec3()); + + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f); +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec_f16) { + Enable(ast::Extension::kF16); + + auto* var = Var("a", ty.vec3(), ast::StorageClass::kNone, vec3()); + + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), R"(half3 a = half3(0.0h); +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroMat_f32) { + auto* var = Var("a", ty.mat2x3(), ast::StorageClass::kNone, mat2x3()); + + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), + R"(float2x3 a = float2x3(float3(0.0f), float3(0.0f)); +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroMat_f16) { + Enable(ast::Extension::kF16); + + auto* var = Var("a", ty.mat2x3(), ast::StorageClass::kNone, mat2x3()); + + auto* stmt = Decl(var); + WrapInFunction(stmt); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); + EXPECT_EQ(gen.result(), + R"(half2x3 a = half2x3(half3(0.0h), half3(0.0h)); +)"); +} + TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Private) { GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate); @@ -396,19 +545,5 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Workgroup) { EXPECT_THAT(gen.result(), HasSubstr("threadgroup float tint_symbol_2;\n")); } -TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec) { - auto* zero_vec = vec3(); - - auto* var = Var("a", ty.vec3(), ast::StorageClass::kNone, zero_vec); - auto* stmt = Decl(var); - WrapInFunction(stmt); - - GeneratorImpl& gen = Build(); - - ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); - EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f); -)"); -} - } // namespace } // namespace tint::writer::msl