tint/resolver: Split constant checking to utility

And add basic support for builtins returning structures.

Bug tint:1581

Change-Id: I67f987339b9a344e1915c69c9991803f0665305d
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111242
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Ben Clayton 2022-11-22 19:53:21 +00:00 committed by Dawn LUCI CQ
parent d6d30f4256
commit 6559903234
3 changed files with 112 additions and 72 deletions

View File

@ -106,19 +106,7 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
ASSERT_NE(value, nullptr); ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type()); EXPECT_TYPE(value->Type(), sem->Type());
auto values_flat = ScalarArgsFrom(value); CheckConstant(value, expected);
auto expected_values_flat = expected->Args();
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
auto& a = values_flat.values[i];
auto& b = expected_values_flat.values[i];
EXPECT_EQ(a, b);
if (expected->IsIntegral()) {
// Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
}
}
} else { } else {
ASSERT_FALSE(r()->Resolve()); ASSERT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), c.expected.Failure().error); EXPECT_EQ(r()->error(), c.expected.Failure().error);

View File

@ -26,8 +26,9 @@ namespace {
using resolver::operator<<; using resolver::operator<<;
struct Case { struct Case {
Case(utils::VectorRef<Types> in_args, Types expected_value) Case(utils::VectorRef<Types> in_args, utils::VectorRef<Types> expected_values)
: args(std::move(in_args)), expected(Success{std::move(expected_value), false, false}) {} : args(std::move(in_args)),
expected(Success{std::move(expected_values), CheckConstantFlags{}}) {}
Case(utils::VectorRef<Types> in_args, std::string expected_err) Case(utils::VectorRef<Types> in_args, std::string expected_err)
: args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {} : args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {}
@ -35,7 +36,7 @@ struct Case {
/// Expected value may be positive or negative /// Expected value may be positive or negative
Case& PosOrNeg() { Case& PosOrNeg() {
Success s = expected.Get(); Success s = expected.Get();
s.pos_or_neg = true; s.flags.pos_or_neg = true;
expected = s; expected = s;
return *this; return *this;
} }
@ -43,15 +44,14 @@ struct Case {
/// Expected value should be compared using FLOAT_EQ instead of EQ /// Expected value should be compared using FLOAT_EQ instead of EQ
Case& FloatComp() { Case& FloatComp() {
Success s = expected.Get(); Success s = expected.Get();
s.float_compare = true; s.flags.float_compare = true;
expected = s; expected = s;
return *this; return *this;
} }
struct Success { struct Success {
Types value; utils::Vector<Types, 2> values;
bool pos_or_neg = false; CheckConstantFlags flags;
bool float_compare = false;
}; };
struct Failure { struct Failure {
std::string error; std::string error;
@ -69,7 +69,20 @@ static std::ostream& operator<<(std::ostream& o, const Case& c) {
o << "expected: "; o << "expected: ";
if (c.expected) { if (c.expected) {
auto s = c.expected.Get(); auto s = c.expected.Get();
o << s.value << ", pos_or_neg: " << s.pos_or_neg; if (s.values.Length() == 1) {
o << s.values[0];
} else {
o << "[";
for (auto& v : s.values) {
if (&v != &s.values[0]) {
o << ", ";
}
o << v;
}
o << "]";
}
o << ", pos_or_neg: " << s.flags.pos_or_neg;
o << ", float_compare: " << s.flags.float_compare;
} else { } else {
o << "[ERROR: " << c.expected.Failure().error << "]"; o << "[ERROR: " << c.expected.Failure().error << "]";
} }
@ -80,7 +93,7 @@ using ScalarTypes = std::variant<AInt, AFloat, u32, i32, f32, f16>;
/// Creates a Case with Values for args and result /// Creates a Case with Values for args and result
static Case C(std::initializer_list<Types> args, Types result) { static Case C(std::initializer_list<Types> args, Types result) {
return Case{utils::Vector<Types, 8>{args}, std::move(result)}; return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{std::move(result)}};
} }
/// Convenience overload that creates a Case with just scalars /// Convenience overload that creates a Case with just scalars
@ -91,7 +104,7 @@ static Case C(std::initializer_list<ScalarTypes> sargs, ScalarTypes sresult) {
} }
Types result = Val(0_a); Types result = Val(0_a);
std::visit([&](auto&& v) { result = Val(v); }, sresult); std::visit([&](auto&& v) { result = Val(v); }, sresult);
return Case{std::move(args), std::move(result)}; return Case{std::move(args), utils::Vector<Types, 2>{std::move(result)}};
} }
/// Creates a Case with Values for args and expected error /// Creates a Case with Values for args and expected error
@ -127,9 +140,6 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
if (c.expected) { if (c.expected) {
auto expected_case = c.expected.Get(); auto expected_case = c.expected.Get();
auto* expected_expr = ToValueBase(expected_case.value)->Expr(*this);
GlobalConst("E", expected_expr);
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr); auto* sem = Sem().Get(expr);
@ -138,43 +148,19 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
ASSERT_NE(value, nullptr); ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type()); EXPECT_TYPE(value->Type(), sem->Type());
auto* expected_sem = Sem().Get(expected_expr); if (value->Type()->Is<sem::Struct>()) {
const sem::Constant* expected_value = expected_sem->ConstantValue(); // The result type of the constant-evaluated expression is a structure.
ASSERT_NE(expected_value, nullptr); // Compare each of the fields individually.
EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); for (size_t i = 0; i < expected_case.values.Length(); i++) {
CheckConstant(value->Index(i), ToValueBase(expected_case.values[i]),
// @TODO(amaiorano): Rewrite using ScalarArgsFrom() expected_case.flags);
ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) { }
std::visit( } else {
[&](auto&& ct_expected) { // Return type is not a structure. Just compare the single value
using T = typename std::decay_t<decltype(ct_expected)>::ElementType; ASSERT_EQ(expected_case.values.Length(), 1u)
<< "const-eval returned non-struct, but Case expected multiple values";
auto v = a->As<T>(); CheckConstant(value, ToValueBase(expected_case.values[0]), expected_case.flags);
auto e = b->As<T>(); }
if constexpr (std::is_same_v<bool, T>) {
EXPECT_EQ(v, e);
} else if constexpr (IsFloatingPoint<T>) {
if (std::isnan(e)) {
EXPECT_TRUE(std::isnan(v));
} else {
auto vf = (expected_case.pos_or_neg ? Abs(v) : v);
if (expected_case.float_compare) {
EXPECT_FLOAT_EQ(vf, e);
} else {
EXPECT_EQ(vf, e);
}
}
} else {
EXPECT_EQ((expected_case.pos_or_neg ? Abs(v) : v), e);
// Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
}
},
expected_case.value);
return HasFailure() ? Action::kStop : Action::kContinue;
});
} else { } else {
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), c.expected.Failure().error); EXPECT_EQ(r()->error(), c.expected.Failure().error);

View File

@ -61,6 +61,81 @@ inline builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) {
return out; return out;
} }
template <typename T>
inline auto Abs(const Number<T>& v) {
if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
return v;
} else {
return Number<T>(std::abs(v));
}
}
/// Flags that can be passed to CheckConstant()
struct CheckConstantFlags {
/// Expected value may be positive or negative
bool pos_or_neg = false;
/// Expected value should be compared using FLOAT_EQ instead of EQ
bool float_compare = false;
};
/// CheckConstant checks that @p got_constant, the result value of
/// constant-evaluation is equal to @p expected_value.
/// @param got_constant the constant value evaluated by the resolver
/// @param expected_value the expected value for the test
/// @param flags optional flags for controlling the comparisons
inline void CheckConstant(const sem::Constant* got_constant,
const builder::ValueBase* expected_value,
CheckConstantFlags flags = {}) {
auto values_flat = ScalarArgsFrom(got_constant);
auto expected_values_flat = expected_value->Args();
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
auto& got_scalar = values_flat.values[i];
auto& expected_scalar = expected_values_flat.values[i];
std::visit(
[&](auto&& expected) {
using T = std::decay_t<decltype(expected)>;
ASSERT_TRUE(std::holds_alternative<T>(got_scalar));
auto got = std::get<T>(got_scalar);
if constexpr (std::is_same_v<bool, T>) {
EXPECT_EQ(got, expected);
} else if constexpr (IsFloatingPoint<T>) {
if (std::isnan(expected)) {
EXPECT_TRUE(std::isnan(got));
} else {
if (flags.pos_or_neg) {
auto got_abs = Abs(got);
if (flags.float_compare) {
EXPECT_FLOAT_EQ(got_abs, expected);
} else {
EXPECT_EQ(got_abs, expected);
}
} else {
if (flags.float_compare) {
EXPECT_FLOAT_EQ(got, expected);
} else {
EXPECT_EQ(got, expected);
}
}
}
} else {
if (flags.pos_or_neg) {
auto got_abs = Abs(got);
EXPECT_EQ(got_abs, expected);
} else {
EXPECT_EQ(got, expected);
}
// Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(AInt(got), AInt(expected));
}
},
expected_scalar);
}
}
template <typename T> template <typename T>
inline constexpr auto Negate(const Number<T>& v) { inline constexpr auto Negate(const Number<T>& v) {
if constexpr (std::is_integral_v<T>) { if constexpr (std::is_integral_v<T>) {
@ -85,15 +160,6 @@ inline constexpr auto Negate(const Number<T>& v) {
} }
} }
template <typename T>
inline auto Abs(const Number<T>& v) {
if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
return v;
} else {
return Number<T>(std::abs(v));
}
}
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW); TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
template <typename T> template <typename T>
inline constexpr Number<T> Mul(Number<T> v1, Number<T> v2) { inline constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {