tint: clean up const eval test framework

- Remove Types variant, and replace with a type-erasing Value class
  instead. This is not only better for compile times, but makes the code
  much easier to understand.
- Value wraps an internal shared_ptr to a const detail::ValueBase,
  allowing it to be used as a value-type (i.e. copyable), while behaving
  polymorphically.
- Add static_asserts to Val, Vec, and Mat creation helpers to emit a
  more useful error message when the wrong type is passed in.

Bug: tint:1581
Change-Id: Icd0d08522bedb3eab12c44efa0d1555ed6e96458
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/111700
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
Antonio Maiorano 2022-11-25 05:47:42 +00:00 committed by Dawn LUCI CQ
parent 8392a82a40
commit d4908670e1
6 changed files with 147 additions and 215 deletions

View File

@ -22,30 +22,26 @@ using ::testing::HasSubstr;
namespace tint::resolver { namespace tint::resolver {
namespace { namespace {
// Bring in std::ostream& operator<<(std::ostream& o, const Types& types)
using resolver::operator<<;
struct Case { struct Case {
struct Success { struct Success {
Types value; Value value;
}; };
struct Failure { struct Failure {
std::string error; std::string error;
}; };
Types lhs; Value lhs;
Types rhs; Value rhs;
utils::Result<Success, Failure> expected; utils::Result<Success, Failure> expected;
}; };
struct ErrorCase { struct ErrorCase {
Types lhs; Value lhs;
Types rhs; Value rhs;
}; };
/// Creates a Case with Values of any type /// Creates a Case with Values of any type
template <typename T, typename U, typename V> Case C(Value lhs, Value rhs, Value expected) {
Case C(Value<T> lhs, Value<U> rhs, Value<V> expected) {
return Case{std::move(lhs), std::move(rhs), Case::Success{std::move(expected)}}; return Case{std::move(lhs), std::move(rhs), Case::Success{std::move(expected)}};
} }
@ -56,8 +52,7 @@ Case C(T lhs, U rhs, V expected) {
} }
/// Creates an failure Case with Values of any type /// Creates an failure Case with Values of any type
template <typename T, typename U> Case E(Value lhs, Value rhs, std::string error) {
Case E(Value<T> lhs, Value<U> rhs, std::string error) {
return Case{std::move(lhs), std::move(rhs), Case::Failure{std::move(error)}}; return Case{std::move(lhs), std::move(rhs), Case::Failure{std::move(error)}};
} }
@ -71,7 +66,7 @@ Case E(T lhs, U rhs, std::string error) {
static std::ostream& operator<<(std::ostream& o, const Case& c) { static std::ostream& operator<<(std::ostream& o, const Case& c) {
o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: "; o << "lhs: " << c.lhs << ", rhs: " << c.rhs << ", expected: ";
if (c.expected) { if (c.expected) {
auto s = c.expected.Get(); auto& s = c.expected.Get();
o << s.value; o << s.value;
} else { } else {
o << "[ERROR: " << c.expected.Failure().error << "]"; o << "[ERROR: " << c.expected.Failure().error << "]";
@ -91,15 +86,16 @@ TEST_P(ResolverConstEvalBinaryOpTest, Test) {
auto op = std::get<0>(GetParam()); auto op = std::get<0>(GetParam());
auto& c = std::get<1>(GetParam()); auto& c = std::get<1>(GetParam());
auto* lhs_expr = ToValueBase(c.lhs)->Expr(*this); auto* lhs_expr = c.lhs.Expr(*this);
auto* rhs_expr = ToValueBase(c.rhs)->Expr(*this); auto* rhs_expr = c.rhs.Expr(*this);
auto* expr = create<ast::BinaryExpression>(Source{{12, 34}}, op, lhs_expr, rhs_expr); auto* expr = create<ast::BinaryExpression>(Source{{12, 34}}, op, lhs_expr, rhs_expr);
GlobalConst("C", expr); GlobalConst("C", expr);
if (c.expected) { if (c.expected) {
ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_TRUE(r()->Resolve()) << r()->error();
auto expected_case = c.expected.Get(); auto expected_case = c.expected.Get();
auto* expected = ToValueBase(expected_case.value); auto& expected = expected_case.value;
auto* sem = Sem().Get(expr); auto* sem = Sem().Get(expr);
const sem::Constant* value = sem->ConstantValue(); const sem::Constant* value = sem->ConstantValue();
@ -707,7 +703,6 @@ INSTANTIATE_TEST_SUITE_P(Or,
OpOrIntCases<u32>())))); OpOrIntCases<u32>()))));
TEST_F(ResolverConstEvalTest, NotAndOrOfVecs) { TEST_F(ResolverConstEvalTest, NotAndOrOfVecs) {
// const C = !((vec2(true, true) & vec2(true, false)) | vec2(false, true));
auto v1 = Vec(true, true).Expr(*this); auto v1 = Vec(true, true).Expr(*this);
auto v2 = Vec(true, false).Expr(*this); auto v2 = Vec(true, false).Expr(*this);
auto v3 = Vec(false, true).Expr(*this); auto v3 = Vec(false, true).Expr(*this);
@ -978,8 +973,8 @@ TEST_F(ResolverConstEvalTest, BinaryAbstractShiftLeftRemainsAbstract) {
// i32/u32 left shift by >= 32 -> error // i32/u32 left shift by >= 32 -> error
using ResolverConstEvalShiftLeftConcreteGeqBitWidthError = ResolverTestWithParam<ErrorCase>; using ResolverConstEvalShiftLeftConcreteGeqBitWidthError = ResolverTestWithParam<ErrorCase>;
TEST_P(ResolverConstEvalShiftLeftConcreteGeqBitWidthError, Test) { TEST_P(ResolverConstEvalShiftLeftConcreteGeqBitWidthError, Test) {
auto* lhs_expr = ToValueBase(GetParam().lhs)->Expr(*this); auto* lhs_expr = GetParam().lhs.Expr(*this);
auto* rhs_expr = ToValueBase(GetParam().rhs)->Expr(*this); auto* rhs_expr = GetParam().rhs.Expr(*this);
GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr)); GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ( EXPECT_EQ(
@ -1024,8 +1019,8 @@ INSTANTIATE_TEST_SUITE_P(Test,
// AInt left shift results in sign change error // AInt left shift results in sign change error
using ResolverConstEvalShiftLeftSignChangeError = ResolverTestWithParam<ErrorCase>; using ResolverConstEvalShiftLeftSignChangeError = ResolverTestWithParam<ErrorCase>;
TEST_P(ResolverConstEvalShiftLeftSignChangeError, Test) { TEST_P(ResolverConstEvalShiftLeftSignChangeError, Test) {
auto* lhs_expr = ToValueBase(GetParam().lhs)->Expr(*this); auto* lhs_expr = GetParam().lhs.Expr(*this);
auto* rhs_expr = ToValueBase(GetParam().rhs)->Expr(*this); auto* rhs_expr = GetParam().rhs.Expr(*this);
GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr)); GlobalConst("c", Shl(Source{{1, 1}}, lhs_expr, rhs_expr));
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());
EXPECT_EQ(r()->error(), "1:1 error: shift left operation results in sign change"); EXPECT_EQ(r()->error(), "1:1 error: shift left operation results in sign change");

View File

@ -22,15 +22,12 @@ using ::testing::HasSubstr;
namespace tint::resolver { namespace tint::resolver {
namespace { namespace {
// Bring in std::ostream& operator<<(std::ostream& o, const Types& types)
using resolver::operator<<;
struct Case { struct Case {
Case(utils::VectorRef<Types> in_args, utils::VectorRef<Types> expected_values) Case(utils::VectorRef<Value> in_args, utils::VectorRef<Value> expected_values)
: args(std::move(in_args)), : args(std::move(in_args)),
expected(Success{std::move(expected_values), CheckConstantFlags{}}) {} expected(Success{std::move(expected_values), CheckConstantFlags{}}) {}
Case(utils::VectorRef<Types> in_args, std::string expected_err) Case(utils::VectorRef<Value> in_args, std::string expected_err)
: args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {} : args(std::move(in_args)), expected(Failure{std::move(expected_err)}) {}
/// Expected value may be positive or negative /// Expected value may be positive or negative
@ -52,14 +49,14 @@ struct Case {
} }
struct Success { struct Success {
utils::Vector<Types, 2> values; utils::Vector<Value, 2> values;
CheckConstantFlags flags; CheckConstantFlags flags;
}; };
struct Failure { struct Failure {
std::string error; std::string error;
}; };
utils::Vector<Types, 8> args; utils::Vector<Value, 8> args;
utils::Result<Success, Failure> expected; utils::Result<Success, Failure> expected;
}; };
@ -94,34 +91,34 @@ static std::ostream& operator<<(std::ostream& o, const Case& c) {
using ScalarTypes = std::variant<AInt, AFloat, u32, i32, f32, f16>; using ScalarTypes = std::variant<AInt, AFloat, u32, i32, f32, f16>;
/// Creates a Case with Values for args and result /// Creates a Case with Values for args and result
static Case C(std::initializer_list<Types> args, Types result) { static Case C(std::initializer_list<Value> args, Value result) {
return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{std::move(result)}}; return Case{utils::Vector<Value, 8>{args}, utils::Vector<Value, 2>{std::move(result)}};
} }
/// Creates a Case with Values for args and result /// Creates a Case with Values for args and result
static Case C(std::initializer_list<Types> args, std::initializer_list<Types> results) { static Case C(std::initializer_list<Value> args, std::initializer_list<Value> results) {
return Case{utils::Vector<Types, 8>{args}, utils::Vector<Types, 2>{results}}; return Case{utils::Vector<Value, 8>{args}, utils::Vector<Value, 2>{results}};
} }
/// Convenience overload that creates a Case with just scalars /// Convenience overload that creates a Case with just scalars
static Case C(std::initializer_list<ScalarTypes> sargs, ScalarTypes sresult) { static Case C(std::initializer_list<ScalarTypes> sargs, ScalarTypes sresult) {
utils::Vector<Types, 8> args; utils::Vector<Value, 8> args;
for (auto& sa : sargs) { for (auto& sa : sargs) {
std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa); std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa);
} }
Types result = Val(0_a); Value result = Val(0_a);
std::visit([&](auto&& v) { result = Val(v); }, sresult); std::visit([&](auto&& v) { result = Val(v); }, sresult);
return Case{std::move(args), utils::Vector<Types, 2>{std::move(result)}}; return Case{std::move(args), utils::Vector<Value, 2>{std::move(result)}};
} }
/// Creates a Case with Values for args and result /// Creates a Case with Values for args and result
static Case C(std::initializer_list<ScalarTypes> sargs, static Case C(std::initializer_list<ScalarTypes> sargs,
std::initializer_list<ScalarTypes> sresults) { std::initializer_list<ScalarTypes> sresults) {
utils::Vector<Types, 8> args; utils::Vector<Value, 8> args;
for (auto& sa : sargs) { for (auto& sa : sargs) {
std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa); std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa);
} }
utils::Vector<Types, 2> results; utils::Vector<Value, 2> results;
for (auto& sa : sresults) { for (auto& sa : sresults) {
std::visit([&](auto&& v) { return results.Push(Val(v)); }, sa); std::visit([&](auto&& v) { return results.Push(Val(v)); }, sa);
} }
@ -129,13 +126,13 @@ static Case C(std::initializer_list<ScalarTypes> sargs,
} }
/// Creates a Case with Values for args and expected error /// Creates a Case with Values for args and expected error
static Case E(std::initializer_list<Types> args, std::string err) { static Case E(std::initializer_list<Value> args, std::string err) {
return Case{utils::Vector<Types, 8>{args}, std::move(err)}; return Case{utils::Vector<Value, 8>{args}, std::move(err)};
} }
/// Convenience overload that creates an expected-error Case with just scalars /// Convenience overload that creates an expected-error Case with just scalars
static Case E(std::initializer_list<ScalarTypes> sargs, std::string err) { static Case E(std::initializer_list<ScalarTypes> sargs, std::string err) {
utils::Vector<Types, 8> args; utils::Vector<Value, 8> args;
for (auto& sa : sargs) { for (auto& sa : sargs) {
std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa); std::visit([&](auto&& v) { return args.Push(Val(v)); }, sa);
} }
@ -152,7 +149,7 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
utils::Vector<const ast::Expression*, 8> args; utils::Vector<const ast::Expression*, 8> args;
for (auto& a : c.args) { for (auto& a : c.args) {
std::visit([&](auto&& v) { args.Push(v.Expr(*this)); }, a); args.Push(a.Expr(*this));
} }
auto* expr = Call(Source{{12, 34}}, sem::str(builtin), std::move(args)); auto* expr = Call(Source{{12, 34}}, sem::str(builtin), std::move(args));
@ -173,14 +170,13 @@ TEST_P(ResolverConstEvalBuiltinTest, Test) {
// The result type of the constant-evaluated expression is a structure. // The result type of the constant-evaluated expression is a structure.
// Compare each of the fields individually. // Compare each of the fields individually.
for (size_t i = 0; i < expected_case.values.Length(); i++) { for (size_t i = 0; i < expected_case.values.Length(); i++) {
CheckConstant(value->Index(i), ToValueBase(expected_case.values[i]), CheckConstant(value->Index(i), expected_case.values[i], expected_case.flags);
expected_case.flags);
} }
} else { } else {
// Return type is not a structure. Just compare the single value // Return type is not a structure. Just compare the single value
ASSERT_EQ(expected_case.values.Length(), 1u) ASSERT_EQ(expected_case.values.Length(), 1u)
<< "const-eval returned non-struct, but Case expected multiple values"; << "const-eval returned non-struct, but Case expected multiple values";
CheckConstant(value, ToValueBase(expected_case.values[0]), expected_case.flags); CheckConstant(value, expected_case.values[0], expected_case.flags);
} }
} else { } else {
EXPECT_FALSE(r()->Resolve()); EXPECT_FALSE(r()->Resolve());

View File

@ -19,19 +19,6 @@ using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver { namespace tint::resolver {
namespace { namespace {
using Scalar = std::variant< //
builder::Value<AInt>,
builder::Value<AFloat>,
builder::Value<u32>,
builder::Value<i32>,
builder::Value<f32>,
builder::Value<f16>,
builder::Value<bool>>;
static std::ostream& operator<<(std::ostream& o, const Scalar& scalar) {
return ToValueBase(scalar)->Print(o);
}
enum class Kind { enum class Kind {
kScalar, kScalar,
kVector, kVector,
@ -48,8 +35,8 @@ static std::ostream& operator<<(std::ostream& o, const Kind& k) {
} }
struct Case { struct Case {
Scalar input; Value input;
Scalar expected; Value expected;
builder::CreatePtrs type; builder::CreatePtrs type;
bool unrepresentable = false; bool unrepresentable = false;
}; };
@ -65,7 +52,7 @@ static std::ostream& operator<<(std::ostream& o, const Case& c) {
template <typename TO, typename FROM> template <typename TO, typename FROM>
Case Success(FROM input, TO expected) { Case Success(FROM input, TO expected) {
return {builder::Val(input), builder::Val(expected), builder::CreatePtrsFor<TO>()}; return {Val(input), Val(expected), builder::CreatePtrsFor<TO>()};
} }
template <typename TO, typename FROM> template <typename TO, typename FROM>
@ -83,7 +70,7 @@ TEST_P(ResolverConstEvalConvTest, Test) {
const auto& type = std::get<1>(GetParam()).type; const auto& type = std::get<1>(GetParam()).type;
const auto unrepresentable = std::get<1>(GetParam()).unrepresentable; const auto unrepresentable = std::get<1>(GetParam()).unrepresentable;
auto* input_val = ToValueBase(input)->Expr(*this); auto* input_val = input.Expr(*this);
auto* expr = Construct(type.ast(*this), input_val); auto* expr = Construct(type.ast(*this), input_val);
if (kind == Kind::kVector) { if (kind == Kind::kVector) {
expr = Construct(ty.vec(nullptr, 3), expr); expr = Construct(ty.vec(nullptr, 3), expr);
@ -107,7 +94,7 @@ TEST_P(ResolverConstEvalConvTest, Test) {
ASSERT_NE(sem->ConstantValue(), nullptr); ASSERT_NE(sem->ConstantValue(), nullptr);
EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty); EXPECT_TYPE(sem->ConstantValue()->Type(), target_sem_ty);
auto expected_values = ToValueBase(expected)->Args(); auto expected_values = expected.Args();
if (kind == Kind::kVector) { if (kind == Kind::kVector) {
expected_values.values.Push(expected_values.values[0]); expected_values.values.Push(expected_values.values[0]);
expected_values.values.Push(expected_values.values[0]); expected_values.values.Push(expected_values.values[0]);

View File

@ -88,10 +88,10 @@ struct CheckConstantFlags {
/// @param expected_value the expected value for the test /// @param expected_value the expected value for the test
/// @param flags optional flags for controlling the comparisons /// @param flags optional flags for controlling the comparisons
inline void CheckConstant(const sem::Constant* got_constant, inline void CheckConstant(const sem::Constant* got_constant,
const builder::ValueBase* expected_value, const builder::Value& expected_value,
CheckConstantFlags flags = {}) { CheckConstantFlags flags = {}) {
auto values_flat = ScalarArgsFrom(got_constant); auto values_flat = ScalarArgsFrom(got_constant);
auto expected_values_flat = expected_value->Args(); auto expected_values_flat = expected_value.Args();
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length()); ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
for (size_t i = 0; i < values_flat.values.Length(); ++i) { for (size_t i = 0; i < values_flat.values.Length(); ++i) {
auto& got_scalar = values_flat.values[i]; auto& got_scalar = values_flat.values[i];
@ -247,93 +247,8 @@ using builder::IsValue;
using builder::Mat; using builder::Mat;
using builder::Val; using builder::Val;
using builder::Value; using builder::Value;
using builder::ValueBase;
using builder::Vec; using builder::Vec;
using Types = std::variant< //
Value<AInt>,
Value<AFloat>,
Value<u32>,
Value<i32>,
Value<f32>,
Value<f16>,
Value<bool>,
Value<builder::vec2<AInt>>,
Value<builder::vec2<AFloat>>,
Value<builder::vec2<u32>>,
Value<builder::vec2<i32>>,
Value<builder::vec2<f32>>,
Value<builder::vec2<f16>>,
Value<builder::vec2<bool>>,
Value<builder::vec3<AInt>>,
Value<builder::vec3<AFloat>>,
Value<builder::vec3<u32>>,
Value<builder::vec3<i32>>,
Value<builder::vec3<f32>>,
Value<builder::vec3<f16>>,
Value<builder::vec3<bool>>,
Value<builder::vec4<AInt>>,
Value<builder::vec4<AFloat>>,
Value<builder::vec4<u32>>,
Value<builder::vec4<i32>>,
Value<builder::vec4<f32>>,
Value<builder::vec4<f16>>,
Value<builder::vec4<bool>>,
Value<builder::mat2x2<AInt>>,
Value<builder::mat2x2<AFloat>>,
Value<builder::mat2x2<f32>>,
Value<builder::mat2x2<f16>>,
Value<builder::mat3x3<AInt>>,
Value<builder::mat3x3<AFloat>>,
Value<builder::mat3x3<f32>>,
Value<builder::mat3x3<f16>>,
Value<builder::mat4x4<AInt>>,
Value<builder::mat4x4<AFloat>>,
Value<builder::mat4x4<f32>>,
Value<builder::mat4x4<f16>>,
Value<builder::mat2x3<AInt>>,
Value<builder::mat2x3<AFloat>>,
Value<builder::mat2x3<f32>>,
Value<builder::mat2x3<f16>>,
Value<builder::mat3x2<AInt>>,
Value<builder::mat3x2<AFloat>>,
Value<builder::mat3x2<f32>>,
Value<builder::mat3x2<f16>>,
Value<builder::mat2x4<AInt>>,
Value<builder::mat2x4<AFloat>>,
Value<builder::mat2x4<f32>>,
Value<builder::mat2x4<f16>>,
Value<builder::mat4x2<AInt>>,
Value<builder::mat4x2<AFloat>>,
Value<builder::mat4x2<f32>>,
Value<builder::mat4x2<f16>>
//
>;
/// Returns the current Value<T> in the `types` variant as a `ValueBase` pointer to use the
/// polymorphic API. This trades longer compile times using std::variant for longer runtime via
/// virtual function calls.
template <typename ValueVariant>
inline const ValueBase* ToValueBase(const ValueVariant& types) {
return std::visit(
[](auto&& t) -> const ValueBase* { return static_cast<const ValueBase*>(&t); }, types);
}
/// Prints Types to ostream
inline std::ostream& operator<<(std::ostream& o, const Types& types) {
return ToValueBase(types)->Print(o);
}
// Calls `f` on deepest elements of both `a` and `b`. If function returns Action::kStop, it stops // Calls `f` on deepest elements of both `a` and `b`. If function returns Action::kStop, it stops
// traversing, and return Action::kStop; if the function returns Action::kContinue, it continues and // traversing, and return Action::kStop; if the function returns Action::kContinue, it continues and
// returns Action::kContinue when done. // returns Action::kContinue when done.

View File

@ -19,22 +19,17 @@ using namespace tint::number_suffixes; // NOLINT
namespace tint::resolver { namespace tint::resolver {
namespace { namespace {
// Bring in std::ostream& operator<<(std::ostream& o, const Types& types)
using resolver::operator<<;
struct Case { struct Case {
Types input; Value input;
Types expected; Value expected;
}; };
static std::ostream& operator<<(std::ostream& o, const Case& c) { static std::ostream& operator<<(std::ostream& o, const Case& c) {
o << "input: " << c.input << ", expected: " << c.expected; o << "input: " << c.input << ", expected: " << c.expected;
return o; return o;
} }
// Creates a Case with Values of any type
/// Creates a Case with Values of any type Case C(Value input, Value expected) {
template <typename T, typename U>
Case C(Value<T> input, Value<U> expected) {
return Case{std::move(input), std::move(expected)}; return Case{std::move(input), std::move(expected)};
} }
@ -52,10 +47,10 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) {
auto op = std::get<0>(GetParam()); auto op = std::get<0>(GetParam());
auto& c = std::get<1>(GetParam()); auto& c = std::get<1>(GetParam());
auto* expected = ToValueBase(c.expected); auto& expected = c.expected;
auto* input = ToValueBase(c.input); auto& input = c.input;
auto* input_expr = input->Expr(*this); auto* input_expr = input.Expr(*this);
auto* expr = create<ast::UnaryOpExpression>(op, input_expr); auto* expr = create<ast::UnaryOpExpression>(op, input_expr);
GlobalConst("C", expr); GlobalConst("C", expr);
@ -67,13 +62,13 @@ TEST_P(ResolverConstEvalUnaryOpTest, Test) {
EXPECT_TYPE(value->Type(), sem->Type()); EXPECT_TYPE(value->Type(), sem->Type());
auto values_flat = ScalarArgsFrom(value); auto values_flat = ScalarArgsFrom(value);
auto expected_values_flat = expected->Args(); auto expected_values_flat = expected.Args();
ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length()); ASSERT_EQ(values_flat.values.Length(), expected_values_flat.values.Length());
for (size_t i = 0; i < values_flat.values.Length(); ++i) { for (size_t i = 0; i < values_flat.values.Length(); ++i) {
auto& a = values_flat.values[i]; auto& a = values_flat.values[i];
auto& b = expected_values_flat.values[i]; auto& b = expected_values_flat.values[i];
EXPECT_EQ(a, b); EXPECT_EQ(a, b);
if (expected->IsIntegral()) { if (expected.IsIntegral()) {
// Check that the constant's integer doesn't contain unexpected // Check that the constant's integer doesn't contain unexpected
// data in the MSBs that are outside of the bit-width of T. // data in the MSBs that are outside of the bit-width of T.
EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b)); EXPECT_EQ(builder::As<AInt>(a), builder::As<AInt>(b));

View File

@ -241,12 +241,21 @@ using ast_expr_from_double_func_ptr = const ast::Expression* (*)(ProgramBuilder&
using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b); using sem_type_func_ptr = const sem::Type* (*)(ProgramBuilder& b);
using type_name_func_ptr = std::string (*)(); using type_name_func_ptr = std::string (*)();
struct UnspecializedElementType {};
/// Base template for DataType, specialized below.
template <typename T> template <typename T>
struct DataType {}; struct DataType {
/// The element type
using ElementType = UnspecializedElementType;
};
/// Helper that represents no-type. Returns nullptr for all static methods. /// Helper that represents no-type. Returns nullptr for all static methods.
template <> template <>
struct DataType<void> { struct DataType<void> {
/// The element type
using ElementType = void;
/// @return nullptr /// @return nullptr
static inline const ast::Type* AST(ProgramBuilder&) { return nullptr; } static inline const ast::Type* AST(ProgramBuilder&) { return nullptr; }
/// @return nullptr /// @return nullptr
@ -762,7 +771,13 @@ constexpr CreatePtrs CreatePtrsFor() {
DataType<T>::Name}; DataType<T>::Name};
} }
/// Base class for Value<T> /// True if DataType<T> is specialized for T, false otherwise.
template <typename T>
const bool IsDataTypeSpecializedFor =
!std::is_same_v<typename DataType<T>::ElementType, UnspecializedElementType>;
namespace detail {
/// ValueBase is a base class of ConcreteValue<T>
struct ValueBase { struct ValueBase {
/// Constructor /// Constructor
ValueBase() = default; ValueBase() = default;
@ -793,13 +808,12 @@ struct ValueBase {
virtual std::ostream& Print(std::ostream& o) const = 0; virtual std::ostream& Print(std::ostream& o) const = 0;
}; };
/// Value<T> is an instance of a value of type DataType<T>. Useful for storing values to create /// ConcreteValue<T> is used to create Values of type DataType<T> with a ScalarArgs initializer.
/// expressions with.
template <typename T> template <typename T>
struct Value : ValueBase { struct ConcreteValue : ValueBase {
/// Constructor /// Constructor
/// @param a the scalar args /// @param args the scalar args
explicit Value(ScalarArgs a) : args(std::move(a)) {} explicit ConcreteValue(ScalarArgs args) : args_(std::move(args)) {}
/// Alias to T /// Alias to T
using Type = T; using Type = T;
@ -808,21 +822,16 @@ struct Value : ValueBase {
/// Alias to DataType::ElementType /// Alias to DataType::ElementType
using ElementType = typename DataType::ElementType; using ElementType = typename DataType::ElementType;
/// Creates a Value<T> with `args`
/// @param args the args that will be passed to the expression
/// @returns a Value<T>
static Value Create(ScalarArgs args) { return Value{std::move(args)}; }
/// Creates an `ast::Expression` for the type T passing in previously stored args /// Creates an `ast::Expression` for the type T passing in previously stored args
/// @param b the ProgramBuilder /// @param b the ProgramBuilder
/// @returns an expression node /// @returns an expression node
const ast::Expression* Expr(ProgramBuilder& b) const override { const ast::Expression* Expr(ProgramBuilder& b) const override {
auto create = CreatePtrsFor<T>(); auto create = CreatePtrsFor<T>();
return (*create.expr)(b, args); return (*create.expr)(b, args_);
} }
/// @returns args used to create expression via `Expr` /// @returns args used to create expression via `Expr`
const ScalarArgs& Args() const override { return args; } const ScalarArgs& Args() const override { return args_; }
/// @returns true if element type is abstract /// @returns true if element type is abstract
bool IsAbstract() const override { return tint::IsAbstract<ElementType>; } bool IsAbstract() const override { return tint::IsAbstract<ElementType>; }
@ -838,9 +847,9 @@ struct Value : ValueBase {
/// @returns input argument `o` /// @returns input argument `o`
std::ostream& Print(std::ostream& o) const override { std::ostream& Print(std::ostream& o) const override {
o << TypeName() << "("; o << TypeName() << "(";
for (auto& a : args.values) { for (auto& a : args_.values) {
o << std::get<ElementType>(a); o << std::get<ElementType>(a);
if (&a != &args.values.Back()) { if (&a != &args_.values.Back()) {
o << ", "; o << ", ";
} }
} }
@ -848,60 +857,95 @@ struct Value : ValueBase {
return o; return o;
} }
private:
/// args to create expression with /// args to create expression with
ScalarArgs args; ScalarArgs args_;
}; };
namespace detail {
/// Base template for IsValue
template <typename T>
struct IsValue : std::false_type {};
/// Specialization for IsValue
template <typename T>
struct IsValue<Value<T>> : std::true_type {};
} // namespace detail } // namespace detail
/// True if T is of type Value /// A Value represents a value of type DataType<T> created with ScalarArgs. Useful for storing
/// values for unit tests.
class Value {
public:
/// Creates a Value for type T initialized with `args`
/// @param args the scalar args
/// @returns Value
template <typename T> template <typename T>
constexpr bool IsValue = detail::IsValue<T>::value; static Value Create(ScalarArgs args) {
static_assert(IsDataTypeSpecializedFor<T>, "No DataType<T> specialization exists");
/// Returns the friendly name of ValueT return Value{std::make_shared<detail::ConcreteValue<T>>(std::move(args))};
template <typename ValueT, typename = traits::EnableIf<IsValue<ValueT>>>
const char* FriendlyName() {
return tint::FriendlyName<typename ValueT::ElementType>();
} }
/// Creates a `Value<T>` from a scalar `v` /// Creates an `ast::Expression` for the type T passing in previously stored args
template <typename T> /// @param b the ProgramBuilder
auto Val(T v) { /// @returns an expression node
return Value<T>::Create(ScalarArgs{v}); const ast::Expression* Expr(ProgramBuilder& b) const { return value_->Expr(b); }
/// @returns args used to create expression via `Expr`
const ScalarArgs& Args() const { return value_->Args(); }
/// @returns true if element type is abstract
bool IsAbstract() const { return value_->IsAbstract(); }
/// @returns true if element type is an integral
bool IsIntegral() const { return value_->IsIntegral(); }
/// @returns element type name
std::string TypeName() const { return value_->TypeName(); }
/// Prints this value to the output stream
/// @param o the output stream
/// @returns input argument `o`
std::ostream& Print(std::ostream& o) const { return value_->Print(o); }
private:
/// Private constructor
explicit Value(std::shared_ptr<const detail::ValueBase> value) : value_(std::move(value)) {}
/// Shared pointer to an immutable value. This type-erasure pattern allows Value to wrap a
/// polymorphic type, while being used like a value-type (i.e. copyable).
std::shared_ptr<const detail::ValueBase> value_;
};
/// Prints Value to ostream
inline std::ostream& operator<<(std::ostream& o, const Value& value) {
return value.Print(o);
} }
/// Creates a `Value<vec<N, T>>` from N scalar `args` /// True if T is Value, false otherwise
template <typename T>
constexpr bool IsValue = std::is_same_v<T, Value>;
/// Creates a Value of DataType<T> from a scalar `v`
template <typename T>
Value Val(T v) {
return Value::Create<T>(ScalarArgs{v});
}
/// Creates a Value of DataType<vec<N, T>> from N scalar `args`
template <typename... T> template <typename... T>
auto Vec(T... args) { Value Vec(T... args) {
constexpr size_t N = sizeof...(args);
using FirstT = std::tuple_element_t<0, std::tuple<T...>>; using FirstT = std::tuple_element_t<0, std::tuple<T...>>;
constexpr size_t N = sizeof...(args);
utils::Vector v{args...}; utils::Vector v{args...};
using VT = vec<N, FirstT>; return Value::Create<vec<N, FirstT>>(utils::VectorRef<FirstT>{v});
return Value<VT>::Create(utils::VectorRef<FirstT>{v});
} }
/// Creates a `Value<mat<C,R,T>` from C*R scalar `args` /// Creates a Value of DataType<mat<C,R,T> from C*R scalar `args`
template <size_t C, size_t R, typename T> template <size_t C, size_t R, typename T>
auto Mat(const T (&m_in)[C][R]) { Value Mat(const T (&m_in)[C][R]) {
utils::Vector<T, C * R> m; utils::Vector<T, C * R> m;
for (uint32_t i = 0; i < C; ++i) { for (uint32_t i = 0; i < C; ++i) {
for (size_t j = 0; j < R; ++j) { for (size_t j = 0; j < R; ++j) {
m.Push(m_in[i][j]); m.Push(m_in[i][j]);
} }
} }
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m}); return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
} }
/// Creates a `Value<mat<2,R,T>` from column vectors `c0` and `c1` /// Creates a Value of DataType<mat<2,R,T> from column vectors `c0` and `c1`
template <typename T, size_t R> template <typename T, size_t R>
auto Mat(const T (&c0)[R], const T (&c1)[R]) { Value Mat(const T (&c0)[R], const T (&c1)[R]) {
constexpr size_t C = 2; constexpr size_t C = 2;
utils::Vector<T, C * R> m; utils::Vector<T, C * R> m;
for (auto v : c0) { for (auto v : c0) {
@ -910,12 +954,12 @@ auto Mat(const T (&c0)[R], const T (&c1)[R]) {
for (auto v : c1) { for (auto v : c1) {
m.Push(v); m.Push(v);
} }
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m}); return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
} }
/// Creates a `Value<mat<3,R,T>` from column vectors `c0`, `c1`, and `c2` /// Creates a Value of DataType<mat<3,R,T> from column vectors `c0`, `c1`, and `c2`
template <typename T, size_t R> template <typename T, size_t R>
auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) { Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
constexpr size_t C = 3; constexpr size_t C = 3;
utils::Vector<T, C * R> m; utils::Vector<T, C * R> m;
for (auto v : c0) { for (auto v : c0) {
@ -927,12 +971,12 @@ auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R]) {
for (auto v : c2) { for (auto v : c2) {
m.Push(v); m.Push(v);
} }
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m}); return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
} }
/// Creates a `Value<mat<4,R,T>` from column vectors `c0`, `c1`, `c2`, and `c3` /// Creates a Value of DataType<mat<4,R,T> from column vectors `c0`, `c1`, `c2`, and `c3`
template <typename T, size_t R> template <typename T, size_t R>
auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) { Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]) {
constexpr size_t C = 4; constexpr size_t C = 4;
utils::Vector<T, C * R> m; utils::Vector<T, C * R> m;
for (auto v : c0) { for (auto v : c0) {
@ -947,7 +991,7 @@ auto Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R])
for (auto v : c3) { for (auto v : c3) {
m.Push(v); m.Push(v);
} }
return Value<mat<C, R, T>>::Create(utils::VectorRef<T>{m}); return Value::Create<mat<C, R, T>>(utils::VectorRef<T>{m});
} }
} // namespace builder } // namespace builder