diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc index 76dd62e7cb..5f4a41f50f 100644 --- a/src/tint/resolver/materialize_test.cc +++ b/src/tint/resolver/materialize_test.cc @@ -325,7 +325,7 @@ TEST_P(MaterializeAbstractNumericToConcreteType, Test) { EXPECT_TYPE(expr->ConstantValue().Type(), target_sem_ty); uint32_t num_elems = 0; - const sem::Type* target_sem_el_ty = sem::Type::ElementOf(target_sem_ty, &num_elems); + const sem::Type* target_sem_el_ty = sem::Type::DeepestElementOf(target_sem_ty, &num_elems); EXPECT_TYPE(expr->ConstantValue().ElementType(), target_sem_el_ty); expr->ConstantValue().WithElements([&](auto&& vec) { using VEC_TY = std::decay_t; @@ -738,7 +738,8 @@ TEST_P(MaterializeAbstractNumericToDefaultType, Test) { EXPECT_TYPE(expr->ConstantValue().Type(), expected_sem_ty); uint32_t num_elems = 0; - const sem::Type* expected_sem_el_ty = sem::Type::ElementOf(expected_sem_ty, &num_elems); + const sem::Type* expected_sem_el_ty = + sem::Type::DeepestElementOf(expected_sem_ty, &num_elems); EXPECT_TYPE(expr->ConstantValue().ElementType(), expected_sem_el_ty); expr->ConstantValue().WithElements([&](auto&& vec) { using VEC_TY = std::decay_t; diff --git a/src/tint/resolver/resolver_constants.cc b/src/tint/resolver/resolver_constants.cc index 95b494ed36..ad89f3e334 100644 --- a/src/tint/resolver/resolver_constants.cc +++ b/src/tint/resolver/resolver_constants.cc @@ -152,13 +152,10 @@ utils::Result MaterializeElements(const sem::Constant:: } // namespace sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) { - if (auto* e = expr->As()) { - return EvaluateConstantValue(e, type); - } - if (auto* e = expr->As()) { - return EvaluateConstantValue(e, type); - } - return {}; + return Switch( + expr, // + [&](const ast::LiteralExpression* e) { return EvaluateConstantValue(e, type); }, + [&](const ast::CallExpression* e) { return EvaluateConstantValue(e, type); }); } sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal, @@ -178,10 +175,10 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, const sem::Type* ty) { - uint32_t result_size = 0; - auto* el_ty = sem::Type::ElementOf(ty, &result_size); + uint32_t num_elems = 0; + auto* el_ty = sem::Type::DeepestElementOf(ty, &num_elems); if (!el_ty) { - return sem::Constant{}; + return {}; } // ElementOf() will also return the element type of array, which we do not support. @@ -194,16 +191,16 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, return Switch( el_ty, [&](const sem::AbstractInt*) { - return sem::Constant(ty, std::vector(result_size, AInt(0))); + return sem::Constant(ty, std::vector(num_elems, AInt(0))); }, [&](const sem::AbstractFloat*) { - return sem::Constant(ty, std::vector(result_size, AFloat(0))); + return sem::Constant(ty, std::vector(num_elems, AFloat(0))); }, - [&](const sem::I32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); }, - [&](const sem::U32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); }, - [&](const sem::F32*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); }, - [&](const sem::F16*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); }, - [&](const sem::Bool*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); }); + [&](const sem::I32*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); }, + [&](const sem::U32*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); }, + [&](const sem::F32*) { return sem::Constant(ty, std::vector(num_elems, AFloat(0))); }, + [&](const sem::F16*) { return sem::Constant(ty, std::vector(num_elems, AFloat(0))); }, + [&](const sem::Bool*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); }); } // Build value for type_ctor from each child value by converting to type_ctor's type. @@ -235,18 +232,27 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, } } - // Splat single-value initializers - std::visit( + if (!elements) { + return {}; + } + + return std::visit( [&](auto&& v) { - if (v.size() == 1) { - for (uint32_t i = 0; i < result_size - 1; ++i) { - v.emplace_back(v[0]); + if (num_elems != v.size()) { + if (v.size() == 1) { + // Splat single-value initializers + for (uint32_t i = 0; i < num_elems - 1; ++i) { + v.emplace_back(v[0]); + } + } else { + // Provided number of arguments does not match the required number of elements. + // Validation should error here. + return sem::Constant{}; } } + return sem::Constant(ty, std::move(elements.value())); }, elements.value()); - - return sem::Constant(ty, std::move(elements.value())); } utils::Result Resolver::ConvertValue(const sem::Constant& value, @@ -256,7 +262,7 @@ utils::Result Resolver::ConvertValue(const sem::Constant& value, return value; } - auto* el_ty = sem::Type::ElementOf(ty); + auto* el_ty = sem::Type::DeepestElementOf(ty); if (el_ty == nullptr) { return sem::Constant{}; } diff --git a/src/tint/resolver/resolver_constants_test.cc b/src/tint/resolver/resolver_constants_test.cc index bbdbfeafd4..c937fa1ff7 100644 --- a/src/tint/resolver/resolver_constants_test.cc +++ b/src/tint/resolver/resolver_constants_test.cc @@ -93,9 +93,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -112,9 +113,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -131,9 +133,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -150,9 +153,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -169,9 +173,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -188,9 +193,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_u32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -207,9 +213,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_f32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -226,9 +233,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_bool) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -245,9 +253,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -264,9 +273,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -283,9 +293,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -302,9 +313,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -321,9 +333,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -340,9 +353,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -359,9 +373,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -378,9 +393,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -397,9 +413,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_f32_to_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -416,9 +433,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_u32_to_f32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -435,9 +453,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_Large_f32_to_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -454,9 +473,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_Large_f32_to_u32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -478,9 +498,10 @@ TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Large_f32_to_f16) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -500,9 +521,10 @@ TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Small_f32_to_f16) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -511,5 +533,80 @@ TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Small_f32_to_f16) { EXPECT_EQ(sem->ConstantValue().Element(2).value, 0.0); } +TEST_F(ResolverConstantsTest, Mat2x3_ZeroInit_f32) { + auto* expr = mat2x3(); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + EXPECT_NE(sem, nullptr); + auto* mat = sem->Type()->As(); + ASSERT_NE(mat, nullptr); + EXPECT_TRUE(mat->type()->Is()); + EXPECT_EQ(mat->columns(), 2u); + EXPECT_EQ(mat->rows(), 3u); + EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); + EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 0._f); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 0._f); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 0._f); + EXPECT_EQ(sem->ConstantValue().Element(3).value, 0._f); + EXPECT_EQ(sem->ConstantValue().Element(4).value, 0._f); + EXPECT_EQ(sem->ConstantValue().Element(5).value, 0._f); +} + +TEST_F(ResolverConstantsTest, Mat3x2_Construct_Scalars_af) { + auto* expr = Construct(ty.mat(nullptr, 3, 2), 1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + EXPECT_NE(sem, nullptr); + auto* mat = sem->Type()->As(); + ASSERT_NE(mat, nullptr); + EXPECT_TRUE(mat->type()->Is()); + EXPECT_EQ(mat->columns(), 3u); + EXPECT_EQ(mat->rows(), 2u); + EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); + EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1._a); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2._a); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3._a); + EXPECT_EQ(sem->ConstantValue().Element(3).value, 4._a); + EXPECT_EQ(sem->ConstantValue().Element(4).value, 5._a); + EXPECT_EQ(sem->ConstantValue().Element(5).value, 6._a); +} + +TEST_F(ResolverConstantsTest, Mat3x2_Construct_Columns_af) { + auto* expr = Construct(ty.mat(nullptr, 3, 2), // + vec(nullptr, 2u, 1.0_a, 2.0_a), // + vec(nullptr, 2u, 3.0_a, 4.0_a), // + vec(nullptr, 2u, 5.0_a, 6.0_a)); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + EXPECT_NE(sem, nullptr); + auto* mat = sem->Type()->As(); + ASSERT_NE(mat, nullptr); + EXPECT_TRUE(mat->type()->Is()); + EXPECT_EQ(mat->columns(), 3u); + EXPECT_EQ(mat->rows(), 2u); + EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); + EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1._a); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2._a); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3._a); + EXPECT_EQ(sem->ConstantValue().Element(3).value, 4._a); + EXPECT_EQ(sem->ConstantValue().Element(4).value, 5._a); + EXPECT_EQ(sem->ConstantValue().Element(5).value, 6._a); +} + } // namespace } // namespace tint::resolver diff --git a/src/tint/sem/constant.cc b/src/tint/sem/constant.cc index 80869920cc..921e9a7f91 100644 --- a/src/tint/sem/constant.cc +++ b/src/tint/sem/constant.cc @@ -98,7 +98,7 @@ const Type* Constant::CheckElemType(const sem::Type* ty, size_t num_elements) { diag::List diag; if (ty->is_abstract_or_scalar() || ty->IsAnyOf()) { uint32_t count = 0; - auto* el_ty = Type::ElementOf(ty, &count); + auto* el_ty = Type::DeepestElementOf(ty, &count); if (num_elements != count) { TINT_ICE(Semantic, diag) << "sem::Constant() type <-> element mismatch. type: '" << ty->TypeInfo().name << "' element: " << num_elements; diff --git a/src/tint/sem/constant.h b/src/tint/sem/constant.h index 4f7132a5ef..42683cf7ec 100644 --- a/src/tint/sem/constant.h +++ b/src/tint/sem/constant.h @@ -104,10 +104,10 @@ class Constant { return std::visit([](auto&& v) { return v.size(); }, elems_); } - /// @returns the element type of the Constant + /// @returns the flattened element type of the Constant const sem::Type* ElementType() const { return elem_type_; } - /// @returns the constant's elements + /// @returns the constant's flattened elements const Elements& GetElements() const { return elems_; } /// WithElements calls the function `f` with the vector of elements as either AFloats or AInts diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc index d92b236d41..9d4c4699db 100644 --- a/src/tint/sem/type.cc +++ b/src/tint/sem/type.cc @@ -220,9 +220,9 @@ const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) { }, [&](const Matrix* m) { if (count) { - *count = m->columns() * m->rows(); + *count = m->columns(); } - return m->type(); + return m->ColumnType(); }, [&](const Array* a) { if (count) { diff --git a/src/tint/sem/type.h b/src/tint/sem/type.h index ac86320742..25f3a438a7 100644 --- a/src/tint/sem/type.h +++ b/src/tint/sem/type.h @@ -132,15 +132,22 @@ class Type : public Castable { /// @param ty the type to obtain the element type from /// @param count if not null, then this is assigned the number of child elements in the type. /// For example, the count of an `array, 5>` type would be 5. - /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector, - /// matrix or array, otherwise nullptr. + /// @returns + /// * `ty` if `ty` is an abstract or scalar + /// * the element type if `ty` is a vector or array + /// * the column type if `ty` is a matrix + /// * `nullptr` if `ty` is none of the above static const Type* ElementOf(const Type* ty, uint32_t* count = nullptr); /// @param ty the type to obtain the deepest element type from /// @param count if not null, then this is assigned the full number of most deeply nested /// elements in the type. For example, the count of an `array, 5>` type would be 15. - /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector, - /// matrix, or the deepest element type if ty is an array, otherwise nullptr. + /// @returns + /// * `ty` if `ty` is an abstract or scalar + /// * the element type if `ty` is a vector + /// * the matrix element type if `ty` is a matrix + /// * the deepest element type if `ty` is an array + /// * `nullptr` if `ty` is none of the above static const Type* DeepestElementOf(const Type* ty, uint32_t* count = nullptr); /// @param types a pointer to a list of `const Type*`. diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc index e6044533f6..8458880534 100644 --- a/src/tint/sem/type_test.cc +++ b/src/tint/sem/type_test.cc @@ -157,9 +157,9 @@ TEST_F(TypeTest, ElementOf) { EXPECT_TYPE(Type::ElementOf(vec4_f32), f32); EXPECT_TYPE(Type::ElementOf(vec3_u32), u32); EXPECT_TYPE(Type::ElementOf(vec3_i32), i32); - EXPECT_TYPE(Type::ElementOf(mat2x4_f32), f32); - EXPECT_TYPE(Type::ElementOf(mat4x2_f32), f32); - EXPECT_TYPE(Type::ElementOf(mat4x3_f16), f16); + EXPECT_TYPE(Type::ElementOf(mat2x4_f32), vec4_f32); + EXPECT_TYPE(Type::ElementOf(mat4x2_f32), vec2_f32); + EXPECT_TYPE(Type::ElementOf(mat4x3_f16), vec3_f16); EXPECT_TYPE(Type::ElementOf(str), nullptr); EXPECT_TYPE(Type::ElementOf(arr_i32), i32); EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32); @@ -195,14 +195,14 @@ TEST_F(TypeTest, ElementOf) { EXPECT_TYPE(Type::ElementOf(vec3_i32, &count), i32); EXPECT_EQ(count, 3u); count = 42; - EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), f32); - EXPECT_EQ(count, 8u); + EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), vec4_f32); + EXPECT_EQ(count, 2u); count = 42; - EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), f32); - EXPECT_EQ(count, 8u); + EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), vec2_f32); + EXPECT_EQ(count, 4u); count = 42; - EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), f16); - EXPECT_EQ(count, 12u); + EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), vec3_f16); + EXPECT_EQ(count, 4u); count = 42; EXPECT_TYPE(Type::ElementOf(str, &count), nullptr); EXPECT_EQ(count, 0u);