tint/resolver: Further simplify test const eval framework

Replace ScalarArgs struct with Scalar variant and vector.
Fold ValueBase and ConcreteValue into Value.

Change-Id: I5cc5811a87f1aae162feb65fb6b1ecdac033d0fe
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111761
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Ben Clayton 2022-11-25 15:30:51 +00:00 committed by Dawn LUCI CQ
parent c572df265d
commit 95d174a118
4 changed files with 127 additions and 252 deletions

View File

@ -94,12 +94,12 @@ TEST_P(ResolverConstEvalConvTest, Test) {
ASSERT_NE(sem->ConstantValue(), nullptr);
EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
auto expected_values = expected.Args();
auto expected_values = expected.args;
if (kind == Kind::kVector) {
expected_values.values.Push(expected_values.values[0]);
expected_values.values.Push(expected_values.values[0]);
expected_values.Push(expected_values[0]);
expected_values.Push(expected_values[0]);
}
auto got_values = ScalarArgsFrom(sem->ConstantValue());
auto got_values = ScalarsFrom(sem->ConstantValue());
EXPECT_EQ(expected_values, got_values);
}
}

View File

@ -37,28 +37,29 @@ template <typename T>
inline const auto k3PiOver4 = T(UnwrapNumber<T>(2.356194490192344928846));
/// Walks the sem::Constant @p c, accumulating all the inner-most scalar values into @p args
inline void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) {
template <size_t N>
inline void CollectScalars(const sem::Constant* c, utils::Vector<builder::Scalar, N>& scalars) {
Switch(
c->Type(), //
[&](const sem::AbstractInt*) { args.values.Push(c->As<AInt>()); },
[&](const sem::AbstractFloat*) { args.values.Push(c->As<AFloat>()); },
[&](const sem::Bool*) { args.values.Push(c->As<bool>()); },
[&](const sem::I32*) { args.values.Push(c->As<i32>()); },
[&](const sem::U32*) { args.values.Push(c->As<u32>()); },
[&](const sem::F32*) { args.values.Push(c->As<f32>()); },
[&](const sem::F16*) { args.values.Push(c->As<f16>()); },
[&](const sem::AbstractInt*) { scalars.Push(c->As<AInt>()); },
[&](const sem::AbstractFloat*) { scalars.Push(c->As<AFloat>()); },
[&](const sem::Bool*) { scalars.Push(c->As<bool>()); },
[&](const sem::I32*) { scalars.Push(c->As<i32>()); },
[&](const sem::U32*) { scalars.Push(c->As<u32>()); },
[&](const sem::F32*) { scalars.Push(c->As<f32>()); },
[&](const sem::F16*) { scalars.Push(c->As<f16>()); },
[&](Default) {
size_t i = 0;
while (auto* child = c->Index(i++)) {
CollectScalarArgs(child, args);
CollectScalars(child, scalars);
}
});
}
/// Walks the sem::Constant @p c, returning all the inner-most scalar values.
inline builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) {
builder::ScalarArgs out;
CollectScalarArgs(c, out);
inline utils::Vector<builder::Scalar, 16> ScalarsFrom(const sem::Constant* c) {
utils::Vector<builder::Scalar, 16> out;
CollectScalars(c, out);
return out;
}
@ -90,14 +91,14 @@ struct CheckConstantFlags {
inline void CheckConstant(const sem::Constant* got_constant,
const builder::Value& expected_value,
CheckConstantFlags flags = {}) {
auto values_flat = ScalarArgsFrom(got_constant);
auto expected_values_flat = expected_value.Args();
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
auto& got_scalar = values_flat.values[i];
auto& expected_scalar = expected_values_flat.values[i];
auto values_flat = ScalarsFrom(got_constant);
auto expected_values_flat = expected_value.args;
ASSERT_EQ(values_flat.Length(), expected_values_flat.Length());
for (size_t i = 0; i < values_flat.Length(); ++i) {
auto& got_scalar = values_flat[i];
auto& expected_scalar = expected_values_flat[i];
std::visit(
[&](auto&& expected) {
[&](const auto& expected) {
using T = std::decay_t<decltype(expected)>;
ASSERT_TRUE(std::holds_alternative<T>(got_scalar));

View File

@ -61,14 +61,14 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) {
ASSERT_NE(value, nullptr);
EXPECT_TYPE(value->Type(), sem->Type());
auto values_flat = ScalarArgsFrom(value);
auto expected_values_flat = expected.Args();
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
auto& a = values_flat.values[i];
auto& b = expected_values_flat.values[i];
auto values_flat = ScalarsFrom(value);
auto expected_values_flat = expected.args;
ASSERT_EQ(values_flat.Length(), expected_values_flat.Length());
for (size_t i = 0; i < values_flat.Length(); ++i) {
auto& a = values_flat[i];
auto& b = expected_values_flat[i];
EXPECT_EQ(a, b);
if (expected.IsIntegral()) {
if (expected.is_integral) {
// Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));

View File

@ -180,63 +180,18 @@ using alias3 = alias<TO, 3>;
template <typename TO>
struct ptr {};
/// Type used to accept scalars as arguments. Can be either a single value that gets splatted for
/// composite types, or all values required by the composite type.
struct ScalarArgs {
/// Constructor
ScalarArgs() = default;
/// Constructor
/// @param single_value single value to initialize with
template <typename T>
explicit ScalarArgs(T single_value) : values(utils::Vector<Storage, 1>{single_value}) {}
/// Constructor
/// @param all_values all values to initialize the composite type with
template <typename T>
ScalarArgs(utils::VectorRef<T> all_values) // NOLINT: implicit on purpose
{
for (auto& v : all_values) {
values.Push(v);
}
}
/// @param other the other ScalarArgs to compare against
/// @returns true if all values are equal to the values in @p other
bool operator==(const ScalarArgs& other) const { return values == other.values; }
/// Valid scalar types for args
using Storage = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>;
/// The vector of values
utils::Vector<Storage, 16> values;
};
/// A scalar value
using Scalar = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>;
/// Returns current variant value in `s` cast to type `T`
template <typename T>
T As(ScalarArgs::Storage& s) {
T As(Scalar& s) {
return std::visit([](auto&& v) { return static_cast<T>(v); }, s);
}
/// @param o the std::ostream to write to
/// @param args the ScalarArgs
/// @return the std::ostream so calls can be chained
inline std::ostream& operator<<(std::ostream& o, const ScalarArgs& args) {
o << "[";
bool first = true;
for (auto& val : args.values) {
if (!first) {
o << ", ";
}
first = false;
std::visit([&](auto&& v) { o << v; }, val);
}
o << "]";
return o;
}
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args);
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
utils::VectorRef<Scalar> args);
using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v);
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
using type_name_func_ptr = std::string (*)();
@ -280,14 +235,14 @@ struct DataType<bool> {
/// @param b the ProgramBuilder
/// @param args args of size 1 with the boolean value to init with
/// @return a new AST expression of the bool type
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<bool>(args.values[0]));
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Expr(std::get<bool>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to bool.
/// @return a new AST expression of the bool type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "bool"; }
@ -311,14 +266,14 @@ struct DataType<i32> {
/// @param b the ProgramBuilder
/// @param args args of size 1 with the i32 value to init with
/// @return a new AST i32 literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<i32>(args.values[0]));
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Expr(std::get<i32>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to i32.
/// @return a new AST i32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "i32"; }
@ -342,14 +297,14 @@ struct DataType<u32> {
/// @param b the ProgramBuilder
/// @param args args of size 1 with the u32 value to init with
/// @return a new AST u32 literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<u32>(args.values[0]));
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Expr(std::get<u32>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to u32.
/// @return a new AST u32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "u32"; }
@ -373,14 +328,14 @@ struct DataType<f32> {
/// @param b the ProgramBuilder
/// @param args args of size 1 with the f32 value to init with
/// @return a new AST f32 literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<f32>(args.values[0]));
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Expr(std::get<f32>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to f32.
/// @return a new AST f32 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<f32>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<f32>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "f32"; }
@ -404,14 +359,14 @@ struct DataType<f16> {
/// @param b the ProgramBuilder
/// @param args args of size 1 with the f16 value to init with
/// @return a new AST f16 literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<f16>(args.values[0]));
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Expr(std::get<f16>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to f16.
/// @return a new AST f16 literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "f16"; }
@ -434,14 +389,14 @@ struct DataType<AFloat> {
/// @param b the ProgramBuilder
/// @param args args of size 1 with the abstract-float value to init with
/// @return a new AST abstract-float literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<AFloat>(args.values[0]));
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Expr(std::get<AFloat>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to AFloat.
/// @return a new AST abstract-float literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-float"; }
@ -464,14 +419,14 @@ struct DataType<AInt> {
/// @param b the ProgramBuilder
/// @param args args of size 1 with the abstract-int value to init with
/// @return a new AST abstract-int literal value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
return b.Expr(std::get<AInt>(args.values[0]));
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Expr(std::get<AInt>(args[0]));
}
/// @param b the ProgramBuilder
/// @param v arg of type double that will be cast to AInt.
/// @return a new AST abstract-int literal value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() { return "abstract-int"; }
@ -499,17 +454,17 @@ struct DataType<vec<N, T>> {
/// @param b the ProgramBuilder
/// @param args args of size 1 or N with values of type T to initialize with
/// @return a new AST vector value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param args args of size 1 or N with values of type T to initialize with
/// @return the list of expressions that are used to construct the vector
static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
const bool one_value = args.values.Length() == 1;
static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
const bool one_value = args.Length() == 1;
utils::Vector<const ast::Expression*, N> r;
for (size_t i = 0; i < N; ++i) {
r.Push(DataType<T>::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]}));
r.Push(DataType<T>::Expr(b, utils::Vector<Scalar, 1>{one_value ? args[0] : args[i]}));
}
return r;
}
@ -517,7 +472,7 @@ struct DataType<vec<N, T>> {
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST vector value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@ -548,25 +503,25 @@ struct DataType<mat<N, M, T>> {
/// @param b the ProgramBuilder
/// @param args args of size 1 or N*M with values of type T to initialize with
/// @return a new AST matrix value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param args args of size 1 or N*M with values of type T to initialize with
/// @return a new AST matrix value expression
static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
const bool one_value = args.values.Length() == 1;
static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
const bool one_value = args.Length() == 1;
size_t next = 0;
utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; ++i) {
if (one_value) {
r.Push(DataType<vec<M, T>>::Expr(b, ScalarArgs{args.values[0]}));
r.Push(DataType<vec<M, T>>::Expr(b, utils::Vector<Scalar, 1>{args[0]}));
} else {
utils::Vector<T, M> v;
utils::Vector<Scalar, M> v;
for (size_t j = 0; j < M; ++j) {
v.Push(std::get<T>(args.values[next++]));
v.Push(args[next++]);
}
r.Push(DataType<vec<M, T>>::Expr(b, utils::VectorRef<T>{v}));
r.Push(DataType<vec<M, T>>::Expr(b, std::move(v)));
}
}
return r;
@ -575,7 +530,7 @@ struct DataType<mat<N, M, T>> {
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST matrix value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@ -611,8 +566,9 @@ struct DataType<alias<T, ID>> {
/// @param args the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
ScalarArgs args) {
static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(
ProgramBuilder& b,
utils::VectorRef<Scalar> args) {
// Cast
return b.Construct(AST(b), DataType<T>::Expr(b, std::move(args)));
}
@ -621,8 +577,9 @@ struct DataType<alias<T, ID>> {
/// @param args the value nested elements will be initialized with
/// @return a new AST expression of the alias type
template <bool IS_COMPOSITE = is_composite>
static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
ScalarArgs args) {
static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(
ProgramBuilder& b,
utils::VectorRef<Scalar> args) {
// Construct
return b.Construct(AST(b), DataType<T>::ExprArgs(b, std::move(args)));
}
@ -631,7 +588,7 @@ struct DataType<alias<T, ID>> {
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST expression of the alias type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
@ -662,7 +619,8 @@ struct DataType<ptr<T>> {
/// @param b the ProgramBuilder
/// @return a new AST expression of the pointer type
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs /*unused*/) {
static inline const ast::Expression* Expr(ProgramBuilder& b,
utils::VectorRef<Scalar> /*unused*/) {
auto sym = b.Symbols().New("global_for_ptr");
b.GlobalVar(sym, DataType<T>::AST(b), ast::AddressSpace::kPrivate);
return b.AddressOf(sym);
@ -672,7 +630,7 @@ struct DataType<ptr<T>> {
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST expression of the pointer type
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
@ -716,17 +674,17 @@ struct DataType<array<N, T>> {
/// @param args args of size 1 or N with values of type T to initialize with
/// with
/// @return a new AST array value expression
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
}
/// @param b the ProgramBuilder
/// @param args args of size 1 or N with values of type T to initialize with
/// @return the list of expressions that are used to construct the array
static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
const bool one_value = args.values.Length() == 1;
static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
const bool one_value = args.Length() == 1;
utils::Vector<const ast::Expression*, N> r;
for (uint32_t i = 0; i < N; i++) {
r.Push(DataType<T>::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]}));
r.Push(DataType<T>::Expr(b, utils::Vector<Scalar, 1>{one_value ? args[0] : args[i]}));
}
return r;
}
@ -734,7 +692,7 @@ struct DataType<array<N, T>> {
/// @param v arg of type double that will be cast to ElementType
/// @return a new AST array value expression
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
}
/// @returns the WGSL name for the type
static inline std::string Name() {
@ -776,80 +734,34 @@ template <typename T>
const bool IsDataTypeSpecializedFor =
!std::is_same_v<typename DataType<T>::ElementType, UnspecializedElementType>;
namespace detail {
/// ValueBase is a base class of ConcreteValue<T>
struct ValueBase {
/// Constructor
ValueBase() = default;
/// Destructor
virtual ~ValueBase() = default;
/// Move constructor
ValueBase(ValueBase&&) = default;
/// Copy constructor
ValueBase(const ValueBase&) = default;
/// Copy assignment operator
/// @returns this instance
ValueBase& operator=(const ValueBase&) = default;
/// Creates an `ast::Expression` for the type T passing in previously stored args
/// @param b the ProgramBuilder
/// @returns an expression node
virtual const ast::Expression* Expr(ProgramBuilder& b) const = 0;
/// @returns args used to create expression via `Expr`
virtual const ScalarArgs& Args() const = 0;
/// @returns true if element type is abstract
virtual bool IsAbstract() const = 0;
/// @returns true if element type is an integral
virtual bool IsIntegral() const = 0;
/// @returns element type name
virtual std::string TypeName() const = 0;
/// Prints this value to the output stream
/// @param o the output stream
/// @returns input argument `o`
virtual std::ostream& Print(std::ostream& o) const = 0;
};
/// ConcreteValue<T> is used to create Values of type DataType<T> with a ScalarArgs initializer.
template <typename T>
struct ConcreteValue : ValueBase {
/// Constructor
/// Value is used to create Values with a Scalar vector initializer.
struct Value {
/// Creates a Value for type T initialized with `args`
/// @param args the scalar args
explicit ConcreteValue(ScalarArgs args) : args_(std::move(args)) {}
/// Alias to T
using Type = T;
/// Alias to DataType<T>
using DataType = builder::DataType<T>;
/// Alias to DataType::ElementType
using ElementType = typename DataType::ElementType;
/// Creates an `ast::Expression` for the type T passing in previously stored args
/// @param b the ProgramBuilder
/// @returns an expression node
const ast::Expression* Expr(ProgramBuilder& b) const override {
auto create = CreatePtrsFor<T>();
return (*create.expr)(b, args_);
/// @returns Value
template <typename T>
static Value Create(utils::VectorRef<Scalar> args) {
static_assert(IsDataTypeSpecializedFor<T>, "No DataType<T> specialization exists");
using EL_TY = typename builder::DataType<T>::ElementType;
return Value{
std::move(args), CreatePtrsFor<T>().expr, tint::IsAbstract<EL_TY>,
tint::IsIntegral<EL_TY>, tint::FriendlyName<EL_TY>(),
};
}
/// @returns args used to create expression via `Expr`
const ScalarArgs& Args() const override { return args_; }
/// @returns true if element type is abstract
bool IsAbstract() const override { return tint::IsAbstract<ElementType>; }
/// @returns true if element type is an integral
bool IsIntegral() const override { return tint::IsIntegral<ElementType>; }
/// @returns element type name
std::string TypeName() const override { return tint::FriendlyName<ElementType>(); }
/// Creates an `ast::Expression` for the type T passing in previously stored args
/// @param b the ProgramBuilder
/// @returns an expression node
const ast::Expression* Expr(ProgramBuilder& b) const { return (*create)(b, args); }
/// Prints this value to the output stream
/// @param o the output stream
/// @returns input argument `o`
std::ostream& Print(std::ostream& o) const override {
o << TypeName() << "(";
for (auto& a : args_.values) {
o << std::get<ElementType>(a);
if (&a != &args_.values.Back()) {
std::ostream& Print(std::ostream& o) const {
o << type_name << "(";
for (auto& a : args) {
std::visit([&](auto& v) { o << v; }, a);
if (&a != &args.Back()) {
o << ", ";
}
}
@ -857,54 +769,16 @@ struct ConcreteValue : ValueBase {
return o;
}
private:
/// args to create expression with
ScalarArgs args_;
};
} // namespace detail
/// A Value represents a value of type DataType<T> created with ScalarArgs. Useful for storing
/// values for unit tests.
class Value {
public:
/// Creates a Value for type T initialized with `args`
/// @param args the scalar args
/// @returns Value
template <typename T>
static Value Create(ScalarArgs args) {
static_assert(IsDataTypeSpecializedFor<T>, "No DataType<T> specialization exists");
return Value{std::make_shared<detail::ConcreteValue<T>>(std::move(args))};
}
/// Creates an `ast::Expression` for the type T passing in previously stored args
/// @param b the ProgramBuilder
/// @returns an expression node
const ast::Expression* Expr(ProgramBuilder& b) const { return value_->Expr(b); }
/// @returns args used to create expression via `Expr`
const ScalarArgs& Args() const { return value_->Args(); }
/// @returns true if element type is abstract
bool IsAbstract() const { return value_->IsAbstract(); }
/// @returns true if element type is an integral
bool IsIntegral() const { return value_->IsIntegral(); }
/// @returns element type name
std::string TypeName() const { return value_->TypeName(); }
/// Prints this value to the output stream
/// @param o the output stream
/// @returns input argument `o`
std::ostream& Print(std::ostream& o) const { return value_->Print(o); }
private:
/// Private constructor
explicit Value(std::shared_ptr<const detail::ValueBase> value) : value_(std::move(value)) {}
/// Shared pointer to an immutable value. This type-erasure pattern allows Value to wrap a
/// polymorphic type, while being used like a value-type (i.e. copyable).
std::shared_ptr<const detail::ValueBase> value_;
/// The arguments used to construct the value
utils::Vector<Scalar, 4> args;
/// Function used to construct an expression with the given value
builder::ast_expr_func_ptr create;
/// True if the element type is abstract
bool is_abstract = false;
/// True if the element type is an integer
bool is_integral = false;
/// The name of the type.
const char* type_name = "<invalid>";
};
/// Prints Value to ostream
@ -919,7 +793,7 @@ constexpr bool IsValue = std::is_same_v<T, Value>;
/// Creates a Value of DataType<T> from a scalar `v`
template <typename T>
Value Val(T v) {
return Value::Create<T>(ScalarArgs{v});
return Value::Create<T>(utils::Vector<Scalar, 1>{v});
}
/// Creates a Value of DataType<vec<N, T>> from N scalar `args`
@ -927,41 +801,41 @@ template <typename... T>
Value Vec(T... args) {
using FirstT = std::tuple_element_t<0, std::tuple<T...>>;
constexpr size_t N = sizeof...(args);
utils::Vector v{args...};
return Value::Create<vec<N, FirstT>>(utils::VectorRef<FirstT>{v});
utils::Vector<Scalar, sizeof...(args)> v{args...};
return Value::Create<vec<N, FirstT>>(std::move(v));
}
/// Creates a Value of DataType<mat<C,R,T> from C*R scalar `args`
template <size_t C, size_t R, typename T>
Value Mat(const T (&m_in)[C][R]) {
utils::Vector<T, C * R> m;
utils::Vector<Scalar, C * R> m;
for (uint32_t i = 0; i < C; ++i) {
for (size_t j = 0; j < R; ++j) {
m.Push(m_in[i][j]);
}
}
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
return Value::Create<mat<C, R, T>>(std::move(m));
}
/// Creates a Value of DataType<mat<2,R,T> from column vectors `c0` and `c1`
template <typename T, size_t R>
Value Mat(const T (&c0)[R], const T (&c1)[R]) {
constexpr size_t C = 2;
utils::Vector<T, C * R> m;
utils::Vector<Scalar, C * R> m;
for (auto v : c0) {
m.Push(v);
}
for (auto v : c1) {
m.Push(v);
}
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
return Value::Create<mat<C, R, T>>(std::move(m));
}
/// Creates a Value of DataType<mat<3,R,T> from column vectors `c0`, `c1`, and `c2`
template <typename T, size_t R>
Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
constexpr size_t C = 3;
utils::Vector<T, C * R> m;
utils::Vector<Scalar, C * R> m;
for (auto v : c0) {
m.Push(v);
}
@ -971,14 +845,14 @@ Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
for (auto v : c2) {
m.Push(v);
}
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
return Value::Create<mat<C, R, T>>(std::move(m));
}
/// Creates a Value of DataType<mat<4,R,T> from column vectors `c0`, `c1`, `c2`, and `c3`
template <typename T, size_t R>
Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) {
constexpr size_t C = 4;
utils::Vector<T, C * R> m;
utils::Vector<Scalar, C * R> m;
for (auto v : c0) {
m.Push(v);
}
@ -991,7 +865,7 @@ Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]
for (auto v : c3) {
m.Push(v);
}
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
return Value::Create<mat<C, R, T>>(std::move(m));
}
} // namespace builder