tint: optimize compile time for const_eval_*_test files
The reason for slow compile times is because the very large variants of builder::Value<T>s combined with the many std::visits over these variants result in many combinatorial instantiations of the visit callbacks. To address this, I added a polymorphic base class ValueBase to Value<T>, and replaced most of the std::visit-based compile time code with runtime virtual calls. For the two heaviest users of std::visit over the large variants, compiles times dropped more than half (clang-10, debug): const_eval_binary_op_test.cc: 19.079s to 7.736s const_eval_unary_op_test.cc: 10.021s to 4.789s Bug: tint:1711 Change-Id: Iba05e6ae1004ef0814250e2a8ea50aa2b26b85f2 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/105782 Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
3fd42ae042
commit
29fb8f8eef
|
@ -54,47 +54,39 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
|
||||||
auto op = std::get<0>(GetParam());
|
auto op = std::get<0>(GetParam());
|
||||||
auto& c = std::get<1>(GetParam());
|
auto& c = std::get<1>(GetParam());
|
||||||
|
|
||||||
std::visit(
|
auto* expected = ToValueBase(c.expected);
|
||||||
[&](auto&& expected) {
|
if (expected->IsAbstract() && c.overflow) {
|
||||||
using T = typename std::decay_t<decltype(expected)>::ElementType;
|
// Overflow is not allowed for abstract types. This is tested separately.
|
||||||
if constexpr (std::is_same_v<T, AInt> || std::is_same_v<T, AFloat>) {
|
return;
|
||||||
if (c.overflow) {
|
}
|
||||||
// Overflow is not allowed for abstract types. This is tested separately.
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
|
auto* lhs = ToValueBase(c.lhs);
|
||||||
auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
|
auto* rhs = ToValueBase(c.rhs);
|
||||||
auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
|
|
||||||
|
|
||||||
GlobalConst("C", expr);
|
auto* lhs_expr = lhs->Expr(*this);
|
||||||
auto* expected_expr = expected.Expr(*this);
|
auto* rhs_expr = rhs->Expr(*this);
|
||||||
GlobalConst("E", expected_expr);
|
auto* expr = create<ast::BinaryExpression>(op, lhs_expr, rhs_expr);
|
||||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
GlobalConst("C", expr);
|
||||||
|
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
|
||||||
auto* sem = Sem().Get(expr);
|
auto* sem = Sem().Get(expr);
|
||||||
const sem::Constant* value = sem->ConstantValue();
|
const sem::Constant* value = sem->ConstantValue();
|
||||||
ASSERT_NE(value, nullptr);
|
ASSERT_NE(value, nullptr);
|
||||||
EXPECT_TYPE(value->Type(), sem->Type());
|
EXPECT_TYPE(value->Type(), sem->Type());
|
||||||
|
|
||||||
auto* expected_sem = Sem().Get(expected_expr);
|
auto values_flat = ScalarArgsFrom(value);
|
||||||
const sem::Constant* expected_value = expected_sem->ConstantValue();
|
auto expected_values_flat = expected->Args();
|
||||||
ASSERT_NE(expected_value, nullptr);
|
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
|
||||||
EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
|
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
|
||||||
|
auto& a = values_flat.values[i];
|
||||||
ForEachElemPair(value, expected_value,
|
auto& b = expected_values_flat.values[i];
|
||||||
[&](const sem::Constant* a, const sem::Constant* b) {
|
EXPECT_EQ(a, b);
|
||||||
EXPECT_EQ(a->As<T>(), b->As<T>());
|
if (expected->IsIntegral()) {
|
||||||
if constexpr (IsIntegral<T>) {
|
// Check that the constant's integer doesn't contain unexpected
|
||||||
// Check that the constant's integer doesn't contain unexpected
|
// data in the MSBs that are outside of the bit-width of T.
|
||||||
// data in the MSBs that are outside of the bit-width of T.
|
EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
|
||||||
EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
|
}
|
||||||
}
|
}
|
||||||
return HasFailure() ? Action::kStop : Action::kContinue;
|
|
||||||
});
|
|
||||||
},
|
|
||||||
c.expected);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs,
|
INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs,
|
||||||
|
@ -658,21 +650,15 @@ using ResolverConstEvalBinaryOpTest_Overflow = ResolverTestWithParam<OverflowCas
|
||||||
TEST_P(ResolverConstEvalBinaryOpTest_Overflow, Test) {
|
TEST_P(ResolverConstEvalBinaryOpTest_Overflow, Test) {
|
||||||
Enable(ast::Extension::kF16);
|
Enable(ast::Extension::kF16);
|
||||||
auto& c = GetParam();
|
auto& c = GetParam();
|
||||||
auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs);
|
auto* lhs = ToValueBase(c.lhs);
|
||||||
auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs);
|
auto* rhs = ToValueBase(c.rhs);
|
||||||
|
auto* lhs_expr = lhs->Expr(*this);
|
||||||
|
auto* rhs_expr = rhs->Expr(*this);
|
||||||
auto* expr = create<ast::BinaryExpression>(Source{{1, 1}}, c.op, lhs_expr, rhs_expr);
|
auto* expr = create<ast::BinaryExpression>(Source{{1, 1}}, c.op, lhs_expr, rhs_expr);
|
||||||
GlobalConst("C", expr);
|
GlobalConst("C", expr);
|
||||||
ASSERT_FALSE(r()->Resolve());
|
ASSERT_FALSE(r()->Resolve());
|
||||||
|
|
||||||
std::string type_name = std::visit(
|
|
||||||
[&](auto&& value) {
|
|
||||||
using ValueType = std::decay_t<decltype(value)>;
|
|
||||||
return builder::FriendlyName<ValueType>();
|
|
||||||
},
|
|
||||||
c.lhs);
|
|
||||||
|
|
||||||
EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '"));
|
EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '"));
|
||||||
EXPECT_THAT(r()->error(), HasSubstr("' cannot be represented as '" + type_name + "'"));
|
EXPECT_THAT(r()->error(), HasSubstr("' cannot be represented as '" + lhs->TypeName() + "'"));
|
||||||
}
|
}
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
Test,
|
Test,
|
||||||
|
@ -854,10 +840,8 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractShiftLeftByNegativeValue_Error) {
|
||||||
using ResolverConstEvalShiftLeftConcreteGeqBitWidthError =
|
using ResolverConstEvalShiftLeftConcreteGeqBitWidthError =
|
||||||
ResolverTestWithParam<std::tuple<Types, Types>>;
|
ResolverTestWithParam<std::tuple<Types, Types>>;
|
||||||
TEST_P(ResolverConstEvalShiftLeftConcreteGeqBitWidthError, Test) {
|
TEST_P(ResolverConstEvalShiftLeftConcreteGeqBitWidthError, Test) {
|
||||||
auto* lhs_expr =
|
auto* lhs_expr = ToValueBase(std::get<0>(GetParam()))->Expr(*this);
|
||||||
std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<0>(GetParam()));
|
auto* rhs_expr = ToValueBase(std::get<1>(GetParam()))->Expr(*this);
|
||||||
auto* rhs_expr =
|
|
||||||
std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<1>(GetParam()));
|
|
||||||
GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
|
GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
|
||||||
EXPECT_FALSE(r()->Resolve());
|
EXPECT_FALSE(r()->Resolve());
|
||||||
EXPECT_EQ(
|
EXPECT_EQ(
|
||||||
|
@ -880,10 +864,8 @@ INSTANTIATE_TEST_SUITE_P(Test,
|
||||||
// AInt left shift results in sign change error
|
// AInt left shift results in sign change error
|
||||||
using ResolverConstEvalShiftLeftSignChangeError = ResolverTestWithParam<std::tuple<Types, Types>>;
|
using ResolverConstEvalShiftLeftSignChangeError = ResolverTestWithParam<std::tuple<Types, Types>>;
|
||||||
TEST_P(ResolverConstEvalShiftLeftSignChangeError, Test) {
|
TEST_P(ResolverConstEvalShiftLeftSignChangeError, Test) {
|
||||||
auto* lhs_expr =
|
auto* lhs_expr = ToValueBase(std::get<0>(GetParam()))->Expr(*this);
|
||||||
std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<0>(GetParam()));
|
auto* rhs_expr = ToValueBase(std::get<1>(GetParam()))->Expr(*this);
|
||||||
auto* rhs_expr =
|
|
||||||
std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<1>(GetParam()));
|
|
||||||
GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
|
GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
|
||||||
EXPECT_FALSE(r()->Resolve());
|
EXPECT_FALSE(r()->Resolve());
|
||||||
EXPECT_EQ(r()->error(), "1:1 error: shift left operation results in sign change");
|
EXPECT_EQ(r()->error(), "1:1 error: shift left operation results in sign change");
|
||||||
|
|
|
@ -83,54 +83,57 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
|
||||||
std::visit([&](auto&& v) { args.Push(v.Expr(*this)); }, a);
|
std::visit([&](auto&& v) { args.Push(v.Expr(*this)); }, a);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::visit(
|
auto* expected = ToValueBase(c.expected);
|
||||||
[&](auto&& expected) {
|
auto* expr = Call(sem::str(builtin), std::move(args));
|
||||||
using T = typename std::decay_t<decltype(expected)>::ElementType;
|
|
||||||
auto* expr = Call(sem::str(builtin), std::move(args));
|
|
||||||
|
|
||||||
GlobalConst("C", expr);
|
GlobalConst("C", expr);
|
||||||
auto* expected_expr = expected.Expr(*this);
|
auto* expected_expr = expected->Expr(*this);
|
||||||
GlobalConst("E", expected_expr);
|
GlobalConst("E", expected_expr);
|
||||||
|
|
||||||
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
EXPECT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
|
||||||
auto* sem = Sem().Get(expr);
|
auto* sem = Sem().Get(expr);
|
||||||
const sem::Constant* value = sem->ConstantValue();
|
const sem::Constant* value = sem->ConstantValue();
|
||||||
ASSERT_NE(value, nullptr);
|
ASSERT_NE(value, nullptr);
|
||||||
EXPECT_TYPE(value->Type(), sem->Type());
|
EXPECT_TYPE(value->Type(), sem->Type());
|
||||||
|
|
||||||
auto* expected_sem = Sem().Get(expected_expr);
|
auto* expected_sem = Sem().Get(expected_expr);
|
||||||
const sem::Constant* expected_value = expected_sem->ConstantValue();
|
const sem::Constant* expected_value = expected_sem->ConstantValue();
|
||||||
ASSERT_NE(expected_value, nullptr);
|
ASSERT_NE(expected_value, nullptr);
|
||||||
EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
|
EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
|
||||||
|
|
||||||
ForEachElemPair(value, expected_value,
|
// @TODO(amaiorano): Rewrite using ScalarArgsFrom()
|
||||||
[&](const sem::Constant* a, const sem::Constant* b) {
|
ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) {
|
||||||
auto v = a->As<T>();
|
std::visit(
|
||||||
auto e = b->As<T>();
|
[&](auto&& ct_expected) {
|
||||||
if constexpr (std::is_same_v<bool, T>) {
|
using T = typename std::decay_t<decltype(ct_expected)>::ElementType;
|
||||||
EXPECT_EQ(v, e);
|
|
||||||
} else if constexpr (IsFloatingPoint<T>) {
|
auto v = a->As<T>();
|
||||||
if (std::isnan(e)) {
|
auto e = b->As<T>();
|
||||||
EXPECT_TRUE(std::isnan(v));
|
if constexpr (std::is_same_v<bool, T>) {
|
||||||
} else {
|
EXPECT_EQ(v, e);
|
||||||
auto vf = (c.expected_pos_or_neg ? Abs(v) : v);
|
} else if constexpr (IsFloatingPoint<T>) {
|
||||||
if (c.float_compare) {
|
if (std::isnan(e)) {
|
||||||
EXPECT_FLOAT_EQ(vf, e);
|
EXPECT_TRUE(std::isnan(v));
|
||||||
} else {
|
} else {
|
||||||
EXPECT_EQ(vf, e);
|
auto vf = (c.expected_pos_or_neg ? Abs(v) : v);
|
||||||
}
|
if (c.float_compare) {
|
||||||
}
|
EXPECT_FLOAT_EQ(vf, e);
|
||||||
} else {
|
} else {
|
||||||
EXPECT_EQ((c.expected_pos_or_neg ? Abs(v) : v), e);
|
EXPECT_EQ(vf, e);
|
||||||
// Check that the constant's integer doesn't contain unexpected
|
}
|
||||||
// data in the MSBs that are outside of the bit-width of T.
|
}
|
||||||
EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
|
} else {
|
||||||
}
|
EXPECT_EQ((c.expected_pos_or_neg ? Abs(v) : v), e);
|
||||||
return HasFailure() ? Action::kStop : Action::kContinue;
|
// Check that the constant's integer doesn't contain unexpected
|
||||||
});
|
// data in the MSBs that are outside of the bit-width of T.
|
||||||
},
|
EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
|
||||||
c.expected);
|
}
|
||||||
|
},
|
||||||
|
c.expected);
|
||||||
|
|
||||||
|
return HasFailure() ? Action::kStop : Action::kContinue;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P( //
|
INSTANTIATE_TEST_SUITE_P( //
|
||||||
|
|
|
@ -29,20 +29,7 @@ using Scalar = std::variant< //
|
||||||
builder::Value<bool>>;
|
builder::Value<bool>>;
|
||||||
|
|
||||||
static std::ostream& operator<<(std::ostream& o, const Scalar& scalar) {
|
static std::ostream& operator<<(std::ostream& o, const Scalar& scalar) {
|
||||||
std::visit(
|
return ToValueBase(scalar)->Print(o);
|
||||||
[&](auto&& v) {
|
|
||||||
using ValueType = std::decay_t<decltype(v)>;
|
|
||||||
o << ValueType::DataType::Name() << "(";
|
|
||||||
for (auto& a : v.args.values) {
|
|
||||||
o << std::get<typename ValueType::ElementType>(a);
|
|
||||||
if (&a != &v.args.values.Back()) {
|
|
||||||
o << ", ";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
o << ")";
|
|
||||||
},
|
|
||||||
scalar);
|
|
||||||
return o;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
enum class Kind {
|
enum class Kind {
|
||||||
|
@ -96,7 +83,7 @@ TEST_P(ResolverConstEvalConvTest, Test) {
|
||||||
const auto& type = std::get<1>(GetParam()).type;
|
const auto& type = std::get<1>(GetParam()).type;
|
||||||
const auto unrepresentable = std::get<1>(GetParam()).unrepresentable;
|
const auto unrepresentable = std::get<1>(GetParam()).unrepresentable;
|
||||||
|
|
||||||
auto* input_val = std::visit([&](auto val) { return val.Expr(*this); }, input);
|
auto* input_val = ToValueBase(input)->Expr(*this);
|
||||||
auto* expr = Construct(type.ast(*this), input_val);
|
auto* expr = Construct(type.ast(*this), input_val);
|
||||||
if (kind == Kind::kVector) {
|
if (kind == Kind::kVector) {
|
||||||
expr = Construct(ty.vec(nullptr, 3), expr);
|
expr = Construct(ty.vec(nullptr, 3), expr);
|
||||||
|
@ -120,7 +107,7 @@ TEST_P(ResolverConstEvalConvTest, Test) {
|
||||||
ASSERT_NE(sem->ConstantValue(), nullptr);
|
ASSERT_NE(sem->ConstantValue(), nullptr);
|
||||||
EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
|
EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
|
||||||
|
|
||||||
auto expected_values = std::visit([&](auto&& val) { return val.args; }, expected);
|
auto expected_values = ToValueBase(expected)->Args();
|
||||||
if (kind == Kind::kVector) {
|
if (kind == Kind::kVector) {
|
||||||
expected_values.values.Push(expected_values.values[0]);
|
expected_values.values.Push(expected_values.values[0]);
|
||||||
expected_values.values.Push(expected_values.values[0]);
|
expected_values.values.Push(expected_values.values[0]);
|
||||||
|
|
|
@ -41,6 +41,8 @@ inline const auto k3PiOver4 = T(UnwrapNumber<T>(2.356194490192344928846));
|
||||||
inline void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) {
|
inline void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) {
|
||||||
Switch(
|
Switch(
|
||||||
c->Type(), //
|
c->Type(), //
|
||||||
|
[&](const sem::AbstractInt*) { args.values.Push(c->As<AInt>()); },
|
||||||
|
[&](const sem::AbstractFloat*) { args.values.Push(c->As<AFloat>()); },
|
||||||
[&](const sem::Bool*) { args.values.Push(c->As<bool>()); },
|
[&](const sem::Bool*) { args.values.Push(c->As<bool>()); },
|
||||||
[&](const sem::I32*) { args.values.Push(c->As<i32>()); },
|
[&](const sem::I32*) { args.values.Push(c->As<i32>()); },
|
||||||
[&](const sem::U32*) { args.values.Push(c->As<u32>()); },
|
[&](const sem::U32*) { args.values.Push(c->As<u32>()); },
|
||||||
|
@ -136,6 +138,7 @@ using builder::IsValue;
|
||||||
using builder::Mat;
|
using builder::Mat;
|
||||||
using builder::Val;
|
using builder::Val;
|
||||||
using builder::Value;
|
using builder::Value;
|
||||||
|
using builder::ValueBase;
|
||||||
using builder::Vec;
|
using builder::Vec;
|
||||||
|
|
||||||
using Types = std::variant< //
|
using Types = std::variant< //
|
||||||
|
@ -188,21 +191,18 @@ using Types = std::variant< //
|
||||||
//
|
//
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
/// Returns the current Value<T> in the `types` variant as a `ValueBase` pointer to use the
|
||||||
|
/// polymorphic API. This trades longer compile times using std::variant for longer runtime via
|
||||||
|
/// virtual function calls.
|
||||||
|
template <typename ValueVariant>
|
||||||
|
inline const ValueBase* ToValueBase(const ValueVariant& types) {
|
||||||
|
return std::visit(
|
||||||
|
[](auto&& t) -> const ValueBase* { return static_cast<const ValueBase*>(&t); }, types);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Prints Types to ostream
|
||||||
inline std::ostream& operator<<(std::ostream& o, const Types& types) {
|
inline std::ostream& operator<<(std::ostream& o, const Types& types) {
|
||||||
std::visit(
|
return ToValueBase(types)->Print(o);
|
||||||
[&](auto&& v) {
|
|
||||||
using ValueType = std::decay_t<decltype(v)>;
|
|
||||||
o << ValueType::DataType::Name() << "(";
|
|
||||||
for (auto& a : v.args.values) {
|
|
||||||
o << std::get<typename ValueType::ElementType>(a);
|
|
||||||
if (&a != &v.args.values.Back()) {
|
|
||||||
o << ", ";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
o << ")";
|
|
||||||
},
|
|
||||||
types);
|
|
||||||
return o;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls `f` on deepest elements of both `a` and `b`. If function returns Action::kStop, it stops
|
// Calls `f` on deepest elements of both `a` and `b`. If function returns Action::kStop, it stops
|
||||||
|
|
|
@ -51,40 +51,34 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) {
|
||||||
|
|
||||||
auto op = std::get<0>(GetParam());
|
auto op = std::get<0>(GetParam());
|
||||||
auto& c = std::get<1>(GetParam());
|
auto& c = std::get<1>(GetParam());
|
||||||
std::visit(
|
|
||||||
[&](auto&& expected) {
|
|
||||||
using T = typename std::decay_t<decltype(expected)>::ElementType;
|
|
||||||
|
|
||||||
auto* input_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.input);
|
auto* expected = ToValueBase(c.expected);
|
||||||
auto* expr = create<ast::UnaryOpExpression>(op, input_expr);
|
auto* input = ToValueBase(c.input);
|
||||||
|
|
||||||
GlobalConst("C", expr);
|
auto* input_expr = input->Expr(*this);
|
||||||
auto* expected_expr = expected.Expr(*this);
|
auto* expr = create<ast::UnaryOpExpression>(op, input_expr);
|
||||||
GlobalConst("E", expected_expr);
|
|
||||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
|
||||||
|
|
||||||
auto* sem = Sem().Get(expr);
|
GlobalConst("C", expr);
|
||||||
const sem::Constant* value = sem->ConstantValue();
|
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||||
ASSERT_NE(value, nullptr);
|
|
||||||
EXPECT_TYPE(value->Type(), sem->Type());
|
|
||||||
|
|
||||||
auto* expected_sem = Sem().Get(expected_expr);
|
auto* sem = Sem().Get(expr);
|
||||||
const sem::Constant* expected_value = expected_sem->ConstantValue();
|
const sem::Constant* value = sem->ConstantValue();
|
||||||
ASSERT_NE(expected_value, nullptr);
|
ASSERT_NE(value, nullptr);
|
||||||
EXPECT_TYPE(expected_value->Type(), expected_sem->Type());
|
EXPECT_TYPE(value->Type(), sem->Type());
|
||||||
|
|
||||||
ForEachElemPair(value, expected_value,
|
auto values_flat = ScalarArgsFrom(value);
|
||||||
[&](const sem::Constant* a, const sem::Constant* b) {
|
auto expected_values_flat = expected->Args();
|
||||||
EXPECT_EQ(a->As<T>(), b->As<T>());
|
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
|
||||||
if constexpr (IsIntegral<T>) {
|
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
|
||||||
// Check that the constant's integer doesn't contain unexpected
|
auto& a = values_flat.values[i];
|
||||||
// data in the MSBs that are outside of the bit-width of T.
|
auto& b = expected_values_flat.values[i];
|
||||||
EXPECT_EQ(a->As<AInt>(), b->As<AInt>());
|
EXPECT_EQ(a, b);
|
||||||
}
|
if (expected->IsIntegral()) {
|
||||||
return HasFailure() ? Action::kStop : Action::kContinue;
|
// Check that the constant's integer doesn't contain unexpected
|
||||||
});
|
// data in the MSBs that are outside of the bit-width of T.
|
||||||
},
|
EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
|
||||||
c.expected);
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
INSTANTIATE_TEST_SUITE_P(Complement,
|
INSTANTIATE_TEST_SUITE_P(Complement,
|
||||||
ResolverConstEvalUnaryOpTest,
|
ResolverConstEvalUnaryOpTest,
|
||||||
|
|
|
@ -206,6 +206,12 @@ struct ScalarArgs {
|
||||||
utils::Vector<Storage, 16> values;
|
utils::Vector<Storage, 16> values;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Returns current variant value in `s` cast to type `T`
|
||||||
|
template <typename T>
|
||||||
|
T As(ScalarArgs::Storage& s) {
|
||||||
|
return std::visit([](auto&& v) { return static_cast<T>(v); }, s);
|
||||||
|
}
|
||||||
|
|
||||||
/// @param o the std::ostream to write to
|
/// @param o the std::ostream to write to
|
||||||
/// @param args the ScalarArgs
|
/// @param args the ScalarArgs
|
||||||
/// @return the std::ostream so calls can be chained
|
/// @return the std::ostream so calls can be chained
|
||||||
|
@ -750,10 +756,45 @@ constexpr CreatePtrs CreatePtrsFor() {
|
||||||
DataType<T>::Name};
|
DataType<T>::Name};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Base class for Value<T>
|
||||||
|
struct ValueBase {
|
||||||
|
/// Constructor
|
||||||
|
ValueBase() = default;
|
||||||
|
/// Destructor
|
||||||
|
virtual ~ValueBase() = default;
|
||||||
|
/// Move constructor
|
||||||
|
ValueBase(ValueBase&&) = default;
|
||||||
|
/// Copy constructor
|
||||||
|
ValueBase(const ValueBase&) = default;
|
||||||
|
/// Copy assignment operator
|
||||||
|
/// @returns this instance
|
||||||
|
ValueBase& operator=(const ValueBase&) = default;
|
||||||
|
/// Creates an `ast::Expression` for the type T passing in previously stored args
|
||||||
|
/// @param b the ProgramBuilder
|
||||||
|
/// @returns an expression node
|
||||||
|
virtual const ast::Expression* Expr(ProgramBuilder& b) const = 0;
|
||||||
|
/// @returns args used to create expression via `Expr`
|
||||||
|
virtual const ScalarArgs& Args() const = 0;
|
||||||
|
/// @returns true if element type is abstract
|
||||||
|
virtual bool IsAbstract() const = 0;
|
||||||
|
/// @returns true if element type is an integral
|
||||||
|
virtual bool IsIntegral() const = 0;
|
||||||
|
/// @returns element type name
|
||||||
|
virtual std::string TypeName() const = 0;
|
||||||
|
/// Prints this value to the output stream
|
||||||
|
/// @param o the output stream
|
||||||
|
/// @returns input argument `o`
|
||||||
|
virtual std::ostream& Print(std::ostream& o) const = 0;
|
||||||
|
};
|
||||||
|
|
||||||
/// Value<T> is an instance of a value of type DataType<T>. Useful for storing values to create
|
/// Value<T> is an instance of a value of type DataType<T>. Useful for storing values to create
|
||||||
/// expressions with.
|
/// expressions with.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct Value {
|
struct Value : ValueBase {
|
||||||
|
/// Constructor
|
||||||
|
/// @param a the scalar args
|
||||||
|
explicit Value(ScalarArgs a) : args(std::move(a)) {}
|
||||||
|
|
||||||
/// Alias to T
|
/// Alias to T
|
||||||
using Type = T;
|
using Type = T;
|
||||||
/// Alias to DataType<T>
|
/// Alias to DataType<T>
|
||||||
|
@ -764,15 +805,43 @@ struct Value {
|
||||||
/// Creates a Value<T> with `args`
|
/// Creates a Value<T> with `args`
|
||||||
/// @param args the args that will be passed to the expression
|
/// @param args the args that will be passed to the expression
|
||||||
/// @returns a Value<T>
|
/// @returns a Value<T>
|
||||||
static Value Create(ScalarArgs args) { return Value{CreatePtrsFor<T>(), std::move(args)}; }
|
static Value Create(ScalarArgs args) { return Value{std::move(args)}; }
|
||||||
|
|
||||||
/// Creates an `ast::Expression` for the type T passing in previously stored args
|
/// Creates an `ast::Expression` for the type T passing in previously stored args
|
||||||
/// @param b the ProgramBuilder
|
/// @param b the ProgramBuilder
|
||||||
/// @returns an expression node
|
/// @returns an expression node
|
||||||
const ast::Expression* Expr(ProgramBuilder& b) const { return (*create.expr)(b, args); }
|
const ast::Expression* Expr(ProgramBuilder& b) const override {
|
||||||
|
auto create = CreatePtrsFor<T>();
|
||||||
|
return (*create.expr)(b, args);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @returns args used to create expression via `Expr`
|
||||||
|
const ScalarArgs& Args() const override { return args; }
|
||||||
|
|
||||||
|
/// @returns true if element type is abstract
|
||||||
|
bool IsAbstract() const override { return tint::IsAbstract<ElementType>; }
|
||||||
|
|
||||||
|
/// @returns true if element type is an integral
|
||||||
|
bool IsIntegral() const override { return tint::IsIntegral<ElementType>; }
|
||||||
|
|
||||||
|
/// @returns element type name
|
||||||
|
std::string TypeName() const override { return tint::FriendlyName<ElementType>(); }
|
||||||
|
|
||||||
|
/// Prints this value to the output stream
|
||||||
|
/// @param o the output stream
|
||||||
|
/// @returns input argument `o`
|
||||||
|
std::ostream& Print(std::ostream& o) const override {
|
||||||
|
o << TypeName() << "(";
|
||||||
|
for (auto& a : args.values) {
|
||||||
|
o << std::get<ElementType>(a);
|
||||||
|
if (&a != &args.values.Back()) {
|
||||||
|
o << ", ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
o << ")";
|
||||||
|
return o;
|
||||||
|
}
|
||||||
|
|
||||||
/// functions to create values / types of the value
|
|
||||||
CreatePtrs create;
|
|
||||||
/// args to create expression with
|
/// args to create expression with
|
||||||
ScalarArgs args;
|
ScalarArgs args;
|
||||||
};
|
};
|
||||||
|
|
Loading…
Reference in New Issue