diff --git a/src/reader/wgsl/lexer.cc b/src/reader/wgsl/lexer.cc index c4eed5f480..9ddcdb82b4 100644 --- a/src/reader/wgsl/lexer.cc +++ b/src/reader/wgsl/lexer.cc @@ -14,8 +14,12 @@ #include "src/reader/wgsl/lexer.h" +#include +#include #include +#include "src/debug.h" + namespace tint { namespace reader { namespace wgsl { @@ -25,6 +29,26 @@ bool is_whitespace(char c) { return std::isspace(c); } +uint32_t dec_value(char c) { + if (c >= '0' && c <= '9') { + return static_cast(c - '0'); + } + return 0; +} + +uint32_t hex_value(char c) { + if (c >= '0' && c <= '9') { + return static_cast(c - '0'); + } + if (c >= 'a' && c <= 'f') { + return 0xA + static_cast(c - 'a'); + } + if (c >= 'A' && c <= 'F') { + return 0xA + static_cast(c - 'A'); + } + return 0; +} + } // namespace Lexer::Lexer(const std::string& file_path, const Source::FileContent* content) @@ -43,7 +67,12 @@ Token Lexer::next() { return {Token::Type::kEOF, begin_source()}; } - auto t = try_hex_integer(); + auto t = try_hex_float(); + if (!t.IsUninitialized()) { + return t; + } + + t = try_hex_integer(); if (!t.IsUninitialized()) { return t; } @@ -239,6 +268,225 @@ Token Lexer::try_float() { return {source, static_cast(res)}; } +Token Lexer::try_hex_float() { + constexpr uint32_t kTotalBits = 32; + constexpr uint32_t kTotalMsb = kTotalBits - 1; + constexpr uint32_t kMantissaBits = 23; + constexpr uint32_t kMantissaMsb = kMantissaBits - 1; + constexpr uint32_t kMantissaShiftRight = kTotalBits - kMantissaBits; + constexpr int32_t kExponentBias = 127; + constexpr int32_t kExponentMax = 255; + constexpr uint32_t kExponentBits = 8; + constexpr uint32_t kExponentMask = (1 << kExponentBits) - 1; + constexpr uint32_t kExponentLeftShift = kMantissaBits; + constexpr uint32_t kSignBit = 31; + + auto start = pos_; + auto end = pos_; + + auto source = begin_source(); + + // clang-format off + // -?0x([0-9a-fA-F]*.?[0-9a-fA-F]+ | [0-9a-fA-F]+.[0-9a-fA-F]*)(p|P)(+|-)?[0-9]+ // NOLINT + // clang-format on + + // -? + int32_t sign_bit = 0; + if (matches(end, "-")) { + sign_bit = 1; + end++; + } + // 0x + if (matches(end, "0x")) { + end += 2; + } else { + return {}; + } + + uint32_t mantissa = 0; + int32_t exponent = 0; + + // `set_next_mantissa_bit_to` sets next `mantissa` bit starting from msb to + // lsb to value 1 if `set` is true, 0 otherwise + uint32_t mantissa_next_bit = kTotalMsb; + auto set_next_mantissa_bit_to = [&](bool set) -> bool { + if (mantissa_next_bit > kTotalMsb) { + return false; // Overflowed mantissa + } + if (set) { + mantissa |= (1 << mantissa_next_bit); + } + --mantissa_next_bit; + return true; + }; + + // Parse integer part + // [0-9a-fA-F]* + bool has_integer = false; + bool has_zero_integer = true; + bool leading_bit_seen = false; + while (end < len_ && is_hex(content_->data[end])) { + has_integer = true; + + const auto nibble = hex_value(content_->data[end]); + if (nibble != 0) { + has_zero_integer = false; + } + + for (int32_t i = 3; i >= 0; --i) { + auto v = 1 & (nibble >> i); + + // Skip leading 0s and the first 1 + if (leading_bit_seen) { + if (!set_next_mantissa_bit_to(v != 0)) { + return {}; + } + ++exponent; + } else { + if (v == 1) { + leading_bit_seen = true; + } + } + } + + end++; + } + + // .? + if (matches(end, ".")) { + end++; + } + + // Parse fractional part + // [0-9a-fA-F]* + bool has_fractional = false; + leading_bit_seen = false; + while (end < len_ && is_hex(content_->data[end])) { + has_fractional = true; + auto nibble = hex_value(content_->data[end]); + for (int32_t i = 3; i >= 0; --i) { + auto v = 1 & (nibble >> i); + + if (v == 1) { + leading_bit_seen = true; + } + + // If integer part is 0 (denorm), we only start writing bits to the + // mantissa once we have a non-zero fractional bit. While the fractional + // values are 0, we adjust the exponent to avoid overflowing `mantissa`. + if (has_zero_integer && !leading_bit_seen) { + --exponent; + } else { + if (!set_next_mantissa_bit_to(v != 0)) { + return {}; + } + } + } + + end++; + } + + if (!(has_integer || has_fractional)) { + return {}; + } + + // (p|P) + if (matches(end, "p") || matches(end, "P")) { + end++; + } else { + return {}; + } + + // (+|-)? + int32_t exponent_sign = 1; + if (matches(end, "+")) { + end++; + } else if (matches(end, "-")) { + exponent_sign = -1; + end++; + } + + // Parse exponent from input + // [0-9]+ + bool has_exponent = false; + int32_t input_exponent = 0; + while (end < len_ && isdigit(content_->data[end])) { + has_exponent = true; + input_exponent = (input_exponent * 10) + dec_value(content_->data[end]); + end++; + } + if (!has_exponent) { + return {}; + } + + pos_ = end; + location_.column += (end - start); + end_source(source); + + // Compute exponent so far + exponent = exponent + (input_exponent * exponent_sign); + + // Determine if value is zero + // Note: it's not enough to check mantissa == 0 as we drop initial bit from + // integer part. + bool is_zero = has_zero_integer && mantissa == 0; + TINT_ASSERT(Reader, !is_zero || (exponent == 0 && mantissa == 0)); + + if (!is_zero) { + // Bias exponent if non-zero + // After this, if exponent is <= 0, our value is a denormal + exponent += kExponentBias; + + // Denormal uses biased exponent of -126, not -127 + if (has_zero_integer) { + mantissa <<= 1; + --exponent; + } + } + + // Shift mantissa to occupy the low 23 bits + mantissa >>= kMantissaShiftRight; + + // If denormal, shift mantissa until our exponent is zero + if (!is_zero) { + // Denorm has exponent 0 and non-zero mantissa. We set the top bit here, + // then shift the mantissa to make exponent zero. + if (exponent <= 0) { + mantissa >>= 1; + mantissa |= (1 << kMantissaMsb); + } + + while (exponent < 0) { + mantissa >>= 1; + ++exponent; + + // If underflow, clamp to zero + if (mantissa == 0) { + exponent = 0; + } + } + } + + if (exponent > kExponentMax) { + // Overflow: set to infinity + exponent = kExponentMax; + mantissa = 0; + } else if (exponent == kExponentMax && mantissa != 0) { + // NaN: set to infinity + mantissa = 0; + } + + // Combine sign, mantissa, and exponent + uint32_t result_u32 = sign_bit << kSignBit; + result_u32 |= mantissa; + result_u32 |= (exponent & kExponentMask) << kExponentLeftShift; + + // Reinterpret as float and return + float result; + std::memcpy(&result, &result_u32, sizeof(result)); + return {source, static_cast(result)}; +} + Token Lexer::build_token_from_int_if_possible(Source source, size_t start, size_t end, diff --git a/src/reader/wgsl/lexer.h b/src/reader/wgsl/lexer.h index b1774e9785..9c96bb5af3 100644 --- a/src/reader/wgsl/lexer.h +++ b/src/reader/wgsl/lexer.h @@ -46,6 +46,7 @@ class Lexer { int32_t base); Token check_keyword(const Source&, const std::string&); Token try_float(); + Token try_hex_float(); Token try_hex_integer(); Token try_ident(); Token try_integer(); diff --git a/src/reader/wgsl/parser_impl_const_literal_test.cc b/src/reader/wgsl/parser_impl_const_literal_test.cc index d73abaa25e..4dee6ce219 100644 --- a/src/reader/wgsl/parser_impl_const_literal_test.cc +++ b/src/reader/wgsl/parser_impl_const_literal_test.cc @@ -14,17 +14,39 @@ #include "src/reader/wgsl/parser_impl_test_helper.h" +#include +#include + namespace tint { namespace reader { namespace wgsl { namespace { +// Makes an IEEE 754 binary32 floating point number with +// - 0 sign if sign is 0, 1 otherwise +// - 'exponent_bits' is placed in the exponent space. +// So, the exponent bias must already be included. +float MakeFloat(int sign, int biased_exponent, int mantissa) { + const uint32_t sign_bit = sign ? 0x80000000u : 0u; + // The binary32 exponent is 8 bits, just below the sign. + const uint32_t exponent_bits = (biased_exponent & 0xffu) << 23; + // The mantissa is the bottom 23 bits. + const uint32_t mantissa_bits = (mantissa & 0x7fffffu); + + uint32_t bits = sign_bit | exponent_bits | mantissa_bits; + float result = 0.0f; + static_assert(sizeof(result) == sizeof(bits), + "expected float and uint32_t to be the same size"); + std::memcpy(&result, &bits, sizeof(bits)); + return result; +} + TEST_F(ParserImplTest, ConstLiteral_Int) { auto p = parser("-234"); auto c = p->const_literal(); EXPECT_TRUE(c.matched); EXPECT_FALSE(c.errored); - EXPECT_FALSE(p->has_error()); + EXPECT_FALSE(p->has_error()) << p->error(); ASSERT_NE(c.value, nullptr); ASSERT_TRUE(c->Is()); EXPECT_EQ(c->As()->value(), -234); @@ -36,7 +58,7 @@ TEST_F(ParserImplTest, ConstLiteral_Uint) { auto c = p->const_literal(); EXPECT_TRUE(c.matched); EXPECT_FALSE(c.errored); - EXPECT_FALSE(p->has_error()); + EXPECT_FALSE(p->has_error()) << p->error(); ASSERT_NE(c.value, nullptr); ASSERT_TRUE(c->Is()); EXPECT_EQ(c->As()->value(), 234u); @@ -48,7 +70,7 @@ TEST_F(ParserImplTest, ConstLiteral_Float) { auto c = p->const_literal(); EXPECT_TRUE(c.matched); EXPECT_FALSE(c.errored); - EXPECT_FALSE(p->has_error()); + EXPECT_FALSE(p->has_error()) << p->error(); ASSERT_NE(c.value, nullptr); ASSERT_TRUE(c->Is()); EXPECT_FLOAT_EQ(c->As()->value(), 234e12f); @@ -63,12 +85,221 @@ TEST_F(ParserImplTest, ConstLiteral_InvalidFloat) { ASSERT_EQ(c.value, nullptr); } +struct FloatLiteralTestCase { + const char* input; + float expected; +}; + +inline std::ostream& operator<<(std::ostream& out, FloatLiteralTestCase data) { + out << data.input; + return out; +} + +class ParserImplFloatLiteralTest + : public ParserImplTestWithParam {}; +TEST_P(ParserImplFloatLiteralTest, Parse) { + auto params = GetParam(); + SCOPED_TRACE(params.input); + auto p = parser(params.input); + auto c = p->const_literal(); + EXPECT_TRUE(c.matched); + EXPECT_FALSE(c.errored); + EXPECT_FALSE(p->has_error()) << p->error(); + ASSERT_NE(c.value, nullptr); + ASSERT_TRUE(c->Is()); + EXPECT_FLOAT_EQ(c->As()->value(), params.expected); +} + +FloatLiteralTestCase float_literal_test_cases[] = { + {"0.0", 0.0f}, // Zero + {"1.0", 1.0f}, // One + {"-1.0", -1.0f}, // MinusOne + {"1000000000.0", 1e9f}, // Billion + {"-0.0", std::copysign(0.0f, -5.0f)}, // NegativeZero + {"0.0", MakeFloat(0, 0, 0)}, // Zero + {"-0.0", MakeFloat(1, 0, 0)}, // NegativeZero + {"1.0", MakeFloat(0, 127, 0)}, // One + {"-1.0", MakeFloat(1, 127, 0)}, // NegativeOne +}; +INSTANTIATE_TEST_SUITE_P(ParserImplFloatLiteralTest_Float, + ParserImplFloatLiteralTest, + testing::ValuesIn(float_literal_test_cases)); + +const float NegInf = MakeFloat(1, 255, 0); +const float PosInf = MakeFloat(0, 255, 0); +FloatLiteralTestCase hexfloat_literal_test_cases[] = { + // Regular numbers + {"0x0p+0", 0.f}, + {"0x1p+0", 1.f}, + {"0x1p+1", 2.f}, + {"0x1.8p+1", 3.f}, + {"0x1.99999ap-4", 0.1f}, + {"0x1p-1", 0.5f}, + {"0x1p-2", 0.25f}, + {"0x1.8p-1", 0.75f}, + {"-0x0p+0", -0.f}, + {"-0x1p+0", -1.f}, + {"-0x1p-1", -0.5f}, + {"-0x1p-2", -0.25f}, + {"-0x1.8p-1", -0.75f}, + + // Large numbers + {"0x1p+9", 512.f}, + {"0x1p+10", 1024.f}, + {"0x1.02p+10", 1024.f + 8.f}, + {"-0x1p+9", -512.f}, + {"-0x1p+10", -1024.f}, + {"-0x1.02p+10", -1024.f - 8.f}, + + // Small numbers + {"0x1p-9", 1.0f / 512.f}, + {"0x1p-10", 1.0f / 1024.f}, + {"0x1.02p-3", 1.0f / 1024.f + 1.0f / 8.f}, + {"-0x1p-9", 1.0f / -512.f}, + {"-0x1p-10", 1.0f / -1024.f}, + {"-0x1.02p-3", 1.0f / -1024.f - 1.0f / 8.f}, + + // Near lowest non-denorm + {"0x1p-124", std::ldexp(1.f * 8.f, -127)}, + {"0x1p-125", std::ldexp(1.f * 4.f, -127)}, + {"-0x1p-124", -std::ldexp(1.f * 8.f, -127)}, + {"-0x1p-125", -std::ldexp(1.f * 4.f, -127)}, + + // Lowest non-denorm + {"0x1p-126", std::ldexp(1.f * 2.f, -127)}, + {"-0x1p-126", -std::ldexp(1.f * 2.f, -127)}, + + // Denormalized values + {"0x1p-127", std::ldexp(1.f, -127)}, + {"0x1p-128", std::ldexp(1.f / 2.f, -127)}, + {"0x1p-129", std::ldexp(1.f / 4.f, -127)}, + {"0x1p-130", std::ldexp(1.f / 8.f, -127)}, + {"-0x1p-127", -std::ldexp(1.f, -127)}, + {"-0x1p-128", -std::ldexp(1.f / 2.f, -127)}, + {"-0x1p-129", -std::ldexp(1.f / 4.f, -127)}, + {"-0x1p-130", -std::ldexp(1.f / 8.f, -127)}, + + {"0x1.8p-127", std::ldexp(1.f, -127) + (std::ldexp(1.f, -127) / 2.f)}, + {"0x1.8p-128", std::ldexp(1.f, -127) / 2.f + (std::ldexp(1.f, -127) / 4.f)}, + + {"0x1p-149", MakeFloat(0, 0, 1)}, // +SmallestDenormal + {"0x1p-148", MakeFloat(0, 0, 2)}, // +BiggerDenormal + {"0x1.fffffcp-127", MakeFloat(0, 0, 0x7fffff)}, // +LargestDenormal + {"-0x1p-149", MakeFloat(1, 0, 1)}, // -SmallestDenormal + {"-0x1p-148", MakeFloat(1, 0, 2)}, // -BiggerDenormal + {"-0x1.fffffcp-127", MakeFloat(1, 0, 0x7fffff)}, // -LargestDenormal + + {"0x1.2bfaf8p-127", MakeFloat(0, 0, 0xcafebe)}, // +Subnormal + {"-0x1.2bfaf8p-127", MakeFloat(1, 0, 0xcafebe)}, // -Subnormal + {"0x1.55554p-130", MakeFloat(0, 0, 0xaaaaa)}, // +Subnormal + {"-0x1.55554p-130", MakeFloat(1, 0, 0xaaaaa)}, // -Subnormal + + // Nan -> Infinity + {"0x1.8p+128", PosInf}, + {"0x1.0002p+128", PosInf}, + {"0x1.0018p+128", PosInf}, + {"0x1.01ep+128", PosInf}, + {"0x1.fffffep+128", PosInf}, + {"-0x1.8p+128", NegInf}, + {"-0x1.0002p+128", NegInf}, + {"-0x1.0018p+128", NegInf}, + {"-0x1.01ep+128", NegInf}, + {"-0x1.fffffep+128", NegInf}, + + // Infinity + {"0x1p+128", PosInf}, + {"-0x1p+128", NegInf}, + {"0x32p+127", PosInf}, + {"0x32p+500", PosInf}, + {"-0x32p+127", NegInf}, + {"-0x32p+500", NegInf}, + + // Overflow -> Infinity + {"0x1p+129", PosInf}, + {"0x1.1p+128", PosInf}, + {"-0x1p+129", NegInf}, + {"-0x1.1p+128", NegInf}, + + // Underflow -> Zero + {"0x1p-500", 0.f}, // Exponent underflows + {"-0x1p-500", -0.f}, + {"0x0.00000000001p-126", 0.f}, // Fraction causes underflow + {"-0x0.0000000001p-127", -0.f}, + {"0x0.01p-142", 0.f}, + {"-0x0.01p-142", -0.f}, // Fraction causes additional underflow + + // Test parsing + {"0x0p0", 0.f}, + {"0x0p-0", 0.f}, + {"0x0p+000", 0.f}, + {"0x00000000000000p+000000000000000", 0.f}, + {"0x00000000000000p-000000000000000", 0.f}, + {"0x00000000000001p+000000000000000", 1.f}, + {"0x00000000000001p-000000000000000", 1.f}, + {"0x0000000000000000000001.99999ap-000000000000000004", 0.1f}, + {"0x2p+0", 2.f}, + {"0xFFp+0", 255.f}, + {"0x0.8p+0", 0.5f}, + {"0x0.4p+0", 0.25f}, + {"0x0.4p+1", 2 * 0.25f}, + {"0x0.4p+2", 4 * 0.25f}, + {"0x123Ep+1", 9340.f}, + {"-0x123Ep+1", -9340.f}, + {"0x1a2b3cP12", 7.024656e+09f}, + {"-0x1a2b3cP12", -7.024656e+09f}, +}; +INSTANTIATE_TEST_SUITE_P(ParserImplFloatLiteralTest_HexFloat, + ParserImplFloatLiteralTest, + testing::ValuesIn(hexfloat_literal_test_cases)); + +TEST_F(ParserImplTest, ConstLiteral_FloatHighest) { + const auto highest = std::numeric_limits::max(); + const auto expected_highest = 340282346638528859811704183484516925440.0f; + if (highest < expected_highest || highest > expected_highest) { + GTEST_SKIP() << "std::numeric_limits::max() is not as expected for " + "this target"; + } + auto p = parser("340282346638528859811704183484516925440.0"); + auto c = p->const_literal(); + EXPECT_TRUE(c.matched); + EXPECT_FALSE(c.errored); + EXPECT_FALSE(p->has_error()) << p->error(); + ASSERT_NE(c.value, nullptr); + ASSERT_TRUE(c->Is()); + EXPECT_FLOAT_EQ(c->As()->value(), + std::numeric_limits::max()); + EXPECT_EQ(c->source().range, (Source::Range{{1u, 1u}, {1u, 42u}})); +} + +TEST_F(ParserImplTest, ConstLiteral_FloatLowest) { + // Some compilers complain if you test floating point numbers for equality. + // So say it via two inequalities. + const auto lowest = std::numeric_limits::lowest(); + const auto expected_lowest = -340282346638528859811704183484516925440.0f; + if (lowest < expected_lowest || lowest > expected_lowest) { + GTEST_SKIP() + << "std::numeric_limits::lowest() is not as expected for " + "this target"; + } + + auto p = parser("-340282346638528859811704183484516925440.0"); + auto c = p->const_literal(); + EXPECT_TRUE(c.matched); + EXPECT_FALSE(c.errored); + EXPECT_FALSE(p->has_error()) << p->error(); + ASSERT_NE(c.value, nullptr); + ASSERT_TRUE(c->Is()); + EXPECT_FLOAT_EQ(c->As()->value(), + std::numeric_limits::lowest()); + EXPECT_EQ(c->source().range, (Source::Range{{1u, 1u}, {1u, 43u}})); +} + TEST_F(ParserImplTest, ConstLiteral_True) { auto p = parser("true"); auto c = p->const_literal(); EXPECT_TRUE(c.matched); EXPECT_FALSE(c.errored); - EXPECT_FALSE(p->has_error()); + EXPECT_FALSE(p->has_error()) << p->error(); ASSERT_NE(c.value, nullptr); ASSERT_TRUE(c->Is()); EXPECT_TRUE(c->As()->IsTrue()); @@ -80,7 +311,7 @@ TEST_F(ParserImplTest, ConstLiteral_False) { auto c = p->const_literal(); EXPECT_TRUE(c.matched); EXPECT_FALSE(c.errored); - EXPECT_FALSE(p->has_error()); + EXPECT_FALSE(p->has_error()) << p->error(); ASSERT_NE(c.value, nullptr); ASSERT_TRUE(c->Is()); EXPECT_TRUE(c->As()->IsFalse()); @@ -92,7 +323,7 @@ TEST_F(ParserImplTest, ConstLiteral_NoMatch) { auto c = p->const_literal(); EXPECT_FALSE(c.matched); EXPECT_FALSE(c.errored); - EXPECT_FALSE(p->has_error()); + EXPECT_FALSE(p->has_error()) << p->error(); ASSERT_EQ(c.value, nullptr); }