diff --git a/src/tint/number.h b/src/tint/number.h index 69e598309e..821e486e63 100644 --- a/src/tint/number.h +++ b/src/tint/number.h @@ -73,9 +73,9 @@ template using UnwrapNumber = typename detail::NumberUnwrapper::type; /// Evaluates to true iff T or Number is a floating-point type or is NumberKindF16. -template , UnwrapNumber, T>> -constexpr bool IsFloatingPoint = - std::is_floating_point_v || std::is_same_v; +template +constexpr bool IsFloatingPoint = std::is_floating_point_v> || + std::is_same_v, detail::NumberKindF16>; /// Evaluates to true iff T or Number is an integral type. template diff --git a/src/tint/resolver/const_eval_test.cc b/src/tint/resolver/const_eval_test.cc index 11a7315a52..ef7891aa88 100644 --- a/src/tint/resolver/const_eval_test.cc +++ b/src/tint/resolver/const_eval_test.cc @@ -2981,6 +2981,7 @@ struct BitValues { // Unary op //////////////////////////////////////////////////////////////////////////////////////////////////// namespace unary_op { +// Bring in std::ostream& operator<<(std::ostream& o, const Types& types) using resolver::operator<<; struct Case { @@ -3035,7 +3036,7 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) { ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) { EXPECT_EQ(a->As(), b->As()); - if constexpr (IsIntegral>) { + if constexpr (IsIntegral) { // Check that the constant's integer doesn't contain unexpected // data in the MSBs that are outside of the bit-width of T. EXPECT_EQ(a->As(), b->As()); @@ -3151,6 +3152,7 @@ INSTANTIATE_TEST_SUITE_P(Not, //////////////////////////////////////////////////////////////////////////////////////////////////// namespace binary_op { +// Bring in std::ostream& operator<<(std::ostream& o, const Types& types) using resolver::operator<<; struct Case { @@ -3166,7 +3168,7 @@ Case C(Value lhs, Value rhs, Value expected, bool overflow = false) { return Case{std::move(lhs), std::move(rhs), std::move(expected), overflow}; } -/// Convenience overload to creates a Case with just scalars +/// Convenience overload that creates a Case with just scalars template >> Case C(T lhs, U rhs, V expected, bool overflow = false) { return Case{Val(lhs), Val(rhs), Val(expected), overflow}; @@ -3216,7 +3218,7 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) { ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) { EXPECT_EQ(a->As(), b->As()); - if constexpr (IsIntegral>) { + if constexpr (IsIntegral) { // Check that the constant's integer doesn't contain unexpected // data in the MSBs that are outside of the bit-width of T. EXPECT_EQ(a->As(), b->As()); @@ -3238,7 +3240,7 @@ INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs, template std::vector OpAddIntCases() { - static_assert(IsIntegral>); + static_assert(IsIntegral); return { C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{3}), @@ -3251,7 +3253,7 @@ std::vector OpAddIntCases() { } template std::vector OpAddFloatCases() { - static_assert(IsFloatingPoint>); + static_assert(IsFloatingPoint); return { C(T{0}, T{0}, T{0}), C(T{1}, T{2}, T{3}), @@ -3275,7 +3277,7 @@ INSTANTIATE_TEST_SUITE_P(Add, template std::vector OpSubIntCases() { - static_assert(IsIntegral>); + static_assert(IsIntegral); return { C(T{0}, T{0}, T{0}), C(T{3}, T{2}, T{1}), @@ -3288,7 +3290,7 @@ std::vector OpSubIntCases() { } template std::vector OpSubFloatCases() { - static_assert(IsFloatingPoint>); + static_assert(IsFloatingPoint); return { C(T{0}, T{0}, T{0}), C(T{3}, T{2}, T{1}), @@ -4051,25 +4053,56 @@ INSTANTIATE_TEST_SUITE_P(Test, //////////////////////////////////////////////////////////////////////////////////////////////////// namespace builtin { - -using Types = std::variant; +// Bring in std::ostream& operator<<(std::ostream& o, const Types& types) +using resolver::operator<<; struct Case { + Case(utils::VectorRef in_args, Types in_expected) + : args(std::move(in_args)), expected(std::move(in_expected)) {} + + /// Expected value may be positive or negative + Case& PosOrNeg() { + expected_pos_or_neg = true; + return *this; + } + + /// Expected value should be compared using FLOAT_EQ instead of EQ + Case& FloatComp() { + float_compare = true; + return *this; + } + utils::Vector args; - Types result; - bool result_pos_or_neg; + Types expected; + bool expected_pos_or_neg = false; + bool float_compare = false; }; static std::ostream& operator<<(std::ostream& o, const Case& c) { + o << "args: "; for (auto& a : c.args) { - std::visit([&](auto&& v) { o << v << ((&a != &c.args.Back()) ? " " : ""); }, a); + o << a << ", "; } + o << "expected: " << c.expected << ", expected_pos_or_neg: " << c.expected_pos_or_neg; return o; } -template -Case C(std::initializer_list args, T result, bool result_pos_or_neg = false) { - return Case{std::move(args), std::move(result), result_pos_or_neg}; +/// Creates a Case with Values for args and result +// template +static Case C(std::initializer_list args, Types result) { + return Case{utils::Vector{args}, std::move(result)}; +} + +/// Convenience overload that creates a Case with just scalars +using ScalarTypes = std::variant; +static Case C(std::initializer_list sargs, ScalarTypes sresult) { + utils::Vector args; + for (auto& sa : sargs) { + std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa); + } + Types result = Val(0_a); + std::visit([&](auto&& v) { result = Val(v); }, sresult); + return Case{std::move(args), std::move(result)}; } using ResolverConstEvalBuiltinTest = ResolverTestWithParam>; @@ -4078,19 +4111,21 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) { Enable(ast::Extension::kF16); auto builtin = std::get<0>(GetParam()); - auto c = std::get<1>(GetParam()); + auto& c = std::get<1>(GetParam()); utils::Vector args; for (auto& a : c.args) { - std::visit([&](auto&& v) { args.Push(Expr(v)); }, a); + std::visit([&](auto&& v) { args.Push(v.Expr(*this)); }, a); } std::visit( - [&](auto&& result) { - using T = std::decay_t; + [&](auto&& expected) { + using T = typename std::decay_t::ElementType; auto* expr = Call(sem::str(builtin), std::move(args)); GlobalConst("C", expr); + auto* expected_expr = expected.Expr(*this); + GlobalConst("E", expected_expr); EXPECT_TRUE(r()->Resolve()) << r()->error(); @@ -4099,64 +4134,92 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) { ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); - auto actual = value->As(); + auto* expected_sem = Sem().Get(expected_expr); + const sem::Constant* expected_value = expected_sem->ConstantValue(); + ASSERT_NE(expected_value, nullptr); + EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); - if constexpr (IsFloatingPoint>) { - if (std::isnan(result)) { - EXPECT_TRUE(std::isnan(actual)); - } else { - EXPECT_FLOAT_EQ(c.result_pos_or_neg ? Abs(actual) : actual, result); - } - } else { - EXPECT_EQ(c.result_pos_or_neg ? Abs(actual) : actual, result); - } - - if constexpr (IsIntegral>) { - // Check that the constant's integer doesn't contain unexpected data in the MSBs - // that are outside of the bit-width of T. - EXPECT_EQ(value->As(), AInt(result)); - } + ForEachElemPair(value, expected_value, + [&](const sem::Constant* a, const sem::Constant* b) { + auto v = a->As(); + auto e = b->As(); + if constexpr (std::is_same_v) { + EXPECT_EQ(v, e); + } else if constexpr (IsFloatingPoint) { + if (std::isnan(e)) { + EXPECT_TRUE(std::isnan(v)); + } else { + auto vf = (c.expected_pos_or_neg ? Abs(v) : v); + if (c.float_compare) { + EXPECT_FLOAT_EQ(vf, e); + } else { + EXPECT_EQ(vf, e); + } + } + } else { + EXPECT_EQ((c.expected_pos_or_neg ? Abs(v) : v), e); + // Check that the constant's integer doesn't contain unexpected + // data in the MSBs that are outside of the bit-width of T. + EXPECT_EQ(a->As(), b->As()); + } + return HasFailure() ? Action::kStop : Action::kContinue; + }); }, - c.result); + c.expected); } +INSTANTIATE_TEST_SUITE_P( // + MixedAbstractArgs, + ResolverConstEvalBuiltinTest, + testing::Combine(testing::Values(sem::BuiltinType::kAtan2), + testing::ValuesIn(std::vector{ + C({0_a, -0.0_a}, kPi), + C({1.0_a, 0_a}, kPiOver2), + }))); + template std::vector Atan2Cases() { std::vector cases = { // If y is +/-0 and x is negative or -0, +/-PI is returned - C({T(0.0), -T(0.0)}, kPi, true), + C({T(0.0), -T(0.0)}, kPi).PosOrNeg().FloatComp(), // If y is +/-0 and x is positive or +0, +/-0 is returned - C({T(0.0), T(0.0)}, T(0.0), true), + C({T(0.0), T(0.0)}, T(0.0)).PosOrNeg(), // If x is +/-0 and y is negative, -PI/2 is returned - C({-T(1.0), T(0.0)}, -kPiOver2), - C({-T(1.0), -T(0.0)}, -kPiOver2), + C({-T(1.0), T(0.0)}, -kPiOver2).FloatComp(), // + C({-T(1.0), -T(0.0)}, -kPiOver2).FloatComp(), // If x is +/-0 and y is positive, +PI/2 is returned - C({T(1.0), T(0.0)}, kPiOver2), - C({T(1.0), -T(0.0)}, kPiOver2), + C({T(1.0), T(0.0)}, kPiOver2).FloatComp(), // + C({T(1.0), -T(0.0)}, kPiOver2).FloatComp(), + + // Vector tests + C({Vec(T(0.0), T(0.0)), Vec(-T(0.0), T(0.0))}, Vec(kPi, T(0.0))).PosOrNeg().FloatComp(), + C({Vec(-T(1.0), -T(1.0)), Vec(T(0.0), -T(0.0))}, Vec(-kPiOver2, -kPiOver2)) + .FloatComp(), + C({Vec(T(1.0), T(1.0)), Vec(T(0.0), -T(0.0))}, Vec(kPiOver2, kPiOver2)).FloatComp(), }; if constexpr (!finite_only) { std::vector non_finite_cases = { // If y is +/-INF and x is finite, +/-PI/2 is returned - C({T::Inf(), T(0.0)}, kPiOver2, true), - C({-T::Inf(), T(0.0)}, kPiOver2, true), + C({T::Inf(), T(0.0)}, kPiOver2).PosOrNeg().FloatComp(), + C({-T::Inf(), T(0.0)}, kPiOver2).PosOrNeg().FloatComp(), // If y is +/-INF and x is -INF, +/-3PI/4 is returned - C({T::Inf(), -T::Inf()}, k3PiOver4, true), - C({-T::Inf(), -T::Inf()}, k3PiOver4, true), + C({T::Inf(), -T::Inf()}, k3PiOver4).PosOrNeg().FloatComp(), + C({-T::Inf(), -T::Inf()}, k3PiOver4).PosOrNeg().FloatComp(), // If y is +/-INF and x is +INF, +/-PI/4 is returned - C({T::Inf(), T::Inf()}, kPiOver4, true), - C({-T::Inf(), T::Inf()}, kPiOver4, true), + C({T::Inf(), T::Inf()}, kPiOver4).PosOrNeg().FloatComp(), + C({-T::Inf(), T::Inf()}, kPiOver4).PosOrNeg().FloatComp(), // If x is -INF and y is finite and positive, +PI is returned - C({T(0.0), -T::Inf()}, kPi), + C({T(0.0), -T::Inf()}, kPi).FloatComp(), // If x is -INF and y is finite and negative, -PI is returned - C({-T(0.0), -T::Inf()}, -kPi), + C({-T(0.0), -T::Inf()}, -kPi).FloatComp(), // If x is +INF and y is finite and positive, +0 is returned C({T(0.0), T::Inf()}, T(0.0)), @@ -4168,23 +4231,19 @@ std::vector Atan2Cases() { C({T::NaN(), T(0.0)}, T::NaN()), C({T(0.0), T::NaN()}, T::NaN()), C({T::NaN(), T::NaN()}, T::NaN()), - }; + // Vector tests + C({Vec(T::Inf(), -T::Inf(), T::Inf(), -T::Inf()), // + Vec(T(0.0), T(0.0), -T::Inf(), -T::Inf())}, // + Vec(kPiOver2, kPiOver2, k3PiOver4, k3PiOver4)) + .PosOrNeg() + .FloatComp(), + }; cases = Concat(cases, non_finite_cases); } return cases; } - -INSTANTIATE_TEST_SUITE_P( // - MixedAbstractArgs, - ResolverConstEvalBuiltinTest, - testing::Combine(testing::Values(sem::BuiltinType::kAtan2), - testing::ValuesIn(std::vector{ - C({1_a, 1.0_a}, 0.78539819_a), - C({1.0_a, 1_a}, 0.78539819_a), - }))); - INSTANTIATE_TEST_SUITE_P( // Atan2, ResolverConstEvalBuiltinTest, @@ -4205,9 +4264,18 @@ std::vector ClampCases() { C({T::Lowest(), T::Lowest(), T::Lowest()}, T::Lowest()), C({T::Highest(), T::Lowest(), T::Highest()}, T::Highest()), C({T::Lowest(), T::Lowest(), T::Highest()}, T::Lowest()), + + // Vector tests + C({Vec(T(0), T(0)), // + Vec(T(0), T(42)), // + Vec(T(0), T::Highest())}, // + Vec(T(0), T(42))), // + C({Vec(T::Lowest(), T(0), T(0)), // + Vec(T(0), T::Lowest(), T::Highest()), // + Vec(T(42), T::Highest(), T::Lowest())}, // + Vec(T(0), T(0), T::Lowest())), }; } - INSTANTIATE_TEST_SUITE_P( // Clamp, ResolverConstEvalBuiltinTest,