tint/writer/msl: Support for F16 type, constructor, and convertor

This patch make MSL writer support emitting f16 types, f16 literals,
f16 constructor and convertor. Unittests are also implemented.

The MSL writer will emit f16 literal as `1.23h`, and map f16 types as
follow:
WGSL type   -> MSL type
f16         -> half
vec2<f16>   -> half2
vec3<f16>   -> half3
vec4<f16>   -> half4
mat2x2<f16> -> half2x2
mat2x3<f16> -> half2x3
mat2x4<f16> -> half2x4
mat3x2<f16> -> half3x2
mat3x3<f16> -> half3x3
mat3x4<f16> -> half3x4
mat4x2<f16> -> half4x2
mat4x3<f16> -> half4x3
mat4x4<f16> -> half4x4

Bug: tint:1473, tint:1502
Change-Id: Id91821e1a32d48c80bad9a0753faa5247835b0f7
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95686
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Zhaoming Jiang 2022-07-11 03:03:11 +00:00 committed by Dawn LUCI CQ
parent 0d3d3210d4
commit e9c5070348
5 changed files with 548 additions and 40 deletions

View File

@ -39,6 +39,7 @@
#include "src/tint/sem/constant.h"
#include "src/tint/sem/depth_multisampled_texture.h"
#include "src/tint/sem/depth_texture.h"
#include "src/tint/sem/f16.h"
#include "src/tint/sem/f32.h"
#include "src/tint/sem/function.h"
#include "src/tint/sem/i32.h"
@ -97,6 +98,20 @@ void PrintF32(std::ostream& out, float value) {
}
}
void PrintF16(std::ostream& out, float value) {
// Note: Currently inf and nan should not be constructable, but this is implemented for the day
// we support them.
if (std::isinf(value)) {
// HUGE_VALH evaluates to +infinity.
out << (value >= 0 ? "HUGE_VALH" : "-HUGE_VALH");
} else if (std::isnan(value)) {
// There is no NaN expr for half in MSL, "NAN" is of float type.
out << "NAN";
} else {
out << FloatToString(value) << "h";
}
}
void PrintI32(std::ostream& out, int32_t value) {
// MSL (and C++) parse `-2147483648` as a `long` because it parses unary minus and `2147483648`
// as separate tokens, and the latter doesn't fit into an (32-bit) `int`.
@ -1551,10 +1566,8 @@ bool GeneratorImpl::EmitZeroValue(std::ostream& out, const sem::Type* type) {
return true;
},
[&](const sem::F16*) {
// Placeholder for emitting f16 zero value
diagnostics_.add_error(diag::System::Writer,
"Type f16 is not completely implemented yet");
return false;
out << "0.0h";
return true;
},
[&](const sem::F32*) {
out << "0.0f";
@ -1605,6 +1618,10 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan
PrintF32(out, constant->As<float>());
return true;
},
[&](const sem::F16*) {
PrintF16(out, constant->As<float>());
return true;
},
[&](const sem::I32*) {
PrintI32(out, constant->As<int32_t>());
return true;
@ -1694,7 +1711,11 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression*
return true;
},
[&](const ast::FloatLiteralExpression* l) {
PrintF32(out, static_cast<float>(l->value));
if (l->suffix == ast::FloatLiteralExpression::Suffix::kH) {
PrintF16(out, static_cast<float>(l->value));
} else {
PrintF32(out, static_cast<float>(l->value));
}
return true;
},
[&](const ast::IntLiteralExpression* i) {
@ -2459,9 +2480,8 @@ bool GeneratorImpl::EmitType(std::ostream& out,
return true;
},
[&](const sem::F16*) {
diagnostics_.add_error(diag::System::Writer,
"Type f16 is not completely implemented yet");
return false;
out << "half";
return true;
},
[&](const sem::F32*) {
out << "float";
@ -3043,20 +3063,25 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::
[&](const sem::F32*) {
return SizeAndAlign{4, 4};
},
[&](const sem::F16*) {
return SizeAndAlign{2, 2};
},
[&](const sem::Vector* vec) {
auto num_els = vec->Width();
auto* el_ty = vec->type();
if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
SizeAndAlign el_size_align = MslPackedTypeSizeAndAlign(el_ty);
if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32, sem::F16>()) {
// Use a packed_vec type for 3-element vectors only.
if (num_els == 3) {
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
// 2.2.3 Packed Vector Types
return SizeAndAlign{num_els * 4, 4};
return SizeAndAlign{num_els * el_size_align.size, el_size_align.align};
} else {
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
// 2.2 Vector Data Types
return SizeAndAlign{num_els * 4, num_els * 4};
// Vector data types are aligned to their size.
return SizeAndAlign{num_els * el_size_align.size, num_els * el_size_align.size};
}
}
TINT_UNREACHABLE(Writer, diagnostics_)
@ -3070,8 +3095,9 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::
auto cols = mat->columns();
auto rows = mat->rows();
auto* el_ty = mat->type();
if (el_ty->IsAnyOf<sem::U32, sem::I32, sem::F32>()) {
static constexpr SizeAndAlign table[] = {
// Metal only support half and float matrix.
if (el_ty->IsAnyOf<sem::F32, sem::F16>()) {
static constexpr SizeAndAlign table_f32[] = {
/* float2x2 */ {16, 8},
/* float2x3 */ {32, 16},
/* float2x4 */ {32, 16},
@ -3082,8 +3108,23 @@ GeneratorImpl::SizeAndAlign GeneratorImpl::MslPackedTypeSizeAndAlign(const sem::
/* float4x3 */ {64, 16},
/* float4x4 */ {64, 16},
};
static constexpr SizeAndAlign table_f16[] = {
/* half2x2 */ {8, 4},
/* half2x3 */ {16, 8},
/* half2x4 */ {16, 8},
/* half3x2 */ {12, 4},
/* half3x3 */ {24, 8},
/* half3x4 */ {24, 8},
/* half4x2 */ {16, 4},
/* half4x3 */ {32, 8},
/* half4x4 */ {32, 8},
};
if (cols >= 2 && cols <= 4 && rows >= 2 && rows <= 4) {
return table[(3 * (cols - 2)) + (rows - 2)];
if (el_ty->Is<sem::F32>()) {
return table_f32[(3 * (cols - 2)) + (rows - 2)];
} else {
return table_f16[(3 * (cols - 2)) + (rows - 2)];
}
}
}

View File

@ -61,6 +61,18 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Float) {
EXPECT_THAT(gen.result(), HasSubstr("1073741824.0f"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_F16) {
Enable(ast::Extension::kF16);
// Use a number close to 1<<16 but whose decimal representation ends in 0.
WrapInFunction(Expr(f16((1 << 15) - 8)));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("32752.0h"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) {
WrapInFunction(Construct<f32>(-1.2e-5_f));
@ -70,6 +82,17 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Float) {
EXPECT_THAT(gen.result(), HasSubstr("-0.000012f"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_F16) {
Enable(ast::Extension::kF16);
WrapInFunction(Construct<f16>(-1.2e-3_h));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("-0.00119972229h"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Bool) {
WrapInFunction(Construct<bool>(true));
@ -97,7 +120,7 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Uint) {
EXPECT_THAT(gen.result(), HasSubstr("12345u"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) {
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_F32) {
WrapInFunction(vec3<f32>(1_f, 2_f, 3_f));
GeneratorImpl& gen = Build();
@ -106,7 +129,18 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec) {
EXPECT_THAT(gen.result(), HasSubstr("float3(1.0f, 2.0f, 3.0f)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) {
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_F16) {
Enable(ast::Extension::kF16);
WrapInFunction(vec3<f16>(1_h, 2_h, 3_h));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("half3(1.0h, 2.0h, 3.0h)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty_F32) {
WrapInFunction(vec3<f32>());
GeneratorImpl& gen = Build();
@ -115,26 +149,230 @@ TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty) {
EXPECT_THAT(gen.result(), HasSubstr("float3(0.0f)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat) {
WrapInFunction(Construct(ty.mat2x3<f32>(), vec3<f32>(1_f, 2_f, 3_f), vec3<f32>(3_f, 4_f, 5_f)));
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_Empty_F16) {
Enable(ast::Extension::kF16);
WrapInFunction(vec3<f16>());
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("half3(0.0h)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F32_Literal) {
WrapInFunction(vec3<f32>(2_f));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("float3(2.0f)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F16_Literal) {
Enable(ast::Extension::kF16);
WrapInFunction(vec3<f16>(2_h));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("half3(2.0h)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F32_Var) {
auto* var = Var("v", nullptr, Expr(2_f));
auto* cast = vec3<f32>(var);
WrapInFunction(var, cast);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr(R"(float v = 2.0f;
float3 const tint_symbol = float3(v);)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_F16_Var) {
Enable(ast::Extension::kF16);
auto* var = Var("v", nullptr, Expr(2_h));
auto* cast = vec3<f16>(var);
WrapInFunction(var, cast);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr(R"(half v = 2.0h;
half3 const tint_symbol = half3(v);)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_Bool) {
WrapInFunction(vec3<bool>(true));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("bool3(true)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_Int) {
WrapInFunction(vec3<i32>(2_i));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("int3(2)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Vec_SingleScalar_UInt) {
WrapInFunction(vec3<u32>(2_u));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("uint3(2u)"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_F32) {
WrapInFunction(mat2x3<f32>(vec3<f32>(1_f, 2_f, 3_f), vec3<f32>(3_f, 4_f, 5_f)));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
// A matrix of type T with n columns and m rows can also be constructed from
// n vectors of type T with m components.
EXPECT_THAT(gen.result(),
HasSubstr("float2x3(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty) {
WrapInFunction(mat4x4<f32>());
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_F16) {
Enable(ast::Extension::kF16);
WrapInFunction(mat2x3<f16>(vec3<f16>(1_h, 2_h, 3_h), vec3<f16>(3_h, 4_h, 5_h)));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(0.0f), float4(0.0f)"));
EXPECT_THAT(gen.result(),
HasSubstr("half2x3(half3(1.0h, 2.0h, 3.0h), half3(3.0h, 4.0h, 5.0h))"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Complex_F32) {
// mat4x4<f32>(
// vec4<f32>(2.0f, 3.0f, 4.0f, 8.0f),
// vec4<f32>(),
// vec4<f32>(7.0f),
// vec4<f32>(vec4<f32>(42.0f, 21.0f, 6.0f, -5.0f)),
// );
auto* vector_literal =
vec4<f32>(Expr(f32(2.0)), Expr(f32(3.0)), Expr(f32(4.0)), Expr(f32(8.0)));
auto* vector_zero_ctor = vec4<f32>();
auto* vector_single_scalar_ctor = vec4<f32>(Expr(f32(7.0)));
auto* vector_identical_ctor =
vec4<f32>(vec4<f32>(Expr(f32(42.0)), Expr(f32(21.0)), Expr(f32(6.0)), Expr(f32(-5.0))));
auto* constructor = mat4x4<f32>(vector_literal, vector_zero_ctor, vector_single_scalar_ctor,
vector_identical_ctor);
WrapInFunction(constructor);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("float4x4(float4(2.0f, 3.0f, 4.0f, 8.0f), float4(0.0f), "
"float4(7.0f), float4(42.0f, 21.0f, 6.0f, -5.0f))"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Complex_F16) {
// mat4x4<f16>(
// vec4<f16>(2.0h, 3.0h, 4.0h, 8.0h),
// vec4<f16>(),
// vec4<f16>(7.0h),
// vec4<f16>(vec4<f16>(42.0h, 21.0h, 6.0h, -5.0h)),
// );
Enable(ast::Extension::kF16);
auto* vector_literal =
vec4<f16>(Expr(f16(2.0)), Expr(f16(3.0)), Expr(f16(4.0)), Expr(f16(8.0)));
auto* vector_zero_ctor = vec4<f16>();
auto* vector_single_scalar_ctor = vec4<f16>(Expr(f16(7.0)));
auto* vector_identical_ctor =
vec4<f16>(vec4<f16>(Expr(f16(42.0)), Expr(f16(21.0)), Expr(f16(6.0)), Expr(f16(-5.0))));
auto* constructor = mat4x4<f16>(vector_literal, vector_zero_ctor, vector_single_scalar_ctor,
vector_identical_ctor);
WrapInFunction(constructor);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("half4x4(half4(2.0h, 3.0h, 4.0h, 8.0h), half4(0.0h), "
"half4(7.0h), half4(42.0h, 21.0h, 6.0h, -5.0h))"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty_F32) {
WrapInFunction(mat2x3<f32>());
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(),
HasSubstr("float2x3 const tint_symbol = float2x3(float3(0.0f), float3(0.0f))"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Empty_F16) {
Enable(ast::Extension::kF16);
WrapInFunction(mat2x3<f16>());
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(),
HasSubstr("half2x3 const tint_symbol = half2x3(half3(0.0h), half3(0.0h))"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Identity_F32) {
// fn f() {
// var m_1: mat4x4<f32> = mat4x4<f32>();
// var m_2: mat4x4<f32> = mat4x4<f32>(m_1);
// }
auto* m_1 = Var("m_1", ty.mat4x4(ty.f32()), mat4x4<f32>());
auto* m_2 = Var("m_2", ty.mat4x4(ty.f32()), mat4x4<f32>(m_1));
WrapInFunction(m_1, m_2);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("float4x4 m_2 = float4x4(m_1);"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Mat_Identity_F16) {
// fn f() {
// var m_1: mat4x4<f16> = mat4x4<f16>();
// var m_2: mat4x4<f16> = mat4x4<f16>(m_1);
// }
Enable(ast::Extension::kF16);
auto* m_1 = Var("m_1", ty.mat4x4(ty.f16()), mat4x4<f16>());
auto* m_2 = Var("m_2", ty.mat4x4(ty.f16()), mat4x4<f16>(m_1));
WrapInFunction(m_1, m_2);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("half4x4 m_2 = half4x4(m_1);"));
}
TEST_F(MslGeneratorImplTest, EmitConstructor_Type_Array) {

View File

@ -112,6 +112,26 @@ void f() {
)");
}
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_f16) {
Enable(ast::Extension::kF16);
auto* var = GlobalConst("G", nullptr, Expr(1_h));
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
void f() {
half const l = 1.0h;
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_vec3_AInt) {
auto* var = GlobalConst("G", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a));
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
@ -166,6 +186,26 @@ void f() {
)");
}
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_vec3_f16) {
Enable(ast::Extension::kF16);
auto* var = GlobalConst("G", nullptr, vec3<f16>(1_h, 2_h, 3_h));
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
void f() {
half3 const l = half3(1.0h, 2.0h, 3.0h);
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_mat2x3_AFloat) {
auto* var = GlobalConst("G", nullptr,
Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a));
@ -203,6 +243,26 @@ void f() {
)");
}
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_mat2x3_f16) {
Enable(ast::Extension::kF16);
auto* var = GlobalConst("G", nullptr, mat2x3<f16>(1_h, 2_h, 3_h, 4_h, 5_h, 6_h));
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
void f() {
half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h));
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_GlobalConst_arr_f32) {
auto* var = GlobalConst("G", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});

View File

@ -71,6 +71,18 @@ DECLARE_TYPE(float3x4, 48, 16);
DECLARE_TYPE(float4x2, 32, 8);
DECLARE_TYPE(float4x3, 64, 16);
DECLARE_TYPE(float4x4, 64, 16);
DECLARE_TYPE(half2, 4, 4);
DECLARE_TYPE(packed_half3, 6, 2);
DECLARE_TYPE(half4, 8, 8);
DECLARE_TYPE(half2x2, 8, 4);
DECLARE_TYPE(half2x3, 16, 8);
DECLARE_TYPE(half2x4, 16, 8);
DECLARE_TYPE(half3x2, 12, 4);
DECLARE_TYPE(half3x3, 24, 8);
DECLARE_TYPE(half3x4, 24, 8);
DECLARE_TYPE(half4x2, 16, 4);
DECLARE_TYPE(half4x3, 32, 8);
DECLARE_TYPE(half4x4, 32, 8);
using uint = unsigned int;
using MslGeneratorImplTest = TestHelper;
@ -153,6 +165,16 @@ TEST_F(MslGeneratorImplTest, EmitType_F32) {
EXPECT_EQ(out.str(), "float");
}
TEST_F(MslGeneratorImplTest, EmitType_F16) {
auto* f16 = create<sem::F16>();
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitType(out, f16, "")) << gen.error();
EXPECT_EQ(out.str(), "half");
}
TEST_F(MslGeneratorImplTest, EmitType_I32) {
auto* i32 = create<sem::I32>();
@ -163,7 +185,7 @@ TEST_F(MslGeneratorImplTest, EmitType_I32) {
EXPECT_EQ(out.str(), "int");
}
TEST_F(MslGeneratorImplTest, EmitType_Matrix) {
TEST_F(MslGeneratorImplTest, EmitType_Matrix_F32) {
auto* f32 = create<sem::F32>();
auto* vec3 = create<sem::Vector>(f32, 3u);
auto* mat2x3 = create<sem::Matrix>(vec3, 2u);
@ -175,6 +197,18 @@ TEST_F(MslGeneratorImplTest, EmitType_Matrix) {
EXPECT_EQ(out.str(), "float2x3");
}
TEST_F(MslGeneratorImplTest, EmitType_Matrix_F16) {
auto* f16 = create<sem::F16>();
auto* vec3 = create<sem::Vector>(f16, 3u);
auto* mat2x3 = create<sem::Matrix>(vec3, 2u);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitType(out, mat2x3, "")) << gen.error();
EXPECT_EQ(out.str(), "half2x3");
}
TEST_F(MslGeneratorImplTest, EmitType_Pointer) {
auto* f32 = create<sem::F32>();
auto* p = create<sem::Pointer>(f32, ast::StorageClass::kWorkgroup, ast::Access::kReadWrite);

View File

@ -154,6 +154,26 @@ void f() {
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_f16) {
Enable(ast::Extension::kF16);
auto* C = Const("C", nullptr, Expr(1_h));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
void f() {
half const l = 1.0h;
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_AInt) {
auto* C = Const("C", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
@ -208,6 +228,26 @@ void f() {
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_f16) {
Enable(ast::Extension::kF16);
auto* C = Const("C", nullptr, vec3<f16>(1_h, 2_h, 3_h));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
void f() {
half3 const l = half3(1.0h, 2.0h, 3.0h);
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_AFloat) {
auto* C =
Const("C", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a));
@ -245,6 +285,26 @@ void f() {
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_f16) {
Enable(ast::Extension::kF16);
auto* C = Const("C", nullptr, mat2x3<f16>(1_h, 2_h, 3_h, 4_h, 5_h, 6_h));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(#include <metal_stdlib>
using namespace metal;
void f() {
half2x3 const l = half2x3(half3(1.0h, 2.0h, 3.0h), half3(4.0h, 5.0h, 6.0h));
}
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Const_arr_f32) {
auto* C = Const("C", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
@ -343,7 +403,7 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Struct) {
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector) {
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector_f32) {
auto* var = Var("a", ty.vec2<f32>());
auto* stmt = Decl(var);
WrapInFunction(stmt);
@ -356,7 +416,22 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector) {
EXPECT_EQ(gen.result(), " float2 a = 0.0f;\n");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix) {
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Vector_f16) {
Enable(ast::Extension::kF16);
auto* var = Var("a", ty.vec2<f16>());
auto* stmt = Decl(var);
WrapInFunction(stmt);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(), " half2 a = 0.0h;\n");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix_f32) {
auto* var = Var("a", ty.mat3x2<f32>());
auto* stmt = Decl(var);
@ -370,6 +445,80 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix) {
EXPECT_EQ(gen.result(), " float3x2 a = float3x2(0.0f);\n");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Matrix_f16) {
Enable(ast::Extension::kF16);
auto* var = Var("a", ty.mat3x2<f16>());
auto* stmt = Decl(var);
WrapInFunction(stmt);
GeneratorImpl& gen = Build();
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(), " half3x2 a = half3x2(0.0h);\n");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec_f32) {
auto* var = Var("a", ty.vec3<f32>(), ast::StorageClass::kNone, vec3<f32>());
auto* stmt = Decl(var);
WrapInFunction(stmt);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f);
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec_f16) {
Enable(ast::Extension::kF16);
auto* var = Var("a", ty.vec3<f16>(), ast::StorageClass::kNone, vec3<f16>());
auto* stmt = Decl(var);
WrapInFunction(stmt);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(), R"(half3 a = half3(0.0h);
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroMat_f32) {
auto* var = Var("a", ty.mat2x3<f32>(), ast::StorageClass::kNone, mat2x3<f32>());
auto* stmt = Decl(var);
WrapInFunction(stmt);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(),
R"(float2x3 a = float2x3(float3(0.0f), float3(0.0f));
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroMat_f16) {
Enable(ast::Extension::kF16);
auto* var = Var("a", ty.mat2x3<f16>(), ast::StorageClass::kNone, mat2x3<f16>());
auto* stmt = Decl(var);
WrapInFunction(stmt);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(),
R"(half2x3 a = half2x3(half3(0.0h), half3(0.0h));
)");
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Private) {
GlobalVar("a", ty.f32(), ast::StorageClass::kPrivate);
@ -396,19 +545,5 @@ TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Workgroup) {
EXPECT_THAT(gen.result(), HasSubstr("threadgroup float tint_symbol_2;\n"));
}
TEST_F(MslGeneratorImplTest, Emit_VariableDeclStatement_Initializer_ZeroVec) {
auto* zero_vec = vec3<f32>();
auto* var = Var("a", ty.vec3<f32>(), ast::StorageClass::kNone, zero_vec);
auto* stmt = Decl(var);
WrapInFunction(stmt);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(), R"(float3 a = float3(0.0f);
)");
}
} // namespace
} // namespace tint::writer::msl