tint: clean up const eval test framework
- Remove Types variant, and replace with a type-erasing Value class instead. This is not only better for compile times, but makes the code much easier to understand. - Value wraps an internal shared_ptr to a const detail::ValueBase, allowing it to be used as a value-type (i.e. copyable), while behaving polymorphically. - Add static_asserts to Val, Vec, and Mat creation helpers to emit a more useful error message when the wrong type is passed in. Bug: tint:1581 Change-Id: Icd0d08522bedb3eab12c44efa0d1555ed6e96458 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111700 Commit-Queue: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Dan Sinclair <dsinclair@chromium.org> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
8392a82a40
commit
d4908670e1
|
@ -22,30 +22,26 @@ using ::testing::HasSubstr;
|
|||
namespace tint::resolver {
|
||||
namespace {
|
||||
|
||||
// Bring in std::ostream& operator<<(std::ostream& o, const Types& types)
|
||||
using resolver::operator<<;
|
||||
|
||||
struct Case {
|
||||
struct Success {
|
||||
Types value;
|
||||
Value value;
|
||||
};
|
||||
struct Failure {
|
||||
std::string error;
|
||||
};
|
||||
|
||||
Types lhs;
|
||||
Types rhs;
|
||||
Value lhs;
|
||||
Value rhs;
|
||||
utils::Result<Success, Failure> expected;
|
||||
};
|
||||
|
||||
struct ErrorCase {
|
||||
Types lhs;
|
||||
Types rhs;
|
||||
Value lhs;
|
||||
Value rhs;
|
||||
};
|
||||
|
||||
/// Creates a Case with Values of any type
|
||||
template <typename T, typename U, typename V>
|
||||
Case C(Value<T> lhs, Value<U> rhs, Value<V> expected) {
|
||||
Case C(Value lhs, Value rhs, Value expected) {
|
||||
return Case{std::move(lhs), std::move(rhs), Case::Success{std::move(expected)}};
|
||||
}
|
||||
|
||||
|
@ -56,8 +52,7 @@ Case C(T lhs, U rhs, V expected) {
|
|||
}
|
||||
|
||||
/// Creates an failure Case with Values of any type
|
||||
template <typename T, typename U>
|
||||
Case E(Value<T> lhs, Value<U> rhs, std::string error) {
|
||||
Case E(Value lhs, Value rhs, std::string error) {
|
||||
return Case{std::move(lhs), std::move(rhs), Case::Failure{std::move(error)}};
|
||||
}
|
||||
|
||||
|
@ -71,7 +66,7 @@ Case E(T lhs, U rhs, std::string error) {
|
|||
static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
||||
o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: ";
|
||||
if (c.expected) {
|
||||
auto s = c.expected.Get();
|
||||
auto& s = c.expected.Get();
|
||||
o << s.value;
|
||||
} else {
|
||||
o << "[ERROR: " << c.expected.Failure().error << "]";
|
||||
|
@ -91,15 +86,16 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
|
|||
auto op = std::get<0>(GetParam());
|
||||
auto& c = std::get<1>(GetParam());
|
||||
|
||||
auto* lhs_expr = ToValueBase(c.lhs)->Expr(*this);
|
||||
auto* rhs_expr = ToValueBase(c.rhs)->Expr(*this);
|
||||
auto* lhs_expr = c.lhs.Expr(*this);
|
||||
auto* rhs_expr = c.rhs.Expr(*this);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(Source{{12, 34}}, op, lhs_expr, rhs_expr);
|
||||
GlobalConst("C", expr);
|
||||
|
||||
if (c.expected) {
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
auto expected_case = c.expected.Get();
|
||||
auto* expected = ToValueBase(expected_case.value);
|
||||
auto& expected = expected_case.value;
|
||||
|
||||
auto* sem = Sem().Get(expr);
|
||||
const sem::Constant* value = sem->ConstantValue();
|
||||
|
@ -707,7 +703,6 @@ INSTANTIATE_TEST_SUITE_P(Or,
|
|||
OpOrIntCases<u32>()))));
|
||||
|
||||
TEST_F(ResolverConstEvalTest, NotAndOrOfVecs) {
|
||||
// const C = !((vec2(true, true) & vec2(true, false)) | vec2(false, true));
|
||||
auto v1 = Vec(true, true).Expr(*this);
|
||||
auto v2 = Vec(true, false).Expr(*this);
|
||||
auto v3 = Vec(false, true).Expr(*this);
|
||||
|
@ -978,8 +973,8 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractShiftLeftRemainsAbstract) {
|
|||
// i32/u32 left shift by >= 32 -> error
|
||||
using ResolverConstEvalShiftLeftConcreteGeqBitWidthError = ResolverTestWithParam<ErrorCase>;
|
||||
TEST_P(ResolverConstEvalShiftLeftConcreteGeqBitWidthError, Test) {
|
||||
auto* lhs_expr = ToValueBase(GetParam().lhs)->Expr(*this);
|
||||
auto* rhs_expr = ToValueBase(GetParam().rhs)->Expr(*this);
|
||||
auto* lhs_expr = GetParam().lhs.Expr(*this);
|
||||
auto* rhs_expr = GetParam().rhs.Expr(*this);
|
||||
GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(
|
||||
|
@ -1024,8 +1019,8 @@ INSTANTIATE_TEST_SUITE_P(Test,
|
|||
// AInt left shift results in sign change error
|
||||
using ResolverConstEvalShiftLeftSignChangeError = ResolverTestWithParam<ErrorCase>;
|
||||
TEST_P(ResolverConstEvalShiftLeftSignChangeError, Test) {
|
||||
auto* lhs_expr = ToValueBase(GetParam().lhs)->Expr(*this);
|
||||
auto* rhs_expr = ToValueBase(GetParam().rhs)->Expr(*this);
|
||||
auto* lhs_expr = GetParam().lhs.Expr(*this);
|
||||
auto* rhs_expr = GetParam().rhs.Expr(*this);
|
||||
GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
EXPECT_EQ(r()->error(), "1:1 error: shift left operation results in sign change");
|
||||
|
|
|
@ -22,15 +22,12 @@ using ::testing::HasSubstr;
|
|||
namespace tint::resolver {
|
||||
namespace {
|
||||
|
||||
// Bring in std::ostream& operator<<(std::ostream& o, const Types& types)
|
||||
using resolver::operator<<;
|
||||
|
||||
struct Case {
|
||||
Case(utils::VectorRef<Types> in_args, utils::VectorRef<Types> expected_values)
|
||||
Case(utils::VectorRef<Value> in_args, utils::VectorRef<Value> expected_values)
|
||||
: args(std::move(in_args)),
|
||||
expected(Success{std::move(expected_values), CheckConstantFlags{}}) {}
|
||||
|
||||
Case(utils::VectorRef<Types> in_args, std::string expected_err)
|
||||
Case(utils::VectorRef<Value> in_args, std::string expected_err)
|
||||
: args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {}
|
||||
|
||||
/// Expected value may be positive or negative
|
||||
|
@ -52,14 +49,14 @@ struct Case {
|
|||
}
|
||||
|
||||
struct Success {
|
||||
utils::Vector<Types, 2> values;
|
||||
utils::Vector<Value, 2> values;
|
||||
CheckConstantFlags flags;
|
||||
};
|
||||
struct Failure {
|
||||
std::string error;
|
||||
};
|
||||
|
||||
utils::Vector<Types, 8> args;
|
||||
utils::Vector<Value, 8> args;
|
||||
utils::Result<Success, Failure> expected;
|
||||
};
|
||||
|
||||
|
@ -94,34 +91,34 @@ static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
|||
using ScalarTypes = std::variant<AInt, AFloat, u32, i32, f32, f16>;
|
||||
|
||||
/// Creates a Case with Values for args and result
|
||||
static Case C(std::initializer_list<Types> args, Types result) {
|
||||
return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{std::move(result)}};
|
||||
static Case C(std::initializer_list<Value> args, Value result) {
|
||||
return Case{utils::Vector<Value, 8>{args}, utils::Vector<Value, 2>{std::move(result)}};
|
||||
}
|
||||
|
||||
/// Creates a Case with Values for args and result
|
||||
static Case C(std::initializer_list<Types> args, std::initializer_list<Types> results) {
|
||||
return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{results}};
|
||||
static Case C(std::initializer_list<Value> args, std::initializer_list<Value> results) {
|
||||
return Case{utils::Vector<Value, 8>{args}, utils::Vector<Value, 2>{results}};
|
||||
}
|
||||
|
||||
/// Convenience overload that creates a Case with just scalars
|
||||
static Case C(std::initializer_list<ScalarTypes> sargs, ScalarTypes sresult) {
|
||||
utils::Vector<Types, 8> args;
|
||||
utils::Vector<Value, 8> args;
|
||||
for (auto& sa : sargs) {
|
||||
std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa);
|
||||
}
|
||||
Types result = Val(0_a);
|
||||
Value result = Val(0_a);
|
||||
std::visit([&](auto&& v) { result = Val(v); }, sresult);
|
||||
return Case{std::move(args), utils::Vector<Types, 2>{std::move(result)}};
|
||||
return Case{std::move(args), utils::Vector<Value, 2>{std::move(result)}};
|
||||
}
|
||||
|
||||
/// Creates a Case with Values for args and result
|
||||
static Case C(std::initializer_list<ScalarTypes> sargs,
|
||||
std::initializer_list<ScalarTypes> sresults) {
|
||||
utils::Vector<Types, 8> args;
|
||||
utils::Vector<Value, 8> args;
|
||||
for (auto& sa : sargs) {
|
||||
std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa);
|
||||
}
|
||||
utils::Vector<Types, 2> results;
|
||||
utils::Vector<Value, 2> results;
|
||||
for (auto& sa : sresults) {
|
||||
std::visit([&](auto&& v) { return results.Push(Val(v)); }, sa);
|
||||
}
|
||||
|
@ -129,13 +126,13 @@ static Case C(std::initializer_list<ScalarTypes> sargs,
|
|||
}
|
||||
|
||||
/// Creates a Case with Values for args and expected error
|
||||
static Case E(std::initializer_list<Types> args, std::string err) {
|
||||
return Case{utils::Vector<Types, 8>{args}, std::move(err)};
|
||||
static Case E(std::initializer_list<Value> args, std::string err) {
|
||||
return Case{utils::Vector<Value, 8>{args}, std::move(err)};
|
||||
}
|
||||
|
||||
/// Convenience overload that creates an expected-error Case with just scalars
|
||||
static Case E(std::initializer_list<ScalarTypes> sargs, std::string err) {
|
||||
utils::Vector<Types, 8> args;
|
||||
utils::Vector<Value, 8> args;
|
||||
for (auto& sa : sargs) {
|
||||
std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa);
|
||||
}
|
||||
|
@ -152,7 +149,7 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
|
|||
|
||||
utils::Vector<const ast::Expression*, 8> args;
|
||||
for (auto& a : c.args) {
|
||||
std::visit([&](auto&& v) { args.Push(v.Expr(*this)); }, a);
|
||||
args.Push(a.Expr(*this));
|
||||
}
|
||||
|
||||
auto* expr = Call(Source{{12, 34}}, sem::str(builtin), std::move(args));
|
||||
|
@ -173,14 +170,13 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
|
|||
// The result type of the constant-evaluated expression is a structure.
|
||||
// Compare each of the fields individually.
|
||||
for (size_t i = 0; i < expected_case.values.Length(); i++) {
|
||||
CheckConstant(value->Index(i), ToValueBase(expected_case.values[i]),
|
||||
expected_case.flags);
|
||||
CheckConstant(value->Index(i), expected_case.values[i], expected_case.flags);
|
||||
}
|
||||
} else {
|
||||
// Return type is not a structure. Just compare the single value
|
||||
ASSERT_EQ(expected_case.values.Length(), 1u)
|
||||
<< "const-eval returned non-struct, but Case expected multiple values";
|
||||
CheckConstant(value, ToValueBase(expected_case.values[0]), expected_case.flags);
|
||||
CheckConstant(value, expected_case.values[0], expected_case.flags);
|
||||
}
|
||||
} else {
|
||||
EXPECT_FALSE(r()->Resolve());
|
||||
|
|
|
@ -19,19 +19,6 @@ using namespace tint::number_suffixes; // NOLINT
|
|||
namespace tint::resolver {
|
||||
namespace {
|
||||
|
||||
using Scalar = std::variant< //
|
||||
builder::Value<AInt>,
|
||||
builder::Value<AFloat>,
|
||||
builder::Value<u32>,
|
||||
builder::Value<i32>,
|
||||
builder::Value<f32>,
|
||||
builder::Value<f16>,
|
||||
builder::Value<bool>>;
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, const Scalar& scalar) {
|
||||
return ToValueBase(scalar)->Print(o);
|
||||
}
|
||||
|
||||
enum class Kind {
|
||||
kScalar,
|
||||
kVector,
|
||||
|
@ -48,8 +35,8 @@ static std::ostream& operator<<(std::ostream& o, const Kind& k) {
|
|||
}
|
||||
|
||||
struct Case {
|
||||
Scalar input;
|
||||
Scalar expected;
|
||||
Value input;
|
||||
Value expected;
|
||||
builder::CreatePtrs type;
|
||||
bool unrepresentable = false;
|
||||
};
|
||||
|
@ -65,7 +52,7 @@ static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
|||
|
||||
template <typename TO, typename FROM>
|
||||
Case Success(FROM input, TO expected) {
|
||||
return {builder::Val(input), builder::Val(expected), builder::CreatePtrsFor<TO>()};
|
||||
return {Val(input), Val(expected), builder::CreatePtrsFor<TO>()};
|
||||
}
|
||||
|
||||
template <typename TO, typename FROM>
|
||||
|
@ -83,7 +70,7 @@ TEST_P(ResolverConstEvalConvTest, Test) {
|
|||
const auto& type = std::get<1>(GetParam()).type;
|
||||
const auto unrepresentable = std::get<1>(GetParam()).unrepresentable;
|
||||
|
||||
auto* input_val = ToValueBase(input)->Expr(*this);
|
||||
auto* input_val = input.Expr(*this);
|
||||
auto* expr = Construct(type.ast(*this), input_val);
|
||||
if (kind == Kind::kVector) {
|
||||
expr = Construct(ty.vec(nullptr, 3), expr);
|
||||
|
@ -107,7 +94,7 @@ TEST_P(ResolverConstEvalConvTest, Test) {
|
|||
ASSERT_NE(sem->ConstantValue(), nullptr);
|
||||
EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
|
||||
|
||||
auto expected_values = ToValueBase(expected)->Args();
|
||||
auto expected_values = expected.Args();
|
||||
if (kind == Kind::kVector) {
|
||||
expected_values.values.Push(expected_values.values[0]);
|
||||
expected_values.values.Push(expected_values.values[0]);
|
||||
|
|
|
@ -88,10 +88,10 @@ struct CheckConstantFlags {
|
|||
/// @param expected_value the expected value for the test
|
||||
/// @param flags optional flags for controlling the comparisons
|
||||
inline void CheckConstant(const sem::Constant* got_constant,
|
||||
const builder::ValueBase* expected_value,
|
||||
const builder::Value& expected_value,
|
||||
CheckConstantFlags flags = {}) {
|
||||
auto values_flat = ScalarArgsFrom(got_constant);
|
||||
auto expected_values_flat = expected_value->Args();
|
||||
auto expected_values_flat = expected_value.Args();
|
||||
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
|
||||
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
|
||||
auto& got_scalar = values_flat.values[i];
|
||||
|
@ -247,93 +247,8 @@ using builder::IsValue;
|
|||
using builder::Mat;
|
||||
using builder::Val;
|
||||
using builder::Value;
|
||||
using builder::ValueBase;
|
||||
using builder::Vec;
|
||||
|
||||
using Types = std::variant< //
|
||||
Value<AInt>,
|
||||
Value<AFloat>,
|
||||
Value<u32>,
|
||||
Value<i32>,
|
||||
Value<f32>,
|
||||
Value<f16>,
|
||||
Value<bool>,
|
||||
|
||||
Value<builder::vec2<AInt>>,
|
||||
Value<builder::vec2<AFloat>>,
|
||||
Value<builder::vec2<u32>>,
|
||||
Value<builder::vec2<i32>>,
|
||||
Value<builder::vec2<f32>>,
|
||||
Value<builder::vec2<f16>>,
|
||||
Value<builder::vec2<bool>>,
|
||||
|
||||
Value<builder::vec3<AInt>>,
|
||||
Value<builder::vec3<AFloat>>,
|
||||
Value<builder::vec3<u32>>,
|
||||
Value<builder::vec3<i32>>,
|
||||
Value<builder::vec3<f32>>,
|
||||
Value<builder::vec3<f16>>,
|
||||
Value<builder::vec3<bool>>,
|
||||
|
||||
Value<builder::vec4<AInt>>,
|
||||
Value<builder::vec4<AFloat>>,
|
||||
Value<builder::vec4<u32>>,
|
||||
Value<builder::vec4<i32>>,
|
||||
Value<builder::vec4<f32>>,
|
||||
Value<builder::vec4<f16>>,
|
||||
Value<builder::vec4<bool>>,
|
||||
|
||||
Value<builder::mat2x2<AInt>>,
|
||||
Value<builder::mat2x2<AFloat>>,
|
||||
Value<builder::mat2x2<f32>>,
|
||||
Value<builder::mat2x2<f16>>,
|
||||
|
||||
Value<builder::mat3x3<AInt>>,
|
||||
Value<builder::mat3x3<AFloat>>,
|
||||
Value<builder::mat3x3<f32>>,
|
||||
Value<builder::mat3x3<f16>>,
|
||||
|
||||
Value<builder::mat4x4<AInt>>,
|
||||
Value<builder::mat4x4<AFloat>>,
|
||||
Value<builder::mat4x4<f32>>,
|
||||
Value<builder::mat4x4<f16>>,
|
||||
|
||||
Value<builder::mat2x3<AInt>>,
|
||||
Value<builder::mat2x3<AFloat>>,
|
||||
Value<builder::mat2x3<f32>>,
|
||||
Value<builder::mat2x3<f16>>,
|
||||
|
||||
Value<builder::mat3x2<AInt>>,
|
||||
Value<builder::mat3x2<AFloat>>,
|
||||
Value<builder::mat3x2<f32>>,
|
||||
Value<builder::mat3x2<f16>>,
|
||||
|
||||
Value<builder::mat2x4<AInt>>,
|
||||
Value<builder::mat2x4<AFloat>>,
|
||||
Value<builder::mat2x4<f32>>,
|
||||
Value<builder::mat2x4<f16>>,
|
||||
|
||||
Value<builder::mat4x2<AInt>>,
|
||||
Value<builder::mat4x2<AFloat>>,
|
||||
Value<builder::mat4x2<f32>>,
|
||||
Value<builder::mat4x2<f16>>
|
||||
//
|
||||
>;
|
||||
|
||||
/// Returns the current Value<T> in the `types` variant as a `ValueBase` pointer to use the
|
||||
/// polymorphic API. This trades longer compile times using std::variant for longer runtime via
|
||||
/// virtual function calls.
|
||||
template <typename ValueVariant>
|
||||
inline const ValueBase* ToValueBase(const ValueVariant& types) {
|
||||
return std::visit(
|
||||
[](auto&& t) -> const ValueBase* { return static_cast<const ValueBase*>(&t); }, types);
|
||||
}
|
||||
|
||||
/// Prints Types to ostream
|
||||
inline std::ostream& operator<<(std::ostream& o, const Types& types) {
|
||||
return ToValueBase(types)->Print(o);
|
||||
}
|
||||
|
||||
// Calls `f` on deepest elements of both `a` and `b`. If function returns Action::kStop, it stops
|
||||
// traversing, and return Action::kStop; if the function returns Action::kContinue, it continues and
|
||||
// returns Action::kContinue when done.
|
||||
|
|
|
@ -19,22 +19,17 @@ using namespace tint::number_suffixes; // NOLINT
|
|||
namespace tint::resolver {
|
||||
namespace {
|
||||
|
||||
// Bring in std::ostream& operator<<(std::ostream& o, const Types& types)
|
||||
using resolver::operator<<;
|
||||
|
||||
struct Case {
|
||||
Types input;
|
||||
Types expected;
|
||||
Value input;
|
||||
Value expected;
|
||||
};
|
||||
|
||||
static std::ostream& operator<<(std::ostream& o, const Case& c) {
|
||||
o << "input: " << c.input << ", expected: " << c.expected;
|
||||
return o;
|
||||
}
|
||||
|
||||
/// Creates a Case with Values of any type
|
||||
template <typename T, typename U>
|
||||
Case C(Value<T> input, Value<U> expected) {
|
||||
// Creates a Case with Values of any type
|
||||
Case C(Value input, Value expected) {
|
||||
return Case{std::move(input), std::move(expected)};
|
||||
}
|
||||
|
||||
|
@ -52,10 +47,10 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) {
|
|||
auto op = std::get<0>(GetParam());
|
||||
auto& c = std::get<1>(GetParam());
|
||||
|
||||
auto* expected = ToValueBase(c.expected);
|
||||
auto* input = ToValueBase(c.input);
|
||||
auto& expected = c.expected;
|
||||
auto& input = c.input;
|
||||
|
||||
auto* input_expr = input->Expr(*this);
|
||||
auto* input_expr = input.Expr(*this);
|
||||
auto* expr = create<ast::UnaryOpExpression>(op, input_expr);
|
||||
|
||||
GlobalConst("C", expr);
|
||||
|
@ -67,13 +62,13 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) {
|
|||
EXPECT_TYPE(value->Type(), sem->Type());
|
||||
|
||||
auto values_flat = ScalarArgsFrom(value);
|
||||
auto expected_values_flat = expected->Args();
|
||||
auto expected_values_flat = expected.Args();
|
||||
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
|
||||
for (size_t i = 0; i < values_flat.values.Length(); ++i) {
|
||||
auto& a = values_flat.values[i];
|
||||
auto& b = expected_values_flat.values[i];
|
||||
EXPECT_EQ(a, b);
|
||||
if (expected->IsIntegral()) {
|
||||
if (expected.IsIntegral()) {
|
||||
// Check that the constant's integer doesn't contain unexpected
|
||||
// data in the MSBs that are outside of the bit-width of T.
|
||||
EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));
|
||||
|
|
|
@ -241,12 +241,21 @@ using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder&
|
|||
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
|
||||
using type_name_func_ptr = std::string (*)();
|
||||
|
||||
struct UnspecializedElementType {};
|
||||
|
||||
/// Base template for DataType, specialized below.
|
||||
template <typename T>
|
||||
struct DataType {};
|
||||
struct DataType {
|
||||
/// The element type
|
||||
using ElementType = UnspecializedElementType;
|
||||
};
|
||||
|
||||
/// Helper that represents no-type. Returns nullptr for all static methods.
|
||||
template <>
|
||||
struct DataType<void> {
|
||||
/// The element type
|
||||
using ElementType = void;
|
||||
|
||||
/// @return nullptr
|
||||
static inline const ast::Type* AST(ProgramBuilder&) { return nullptr; }
|
||||
/// @return nullptr
|
||||
|
@ -762,7 +771,13 @@ constexpr CreatePtrs CreatePtrsFor() {
|
|||
DataType<T>::Name};
|
||||
}
|
||||
|
||||
/// Base class for Value<T>
|
||||
/// True if DataType<T> is specialized for T, false otherwise.
|
||||
template <typename T>
|
||||
const bool IsDataTypeSpecializedFor =
|
||||
!std::is_same_v<typename DataType<T>::ElementType, UnspecializedElementType>;
|
||||
|
||||
namespace detail {
|
||||
/// ValueBase is a base class of ConcreteValue<T>
|
||||
struct ValueBase {
|
||||
/// Constructor
|
||||
ValueBase() = default;
|
||||
|
@ -793,13 +808,12 @@ struct ValueBase {
|
|||
virtual std::ostream& Print(std::ostream& o) const = 0;
|
||||
};
|
||||
|
||||
/// Value<T> is an instance of a value of type DataType<T>. Useful for storing values to create
|
||||
/// expressions with.
|
||||
/// ConcreteValue<T> is used to create Values of type DataType<T> with a ScalarArgs initializer.
|
||||
template <typename T>
|
||||
struct Value : ValueBase {
|
||||
struct ConcreteValue : ValueBase {
|
||||
/// Constructor
|
||||
/// @param a the scalar args
|
||||
explicit Value(ScalarArgs a) : args(std::move(a)) {}
|
||||
/// @param args the scalar args
|
||||
explicit ConcreteValue(ScalarArgs args) : args_(std::move(args)) {}
|
||||
|
||||
/// Alias to T
|
||||
using Type = T;
|
||||
|
@ -808,21 +822,16 @@ struct Value : ValueBase {
|
|||
/// Alias to DataType::ElementType
|
||||
using ElementType = typename DataType::ElementType;
|
||||
|
||||
/// Creates a Value<T> with `args`
|
||||
/// @param args the args that will be passed to the expression
|
||||
/// @returns a Value<T>
|
||||
static Value Create(ScalarArgs args) { return Value{std::move(args)}; }
|
||||
|
||||
/// Creates an `ast::Expression` for the type T passing in previously stored args
|
||||
/// @param b the ProgramBuilder
|
||||
/// @returns an expression node
|
||||
const ast::Expression* Expr(ProgramBuilder& b) const override {
|
||||
auto create = CreatePtrsFor<T>();
|
||||
return (*create.expr)(b, args);
|
||||
return (*create.expr)(b, args_);
|
||||
}
|
||||
|
||||
/// @returns args used to create expression via `Expr`
|
||||
const ScalarArgs& Args() const override { return args; }
|
||||
const ScalarArgs& Args() const override { return args_; }
|
||||
|
||||
/// @returns true if element type is abstract
|
||||
bool IsAbstract() const override { return tint::IsAbstract<ElementType>; }
|
||||
|
@ -838,9 +847,9 @@ struct Value : ValueBase {
|
|||
/// @returns input argument `o`
|
||||
std::ostream& Print(std::ostream& o) const override {
|
||||
o << TypeName() << "(";
|
||||
for (auto& a : args.values) {
|
||||
for (auto& a : args_.values) {
|
||||
o << std::get<ElementType>(a);
|
||||
if (&a != &args.values.Back()) {
|
||||
if (&a != &args_.values.Back()) {
|
||||
o << ", ";
|
||||
}
|
||||
}
|
||||
|
@ -848,60 +857,95 @@ struct Value : ValueBase {
|
|||
return o;
|
||||
}
|
||||
|
||||
private:
|
||||
/// args to create expression with
|
||||
ScalarArgs args;
|
||||
ScalarArgs args_;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
/// Base template for IsValue
|
||||
template <typename T>
|
||||
struct IsValue : std::false_type {};
|
||||
/// Specialization for IsValue
|
||||
template <typename T>
|
||||
struct IsValue<Value<T>> : std::true_type {};
|
||||
} // namespace detail
|
||||
|
||||
/// True if T is of type Value
|
||||
template <typename T>
|
||||
constexpr bool IsValue = detail::IsValue<T>::value;
|
||||
/// A Value represents a value of type DataType<T> created with ScalarArgs. Useful for storing
|
||||
/// values for unit tests.
|
||||
class Value {
|
||||
public:
|
||||
/// Creates a Value for type T initialized with `args`
|
||||
/// @param args the scalar args
|
||||
/// @returns Value
|
||||
template <typename T>
|
||||
static Value Create(ScalarArgs args) {
|
||||
static_assert(IsDataTypeSpecializedFor<T>, "No DataType<T> specialization exists");
|
||||
return Value{std::make_shared<detail::ConcreteValue<T>>(std::move(args))};
|
||||
}
|
||||
|
||||
/// Returns the friendly name of ValueT
|
||||
template <typename ValueT, typename = traits::EnableIf<IsValue<ValueT>>>
|
||||
const char* FriendlyName() {
|
||||
return tint::FriendlyName<typename ValueT::ElementType>();
|
||||
/// Creates an `ast::Expression` for the type T passing in previously stored args
|
||||
/// @param b the ProgramBuilder
|
||||
/// @returns an expression node
|
||||
const ast::Expression* Expr(ProgramBuilder& b) const { return value_->Expr(b); }
|
||||
|
||||
/// @returns args used to create expression via `Expr`
|
||||
const ScalarArgs& Args() const { return value_->Args(); }
|
||||
|
||||
/// @returns true if element type is abstract
|
||||
bool IsAbstract() const { return value_->IsAbstract(); }
|
||||
|
||||
/// @returns true if element type is an integral
|
||||
bool IsIntegral() const { return value_->IsIntegral(); }
|
||||
|
||||
/// @returns element type name
|
||||
std::string TypeName() const { return value_->TypeName(); }
|
||||
|
||||
/// Prints this value to the output stream
|
||||
/// @param o the output stream
|
||||
/// @returns input argument `o`
|
||||
std::ostream& Print(std::ostream& o) const { return value_->Print(o); }
|
||||
|
||||
private:
|
||||
/// Private constructor
|
||||
explicit Value(std::shared_ptr<const detail::ValueBase> value) : value_(std::move(value)) {}
|
||||
|
||||
/// Shared pointer to an immutable value. This type-erasure pattern allows Value to wrap a
|
||||
/// polymorphic type, while being used like a value-type (i.e. copyable).
|
||||
std::shared_ptr<const detail::ValueBase> value_;
|
||||
};
|
||||
|
||||
/// Prints Value to ostream
|
||||
inline std::ostream& operator<<(std::ostream& o, const Value& value) {
|
||||
return value.Print(o);
|
||||
}
|
||||
|
||||
/// Creates a `Value<T>` from a scalar `v`
|
||||
/// True if T is Value, false otherwise
|
||||
template <typename T>
|
||||
auto Val(T v) {
|
||||
return Value<T>::Create(ScalarArgs{v});
|
||||
constexpr bool IsValue = std::is_same_v<T, Value>;
|
||||
|
||||
/// Creates a Value of DataType<T> from a scalar `v`
|
||||
template <typename T>
|
||||
Value Val(T v) {
|
||||
return Value::Create<T>(ScalarArgs{v});
|
||||
}
|
||||
|
||||
/// Creates a `Value<vec<N, T>>` from N scalar `args`
|
||||
/// Creates a Value of DataType<vec<N, T>> from N scalar `args`
|
||||
template <typename... T>
|
||||
auto Vec(T... args) {
|
||||
constexpr size_t N = sizeof...(args);
|
||||
Value Vec(T... args) {
|
||||
using FirstT = std::tuple_element_t<0, std::tuple<T...>>;
|
||||
constexpr size_t N = sizeof...(args);
|
||||
utils::Vector v{args...};
|
||||
using VT = vec<N, FirstT>;
|
||||
return Value<VT>::Create(utils::VectorRef<FirstT>{v});
|
||||
return Value::Create<vec<N, FirstT>>(utils::VectorRef<FirstT>{v});
|
||||
}
|
||||
|
||||
/// Creates a `Value<mat<C,R,T>` from C*R scalar `args`
|
||||
/// Creates a Value of DataType<mat<C,R,T> from C*R scalar `args`
|
||||
template <size_t C, size_t R, typename T>
|
||||
auto Mat(const T (&m_in)[C][R]) {
|
||||
Value Mat(const T (&m_in)[C][R]) {
|
||||
utils::Vector<T, C * R> m;
|
||||
for (uint32_t i = 0; i < C; ++i) {
|
||||
for (size_t j = 0; j < R; ++j) {
|
||||
m.Push(m_in[i][j]);
|
||||
}
|
||||
}
|
||||
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
|
||||
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
|
||||
}
|
||||
|
||||
/// Creates a `Value<mat<2,R,T>` from column vectors `c0` and `c1`
|
||||
/// Creates a Value of DataType<mat<2,R,T> from column vectors `c0` and `c1`
|
||||
template <typename T, size_t R>
|
||||
auto Mat(const T (&c0)[R], const T (&c1)[R]) {
|
||||
Value Mat(const T (&c0)[R], const T (&c1)[R]) {
|
||||
constexpr size_t C = 2;
|
||||
utils::Vector<T, C * R> m;
|
||||
for (auto v : c0) {
|
||||
|
@ -910,12 +954,12 @@ auto Mat(const T (&c0)[R], const T (&c1)[R]) {
|
|||
for (auto v : c1) {
|
||||
m.Push(v);
|
||||
}
|
||||
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
|
||||
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
|
||||
}
|
||||
|
||||
/// Creates a `Value<mat<3,R,T>` from column vectors `c0`, `c1`, and `c2`
|
||||
/// Creates a Value of DataType<mat<3,R,T> from column vectors `c0`, `c1`, and `c2`
|
||||
template <typename T, size_t R>
|
||||
auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
|
||||
Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
|
||||
constexpr size_t C = 3;
|
||||
utils::Vector<T, C * R> m;
|
||||
for (auto v : c0) {
|
||||
|
@ -927,12 +971,12 @@ auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
|
|||
for (auto v : c2) {
|
||||
m.Push(v);
|
||||
}
|
||||
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
|
||||
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
|
||||
}
|
||||
|
||||
/// Creates a `Value<mat<4,R,T>` from column vectors `c0`, `c1`, `c2`, and `c3`
|
||||
/// Creates a Value of DataType<mat<4,R,T> from column vectors `c0`, `c1`, `c2`, and `c3`
|
||||
template <typename T, size_t R>
|
||||
auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) {
|
||||
Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) {
|
||||
constexpr size_t C = 4;
|
||||
utils::Vector<T, C * R> m;
|
||||
for (auto v : c0) {
|
||||
|
@ -947,7 +991,7 @@ auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R])
|
|||
for (auto v : c3) {
|
||||
m.Push(v);
|
||||
}
|
||||
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m});
|
||||
return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
|
|
Loading…
Reference in New Issue