tint/resolver: Split constant checking to utility
And add basic support for builtins returning structures. Bug tint:1581 Change-Id: I67f987339b9a344e1915c69c9991803f0665305d Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111242 Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Reviewed-by: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
d6d30f4256
commit
6559903234
|
@ -106,19 +106,7 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
|
||||||
ASSERT_NE(value, nullptr);
|
ASSERT_NE(value, nullptr);
|
||||||
EXPECT_TYPE(value->Type(), sem->Type());
|
EXPECT_TYPE(value->Type(), sem->Type());
|
||||||
|
|
||||||
auto values_flat = ScalarArgsFrom(value);
|
CheckConstant(value, expected);
|
||||||
auto expected_values_flat = expected->Args();
|
|
||||||
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
|
|
||||||
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
|
|
||||||
auto& a = values_flat.values[i];
|
|
||||||
auto& b = expected_values_flat.values[i];
|
|
||||||
EXPECT_EQ(a, b);
|
|
||||||
if (expected->IsIntegral()) {
|
|
||||||
// Check that the constant's integer doesn't contain unexpected
|
|
||||||
// data in the MSBs that are outside of the bit-width of T.
|
|
||||||
EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
ASSERT_FALSE(r()->Resolve());
|
ASSERT_FALSE(r()->Resolve());
|
||||||
EXPECT_EQ(r()->error(), c.expected.Failure().error);
|
EXPECT_EQ(r()->error(), c.expected.Failure().error);
|
||||||
|
|
|
@ -26,8 +26,9 @@ namespace {
|
||||||
using resolver::operator<<;
|
using resolver::operator<<;
|
||||||
|
|
||||||
struct Case {
|
struct Case {
|
||||||
Case(utils::VectorRef<Types> in_args, Types expected_value)
|
Case(utils::VectorRef<Types> in_args, utils::VectorRef<Types> expected_values)
|
||||||
: args(std::move(in_args)), expected(Success{std::move(expected_value), false, false}) {}
|
: args(std::move(in_args)),
|
||||||
|
expected(Success{std::move(expected_values), CheckConstantFlags{}}) {}
|
||||||
|
|
||||||
Case(utils::VectorRef<Types> in_args, std::string expected_err)
|
Case(utils::VectorRef<Types> in_args, std::string expected_err)
|
||||||
: args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {}
|
: args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {}
|
||||||
|
@ -35,7 +36,7 @@ struct Case {
|
||||||
/// Expected value may be positive or negative
|
/// Expected value may be positive or negative
|
||||||
Case& PosOrNeg() {
|
Case& PosOrNeg() {
|
||||||
Success s = expected.Get();
|
Success s = expected.Get();
|
||||||
s.pos_or_neg = true;
|
s.flags.pos_or_neg = true;
|
||||||
expected = s;
|
expected = s;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
@ -43,15 +44,14 @@ struct Case {
|
||||||
/// Expected value should be compared using FLOAT_EQ instead of EQ
|
/// Expected value should be compared using FLOAT_EQ instead of EQ
|
||||||
Case& FloatComp() {
|
Case& FloatComp() {
|
||||||
Success s = expected.Get();
|
Success s = expected.Get();
|
||||||
s.float_compare = true;
|
s.flags.float_compare = true;
|
||||||
expected = s;
|
expected = s;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Success {
|
struct Success {
|
||||||
Types value;
|
utils::Vector<Types, 2> values;
|
||||||
bool pos_or_neg = false;
|
CheckConstantFlags flags;
|
||||||
bool float_compare = false;
|
|
||||||
};
|
};
|
||||||
struct Failure {
|
struct Failure {
|
||||||
std::string error;
|
std::string error;
|
||||||
|
@ -69,7 +69,20 @@ static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
||||||
o << "expected: ";
|
o << "expected: ";
|
||||||
if (c.expected) {
|
if (c.expected) {
|
||||||
auto s = c.expected.Get();
|
auto s = c.expected.Get();
|
||||||
o << s.value << ", pos_or_neg: " << s.pos_or_neg;
|
if (s.values.Length() == 1) {
|
||||||
|
o << s.values[0];
|
||||||
|
} else {
|
||||||
|
o << "[";
|
||||||
|
for (auto& v : s.values) {
|
||||||
|
if (&v != &s.values[0]) {
|
||||||
|
o << ", ";
|
||||||
|
}
|
||||||
|
o << v;
|
||||||
|
}
|
||||||
|
o << "]";
|
||||||
|
}
|
||||||
|
o << ", pos_or_neg: " << s.flags.pos_or_neg;
|
||||||
|
o << ", float_compare: " << s.flags.float_compare;
|
||||||
} else {
|
} else {
|
||||||
o << "[ERROR: " << c.expected.Failure().error << "]";
|
o << "[ERROR: " << c.expected.Failure().error << "]";
|
||||||
}
|
}
|
||||||
|
@ -80,7 +93,7 @@ using ScalarTypes = std::variant<AInt, AFloat, u32, i32, f32, f16>;
|
||||||
|
|
||||||
/// Creates a Case with Values for args and result
|
/// Creates a Case with Values for args and result
|
||||||
static Case C(std::initializer_list<Types> args, Types result) {
|
static Case C(std::initializer_list<Types> args, Types result) {
|
||||||
return Case{utils::Vector<Types, 8>{args}, std::move(result)};
|
return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{std::move(result)}};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convenience overload that creates a Case with just scalars
|
/// Convenience overload that creates a Case with just scalars
|
||||||
|
@ -91,7 +104,7 @@ static Case C(std::initializer_list<ScalarTypes> sargs, ScalarTypes sresult) {
|
||||||
}
|
}
|
||||||
Types result = Val(0_a);
|
Types result = Val(0_a);
|
||||||
std::visit([&](auto&& v) { result = Val(v); }, sresult);
|
std::visit([&](auto&& v) { result = Val(v); }, sresult);
|
||||||
return Case{std::move(args), std::move(result)};
|
return Case{std::move(args), utils::Vector<Types, 2>{std::move(result)}};
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a Case with Values for args and expected error
|
/// Creates a Case with Values for args and expected error
|
||||||
|
@ -127,9 +140,6 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
|
||||||
if (c.expected) {
|
if (c.expected) {
|
||||||
auto expected_case = c.expected.Get();
|
auto expected_case = c.expected.Get();
|
||||||
|
|
||||||
auto* expected_expr = ToValueBase(expected_case.value)->Expr(*this);
|
|
||||||
GlobalConst("E", expected_expr);
|
|
||||||
|
|
||||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
|
||||||
auto* sem = Sem().Get(expr);
|
auto* sem = Sem().Get(expr);
|
||||||
|
@ -138,43 +148,19 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
|
||||||
ASSERT_NE(value, nullptr);
|
ASSERT_NE(value, nullptr);
|
||||||
EXPECT_TYPE(value->Type(), sem->Type());
|
EXPECT_TYPE(value->Type(), sem->Type());
|
||||||
|
|
||||||
auto* expected_sem = Sem().Get(expected_expr);
|
if (value->Type()->Is<sem::Struct>()) {
|
||||||
const sem::Constant* expected_value = expected_sem->ConstantValue();
|
// The result type of the constant-evaluated expression is a structure.
|
||||||
ASSERT_NE(expected_value, nullptr);
|
// Compare each of the fields individually.
|
||||||
EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
|
for (size_t i = 0; i < expected_case.values.Length(); i++) {
|
||||||
|
CheckConstant(value->Index(i), ToValueBase(expected_case.values[i]),
|
||||||
// @TODO(amaiorano): Rewrite using ScalarArgsFrom()
|
expected_case.flags);
|
||||||
ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) {
|
}
|
||||||
std::visit(
|
} else {
|
||||||
[&](auto&& ct_expected) {
|
// Return type is not a structure. Just compare the single value
|
||||||
using T = typename std::decay_t<decltype(ct_expected)>::ElementType;
|
ASSERT_EQ(expected_case.values.Length(), 1u)
|
||||||
|
<< "const-eval returned non-struct, but Case expected multiple values";
|
||||||
auto v = a->As<T>();
|
CheckConstant(value, ToValueBase(expected_case.values[0]), expected_case.flags);
|
||||||
auto e = b->As<T>();
|
}
|
||||||
if constexpr (std::is_same_v<bool, T>) {
|
|
||||||
EXPECT_EQ(v, e);
|
|
||||||
} else if constexpr (IsFloatingPoint<T>) {
|
|
||||||
if (std::isnan(e)) {
|
|
||||||
EXPECT_TRUE(std::isnan(v));
|
|
||||||
} else {
|
|
||||||
auto vf = (expected_case.pos_or_neg ? Abs(v) : v);
|
|
||||||
if (expected_case.float_compare) {
|
|
||||||
EXPECT_FLOAT_EQ(vf, e);
|
|
||||||
} else {
|
|
||||||
EXPECT_EQ(vf, e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
EXPECT_EQ((expected_case.pos_or_neg ? Abs(v) : v), e);
|
|
||||||
// Check that the constant's integer doesn't contain unexpected
|
|
||||||
// data in the MSBs that are outside of the bit-width of T.
|
|
||||||
EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
|
|
||||||
}
|
|
||||||
},
|
|
||||||
expected_case.value);
|
|
||||||
|
|
||||||
return HasFailure() ? Action::kStop : Action::kContinue;
|
|
||||||
});
|
|
||||||
} else {
|
} else {
|
||||||
EXPECT_FALSE(r()->Resolve());
|
EXPECT_FALSE(r()->Resolve());
|
||||||
EXPECT_EQ(r()->error(), c.expected.Failure().error);
|
EXPECT_EQ(r()->error(), c.expected.Failure().error);
|
||||||
|
|
|
@ -61,6 +61,81 @@ inline builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) {
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline auto Abs(const Number<T>& v) {
|
||||||
|
if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
|
||||||
|
return v;
|
||||||
|
} else {
|
||||||
|
return Number<T>(std::abs(v));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Flags that can be passed to CheckConstant()
|
||||||
|
struct CheckConstantFlags {
|
||||||
|
/// Expected value may be positive or negative
|
||||||
|
bool pos_or_neg = false;
|
||||||
|
/// Expected value should be compared using FLOAT_EQ instead of EQ
|
||||||
|
bool float_compare = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
/// CheckConstant checks that @p got_constant, the result value of
|
||||||
|
/// constant-evaluation is equal to @p expected_value.
|
||||||
|
/// @param got_constant the constant value evaluated by the resolver
|
||||||
|
/// @param expected_value the expected value for the test
|
||||||
|
/// @param flags optional flags for controlling the comparisons
|
||||||
|
inline void CheckConstant(const sem::Constant* got_constant,
|
||||||
|
const builder::ValueBase* expected_value,
|
||||||
|
CheckConstantFlags flags = {}) {
|
||||||
|
auto values_flat = ScalarArgsFrom(got_constant);
|
||||||
|
auto expected_values_flat = expected_value->Args();
|
||||||
|
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
|
||||||
|
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
|
||||||
|
auto& got_scalar = values_flat.values[i];
|
||||||
|
auto& expected_scalar = expected_values_flat.values[i];
|
||||||
|
std::visit(
|
||||||
|
[&](auto&& expected) {
|
||||||
|
using T = std::decay_t<decltype(expected)>;
|
||||||
|
|
||||||
|
ASSERT_TRUE(std::holds_alternative<T>(got_scalar));
|
||||||
|
auto got = std::get<T>(got_scalar);
|
||||||
|
|
||||||
|
if constexpr (std::is_same_v<bool, T>) {
|
||||||
|
EXPECT_EQ(got, expected);
|
||||||
|
} else if constexpr (IsFloatingPoint<T>) {
|
||||||
|
if (std::isnan(expected)) {
|
||||||
|
EXPECT_TRUE(std::isnan(got));
|
||||||
|
} else {
|
||||||
|
if (flags.pos_or_neg) {
|
||||||
|
auto got_abs = Abs(got);
|
||||||
|
if (flags.float_compare) {
|
||||||
|
EXPECT_FLOAT_EQ(got_abs, expected);
|
||||||
|
} else {
|
||||||
|
EXPECT_EQ(got_abs, expected);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (flags.float_compare) {
|
||||||
|
EXPECT_FLOAT_EQ(got, expected);
|
||||||
|
} else {
|
||||||
|
EXPECT_EQ(got, expected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (flags.pos_or_neg) {
|
||||||
|
auto got_abs = Abs(got);
|
||||||
|
EXPECT_EQ(got_abs, expected);
|
||||||
|
} else {
|
||||||
|
EXPECT_EQ(got, expected);
|
||||||
|
}
|
||||||
|
// Check that the constant's integer doesn't contain unexpected
|
||||||
|
// data in the MSBs that are outside of the bit-width of T.
|
||||||
|
EXPECT_EQ(AInt(got), AInt(expected));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
expected_scalar);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline constexpr auto Negate(const Number<T>& v) {
|
inline constexpr auto Negate(const Number<T>& v) {
|
||||||
if constexpr (std::is_integral_v<T>) {
|
if constexpr (std::is_integral_v<T>) {
|
||||||
|
@ -85,15 +160,6 @@ inline constexpr auto Negate(const Number<T>& v) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline auto Abs(const Number<T>& v) {
|
|
||||||
if constexpr (std::is_integral_v<T> && std::is_unsigned_v<T>) {
|
|
||||||
return v;
|
|
||||||
} else {
|
|
||||||
return Number<T>(std::abs(v));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
|
TINT_BEGIN_DISABLE_WARNING(CONSTANT_OVERFLOW);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {
|
inline constexpr Number<T> Mul(Number<T> v1, Number<T> v2) {
|
||||||
|
|
Loading…
Reference in New Issue