tint: add vector cases for const eval builtin tests

- Make the const eval builtin tests use the same framework as the
unary/binary op tests, allowing for Vector cases.
- No longer always use float compare, instead enable it per case.
  Currently this is necessary because atan2 doesn't always return the
  same constant for PI on all platforms.
- Add vector cases for atan2 and clamp.

Bug: tint:1581
Change-Id: I7eaec10b4f9685c913a9d0d17b47c413f659be7a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/104424
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2022-10-04 22:40:32 +00:00 committed by Dawn LUCI CQ
parent fc7994ba4c
commit 6b3f4aaf26
2 changed files with 133 additions and 65 deletions

View File

@ -73,9 +73,9 @@ template <typename T>
using UnwrapNumber = typename detail::NumberUnwrapper<T>::type; using UnwrapNumber = typename detail::NumberUnwrapper<T>::type;
/// Evaluates to true iff T or Number<T> is a floating-point type or is NumberKindF16. /// Evaluates to true iff T or Number<T> is a floating-point type or is NumberKindF16.
template <typename T, typename U = std::conditional_t<IsNumber<T>, UnwrapNumber<T>, T>> template <typename T>
constexpr bool IsFloatingPoint = constexpr bool IsFloatingPoint = std::is_floating_point_v<UnwrapNumber<T>> ||
std::is_floating_point_v<U> || std::is_same_v<T, detail::NumberKindF16>; std::is_same_v<UnwrapNumber<T>, detail::NumberKindF16>;
/// Evaluates to true iff T or Number<T> is an integral type. /// Evaluates to true iff T or Number<T> is an integral type.
template <typename T> template <typename T>

View File

@ -2981,6 +2981,7 @@ struct BitValues {
// Unary op // Unary op
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
namespace unary_op { namespace unary_op {
// Bring in std::ostream& operator<<(std::ostream& o, const Types& types)
using resolver::operator<<; using resolver::operator<<;
struct Case { struct Case {
@ -3035,7 +3036,7 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) {
ForEachElemPair(value, expected_value, ForEachElemPair(value, expected_value,
[&](const sem::Constant* a, const sem::Constant* b) { [&](const sem::Constant* a, const sem::Constant* b) {
EXPECT_EQ(a->As<T>(), b->As<T>()); EXPECT_EQ(a->As<T>(), b->As<T>());
if constexpr (IsIntegral<UnwrapNumber<T>>) { if constexpr (IsIntegral<T>) {
// Check that the constant's integer doesn't contain unexpected // Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T. // data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(a->As<AInt>(), b->As<AInt>()); EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
@ -3151,6 +3152,7 @@ INSTANTIATE_TEST_SUITE_P(Not,
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
namespace binary_op { namespace binary_op {
// Bring in std::ostream& operator<<(std::ostream& o, const Types& types)
using resolver::operator<<; using resolver::operator<<;
struct Case { struct Case {
@ -3166,7 +3168,7 @@ Case C(Value<T> lhs, Value<U> rhs, Value<V> expected, bool overflow = false) {
return Case{std::move(lhs), std::move(rhs), std::move(expected), overflow}; return Case{std::move(lhs), std::move(rhs), std::move(expected), overflow};
} }
/// Convenience overload to creates a Case with just scalars /// Convenience overload that creates a Case with just scalars
template <typename T, typename U, typename V, typename = std::enable_if_t<!IsValue<T>>> template <typename T, typename U, typename V, typename = std::enable_if_t<!IsValue<T>>>
Case C(T lhs, U rhs, V expected, bool overflow = false) { Case C(T lhs, U rhs, V expected, bool overflow = false) {
return Case{Val(lhs), Val(rhs), Val(expected), overflow}; return Case{Val(lhs), Val(rhs), Val(expected), overflow};
@ -3216,7 +3218,7 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
ForEachElemPair(value, expected_value, ForEachElemPair(value, expected_value,
[&](const sem::Constant* a, const sem::Constant* b) { [&](const sem::Constant* a, const sem::Constant* b) {
EXPECT_EQ(a->As<T>(), b->As<T>()); EXPECT_EQ(a->As<T>(), b->As<T>());
if constexpr (IsIntegral<UnwrapNumber<T>>) { if constexpr (IsIntegral<T>) {
// Check that the constant's integer doesn't contain unexpected // Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T. // data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(a->As<AInt>(), b->As<AInt>()); EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
@ -3238,7 +3240,7 @@ INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs,
template <typename T> template <typename T>
std::vector<Case> OpAddIntCases() { std::vector<Case> OpAddIntCases() {
static_assert(IsIntegral<UnwrapNumber<T>>); static_assert(IsIntegral<T>);
return { return {
C(T{0}, T{0}, T{0}), C(T{0}, T{0}, T{0}),
C(T{1}, T{2}, T{3}), C(T{1}, T{2}, T{3}),
@ -3251,7 +3253,7 @@ std::vector<Case> OpAddIntCases() {
} }
template <typename T> template <typename T>
std::vector<Case> OpAddFloatCases() { std::vector<Case> OpAddFloatCases() {
static_assert(IsFloatingPoint<UnwrapNumber<T>>); static_assert(IsFloatingPoint<T>);
return { return {
C(T{0}, T{0}, T{0}), C(T{0}, T{0}, T{0}),
C(T{1}, T{2}, T{3}), C(T{1}, T{2}, T{3}),
@ -3275,7 +3277,7 @@ INSTANTIATE_TEST_SUITE_P(Add,
template <typename T> template <typename T>
std::vector<Case> OpSubIntCases() { std::vector<Case> OpSubIntCases() {
static_assert(IsIntegral<UnwrapNumber<T>>); static_assert(IsIntegral<T>);
return { return {
C(T{0}, T{0}, T{0}), C(T{0}, T{0}, T{0}),
C(T{3}, T{2}, T{1}), C(T{3}, T{2}, T{1}),
@ -3288,7 +3290,7 @@ std::vector<Case> OpSubIntCases() {
} }
template <typename T> template <typename T>
std::vector<Case> OpSubFloatCases() { std::vector<Case> OpSubFloatCases() {
static_assert(IsFloatingPoint<UnwrapNumber<T>>); static_assert(IsFloatingPoint<T>);
return { return {
C(T{0}, T{0}, T{0}), C(T{0}, T{0}, T{0}),
C(T{3}, T{2}, T{1}), C(T{3}, T{2}, T{1}),
@ -4051,25 +4053,56 @@ INSTANTIATE_TEST_SUITE_P(Test,
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
namespace builtin { namespace builtin {
// Bring in std::ostream& operator<<(std::ostream& o, const Types& types)
using Types = std::variant<AInt, AFloat, u32, i32, f32, f16>; using resolver::operator<<;
struct Case { struct Case {
Case(utils::VectorRef<Types> in_args, Types in_expected)
: args(std::move(in_args)), expected(std::move(in_expected)) {}
/// Expected value may be positive or negative
Case& PosOrNeg() {
expected_pos_or_neg = true;
return *this;
}
/// Expected value should be compared using FLOAT_EQ instead of EQ
Case& FloatComp() {
float_compare = true;
return *this;
}
utils::Vector<Types, 8> args; utils::Vector<Types, 8> args;
Types result; Types expected;
bool result_pos_or_neg; bool expected_pos_or_neg = false;
bool float_compare = false;
}; };
static std::ostream& operator<<(std::ostream& o, const Case& c) { static std::ostream& operator<<(std::ostream& o, const Case& c) {
o << "args: ";
for (auto& a : c.args) { for (auto& a : c.args) {
std::visit([&](auto&& v) { o << v << ((&a != &c.args.Back()) ? " " : ""); }, a); o << a << ", ";
} }
o << "expected: " << c.expected << ", expected_pos_or_neg: " << c.expected_pos_or_neg;
return o; return o;
} }
template <typename T> /// Creates a Case with Values for args and result
Case C(std::initializer_list<Types> args, T result, bool result_pos_or_neg = false) { // template <typename T>
return Case{std::move(args), std::move(result), result_pos_or_neg}; static Case C(std::initializer_list<Types> args, Types result) {
return Case{utils::Vector<Types, 8>{args}, std::move(result)};
}
/// Convenience overload that creates a Case with just scalars
using ScalarTypes = std::variant<AInt, AFloat, u32, i32, f32, f16>;
static Case C(std::initializer_list<ScalarTypes> sargs, ScalarTypes sresult) {
utils::Vector<Types, 8> args;
for (auto& sa : sargs) {
std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa);
}
Types result = Val(0_a);
std::visit([&](auto&& v) { result = Val(v); }, sresult);
return Case{std::move(args), std::move(result)};
} }
using ResolverConstEvalBuiltinTest = ResolverTestWithParam<std::tuple<sem::BuiltinType, Case>>; using ResolverConstEvalBuiltinTest = ResolverTestWithParam<std::tuple<sem::BuiltinType, Case>>;
@ -4078,19 +4111,21 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
Enable(ast::Extension::kF16); Enable(ast::Extension::kF16);
auto builtin = std::get<0>(GetParam()); auto builtin = std::get<0>(GetParam());
auto c = std::get<1>(GetParam()); auto& c = std::get<1>(GetParam());
utils::Vector<const ast::Expression*, 8> args; utils::Vector<const ast::Expression*, 8> args;
for (auto& a : c.args) { for (auto& a : c.args) {
std::visit([&](auto&& v) { args.Push(Expr(v)); }, a); std::visit([&](auto&& v) { args.Push(v.Expr(*this)); }, a);
} }
std::visit( std::visit(
[&](auto&& result) { [&](auto&& expected) {
using T = std::decay_t<decltype(result)>; using T = typename std::decay_t<decltype(expected)>::ElementType;
auto* expr = Call(sem::str(builtin), std::move(args)); auto* expr = Call(sem::str(builtin), std::move(args));
GlobalConst("C", expr); GlobalConst("C", expr);
auto* expected_expr = expected.Expr(*this);
GlobalConst("E", expected_expr);
EXPECT_TRUE(r()->Resolve()) << r()->error(); EXPECT_TRUE(r()->Resolve()) << r()->error();
@ -4099,64 +4134,92 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
ASSERT_NE(value, nullptr); ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type()); EXPECT_TYPE(value->Type(), sem->Type());
auto actual = value->As<T>(); auto* expected_sem = Sem().Get(expected_expr);
const sem::Constant* expected_value = expected_sem->ConstantValue();
ASSERT_NE(expected_value, nullptr);
EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
if constexpr (IsFloatingPoint<UnwrapNumber<T>>) { ForEachElemPair(value, expected_value,
if (std::isnan(result)) { [&](const sem::Constant* a, const sem::Constant* b) {
EXPECT_TRUE(std::isnan(actual)); auto v = a->As<T>();
} else { auto e = b->As<T>();
EXPECT_FLOAT_EQ(c.result_pos_or_neg ? Abs(actual) : actual, result); if constexpr (std::is_same_v<bool, T>) {
} EXPECT_EQ(v, e);
} else { } else if constexpr (IsFloatingPoint<T>) {
EXPECT_EQ(c.result_pos_or_neg ? Abs(actual) : actual, result); if (std::isnan(e)) {
} EXPECT_TRUE(std::isnan(v));
} else {
if constexpr (IsIntegral<UnwrapNumber<T>>) { auto vf = (c.expected_pos_or_neg ? Abs(v) : v);
// Check that the constant's integer doesn't contain unexpected data in the MSBs if (c.float_compare) {
// that are outside of the bit-width of T. EXPECT_FLOAT_EQ(vf, e);
EXPECT_EQ(value->As<AInt>(), AInt(result)); } else {
} EXPECT_EQ(vf, e);
}
}
} else {
EXPECT_EQ((c.expected_pos_or_neg ? Abs(v) : v), e);
// Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
}
return HasFailure() ? Action::kStop : Action::kContinue;
});
}, },
c.result); c.expected);
} }
INSTANTIATE_TEST_SUITE_P( //
MixedAbstractArgs,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kAtan2),
testing::ValuesIn(std::vector{
C({0_a, -0.0_a}, kPi<AFloat>),
C({1.0_a, 0_a}, kPiOver2<AFloat>),
})));
template <typename T, bool finite_only> template <typename T, bool finite_only>
std::vector<Case> Atan2Cases() { std::vector<Case> Atan2Cases() {
std::vector<Case> cases = { std::vector<Case> cases = {
// If y is +/-0 and x is negative or -0, +/-PI is returned // If y is +/-0 and x is negative or -0, +/-PI is returned
C({T(0.0), -T(0.0)}, kPi<T>, true), C({T(0.0), -T(0.0)}, kPi<T>).PosOrNeg().FloatComp(),
// If y is +/-0 and x is positive or +0, +/-0 is returned // If y is +/-0 and x is positive or +0, +/-0 is returned
C({T(0.0), T(0.0)}, T(0.0), true), C({T(0.0), T(0.0)}, T(0.0)).PosOrNeg(),
// If x is +/-0 and y is negative, -PI/2 is returned // If x is +/-0 and y is negative, -PI/2 is returned
C({-T(1.0), T(0.0)}, -kPiOver2<T>), C({-T(1.0), T(0.0)}, -kPiOver2<T>).FloatComp(), //
C({-T(1.0), -T(0.0)}, -kPiOver2<T>), C({-T(1.0), -T(0.0)}, -kPiOver2<T>).FloatComp(),
// If x is +/-0 and y is positive, +PI/2 is returned // If x is +/-0 and y is positive, +PI/2 is returned
C({T(1.0), T(0.0)}, kPiOver2<T>), C({T(1.0), T(0.0)}, kPiOver2<T>).FloatComp(), //
C({T(1.0), -T(0.0)}, kPiOver2<T>), C({T(1.0), -T(0.0)}, kPiOver2<T>).FloatComp(),
// Vector tests
C({Vec(T(0.0), T(0.0)), Vec(-T(0.0), T(0.0))}, Vec(kPi<T>, T(0.0))).PosOrNeg().FloatComp(),
C({Vec(-T(1.0), -T(1.0)), Vec(T(0.0), -T(0.0))}, Vec(-kPiOver2<T>, -kPiOver2<T>))
.FloatComp(),
C({Vec(T(1.0), T(1.0)), Vec(T(0.0), -T(0.0))}, Vec(kPiOver2<T>, kPiOver2<T>)).FloatComp(),
}; };
if constexpr (!finite_only) { if constexpr (!finite_only) {
std::vector<Case> non_finite_cases = { std::vector<Case> non_finite_cases = {
// If y is +/-INF and x is finite, +/-PI/2 is returned // If y is +/-INF and x is finite, +/-PI/2 is returned
C({T::Inf(), T(0.0)}, kPiOver2<T>, true), C({T::Inf(), T(0.0)}, kPiOver2<T>).PosOrNeg().FloatComp(),
C({-T::Inf(), T(0.0)}, kPiOver2<T>, true), C({-T::Inf(), T(0.0)}, kPiOver2<T>).PosOrNeg().FloatComp(),
// If y is +/-INF and x is -INF, +/-3PI/4 is returned // If y is +/-INF and x is -INF, +/-3PI/4 is returned
C({T::Inf(), -T::Inf()}, k3PiOver4<T>, true), C({T::Inf(), -T::Inf()}, k3PiOver4<T>).PosOrNeg().FloatComp(),
C({-T::Inf(), -T::Inf()}, k3PiOver4<T>, true), C({-T::Inf(), -T::Inf()}, k3PiOver4<T>).PosOrNeg().FloatComp(),
// If y is +/-INF and x is +INF, +/-PI/4 is returned // If y is +/-INF and x is +INF, +/-PI/4 is returned
C({T::Inf(), T::Inf()}, kPiOver4<T>, true), C({T::Inf(), T::Inf()}, kPiOver4<T>).PosOrNeg().FloatComp(),
C({-T::Inf(), T::Inf()}, kPiOver4<T>, true), C({-T::Inf(), T::Inf()}, kPiOver4<T>).PosOrNeg().FloatComp(),
// If x is -INF and y is finite and positive, +PI is returned // If x is -INF and y is finite and positive, +PI is returned
C({T(0.0), -T::Inf()}, kPi<T>), C({T(0.0), -T::Inf()}, kPi<T>).FloatComp(),
// If x is -INF and y is finite and negative, -PI is returned // If x is -INF and y is finite and negative, -PI is returned
C({-T(0.0), -T::Inf()}, -kPi<T>), C({-T(0.0), -T::Inf()}, -kPi<T>).FloatComp(),
// If x is +INF and y is finite and positive, +0 is returned // If x is +INF and y is finite and positive, +0 is returned
C({T(0.0), T::Inf()}, T(0.0)), C({T(0.0), T::Inf()}, T(0.0)),
@ -4168,23 +4231,19 @@ std::vector<Case> Atan2Cases() {
C({T::NaN(), T(0.0)}, T::NaN()), C({T::NaN(), T(0.0)}, T::NaN()),
C({T(0.0), T::NaN()}, T::NaN()), C({T(0.0), T::NaN()}, T::NaN()),
C({T::NaN(), T::NaN()}, T::NaN()), C({T::NaN(), T::NaN()}, T::NaN()),
};
// Vector tests
C({Vec(T::Inf(), -T::Inf(), T::Inf(), -T::Inf()), //
Vec(T(0.0), T(0.0), -T::Inf(), -T::Inf())}, //
Vec(kPiOver2<T>, kPiOver2<T>, k3PiOver4<T>, k3PiOver4<T>))
.PosOrNeg()
.FloatComp(),
};
cases = Concat(cases, non_finite_cases); cases = Concat(cases, non_finite_cases);
} }
return cases; return cases;
} }
INSTANTIATE_TEST_SUITE_P( //
MixedAbstractArgs,
ResolverConstEvalBuiltinTest,
testing::Combine(testing::Values(sem::BuiltinType::kAtan2),
testing::ValuesIn(std::vector{
C({1_a, 1.0_a}, 0.78539819_a),
C({1.0_a, 1_a}, 0.78539819_a),
})));
INSTANTIATE_TEST_SUITE_P( // INSTANTIATE_TEST_SUITE_P( //
Atan2, Atan2,
ResolverConstEvalBuiltinTest, ResolverConstEvalBuiltinTest,
@ -4205,9 +4264,18 @@ std::vector<Case> ClampCases() {
C({T::Lowest(), T::Lowest(), T::Lowest()}, T::Lowest()), C({T::Lowest(), T::Lowest(), T::Lowest()}, T::Lowest()),
C({T::Highest(), T::Lowest(), T::Highest()}, T::Highest()), C({T::Highest(), T::Lowest(), T::Highest()}, T::Highest()),
C({T::Lowest(), T::Lowest(), T::Highest()}, T::Lowest()), C({T::Lowest(), T::Lowest(), T::Highest()}, T::Lowest()),
// Vector tests
C({Vec(T(0), T(0)), //
Vec(T(0), T(42)), //
Vec(T(0), T::Highest())}, //
Vec(T(0), T(42))), //
C({Vec(T::Lowest(), T(0), T(0)), //
Vec(T(0), T::Lowest(), T::Highest()), //
Vec(T(42), T::Highest(), T::Lowest())}, //
Vec(T(0), T(0), T::Lowest())),
}; };
} }
INSTANTIATE_TEST_SUITE_P( // INSTANTIATE_TEST_SUITE_P( //
Clamp, Clamp,
ResolverConstEvalBuiltinTest, ResolverConstEvalBuiltinTest,