From 95d174a1182dd43fa41dfb84173f0a14d3a2b5d6 Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Fri, 25 Nov 2022 15:30:51 +0000 Subject: [PATCH] tint/resolver: Further simplify test const eval framework Replace ScalarArgs struct with Scalar variant and vector. Fold ValueBase and ConcreteValue into Value. Change-Id: I5cc5811a87f1aae162feb65fb6b1ecdac033d0fe Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111761 Reviewed-by: Antonio Maiorano Commit-Queue: Ben Clayton Kokoro: Kokoro --- .../resolver/const_eval_conversion_test.cc | 8 +- src/tint/resolver/const_eval_test.h | 39 +-- src/tint/resolver/const_eval_unary_op_test.cc | 14 +- src/tint/resolver/resolver_test_helper.h | 318 ++++++------------ 4 files changed, 127 insertions(+), 252 deletions(-) diff --git a/src/tint/resolver/const_eval_conversion_test.cc b/src/tint/resolver/const_eval_conversion_test.cc index da37f3bc82..8640a1d2e6 100644 --- a/src/tint/resolver/const_eval_conversion_test.cc +++ b/src/tint/resolver/const_eval_conversion_test.cc @@ -94,12 +94,12 @@ TEST_P(ResolverConstEvalConvTest, Test) { ASSERT_NE(sem->ConstantValue(), nullptr); EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty); - auto expected_values = expected.Args(); + auto expected_values = expected.args; if (kind == Kind::kVector) { - expected_values.values.Push(expected_values.values[0]); - expected_values.values.Push(expected_values.values[0]); + expected_values.Push(expected_values[0]); + expected_values.Push(expected_values[0]); } - auto got_values = ScalarArgsFrom(sem->ConstantValue()); + auto got_values = ScalarsFrom(sem->ConstantValue()); EXPECT_EQ(expected_values, got_values); } } diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h index dcb91d0914..0d27805ce9 100644 --- a/src/tint/resolver/const_eval_test.h +++ b/src/tint/resolver/const_eval_test.h @@ -37,28 +37,29 @@ template inline const auto k3PiOver4 = T(UnwrapNumber(2.356194490192344928846)); /// Walks the sem::Constant @p c, accumulating all the inner-most scalar values into @p args -inline void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) { +template +inline void CollectScalars(const sem::Constant* c, utils::Vector& scalars) { Switch( c->Type(), // - [&](const sem::AbstractInt*) { args.values.Push(c->As()); }, - [&](const sem::AbstractFloat*) { args.values.Push(c->As()); }, - [&](const sem::Bool*) { args.values.Push(c->As()); }, - [&](const sem::I32*) { args.values.Push(c->As()); }, - [&](const sem::U32*) { args.values.Push(c->As()); }, - [&](const sem::F32*) { args.values.Push(c->As()); }, - [&](const sem::F16*) { args.values.Push(c->As()); }, + [&](const sem::AbstractInt*) { scalars.Push(c->As()); }, + [&](const sem::AbstractFloat*) { scalars.Push(c->As()); }, + [&](const sem::Bool*) { scalars.Push(c->As()); }, + [&](const sem::I32*) { scalars.Push(c->As()); }, + [&](const sem::U32*) { scalars.Push(c->As()); }, + [&](const sem::F32*) { scalars.Push(c->As()); }, + [&](const sem::F16*) { scalars.Push(c->As()); }, [&](Default) { size_t i = 0; while (auto* child = c->Index(i++)) { - CollectScalarArgs(child, args); + CollectScalars(child, scalars); } }); } /// Walks the sem::Constant @p c, returning all the inner-most scalar values. -inline builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) { - builder::ScalarArgs out; - CollectScalarArgs(c, out); +inline utils::Vector ScalarsFrom(const sem::Constant* c) { + utils::Vector out; + CollectScalars(c, out); return out; } @@ -90,14 +91,14 @@ struct CheckConstantFlags { inline void CheckConstant(const sem::Constant* got_constant, const builder::Value& expected_value, CheckConstantFlags flags = {}) { - auto values_flat = ScalarArgsFrom(got_constant); - auto expected_values_flat = expected_value.Args(); - ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length()); - for (size_t i = 0; i < values_flat.values.Length(); ++i) { - auto& got_scalar = values_flat.values[i]; - auto& expected_scalar = expected_values_flat.values[i]; + auto values_flat = ScalarsFrom(got_constant); + auto expected_values_flat = expected_value.args; + ASSERT_EQ(values_flat.Length(), expected_values_flat.Length()); + for (size_t i = 0; i < values_flat.Length(); ++i) { + auto& got_scalar = values_flat[i]; + auto& expected_scalar = expected_values_flat[i]; std::visit( - [&](auto&& expected) { + [&](const auto& expected) { using T = std::decay_t; ASSERT_TRUE(std::holds_alternative(got_scalar)); diff --git a/src/tint/resolver/const_eval_unary_op_test.cc b/src/tint/resolver/const_eval_unary_op_test.cc index d24c27b60e..f7f592855a 100644 --- a/src/tint/resolver/const_eval_unary_op_test.cc +++ b/src/tint/resolver/const_eval_unary_op_test.cc @@ -61,14 +61,14 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) { ASSERT_NE(value, nullptr); EXPECT_TYPE(value->Type(), sem->Type()); - auto values_flat = ScalarArgsFrom(value); - auto expected_values_flat = expected.Args(); - ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length()); - for (size_t i = 0; i < values_flat.values.Length(); ++i) { - auto& a = values_flat.values[i]; - auto& b = expected_values_flat.values[i]; + auto values_flat = ScalarsFrom(value); + auto expected_values_flat = expected.args; + ASSERT_EQ(values_flat.Length(), expected_values_flat.Length()); + for (size_t i = 0; i < values_flat.Length(); ++i) { + auto& a = values_flat[i]; + auto& b = expected_values_flat[i]; EXPECT_EQ(a, b); - if (expected.IsIntegral()) { + if (expected.is_integral) { // Check that the constant's integer doesn't contain unexpected // data in the MSBs that are outside of the bit-width of T. EXPECT_EQ(builder::As(a), builder::As(b)); diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h index edbc456590..cf17b616df 100644 --- a/src/tint/resolver/resolver_test_helper.h +++ b/src/tint/resolver/resolver_test_helper.h @@ -180,63 +180,18 @@ using alias3 = alias; template struct ptr {}; -/// Type used to accept scalars as arguments. Can be either a single value that gets splatted for -/// composite types, or all values required by the composite type. -struct ScalarArgs { - /// Constructor - ScalarArgs() = default; - - /// Constructor - /// @param single_value single value to initialize with - template - explicit ScalarArgs(T single_value) : values(utils::Vector{single_value}) {} - - /// Constructor - /// @param all_values all values to initialize the composite type with - template - ScalarArgs(utils::VectorRef all_values) // NOLINT: implicit on purpose - { - for (auto& v : all_values) { - values.Push(v); - } - } - - /// @param other the other ScalarArgs to compare against - /// @returns true if all values are equal to the values in @p other - bool operator==(const ScalarArgs& other) const { return values == other.values; } - - /// Valid scalar types for args - using Storage = std::variant; - - /// The vector of values - utils::Vector values; -}; +/// A scalar value +using Scalar = std::variant; /// Returns current variant value in `s` cast to type `T` template -T As(ScalarArgs::Storage& s) { +T As(Scalar& s) { return std::visit([](auto&& v) { return static_cast(v); }, s); } -/// @param o the std::ostream to write to -/// @param args the ScalarArgs -/// @return the std::ostream so calls can be chained -inline std::ostream& operator<<(std::ostream& o, const ScalarArgs& args) { - o << "["; - bool first = true; - for (auto& val : args.values) { - if (!first) { - o << ", "; - } - first = false; - std::visit([&](auto&& v) { o << v; }, val); - } - o << "]"; - return o; -} - using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b); -using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args); +using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, + utils::VectorRef args); using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v); using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b); using type_name_func_ptr = std::string (*)(); @@ -280,14 +235,14 @@ struct DataType { /// @param b the ProgramBuilder /// @param args args of size 1 with the boolean value to init with /// @return a new AST expression of the bool type - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { - return b.Expr(std::get(args.values[0])); + static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef args) { + return b.Expr(std::get(args[0])); } /// @param b the ProgramBuilder /// @param v arg of type double that will be cast to bool. /// @return a new AST expression of the bool type static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "bool"; } @@ -311,14 +266,14 @@ struct DataType { /// @param b the ProgramBuilder /// @param args args of size 1 with the i32 value to init with /// @return a new AST i32 literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { - return b.Expr(std::get(args.values[0])); + static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef args) { + return b.Expr(std::get(args[0])); } /// @param b the ProgramBuilder /// @param v arg of type double that will be cast to i32. /// @return a new AST i32 literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "i32"; } @@ -342,14 +297,14 @@ struct DataType { /// @param b the ProgramBuilder /// @param args args of size 1 with the u32 value to init with /// @return a new AST u32 literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { - return b.Expr(std::get(args.values[0])); + static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef args) { + return b.Expr(std::get(args[0])); } /// @param b the ProgramBuilder /// @param v arg of type double that will be cast to u32. /// @return a new AST u32 literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "u32"; } @@ -373,14 +328,14 @@ struct DataType { /// @param b the ProgramBuilder /// @param args args of size 1 with the f32 value to init with /// @return a new AST f32 literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { - return b.Expr(std::get(args.values[0])); + static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef args) { + return b.Expr(std::get(args[0])); } /// @param b the ProgramBuilder /// @param v arg of type double that will be cast to f32. /// @return a new AST f32 literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "f32"; } @@ -404,14 +359,14 @@ struct DataType { /// @param b the ProgramBuilder /// @param args args of size 1 with the f16 value to init with /// @return a new AST f16 literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { - return b.Expr(std::get(args.values[0])); + static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef args) { + return b.Expr(std::get(args[0])); } /// @param b the ProgramBuilder /// @param v arg of type double that will be cast to f16. /// @return a new AST f16 literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "f16"; } @@ -434,14 +389,14 @@ struct DataType { /// @param b the ProgramBuilder /// @param args args of size 1 with the abstract-float value to init with /// @return a new AST abstract-float literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { - return b.Expr(std::get(args.values[0])); + static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef args) { + return b.Expr(std::get(args[0])); } /// @param b the ProgramBuilder /// @param v arg of type double that will be cast to AFloat. /// @return a new AST abstract-float literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "abstract-float"; } @@ -464,14 +419,14 @@ struct DataType { /// @param b the ProgramBuilder /// @param args args of size 1 with the abstract-int value to init with /// @return a new AST abstract-int literal value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { - return b.Expr(std::get(args.values[0])); + static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef args) { + return b.Expr(std::get(args[0])); } /// @param b the ProgramBuilder /// @param v arg of type double that will be cast to AInt. /// @return a new AST abstract-int literal value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { return "abstract-int"; } @@ -499,17 +454,17 @@ struct DataType> { /// @param b the ProgramBuilder /// @param args args of size 1 or N with values of type T to initialize with /// @return a new AST vector value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef args) { return b.Construct(AST(b), ExprArgs(b, std::move(args))); } /// @param b the ProgramBuilder /// @param args args of size 1 or N with values of type T to initialize with /// @return the list of expressions that are used to construct the vector - static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) { - const bool one_value = args.values.Length() == 1; + static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef args) { + const bool one_value = args.Length() == 1; utils::Vector r; for (size_t i = 0; i < N; ++i) { - r.Push(DataType::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]})); + r.Push(DataType::Expr(b, utils::Vector{one_value ? args[0] : args[i]})); } return r; } @@ -517,7 +472,7 @@ struct DataType> { /// @param v arg of type double that will be cast to ElementType /// @return a new AST vector value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { @@ -548,25 +503,25 @@ struct DataType> { /// @param b the ProgramBuilder /// @param args args of size 1 or N*M with values of type T to initialize with /// @return a new AST matrix value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef args) { return b.Construct(AST(b), ExprArgs(b, std::move(args))); } /// @param b the ProgramBuilder /// @param args args of size 1 or N*M with values of type T to initialize with /// @return a new AST matrix value expression - static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) { - const bool one_value = args.values.Length() == 1; + static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef args) { + const bool one_value = args.Length() == 1; size_t next = 0; utils::Vector r; for (uint32_t i = 0; i < N; ++i) { if (one_value) { - r.Push(DataType>::Expr(b, ScalarArgs{args.values[0]})); + r.Push(DataType>::Expr(b, utils::Vector{args[0]})); } else { - utils::Vector v; + utils::Vector v; for (size_t j = 0; j < M; ++j) { - v.Push(std::get(args.values[next++])); + v.Push(args[next++]); } - r.Push(DataType>::Expr(b, utils::VectorRef{v})); + r.Push(DataType>::Expr(b, std::move(v))); } } return r; @@ -575,7 +530,7 @@ struct DataType> { /// @param v arg of type double that will be cast to ElementType /// @return a new AST matrix value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { @@ -611,8 +566,9 @@ struct DataType> { /// @param args the value nested elements will be initialized with /// @return a new AST expression of the alias type template - static inline traits::EnableIf Expr(ProgramBuilder& b, - ScalarArgs args) { + static inline traits::EnableIf Expr( + ProgramBuilder& b, + utils::VectorRef args) { // Cast return b.Construct(AST(b), DataType::Expr(b, std::move(args))); } @@ -621,8 +577,9 @@ struct DataType> { /// @param args the value nested elements will be initialized with /// @return a new AST expression of the alias type template - static inline traits::EnableIf Expr(ProgramBuilder& b, - ScalarArgs args) { + static inline traits::EnableIf Expr( + ProgramBuilder& b, + utils::VectorRef args) { // Construct return b.Construct(AST(b), DataType::ExprArgs(b, std::move(args))); } @@ -631,7 +588,7 @@ struct DataType> { /// @param v arg of type double that will be cast to ElementType /// @return a new AST expression of the alias type static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type @@ -662,7 +619,8 @@ struct DataType> { /// @param b the ProgramBuilder /// @return a new AST expression of the pointer type - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs /*unused*/) { + static inline const ast::Expression* Expr(ProgramBuilder& b, + utils::VectorRef /*unused*/) { auto sym = b.Symbols().New("global_for_ptr"); b.GlobalVar(sym, DataType::AST(b), ast::AddressSpace::kPrivate); return b.AddressOf(sym); @@ -672,7 +630,7 @@ struct DataType> { /// @param v arg of type double that will be cast to ElementType /// @return a new AST expression of the pointer type static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type @@ -716,17 +674,17 @@ struct DataType> { /// @param args args of size 1 or N with values of type T to initialize with /// with /// @return a new AST array value expression - static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) { + static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef args) { return b.Construct(AST(b), ExprArgs(b, std::move(args))); } /// @param b the ProgramBuilder /// @param args args of size 1 or N with values of type T to initialize with /// @return the list of expressions that are used to construct the array - static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) { - const bool one_value = args.values.Length() == 1; + static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef args) { + const bool one_value = args.Length() == 1; utils::Vector r; for (uint32_t i = 0; i < N; i++) { - r.Push(DataType::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]})); + r.Push(DataType::Expr(b, utils::Vector{one_value ? args[0] : args[i]})); } return r; } @@ -734,7 +692,7 @@ struct DataType> { /// @param v arg of type double that will be cast to ElementType /// @return a new AST array value expression static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) { - return Expr(b, ScalarArgs{static_cast(v)}); + return Expr(b, utils::Vector{static_cast(v)}); } /// @returns the WGSL name for the type static inline std::string Name() { @@ -776,80 +734,34 @@ template const bool IsDataTypeSpecializedFor = !std::is_same_v::ElementType, UnspecializedElementType>; -namespace detail { -/// ValueBase is a base class of ConcreteValue -struct ValueBase { - /// Constructor - ValueBase() = default; - /// Destructor - virtual ~ValueBase() = default; - /// Move constructor - ValueBase(ValueBase&&) = default; - /// Copy constructor - ValueBase(const ValueBase&) = default; - /// Copy assignment operator - /// @returns this instance - ValueBase& operator=(const ValueBase&) = default; - /// Creates an `ast::Expression` for the type T passing in previously stored args - /// @param b the ProgramBuilder - /// @returns an expression node - virtual const ast::Expression* Expr(ProgramBuilder& b) const = 0; - /// @returns args used to create expression via `Expr` - virtual const ScalarArgs& Args() const = 0; - /// @returns true if element type is abstract - virtual bool IsAbstract() const = 0; - /// @returns true if element type is an integral - virtual bool IsIntegral() const = 0; - /// @returns element type name - virtual std::string TypeName() const = 0; - /// Prints this value to the output stream - /// @param o the output stream - /// @returns input argument `o` - virtual std::ostream& Print(std::ostream& o) const = 0; -}; - -/// ConcreteValue is used to create Values of type DataType with a ScalarArgs initializer. -template -struct ConcreteValue : ValueBase { - /// Constructor +/// Value is used to create Values with a Scalar vector initializer. +struct Value { + /// Creates a Value for type T initialized with `args` /// @param args the scalar args - explicit ConcreteValue(ScalarArgs args) : args_(std::move(args)) {} - - /// Alias to T - using Type = T; - /// Alias to DataType - using DataType = builder::DataType; - /// Alias to DataType::ElementType - using ElementType = typename DataType::ElementType; - - /// Creates an `ast::Expression` for the type T passing in previously stored args - /// @param b the ProgramBuilder - /// @returns an expression node - const ast::Expression* Expr(ProgramBuilder& b) const override { - auto create = CreatePtrsFor(); - return (*create.expr)(b, args_); + /// @returns Value + template + static Value Create(utils::VectorRef args) { + static_assert(IsDataTypeSpecializedFor, "No DataType specialization exists"); + using EL_TY = typename builder::DataType::ElementType; + return Value{ + std::move(args), CreatePtrsFor().expr, tint::IsAbstract, + tint::IsIntegral, tint::FriendlyName(), + }; } - /// @returns args used to create expression via `Expr` - const ScalarArgs& Args() const override { return args_; } - - /// @returns true if element type is abstract - bool IsAbstract() const override { return tint::IsAbstract; } - - /// @returns true if element type is an integral - bool IsIntegral() const override { return tint::IsIntegral; } - - /// @returns element type name - std::string TypeName() const override { return tint::FriendlyName(); } + /// Creates an `ast::Expression` for the type T passing in previously stored args + /// @param b the ProgramBuilder + /// @returns an expression node + const ast::Expression* Expr(ProgramBuilder& b) const { return (*create)(b, args); } /// Prints this value to the output stream /// @param o the output stream /// @returns input argument `o` - std::ostream& Print(std::ostream& o) const override { - o << TypeName() << "("; - for (auto& a : args_.values) { - o << std::get(a); - if (&a != &args_.values.Back()) { + std::ostream& Print(std::ostream& o) const { + o << type_name << "("; + for (auto& a : args) { + std::visit([&](auto& v) { o << v; }, a); + if (&a != &args.Back()) { o << ", "; } } @@ -857,54 +769,16 @@ struct ConcreteValue : ValueBase { return o; } - private: - /// args to create expression with - ScalarArgs args_; -}; -} // namespace detail - -/// A Value represents a value of type DataType created with ScalarArgs. Useful for storing -/// values for unit tests. -class Value { - public: - /// Creates a Value for type T initialized with `args` - /// @param args the scalar args - /// @returns Value - template - static Value Create(ScalarArgs args) { - static_assert(IsDataTypeSpecializedFor, "No DataType specialization exists"); - return Value{std::make_shared>(std::move(args))}; - } - - /// Creates an `ast::Expression` for the type T passing in previously stored args - /// @param b the ProgramBuilder - /// @returns an expression node - const ast::Expression* Expr(ProgramBuilder& b) const { return value_->Expr(b); } - - /// @returns args used to create expression via `Expr` - const ScalarArgs& Args() const { return value_->Args(); } - - /// @returns true if element type is abstract - bool IsAbstract() const { return value_->IsAbstract(); } - - /// @returns true if element type is an integral - bool IsIntegral() const { return value_->IsIntegral(); } - - /// @returns element type name - std::string TypeName() const { return value_->TypeName(); } - - /// Prints this value to the output stream - /// @param o the output stream - /// @returns input argument `o` - std::ostream& Print(std::ostream& o) const { return value_->Print(o); } - - private: - /// Private constructor - explicit Value(std::shared_ptr value) : value_(std::move(value)) {} - - /// Shared pointer to an immutable value. This type-erasure pattern allows Value to wrap a - /// polymorphic type, while being used like a value-type (i.e. copyable). - std::shared_ptr value_; + /// The arguments used to construct the value + utils::Vector args; + /// Function used to construct an expression with the given value + builder::ast_expr_func_ptr create; + /// True if the element type is abstract + bool is_abstract = false; + /// True if the element type is an integer + bool is_integral = false; + /// The name of the type. + const char* type_name = ""; }; /// Prints Value to ostream @@ -919,7 +793,7 @@ constexpr bool IsValue = std::is_same_v; /// Creates a Value of DataType from a scalar `v` template Value Val(T v) { - return Value::Create(ScalarArgs{v}); + return Value::Create(utils::Vector{v}); } /// Creates a Value of DataType> from N scalar `args` @@ -927,41 +801,41 @@ template Value Vec(T... args) { using FirstT = std::tuple_element_t<0, std::tuple>; constexpr size_t N = sizeof...(args); - utils::Vector v{args...}; - return Value::Create>(utils::VectorRef{v}); + utils::Vector v{args...}; + return Value::Create>(std::move(v)); } /// Creates a Value of DataType from C*R scalar `args` template Value Mat(const T (&m_in)[C][R]) { - utils::Vector m; + utils::Vector m; for (uint32_t i = 0; i < C; ++i) { for (size_t j = 0; j < R; ++j) { m.Push(m_in[i][j]); } } - return Value::Create>(utils::VectorRef{m}); + return Value::Create>(std::move(m)); } /// Creates a Value of DataType from column vectors `c0` and `c1` template Value Mat(const T (&c0)[R], const T (&c1)[R]) { constexpr size_t C = 2; - utils::Vector m; + utils::Vector m; for (auto v : c0) { m.Push(v); } for (auto v : c1) { m.Push(v); } - return Value::Create>(utils::VectorRef{m}); + return Value::Create>(std::move(m)); } /// Creates a Value of DataType from column vectors `c0`, `c1`, and `c2` template Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) { constexpr size_t C = 3; - utils::Vector m; + utils::Vector m; for (auto v : c0) { m.Push(v); } @@ -971,14 +845,14 @@ Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) { for (auto v : c2) { m.Push(v); } - return Value::Create>(utils::VectorRef{m}); + return Value::Create>(std::move(m)); } /// Creates a Value of DataType from column vectors `c0`, `c1`, `c2`, and `c3` template Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) { constexpr size_t C = 4; - utils::Vector m; + utils::Vector m; for (auto v : c0) { m.Push(v); } @@ -991,7 +865,7 @@ Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R] for (auto v : c3) { m.Push(v); } - return Value::Create>(utils::VectorRef{m}); + return Value::Create>(std::move(m)); } } // namespace builder