diff --git a/src/tint/resolver/const_eval_construction_test.cc b/src/tint/resolver/const_eval_construction_test.cc index bb231212d7..1462f7bf7d 100644 --- a/src/tint/resolver/const_eval_construction_test.cc +++ b/src/tint/resolver/const_eval_construction_test.cc @@ -1623,6 +1623,102 @@ TEST_F(ResolverConstEvalTest, Array_i32_Elements) { EXPECT_EQ(sem->ConstantValue()->Index(3)->ValueAs(), 40_i); } +namespace ArrayInit { +struct Case { + Value input; +}; +static Case C(Value input) { + return Case{std::move(input)}; +} +static std::ostream& operator<<(std::ostream& o, const Case& c) { + return o << "input: " << c.input; +} + +using ResolverConstEvalArrayInitTest = ResolverTestWithParam; +TEST_P(ResolverConstEvalArrayInitTest, Test) { + Enable(ast::Extension::kF16); + auto& param = GetParam(); + auto* expr = param.input.Expr(*this); + auto* a = Const("a", expr); + WrapInFunction(a); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + ASSERT_NE(sem, nullptr); + auto* arr = sem->Type()->As(); + ASSERT_NE(arr, nullptr); + + EXPECT_TYPE(sem->ConstantValue()->Type(), sem->Type()); + // Constant values should match input values + CheckConstant(sem->ConstantValue(), param.input); +} +template +std::vector ArrayInitCases() { + return { + C(Array(T(0))), // + C(Array(T(0))), // + C(Array(T(0), T(1))), // + C(Array(T(0), T(1), T(2))), // + C(Array(T(2), T(1), T(0))), // + C(Array(T(2), T(0), T(1))), // + }; +} +INSTANTIATE_TEST_SUITE_P( // + ArrayInit, + ResolverConstEvalArrayInitTest, + testing::ValuesIn(Concat(ArrayInitCases(), // + ArrayInitCases(), // + ArrayInitCases(), // + ArrayInitCases(), // + ArrayInitCases(), // + ArrayInitCases(), // + ArrayInitCases()))); +} // namespace ArrayInit + +TEST_F(ResolverConstEvalTest, ArrayInit_Nested_f32) { + auto inner_ty = [&] { return ty.array(); }; + auto outer_ty = ty.array(inner_ty(), Expr(3_i)); + + auto* expr = Construct(outer_ty, // + Construct(inner_ty(), 1_f, 2_f), // + Construct(inner_ty(), 3_f, 4_f), // + Construct(inner_ty(), 5_f, 6_f)); + + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + ASSERT_NE(sem, nullptr); + auto* outer_arr = sem->Type()->As(); + ASSERT_NE(outer_arr, nullptr); + EXPECT_TRUE(outer_arr->ElemType()->Is()); + EXPECT_TRUE(outer_arr->ElemType()->As()->ElemType()->Is()); + + auto* arr = sem->ConstantValue(); + EXPECT_FALSE(arr->AllEqual()); + EXPECT_FALSE(arr->AnyZero()); + EXPECT_FALSE(arr->AllZero()); + + EXPECT_FALSE(arr->Index(0)->AllEqual()); + EXPECT_FALSE(arr->Index(0)->AnyZero()); + EXPECT_FALSE(arr->Index(0)->AllZero()); + EXPECT_FALSE(arr->Index(1)->AllEqual()); + EXPECT_FALSE(arr->Index(1)->AnyZero()); + EXPECT_FALSE(arr->Index(1)->AllZero()); + EXPECT_FALSE(arr->Index(2)->AllEqual()); + EXPECT_FALSE(arr->Index(2)->AnyZero()); + EXPECT_FALSE(arr->Index(2)->AllZero()); + + EXPECT_EQ(arr->Index(0)->Index(0)->ValueAs(), 1.0f); + EXPECT_EQ(arr->Index(0)->Index(1)->ValueAs(), 2.0f); + EXPECT_EQ(arr->Index(1)->Index(0)->ValueAs(), 3.0f); + EXPECT_EQ(arr->Index(1)->Index(1)->ValueAs(), 4.0f); + EXPECT_EQ(arr->Index(2)->Index(0)->ValueAs(), 5.0f); + EXPECT_EQ(arr->Index(2)->Index(1)->ValueAs(), 6.0f); +} + TEST_F(ResolverConstEvalTest, Array_f32_Elements) { auto* expr = Construct(ty.array(), 10_f, 20_f, 30_f, 40_f); WrapInFunction(expr); diff --git a/src/tint/resolver/const_eval_test.h b/src/tint/resolver/const_eval_test.h index c31d89f7a8..001d3d1186 100644 --- a/src/tint/resolver/const_eval_test.h +++ b/src/tint/resolver/const_eval_test.h @@ -245,6 +245,7 @@ std::string OverflowExpErrorMessage(std::string_view base, NumberT exp) { return ss.str(); } +using builder::Array; using builder::IsValue; using builder::Mat; using builder::Val; diff --git a/src/tint/resolver/resolver_test_helper.h b/src/tint/resolver/resolver_test_helper.h index 33818f28b5..24aea11033 100644 --- a/src/tint/resolver/resolver_test_helper.h +++ b/src/tint/resolver/resolver_test_helper.h @@ -822,6 +822,17 @@ Value Vec(Ts... args) { return Value::Create>(std::move(v)); } +/// Creates a Value of DataType> from N scalar `args` +template +Value Array(Ts... args) { + using FirstT = std::tuple_element_t<0, std::tuple>; + static_assert(std::conjunction_v...>, + "Array args must all be the same type"); + constexpr size_t N = sizeof...(args); + utils::Vector v{args...}; + return Value::Create>(std::move(v)); +} + /// Creates a Value of DataType from C*R scalar `args` template Value Mat(const T (&m_in)[C][R]) { @@ -884,7 +895,6 @@ Value Mat(const T (&c0)[R], const T (&c1)[R], const T (&c2)[R], const T (&c3)[R] } return Value::Create>(std::move(m)); } - } // namespace builder } // namespace tint::resolver