tint/resolver: Split constant checking to utility
And add basic support for builtins returning structures. Bug tint:1581 Change-Id: I67f987339b9a344e1915c69c9991803f0665305d Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111242 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
d6d30f4256
commit
6559903234
|
@ -106,19 +106,7 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
|
|||
ASSERT_NE(value, nullptr);
|
||||
EXPECT_TYPE(value->Type(), sem->Type());
|
||||
|
||||
auto values_flat = ScalarArgsFrom(value);
|
||||
auto expected_values_flat = expected->Args();
|
||||
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
|
||||
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
|
||||
auto& a = values_flat.values[i];
|
||||
auto& b = expected_values_flat.values[i];
|
||||
EXPECT_EQ(a, b);
|
||||
if (expected->IsIntegral()) {
|
||||
// Check that the constant's integer doesn't contain unexpected
|
||||
// data in the MSBs that are outside of the bit-width of T.
|
||||
EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
|
||||
}
|
||||
}
|
||||
CheckConstant(value, expected);
|
||||
} else {
|
||||
ASSERT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), c.expected.Failure().error);
|
||||
|
|
|
@ -26,8 +26,9 @@ namespace {
|
|||
using resolver::operator<<;
|
||||
|
||||
struct Case {
|
||||
Case(utils::VectorRef<Types> in_args, Types expected_value)
|
||||
: args(std::move(in_args)), expected(Success{std::move(expected_value), false, false}) {}
|
||||
Case(utils::VectorRef<Types> in_args, utils::VectorRef<Types> expected_values)
|
||||
: args(std::move(in_args)),
|
||||
expected(Success{std::move(expected_values), CheckConstantFlags{}}) {}
|
||||
|
||||
Case(utils::VectorRef<Types> in_args, std::string expected_err)
|
||||
: args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {}
|
||||
|
@ -35,7 +36,7 @@ struct Case {
|
|||
/// Expected value may be positive or negative
|
||||
Case& PosOrNeg() {
|
||||
Success s = expected.Get();
|
||||
s.pos_or_neg = true;
|
||||
s.flags.pos_or_neg = true;
|
||||
expected = s;
|
||||
return *this;
|
||||
}
|
||||
|
@ -43,15 +44,14 @@ struct Case {
|
|||
/// Expected value should be compared using FLOAT_EQ instead of EQ
|
||||
Case& FloatComp() {
|
||||
Success s = expected.Get();
|
||||
s.float_compare = true;
|
||||
s.flags.float_compare = true;
|
||||
expected = s;
|
||||
return *this;
|
||||
}
|
||||
|
||||
struct Success {
|
||||
Types value;
|
||||
bool pos_or_neg = false;
|
||||
bool float_compare = false;
|
||||
utils::Vector<Types, 2> values;
|
||||
CheckConstantFlags flags;
|
||||
};
|
||||
struct Failure {
|
||||
std::string error;
|
||||
|
@ -69,7 +69,20 @@ static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
|||
o << "expected: ";
|
||||
if (c.expected) {
|
||||
auto s = c.expected.Get();
|
||||
o << s.value << ", pos_or_neg: " << s.pos_or_neg;
|
||||
if (s.values.Length() == 1) {
|
||||
o << s.values[0];
|
||||
} else {
|
||||
o << "[";
|
||||
for (auto& v : s.values) {
|
||||
if (&v != &s.values[0]) {
|
||||
o << ", ";
|
||||
}
|
||||
o << v;
|
||||
}
|
||||
o << "]";
|
||||
}
|
||||
o << ", pos_or_neg: " << s.flags.pos_or_neg;
|
||||
o << ", float_compare: " << s.flags.float_compare;
|
||||
} else {
|
||||
o << "[ERROR: " << c.expected.Failure().error << "]";
|
||||
}
|
||||
|
@ -80,7 +93,7 @@ using ScalarTypes = std::variant<AInt, AFloat, u32, i32, f32, f16>;
|
|||
|
||||
/// Creates a Case with Values for args and result
|
||||
static Case C(std::initializer_list<Types> args, Types result) {
|
||||
return Case{utils::Vector<Types, 8>{args}, std::move(result)};
|
||||
return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{std::move(result)}};
|
||||
}
|
||||
|
||||
/// Convenience overload that creates a Case with just scalars
|
||||
|
@ -91,7 +104,7 @@ static Case C(std::initializer_list<ScalarTypes> sargs, ScalarTypes sresult) {
|
|||
}
|
||||
Types result = Val(0_a);
|
||||
std::visit([&](auto&& v) { result = Val(v); }, sresult);
|
||||
return Case{std::move(args), std::move(result)};
|
||||
return Case{std::move(args), utils::Vector<Types, 2>{std::move(result)}};
|
||||
}
|
||||
|
||||
/// Creates a Case with Values for args and expected error
|
||||
|
@ -127,9 +140,6 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
|
|||
if (c.expected) {
|
||||
auto expected_case = c.expected.Get();
|
||||
|
||||
auto* expected_expr = ToValueBase(expected_case.value)->Expr(*this);
|
||||
GlobalConst("E", expected_expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
|
||||
auto* sem = Sem().Get(expr);
|
||||
|
@ -138,43 +148,19 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
|
|||
ASSERT_NE(value, nullptr);
|
||||
EXPECT_TYPE(value->Type(), sem->Type());
|
||||
|
||||
auto* expected_sem = Sem().Get(expected_expr);
|
||||
const sem::Constant* expected_value = expected_sem->ConstantValue();
|
||||
ASSERT_NE(expected_value, nullptr);
|
||||
EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
|
||||
|
||||
// @TODO(amaiorano): Rewrite using ScalarArgsFrom()
|
||||
ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) {
|
||||
std::visit(
|
||||
[&](auto&& ct_expected) {
|
||||
using T = typename std::decay_t<decltype(ct_expected)>::ElementType;
|
||||
|
||||
auto v = a->As<T>();
|
||||
auto e = b->As<T>();
|
||||
if constexpr (std::is_same_v<bool, T>) {
|
||||
EXPECT_EQ(v, e);
|
||||
} else if constexpr (IsFloatingPoint<T>) {
|
||||
if (std::isnan(e)) {
|
||||
EXPECT_TRUE(std::isnan(v));
|
||||
} else {
|
||||
auto vf = (expected_case.pos_or_neg ? Abs(v) : v);
|
||||
if (expected_case.float_compare) {
|
||||
EXPECT_FLOAT_EQ(vf, e);
|
||||
} else {
|
||||
EXPECT_EQ(vf, e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
EXPECT_EQ((expected_case.pos_or_neg ? Abs(v) : v), e);
|
||||
// Check that the constant's integer doesn't contain unexpected
|
||||
// data in the MSBs that are outside of the bit-width of T.
|
||||
EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
|
||||
}
|
||||
},
|
||||
expected_case.value);
|
||||
|
||||
return HasFailure() ? Action::kStop : Action::kContinue;
|
||||
});
|
||||
if (value->Type()->Is<sem::Struct>()) {
|
||||
// The result type of the constant-evaluated expression is a structure.
|
||||
// Compare each of the fields individually.
|
||||
for (size_t i = 0; i < expected_case.values.Length(); i++) {
|
||||
CheckConstant(value->Index(i), ToValueBase(expected_case.values[i]),
|
||||
expected_case.flags);
|
||||
}
|
||||
} else {
|
||||
// Return type is not a structure. Just compare the single value
|
||||
ASSERT_EQ(expected_case.values.Length(), 1u)
|
||||
<< "const-eval returned non-struct, but Case expected multiple values";
|
||||
CheckConstant(value, ToValueBase(expected_case.values[0]), expected_case.flags);
|
||||
}
|
||||
} else {
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), c.expected.Failure().error);
|
||||
|
|
|
@ -61,6 +61,81 @@ inline builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) {
|
|||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline auto Abs(const Number<T>& v) {
|
||||
if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
|
||||
return v;
|
||||
} else {
|
||||
return Number<T>(std::abs(v));
|
||||
}
|
||||
}
|
||||
|
||||
/// Flags that can be passed to CheckConstant()
|
||||
struct CheckConstantFlags {
|
||||
/// Expected value may be positive or negative
|
||||
bool pos_or_neg = false;
|
||||
/// Expected value should be compared using FLOAT_EQ instead of EQ
|
||||
bool float_compare = false;
|
||||
};
|
||||
|
||||
/// CheckConstant checks that @p got_constant, the result value of
|
||||
/// constant-evaluation is equal to @p expected_value.
|
||||
/// @param got_constant the constant value evaluated by the resolver
|
||||
/// @param expected_value the expected value for the test
|
||||
/// @param flags optional flags for controlling the comparisons
|
||||
inline void CheckConstant(const sem::Constant* got_constant,
|
||||
const builder::ValueBase* expected_value,
|
||||
CheckConstantFlags flags = {}) {
|
||||
auto values_flat = ScalarArgsFrom(got_constant);
|
||||
auto expected_values_flat = expected_value->Args();
|
||||
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
|
||||
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
|
||||
auto& got_scalar = values_flat.values[i];
|
||||
auto& expected_scalar = expected_values_flat.values[i];
|
||||
std::visit(
|
||||
[&](auto&& expected) {
|
||||
using T = std::decay_t<decltype(expected)>;
|
||||
|
||||
ASSERT_TRUE(std::holds_alternative<T>(got_scalar));
|
||||
auto got = std::get<T>(got_scalar);
|
||||
|
||||
if constexpr (std::is_same_v<bool, T>) {
|
||||
EXPECT_EQ(got, expected);
|
||||
} else if constexpr (IsFloatingPoint<T>) {
|
||||
if (std::isnan(expected)) {
|
||||
EXPECT_TRUE(std::isnan(got));
|
||||
} else {
|
||||
if (flags.pos_or_neg) {
|
||||
auto got_abs = Abs(got);
|
||||
if (flags.float_compare) {
|
||||
EXPECT_FLOAT_EQ(got_abs, expected);
|
||||
} else {
|
||||
EXPECT_EQ(got_abs, expected);
|
||||
}
|
||||
} else {
|
||||
if (flags.float_compare) {
|
||||
EXPECT_FLOAT_EQ(got, expected);
|
||||
} else {
|
||||
EXPECT_EQ(got, expected);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if (flags.pos_or_neg) {
|
||||
auto got_abs = Abs(got);
|
||||
EXPECT_EQ(got_abs, expected);
|
||||
} else {
|
||||
EXPECT_EQ(got, expected);
|
||||
}
|
||||
// Check that the constant's integer doesn't contain unexpected
|
||||
// data in the MSBs that are outside of the bit-width of T.
|
||||
EXPECT_EQ(AInt(got), AInt(expected));
|
||||
}
|
||||
},
|
||||
expected_scalar);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline constexpr auto Negate(const Number<T>& v) {
|
||||
if constexpr (std::is_integral_v<T>) {
|
||||
|
@ -85,15 +160,6 @@ inline constexpr auto Negate(const Number<T>& v) {
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline auto Abs(const Number<T>& v) {
|
||||
if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
|
||||
return v;
|
||||
} else {
|
||||
return Number<T>(std::abs(v));
|
||||
}
|
||||
}
|
||||
|
||||
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
|
||||
template <typename T>
|
||||
inline constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {
|
||||
|
|
Loading…
Reference in New Issue