tint/resolver: Further simplify test const eval framework
Replace ScalarArgs struct with Scalar variant and vector. Fold ValueBase and ConcreteValue into Value. Change-Id: I5cc5811a87f1aae162feb65fb6b1ecdac033d0fe Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111761 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Commit-Queue: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
c572df265d
commit
95d174a118
|
@ -94,12 +94,12 @@ TEST_P(ResolverConstEvalConvTest, Test) {
|
|||
ASSERT_NE(sem->ConstantValue(), nullptr);
|
||||
EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
|
||||
|
||||
auto expected_values = expected.Args();
|
||||
auto expected_values = expected.args;
|
||||
if (kind == Kind::kVector) {
|
||||
expected_values.values.Push(expected_values.values[0]);
|
||||
expected_values.values.Push(expected_values.values[0]);
|
||||
expected_values.Push(expected_values[0]);
|
||||
expected_values.Push(expected_values[0]);
|
||||
}
|
||||
auto got_values = ScalarArgsFrom(sem->ConstantValue());
|
||||
auto got_values = ScalarsFrom(sem->ConstantValue());
|
||||
EXPECT_EQ(expected_values, got_values);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,28 +37,29 @@ template <typename T>
|
|||
inline const auto k3PiOver4 = T(UnwrapNumber<T>(2.356194490192344928846));
|
||||
|
||||
/// Walks the sem::Constant @p c, accumulating all the inner-most scalar values into @p args
|
||||
inline void CollectScalarArgs(const sem::Constant* c, builder::ScalarArgs& args) {
|
||||
template <size_t N>
|
||||
inline void CollectScalars(const sem::Constant* c, utils::Vector<builder::Scalar, N>& scalars) {
|
||||
Switch(
|
||||
c->Type(), //
|
||||
[&](const sem::AbstractInt*) { args.values.Push(c->As<AInt>()); },
|
||||
[&](const sem::AbstractFloat*) { args.values.Push(c->As<AFloat>()); },
|
||||
[&](const sem::Bool*) { args.values.Push(c->As<bool>()); },
|
||||
[&](const sem::I32*) { args.values.Push(c->As<i32>()); },
|
||||
[&](const sem::U32*) { args.values.Push(c->As<u32>()); },
|
||||
[&](const sem::F32*) { args.values.Push(c->As<f32>()); },
|
||||
[&](const sem::F16*) { args.values.Push(c->As<f16>()); },
|
||||
[&](const sem::AbstractInt*) { scalars.Push(c->As<AInt>()); },
|
||||
[&](const sem::AbstractFloat*) { scalars.Push(c->As<AFloat>()); },
|
||||
[&](const sem::Bool*) { scalars.Push(c->As<bool>()); },
|
||||
[&](const sem::I32*) { scalars.Push(c->As<i32>()); },
|
||||
[&](const sem::U32*) { scalars.Push(c->As<u32>()); },
|
||||
[&](const sem::F32*) { scalars.Push(c->As<f32>()); },
|
||||
[&](const sem::F16*) { scalars.Push(c->As<f16>()); },
|
||||
[&](Default) {
|
||||
size_t i = 0;
|
||||
while (auto* child = c->Index(i++)) {
|
||||
CollectScalarArgs(child, args);
|
||||
CollectScalars(child, scalars);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Walks the sem::Constant @p c, returning all the inner-most scalar values.
|
||||
inline builder::ScalarArgs ScalarArgsFrom(const sem::Constant* c) {
|
||||
builder::ScalarArgs out;
|
||||
CollectScalarArgs(c, out);
|
||||
inline utils::Vector<builder::Scalar, 16> ScalarsFrom(const sem::Constant* c) {
|
||||
utils::Vector<builder::Scalar, 16> out;
|
||||
CollectScalars(c, out);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
@ -90,14 +91,14 @@ struct CheckConstantFlags {
|
|||
inline void CheckConstant(const sem::Constant* got_constant,
|
||||
const builder::Value& expected_value,
|
||||
CheckConstantFlags flags = {}) {
|
||||
auto values_flat = ScalarArgsFrom(got_constant);
|
||||
auto expected_values_flat = expected_value.Args();
|
||||
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
|
||||
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
|
||||
auto& got_scalar = values_flat.values[i];
|
||||
auto& expected_scalar = expected_values_flat.values[i];
|
||||
auto values_flat = ScalarsFrom(got_constant);
|
||||
auto expected_values_flat = expected_value.args;
|
||||
ASSERT_EQ(values_flat.Length(), expected_values_flat.Length());
|
||||
for (size_t i = 0; i < values_flat.Length(); ++i) {
|
||||
auto& got_scalar = values_flat[i];
|
||||
auto& expected_scalar = expected_values_flat[i];
|
||||
std::visit(
|
||||
[&](auto&& expected) {
|
||||
[&](const auto& expected) {
|
||||
using T = std::decay_t<decltype(expected)>;
|
||||
|
||||
ASSERT_TRUE(std::holds_alternative<T>(got_scalar));
|
||||
|
|
|
@ -61,14 +61,14 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) {
|
|||
ASSERT_NE(value, nullptr);
|
||||
EXPECT_TYPE(value->Type(), sem->Type());
|
||||
|
||||
auto values_flat = ScalarArgsFrom(value);
|
||||
auto expected_values_flat = expected.Args();
|
||||
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
|
||||
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
|
||||
auto& a = values_flat.values[i];
|
||||
auto& b = expected_values_flat.values[i];
|
||||
auto values_flat = ScalarsFrom(value);
|
||||
auto expected_values_flat = expected.args;
|
||||
ASSERT_EQ(values_flat.Length(), expected_values_flat.Length());
|
||||
for (size_t i = 0; i < values_flat.Length(); ++i) {
|
||||
auto& a = values_flat[i];
|
||||
auto& b = expected_values_flat[i];
|
||||
EXPECT_EQ(a, b);
|
||||
if (expected.IsIntegral()) {
|
||||
if (expected.is_integral) {
|
||||
// Check that the constant's integer doesn't contain unexpected
|
||||
// data in the MSBs that are outside of the bit-width of T.
|
||||
EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
|
||||
|
|
|
@ -180,63 +180,18 @@ using alias3 = alias<TO, 3>;
|
|||
template <typename TO>
|
||||
struct ptr {};
|
||||
|
||||
/// Type used to accept scalars as arguments. Can be either a single value that gets splatted for
|
||||
/// composite types, or all values required by the composite type.
|
||||
struct ScalarArgs {
|
||||
/// Constructor
|
||||
ScalarArgs() = default;
|
||||
|
||||
/// Constructor
|
||||
/// @param single_value single value to initialize with
|
||||
template <typename T>
|
||||
explicit ScalarArgs(T single_value) : values(utils::Vector<Storage, 1>{single_value}) {}
|
||||
|
||||
/// Constructor
|
||||
/// @param all_values all values to initialize the composite type with
|
||||
template <typename T>
|
||||
ScalarArgs(utils::VectorRef<T> all_values) // NOLINT: implicit on purpose
|
||||
{
|
||||
for (auto& v : all_values) {
|
||||
values.Push(v);
|
||||
}
|
||||
}
|
||||
|
||||
/// @param other the other ScalarArgs to compare against
|
||||
/// @returns true if all values are equal to the values in @p other
|
||||
bool operator==(const ScalarArgs& other) const { return values == other.values; }
|
||||
|
||||
/// Valid scalar types for args
|
||||
using Storage = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>;
|
||||
|
||||
/// The vector of values
|
||||
utils::Vector<Storage, 16> values;
|
||||
};
|
||||
/// A scalar value
|
||||
using Scalar = std::variant<i32, u32, f32, f16, AInt, AFloat, bool>;
|
||||
|
||||
/// Returns current variant value in `s` cast to type `T`
|
||||
template <typename T>
|
||||
T As(ScalarArgs::Storage& s) {
|
||||
T As(Scalar& s) {
|
||||
return std::visit([](auto&& v) { return static_cast<T>(v); }, s);
|
||||
}
|
||||
|
||||
/// @param o the std::ostream to write to
|
||||
/// @param args the ScalarArgs
|
||||
/// @return the std::ostream so calls can be chained
|
||||
inline std::ostream& operator<<(std::ostream& o, const ScalarArgs& args) {
|
||||
o << "[";
|
||||
bool first = true;
|
||||
for (auto& val : args.values) {
|
||||
if (!first) {
|
||||
o << ", ";
|
||||
}
|
||||
first = false;
|
||||
std::visit([&](auto&& v) { o << v; }, val);
|
||||
}
|
||||
o << "]";
|
||||
return o;
|
||||
}
|
||||
|
||||
using ast_type_func_ptr = const ast::Type* (*)(ProgramBuilder& b);
|
||||
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, ScalarArgs args);
|
||||
using ast_expr_func_ptr = const ast::Expression* (*)(ProgramBuilder& b,
|
||||
utils::VectorRef<Scalar> args);
|
||||
using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder& b, double v);
|
||||
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
|
||||
using type_name_func_ptr = std::string (*)();
|
||||
|
@ -280,14 +235,14 @@ struct DataType<bool> {
|
|||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 with the boolean value to init with
|
||||
/// @return a new AST expression of the bool type
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
|
||||
return b.Expr(std::get<bool>(args.values[0]));
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
return b.Expr(std::get<bool>(args[0]));
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param v arg of type double that will be cast to bool.
|
||||
/// @return a new AST expression of the bool type
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
/// @returns the WGSL name for the type
|
||||
static inline std::string Name() { return "bool"; }
|
||||
|
@ -311,14 +266,14 @@ struct DataType<i32> {
|
|||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 with the i32 value to init with
|
||||
/// @return a new AST i32 literal value expression
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
|
||||
return b.Expr(std::get<i32>(args.values[0]));
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
return b.Expr(std::get<i32>(args[0]));
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param v arg of type double that will be cast to i32.
|
||||
/// @return a new AST i32 literal value expression
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
/// @returns the WGSL name for the type
|
||||
static inline std::string Name() { return "i32"; }
|
||||
|
@ -342,14 +297,14 @@ struct DataType<u32> {
|
|||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 with the u32 value to init with
|
||||
/// @return a new AST u32 literal value expression
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
|
||||
return b.Expr(std::get<u32>(args.values[0]));
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
return b.Expr(std::get<u32>(args[0]));
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param v arg of type double that will be cast to u32.
|
||||
/// @return a new AST u32 literal value expression
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
/// @returns the WGSL name for the type
|
||||
static inline std::string Name() { return "u32"; }
|
||||
|
@ -373,14 +328,14 @@ struct DataType<f32> {
|
|||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 with the f32 value to init with
|
||||
/// @return a new AST f32 literal value expression
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
|
||||
return b.Expr(std::get<f32>(args.values[0]));
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
return b.Expr(std::get<f32>(args[0]));
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param v arg of type double that will be cast to f32.
|
||||
/// @return a new AST f32 literal value expression
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<f32>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<f32>(v)});
|
||||
}
|
||||
/// @returns the WGSL name for the type
|
||||
static inline std::string Name() { return "f32"; }
|
||||
|
@ -404,14 +359,14 @@ struct DataType<f16> {
|
|||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 with the f16 value to init with
|
||||
/// @return a new AST f16 literal value expression
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
|
||||
return b.Expr(std::get<f16>(args.values[0]));
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
return b.Expr(std::get<f16>(args[0]));
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param v arg of type double that will be cast to f16.
|
||||
/// @return a new AST f16 literal value expression
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
/// @returns the WGSL name for the type
|
||||
static inline std::string Name() { return "f16"; }
|
||||
|
@ -434,14 +389,14 @@ struct DataType<AFloat> {
|
|||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 with the abstract-float value to init with
|
||||
/// @return a new AST abstract-float literal value expression
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
|
||||
return b.Expr(std::get<AFloat>(args.values[0]));
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
return b.Expr(std::get<AFloat>(args[0]));
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param v arg of type double that will be cast to AFloat.
|
||||
/// @return a new AST abstract-float literal value expression
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
/// @returns the WGSL name for the type
|
||||
static inline std::string Name() { return "abstract-float"; }
|
||||
|
@ -464,14 +419,14 @@ struct DataType<AInt> {
|
|||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 with the abstract-int value to init with
|
||||
/// @return a new AST abstract-int literal value expression
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
|
||||
return b.Expr(std::get<AInt>(args.values[0]));
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
return b.Expr(std::get<AInt>(args[0]));
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param v arg of type double that will be cast to AInt.
|
||||
/// @return a new AST abstract-int literal value expression
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
/// @returns the WGSL name for the type
|
||||
static inline std::string Name() { return "abstract-int"; }
|
||||
|
@ -499,17 +454,17 @@ struct DataType<vec<N, T>> {
|
|||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 or N with values of type T to initialize with
|
||||
/// @return a new AST vector value expression
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 or N with values of type T to initialize with
|
||||
/// @return the list of expressions that are used to construct the vector
|
||||
static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
|
||||
const bool one_value = args.values.Length() == 1;
|
||||
static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
const bool one_value = args.Length() == 1;
|
||||
utils::Vector<const ast::Expression*, N> r;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
r.Push(DataType<T>::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]}));
|
||||
r.Push(DataType<T>::Expr(b, utils::Vector<Scalar, 1>{one_value ? args[0] : args[i]}));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
@ -517,7 +472,7 @@ struct DataType<vec<N, T>> {
|
|||
/// @param v arg of type double that will be cast to ElementType
|
||||
/// @return a new AST vector value expression
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
/// @returns the WGSL name for the type
|
||||
static inline std::string Name() {
|
||||
|
@ -548,25 +503,25 @@ struct DataType<mat<N, M, T>> {
|
|||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 or N*M with values of type T to initialize with
|
||||
/// @return a new AST matrix value expression
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 or N*M with values of type T to initialize with
|
||||
/// @return a new AST matrix value expression
|
||||
static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
|
||||
const bool one_value = args.values.Length() == 1;
|
||||
static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
const bool one_value = args.Length() == 1;
|
||||
size_t next = 0;
|
||||
utils::Vector<const ast::Expression*, N> r;
|
||||
for (uint32_t i = 0; i < N; ++i) {
|
||||
if (one_value) {
|
||||
r.Push(DataType<vec<M, T>>::Expr(b, ScalarArgs{args.values[0]}));
|
||||
r.Push(DataType<vec<M, T>>::Expr(b, utils::Vector<Scalar, 1>{args[0]}));
|
||||
} else {
|
||||
utils::Vector<T, M> v;
|
||||
utils::Vector<Scalar, M> v;
|
||||
for (size_t j = 0; j < M; ++j) {
|
||||
v.Push(std::get<T>(args.values[next++]));
|
||||
v.Push(args[next++]);
|
||||
}
|
||||
r.Push(DataType<vec<M, T>>::Expr(b, utils::VectorRef<T>{v}));
|
||||
r.Push(DataType<vec<M, T>>::Expr(b, std::move(v)));
|
||||
}
|
||||
}
|
||||
return r;
|
||||
|
@ -575,7 +530,7 @@ struct DataType<mat<N, M, T>> {
|
|||
/// @param v arg of type double that will be cast to ElementType
|
||||
/// @return a new AST matrix value expression
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
/// @returns the WGSL name for the type
|
||||
static inline std::string Name() {
|
||||
|
@ -611,8 +566,9 @@ struct DataType<alias<T, ID>> {
|
|||
/// @param args the value nested elements will be initialized with
|
||||
/// @return a new AST expression of the alias type
|
||||
template <bool IS_COMPOSITE = is_composite>
|
||||
static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
|
||||
ScalarArgs args) {
|
||||
static inline traits::EnableIf<!IS_COMPOSITE, const ast::Expression*> Expr(
|
||||
ProgramBuilder& b,
|
||||
utils::VectorRef<Scalar> args) {
|
||||
// Cast
|
||||
return b.Construct(AST(b), DataType<T>::Expr(b, std::move(args)));
|
||||
}
|
||||
|
@ -621,8 +577,9 @@ struct DataType<alias<T, ID>> {
|
|||
/// @param args the value nested elements will be initialized with
|
||||
/// @return a new AST expression of the alias type
|
||||
template <bool IS_COMPOSITE = is_composite>
|
||||
static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(ProgramBuilder& b,
|
||||
ScalarArgs args) {
|
||||
static inline traits::EnableIf<IS_COMPOSITE, const ast::Expression*> Expr(
|
||||
ProgramBuilder& b,
|
||||
utils::VectorRef<Scalar> args) {
|
||||
// Construct
|
||||
return b.Construct(AST(b), DataType<T>::ExprArgs(b, std::move(args)));
|
||||
}
|
||||
|
@ -631,7 +588,7 @@ struct DataType<alias<T, ID>> {
|
|||
/// @param v arg of type double that will be cast to ElementType
|
||||
/// @return a new AST expression of the alias type
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
|
||||
/// @returns the WGSL name for the type
|
||||
|
@ -662,7 +619,8 @@ struct DataType<ptr<T>> {
|
|||
|
||||
/// @param b the ProgramBuilder
|
||||
/// @return a new AST expression of the pointer type
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs /*unused*/) {
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b,
|
||||
utils::VectorRef<Scalar> /*unused*/) {
|
||||
auto sym = b.Symbols().New("global_for_ptr");
|
||||
b.GlobalVar(sym, DataType<T>::AST(b), ast::AddressSpace::kPrivate);
|
||||
return b.AddressOf(sym);
|
||||
|
@ -672,7 +630,7 @@ struct DataType<ptr<T>> {
|
|||
/// @param v arg of type double that will be cast to ElementType
|
||||
/// @return a new AST expression of the pointer type
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
|
||||
/// @returns the WGSL name for the type
|
||||
|
@ -716,17 +674,17 @@ struct DataType<array<N, T>> {
|
|||
/// @param args args of size 1 or N with values of type T to initialize with
|
||||
/// with
|
||||
/// @return a new AST array value expression
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, ScalarArgs args) {
|
||||
static inline const ast::Expression* Expr(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
return b.Construct(AST(b), ExprArgs(b, std::move(args)));
|
||||
}
|
||||
/// @param b the ProgramBuilder
|
||||
/// @param args args of size 1 or N with values of type T to initialize with
|
||||
/// @return the list of expressions that are used to construct the array
|
||||
static inline auto ExprArgs(ProgramBuilder& b, ScalarArgs args) {
|
||||
const bool one_value = args.values.Length() == 1;
|
||||
static inline auto ExprArgs(ProgramBuilder& b, utils::VectorRef<Scalar> args) {
|
||||
const bool one_value = args.Length() == 1;
|
||||
utils::Vector<const ast::Expression*, N> r;
|
||||
for (uint32_t i = 0; i < N; i++) {
|
||||
r.Push(DataType<T>::Expr(b, ScalarArgs{one_value ? args.values[0] : args.values[i]}));
|
||||
r.Push(DataType<T>::Expr(b, utils::Vector<Scalar, 1>{one_value ? args[0] : args[i]}));
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
@ -734,7 +692,7 @@ struct DataType<array<N, T>> {
|
|||
/// @param v arg of type double that will be cast to ElementType
|
||||
/// @return a new AST array value expression
|
||||
static inline const ast::Expression* ExprFromDouble(ProgramBuilder& b, double v) {
|
||||
return Expr(b, ScalarArgs{static_cast<ElementType>(v)});
|
||||
return Expr(b, utils::Vector<Scalar, 1>{static_cast<ElementType>(v)});
|
||||
}
|
||||
/// @returns the WGSL name for the type
|
||||
static inline std::string Name() {
|
||||
|
@ -776,80 +734,34 @@ template <typename T>
|
|||
const bool IsDataTypeSpecializedFor =
|
||||
!std::is_same_v<typename DataType<T>::ElementType, UnspecializedElementType>;
|
||||
|
||||
namespace detail {
|
||||
/// ValueBase is a base class of ConcreteValue<T>
|
||||
struct ValueBase {
|
||||
/// Constructor
|
||||
ValueBase() = default;
|
||||
/// Destructor
|
||||
virtual ~ValueBase() = default;
|
||||
/// Move constructor
|
||||
ValueBase(ValueBase&&) = default;
|
||||
/// Copy constructor
|
||||
ValueBase(const ValueBase&) = default;
|
||||
/// Copy assignment operator
|
||||
/// @returns this instance
|
||||
ValueBase& operator=(const ValueBase&) = default;
|
||||
/// Creates an `ast::Expression` for the type T passing in previously stored args
|
||||
/// @param b the ProgramBuilder
|
||||
/// @returns an expression node
|
||||
virtual const ast::Expression* Expr(ProgramBuilder& b) const = 0;
|
||||
/// @returns args used to create expression via `Expr`
|
||||
virtual const ScalarArgs& Args() const = 0;
|
||||
/// @returns true if element type is abstract
|
||||
virtual bool IsAbstract() const = 0;
|
||||
/// @returns true if element type is an integral
|
||||
virtual bool IsIntegral() const = 0;
|
||||
/// @returns element type name
|
||||
virtual std::string TypeName() const = 0;
|
||||
/// Prints this value to the output stream
|
||||
/// @param o the output stream
|
||||
/// @returns input argument `o`
|
||||
virtual std::ostream& Print(std::ostream& o) const = 0;
|
||||
};
|
||||
|
||||
/// ConcreteValue<T> is used to create Values of type DataType<T> with a ScalarArgs initializer.
|
||||
template <typename T>
|
||||
struct ConcreteValue : ValueBase {
|
||||
/// Constructor
|
||||
/// Value is used to create Values with a Scalar vector initializer.
|
||||
struct Value {
|
||||
/// Creates a Value for type T initialized with `args`
|
||||
/// @param args the scalar args
|
||||
explicit ConcreteValue(ScalarArgs args) : args_(std::move(args)) {}
|
||||
|
||||
/// Alias to T
|
||||
using Type = T;
|
||||
/// Alias to DataType<T>
|
||||
using DataType = builder::DataType<T>;
|
||||
/// Alias to DataType::ElementType
|
||||
using ElementType = typename DataType::ElementType;
|
||||
|
||||
/// Creates an `ast::Expression` for the type T passing in previously stored args
|
||||
/// @param b the ProgramBuilder
|
||||
/// @returns an expression node
|
||||
const ast::Expression* Expr(ProgramBuilder& b) const override {
|
||||
auto create = CreatePtrsFor<T>();
|
||||
return (*create.expr)(b, args_);
|
||||
/// @returns Value
|
||||
template <typename T>
|
||||
static Value Create(utils::VectorRef<Scalar> args) {
|
||||
static_assert(IsDataTypeSpecializedFor<T>, "No DataType<T> specialization exists");
|
||||
using EL_TY = typename builder::DataType<T>::ElementType;
|
||||
return Value{
|
||||
std::move(args), CreatePtrsFor<T>().expr, tint::IsAbstract<EL_TY>,
|
||||
tint::IsIntegral<EL_TY>, tint::FriendlyName<EL_TY>(),
|
||||
};
|
||||
}
|
||||
|
||||
/// @returns args used to create expression via `Expr`
|
||||
const ScalarArgs& Args() const override { return args_; }
|
||||
|
||||
/// @returns true if element type is abstract
|
||||
bool IsAbstract() const override { return tint::IsAbstract<ElementType>; }
|
||||
|
||||
/// @returns true if element type is an integral
|
||||
bool IsIntegral() const override { return tint::IsIntegral<ElementType>; }
|
||||
|
||||
/// @returns element type name
|
||||
std::string TypeName() const override { return tint::FriendlyName<ElementType>(); }
|
||||
/// Creates an `ast::Expression` for the type T passing in previously stored args
|
||||
/// @param b the ProgramBuilder
|
||||
/// @returns an expression node
|
||||
const ast::Expression* Expr(ProgramBuilder& b) const { return (*create)(b, args); }
|
||||
|
||||
/// Prints this value to the output stream
|
||||
/// @param o the output stream
|
||||
/// @returns input argument `o`
|
||||
std::ostream& Print(std::ostream& o) const override {
|
||||
o << TypeName() << "(";
|
||||
for (auto& a : args_.values) {
|
||||
o << std::get<ElementType>(a);
|
||||
if (&a != &args_.values.Back()) {
|
||||
std::ostream& Print(std::ostream& o) const {
|
||||
o << type_name << "(";
|
||||
for (auto& a : args) {
|
||||
std::visit([&](auto& v) { o << v; }, a);
|
||||
if (&a != &args.Back()) {
|
||||
o << ", ";
|
||||
}
|
||||
}
|
||||
|
@ -857,54 +769,16 @@ struct ConcreteValue : ValueBase {
|
|||
return o;
|
||||
}
|
||||
|
||||
private:
|
||||
/// args to create expression with
|
||||
ScalarArgs args_;
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
/// A Value represents a value of type DataType<T> created with ScalarArgs. Useful for storing
|
||||
/// values for unit tests.
|
||||
class Value {
|
||||
public:
|
||||
/// Creates a Value for type T initialized with `args`
|
||||
/// @param args the scalar args
|
||||
/// @returns Value
|
||||
template <typename T>
|
||||
static Value Create(ScalarArgs args) {
|
||||
static_assert(IsDataTypeSpecializedFor<T>, "No DataType<T> specialization exists");
|
||||
return Value{std::make_shared<detail::ConcreteValue<T>>(std::move(args))};
|
||||
}
|
||||
|
||||
/// Creates an `ast::Expression` for the type T passing in previously stored args
|
||||
/// @param b the ProgramBuilder
|
||||
/// @returns an expression node
|
||||
const ast::Expression* Expr(ProgramBuilder& b) const { return value_->Expr(b); }
|
||||
|
||||
/// @returns args used to create expression via `Expr`
|
||||
const ScalarArgs& Args() const { return value_->Args(); }
|
||||
|
||||
/// @returns true if element type is abstract
|
||||
bool IsAbstract() const { return value_->IsAbstract(); }
|
||||
|
||||
/// @returns true if element type is an integral
|
||||
bool IsIntegral() const { return value_->IsIntegral(); }
|
||||
|
||||
/// @returns element type name
|
||||
std::string TypeName() const { return value_->TypeName(); }
|
||||
|
||||
/// Prints this value to the output stream
|
||||
/// @param o the output stream
|
||||
/// @returns input argument `o`
|
||||
std::ostream& Print(std::ostream& o) const { return value_->Print(o); }
|
||||
|
||||
private:
|
||||
/// Private constructor
|
||||
explicit Value(std::shared_ptr<const detail::ValueBase> value) : value_(std::move(value)) {}
|
||||
|
||||
/// Shared pointer to an immutable value. This type-erasure pattern allows Value to wrap a
|
||||
/// polymorphic type, while being used like a value-type (i.e. copyable).
|
||||
std::shared_ptr<const detail::ValueBase> value_;
|
||||
/// The arguments used to construct the value
|
||||
utils::Vector<Scalar, 4> args;
|
||||
/// Function used to construct an expression with the given value
|
||||
builder::ast_expr_func_ptr create;
|
||||
/// True if the element type is abstract
|
||||
bool is_abstract = false;
|
||||
/// True if the element type is an integer
|
||||
bool is_integral = false;
|
||||
/// The name of the type.
|
||||
const char* type_name = "<invalid>";
|
||||
};
|
||||
|
||||
/// Prints Value to ostream
|
||||
|
@ -919,7 +793,7 @@ constexpr bool IsValue = std::is_same_v<T, Value>;
|
|||
/// Creates a Value of DataType<T> from a scalar `v`
|
||||
template <typename T>
|
||||
Value Val(T v) {
|
||||
return Value::Create<T>(ScalarArgs{v});
|
||||
return Value::Create<T>(utils::Vector<Scalar, 1>{v});
|
||||
}
|
||||
|
||||
/// Creates a Value of DataType<vec<N, T>> from N scalar `args`
|
||||
|
@ -927,41 +801,41 @@ template <typename... T>
|
|||
Value Vec(T... args) {
|
||||
using FirstT = std::tuple_element_t<0, std::tuple<T...>>;
|
||||
constexpr size_t N = sizeof...(args);
|
||||
utils::Vector v{args...};
|
||||
return Value::Create<vec<N, FirstT>>(utils::VectorRef<FirstT>{v});
|
||||
utils::Vector<Scalar, sizeof...(args)> v{args...};
|
||||
return Value::Create<vec<N, FirstT>>(std::move(v));
|
||||
}
|
||||
|
||||
/// Creates a Value of DataType<mat<C,R,T> from C*R scalar `args`
|
||||
template <size_t C, size_t R, typename T>
|
||||
Value Mat(const T (&m_in)[C][R]) {
|
||||
utils::Vector<T, C * R> m;
|
||||
utils::Vector<Scalar, C * R> m;
|
||||
for (uint32_t i = 0; i < C; ++i) {
|
||||
for (size_t j = 0; j < R; ++j) {
|
||||
m.Push(m_in[i][j]);
|
||||
}
|
||||
}
|
||||
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
|
||||
return Value::Create<mat<C, R, T>>(std::move(m));
|
||||
}
|
||||
|
||||
/// Creates a Value of DataType<mat<2,R,T> from column vectors `c0` and `c1`
|
||||
template <typename T, size_t R>
|
||||
Value Mat(const T (&c0)[R], const T (&c1)[R]) {
|
||||
constexpr size_t C = 2;
|
||||
utils::Vector<T, C * R> m;
|
||||
utils::Vector<Scalar, C * R> m;
|
||||
for (auto v : c0) {
|
||||
m.Push(v);
|
||||
}
|
||||
for (auto v : c1) {
|
||||
m.Push(v);
|
||||
}
|
||||
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
|
||||
return Value::Create<mat<C, R, T>>(std::move(m));
|
||||
}
|
||||
|
||||
/// Creates a Value of DataType<mat<3,R,T> from column vectors `c0`, `c1`, and `c2`
|
||||
template <typename T, size_t R>
|
||||
Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
|
||||
constexpr size_t C = 3;
|
||||
utils::Vector<T, C * R> m;
|
||||
utils::Vector<Scalar, C * R> m;
|
||||
for (auto v : c0) {
|
||||
m.Push(v);
|
||||
}
|
||||
|
@ -971,14 +845,14 @@ Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
|
|||
for (auto v : c2) {
|
||||
m.Push(v);
|
||||
}
|
||||
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
|
||||
return Value::Create<mat<C, R, T>>(std::move(m));
|
||||
}
|
||||
|
||||
/// Creates a Value of DataType<mat<4,R,T> from column vectors `c0`, `c1`, `c2`, and `c3`
|
||||
template <typename T, size_t R>
|
||||
Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) {
|
||||
constexpr size_t C = 4;
|
||||
utils::Vector<T, C * R> m;
|
||||
utils::Vector<Scalar, C * R> m;
|
||||
for (auto v : c0) {
|
||||
m.Push(v);
|
||||
}
|
||||
|
@ -991,7 +865,7 @@ Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]
|
|||
for (auto v : c3) {
|
||||
m.Push(v);
|
||||
}
|
||||
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
|
||||
return Value::Create<mat<C, R, T>>(std::move(m));
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
|
|
Loading…
Reference in New Issue