mirror of
				https://github.com/encounter/dawn-cmake.git
				synced 2025-10-26 19:50:30 +00:00 
			
		
		
		
	tint/writer/msl: Support for F16 type, constructor, and convertor
This patch make MSL writer support emitting f16 types, f16 literals, f16 constructor and convertor. Unittests are also implemented. The MSL writer will emit f16 literal as `1.23h`, and map f16 types as follow: WGSL type -> MSL type f16 -> half vec2<f16> -> half2 vec3<f16> -> half3 vec4<f16> -> half4 mat2x2<f16> -> half2x2 mat2x3<f16> -> half2x3 mat2x4<f16> -> half2x4 mat3x2<f16> -> half3x2 mat3x3<f16> -> half3x3 mat3x4<f16> -> half3x4 mat4x2<f16> -> half4x2 mat4x3<f16> -> half4x3 mat4x4<f16> -> half4x4 Bug: tint:1473, tint:1502 Change-Id: Id91821e1a32d48c80bad9a0753faa5247835b0f7 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95686 Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
		
							parent
							
								
									0d3d3210d4
								
							
						
					
					
						commit
						e9c5070348
					
				| @ -39,6 +39,7 @@ | ||||
| #include "src/tint/sem/constant.h" | ||||
| #include "src/tint/sem/depth_multisampled_texture.h" | ||||
| #include "src/tint/sem/depth_texture.h" | ||||
| #include "src/tint/sem/f16.h" | ||||
| #include "src/tint/sem/f32.h" | ||||
| #include "src/tint/sem/function.h" | ||||
| #include "src/tint/sem/i32.h" | ||||
| @ -97,6 +98,20 @@ void PrintF32(std::ostream& out, float value) { | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| void PrintF16(std::ostream& out, float value) { | ||||
|     // Note: Currently inf and nan should not be constructable, but this is implemented for the day
 | ||||
|     // we support them.
 | ||||
|     if (std::isinf(value)) { | ||||
|         // HUGE_VALH evaluates to +infinity.
 | ||||
|         out << (value >= 0 ? "HUGE_VALH" : "-HUGE_VALH"); | ||||
|     } else if (std::isnan(value)) { | ||||
|         // There is no NaN expr for half in MSL, "NAN" is of float type.
 | ||||
|         out << "NAN"; | ||||
|     } else { | ||||
|         out << FloatToString(value) << "h"; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| void PrintI32(std::ostream& out, int32_t value) { | ||||
|     // MSL (and C++) parse `-2147483648` as a `long` because it parses unary minus and `2147483648`
 | ||||
|     // as separate tokens, and the latter doesn't fit into an (32-bit) `int`.
 | ||||
| @ -1551,10 +1566,8 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) { | ||||
|             return true; | ||||
|         }, | ||||
|         [&](const sem::F16*) { | ||||
|             // Placeholder for emitting f16 zero value
 | ||||
|             diagnostics_.add_error(diag::System::Writer, | ||||
|                                    "Type f16 is not completely implemented yet"); | ||||
|             return false; | ||||
|             out << "0.0h"; | ||||
|             return true; | ||||
|         }, | ||||
|         [&](const sem::F32*) { | ||||
|             out << "0.0f"; | ||||
| @ -1605,6 +1618,10 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan | ||||
|             PrintF32(out, constant->As<float>()); | ||||
|             return true; | ||||
|         }, | ||||
|         [&](const sem::F16*) { | ||||
|             PrintF16(out, constant->As<float>()); | ||||
|             return true; | ||||
|         }, | ||||
|         [&](const sem::I32*) { | ||||
|             PrintI32(out, constant->As<int32_t>()); | ||||
|             return true; | ||||
| @ -1694,7 +1711,11 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression* | ||||
|             return true; | ||||
|         }, | ||||
|         [&](const ast::FloatLiteralExpression* l) { | ||||
|             PrintF32(out, static_cast<float>(l->value)); | ||||
|             if (l->suffix == ast::FloatLiteralExpression::Suffix::kH) { | ||||
|                 PrintF16(out, static_cast<float>(l->value)); | ||||
|             } else { | ||||
|                 PrintF32(out, static_cast<float>(l->value)); | ||||
|             } | ||||
|             return true; | ||||
|         }, | ||||
|         [&](const ast::IntLiteralExpression* i) { | ||||
| @ -2459,9 +2480,8 @@ bool GeneratorImpl::EmitType(std::ostream& out, | ||||
|             return true; | ||||
|         }, | ||||
|         [&](const sem::F16*) { | ||||
|             diagnostics_.add_error(diag::System::Writer, | ||||
|                                    "Type f16 is not completely implemented yet"); | ||||
|             return false; | ||||
|             out << "half"; | ||||
|             return true; | ||||
|         }, | ||||
|         [&](const sem::F32*) { | ||||
|             out << "float"; | ||||
| @ -3043,20 +3063,25 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem:: | ||||
|         [&](const sem::F32*) { | ||||
|             return SizeAndAlign{4, 4}; | ||||
|         }, | ||||
|         [&](const sem::F16*) { | ||||
|             return SizeAndAlign{2, 2}; | ||||
|         }, | ||||
| 
 | ||||
|         [&](const sem::Vector* vec) { | ||||
|             auto num_els = vec->Width(); | ||||
|             auto* el_ty = vec->type(); | ||||
|             if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) { | ||||
|             SizeAndAlign el_size_align = MslPackedTypeSizeAndAlign(el_ty); | ||||
|             if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32, sem::F16>()) { | ||||
|                 // Use a packed_vec type for 3-element vectors only.
 | ||||
|                 if (num_els == 3) { | ||||
|                     // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
 | ||||
|                     // 2.2.3 Packed Vector Types
 | ||||
|                     return SizeAndAlign{num_els * 4, 4}; | ||||
|                     return SizeAndAlign{num_els * el_size_align.size, el_size_align.align}; | ||||
|                 } else { | ||||
|                     // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
 | ||||
|                     // 2.2 Vector Data Types
 | ||||
|                     return SizeAndAlign{num_els * 4, num_els * 4}; | ||||
|                     // Vector data types are aligned to their size.
 | ||||
|                     return SizeAndAlign{num_els * el_size_align.size, num_els * el_size_align.size}; | ||||
|                 } | ||||
|             } | ||||
|             TINT_UNREACHABLE(Writer, diagnostics_) | ||||
| @ -3070,8 +3095,9 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem:: | ||||
|             auto cols = mat->columns(); | ||||
|             auto rows = mat->rows(); | ||||
|             auto* el_ty = mat->type(); | ||||
|             if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) { | ||||
|                 static constexpr SizeAndAlign table[] = { | ||||
|             // Metal only support half and float matrix.
 | ||||
|             if (el_ty->IsAnyOf<sem::F32, sem::F16>()) { | ||||
|                 static constexpr SizeAndAlign table_f32[] = { | ||||
|                     /* float2x2 */ {16, 8}, | ||||
|                     /* float2x3 */ {32, 16}, | ||||
|                     /* float2x4 */ {32, 16}, | ||||
| @ -3082,8 +3108,23 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem:: | ||||
|                     /* float4x3 */ {64, 16}, | ||||
|                     /* float4x4 */ {64, 16}, | ||||
|                 }; | ||||
|                 static constexpr SizeAndAlign table_f16[] = { | ||||
|                     /* half2x2 */ {8, 4}, | ||||
|                     /* half2x3 */ {16, 8}, | ||||
|                     /* half2x4 */ {16, 8}, | ||||
|                     /* half3x2 */ {12, 4}, | ||||
|                     /* half3x3 */ {24, 8}, | ||||
|                     /* half3x4 */ {24, 8}, | ||||
|                     /* half4x2 */ {16, 4}, | ||||
|                     /* half4x3 */ {32, 8}, | ||||
|                     /* half4x4 */ {32, 8}, | ||||
|                 }; | ||||
|                 if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) { | ||||
|                     return table[(3 * (cols - 2)) + (rows - 2)]; | ||||
|                     if (el_ty->Is<sem::F32>()) { | ||||
|                         return table_f32[(3 * (cols - 2)) + (rows - 2)]; | ||||
|                     } else { | ||||
|                         return table_f16[(3 * (cols - 2)) + (rows - 2)]; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|  | ||||
| @ -61,6 +61,18 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Float) { | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("1073741824.0f")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_F16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     // Use a number close to 1<<16 but whose decimal representation ends in 0.
 | ||||
|     WrapInFunction(Expr(f16((1 << 15) - 8))); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("32752.0h")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) { | ||||
|     WrapInFunction(Construct<f32>(-1.2e-5_f)); | ||||
| 
 | ||||
| @ -70,6 +82,17 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) { | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("-0.000012f")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_F16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     WrapInFunction(Construct<f16>(-1.2e-3_h)); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("-0.00119972229h")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Bool) { | ||||
|     WrapInFunction(Construct<bool>(true)); | ||||
| 
 | ||||
| @ -97,7 +120,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Uint) { | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("12345u")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) { | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_F32) { | ||||
|     WrapInFunction(vec3<f32>(1_f, 2_f, 3_f)); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| @ -106,7 +129,18 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) { | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("float3(1.0f, 2.0f, 3.0f)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) { | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_F16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     WrapInFunction(vec3<f16>(1_h, 2_h, 3_h)); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("half3(1.0h, 2.0h, 3.0h)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty_F32) { | ||||
|     WrapInFunction(vec3<f32>()); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| @ -115,26 +149,230 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) { | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("float3(0.0f)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat) { | ||||
|     WrapInFunction(Construct(ty.mat2x3<f32>(), vec3<f32>(1_f, 2_f, 3_f), vec3<f32>(3_f, 4_f, 5_f))); | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty_F16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     WrapInFunction(vec3<f16>()); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("half3(0.0h)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F32_Literal) { | ||||
|     WrapInFunction(vec3<f32>(2_f)); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("float3(2.0f)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F16_Literal) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     WrapInFunction(vec3<f16>(2_h)); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("half3(2.0h)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F32_Var) { | ||||
|     auto* var = Var("v", nullptr, Expr(2_f)); | ||||
|     auto* cast = vec3<f32>(var); | ||||
|     WrapInFunction(var, cast); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr(R"(float v = 2.0f; | ||||
|   float3 const tint_symbol = float3(v);)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F16_Var) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* var = Var("v", nullptr, Expr(2_h)); | ||||
|     auto* cast = vec3<f16>(var); | ||||
|     WrapInFunction(var, cast); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr(R"(half v = 2.0h; | ||||
|   half3 const tint_symbol = half3(v);)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_Bool) { | ||||
|     WrapInFunction(vec3<bool>(true)); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("bool3(true)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_Int) { | ||||
|     WrapInFunction(vec3<i32>(2_i)); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("int3(2)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_UInt) { | ||||
|     WrapInFunction(vec3<u32>(2_u)); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("uint3(2u)")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_F32) { | ||||
|     WrapInFunction(mat2x3<f32>(vec3<f32>(1_f, 2_f, 3_f), vec3<f32>(3_f, 4_f, 5_f))); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     // A matrix of type T with n columns and m rows can also be constructed from
 | ||||
|     // n vectors of type T with m components.
 | ||||
|     EXPECT_THAT(gen.result(), | ||||
|                 HasSubstr("float2x3(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty) { | ||||
|     WrapInFunction(mat4x4<f32>()); | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_F16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     WrapInFunction(mat2x3<f16>(vec3<f16>(1_h, 2_h, 3_h), vec3<f16>(3_h, 4_h, 5_h))); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(0.0f), float4(0.0f)")); | ||||
| 
 | ||||
|     EXPECT_THAT(gen.result(), | ||||
|                 HasSubstr("half2x3(half3(1.0h, 2.0h, 3.0h), half3(3.0h, 4.0h, 5.0h))")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Complex_F32) { | ||||
|     // mat4x4<f32>(
 | ||||
|     //     vec4<f32>(2.0f, 3.0f, 4.0f, 8.0f),
 | ||||
|     //     vec4<f32>(),
 | ||||
|     //     vec4<f32>(7.0f),
 | ||||
|     //     vec4<f32>(vec4<f32>(42.0f, 21.0f, 6.0f, -5.0f)),
 | ||||
|     //   );
 | ||||
|     auto* vector_literal = | ||||
|         vec4<f32>(Expr(f32(2.0)), Expr(f32(3.0)), Expr(f32(4.0)), Expr(f32(8.0))); | ||||
|     auto* vector_zero_ctor = vec4<f32>(); | ||||
|     auto* vector_single_scalar_ctor = vec4<f32>(Expr(f32(7.0))); | ||||
|     auto* vector_identical_ctor = | ||||
|         vec4<f32>(vec4<f32>(Expr(f32(42.0)), Expr(f32(21.0)), Expr(f32(6.0)), Expr(f32(-5.0)))); | ||||
| 
 | ||||
|     auto* constructor = mat4x4<f32>(vector_literal, vector_zero_ctor, vector_single_scalar_ctor, | ||||
|                                     vector_identical_ctor); | ||||
| 
 | ||||
|     WrapInFunction(constructor); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(2.0f, 3.0f, 4.0f, 8.0f), float4(0.0f), " | ||||
|                                         "float4(7.0f), float4(42.0f, 21.0f, 6.0f, -5.0f))")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Complex_F16) { | ||||
|     // mat4x4<f16>(
 | ||||
|     //     vec4<f16>(2.0h, 3.0h, 4.0h, 8.0h),
 | ||||
|     //     vec4<f16>(),
 | ||||
|     //     vec4<f16>(7.0h),
 | ||||
|     //     vec4<f16>(vec4<f16>(42.0h, 21.0h, 6.0h, -5.0h)),
 | ||||
|     //   );
 | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* vector_literal = | ||||
|         vec4<f16>(Expr(f16(2.0)), Expr(f16(3.0)), Expr(f16(4.0)), Expr(f16(8.0))); | ||||
|     auto* vector_zero_ctor = vec4<f16>(); | ||||
|     auto* vector_single_scalar_ctor = vec4<f16>(Expr(f16(7.0))); | ||||
|     auto* vector_identical_ctor = | ||||
|         vec4<f16>(vec4<f16>(Expr(f16(42.0)), Expr(f16(21.0)), Expr(f16(6.0)), Expr(f16(-5.0)))); | ||||
| 
 | ||||
|     auto* constructor = mat4x4<f16>(vector_literal, vector_zero_ctor, vector_single_scalar_ctor, | ||||
|                                     vector_identical_ctor); | ||||
| 
 | ||||
|     WrapInFunction(constructor); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("half4x4(half4(2.0h, 3.0h, 4.0h, 8.0h), half4(0.0h), " | ||||
|                                         "half4(7.0h), half4(42.0h, 21.0h, 6.0h, -5.0h))")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty_F32) { | ||||
|     WrapInFunction(mat2x3<f32>()); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_THAT(gen.result(), | ||||
|                 HasSubstr("float2x3 const tint_symbol = float2x3(float3(0.0f), float3(0.0f))")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty_F16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     WrapInFunction(mat2x3<f16>()); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_THAT(gen.result(), | ||||
|                 HasSubstr("half2x3 const tint_symbol = half2x3(half3(0.0h), half3(0.0h))")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Identity_F32) { | ||||
|     // fn f() {
 | ||||
|     //     var m_1: mat4x4<f32> = mat4x4<f32>();
 | ||||
|     //     var m_2: mat4x4<f32> = mat4x4<f32>(m_1);
 | ||||
|     // }
 | ||||
| 
 | ||||
|     auto* m_1 = Var("m_1", ty.mat4x4(ty.f32()), mat4x4<f32>()); | ||||
|     auto* m_2 = Var("m_2", ty.mat4x4(ty.f32()), mat4x4<f32>(m_1)); | ||||
| 
 | ||||
|     WrapInFunction(m_1, m_2); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("float4x4 m_2 = float4x4(m_1);")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Identity_F16) { | ||||
|     // fn f() {
 | ||||
|     //     var m_1: mat4x4<f16> = mat4x4<f16>();
 | ||||
|     //     var m_2: mat4x4<f16> = mat4x4<f16>(m_1);
 | ||||
|     // }
 | ||||
| 
 | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* m_1 = Var("m_1", ty.mat4x4(ty.f16()), mat4x4<f16>()); | ||||
|     auto* m_2 = Var("m_2", ty.mat4x4(ty.f16()), mat4x4<f16>(m_1)); | ||||
| 
 | ||||
|     WrapInFunction(m_1, m_2); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("half4x4 m_2 = half4x4(m_1);")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Array) { | ||||
|  | ||||
| @ -112,6 +112,26 @@ void f() { | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_GlobalConst_f16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* var = GlobalConst("G", nullptr, Expr(1_h)); | ||||
|     Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_EQ(gen.result(), R"(#include <metal_stdlib> | ||||
| 
 | ||||
| using namespace metal; | ||||
| void f() { | ||||
|   half const l = 1.0h; | ||||
| } | ||||
| 
 | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_GlobalConst_vec3_AInt) { | ||||
|     auto* var = GlobalConst("G", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a)); | ||||
|     Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); | ||||
| @ -166,6 +186,26 @@ void f() { | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_GlobalConst_vec3_f16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* var = GlobalConst("G", nullptr, vec3<f16>(1_h, 2_h, 3_h)); | ||||
|     Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_EQ(gen.result(), R"(#include <metal_stdlib> | ||||
| 
 | ||||
| using namespace metal; | ||||
| void f() { | ||||
|   half3 const l = half3(1.0h, 2.0h, 3.0h); | ||||
| } | ||||
| 
 | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_GlobalConst_mat2x3_AFloat) { | ||||
|     auto* var = GlobalConst("G", nullptr, | ||||
|                             Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a)); | ||||
| @ -203,6 +243,26 @@ void f() { | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_GlobalConst_mat2x3_f16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* var = GlobalConst("G", nullptr, mat2x3<f16>(1_h, 2_h, 3_h, 4_h, 5_h, 6_h)); | ||||
|     Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_EQ(gen.result(), R"(#include <metal_stdlib> | ||||
| 
 | ||||
| using namespace metal; | ||||
| void f() { | ||||
|   half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h)); | ||||
| } | ||||
| 
 | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_GlobalConst_arr_f32) { | ||||
|     auto* var = GlobalConst("G", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f)); | ||||
|     Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); | ||||
|  | ||||
| @ -71,6 +71,18 @@ DECLARE_TYPE(float3x4, 48, 16); | ||||
| DECLARE_TYPE(float4x2, 32, 8); | ||||
| DECLARE_TYPE(float4x3, 64, 16); | ||||
| DECLARE_TYPE(float4x4, 64, 16); | ||||
| DECLARE_TYPE(half2, 4, 4); | ||||
| DECLARE_TYPE(packed_half3, 6, 2); | ||||
| DECLARE_TYPE(half4, 8, 8); | ||||
| DECLARE_TYPE(half2x2, 8, 4); | ||||
| DECLARE_TYPE(half2x3, 16, 8); | ||||
| DECLARE_TYPE(half2x4, 16, 8); | ||||
| DECLARE_TYPE(half3x2, 12, 4); | ||||
| DECLARE_TYPE(half3x3, 24, 8); | ||||
| DECLARE_TYPE(half3x4, 24, 8); | ||||
| DECLARE_TYPE(half4x2, 16, 4); | ||||
| DECLARE_TYPE(half4x3, 32, 8); | ||||
| DECLARE_TYPE(half4x4, 32, 8); | ||||
| using uint = unsigned int; | ||||
| 
 | ||||
| using MslGeneratorImplTest = TestHelper; | ||||
| @ -153,6 +165,16 @@ TEST_F(MslGeneratorImplTest, EmitType_F32) { | ||||
|     EXPECT_EQ(out.str(), "float"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitType_F16) { | ||||
|     auto* f16 = create<sem::F16>(); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     std::stringstream out; | ||||
|     ASSERT_TRUE(gen.EmitType(out, f16, "")) << gen.error(); | ||||
|     EXPECT_EQ(out.str(), "half"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitType_I32) { | ||||
|     auto* i32 = create<sem::I32>(); | ||||
| 
 | ||||
| @ -163,7 +185,7 @@ TEST_F(MslGeneratorImplTest, EmitType_I32) { | ||||
|     EXPECT_EQ(out.str(), "int"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitType_Matrix) { | ||||
| TEST_F(MslGeneratorImplTest, EmitType_Matrix_F32) { | ||||
|     auto* f32 = create<sem::F32>(); | ||||
|     auto* vec3 = create<sem::Vector>(f32, 3u); | ||||
|     auto* mat2x3 = create<sem::Matrix>(vec3, 2u); | ||||
| @ -175,6 +197,18 @@ TEST_F(MslGeneratorImplTest, EmitType_Matrix) { | ||||
|     EXPECT_EQ(out.str(), "float2x3"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitType_Matrix_F16) { | ||||
|     auto* f16 = create<sem::F16>(); | ||||
|     auto* vec3 = create<sem::Vector>(f16, 3u); | ||||
|     auto* mat2x3 = create<sem::Matrix>(vec3, 2u); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     std::stringstream out; | ||||
|     ASSERT_TRUE(gen.EmitType(out, mat2x3, "")) << gen.error(); | ||||
|     EXPECT_EQ(out.str(), "half2x3"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, EmitType_Pointer) { | ||||
|     auto* f32 = create<sem::F32>(); | ||||
|     auto* p = create<sem::Pointer>(f32, ast::StorageClass::kWorkgroup, ast::Access::kReadWrite); | ||||
|  | ||||
| @ -154,6 +154,26 @@ void f() { | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_f16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* C = Const("C", nullptr, Expr(1_h)); | ||||
|     Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_EQ(gen.result(), R"(#include <metal_stdlib> | ||||
| 
 | ||||
| using namespace metal; | ||||
| void f() { | ||||
|   half const l = 1.0h; | ||||
| } | ||||
| 
 | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_AInt) { | ||||
|     auto* C = Const("C", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a)); | ||||
|     Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); | ||||
| @ -208,6 +228,26 @@ void f() { | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_f16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* C = Const("C", nullptr, vec3<f16>(1_h, 2_h, 3_h)); | ||||
|     Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_EQ(gen.result(), R"(#include <metal_stdlib> | ||||
| 
 | ||||
| using namespace metal; | ||||
| void f() { | ||||
|   half3 const l = half3(1.0h, 2.0h, 3.0h); | ||||
| } | ||||
| 
 | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_AFloat) { | ||||
|     auto* C = | ||||
|         Const("C", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a)); | ||||
| @ -245,6 +285,26 @@ void f() { | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_f16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* C = Const("C", nullptr, mat2x3<f16>(1_h, 2_h, 3_h, 4_h, 5_h, 6_h)); | ||||
|     Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.Generate()) << gen.error(); | ||||
| 
 | ||||
|     EXPECT_EQ(gen.result(), R"(#include <metal_stdlib> | ||||
| 
 | ||||
| using namespace metal; | ||||
| void f() { | ||||
|   half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h)); | ||||
| } | ||||
| 
 | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_arr_f32) { | ||||
|     auto* C = Const("C", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f)); | ||||
|     Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); | ||||
| @ -343,7 +403,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Struct) { | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector) { | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector_f32) { | ||||
|     auto* var = Var("a", ty.vec2<f32>()); | ||||
|     auto* stmt = Decl(var); | ||||
|     WrapInFunction(stmt); | ||||
| @ -356,7 +416,22 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector) { | ||||
|     EXPECT_EQ(gen.result(), "  float2 a = 0.0f;\n"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix) { | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector_f16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* var = Var("a", ty.vec2<f16>()); | ||||
|     auto* stmt = Decl(var); | ||||
|     WrapInFunction(stmt); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     gen.increment_indent(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); | ||||
|     EXPECT_EQ(gen.result(), "  half2 a = 0.0h;\n"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix_f32) { | ||||
|     auto* var = Var("a", ty.mat3x2<f32>()); | ||||
| 
 | ||||
|     auto* stmt = Decl(var); | ||||
| @ -370,6 +445,80 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix) { | ||||
|     EXPECT_EQ(gen.result(), "  float3x2 a = float3x2(0.0f);\n"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix_f16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* var = Var("a", ty.mat3x2<f16>()); | ||||
| 
 | ||||
|     auto* stmt = Decl(var); | ||||
|     WrapInFunction(stmt); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     gen.increment_indent(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); | ||||
|     EXPECT_EQ(gen.result(), "  half3x2 a = half3x2(0.0h);\n"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec_f32) { | ||||
|     auto* var = Var("a", ty.vec3<f32>(), ast::StorageClass::kNone, vec3<f32>()); | ||||
| 
 | ||||
|     auto* stmt = Decl(var); | ||||
|     WrapInFunction(stmt); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); | ||||
|     EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f); | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec_f16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* var = Var("a", ty.vec3<f16>(), ast::StorageClass::kNone, vec3<f16>()); | ||||
| 
 | ||||
|     auto* stmt = Decl(var); | ||||
|     WrapInFunction(stmt); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); | ||||
|     EXPECT_EQ(gen.result(), R"(half3 a = half3(0.0h); | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroMat_f32) { | ||||
|     auto* var = Var("a", ty.mat2x3<f32>(), ast::StorageClass::kNone, mat2x3<f32>()); | ||||
| 
 | ||||
|     auto* stmt = Decl(var); | ||||
|     WrapInFunction(stmt); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); | ||||
|     EXPECT_EQ(gen.result(), | ||||
|               R"(float2x3 a = float2x3(float3(0.0f), float3(0.0f)); | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroMat_f16) { | ||||
|     Enable(ast::Extension::kF16); | ||||
| 
 | ||||
|     auto* var = Var("a", ty.mat2x3<f16>(), ast::StorageClass::kNone, mat2x3<f16>()); | ||||
| 
 | ||||
|     auto* stmt = Decl(var); | ||||
|     WrapInFunction(stmt); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); | ||||
|     EXPECT_EQ(gen.result(), | ||||
|               R"(half2x3 a = half2x3(half3(0.0h), half3(0.0h)); | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Private) { | ||||
|     GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate); | ||||
| 
 | ||||
| @ -396,19 +545,5 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Workgroup) { | ||||
|     EXPECT_THAT(gen.result(), HasSubstr("threadgroup float tint_symbol_2;\n")); | ||||
| } | ||||
| 
 | ||||
| TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec) { | ||||
|     auto* zero_vec = vec3<f32>(); | ||||
| 
 | ||||
|     auto* var = Var("a", ty.vec3<f32>(), ast::StorageClass::kNone, zero_vec); | ||||
|     auto* stmt = Decl(var); | ||||
|     WrapInFunction(stmt); | ||||
| 
 | ||||
|     GeneratorImpl& gen = Build(); | ||||
| 
 | ||||
|     ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error(); | ||||
|     EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f); | ||||
| )"); | ||||
| } | ||||
| 
 | ||||
| }  // namespace
 | ||||
| }  // namespace tint::writer::msl
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user