From 6559903234da5967dd4ac06b545a9e81dab91dbf Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Tue, 22 Nov 2022 19:53:21 +0000 Subject: [PATCH] tint/resolver: Split constant checking to utility And add basic support for builtins returning structures. Bug tint:1581 Change-Id: I67f987339b9a344e1915c69c9991803f0665305d Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111242 Kokoro: Kokoro Commit-Queue: Ben Clayton Reviewed-by: Antonio Maiorano --- .../resolver/const_eval_binary_op_test.cc | 14 +-- src/tint/resolver/const_eval_builtin_test.cc | 86 ++++++++----------- src/tint/resolver/const_eval_test.h | 84 ++++++++++++++++-- 3 files changed, 112 insertions(+), 72 deletions(-) diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc index bbca0d52d5..70176c7d35 100644 --- a/src/tint/resolver/const_eval_binary_op_test.cc +++ b/src/tint/resolver/const_eval_binary_op_test.cc @@ -106,19 +106,7 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) { ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); - auto values_flat = ScalarArgsFrom(value); - auto expected_values_flat = expected->Args(); - ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length()); - for (size_t i = 0; i < values_flat.values.Length(); ++i) { - auto& a = values_flat.values[i]; - auto& b = expected_values_flat.values[i]; - EXPECT_EQ(a, b); - if (expected->IsIntegral()) { - // Check that the constant's integer doesn't contain unexpected - // data in the MSBs that are outside of the bit-width of T. - EXPECT_EQ(builder::As(a), builder::As(b)); - } - } + CheckConstant(value, expected); } else { ASSERT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), c.expected.Failure().error); diff --git a/src/tint/resolver/const_eval_builtin_test.cc b/src/tint/resolver/const_eval_builtin_test.cc index 7fbdf30e38..6c6d193fb8 100644 --- a/src/tint/resolver/const_eval_builtin_test.cc +++ b/src/tint/resolver/const_eval_builtin_test.cc @@ -26,8 +26,9 @@ namespace { using resolver::operator<<; struct Case { - Case(utils::VectorRef in_args, Types expected_value) - : args(std::move(in_args)), expected(Success{std::move(expected_value), false, false}) {} + Case(utils::VectorRef in_args, utils::VectorRef expected_values) + : args(std::move(in_args)), + expected(Success{std::move(expected_values), CheckConstantFlags{}}) {} Case(utils::VectorRef in_args, std::string expected_err) : args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {} @@ -35,7 +36,7 @@ struct Case { /// Expected value may be positive or negative Case& PosOrNeg() { Success s = expected.Get(); - s.pos_or_neg = true; + s.flags.pos_or_neg = true; expected = s; return *this; } @@ -43,15 +44,14 @@ struct Case { /// Expected value should be compared using FLOAT_EQ instead of EQ Case& FloatComp() { Success s = expected.Get(); - s.float_compare = true; + s.flags.float_compare = true; expected = s; return *this; } struct Success { - Types value; - bool pos_or_neg = false; - bool float_compare = false; + utils::Vector values; + CheckConstantFlags flags; }; struct Failure { std::string error; @@ -69,7 +69,20 @@ static std::ostream& operator<<(std::ostream& o, const Case& c) { o << "expected: "; if (c.expected) { auto s = c.expected.Get(); - o << s.value << ", pos_or_neg: " << s.pos_or_neg; + if (s.values.Length() == 1) { + o << s.values[0]; + } else { + o << "["; + for (auto& v : s.values) { + if (&v != &s.values[0]) { + o << ", "; + } + o << v; + } + o << "]"; + } + o << ", pos_or_neg: " << s.flags.pos_or_neg; + o << ", float_compare: " << s.flags.float_compare; } else { o << "[ERROR: " << c.expected.Failure().error << "]"; } @@ -80,7 +93,7 @@ using ScalarTypes = std::variant; /// Creates a Case with Values for args and result static Case C(std::initializer_list args, Types result) { - return Case{utils::Vector{args}, std::move(result)}; + return Case{utils::Vector{args}, utils::Vector{std::move(result)}}; } /// Convenience overload that creates a Case with just scalars @@ -91,7 +104,7 @@ static Case C(std::initializer_list sargs, ScalarTypes sresult) { } Types result = Val(0_a); std::visit([&](auto&& v) { result = Val(v); }, sresult); - return Case{std::move(args), std::move(result)}; + return Case{std::move(args), utils::Vector{std::move(result)}}; } /// Creates a Case with Values for args and expected error @@ -127,9 +140,6 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) { if (c.expected) { auto expected_case = c.expected.Get(); - auto* expected_expr = ToValueBase(expected_case.value)->Expr(*this); - GlobalConst("E", expected_expr); - ASSERT_TRUE(r()->Resolve()) << r()->error(); auto* sem = Sem().Get(expr); @@ -138,43 +148,19 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) { ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); - auto* expected_sem = Sem().Get(expected_expr); - const sem::Constant* expected_value = expected_sem->ConstantValue(); - ASSERT_NE(expected_value, nullptr); - EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); - - // @TODO(amaiorano): Rewrite using ScalarArgsFrom() - ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) { - std::visit( - [&](auto&& ct_expected) { - using T = typename std::decay_t::ElementType; - - auto v = a->As(); - auto e = b->As(); - if constexpr (std::is_same_v) { - EXPECT_EQ(v, e); - } else if constexpr (IsFloatingPoint) { - if (std::isnan(e)) { - EXPECT_TRUE(std::isnan(v)); - } else { - auto vf = (expected_case.pos_or_neg ? Abs(v) : v); - if (expected_case.float_compare) { - EXPECT_FLOAT_EQ(vf, e); - } else { - EXPECT_EQ(vf, e); - } - } - } else { - EXPECT_EQ((expected_case.pos_or_neg ? Abs(v) : v), e); - // Check that the constant's integer doesn't contain unexpected - // data in the MSBs that are outside of the bit-width of T. - EXPECT_EQ(a->As(), b->As()); - } - }, - expected_case.value); - - return HasFailure() ? Action::kStop : Action::kContinue; - }); + if (value->Type()->Is()) { + // The result type of the constant-evaluated expression is a structure. + // Compare each of the fields individually. + for (size_t i = 0; i < expected_case.values.Length(); i++) { + CheckConstant(value->Index(i), ToValueBase(expected_case.values[i]), + expected_case.flags); + } + } else { + // Return type is not a structure. Just compare the single value + ASSERT_EQ(expected_case.values.Length(), 1u) + << "const-eval returned non-struct, but Case expected multiple values"; + CheckConstant(value, ToValueBase(expected_case.values[0]), expected_case.flags); + } } else { EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), c.expected.Failure().error); diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h index cccbf769e4..908d7e2773 100644 --- a/src/tint/resolver/const_eval_test.h +++ b/src/tint/resolver/const_eval_test.h @@ -61,6 +61,81 @@ inline builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) { return out; } +template +inline auto Abs(const Number& v) { + if constexpr (std::is_integral_v && std::is_unsigned_v) { + return v; + } else { + return Number(std::abs(v)); + } +} + +/// Flags that can be passed to CheckConstant() +struct CheckConstantFlags { + /// Expected value may be positive or negative + bool pos_or_neg = false; + /// Expected value should be compared using FLOAT_EQ instead of EQ + bool float_compare = false; +}; + +/// CheckConstant checks that @p got_constant, the result value of +/// constant-evaluation is equal to @p expected_value. +/// @param got_constant the constant value evaluated by the resolver +/// @param expected_value the expected value for the test +/// @param flags optional flags for controlling the comparisons +inline void CheckConstant(const sem::Constant* got_constant, + const builder::ValueBase* expected_value, + CheckConstantFlags flags = {}) { + auto values_flat = ScalarArgsFrom(got_constant); + auto expected_values_flat = expected_value->Args(); + ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length()); + for (size_t i = 0; i < values_flat.values.Length(); ++i) { + auto& got_scalar = values_flat.values[i]; + auto& expected_scalar = expected_values_flat.values[i]; + std::visit( + [&](auto&& expected) { + using T = std::decay_t; + + ASSERT_TRUE(std::holds_alternative(got_scalar)); + auto got = std::get(got_scalar); + + if constexpr (std::is_same_v) { + EXPECT_EQ(got, expected); + } else if constexpr (IsFloatingPoint) { + if (std::isnan(expected)) { + EXPECT_TRUE(std::isnan(got)); + } else { + if (flags.pos_or_neg) { + auto got_abs = Abs(got); + if (flags.float_compare) { + EXPECT_FLOAT_EQ(got_abs, expected); + } else { + EXPECT_EQ(got_abs, expected); + } + } else { + if (flags.float_compare) { + EXPECT_FLOAT_EQ(got, expected); + } else { + EXPECT_EQ(got, expected); + } + } + } + } else { + if (flags.pos_or_neg) { + auto got_abs = Abs(got); + EXPECT_EQ(got_abs, expected); + } else { + EXPECT_EQ(got, expected); + } + // Check that the constant's integer doesn't contain unexpected + // data in the MSBs that are outside of the bit-width of T. + EXPECT_EQ(AInt(got), AInt(expected)); + } + }, + expected_scalar); + } +} + template inline constexpr auto Negate(const Number& v) { if constexpr (std::is_integral_v) { @@ -85,15 +160,6 @@ inline constexpr auto Negate(const Number& v) { } } -template -inline auto Abs(const Number& v) { - if constexpr (std::is_integral_v && std::is_unsigned_v) { - return v; - } else { - return Number(std::abs(v)); - } -} - TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW); template inline constexpr Number Mul(Number v1, Number v2) {