From 797f0f82e0521a4aaaad3004dc281df6581c693e Mon Sep 17 00:00:00 2001 From: Ben Clayton Date: Wed, 22 Jun 2022 00:02:03 +0000 Subject: [PATCH] tint/sem: Return vector for Type::ElementOf(matrix) Return the column vector type, instead of the column vector element type. This matches what you'd get if you were to index the matrix. DeepestElementOf() can be used to easily obtain the matrix column element type. Change-Id: I5293f4cca205c9e378253ac67880bf9d998814aa Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94327 Reviewed-by: Dan Sinclair Commit-Queue: Ben Clayton --- src/tint/resolver/materialize_test.cc | 5 +- src/tint/resolver/resolver_constants.cc | 56 +++-- src/tint/resolver/resolver_constants_test.cc | 229 +++++++++++++------ src/tint/sem/constant.cc | 2 +- src/tint/sem/constant.h | 4 +- src/tint/sem/type.cc | 4 +- src/tint/sem/type.h | 15 +- src/tint/sem/type_test.cc | 18 +- 8 files changed, 222 insertions(+), 111 deletions(-) diff --git a/src/tint/resolver/materialize_test.cc b/src/tint/resolver/materialize_test.cc index 76dd62e7cb..5f4a41f50f 100644 --- a/src/tint/resolver/materialize_test.cc +++ b/src/tint/resolver/materialize_test.cc @@ -325,7 +325,7 @@ TEST_P(MaterializeAbstractNumericToConcreteType, Test) { EXPECT_TYPE(expr->ConstantValue().Type(), target_sem_ty); uint32_t num_elems = 0; - const sem::Type* target_sem_el_ty = sem::Type::ElementOf(target_sem_ty, &num_elems); + const sem::Type* target_sem_el_ty = sem::Type::DeepestElementOf(target_sem_ty, &num_elems); EXPECT_TYPE(expr->ConstantValue().ElementType(), target_sem_el_ty); expr->ConstantValue().WithElements([&](auto&& vec) { using VEC_TY = std::decay_t; @@ -738,7 +738,8 @@ TEST_P(MaterializeAbstractNumericToDefaultType, Test) { EXPECT_TYPE(expr->ConstantValue().Type(), expected_sem_ty); uint32_t num_elems = 0; - const sem::Type* expected_sem_el_ty = sem::Type::ElementOf(expected_sem_ty, &num_elems); + const sem::Type* expected_sem_el_ty = + sem::Type::DeepestElementOf(expected_sem_ty, &num_elems); EXPECT_TYPE(expr->ConstantValue().ElementType(), expected_sem_el_ty); expr->ConstantValue().WithElements([&](auto&& vec) { using VEC_TY = std::decay_t; diff --git a/src/tint/resolver/resolver_constants.cc b/src/tint/resolver/resolver_constants.cc index 95b494ed36..ad89f3e334 100644 --- a/src/tint/resolver/resolver_constants.cc +++ b/src/tint/resolver/resolver_constants.cc @@ -152,13 +152,10 @@ utils::Result MaterializeElements(const sem::Constant:: } // namespace sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) { - if (auto* e = expr->As()) { - return EvaluateConstantValue(e, type); - } - if (auto* e = expr->As()) { - return EvaluateConstantValue(e, type); - } - return {}; + return Switch( + expr, // + [&](const ast::LiteralExpression* e) { return EvaluateConstantValue(e, type); }, + [&](const ast::CallExpression* e) { return EvaluateConstantValue(e, type); }); } sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal, @@ -178,10 +175,10 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, const sem::Type* ty) { - uint32_t result_size = 0; - auto* el_ty = sem::Type::ElementOf(ty, &result_size); + uint32_t num_elems = 0; + auto* el_ty = sem::Type::DeepestElementOf(ty, &num_elems); if (!el_ty) { - return sem::Constant{}; + return {}; } // ElementOf() will also return the element type of array, which we do not support. @@ -194,16 +191,16 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, return Switch( el_ty, [&](const sem::AbstractInt*) { - return sem::Constant(ty, std::vector(result_size, AInt(0))); + return sem::Constant(ty, std::vector(num_elems, AInt(0))); }, [&](const sem::AbstractFloat*) { - return sem::Constant(ty, std::vector(result_size, AFloat(0))); + return sem::Constant(ty, std::vector(num_elems, AFloat(0))); }, - [&](const sem::I32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); }, - [&](const sem::U32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); }, - [&](const sem::F32*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); }, - [&](const sem::F16*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); }, - [&](const sem::Bool*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); }); + [&](const sem::I32*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); }, + [&](const sem::U32*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); }, + [&](const sem::F32*) { return sem::Constant(ty, std::vector(num_elems, AFloat(0))); }, + [&](const sem::F16*) { return sem::Constant(ty, std::vector(num_elems, AFloat(0))); }, + [&](const sem::Bool*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); }); } // Build value for type_ctor from each child value by converting to type_ctor's type. @@ -235,18 +232,27 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call, } } - // Splat single-value initializers - std::visit( + if (!elements) { + return {}; + } + + return std::visit( [&](auto&& v) { - if (v.size() == 1) { - for (uint32_t i = 0; i < result_size - 1; ++i) { - v.emplace_back(v[0]); + if (num_elems != v.size()) { + if (v.size() == 1) { + // Splat single-value initializers + for (uint32_t i = 0; i < num_elems - 1; ++i) { + v.emplace_back(v[0]); + } + } else { + // Provided number of arguments does not match the required number of elements. + // Validation should error here. + return sem::Constant{}; } } + return sem::Constant(ty, std::move(elements.value())); }, elements.value()); - - return sem::Constant(ty, std::move(elements.value())); } utils::Result Resolver::ConvertValue(const sem::Constant& value, @@ -256,7 +262,7 @@ utils::Result Resolver::ConvertValue(const sem::Constant& value, return value; } - auto* el_ty = sem::Type::ElementOf(ty); + auto* el_ty = sem::Type::DeepestElementOf(ty); if (el_ty == nullptr) { return sem::Constant{}; } diff --git a/src/tint/resolver/resolver_constants_test.cc b/src/tint/resolver/resolver_constants_test.cc index bbdbfeafd4..c937fa1ff7 100644 --- a/src/tint/resolver/resolver_constants_test.cc +++ b/src/tint/resolver/resolver_constants_test.cc @@ -93,9 +93,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -112,9 +113,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -131,9 +133,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -150,9 +153,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -169,9 +173,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -188,9 +193,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_u32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -207,9 +213,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_f32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -226,9 +233,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_bool) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -245,9 +253,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -264,9 +273,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -283,9 +293,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -302,9 +313,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -321,9 +333,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -340,9 +353,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -359,9 +373,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -378,9 +393,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -397,9 +413,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_f32_to_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -416,9 +433,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_u32_to_f32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -435,9 +453,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_Large_f32_to_i32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -454,9 +473,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_Large_f32_to_u32) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -478,9 +498,10 @@ TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Large_f32_to_f16) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -500,9 +521,10 @@ TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Small_f32_to_f16) { auto* sem = Sem().Get(expr); EXPECT_NE(sem, nullptr); - ASSERT_TRUE(sem->Type()->Is()); - EXPECT_TRUE(sem->Type()->As()->type()->Is()); - EXPECT_EQ(sem->Type()->As()->Width(), 3u); + auto* vec = sem->Type()->As(); + ASSERT_NE(vec, nullptr); + EXPECT_TRUE(vec->type()->Is()); + EXPECT_EQ(vec->Width(), 3u); EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u); @@ -511,5 +533,80 @@ TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Small_f32_to_f16) { EXPECT_EQ(sem->ConstantValue().Element(2).value, 0.0); } +TEST_F(ResolverConstantsTest, Mat2x3_ZeroInit_f32) { + auto* expr = mat2x3(); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + EXPECT_NE(sem, nullptr); + auto* mat = sem->Type()->As(); + ASSERT_NE(mat, nullptr); + EXPECT_TRUE(mat->type()->Is()); + EXPECT_EQ(mat->columns(), 2u); + EXPECT_EQ(mat->rows(), 3u); + EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); + EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 0._f); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 0._f); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 0._f); + EXPECT_EQ(sem->ConstantValue().Element(3).value, 0._f); + EXPECT_EQ(sem->ConstantValue().Element(4).value, 0._f); + EXPECT_EQ(sem->ConstantValue().Element(5).value, 0._f); +} + +TEST_F(ResolverConstantsTest, Mat3x2_Construct_Scalars_af) { + auto* expr = Construct(ty.mat(nullptr, 3, 2), 1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + EXPECT_NE(sem, nullptr); + auto* mat = sem->Type()->As(); + ASSERT_NE(mat, nullptr); + EXPECT_TRUE(mat->type()->Is()); + EXPECT_EQ(mat->columns(), 3u); + EXPECT_EQ(mat->rows(), 2u); + EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); + EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1._a); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2._a); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3._a); + EXPECT_EQ(sem->ConstantValue().Element(3).value, 4._a); + EXPECT_EQ(sem->ConstantValue().Element(4).value, 5._a); + EXPECT_EQ(sem->ConstantValue().Element(5).value, 6._a); +} + +TEST_F(ResolverConstantsTest, Mat3x2_Construct_Columns_af) { + auto* expr = Construct(ty.mat(nullptr, 3, 2), // + vec(nullptr, 2u, 1.0_a, 2.0_a), // + vec(nullptr, 2u, 3.0_a, 4.0_a), // + vec(nullptr, 2u, 5.0_a, 6.0_a)); + WrapInFunction(expr); + + EXPECT_TRUE(r()->Resolve()) << r()->error(); + + auto* sem = Sem().Get(expr); + EXPECT_NE(sem, nullptr); + auto* mat = sem->Type()->As(); + ASSERT_NE(mat, nullptr); + EXPECT_TRUE(mat->type()->Is()); + EXPECT_EQ(mat->columns(), 3u); + EXPECT_EQ(mat->rows(), 2u); + EXPECT_EQ(sem->ConstantValue().Type(), sem->Type()); + EXPECT_TRUE(sem->ConstantValue().ElementType()->Is()); + ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u); + EXPECT_EQ(sem->ConstantValue().Element(0).value, 1._a); + EXPECT_EQ(sem->ConstantValue().Element(1).value, 2._a); + EXPECT_EQ(sem->ConstantValue().Element(2).value, 3._a); + EXPECT_EQ(sem->ConstantValue().Element(3).value, 4._a); + EXPECT_EQ(sem->ConstantValue().Element(4).value, 5._a); + EXPECT_EQ(sem->ConstantValue().Element(5).value, 6._a); +} + } // namespace } // namespace tint::resolver diff --git a/src/tint/sem/constant.cc b/src/tint/sem/constant.cc index 80869920cc..921e9a7f91 100644 --- a/src/tint/sem/constant.cc +++ b/src/tint/sem/constant.cc @@ -98,7 +98,7 @@ const Type* Constant::CheckElemType(const sem::Type* ty, size_t num_elements) { diag::List diag; if (ty->is_abstract_or_scalar() || ty->IsAnyOf()) { uint32_t count = 0; - auto* el_ty = Type::ElementOf(ty, &count); + auto* el_ty = Type::DeepestElementOf(ty, &count); if (num_elements != count) { TINT_ICE(Semantic, diag) << "sem::Constant() type <-> element mismatch. type: '" << ty->TypeInfo().name << "' element: " << num_elements; diff --git a/src/tint/sem/constant.h b/src/tint/sem/constant.h index 4f7132a5ef..42683cf7ec 100644 --- a/src/tint/sem/constant.h +++ b/src/tint/sem/constant.h @@ -104,10 +104,10 @@ class Constant { return std::visit([](auto&& v) { return v.size(); }, elems_); } - /// @returns the element type of the Constant + /// @returns the flattened element type of the Constant const sem::Type* ElementType() const { return elem_type_; } - /// @returns the constant's elements + /// @returns the constant's flattened elements const Elements& GetElements() const { return elems_; } /// WithElements calls the function `f` with the vector of elements as either AFloats or AInts diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc index d92b236d41..9d4c4699db 100644 --- a/src/tint/sem/type.cc +++ b/src/tint/sem/type.cc @@ -220,9 +220,9 @@ const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) { }, [&](const Matrix* m) { if (count) { - *count = m->columns() * m->rows(); + *count = m->columns(); } - return m->type(); + return m->ColumnType(); }, [&](const Array* a) { if (count) { diff --git a/src/tint/sem/type.h b/src/tint/sem/type.h index ac86320742..25f3a438a7 100644 --- a/src/tint/sem/type.h +++ b/src/tint/sem/type.h @@ -132,15 +132,22 @@ class Type : public Castable { /// @param ty the type to obtain the element type from /// @param count if not null, then this is assigned the number of child elements in the type. /// For example, the count of an `array, 5>` type would be 5. - /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector, - /// matrix or array, otherwise nullptr. + /// @returns + /// * `ty` if `ty` is an abstract or scalar + /// * the element type if `ty` is a vector or array + /// * the column type if `ty` is a matrix + /// * `nullptr` if `ty` is none of the above static const Type* ElementOf(const Type* ty, uint32_t* count = nullptr); /// @param ty the type to obtain the deepest element type from /// @param count if not null, then this is assigned the full number of most deeply nested /// elements in the type. For example, the count of an `array, 5>` type would be 15. - /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector, - /// matrix, or the deepest element type if ty is an array, otherwise nullptr. + /// @returns + /// * `ty` if `ty` is an abstract or scalar + /// * the element type if `ty` is a vector + /// * the matrix element type if `ty` is a matrix + /// * the deepest element type if `ty` is an array + /// * `nullptr` if `ty` is none of the above static const Type* DeepestElementOf(const Type* ty, uint32_t* count = nullptr); /// @param types a pointer to a list of `const Type*`. diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc index e6044533f6..8458880534 100644 --- a/src/tint/sem/type_test.cc +++ b/src/tint/sem/type_test.cc @@ -157,9 +157,9 @@ TEST_F(TypeTest, ElementOf) { EXPECT_TYPE(Type::ElementOf(vec4_f32), f32); EXPECT_TYPE(Type::ElementOf(vec3_u32), u32); EXPECT_TYPE(Type::ElementOf(vec3_i32), i32); - EXPECT_TYPE(Type::ElementOf(mat2x4_f32), f32); - EXPECT_TYPE(Type::ElementOf(mat4x2_f32), f32); - EXPECT_TYPE(Type::ElementOf(mat4x3_f16), f16); + EXPECT_TYPE(Type::ElementOf(mat2x4_f32), vec4_f32); + EXPECT_TYPE(Type::ElementOf(mat4x2_f32), vec2_f32); + EXPECT_TYPE(Type::ElementOf(mat4x3_f16), vec3_f16); EXPECT_TYPE(Type::ElementOf(str), nullptr); EXPECT_TYPE(Type::ElementOf(arr_i32), i32); EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32); @@ -195,14 +195,14 @@ TEST_F(TypeTest, ElementOf) { EXPECT_TYPE(Type::ElementOf(vec3_i32, &count), i32); EXPECT_EQ(count, 3u); count = 42; - EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), f32); - EXPECT_EQ(count, 8u); + EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), vec4_f32); + EXPECT_EQ(count, 2u); count = 42; - EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), f32); - EXPECT_EQ(count, 8u); + EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), vec2_f32); + EXPECT_EQ(count, 4u); count = 42; - EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), f16); - EXPECT_EQ(count, 12u); + EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), vec3_f16); + EXPECT_EQ(count, 4u); count = 42; EXPECT_TYPE(Type::ElementOf(str, &count), nullptr); EXPECT_EQ(count, 0u);