diff --git a/src/tint/sem/type.cc b/src/tint/sem/type.cc index 23f91645b0..8ace823a4b 100644 --- a/src/tint/sem/type.cc +++ b/src/tint/sem/type.cc @@ -254,9 +254,9 @@ const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) { }, [&](Default) { if (count) { - *count = 0; + *count = 1; } - return nullptr; + return ty; }); } diff --git a/src/tint/sem/type.h b/src/tint/sem/type.h index 8bac821ca4..434b9135af 100644 --- a/src/tint/sem/type.h +++ b/src/tint/sem/type.h @@ -140,21 +140,19 @@ class Type : public Castable { /// @param count if not null, then this is assigned the number of child elements in the type. /// For example, the count of an `array, 5>` type would be 5. /// @returns - /// * `ty` if `ty` is an abstract or scalar /// * the element type if `ty` is a vector or array /// * the column type if `ty` is a matrix - /// * `nullptr` if `ty` is none of the above + /// * `ty` if `ty` is none of the above static const Type* ElementOf(const Type* ty, uint32_t* count = nullptr); /// @param ty the type to obtain the deepest element type from /// @param count if not null, then this is assigned the full number of most deeply nested /// elements in the type. For example, the count of an `array, 5>` type would be 15. /// @returns - /// * `ty` if `ty` is an abstract or scalar /// * the element type if `ty` is a vector /// * the matrix element type if `ty` is a matrix /// * the deepest element type if `ty` is an array - /// * `nullptr` if `ty` is none of the above + /// * `ty` if `ty` is none of the above static const Type* DeepestElementOf(const Type* ty, uint32_t* count = nullptr); /// @param types the list of types diff --git a/src/tint/sem/type_test.cc b/src/tint/sem/type_test.cc index 5c7e97b76b..6ea97d345c 100644 --- a/src/tint/sem/type_test.cc +++ b/src/tint/sem/type_test.cc @@ -192,7 +192,7 @@ TEST_F(TypeTest, ElementOf) { EXPECT_TYPE(Type::ElementOf(mat2x4_f32), vec4_f32); EXPECT_TYPE(Type::ElementOf(mat4x2_f32), vec2_f32); EXPECT_TYPE(Type::ElementOf(mat4x3_f16), vec3_f16); - EXPECT_TYPE(Type::ElementOf(str), nullptr); + EXPECT_TYPE(Type::ElementOf(str), str); EXPECT_TYPE(Type::ElementOf(arr_i32), i32); EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32); EXPECT_TYPE(Type::ElementOf(arr_mat4x3_f16), mat4x3_f16); @@ -237,8 +237,8 @@ TEST_F(TypeTest, ElementOf) { EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), vec3_f16); EXPECT_EQ(count, 4u); count = 42; - EXPECT_TYPE(Type::ElementOf(str, &count), nullptr); - EXPECT_EQ(count, 0u); + EXPECT_TYPE(Type::ElementOf(str, &count), str); + EXPECT_EQ(count, 1u); count = 42; EXPECT_TYPE(Type::ElementOf(arr_i32, &count), i32); EXPECT_EQ(count, 5u); @@ -270,12 +270,12 @@ TEST_F(TypeTest, DeepestElementOf) { EXPECT_TYPE(Type::DeepestElementOf(mat2x4_f32), f32); EXPECT_TYPE(Type::DeepestElementOf(mat4x2_f32), f32); EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16), f16); - EXPECT_TYPE(Type::DeepestElementOf(str), nullptr); + EXPECT_TYPE(Type::DeepestElementOf(str), str); EXPECT_TYPE(Type::DeepestElementOf(arr_i32), i32); EXPECT_TYPE(Type::DeepestElementOf(arr_vec3_i32), i32); EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_f16), f16); EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_af), af); - EXPECT_TYPE(Type::DeepestElementOf(arr_str), nullptr); + EXPECT_TYPE(Type::DeepestElementOf(arr_str), str); // With count uint32_t count = 42; @@ -315,8 +315,8 @@ TEST_F(TypeTest, DeepestElementOf) { EXPECT_TYPE(Type::DeepestElementOf(mat4x3_f16, &count), f16); EXPECT_EQ(count, 12u); count = 42; - EXPECT_TYPE(Type::DeepestElementOf(str, &count), nullptr); - EXPECT_EQ(count, 0u); + EXPECT_TYPE(Type::DeepestElementOf(str, &count), str); + EXPECT_EQ(count, 1u); count = 42; EXPECT_TYPE(Type::DeepestElementOf(arr_i32, &count), i32); EXPECT_EQ(count, 5u); @@ -330,8 +330,8 @@ TEST_F(TypeTest, DeepestElementOf) { EXPECT_TYPE(Type::DeepestElementOf(arr_mat4x3_af, &count), af); EXPECT_EQ(count, 60u); count = 42; - EXPECT_TYPE(Type::DeepestElementOf(arr_str, &count), nullptr); - EXPECT_EQ(count, 0u); + EXPECT_TYPE(Type::DeepestElementOf(arr_str, &count), str); + EXPECT_EQ(count, 5u); } TEST_F(TypeTest, Common2) {