diff --git a/src/tint/resolver/const_eval_binary_op_test.cc b/src/tint/resolver/const_eval_binary_op_test.cc index 18b28da2a2..5f43499bf5 100644 --- a/src/tint/resolver/const_eval_binary_op_test.cc +++ b/src/tint/resolver/const_eval_binary_op_test.cc @@ -54,47 +54,39 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) { auto op = std::get<0>(GetParam()); auto& c = std::get<1>(GetParam()); - std::visit( - [&](auto&& expected) { - using T = typename std::decay_t::ElementType; - if constexpr (std::is_same_v || std::is_same_v) { - if (c.overflow) { - // Overflow is not allowed for abstract types. This is tested separately. - return; - } - } + auto* expected = ToValueBase(c.expected); + if (expected->IsAbstract() && c.overflow) { + // Overflow is not allowed for abstract types. This is tested separately. + return; + } - auto* lhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.lhs); - auto* rhs_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.rhs); - auto* expr = create(op, lhs_expr, rhs_expr); + auto* lhs = ToValueBase(c.lhs); + auto* rhs = ToValueBase(c.rhs); - GlobalConst("C", expr); - auto* expected_expr = expected.Expr(*this); - GlobalConst("E", expected_expr); - ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* lhs_expr = lhs->Expr(*this); + auto* rhs_expr = rhs->Expr(*this); + auto* expr = create(op, lhs_expr, rhs_expr); + GlobalConst("C", expr); + ASSERT_TRUE(r()->Resolve()) << r()->error(); - auto* sem = Sem().Get(expr); - const sem::Constant* value = sem->ConstantValue(); - ASSERT_NE(value, nullptr); - EXPECT_TYPE(value->Type(), sem->Type()); + auto* sem = Sem().Get(expr); + const sem::Constant* value = sem->ConstantValue(); + ASSERT_NE(value, nullptr); + EXPECT_TYPE(value->Type(), sem->Type()); - auto* expected_sem = Sem().Get(expected_expr); - const sem::Constant* expected_value = expected_sem->ConstantValue(); - ASSERT_NE(expected_value, nullptr); - EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); - - ForEachElemPair(value, expected_value, - [&](const sem::Constant* a, const sem::Constant* b) { - EXPECT_EQ(a->As(), b->As()); - if constexpr (IsIntegral) { - // Check that the constant's integer doesn't contain unexpected - // data in the MSBs that are outside of the bit-width of T. - EXPECT_EQ(a->As(), b->As()); - } - return HasFailure() ? Action::kStop : Action::kContinue; - }); - }, - c.expected); + auto values_flat = ScalarArgsFrom(value); + auto expected_values_flat = expected->Args(); + ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length()); + for (size_t i = 0; i < values_flat.values.Length(); ++i) { + auto& a = values_flat.values[i]; + auto& b = expected_values_flat.values[i]; + EXPECT_EQ(a, b); + if (expected->IsIntegral()) { + // Check that the constant's integer doesn't contain unexpected + // data in the MSBs that are outside of the bit-width of T. + EXPECT_EQ(builder::As(a), builder::As(b)); + } + } } INSTANTIATE_TEST_SUITE_P(MixedAbstractArgs, @@ -658,21 +650,15 @@ using ResolverConstEvalBinaryOpTest_Overflow = ResolverTestWithParamExpr(*this); + auto* rhs_expr = rhs->Expr(*this); auto* expr = create(Source{{1, 1}}, c.op, lhs_expr, rhs_expr); GlobalConst("C", expr); ASSERT_FALSE(r()->Resolve()); - - std::string type_name = std::visit( - [&](auto&& value) { - using ValueType = std::decay_t; - return builder::FriendlyName(); - }, - c.lhs); - EXPECT_THAT(r()->error(), HasSubstr("1:1 error: '")); - EXPECT_THAT(r()->error(), HasSubstr("' cannot be represented as '" + type_name + "'")); + EXPECT_THAT(r()->error(), HasSubstr("' cannot be represented as '" + lhs->TypeName() + "'")); } INSTANTIATE_TEST_SUITE_P( Test, @@ -854,10 +840,8 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractShiftLeftByNegativeValue_Error) { using ResolverConstEvalShiftLeftConcreteGeqBitWidthError = ResolverTestWithParam>; TEST_P(ResolverConstEvalShiftLeftConcreteGeqBitWidthError, Test) { - auto* lhs_expr = - std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<0>(GetParam())); - auto* rhs_expr = - std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<1>(GetParam())); + auto* lhs_expr = ToValueBase(std::get<0>(GetParam()))->Expr(*this); + auto* rhs_expr = ToValueBase(std::get<1>(GetParam()))->Expr(*this); GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ( @@ -880,10 +864,8 @@ INSTANTIATE_TEST_SUITE_P(Test, // AInt left shift results in sign change error using ResolverConstEvalShiftLeftSignChangeError = ResolverTestWithParam>; TEST_P(ResolverConstEvalShiftLeftSignChangeError, Test) { - auto* lhs_expr = - std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<0>(GetParam())); - auto* rhs_expr = - std::visit([&](auto&& value) { return value.Expr(*this); }, std::get<1>(GetParam())); + auto* lhs_expr = ToValueBase(std::get<0>(GetParam()))->Expr(*this); + auto* rhs_expr = ToValueBase(std::get<1>(GetParam()))->Expr(*this); GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr)); EXPECT_FALSE(r()->Resolve()); EXPECT_EQ(r()->error(), "1:1 error: shift left operation results in sign change"); diff --git a/src/tint/resolver/const_eval_builtin_test.cc b/src/tint/resolver/const_eval_builtin_test.cc index 3936fba242..86223ba3d8 100644 --- a/src/tint/resolver/const_eval_builtin_test.cc +++ b/src/tint/resolver/const_eval_builtin_test.cc @@ -83,54 +83,57 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) { std::visit([&](auto&& v) { args.Push(v.Expr(*this)); }, a); } - std::visit( - [&](auto&& expected) { - using T = typename std::decay_t::ElementType; - auto* expr = Call(sem::str(builtin), std::move(args)); + auto* expected = ToValueBase(c.expected); + auto* expr = Call(sem::str(builtin), std::move(args)); - GlobalConst("C", expr); - auto* expected_expr = expected.Expr(*this); - GlobalConst("E", expected_expr); + GlobalConst("C", expr); + auto* expected_expr = expected->Expr(*this); + GlobalConst("E", expected_expr); - EXPECT_TRUE(r()->Resolve()) << r()->error(); + EXPECT_TRUE(r()->Resolve()) << r()->error(); - auto* sem = Sem().Get(expr); - const sem::Constant* value = sem->ConstantValue(); - ASSERT_NE(value, nullptr); - EXPECT_TYPE(value->Type(), sem->Type()); + auto* sem = Sem().Get(expr); + const sem::Constant* value = sem->ConstantValue(); + ASSERT_NE(value, nullptr); + EXPECT_TYPE(value->Type(), sem->Type()); - auto* expected_sem = Sem().Get(expected_expr); - const sem::Constant* expected_value = expected_sem->ConstantValue(); - ASSERT_NE(expected_value, nullptr); - EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); + auto* expected_sem = Sem().Get(expected_expr); + const sem::Constant* expected_value = expected_sem->ConstantValue(); + ASSERT_NE(expected_value, nullptr); + EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); - ForEachElemPair(value, expected_value, - [&](const sem::Constant* a, const sem::Constant* b) { - auto v = a->As(); - auto e = b->As(); - if constexpr (std::is_same_v) { - EXPECT_EQ(v, e); - } else if constexpr (IsFloatingPoint) { - if (std::isnan(e)) { - EXPECT_TRUE(std::isnan(v)); - } else { - auto vf = (c.expected_pos_or_neg ? Abs(v) : v); - if (c.float_compare) { - EXPECT_FLOAT_EQ(vf, e); - } else { - EXPECT_EQ(vf, e); - } - } - } else { - EXPECT_EQ((c.expected_pos_or_neg ? Abs(v) : v), e); - // Check that the constant's integer doesn't contain unexpected - // data in the MSBs that are outside of the bit-width of T. - EXPECT_EQ(a->As(), b->As()); - } - return HasFailure() ? Action::kStop : Action::kContinue; - }); - }, - c.expected); + // @TODO(amaiorano): Rewrite using ScalarArgsFrom() + ForEachElemPair(value, expected_value, [&](const sem::Constant* a, const sem::Constant* b) { + std::visit( + [&](auto&& ct_expected) { + using T = typename std::decay_t::ElementType; + + auto v = a->As(); + auto e = b->As(); + if constexpr (std::is_same_v) { + EXPECT_EQ(v, e); + } else if constexpr (IsFloatingPoint) { + if (std::isnan(e)) { + EXPECT_TRUE(std::isnan(v)); + } else { + auto vf = (c.expected_pos_or_neg ? Abs(v) : v); + if (c.float_compare) { + EXPECT_FLOAT_EQ(vf, e); + } else { + EXPECT_EQ(vf, e); + } + } + } else { + EXPECT_EQ((c.expected_pos_or_neg ? Abs(v) : v), e); + // Check that the constant's integer doesn't contain unexpected + // data in the MSBs that are outside of the bit-width of T. + EXPECT_EQ(a->As(), b->As()); + } + }, + c.expected); + + return HasFailure() ? Action::kStop : Action::kContinue; + }); } INSTANTIATE_TEST_SUITE_P( // diff --git a/src/tint/resolver/const_eval_conversion_test.cc b/src/tint/resolver/const_eval_conversion_test.cc index 35657eeaa5..7e6a6fc1a1 100644 --- a/src/tint/resolver/const_eval_conversion_test.cc +++ b/src/tint/resolver/const_eval_conversion_test.cc @@ -29,20 +29,7 @@ using Scalar = std::variant< // builder::Value>; static std::ostream& operator<<(std::ostream& o, const Scalar& scalar) { - std::visit( - [&](auto&& v) { - using ValueType = std::decay_t; - o << ValueType::DataType::Name() << "("; - for (auto& a : v.args.values) { - o << std::get(a); - if (&a != &v.args.values.Back()) { - o << ", "; - } - } - o << ")"; - }, - scalar); - return o; + return ToValueBase(scalar)->Print(o); } enum class Kind { @@ -96,7 +83,7 @@ TEST_P(ResolverConstEvalConvTest, Test) { const auto& type = std::get<1>(GetParam()).type; const auto unrepresentable = std::get<1>(GetParam()).unrepresentable; - auto* input_val = std::visit([&](auto val) { return val.Expr(*this); }, input); + auto* input_val = ToValueBase(input)->Expr(*this); auto* expr = Construct(type.ast(*this), input_val); if (kind == Kind::kVector) { expr = Construct(ty.vec(nullptr, 3), expr); @@ -120,7 +107,7 @@ TEST_P(ResolverConstEvalConvTest, Test) { ASSERT_NE(sem->ConstantValue(), nullptr); EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty); - auto expected_values = std::visit([&](auto&& val) { return val.args; }, expected); + auto expected_values = ToValueBase(expected)->Args(); if (kind == Kind::kVector) { expected_values.values.Push(expected_values.values[0]); expected_values.values.Push(expected_values.values[0]); diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h index 38405401f2..2761daff0c 100644 --- a/src/tint/resolver/const_eval_test.h +++ b/src/tint/resolver/const_eval_test.h @@ -41,6 +41,8 @@ inline const auto k3PiOver4 = T(UnwrapNumber(2.356194490192344928846)); inline void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) { Switch( c->Type(), // + [&](const sem::AbstractInt*) { args.values.Push(c->As()); }, + [&](const sem::AbstractFloat*) { args.values.Push(c->As()); }, [&](const sem::Bool*) { args.values.Push(c->As()); }, [&](const sem::I32*) { args.values.Push(c->As()); }, [&](const sem::U32*) { args.values.Push(c->As()); }, @@ -136,6 +138,7 @@ using builder::IsValue; using builder::Mat; using builder::Val; using builder::Value; +using builder::ValueBase; using builder::Vec; using Types = std::variant< // @@ -188,21 +191,18 @@ using Types = std::variant< // // >; +/// Returns the current Value in the `types` variant as a `ValueBase` pointer to use the +/// polymorphic API. This trades longer compile times using std::variant for longer runtime via +/// virtual function calls. +template +inline const ValueBase* ToValueBase(const ValueVariant& types) { + return std::visit( + [](auto&& t) -> const ValueBase* { return static_cast(&t); }, types); +} + +/// Prints Types to ostream inline std::ostream& operator<<(std::ostream& o, const Types& types) { - std::visit( - [&](auto&& v) { - using ValueType = std::decay_t; - o << ValueType::DataType::Name() << "("; - for (auto& a : v.args.values) { - o << std::get(a); - if (&a != &v.args.values.Back()) { - o << ", "; - } - } - o << ")"; - }, - types); - return o; + return ToValueBase(types)->Print(o); } // Calls `f` on deepest elements of both `a` and `b`. If function returns Action::kStop, it stops diff --git a/src/tint/resolver/const_eval_unary_op_test.cc b/src/tint/resolver/const_eval_unary_op_test.cc index 80b8caa6fa..fced490907 100644 --- a/src/tint/resolver/const_eval_unary_op_test.cc +++ b/src/tint/resolver/const_eval_unary_op_test.cc @@ -51,40 +51,34 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) { auto op = std::get<0>(GetParam()); auto& c = std::get<1>(GetParam()); - std::visit( - [&](auto&& expected) { - using T = typename std::decay_t::ElementType; - auto* input_expr = std::visit([&](auto&& value) { return value.Expr(*this); }, c.input); - auto* expr = create(op, input_expr); + auto* expected = ToValueBase(c.expected); + auto* input = ToValueBase(c.input); - GlobalConst("C", expr); - auto* expected_expr = expected.Expr(*this); - GlobalConst("E", expected_expr); - ASSERT_TRUE(r()->Resolve()) << r()->error(); + auto* input_expr = input->Expr(*this); + auto* expr = create(op, input_expr); - auto* sem = Sem().Get(expr); - const sem::Constant* value = sem->ConstantValue(); - ASSERT_NE(value, nullptr); - EXPECT_TYPE(value->Type(), sem->Type()); + GlobalConst("C", expr); + ASSERT_TRUE(r()->Resolve()) << r()->error(); - auto* expected_sem = Sem().Get(expected_expr); - const sem::Constant* expected_value = expected_sem->ConstantValue(); - ASSERT_NE(expected_value, nullptr); - EXPECT_TYPE(expected_value->Type(), expected_sem->Type()); + auto* sem = Sem().Get(expr); + const sem::Constant* value = sem->ConstantValue(); + ASSERT_NE(value, nullptr); + EXPECT_TYPE(value->Type(), sem->Type()); - ForEachElemPair(value, expected_value, - [&](const sem::Constant* a, const sem::Constant* b) { - EXPECT_EQ(a->As(), b->As()); - if constexpr (IsIntegral) { - // Check that the constant's integer doesn't contain unexpected - // data in the MSBs that are outside of the bit-width of T. - EXPECT_EQ(a->As(), b->As()); - } - return HasFailure() ? Action::kStop : Action::kContinue; - }); - }, - c.expected); + auto values_flat = ScalarArgsFrom(value); + auto expected_values_flat = expected->Args(); + ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length()); + for (size_t i = 0; i < values_flat.values.Length(); ++i) { + auto& a = values_flat.values[i]; + auto& b = expected_values_flat.values[i]; + EXPECT_EQ(a, b); + if (expected->IsIntegral()) { + // Check that the constant's integer doesn't contain unexpected + // data in the MSBs that are outside of the bit-width of T. + EXPECT_EQ(builder::As(a), builder::As(b)); + } + } } INSTANTIATE_TEST_SUITE_P(Complement, ResolverConstEvalUnaryOpTest, diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h index 923f3ff461..57fe14a50b 100644 --- a/src/tint/resolver/resolver_test_helper.h +++ b/src/tint/resolver/resolver_test_helper.h @@ -206,6 +206,12 @@ struct ScalarArgs { utils::Vector values; }; +/// Returns current variant value in `s` cast to type `T` +template +T As(ScalarArgs::Storage& s) { + return std::visit([](auto&& v) { return static_cast(v); }, s); +} + /// @param o the std::ostream to write to /// @param args the ScalarArgs /// @return the std::ostream so calls can be chained @@ -750,10 +756,45 @@ constexpr CreatePtrs CreatePtrsFor() { DataType::Name}; } +/// Base class for Value +struct ValueBase { + /// Constructor + ValueBase() = default; + /// Destructor + virtual ~ValueBase() = default; + /// Move constructor + ValueBase(ValueBase&&) = default; + /// Copy constructor + ValueBase(const ValueBase&) = default; + /// Copy assignment operator + /// @returns this instance + ValueBase& operator=(const ValueBase&) = default; + /// Creates an `ast::Expression` for the type T passing in previously stored args + /// @param b the ProgramBuilder + /// @returns an expression node + virtual const ast::Expression* Expr(ProgramBuilder& b) const = 0; + /// @returns args used to create expression via `Expr` + virtual const ScalarArgs& Args() const = 0; + /// @returns true if element type is abstract + virtual bool IsAbstract() const = 0; + /// @returns true if element type is an integral + virtual bool IsIntegral() const = 0; + /// @returns element type name + virtual std::string TypeName() const = 0; + /// Prints this value to the output stream + /// @param o the output stream + /// @returns input argument `o` + virtual std::ostream& Print(std::ostream& o) const = 0; +}; + /// Value is an instance of a value of type DataType. Useful for storing values to create /// expressions with. template -struct Value { +struct Value : ValueBase { + /// Constructor + /// @param a the scalar args + explicit Value(ScalarArgs a) : args(std::move(a)) {} + /// Alias to T using Type = T; /// Alias to DataType @@ -764,15 +805,43 @@ struct Value { /// Creates a Value with `args` /// @param args the args that will be passed to the expression /// @returns a Value - static Value Create(ScalarArgs args) { return Value{CreatePtrsFor(), std::move(args)}; } + static Value Create(ScalarArgs args) { return Value{std::move(args)}; } /// Creates an `ast::Expression` for the type T passing in previously stored args /// @param b the ProgramBuilder /// @returns an expression node - const ast::Expression* Expr(ProgramBuilder& b) const { return (*create.expr)(b, args); } + const ast::Expression* Expr(ProgramBuilder& b) const override { + auto create = CreatePtrsFor(); + return (*create.expr)(b, args); + } + + /// @returns args used to create expression via `Expr` + const ScalarArgs& Args() const override { return args; } + + /// @returns true if element type is abstract + bool IsAbstract() const override { return tint::IsAbstract; } + + /// @returns true if element type is an integral + bool IsIntegral() const override { return tint::IsIntegral; } + + /// @returns element type name + std::string TypeName() const override { return tint::FriendlyName(); } + + /// Prints this value to the output stream + /// @param o the output stream + /// @returns input argument `o` + std::ostream& Print(std::ostream& o) const override { + o << TypeName() << "("; + for (auto& a : args.values) { + o << std::get(a); + if (&a != &args.values.Back()) { + o << ", "; + } + } + o << ")"; + return o; + } - /// functions to create values / types of the value - CreatePtrs create; /// args to create expression with ScalarArgs args; };