diff --git a/src/common/Math.cpp b/src/common/Math.cpp index d9217c8e44..edb68f951f 100644 --- a/src/common/Math.cpp +++ b/src/common/Math.cpp @@ -85,8 +85,10 @@ uint16_t Float32ToFloat16(float fp32) { uint32_t sign16 = (fp32i & 0x80000000) >> 16; uint32_t mantissaAndExponent = fp32i & 0x7FFFFFFF; - if (mantissaAndExponent > 0x47FFEFFF) { // Infinity - return static_cast(sign16 | 0x7FFF); + if (mantissaAndExponent > 0x7F800000) { // NaN + return 0x7FFF; + } else if (mantissaAndExponent > 0x47FFEFFF) { // Infinity + return static_cast(sign16 | 0x7C00); } else if (mantissaAndExponent < 0x38800000) { // Denormal uint32_t mantissa = (mantissaAndExponent & 0x007FFFFF) | 0x00800000; int32_t exponent = 113 - (mantissaAndExponent >> 23); diff --git a/src/tests/unittests/MathTests.cpp b/src/tests/unittests/MathTests.cpp index 4caf9e7d0a..2063a41283 100644 --- a/src/tests/unittests/MathTests.cpp +++ b/src/tests/unittests/MathTests.cpp @@ -16,6 +16,8 @@ #include "common/Math.h" +#include + // Tests for ScanForward TEST(Math, ScanForward) { // Test extrema @@ -133,3 +135,18 @@ TEST(Math, IsAligned) { ASSERT_FALSE(IsAligned(64 + i, 64)); } } + +// Tests for float32 to float16 conversion +TEST(Math, Float32ToFloat16) { + ASSERT_EQ(Float32ToFloat16(0.0f), 0x0000); + ASSERT_EQ(Float32ToFloat16(-0.0f), 0x8000); + + ASSERT_EQ(Float32ToFloat16(INFINITY), 0x7C00); + ASSERT_EQ(Float32ToFloat16(-INFINITY), 0xFC00); + + // Check that NaN is converted to a value in one of the float16 NaN ranges + uint16_t nan16 = Float32ToFloat16(NAN); + ASSERT_TRUE(nan16 > 0xFC00 || (nan16 < 0x8000 && nan16 > 0x7C00)); + + ASSERT_EQ(Float32ToFloat16(1.0f), 0x3C00); +}