tint: add const eval array constructor tests

Bug: tint:1581
Change-Id: Ia6c4ba974b40cdff8dc28ddbd510189355ed27cb
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/115400
Reviewed-by: James Price <jrprice@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2022-12-22 16:27:43 +00:00 committed by Dawn LUCI CQ
parent 906fc9df20
commit 6b4622fb07
3 changed files with 108 additions and 1 deletions

View File

@ -1623,6 +1623,102 @@ TEST_F(ResolverConstEvalTest, Array_i32_Elements) {
EXPECT_EQ(sem->ConstantValue()->Index(3)->ValueAs<i32>(), 40_i);
}
namespace ArrayInit {
struct Case {
Value input;
};
static Case C(Value input) {
return Case{std::move(input)};
}
static std::ostream& operator<<(std::ostream& o, const Case& c) {
return o << "input: " << c.input;
}
using ResolverConstEvalArrayInitTest = ResolverTestWithParam<Case>;
TEST_P(ResolverConstEvalArrayInitTest, Test) {
Enable(ast::Extension::kF16);
auto& param = GetParam();
auto* expr = param.input.Expr(*this);
auto* a = Const("a", expr);
WrapInFunction(a);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* arr = sem->Type()->As<type::Array>();
ASSERT_NE(arr, nullptr);
EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type());
// Constant values should match input values
CheckConstant(sem->ConstantValue(), param.input);
}
template <typename T>
std::vector<Case> ArrayInitCases() {
return {
C(Array(T(0))), //
C(Array(T(0))), //
C(Array(T(0), T(1))), //
C(Array(T(0), T(1), T(2))), //
C(Array(T(2), T(1), T(0))), //
C(Array(T(2), T(0), T(1))), //
};
}
INSTANTIATE_TEST_SUITE_P( //
ArrayInit,
ResolverConstEvalArrayInitTest,
testing::ValuesIn(Concat(ArrayInitCases<AInt>(), //
ArrayInitCases<AFloat>(), //
ArrayInitCases<i32>(), //
ArrayInitCases<u32>(), //
ArrayInitCases<f32>(), //
ArrayInitCases<f16>(), //
ArrayInitCases<bool>())));
} // namespace ArrayInit
TEST_F(ResolverConstEvalTest, ArrayInit_Nested_f32) {
auto inner_ty = [&] { return ty.array<f32, 2>(); };
auto outer_ty = ty.array(inner_ty(), Expr(3_i));
auto* expr = Construct(outer_ty, //
Construct(inner_ty(), 1_f, 2_f), //
Construct(inner_ty(), 3_f, 4_f), //
Construct(inner_ty(), 5_f, 6_f));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
ASSERT_NE(sem, nullptr);
auto* outer_arr = sem->Type()->As<type::Array>();
ASSERT_NE(outer_arr, nullptr);
EXPECT_TRUE(outer_arr->ElemType()->Is<type::Array>());
EXPECT_TRUE(outer_arr->ElemType()->As<type::Array>()->ElemType()->Is<type::F32>());
auto* arr = sem->ConstantValue();
EXPECT_FALSE(arr->AllEqual());
EXPECT_FALSE(arr->AnyZero());
EXPECT_FALSE(arr->AllZero());
EXPECT_FALSE(arr->Index(0)->AllEqual());
EXPECT_FALSE(arr->Index(0)->AnyZero());
EXPECT_FALSE(arr->Index(0)->AllZero());
EXPECT_FALSE(arr->Index(1)->AllEqual());
EXPECT_FALSE(arr->Index(1)->AnyZero());
EXPECT_FALSE(arr->Index(1)->AllZero());
EXPECT_FALSE(arr->Index(2)->AllEqual());
EXPECT_FALSE(arr->Index(2)->AnyZero());
EXPECT_FALSE(arr->Index(2)->AllZero());
EXPECT_EQ(arr->Index(0)->Index(0)->ValueAs<f32>(), 1.0f);
EXPECT_EQ(arr->Index(0)->Index(1)->ValueAs<f32>(), 2.0f);
EXPECT_EQ(arr->Index(1)->Index(0)->ValueAs<f32>(), 3.0f);
EXPECT_EQ(arr->Index(1)->Index(1)->ValueAs<f32>(), 4.0f);
EXPECT_EQ(arr->Index(2)->Index(0)->ValueAs<f32>(), 5.0f);
EXPECT_EQ(arr->Index(2)->Index(1)->ValueAs<f32>(), 6.0f);
}
TEST_F(ResolverConstEvalTest, Array_f32_Elements) {
auto* expr = Construct(ty.array<f32, 4>(), 10_f, 20_f, 30_f, 40_f);
WrapInFunction(expr);

View File

@ -245,6 +245,7 @@ std::string OverflowExpErrorMessage(std::string_view base, NumberT exp) {
return ss.str();
}
using builder::Array;
using builder::IsValue;
using builder::Mat;
using builder::Val;

View File

@ -822,6 +822,17 @@ Value Vec(Ts... args) {
return Value::Create<vec<N, FirstT>>(std::move(v));
}
/// Creates a Value of DataType<array<N, T>> from N scalar `args`
template <typename... Ts>
Value Array(Ts... args) {
using FirstT = std::tuple_element_t<0, std::tuple<Ts...>>;
static_assert(std::conjunction_v<std::is_same<FirstT, Ts>...>,
"Array args must all be the same type");
constexpr size_t N = sizeof...(args);
utils::Vector<Scalar, sizeof...(args)> v{args...};
return Value::Create<array<N, FirstT>>(std::move(v));
}
/// Creates a Value of DataType<mat<C,R,T> from C*R scalar `args`
template <size_t C, size_t R, typename T>
Value Mat(const T (&m_in)[C][R]) {
@ -884,7 +895,6 @@ Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R]
}
return Value::Create<mat<C, R, T>>(std::move(m));
}
} // namespace builder
} // namespace tint::resolver