diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc index 40666e3341..d92b236d41 100644 --- a/src/tint/sem/type.cc +++ b/src/tint/sem/type.cc @@ -229,9 +229,29 @@ const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) { *count = a->Count(); } return a->ElemType(); + }, + [&](Default) { + if (count) { + *count = 0; + } + return nullptr; }); } +const Type* Type::DeepestElementOf(const Type* ty, uint32_t* count /* = nullptr */) { + auto el_ty = ElementOf(ty, count); + while (el_ty && ty != el_ty) { + ty = el_ty; + + uint32_t n = 0; + el_ty = ElementOf(ty, &n); + if (count) { + *count *= n; + } + } + return el_ty; +} + const sem::Type* Type::Common(Type const* const* types, size_t count) { if (count == 0) { return nullptr; diff --git a/src/tint/sem/type.h b/src/tint/sem/type.h index 99876375f8..ac86320742 100644 --- a/src/tint/sem/type.h +++ b/src/tint/sem/type.h @@ -130,11 +130,19 @@ class Type : public Castable { static uint32_t ConversionRank(const Type* from, const Type* to); /// @param ty the type to obtain the element type from - /// @param count if not null, then this is assigned the number of elements in the type - /// @returns `ty` if `ty` is an abstract or scalar, the element type if ty is a vector, matrix - /// or array, otherwise nullptr. + /// @param count if not null, then this is assigned the number of child elements in the type. + /// For example, the count of an `array, 5>` type would be 5. + /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector, + /// matrix or array, otherwise nullptr. static const Type* ElementOf(const Type* ty, uint32_t* count = nullptr); + /// @param ty the type to obtain the deepest element type from + /// @param count if not null, then this is assigned the full number of most deeply nested + /// elements in the type. For example, the count of an `array, 5>` type would be 15. + /// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector, + /// matrix, or the deepest element type if ty is an array, otherwise nullptr. + static const Type* DeepestElementOf(const Type* ty, uint32_t* count = nullptr); + /// @param types a pointer to a list of `const Type*`. /// @param count the number of types in `types`. /// @returns the lowest-ranking type that all types in `types` can be implicitly converted to, diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc index c11efea9f2..e6044533f6 100644 --- a/src/tint/sem/type_test.cc +++ b/src/tint/sem/type_test.cc @@ -104,6 +104,20 @@ TEST_F(TypeTest, ElementOf) { auto* mat2x4_f32 = create(vec4_f32, 2u); auto* mat4x2_f32 = create(vec2_f32, 4u); auto* mat4x3_f16 = create(vec3_f16, 4u); + auto* str = create(nullptr, Sym("s"), + StructMemberList{ + create( + /* declaration */ nullptr, + /* name */ Sym("x"), + /* type */ f16, + /* index */ 0u, + /* offset */ 0u, + /* align */ 4u, + /* size */ 4u), + }, + /* align*/ 4u, + /* size*/ 4u, + /* size_no_padding*/ 4u); auto* arr_i32 = create( /* element */ i32, /* count */ 5u, @@ -111,6 +125,27 @@ TEST_F(TypeTest, ElementOf) { /* size */ 5u * 4u, /* stride */ 5u * 4u, /* implicit_stride */ 5u * 4u); + auto* arr_vec3_i32 = create( + /* element */ vec3_i32, + /* count */ 5u, + /* align */ 16u, + /* size */ 5u * 16u, + /* stride */ 5u * 16u, + /* implicit_stride */ 5u * 16u); + auto* arr_mat4x3_f16 = create( + /* element */ mat4x3_f16, + /* count */ 5u, + /* align */ 64u, + /* size */ 5u * 64u, + /* stride */ 5u * 64u, + /* implicit_stride */ 5u * 64u); + auto* arr_str = create( + /* element */ str, + /* count */ 5u, + /* align */ 4u, + /* size */ 5u * 4u, + /* stride */ 5u * 4u, + /* implicit_stride */ 5u * 4u); // No count EXPECT_TYPE(Type::ElementOf(f32), f32); @@ -125,48 +160,193 @@ TEST_F(TypeTest, ElementOf) { EXPECT_TYPE(Type::ElementOf(mat2x4_f32), f32); EXPECT_TYPE(Type::ElementOf(mat4x2_f32), f32); EXPECT_TYPE(Type::ElementOf(mat4x3_f16), f16); + EXPECT_TYPE(Type::ElementOf(str), nullptr); EXPECT_TYPE(Type::ElementOf(arr_i32), i32); + EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32); + EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16), mat4x3_f16); + EXPECT_TYPE(Type::ElementOf(arr_str), str); // With count - uint32_t count = 0; + uint32_t count = 42; EXPECT_TYPE(Type::ElementOf(f32, &count), f32); EXPECT_EQ(count, 1u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(f16, &count), f16); EXPECT_EQ(count, 1u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(i32, &count), i32); EXPECT_EQ(count, 1u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(u32, &count), u32); EXPECT_EQ(count, 1u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(vec2_f32, &count), f32); EXPECT_EQ(count, 2u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(vec3_f16, &count), f16); EXPECT_EQ(count, 3u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(vec4_f32, &count), f32); EXPECT_EQ(count, 4u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(vec3_u32, &count), u32); EXPECT_EQ(count, 3u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(vec3_i32, &count), i32); EXPECT_EQ(count, 3u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), f32); EXPECT_EQ(count, 8u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), f32); EXPECT_EQ(count, 8u); - count = 0; + count = 42; EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), f16); EXPECT_EQ(count, 12u); - count = 0; + count = 42; + EXPECT_TYPE(Type::ElementOf(str, &count), nullptr); + EXPECT_EQ(count, 0u); + count = 42; EXPECT_TYPE(Type::ElementOf(arr_i32, &count), i32); EXPECT_EQ(count, 5u); + count = 42; + EXPECT_TYPE(Type::ElementOf(arr_vec3_i32, &count), vec3_i32); + EXPECT_EQ(count, 5u); + count = 42; + EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16, &count), mat4x3_f16); + EXPECT_EQ(count, 5u); + count = 42; + EXPECT_TYPE(Type::ElementOf(arr_str, &count), str); + EXPECT_EQ(count, 5u); +} + +TEST_F(TypeTest, DeepestElementOf) { + auto* f32 = create(); + auto* f16 = create(); + auto* i32 = create(); + auto* u32 = create(); + auto* vec2_f32 = create(f32, 2u); + auto* vec3_f16 = create(f16, 3u); + auto* vec4_f32 = create(f32, 4u); + auto* vec3_u32 = create(u32, 3u); + auto* vec3_i32 = create(i32, 3u); + auto* mat2x4_f32 = create(vec4_f32, 2u); + auto* mat4x2_f32 = create(vec2_f32, 4u); + auto* mat4x3_f16 = create(vec3_f16, 4u); + auto* str = create(nullptr, Sym("s"), + StructMemberList{ + create( + /* declaration */ nullptr, + /* name */ Sym("x"), + /* type */ f16, + /* index */ 0u, + /* offset */ 0u, + /* align */ 4u, + /* size */ 4u), + }, + /* align*/ 4u, + /* size*/ 4u, + /* size_no_padding*/ 4u); + auto* arr_i32 = create( + /* element */ i32, + /* count */ 5u, + /* align */ 4u, + /* size */ 5u * 4u, + /* stride */ 5u * 4u, + /* implicit_stride */ 5u * 4u); + auto* arr_vec3_i32 = create( + /* element */ vec3_i32, + /* count */ 5u, + /* align */ 16u, + /* size */ 5u * 16u, + /* stride */ 5u * 16u, + /* implicit_stride */ 5u * 16u); + auto* arr_mat4x3_f16 = create( + /* element */ mat4x3_f16, + /* count */ 5u, + /* align */ 64u, + /* size */ 5u * 64u, + /* stride */ 5u * 64u, + /* implicit_stride */ 5u * 64u); + auto* arr_str = create( + /* element */ str, + /* count */ 5u, + /* align */ 4u, + /* size */ 5u * 4u, + /* stride */ 5u * 4u, + /* implicit_stride */ 5u * 4u); + + // No count + EXPECT_TYPE(Type::DeepestElementOf(f32), f32); + EXPECT_TYPE(Type::DeepestElementOf(f16), f16); + EXPECT_TYPE(Type::DeepestElementOf(i32), i32); + EXPECT_TYPE(Type::DeepestElementOf(u32), u32); + EXPECT_TYPE(Type::DeepestElementOf(vec2_f32), f32); + EXPECT_TYPE(Type::DeepestElementOf(vec3_f16), f16); + EXPECT_TYPE(Type::DeepestElementOf(vec4_f32), f32); + EXPECT_TYPE(Type::DeepestElementOf(vec3_u32), u32); + EXPECT_TYPE(Type::DeepestElementOf(vec3_i32), i32); + EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32), f32); + EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32), f32); + EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16), f16); + EXPECT_TYPE(Type::DeepestElementOf(str), nullptr); + EXPECT_TYPE(Type::DeepestElementOf(arr_i32), i32); + EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32), i32); + EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16), f16); + EXPECT_TYPE(Type::DeepestElementOf(arr_str), nullptr); + + // With count + uint32_t count = 42; + EXPECT_TYPE(Type::DeepestElementOf(f32, &count), f32); + EXPECT_EQ(count, 1u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(f16, &count), f16); + EXPECT_EQ(count, 1u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(i32, &count), i32); + EXPECT_EQ(count, 1u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(u32, &count), u32); + EXPECT_EQ(count, 1u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(vec2_f32, &count), f32); + EXPECT_EQ(count, 2u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(vec3_f16, &count), f16); + EXPECT_EQ(count, 3u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(vec4_f32, &count), f32); + EXPECT_EQ(count, 4u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(vec3_u32, &count), u32); + EXPECT_EQ(count, 3u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(vec3_i32, &count), i32); + EXPECT_EQ(count, 3u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32, &count), f32); + EXPECT_EQ(count, 8u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32, &count), f32); + EXPECT_EQ(count, 8u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16, &count), f16); + EXPECT_EQ(count, 12u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(str, &count), nullptr); + EXPECT_EQ(count, 0u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(arr_i32, &count), i32); + EXPECT_EQ(count, 5u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32, &count), i32); + EXPECT_EQ(count, 15u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16, &count), f16); + EXPECT_EQ(count, 60u); + count = 42; + EXPECT_TYPE(Type::DeepestElementOf(arr_str, &count), nullptr); + EXPECT_EQ(count, 0u); } TEST_F(TypeTest, Common2) {