diff --git a/src/intrinsic_table.cc b/src/intrinsic_table.cc index 8a59ab59d9..276f52cbdc 100644 --- a/src/intrinsic_table.cc +++ b/src/intrinsic_table.cc @@ -356,8 +356,9 @@ class OpenSizeMatBuilder : public Builder { auto* el = element_builder_->Build(state); auto columns = state.open_numbers.at(columns_); auto rows = state.open_numbers.at(rows_); - return state.ty_mgr.Get(const_cast(el), rows, - columns); + auto* column_type = + state.ty_mgr.Get(const_cast(el), rows); + return state.ty_mgr.Get(column_type, columns); } std::string str() const override { diff --git a/src/program_builder.h b/src/program_builder.h index f9e7a0f82d..9a13b7e1cc 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -470,7 +470,7 @@ class ProgramBuilder { type = MaybeCreateTypename(type); return {type.ast ? builder->create(type, rows, columns) : nullptr, - type.sem ? builder->create(type, rows, columns) + type.sem ? builder->create(vec(type, rows), columns) : nullptr}; } @@ -486,7 +486,7 @@ class ProgramBuilder { return {type.ast ? builder->create(source, type, rows, columns) : nullptr, - type.sem ? builder->create(type, rows, columns) + type.sem ? builder->create(vec(type, rows), columns) : nullptr}; } diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index cea5eb34bb..488bc6c870 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -347,8 +347,9 @@ const sem::Type* Resolver::Type(const ast::Type* ty) { } if (auto* t = ty->As()) { if (auto* el = Type(t->type())) { - return builder_->create(const_cast(el), - t->rows(), t->columns()); + auto* column_type = builder_->create( + const_cast(el), t->rows()); + return builder_->create(column_type, t->columns()); } return nullptr; } @@ -1958,8 +1959,10 @@ bool Resolver::Binary(ast::BinaryExpression* expr) { auto* rhs_vec = rhs_type->As(); const sem::Type* result_type = nullptr; if (lhs_mat && rhs_mat) { - result_type = builder_->create( - lhs_mat->type(), lhs_mat->rows(), rhs_mat->columns()); + auto* column_type = + builder_->create(lhs_mat->type(), lhs_mat->rows()); + result_type = + builder_->create(column_type, rhs_mat->columns()); } else if (lhs_mat && rhs_vec) { result_type = builder_->create(lhs_mat->type(), lhs_mat->rows()); @@ -2886,9 +2889,9 @@ const sem::Type* Resolver::Canonical(const sem::Type* type) { const_cast(make_canonical(v->type())), v->size()); } if (auto* m = ct->As()) { - return builder_->create( - const_cast(make_canonical(m->type())), m->rows(), - m->columns()); + auto* column_type = + builder_->create(make_canonical(m->type()), m->rows()); + return builder_->create(column_type, m->columns()); } if (auto* ac = ct->As()) { return builder_->create(ac->access_control(), diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h index 910a23004d..cd9f5a64a3 100644 --- a/src/resolver/resolver_test_helper.h +++ b/src/resolver/resolver_test_helper.h @@ -239,17 +239,20 @@ sem::Type* sem_vec4(const ProgramBuilder::TypesBuilder& ty) { template sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) { - return ty.builder->create(create_type(ty), 2, 2); + auto* column_type = ty.builder->create(create_type(ty), 2u); + return ty.builder->create(column_type, 2u); } template sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) { - return ty.builder->create(create_type(ty), 3, 3); + auto* column_type = ty.builder->create(create_type(ty), 3u); + return ty.builder->create(column_type, 3u); } template sem::Type* sem_mat4x4(const ProgramBuilder::TypesBuilder& ty) { - return ty.builder->create(create_type(ty), 4, 4); + auto* column_type = ty.builder->create(create_type(ty), 4u); + return ty.builder->create(column_type, 4u); } template diff --git a/src/sem/matrix_type.cc b/src/sem/matrix_type.cc index 07c173f807..b8b965c556 100644 --- a/src/sem/matrix_type.cc +++ b/src/sem/matrix_type.cc @@ -15,18 +15,22 @@ #include "src/sem/matrix_type.h" #include "src/program_builder.h" +#include "src/sem/vector_type.h" TINT_INSTANTIATE_TYPEINFO(tint::sem::Matrix); namespace tint { namespace sem { -Matrix::Matrix(Type* subtype, uint32_t rows, uint32_t columns) - : subtype_(subtype), rows_(rows), columns_(columns) { - TINT_ASSERT(rows > 1); - TINT_ASSERT(rows < 5); - TINT_ASSERT(columns > 1); - TINT_ASSERT(columns < 5); +Matrix::Matrix(Vector* column_type, uint32_t columns) + : subtype_(column_type->type()), + column_type_(column_type), + rows_(column_type->size()), + columns_(columns) { + TINT_ASSERT(rows_ > 1); + TINT_ASSERT(rows_ < 5); + TINT_ASSERT(columns_ > 1); + TINT_ASSERT(columns_ < 5); } Matrix::Matrix(Matrix&&) = default; @@ -47,8 +51,8 @@ std::string Matrix::FriendlyName(const SymbolTable& symbols) const { Matrix* Matrix::Clone(CloneContext* ctx) const { // Clone arguments outside of create() call to have deterministic ordering - auto* ty = ctx->Clone(type()); - return ctx->dst->create(ty, rows_, columns_); + auto* column_type = ctx->Clone(ColumnType()); + return ctx->dst->create(column_type, columns_); } } // namespace sem diff --git a/src/sem/matrix_type.h b/src/sem/matrix_type.h index 2dd6ed2576..4a4be0ff1a 100644 --- a/src/sem/matrix_type.h +++ b/src/sem/matrix_type.h @@ -22,14 +22,16 @@ namespace tint { namespace sem { +// Forward declaration +class Vector; + /// A matrix type class Matrix : public Castable { public: /// Constructor - /// @param subtype type matrix type - /// @param rows the number of rows in the matrix + /// @param column_type the type of a column of the matrix /// @param columns the number of columns in the matrix - Matrix(Type* subtype, uint32_t rows, uint32_t columns); + Matrix(Vector* column_type, uint32_t columns); /// Move constructor Matrix(Matrix&&); ~Matrix() override; @@ -41,6 +43,9 @@ class Matrix : public Castable { /// @returns the number of columns in the matrix uint32_t columns() const { return columns_; } + /// @returns the column-vector type of the matrix + Vector* ColumnType() const { return column_type_; } + /// @returns the name for this type std::string type_name() const override; @@ -56,6 +61,7 @@ class Matrix : public Castable { private: Type* const subtype_; + Vector* const column_type_; uint32_t const rows_; uint32_t const columns_; }; diff --git a/src/sem/matrix_type_test.cc b/src/sem/matrix_type_test.cc index f94faa3cd1..9b3df83b2e 100644 --- a/src/sem/matrix_type_test.cc +++ b/src/sem/matrix_type_test.cc @@ -24,7 +24,8 @@ using MatrixTest = TestHelper; TEST_F(MatrixTest, Creation) { I32 i32; - Matrix m{&i32, 2, 4}; + Vector c{&i32, 2}; + Matrix m{&c, 4}; EXPECT_EQ(m.type(), &i32); EXPECT_EQ(m.rows(), 2u); EXPECT_EQ(m.columns(), 4u); @@ -32,7 +33,8 @@ TEST_F(MatrixTest, Creation) { TEST_F(MatrixTest, Is) { I32 i32; - Matrix m{&i32, 2, 3}; + Vector c{&i32, 2}; + Matrix m{&c, 4}; Type* ty = &m; EXPECT_FALSE(ty->Is()); EXPECT_FALSE(ty->Is()); @@ -51,12 +53,15 @@ TEST_F(MatrixTest, Is) { TEST_F(MatrixTest, TypeName) { I32 i32; - Matrix m{&i32, 2, 3}; + Vector c{&i32, 2}; + Matrix m{&c, 3}; EXPECT_EQ(m.type_name(), "__mat_2_3__i32"); } TEST_F(MatrixTest, FriendlyName) { - Matrix m{ty.i32(), 3, 2}; + I32 i32; + Vector c{&i32, 3}; + Matrix m{&c, 2}; EXPECT_EQ(m.FriendlyName(Symbols()), "mat2x3"); } diff --git a/src/transform/transform_test.cc b/src/transform/transform_test.cc index f287f57dd7..b83defc334 100644 --- a/src/transform/transform_test.cc +++ b/src/transform/transform_test.cc @@ -58,7 +58,8 @@ TEST_F(CreateASTTypeForTest, Basic) { TEST_F(CreateASTTypeForTest, Matrix) { auto* mat = create([](ProgramBuilder& b) { - return b.create(b.create(), 2, 3); + auto* column_type = b.create(b.create(), 2u); + return b.create(column_type, 3u); }); ASSERT_TRUE(mat->Is()); ASSERT_TRUE(mat->As()->type()->Is());