tint/sem: Return vector for Type::ElementOf(matrix)

Return the column vector type, instead of the column vector element
type. This matches what you'd get if you were to index the matrix.

DeepestElementOf() can be used to easily obtain the matrix column
element type.

Change-Id: I5293f4cca205c9e378253ac67880bf9d998814aa
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/94327
Reviewed-by: Dan Sinclair <dsinclair@chromium.org>
Commit-Queue: Ben Clayton <bclayton@google.com>
This commit is contained in:
Ben Clayton 2022-06-22 00:02:03 +00:00 committed by Dawn LUCI CQ
parent 9be6b99a4b
commit 797f0f82e0
8 changed files with 222 additions and 111 deletions

View File

@ -325,7 +325,7 @@ TEST_P(MaterializeAbstractNumericToConcreteType, Test) {
EXPECT_TYPE(expr->ConstantValue().Type(), target_sem_ty);
uint32_t num_elems = 0;
const sem::Type* target_sem_el_ty = sem::Type::ElementOf(target_sem_ty, &num_elems);
const sem::Type* target_sem_el_ty = sem::Type::DeepestElementOf(target_sem_ty, &num_elems);
EXPECT_TYPE(expr->ConstantValue().ElementType(), target_sem_el_ty);
expr->ConstantValue().WithElements([&](auto&& vec) {
using VEC_TY = std::decay_t<decltype(vec)>;
@ -738,7 +738,8 @@ TEST_P(MaterializeAbstractNumericToDefaultType, Test) {
EXPECT_TYPE(expr->ConstantValue().Type(), expected_sem_ty);
uint32_t num_elems = 0;
const sem::Type* expected_sem_el_ty = sem::Type::ElementOf(expected_sem_ty, &num_elems);
const sem::Type* expected_sem_el_ty =
sem::Type::DeepestElementOf(expected_sem_ty, &num_elems);
EXPECT_TYPE(expr->ConstantValue().ElementType(), expected_sem_el_ty);
expr->ConstantValue().WithElements([&](auto&& vec) {
using VEC_TY = std::decay_t<decltype(vec)>;

View File

@ -152,13 +152,10 @@ utils::Result<sem::Constant::Elements> MaterializeElements(const sem::Constant::
} // namespace
sem::Constant Resolver::EvaluateConstantValue(const ast::Expression* expr, const sem::Type* type) {
if (auto* e = expr->As<ast::LiteralExpression>()) {
return EvaluateConstantValue(e, type);
}
if (auto* e = expr->As<ast::CallExpression>()) {
return EvaluateConstantValue(e, type);
}
return {};
return Switch(
expr, //
[&](const ast::LiteralExpression* e) { return EvaluateConstantValue(e, type); },
[&](const ast::CallExpression* e) { return EvaluateConstantValue(e, type); });
}
sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* literal,
@ -178,10 +175,10 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::LiteralExpression* lite
sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
const sem::Type* ty) {
uint32_t result_size = 0;
auto* el_ty = sem::Type::ElementOf(ty, &result_size);
uint32_t num_elems = 0;
auto* el_ty = sem::Type::DeepestElementOf(ty, &num_elems);
if (!el_ty) {
return sem::Constant{};
return {};
}
// ElementOf() will also return the element type of array, which we do not support.
@ -194,16 +191,16 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
return Switch(
el_ty,
[&](const sem::AbstractInt*) {
return sem::Constant(ty, std::vector(result_size, AInt(0)));
return sem::Constant(ty, std::vector(num_elems, AInt(0)));
},
[&](const sem::AbstractFloat*) {
return sem::Constant(ty, std::vector(result_size, AFloat(0)));
return sem::Constant(ty, std::vector(num_elems, AFloat(0)));
},
[&](const sem::I32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); },
[&](const sem::U32*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); },
[&](const sem::F32*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); },
[&](const sem::F16*) { return sem::Constant(ty, std::vector(result_size, AFloat(0))); },
[&](const sem::Bool*) { return sem::Constant(ty, std::vector(result_size, AInt(0))); });
[&](const sem::I32*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); },
[&](const sem::U32*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); },
[&](const sem::F32*) { return sem::Constant(ty, std::vector(num_elems, AFloat(0))); },
[&](const sem::F16*) { return sem::Constant(ty, std::vector(num_elems, AFloat(0))); },
[&](const sem::Bool*) { return sem::Constant(ty, std::vector(num_elems, AInt(0))); });
}
// Build value for type_ctor from each child value by converting to type_ctor's type.
@ -235,18 +232,27 @@ sem::Constant Resolver::EvaluateConstantValue(const ast::CallExpression* call,
}
}
// Splat single-value initializers
std::visit(
if (!elements) {
return {};
}
return std::visit(
[&](auto&& v) {
if (v.size() == 1) {
for (uint32_t i = 0; i < result_size - 1; ++i) {
v.emplace_back(v[0]);
if (num_elems != v.size()) {
if (v.size() == 1) {
// Splat single-value initializers
for (uint32_t i = 0; i < num_elems - 1; ++i) {
v.emplace_back(v[0]);
}
} else {
// Provided number of arguments does not match the required number of elements.
// Validation should error here.
return sem::Constant{};
}
}
return sem::Constant(ty, std::move(elements.value()));
},
elements.value());
return sem::Constant(ty, std::move(elements.value()));
}
utils::Result<sem::Constant> Resolver::ConvertValue(const sem::Constant& value,
@ -256,7 +262,7 @@ utils::Result<sem::Constant> Resolver::ConvertValue(const sem::Constant& value,
return value;
}
auto* el_ty = sem::Type::ElementOf(ty);
auto* el_ty = sem::Type::DeepestElementOf(ty);
if (el_ty == nullptr) {
return sem::Constant{};
}

View File

@ -93,9 +93,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_i32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -112,9 +113,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_u32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::U32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -131,9 +133,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_f32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -150,9 +153,10 @@ TEST_F(ResolverConstantsTest, Vec3_ZeroInit_bool) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::Bool>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -169,9 +173,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_i32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -188,9 +193,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_u32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::U32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -207,9 +213,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_f32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -226,9 +233,10 @@ TEST_F(ResolverConstantsTest, Vec3_Splat_bool) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::Bool>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -245,9 +253,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_i32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -264,9 +273,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_u32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::U32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -283,9 +293,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_f32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -302,9 +313,10 @@ TEST_F(ResolverConstantsTest, Vec3_FullConstruct_bool) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::Bool>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -321,9 +333,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_i32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -340,9 +353,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_u32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::U32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -359,9 +373,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_f32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -378,9 +393,10 @@ TEST_F(ResolverConstantsTest, Vec3_MixConstruct_bool) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::Bool>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::Bool>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::Bool>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -397,9 +413,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_f32_to_i32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -416,9 +433,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_u32_to_f32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -435,9 +453,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_Large_f32_to_i32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::I32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::I32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::I32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -454,9 +473,10 @@ TEST_F(ResolverConstantsTest, Vec3_Convert_Large_f32_to_u32) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::U32>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::U32>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::U32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -478,9 +498,10 @@ TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Large_f32_to_f16) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F16>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F16>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F16>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -500,9 +521,10 @@ TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Small_f32_to_f16) {
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
ASSERT_TRUE(sem->Type()->Is<sem::Vector>());
EXPECT_TRUE(sem->Type()->As<sem::Vector>()->type()->Is<sem::F16>());
EXPECT_EQ(sem->Type()->As<sem::Vector>()->Width(), 3u);
auto* vec = sem->Type()->As<sem::Vector>();
ASSERT_NE(vec, nullptr);
EXPECT_TRUE(vec->type()->Is<sem::F16>());
EXPECT_EQ(vec->Width(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F16>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 3u);
@ -511,5 +533,80 @@ TEST_F(ResolverConstantsTest, DISABLED_Vec3_Convert_Small_f32_to_f16) {
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 0.0);
}
TEST_F(ResolverConstantsTest, Mat2x3_ZeroInit_f32) {
auto* expr = mat2x3<f32>();
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
auto* mat = sem->Type()->As<sem::Matrix>();
ASSERT_NE(mat, nullptr);
EXPECT_TRUE(mat->type()->Is<sem::F32>());
EXPECT_EQ(mat->columns(), 2u);
EXPECT_EQ(mat->rows(), 3u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u);
EXPECT_EQ(sem->ConstantValue().Element<f32>(0).value, 0._f);
EXPECT_EQ(sem->ConstantValue().Element<f32>(1).value, 0._f);
EXPECT_EQ(sem->ConstantValue().Element<f32>(2).value, 0._f);
EXPECT_EQ(sem->ConstantValue().Element<f32>(3).value, 0._f);
EXPECT_EQ(sem->ConstantValue().Element<f32>(4).value, 0._f);
EXPECT_EQ(sem->ConstantValue().Element<f32>(5).value, 0._f);
}
TEST_F(ResolverConstantsTest, Mat3x2_Construct_Scalars_af) {
auto* expr = Construct(ty.mat(nullptr, 3, 2), 1.0_a, 2.0_a, 3.0_a, 4.0_a, 5.0_a, 6.0_a);
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
auto* mat = sem->Type()->As<sem::Matrix>();
ASSERT_NE(mat, nullptr);
EXPECT_TRUE(mat->type()->Is<sem::F32>());
EXPECT_EQ(mat->columns(), 3u);
EXPECT_EQ(mat->rows(), 2u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 1._a);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 2._a);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 3._a);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(3).value, 4._a);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(4).value, 5._a);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(5).value, 6._a);
}
TEST_F(ResolverConstantsTest, Mat3x2_Construct_Columns_af) {
auto* expr = Construct(ty.mat(nullptr, 3, 2), //
vec(nullptr, 2u, 1.0_a, 2.0_a), //
vec(nullptr, 2u, 3.0_a, 4.0_a), //
vec(nullptr, 2u, 5.0_a, 6.0_a));
WrapInFunction(expr);
EXPECT_TRUE(r()->Resolve()) << r()->error();
auto* sem = Sem().Get(expr);
EXPECT_NE(sem, nullptr);
auto* mat = sem->Type()->As<sem::Matrix>();
ASSERT_NE(mat, nullptr);
EXPECT_TRUE(mat->type()->Is<sem::F32>());
EXPECT_EQ(mat->columns(), 3u);
EXPECT_EQ(mat->rows(), 2u);
EXPECT_EQ(sem->ConstantValue().Type(), sem->Type());
EXPECT_TRUE(sem->ConstantValue().ElementType()->Is<sem::F32>());
ASSERT_EQ(sem->ConstantValue().ElementCount(), 6u);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(0).value, 1._a);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(1).value, 2._a);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(2).value, 3._a);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(3).value, 4._a);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(4).value, 5._a);
EXPECT_EQ(sem->ConstantValue().Element<AFloat>(5).value, 6._a);
}
} // namespace
} // namespace tint::resolver

View File

@ -98,7 +98,7 @@ const Type* Constant::CheckElemType(const sem::Type* ty, size_t num_elements) {
diag::List diag;
if (ty->is_abstract_or_scalar() || ty->IsAnyOf<Vector, Matrix>()) {
uint32_t count = 0;
auto* el_ty = Type::ElementOf(ty, &count);
auto* el_ty = Type::DeepestElementOf(ty, &count);
if (num_elements != count) {
TINT_ICE(Semantic, diag) << "sem::Constant() type <-> element mismatch. type: '"
<< ty->TypeInfo().name << "' element: " << num_elements;

View File

@ -104,10 +104,10 @@ class Constant {
return std::visit([](auto&& v) { return v.size(); }, elems_);
}
/// @returns the element type of the Constant
/// @returns the flattened element type of the Constant
const sem::Type* ElementType() const { return elem_type_; }
/// @returns the constant's elements
/// @returns the constant's flattened elements
const Elements& GetElements() const { return elems_; }
/// WithElements calls the function `f` with the vector of elements as either AFloats or AInts

View File

@ -220,9 +220,9 @@ const Type* Type::ElementOf(const Type* ty, uint32_t* count /* = nullptr */) {
},
[&](const Matrix* m) {
if (count) {
*count = m->columns() * m->rows();
*count = m->columns();
}
return m->type();
return m->ColumnType();
},
[&](const Array* a) {
if (count) {

View File

@ -132,15 +132,22 @@ class Type : public Castable<Type, Node> {
/// @param ty the type to obtain the element type from
/// @param count if not null, then this is assigned the number of child elements in the type.
/// For example, the count of an `array<vec3<f32>, 5>` type would be 5.
/// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector,
/// matrix or array, otherwise nullptr.
/// @returns
/// * `ty` if `ty` is an abstract or scalar
/// * the element type if `ty` is a vector or array
/// * the column type if `ty` is a matrix
/// * `nullptr` if `ty` is none of the above
static const Type* ElementOf(const Type* ty, uint32_t* count = nullptr);
/// @param ty the type to obtain the deepest element type from
/// @param count if not null, then this is assigned the full number of most deeply nested
/// elements in the type. For example, the count of an `array<vec3<f32>, 5>` type would be 15.
/// @returns `ty` if `ty` is an abstract or scalar, or the element type if ty is a vector,
/// matrix, or the deepest element type if ty is an array, otherwise nullptr.
/// @returns
/// * `ty` if `ty` is an abstract or scalar
/// * the element type if `ty` is a vector
/// * the matrix element type if `ty` is a matrix
/// * the deepest element type if `ty` is an array
/// * `nullptr` if `ty` is none of the above
static const Type* DeepestElementOf(const Type* ty, uint32_t* count = nullptr);
/// @param types a pointer to a list of `const Type*`.

View File

@ -157,9 +157,9 @@ TEST_F(TypeTest, ElementOf) {
EXPECT_TYPE(Type::ElementOf(vec4_f32), f32);
EXPECT_TYPE(Type::ElementOf(vec3_u32), u32);
EXPECT_TYPE(Type::ElementOf(vec3_i32), i32);
EXPECT_TYPE(Type::ElementOf(mat2x4_f32), f32);
EXPECT_TYPE(Type::ElementOf(mat4x2_f32), f32);
EXPECT_TYPE(Type::ElementOf(mat4x3_f16), f16);
EXPECT_TYPE(Type::ElementOf(mat2x4_f32), vec4_f32);
EXPECT_TYPE(Type::ElementOf(mat4x2_f32), vec2_f32);
EXPECT_TYPE(Type::ElementOf(mat4x3_f16), vec3_f16);
EXPECT_TYPE(Type::ElementOf(str), nullptr);
EXPECT_TYPE(Type::ElementOf(arr_i32), i32);
EXPECT_TYPE(Type::ElementOf(arr_vec3_i32), vec3_i32);
@ -195,14 +195,14 @@ TEST_F(TypeTest, ElementOf) {
EXPECT_TYPE(Type::ElementOf(vec3_i32, &count), i32);
EXPECT_EQ(count, 3u);
count = 42;
EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), f32);
EXPECT_EQ(count, 8u);
EXPECT_TYPE(Type::ElementOf(mat2x4_f32, &count), vec4_f32);
EXPECT_EQ(count, 2u);
count = 42;
EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), f32);
EXPECT_EQ(count, 8u);
EXPECT_TYPE(Type::ElementOf(mat4x2_f32, &count), vec2_f32);
EXPECT_EQ(count, 4u);
count = 42;
EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), f16);
EXPECT_EQ(count, 12u);
EXPECT_TYPE(Type::ElementOf(mat4x3_f16, &count), vec3_f16);
EXPECT_EQ(count, 4u);
count = 42;
EXPECT_TYPE(Type::ElementOf(str, &count), nullptr);
EXPECT_EQ(count, 0u);