tint/writer/hlsl: Support for F16 type, constructor, and convertor

This patch make HLSL writer support emitting f16 types, f16 literals,
f16 constructor and convertor. Unittests are also implemented.

The HLSL writer will emit f16 literal as `float16_t(1.23h)`, making the
type explicit, and map f16 types as follow. The generated code require
DXC with SM6.0 or higher, and `-enable-16bit-types`.
WGSL type   -> HLSL type
f16         -> float16_t
vec2<f16>   -> vector<float16_t, 2>
vec3<f16>   -> vector<float16_t, 3>
vec4<f16>   -> vector<float16_t, 4>
mat2x2<f16> -> matrix<float16_t, 2, 2>
mat2x3<f16> -> matrix<float16_t, 2, 3>
mat2x4<f16> -> matrix<float16_t, 2, 4>
mat3x2<f16> -> matrix<float16_t, 3, 2>
mat3x3<f16> -> matrix<float16_t, 3, 3>
mat3x4<f16> -> matrix<float16_t, 3, 4>
mat4x2<f16> -> matrix<float16_t, 4, 2>
mat4x3<f16> -> matrix<float16_t, 4, 3>
mat4x4<f16> -> matrix<float16_t, 4, 4>

Bug: tint:1473, tint:1502
Change-Id: Iaf564f3ce29ace2984cef19d7df5a7dfb0fab2ef
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95685
Commit-Queue: Zhaoming Jiang <zhaoming.jiang@intel.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Zhaoming Jiang 2022-07-11 15:43:38 +00:00 committed by Dawn LUCI CQ
parent 94c6495ad2
commit a5988a3058
5 changed files with 357 additions and 14 deletions

View File

@ -122,6 +122,16 @@ void PrintF32(std::ostream& out, float value) {
} }
} }
bool PrintF16(std::ostream& out, float value) {
// Note: Currently inf and nan should not be constructable, don't emit them.
if (std::isinf(value) || std::isnan(value)) {
return false;
} else {
out << FloatToString(value) << "h";
return true;
}
}
// Helper for writing " : register(RX, spaceY)", where R is the register, X is // Helper for writing " : register(RX, spaceY)", where R is the register, X is
// the binding point binding value, and Y is the binding point group value. // the binding point binding value, and Y is the binding point group value.
struct RegisterAndSpace { struct RegisterAndSpace {
@ -3122,6 +3132,13 @@ bool GeneratorImpl::EmitConstant(std::ostream& out, const sem::Constant* constan
PrintF32(out, constant->As<float>()); PrintF32(out, constant->As<float>());
return true; return true;
}, },
[&](const sem::F16*) {
// emit a f16 scalar with explicit float16_t type declaration.
out << "float16_t(";
bool valid = PrintF16(out, constant->As<float>());
out << ")";
return valid;
},
[&](const sem::I32*) { [&](const sem::I32*) {
out << constant->As<AInt>(); out << constant->As<AInt>();
return true; return true;
@ -3218,6 +3235,13 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression*
return true; return true;
}, },
[&](const ast::FloatLiteralExpression* l) { [&](const ast::FloatLiteralExpression* l) {
if (l->suffix == ast::FloatLiteralExpression::Suffix::kH) {
// Emit f16 literal with explicit float16_t type declaration.
out << "float16_t(";
bool valid = PrintF16(out, static_cast<float>(l->value));
out << ")";
return valid;
}
PrintF32(out, static_cast<float>(l->value)); PrintF32(out, static_cast<float>(l->value));
return true; return true;
}, },
@ -3251,6 +3275,10 @@ bool GeneratorImpl::EmitValue(std::ostream& out, const sem::Type* type, int valu
out << value << ".0f"; out << value << ".0f";
return true; return true;
}, },
[&](const sem::F16*) {
out << "float16_t(" << value << ".0h)";
return true;
},
[&](const sem::I32*) { [&](const sem::I32*) {
out << value; out << value;
return true; return true;
@ -3723,15 +3751,23 @@ bool GeneratorImpl::EmitType(std::ostream& out,
return true; return true;
}, },
[&](const sem::F16*) { [&](const sem::F16*) {
diagnostics_.add_error(diag::System::Writer, out << "float16_t";
"Type f16 is not completely implemented yet."); return true;
return false;
}, },
[&](const sem::I32*) { [&](const sem::I32*) {
out << "int"; out << "int";
return true; return true;
}, },
[&](const sem::Matrix* mat) { [&](const sem::Matrix* mat) {
if (mat->type()->Is<sem::F16>()) {
// Use matrix<type, N, M> for f16 matrix
out << "matrix<";
if (!EmitType(out, mat->type(), storage_class, access, "")) {
return false;
}
out << ", " << mat->columns() << ", " << mat->rows() << ">";
return true;
}
if (!EmitType(out, mat->type(), storage_class, access, "")) { if (!EmitType(out, mat->type(), storage_class, access, "")) {
return false; return false;
} }
@ -3847,6 +3883,7 @@ bool GeneratorImpl::EmitType(std::ostream& out,
} else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) { } else if (vec->type()->Is<sem::Bool>() && width >= 1 && width <= 4) {
out << "bool" << width; out << "bool" << width;
} else { } else {
// For example, use "vector<float16_t, N>" for f16 vector.
out << "vector<"; out << "vector<";
if (!EmitType(out, vec->type(), storage_class, access, "")) { if (!EmitType(out, vec->type(), storage_class, access, "")) {
return false; return false;

View File

@ -61,6 +61,18 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Float) {
EXPECT_THAT(gen.result(), HasSubstr("1073741824.0f")); EXPECT_THAT(gen.result(), HasSubstr("1073741824.0f"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_F16) {
Enable(ast::Extension::kF16);
// Use a number close to 1<<16 but whose decimal representation ends in 0.
WrapInFunction(Expr(f16((1 << 15) - 8)));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("float16_t(32752.0h)"));
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Float) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Float) {
WrapInFunction(Construct<f32>(-1.2e-5_f)); WrapInFunction(Construct<f32>(-1.2e-5_f));
@ -70,6 +82,17 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Float) {
EXPECT_THAT(gen.result(), HasSubstr("-0.000012f")); EXPECT_THAT(gen.result(), HasSubstr("-0.000012f"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_F16) {
Enable(ast::Extension::kF16);
WrapInFunction(Construct<f16>(-1.2e-3_h));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("float16_t(-0.00119972229h)"));
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Bool) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Bool) {
WrapInFunction(Construct<bool>(true)); WrapInFunction(Construct<bool>(true));
@ -97,7 +120,7 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Uint) {
EXPECT_THAT(gen.result(), HasSubstr("12345u")); EXPECT_THAT(gen.result(), HasSubstr("12345u"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_F32) {
WrapInFunction(vec3<f32>(1_f, 2_f, 3_f)); WrapInFunction(vec3<f32>(1_f, 2_f, 3_f));
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -106,7 +129,20 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec) {
EXPECT_THAT(gen.result(), HasSubstr("float3(1.0f, 2.0f, 3.0f)")); EXPECT_THAT(gen.result(), HasSubstr("float3(1.0f, 2.0f, 3.0f)"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_Empty) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_F16) {
Enable(ast::Extension::kF16);
WrapInFunction(vec3<f16>(1_h, 2_h, 3_h));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(
gen.result(),
HasSubstr("vector<float16_t, 3>(float16_t(1.0h), float16_t(2.0h), float16_t(3.0h))"));
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_Empty_F32) {
WrapInFunction(vec3<f32>()); WrapInFunction(vec3<f32>());
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -115,7 +151,18 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_Empty) {
EXPECT_THAT(gen.result(), HasSubstr("0.0f).xxx")); EXPECT_THAT(gen.result(), HasSubstr("0.0f).xxx"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_Float_Literal) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_Empty_F16) {
Enable(ast::Extension::kF16);
WrapInFunction(vec3<f16>());
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("(float16_t(0.0h)).xxx"));
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_F32_Literal) {
WrapInFunction(vec3<f32>(2_f)); WrapInFunction(vec3<f32>(2_f));
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -124,7 +171,18 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_
EXPECT_THAT(gen.result(), HasSubstr("2.0f).xxx")); EXPECT_THAT(gen.result(), HasSubstr("2.0f).xxx"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_Float_Var) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_F16_Literal) {
Enable(ast::Extension::kF16);
WrapInFunction(vec3<f16>(2_h));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("(float16_t(2.0h)).xxx"));
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_F32_Var) {
auto* var = Var("v", nullptr, Expr(2_f)); auto* var = Var("v", nullptr, Expr(2_f));
auto* cast = vec3<f32>(var); auto* cast = vec3<f32>(var);
WrapInFunction(var, cast); WrapInFunction(var, cast);
@ -136,6 +194,20 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_
const float3 tint_symbol = float3((v).xxx);)")); const float3 tint_symbol = float3((v).xxx);)"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_F16_Var) {
Enable(ast::Extension::kF16);
auto* var = Var("v", nullptr, Expr(2_h));
auto* cast = vec3<f16>(var);
WrapInFunction(var, cast);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr(R"(float16_t v = float16_t(2.0h);
const vector<float16_t, 3> tint_symbol = vector<float16_t, 3>((v).xxx);)"));
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_Bool_Literal) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_Bool_Literal) {
WrapInFunction(vec3<bool>(true)); WrapInFunction(vec3<bool>(true));
@ -175,7 +247,7 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Vec_SingleScalar_
EXPECT_THAT(gen.result(), HasSubstr("2u).xxx")); EXPECT_THAT(gen.result(), HasSubstr("2u).xxx"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_F32) {
WrapInFunction(mat2x3<f32>(vec3<f32>(1_f, 2_f, 3_f), vec3<f32>(3_f, 4_f, 5_f))); WrapInFunction(mat2x3<f32>(vec3<f32>(1_f, 2_f, 3_f), vec3<f32>(3_f, 4_f, 5_f)));
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -186,7 +258,22 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat) {
HasSubstr("float2x3(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))")); HasSubstr("float2x3(float3(1.0f, 2.0f, 3.0f), float3(3.0f, 4.0f, 5.0f))"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Complex) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_F16) {
Enable(ast::Extension::kF16);
WrapInFunction(mat2x3<f16>(vec3<f16>(1_h, 2_h, 3_h), vec3<f16>(3_h, 4_h, 5_h)));
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(),
HasSubstr("matrix<float16_t, 2, 3>(vector<float16_t, 3>(float16_t(1.0h), "
"float16_t(2.0h), float16_t(3.0h)), vector<float16_t, "
"3>(float16_t(3.0h), float16_t(4.0h), float16_t(5.0h)))"));
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Complex_F32) {
// mat4x4<f32>( // mat4x4<f32>(
// vec4<f32>(2.0f, 3.0f, 4.0f, 8.0f), // vec4<f32>(2.0f, 3.0f, 4.0f, 8.0f),
// vec4<f32>(), // vec4<f32>(),
@ -213,7 +300,40 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Complex) {
"(7.0f).xxxx, float4(42.0f, 21.0f, 6.0f, -5.0f))")); "(7.0f).xxxx, float4(42.0f, 21.0f, 6.0f, -5.0f))"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Empty) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Complex_F16) {
// mat4x4<f16>(
// vec4<f16>(2.0h, 3.0h, 4.0h, 8.0h),
// vec4<f16>(),
// vec4<f16>(7.0h),
// vec4<f16>(vec4<f16>(42.0h, 21.0h, 6.0h, -5.0h)),
// );
Enable(ast::Extension::kF16);
auto* vector_literal =
vec4<f16>(Expr(f16(2.0)), Expr(f16(3.0)), Expr(f16(4.0)), Expr(f16(8.0)));
auto* vector_zero_ctor = vec4<f16>();
auto* vector_single_scalar_ctor = vec4<f16>(Expr(f16(7.0)));
auto* vector_identical_ctor =
vec4<f16>(vec4<f16>(Expr(f16(42.0)), Expr(f16(21.0)), Expr(f16(6.0)), Expr(f16(-5.0))));
auto* constructor = mat4x4<f16>(vector_literal, vector_zero_ctor, vector_single_scalar_ctor,
vector_identical_ctor);
WrapInFunction(constructor);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(), HasSubstr("matrix<float16_t, 4, 4>("
"vector<float16_t, 4>(float16_t(2.0h), float16_t(3.0h), "
"float16_t(4.0h), float16_t(8.0h)), "
"(float16_t(0.0h)).xxxx, (float16_t(7.0h)).xxxx, "
"vector<float16_t, 4>(float16_t(42.0h), float16_t(21.0h), "
"float16_t(6.0h), float16_t(-5.0h)))"));
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Empty_F32) {
WrapInFunction(mat2x3<f32>()); WrapInFunction(mat2x3<f32>());
GeneratorImpl& gen = Build(); GeneratorImpl& gen = Build();
@ -223,7 +343,20 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Empty) {
EXPECT_THAT(gen.result(), HasSubstr("float2x3 tint_symbol = float2x3((0.0f).xxx, (0.0f).xxx)")); EXPECT_THAT(gen.result(), HasSubstr("float2x3 tint_symbol = float2x3((0.0f).xxx, (0.0f).xxx)"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Identity) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Empty_F16) {
Enable(ast::Extension::kF16);
WrapInFunction(mat2x3<f16>());
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(),
HasSubstr("matrix<float16_t, 2, 3>((float16_t(0.0h)).xxx, (float16_t(0.0h)).xxx)"));
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Identity_F32) {
// fn f() { // fn f() {
// var m_1: mat4x4<f32> = mat4x4<f32>(); // var m_1: mat4x4<f32> = mat4x4<f32>();
// var m_2: mat4x4<f32> = mat4x4<f32>(m_1); // var m_2: mat4x4<f32> = mat4x4<f32>(m_1);
@ -241,6 +374,27 @@ TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Identity) {
EXPECT_THAT(gen.result(), HasSubstr("float4x4 m_2 = float4x4(m_1);")); EXPECT_THAT(gen.result(), HasSubstr("float4x4 m_2 = float4x4(m_1);"));
} }
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Mat_Identity_F16) {
// fn f() {
// var m_1: mat4x4<f16> = mat4x4<f16>();
// var m_2: mat4x4<f16> = mat4x4<f16>(m_1);
// }
Enable(ast::Extension::kF16);
auto* m_1 = Var("m_1", ty.mat4x4(ty.f16()), mat4x4<f16>());
auto* m_2 = Var("m_2", ty.mat4x4(ty.f16()), mat4x4<f16>(m_1));
WrapInFunction(m_1, m_2);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_THAT(gen.result(),
HasSubstr("matrix<float16_t, 4, 4> m_2 = matrix<float16_t, 4, 4>(m_1);"));
}
TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Array) { TEST_F(HlslGeneratorImplTest_Constructor, EmitConstructor_Type_Array) {
WrapInFunction(Construct(ty.array(ty.vec3<f32>(), 3_u), vec3<f32>(1_f, 2_f, 3_f), WrapInFunction(Construct(ty.array(ty.vec3<f32>(), 3_u), vec3<f32>(1_f, 2_f, 3_f),
vec3<f32>(4_f, 5_f, 6_f), vec3<f32>(7_f, 8_f, 9_f))); vec3<f32>(4_f, 5_f, 6_f), vec3<f32>(7_f, 8_f, 9_f)));

View File

@ -92,6 +92,22 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_f32) {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_f16) {
Enable(ast::Extension::kF16);
auto* var = GlobalConst("G", nullptr, Expr(1_h));
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(void f() {
const float16_t l = float16_t(1.0h);
}
)");
}
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_vec3_AInt) { TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_vec3_AInt) {
auto* var = GlobalConst("G", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a)); auto* var = GlobalConst("G", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a));
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
@ -134,6 +150,22 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_vec3_f32) {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_vec3_f16) {
Enable(ast::Extension::kF16);
auto* var = GlobalConst("G", nullptr, vec3<f16>(1_h, 2_h, 3_h));
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(void f() {
const vector<float16_t, 3> l = vector<float16_t, 3>(float16_t(1.0h), float16_t(2.0h), float16_t(3.0h));
}
)");
}
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_mat2x3_AFloat) { TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_mat2x3_AFloat) {
auto* var = GlobalConst("G", nullptr, auto* var = GlobalConst("G", nullptr,
Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a)); Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a));
@ -163,6 +195,22 @@ TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_mat2x3_f32) {
)"); )");
} }
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_mat2x3_f16) {
Enable(ast::Extension::kF16);
auto* var = GlobalConst("G", nullptr, mat2x3<f16>(1_h, 2_h, 3_h, 4_h, 5_h, 6_h));
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(void f() {
const matrix<float16_t, 2, 3> l = matrix<float16_t, 2, 3>(vector<float16_t, 3>(float16_t(1.0h), float16_t(2.0h), float16_t(3.0h)), vector<float16_t, 3>(float16_t(4.0h), float16_t(5.0h), float16_t(6.0h)));
}
)");
}
TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_arr_f32) { TEST_F(HlslGeneratorImplTest_ModuleConstant, Emit_GlobalConst_arr_f32) {
auto* var = GlobalConst("G", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f)); auto* var = GlobalConst("G", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))}); Func("f", {}, ty.void_(), {Decl(Let("l", nullptr, Expr(var)))});

View File

@ -94,6 +94,17 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_Bool) {
EXPECT_EQ(out.str(), "bool"); EXPECT_EQ(out.str(), "bool");
} }
TEST_F(HlslGeneratorImplTest_Type, EmitType_F16) {
auto* f16 = create<sem::F16>();
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitType(out, f16, ast::StorageClass::kNone, ast::Access::kReadWrite, ""))
<< gen.error();
EXPECT_EQ(out.str(), "float16_t");
}
TEST_F(HlslGeneratorImplTest_Type, EmitType_F32) { TEST_F(HlslGeneratorImplTest_Type, EmitType_F32) {
auto* f32 = create<sem::F32>(); auto* f32 = create<sem::F32>();
@ -116,7 +127,20 @@ TEST_F(HlslGeneratorImplTest_Type, EmitType_I32) {
EXPECT_EQ(out.str(), "int"); EXPECT_EQ(out.str(), "int");
} }
TEST_F(HlslGeneratorImplTest_Type, EmitType_Matrix) { TEST_F(HlslGeneratorImplTest_Type, EmitType_Matrix_F16) {
auto* f16 = create<sem::F16>();
auto* vec3 = create<sem::Vector>(f16, 3u);
auto* mat2x3 = create<sem::Matrix>(vec3, 2u);
GeneratorImpl& gen = Build();
std::stringstream out;
ASSERT_TRUE(gen.EmitType(out, mat2x3, ast::StorageClass::kNone, ast::Access::kReadWrite, ""))
<< gen.error();
EXPECT_EQ(out.str(), "matrix<float16_t, 2, 3>");
}
TEST_F(HlslGeneratorImplTest_Type, EmitType_Matrix_F32) {
auto* f32 = create<sem::F32>(); auto* f32 = create<sem::F32>();
auto* vec3 = create<sem::Vector>(f32, 3u); auto* vec3 = create<sem::Vector>(f32, 3u);
auto* mat2x3 = create<sem::Matrix>(vec3, 2u); auto* mat2x3 = create<sem::Matrix>(vec3, 2u);

View File

@ -134,6 +134,22 @@ TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_f32)
)"); )");
} }
TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_f16) {
Enable(ast::Extension::kF16);
auto* C = Const("C", nullptr, Expr(1_h));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(void f() {
const float16_t l = float16_t(1.0h);
}
)");
}
TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_vec3_AInt) { TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_vec3_AInt) {
auto* C = Const("C", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a)); auto* C = Const("C", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
@ -176,6 +192,22 @@ TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_vec3
)"); )");
} }
TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_vec3_f16) {
Enable(ast::Extension::kF16);
auto* C = Const("C", nullptr, vec3<f16>(1_h, 2_h, 3_h));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(void f() {
const vector<float16_t, 3> l = vector<float16_t, 3>(float16_t(1.0h), float16_t(2.0h), float16_t(3.0h));
}
)");
}
TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_mat2x3_AFloat) { TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_mat2x3_AFloat) {
auto* C = auto* C =
Const("C", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a)); Const("C", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a));
@ -205,6 +237,22 @@ TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_mat2
)"); )");
} }
TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_mat2x3_f16) {
Enable(ast::Extension::kF16);
auto* C = Const("C", nullptr, mat2x3<f16>(1_h, 2_h, 3_h, 4_h, 5_h, 6_h));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.Generate()) << gen.error();
EXPECT_EQ(gen.result(), R"(void f() {
const matrix<float16_t, 2, 3> l = matrix<float16_t, 2, 3>(vector<float16_t, 3>(float16_t(1.0h), float16_t(2.0h), float16_t(3.0h)), vector<float16_t, 3>(float16_t(4.0h), float16_t(5.0h), float16_t(6.0h)));
}
)");
}
TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_arr_f32) { TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Const_arr_f32) {
auto* C = Const("C", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f)); auto* C = Const("C", nullptr, Construct(ty.array<f32, 3>(), 1_f, 2_f, 3_f));
Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});
@ -263,7 +311,7 @@ TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Private) {
EXPECT_THAT(gen.result(), HasSubstr(" static float a = 0.0f;\n")); EXPECT_THAT(gen.result(), HasSubstr(" static float a = 0.0f;\n"));
} }
TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Initializer_ZeroVec) { TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Initializer_ZeroVec_F32) {
auto* var = Var("a", ty.vec3<f32>(), ast::StorageClass::kNone, vec3<f32>()); auto* var = Var("a", ty.vec3<f32>(), ast::StorageClass::kNone, vec3<f32>());
auto* stmt = Decl(var); auto* stmt = Decl(var);
@ -276,7 +324,22 @@ TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Initialize
)"); )");
} }
TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Initializer_ZeroMat) { TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Initializer_ZeroVec_F16) {
Enable(ast::Extension::kF16);
auto* var = Var("a", ty.vec3<f16>(), ast::StorageClass::kNone, vec3<f16>());
auto* stmt = Decl(var);
WrapInFunction(stmt);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(gen.result(), R"(vector<float16_t, 3> a = (float16_t(0.0h)).xxx;
)");
}
TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Initializer_ZeroMat_F32) {
auto* var = Var("a", ty.mat2x3<f32>(), ast::StorageClass::kNone, mat2x3<f32>()); auto* var = Var("a", ty.mat2x3<f32>(), ast::StorageClass::kNone, mat2x3<f32>());
auto* stmt = Decl(var); auto* stmt = Decl(var);
@ -290,5 +353,22 @@ TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Initialize
)"); )");
} }
TEST_F(HlslGeneratorImplTest_VariableDecl, Emit_VariableDeclStatement_Initializer_ZeroMat_F16) {
Enable(ast::Extension::kF16);
auto* var = Var("a", ty.mat2x3<f16>(), ast::StorageClass::kNone, mat2x3<f16>());
auto* stmt = Decl(var);
WrapInFunction(stmt);
GeneratorImpl& gen = Build();
ASSERT_TRUE(gen.EmitStatement(stmt)) << gen.error();
EXPECT_EQ(
gen.result(),
R"(matrix<float16_t, 2, 3> a = matrix<float16_t, 2, 3>((float16_t(0.0h)).xxx, (float16_t(0.0h)).xxx);
)");
}
} // namespace } // namespace
} // namespace tint::writer::hlsl } // namespace tint::writer::hlsl