diff --git a/src/tint/writer/wgsl/generator_impl.cc b/src/tint/writer/wgsl/generator_impl.cc index 26e0a294df..50f23fc23b 100644 --- a/src/tint/writer/wgsl/generator_impl.cc +++ b/src/tint/writer/wgsl/generator_impl.cc @@ -258,6 +258,10 @@ bool GeneratorImpl::EmitLiteral(std::ostream& out, const ast::LiteralExpression* return true; }, [&](const ast::FloatLiteralExpression* l) { // + // f16 literals are also emitted as float value with suffix "h". + // Note that all normal and subnormal f16 values are normal f32 values, and since NaN + // and Inf are not allowed to be spelled in literal, it should be fine to emit f16 + // literals in this way. out << FloatToBitPreservingString(static_cast(l->value)) << l->suffix; return true; }, @@ -402,9 +406,8 @@ bool GeneratorImpl::EmitType(std::ostream& out, const ast::Type* ty) { return true; }, [&](const ast::F16*) { - diagnostics_.add_error(diag::System::Writer, - "Type f16 is not completely implemented yet."); - return false; + out << "f16"; + return true; }, [&](const ast::I32*) { out << "i32"; diff --git a/src/tint/writer/wgsl/generator_impl_cast_test.cc b/src/tint/writer/wgsl/generator_impl_cast_test.cc index c423943e40..9b9379b34e 100644 --- a/src/tint/writer/wgsl/generator_impl_cast_test.cc +++ b/src/tint/writer/wgsl/generator_impl_cast_test.cc @@ -21,7 +21,7 @@ namespace { using WgslGeneratorImplTest = TestHelper; -TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Scalar) { +TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Scalar_F32_From_I32) { auto* cast = Construct(1_i); WrapInFunction(cast); @@ -32,7 +32,20 @@ TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Scalar) { EXPECT_EQ(out.str(), "f32(1i)"); } -TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Vector) { +TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Scalar_F16_From_I32) { + Enable(ast::Extension::kF16); + + auto* cast = Construct(1_i); + WrapInFunction(cast); + + GeneratorImpl& gen = Build(); + + std::stringstream out; + ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error(); + EXPECT_EQ(out.str(), "f16(1i)"); +} + +TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Vector_F32_From_I32) { auto* cast = vec3(vec3(1_i, 2_i, 3_i)); WrapInFunction(cast); @@ -43,5 +56,18 @@ TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Vector) { EXPECT_EQ(out.str(), "vec3(vec3(1i, 2i, 3i))"); } +TEST_F(WgslGeneratorImplTest, EmitExpression_Cast_Vector_F16_From_I32) { + Enable(ast::Extension::kF16); + + auto* cast = vec3(vec3(1_i, 2_i, 3_i)); + WrapInFunction(cast); + + GeneratorImpl& gen = Build(); + + std::stringstream out; + ASSERT_TRUE(gen.EmitExpression(out, cast)) << gen.error(); + EXPECT_EQ(out.str(), "vec3(vec3(1i, 2i, 3i))"); +} + } // namespace } // namespace tint::writer::wgsl diff --git a/src/tint/writer/wgsl/generator_impl_constructor_test.cc b/src/tint/writer/wgsl/generator_impl_constructor_test.cc index ae9f2b74f6..07b10ad2ac 100644 --- a/src/tint/writer/wgsl/generator_impl_constructor_test.cc +++ b/src/tint/writer/wgsl/generator_impl_constructor_test.cc @@ -51,7 +51,7 @@ TEST_F(WgslGeneratorImplTest, EmitConstructor_UInt) { EXPECT_THAT(gen.result(), HasSubstr("56779u")); } -TEST_F(WgslGeneratorImplTest, EmitConstructor_Float) { +TEST_F(WgslGeneratorImplTest, EmitConstructor_F32) { // Use a number close to 1<<30 but whose decimal representation ends in 0. WrapInFunction(Expr(f32((1 << 30) - 4))); @@ -61,7 +61,19 @@ TEST_F(WgslGeneratorImplTest, EmitConstructor_Float) { EXPECT_THAT(gen.result(), HasSubstr("1073741824.0f")); } -TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Float) { +TEST_F(WgslGeneratorImplTest, EmitConstructor_F16) { + Enable(ast::Extension::kF16); + + // Use a number close to 1<<16 but whose decimal representation ends in 0. + WrapInFunction(Expr(f16((1 << 15) - 8))); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("32752.0h")); +} + +TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_F32) { WrapInFunction(Construct(Expr(-1.2e-5_f))); GeneratorImpl& gen = Build(); @@ -70,6 +82,17 @@ TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Float) { EXPECT_THAT(gen.result(), HasSubstr("f32(-0.000012f)")); } +TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(Construct(Expr(-1.2e-5_h))); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("f16(-1.19805336e-05h)")); +} + TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Bool) { WrapInFunction(Construct(true)); @@ -97,7 +120,7 @@ TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Uint) { EXPECT_THAT(gen.result(), HasSubstr("u32(12345u)")); } -TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Vec) { +TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Vec_F32) { WrapInFunction(vec3(1_f, 2_f, 3_f)); GeneratorImpl& gen = Build(); @@ -106,7 +129,18 @@ TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Vec) { EXPECT_THAT(gen.result(), HasSubstr("vec3(1.0f, 2.0f, 3.0f)")); } -TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Mat) { +TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Vec_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(vec3(1_h, 2_h, 3_h)); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("vec3(1.0h, 2.0h, 3.0h)")); +} + +TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Mat_F32) { WrapInFunction(mat2x3(vec3(1_f, 2_f, 3_f), vec3(3_f, 4_f, 5_f))); GeneratorImpl& gen = Build(); @@ -116,6 +150,18 @@ TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Mat) { "vec3(3.0f, 4.0f, 5.0f))")); } +TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Mat_F16) { + Enable(ast::Extension::kF16); + + WrapInFunction(mat2x3(vec3(1_h, 2_h, 3_h), vec3(3_h, 4_h, 5_h))); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_THAT(gen.result(), HasSubstr("mat2x3(vec3(1.0h, 2.0h, 3.0h), " + "vec3(3.0h, 4.0h, 5.0h))")); +} + TEST_F(WgslGeneratorImplTest, EmitConstructor_Type_Array) { WrapInFunction(Construct(ty.array(ty.vec3(), 3_u), vec3(1_f, 2_f, 3_f), vec3(4_f, 5_f, 6_f), vec3(7_f, 8_f, 9_f))); diff --git a/src/tint/writer/wgsl/generator_impl_literal_test.cc b/src/tint/writer/wgsl/generator_impl_literal_test.cc index 78d70f6be7..f33bd0ec69 100644 --- a/src/tint/writer/wgsl/generator_impl_literal_test.cc +++ b/src/tint/writer/wgsl/generator_impl_literal_test.cc @@ -25,7 +25,7 @@ namespace { // - 0 sign if sign is 0, 1 otherwise // - 'exponent_bits' is placed in the exponent space. // So, the exponent bias must already be included. -f32 MakeFloat(uint32_t sign, uint32_t biased_exponent, uint32_t mantissa) { +f32 MakeF32(uint32_t sign, uint32_t biased_exponent, uint32_t mantissa) { const uint32_t sign_bit = sign ? 0x80000000u : 0u; // The binary32 exponent is 8 bits, just below the sign. const uint32_t exponent_bits = (biased_exponent & 0xffu) << 23; @@ -40,18 +40,75 @@ f32 MakeFloat(uint32_t sign, uint32_t biased_exponent, uint32_t mantissa) { return f32(result); } -struct FloatData { +// Get the representation of an IEEE 754 binary16 floating point number with +// - 0 sign if sign is 0, 1 otherwise +// - 'exponent_bits' is placed in the exponent space. +// - the exponent bias (15) already be included. +f16 MakeF16(uint32_t sign, uint32_t f16_biased_exponent, uint16_t f16_mantissa) { + assert((f16_biased_exponent & 0xffffffe0u) == 0); + assert((f16_mantissa & 0xfc00u) == 0); + + const uint32_t sign_bit = sign ? 0x80000000u : 0u; + + // F16 has a exponent bias of 15, and f32 bias 127. Adding 127-15=112 to the f16-biased exponent + // to get f32-biased exponent. + uint32_t f32_biased_exponent = (f16_biased_exponent & 0x1fu) + 112; + assert((f32_biased_exponent & 0xffffff00u) == 0); + + if (f16_biased_exponent == 0) { + // +/- zero, or subnormal + if (f16_mantissa == 0) { + // +/- zero + return sign ? f16(-0.0f) : f16(0.0f); + } + // Subnormal f16, calc the corresponding exponent and mantissa of normal f32. + f32_biased_exponent += 1; + // There must be at least one of the 10 mantissa bits being 1, left-shift the mantissa bits + // until the most significant 1 bit is left-shifted to 10th bit (count from zero), which + // will be omitted in the resulting f32 mantissa part. + assert(f16_mantissa & 0x03ffu); + while ((f16_mantissa & 0x0400u) == 0) { + f16_mantissa = static_cast(f16_mantissa << 1); + f32_biased_exponent--; + } + } + + // The binary32 exponent is 8 bits, just below the sign. + const uint32_t f32_exponent_bits = (f32_biased_exponent & 0xffu) << 23; + // The mantissa is the bottom 23 bits. + const uint32_t f32_mantissa_bits = (f16_mantissa & 0x03ffu) << 13; + + uint32_t bits = sign_bit | f32_exponent_bits | f32_mantissa_bits; + float result = 0.0f; + static_assert(sizeof(result) == sizeof(bits), + "expected float and uint32_t to be the same size"); + std::memcpy(&result, &bits, sizeof(bits)); + return f16(result); +} + +struct F32Data { f32 value; std::string expected; }; -inline std::ostream& operator<<(std::ostream& out, FloatData data) { + +struct F16Data { + f16 value; + std::string expected; +}; + +inline std::ostream& operator<<(std::ostream& out, F32Data data) { out << "{" << data.value << "," << data.expected << "}"; return out; } -using WgslGenerator_FloatLiteralTest = TestParamHelper; +inline std::ostream& operator<<(std::ostream& out, F16Data data) { + out << "{" << data.value << "," << data.expected << "}"; + return out; +} -TEST_P(WgslGenerator_FloatLiteralTest, Emit) { +using WgslGenerator_F32LiteralTest = TestParamHelper; + +TEST_P(WgslGenerator_F32LiteralTest, Emit) { auto* v = Expr(GetParam().value); SetResolveOnBuild(false); @@ -63,38 +120,37 @@ TEST_P(WgslGenerator_FloatLiteralTest, Emit) { } INSTANTIATE_TEST_SUITE_P(Zero, - WgslGenerator_FloatLiteralTest, - ::testing::ValuesIn(std::vector{ - {0_f, "0.0f"}, - {MakeFloat(0, 0, 0), "0.0f"}, - {MakeFloat(1, 0, 0), "-0.0f"}})); + WgslGenerator_F32LiteralTest, + ::testing::ValuesIn(std::vector{{0_f, "0.0f"}, + {MakeF32(0, 0, 0), "0.0f"}, + {MakeF32(1, 0, 0), "-0.0f"}})); INSTANTIATE_TEST_SUITE_P(Normal, - WgslGenerator_FloatLiteralTest, - ::testing::ValuesIn(std::vector{{1_f, "1.0f"}, - {-1_f, "-1.0f"}, - {101.375_f, "101.375f"}})); + WgslGenerator_F32LiteralTest, + ::testing::ValuesIn(std::vector{{1_f, "1.0f"}, + {-1_f, "-1.0f"}, + {101.375_f, "101.375f"}})); INSTANTIATE_TEST_SUITE_P(Subnormal, - WgslGenerator_FloatLiteralTest, - ::testing::ValuesIn(std::vector{ - {MakeFloat(0, 0, 1), "0x1p-149f"}, // Smallest - {MakeFloat(1, 0, 1), "-0x1p-149f"}, - {MakeFloat(0, 0, 2), "0x1p-148f"}, - {MakeFloat(1, 0, 2), "-0x1p-148f"}, - {MakeFloat(0, 0, 0x7fffff), "0x1.fffffcp-127f"}, // Largest - {MakeFloat(1, 0, 0x7fffff), "-0x1.fffffcp-127f"}, // Largest - {MakeFloat(0, 0, 0xcafebe), "0x1.2bfaf8p-127f"}, // Scattered bits - {MakeFloat(1, 0, 0xcafebe), "-0x1.2bfaf8p-127f"}, // Scattered bits - {MakeFloat(0, 0, 0xaaaaa), "0x1.55554p-130f"}, // Scattered bits - {MakeFloat(1, 0, 0xaaaaa), "-0x1.55554p-130f"}, // Scattered bits + WgslGenerator_F32LiteralTest, + ::testing::ValuesIn(std::vector{ + {MakeF32(0, 0, 1), "0x1p-149f"}, // Smallest + {MakeF32(1, 0, 1), "-0x1p-149f"}, + {MakeF32(0, 0, 2), "0x1p-148f"}, + {MakeF32(1, 0, 2), "-0x1p-148f"}, + {MakeF32(0, 0, 0x7fffff), "0x1.fffffcp-127f"}, // Largest + {MakeF32(1, 0, 0x7fffff), "-0x1.fffffcp-127f"}, // Largest + {MakeF32(0, 0, 0xcafebe), "0x1.2bfaf8p-127f"}, // Scattered bits + {MakeF32(1, 0, 0xcafebe), "-0x1.2bfaf8p-127f"}, // Scattered bits + {MakeF32(0, 0, 0xaaaaa), "0x1.55554p-130f"}, // Scattered bits + {MakeF32(1, 0, 0xaaaaa), "-0x1.55554p-130f"}, // Scattered bits })); INSTANTIATE_TEST_SUITE_P(Infinity, - WgslGenerator_FloatLiteralTest, - ::testing::ValuesIn(std::vector{ - {MakeFloat(0, 255, 0), "0x1p+128f"}, - {MakeFloat(1, 255, 0), "-0x1p+128f"}})); + WgslGenerator_F32LiteralTest, + ::testing::ValuesIn(std::vector{ + {MakeF32(0, 255, 0), "0x1p+128f"}, + {MakeF32(1, 255, 0), "-0x1p+128f"}})); INSTANTIATE_TEST_SUITE_P( // TODO(dneto): It's unclear how Infinity and NaN should be handled. @@ -106,23 +162,95 @@ INSTANTIATE_TEST_SUITE_P( // whether the NaN is signalling or quiet, but no agreement between // different machine architectures on whether 1 means signalling or // if 1 means quiet. - WgslGenerator_FloatLiteralTest, - ::testing::ValuesIn(std::vector{ + WgslGenerator_F32LiteralTest, + ::testing::ValuesIn(std::vector{ // LSB only. Smallest mantissa. - {MakeFloat(0, 255, 1), "0x1.000002p+128f"}, // Smallest mantissa - {MakeFloat(1, 255, 1), "-0x1.000002p+128f"}, + {MakeF32(0, 255, 1), "0x1.000002p+128f"}, // Smallest mantissa + {MakeF32(1, 255, 1), "-0x1.000002p+128f"}, // MSB only. - {MakeFloat(0, 255, 0x400000), "0x1.8p+128f"}, - {MakeFloat(1, 255, 0x400000), "-0x1.8p+128f"}, + {MakeF32(0, 255, 0x400000), "0x1.8p+128f"}, + {MakeF32(1, 255, 0x400000), "-0x1.8p+128f"}, // All 1s in the mantissa. - {MakeFloat(0, 255, 0x7fffff), "0x1.fffffep+128f"}, - {MakeFloat(1, 255, 0x7fffff), "-0x1.fffffep+128f"}, + {MakeF32(0, 255, 0x7fffff), "0x1.fffffep+128f"}, + {MakeF32(1, 255, 0x7fffff), "-0x1.fffffep+128f"}, // Scattered bits, with 0 in top mantissa bit. - {MakeFloat(0, 255, 0x20101f), "0x1.40203ep+128f"}, - {MakeFloat(1, 255, 0x20101f), "-0x1.40203ep+128f"}, + {MakeF32(0, 255, 0x20101f), "0x1.40203ep+128f"}, + {MakeF32(1, 255, 0x20101f), "-0x1.40203ep+128f"}, // Scattered bits, with 1 in top mantissa bit. - {MakeFloat(0, 255, 0x40101f), "0x1.80203ep+128f"}, - {MakeFloat(1, 255, 0x40101f), "-0x1.80203ep+128f"}})); + {MakeF32(0, 255, 0x40101f), "0x1.80203ep+128f"}, + {MakeF32(1, 255, 0x40101f), "-0x1.80203ep+128f"}})); + +using WgslGenerator_F16LiteralTest = TestParamHelper; + +TEST_P(WgslGenerator_F16LiteralTest, Emit) { + Enable(ast::Extension::kF16); + + auto* v = Expr(GetParam().value); + + SetResolveOnBuild(false); + GeneratorImpl& gen = Build(); + + std::stringstream out; + ASSERT_TRUE(gen.EmitLiteral(out, v)) << gen.error(); + EXPECT_EQ(out.str(), GetParam().expected); +} + +INSTANTIATE_TEST_SUITE_P(Zero, + WgslGenerator_F16LiteralTest, + ::testing::ValuesIn(std::vector{{0_h, "0.0h"}, + {MakeF16(0, 0, 0), "0.0h"}, + {MakeF16(1, 0, 0), "-0.0h"}})); + +INSTANTIATE_TEST_SUITE_P(Normal, + WgslGenerator_F16LiteralTest, + ::testing::ValuesIn(std::vector{{1_h, "1.0h"}, + {-1_h, "-1.0h"}, + {101.375_h, "101.375h"}})); + +INSTANTIATE_TEST_SUITE_P(Subnormal, + WgslGenerator_F16LiteralTest, + ::testing::ValuesIn(std::vector{ + {MakeF16(0, 0, 1), "5.96046448e-08h"}, // Smallest + {MakeF16(1, 0, 1), "-5.96046448e-08h"}, + {MakeF16(0, 0, 2), "1.1920929e-07h"}, + {MakeF16(1, 0, 2), "-1.1920929e-07h"}, + {MakeF16(0, 0, 0x3ffu), "6.09755516e-05h"}, // Largest + {MakeF16(1, 0, 0x3ffu), "-6.09755516e-05h"}, // Largest + {MakeF16(0, 0, 0x3afu), "5.620718e-05h"}, // Scattered bits + {MakeF16(1, 0, 0x3afu), "-5.620718e-05h"}, // Scattered bits + {MakeF16(0, 0, 0x2c7u), "4.23789024e-05h"}, // Scattered bits + {MakeF16(1, 0, 0x2c7u), "-4.23789024e-05h"}, // Scattered bits + })); + +INSTANTIATE_TEST_SUITE_P( + // Currently Inf is impossible to be spelled out in literal. + // https://github.com/gpuweb/gpuweb/issues/1769 + DISABLED_Infinity, + WgslGenerator_F16LiteralTest, + ::testing::ValuesIn(std::vector{{MakeF16(0, 31, 0), "0x1p+128h"}, + {MakeF16(1, 31, 0), "-0x1p+128h"}})); + +INSTANTIATE_TEST_SUITE_P( + // Currently NaN is impossible to be spelled out in literal. + // https://github.com/gpuweb/gpuweb/issues/1769 + DISABLED_NaN, + WgslGenerator_F16LiteralTest, + ::testing::ValuesIn(std::vector{ + // LSB only. Smallest mantissa. + {MakeF16(0, 31, 1), "0x1.004p+128h"}, // Smallest mantissa + {MakeF16(1, 31, 1), "-0x1.004p+128h"}, + // MSB only. + {MakeF16(0, 31, 0x200u), "0x1.8p+128h"}, + {MakeF16(1, 31, 0x200u), "-0x1.8p+128h"}, + // All 1s in the mantissa. + {MakeF16(0, 31, 0x3ffu), "0x1.ffcp+128h"}, + {MakeF16(1, 31, 0x3ffu), "-0x1.ffcp+128h"}, + // Scattered bits, with 0 in top mantissa bit. + {MakeF16(0, 31, 0x11fu), "0x1.47cp+128h"}, + {MakeF16(1, 31, 0x11fu), "-0x1.47cp+128h"}, + // Scattered bits, with 1 in top mantissa bit. + {MakeF16(0, 31, 0x23fu), "0x1.8fcp+128h"}, + {MakeF16(1, 31, 0x23fu), "-0x1.8fcp+128h"}})); } // namespace } // namespace tint::writer::wgsl diff --git a/src/tint/writer/wgsl/generator_impl_type_test.cc b/src/tint/writer/wgsl/generator_impl_type_test.cc index b6ae9c5660..2e546208a3 100644 --- a/src/tint/writer/wgsl/generator_impl_type_test.cc +++ b/src/tint/writer/wgsl/generator_impl_type_test.cc @@ -91,6 +91,19 @@ TEST_F(WgslGeneratorImplTest, EmitType_F32) { EXPECT_EQ(out.str(), "f32"); } +TEST_F(WgslGeneratorImplTest, EmitType_F16) { + Enable(ast::Extension::kF16); + + auto* f16 = ty.f16(); + Alias("make_type_reachable", f16); + + GeneratorImpl& gen = Build(); + + std::stringstream out; + ASSERT_TRUE(gen.EmitType(out, f16)) << gen.error(); + EXPECT_EQ(out.str(), "f16"); +} + TEST_F(WgslGeneratorImplTest, EmitType_I32) { auto* i32 = ty.i32(); Alias("make_type_reachable", i32); @@ -102,7 +115,7 @@ TEST_F(WgslGeneratorImplTest, EmitType_I32) { EXPECT_EQ(out.str(), "i32"); } -TEST_F(WgslGeneratorImplTest, EmitType_Matrix) { +TEST_F(WgslGeneratorImplTest, EmitType_Matrix_F32) { auto* mat2x3 = ty.mat2x3(); Alias("make_type_reachable", mat2x3); @@ -113,6 +126,19 @@ TEST_F(WgslGeneratorImplTest, EmitType_Matrix) { EXPECT_EQ(out.str(), "mat2x3"); } +TEST_F(WgslGeneratorImplTest, EmitType_Matrix_F16) { + Enable(ast::Extension::kF16); + + auto* mat2x3 = ty.mat2x3(); + Alias("make_type_reachable", mat2x3); + + GeneratorImpl& gen = Build(); + + std::stringstream out; + ASSERT_TRUE(gen.EmitType(out, mat2x3)) << gen.error(); + EXPECT_EQ(out.str(), "mat2x3"); +} + TEST_F(WgslGeneratorImplTest, EmitType_Pointer) { auto* p = ty.pointer(ast::StorageClass::kWorkgroup); Alias("make_type_reachable", p); @@ -271,7 +297,7 @@ TEST_F(WgslGeneratorImplTest, EmitType_U32) { EXPECT_EQ(out.str(), "u32"); } -TEST_F(WgslGeneratorImplTest, EmitType_Vector) { +TEST_F(WgslGeneratorImplTest, EmitType_Vector_F32) { auto* vec3 = ty.vec3(); Alias("make_type_reachable", vec3); @@ -282,6 +308,19 @@ TEST_F(WgslGeneratorImplTest, EmitType_Vector) { EXPECT_EQ(out.str(), "vec3"); } +TEST_F(WgslGeneratorImplTest, EmitType_Vector_F16) { + Enable(ast::Extension::kF16); + + auto* vec3 = ty.vec3(); + Alias("make_type_reachable", vec3); + + GeneratorImpl& gen = Build(); + + std::stringstream out; + ASSERT_TRUE(gen.EmitType(out, vec3)) << gen.error(); + EXPECT_EQ(out.str(), "vec3"); +} + struct TextureData { ast::TextureDimension dim; const char* name; diff --git a/src/tint/writer/wgsl/generator_impl_variable_decl_statement_test.cc b/src/tint/writer/wgsl/generator_impl_variable_decl_statement_test.cc index aeacd308cf..abfe4fec8b 100644 --- a/src/tint/writer/wgsl/generator_impl_variable_decl_statement_test.cc +++ b/src/tint/writer/wgsl/generator_impl_variable_decl_statement_test.cc @@ -125,6 +125,25 @@ TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_f32) { )"); } +TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_f16) { + Enable(ast::Extension::kF16); + + auto* C = Const("C", nullptr, Expr(1_h)); + Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(enable f16; + +fn f() { + const C = 1.0h; + let l = C; +} +)"); +} + TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_AInt) { auto* C = Const("C", nullptr, Construct(ty.vec3(nullptr), 1_a, 2_a, 3_a)); Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); @@ -170,6 +189,25 @@ TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_f32) { )"); } +TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_vec3_f16) { + Enable(ast::Extension::kF16); + + auto* C = Const("C", nullptr, vec3(1_h, 2_h, 3_h)); + Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(enable f16; + +fn f() { + const C = vec3(1.0h, 2.0h, 3.0h); + let l = C; +} +)"); +} + TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_AFloat) { auto* C = Const("C", nullptr, Construct(ty.mat(nullptr, 2, 3), 1._a, 2._a, 3._a, 4._a, 5._a, 6._a)); @@ -201,6 +239,25 @@ TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_f32) { )"); } +TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_mat2x3_f16) { + Enable(ast::Extension::kF16); + + auto* C = Const("C", nullptr, mat2x3(1_h, 2_h, 3_h, 4_h, 5_h, 6_h)); + Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))}); + + GeneratorImpl& gen = Build(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + + EXPECT_EQ(gen.result(), R"(enable f16; + +fn f() { + const C = mat2x3(1.0h, 2.0h, 3.0h, 4.0h, 5.0h, 6.0h); + let l = C; +} +)"); +} + TEST_F(WgslGeneratorImplTest, Emit_VariableDeclStatement_Const_arr_f32) { auto* C = Const("C", nullptr, Construct(ty.array(), 1_f, 2_f, 3_f)); Func("f", {}, ty.void_(), {Decl(C), Decl(Let("l", nullptr, Expr(C)))});