sem::Matrix: Pass the column type to the constructor

It's common to want this when indexing matrices.

Change-Id: Ic60a3a8d05873119d78a3cb0860d129e33ac3525
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49880
Reviewed-by: James Price <jrprice@google.com>
Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
Ben Clayton 2021-05-05 16:48:32 +00:00 committed by Commit Bot service account
parent b432f232b5
commit 6c1cf6569e
8 changed files with 53 additions and 30 deletions

View File

@ -356,8 +356,9 @@ class OpenSizeMatBuilder : public Builder {
auto* el = element_builder_->Build(state);
auto columns = state.open_numbers.at(columns_);
auto rows = state.open_numbers.at(rows_);
return state.ty_mgr.Get<sem::Matrix>(const_cast<sem::Type*>(el), rows,
columns);
auto* column_type =
state.ty_mgr.Get<sem::Vector>(const_cast<sem::Type*>(el), rows);
return state.ty_mgr.Get<sem::Matrix>(column_type, columns);
}
std::string str() const override {

View File

@ -470,7 +470,7 @@ class ProgramBuilder {
type = MaybeCreateTypename(type);
return {type.ast ? builder->create<ast::Matrix>(type, rows, columns)
: nullptr,
type.sem ? builder->create<sem::Matrix>(type, rows, columns)
type.sem ? builder->create<sem::Matrix>(vec(type, rows), columns)
: nullptr};
}
@ -486,7 +486,7 @@ class ProgramBuilder {
return {type.ast
? builder->create<ast::Matrix>(source, type, rows, columns)
: nullptr,
type.sem ? builder->create<sem::Matrix>(type, rows, columns)
type.sem ? builder->create<sem::Matrix>(vec(type, rows), columns)
: nullptr};
}

View File

@ -347,8 +347,9 @@ const sem::Type* Resolver::Type(const ast::Type* ty) {
}
if (auto* t = ty->As<ast::Matrix>()) {
if (auto* el = Type(t->type())) {
return builder_->create<sem::Matrix>(const_cast<sem::Type*>(el),
t->rows(), t->columns());
auto* column_type = builder_->create<sem::Vector>(
const_cast<sem::Type*>(el), t->rows());
return builder_->create<sem::Matrix>(column_type, t->columns());
}
return nullptr;
}
@ -1958,8 +1959,10 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
auto* rhs_vec = rhs_type->As<sem::Vector>();
const sem::Type* result_type = nullptr;
if (lhs_mat && rhs_mat) {
result_type = builder_->create<sem::Matrix>(
lhs_mat->type(), lhs_mat->rows(), rhs_mat->columns());
auto* column_type =
builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
result_type =
builder_->create<sem::Matrix>(column_type, rhs_mat->columns());
} else if (lhs_mat && rhs_vec) {
result_type =
builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
@ -2886,9 +2889,9 @@ const sem::Type* Resolver::Canonical(const sem::Type* type) {
const_cast<sem::Type*>(make_canonical(v->type())), v->size());
}
if (auto* m = ct->As<Matrix>()) {
return builder_->create<Matrix>(
const_cast<sem::Type*>(make_canonical(m->type())), m->rows(),
m->columns());
auto* column_type =
builder_->create<sem::Vector>(make_canonical(m->type()), m->rows());
return builder_->create<Matrix>(column_type, m->columns());
}
if (auto* ac = ct->As<AccessControl>()) {
return builder_->create<AccessControl>(ac->access_control(),

View File

@ -239,17 +239,20 @@ sem::Type* sem_vec4(const ProgramBuilder::TypesBuilder& ty) {
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Matrix>(create_type(ty), 2, 2);
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 2u);
return ty.builder->create<sem::Matrix>(column_type, 2u);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Matrix>(create_type(ty), 3, 3);
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
return ty.builder->create<sem::Matrix>(column_type, 3u);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Matrix>(create_type(ty), 4, 4);
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 4u);
return ty.builder->create<sem::Matrix>(column_type, 4u);
}
template <create_sem_type_func_ptr create_type>

View File

@ -15,18 +15,22 @@
#include "src/sem/matrix_type.h"
#include "src/program_builder.h"
#include "src/sem/vector_type.h"
TINT_INSTANTIATE_TYPEINFO(tint::sem::Matrix);
namespace tint {
namespace sem {
Matrix::Matrix(Type* subtype, uint32_t rows, uint32_t columns)
: subtype_(subtype), rows_(rows), columns_(columns) {
TINT_ASSERT(rows > 1);
TINT_ASSERT(rows < 5);
TINT_ASSERT(columns > 1);
TINT_ASSERT(columns < 5);
Matrix::Matrix(Vector* column_type, uint32_t columns)
: subtype_(column_type->type()),
column_type_(column_type),
rows_(column_type->size()),
columns_(columns) {
TINT_ASSERT(rows_ > 1);
TINT_ASSERT(rows_ < 5);
TINT_ASSERT(columns_ > 1);
TINT_ASSERT(columns_ < 5);
}
Matrix::Matrix(Matrix&&) = default;
@ -47,8 +51,8 @@ std::string Matrix::FriendlyName(const SymbolTable& symbols) const {
Matrix* Matrix::Clone(CloneContext* ctx) const {
// Clone arguments outside of create() call to have deterministic ordering
auto* ty = ctx->Clone(type());
return ctx->dst->create<Matrix>(ty, rows_, columns_);
auto* column_type = ctx->Clone(ColumnType());
return ctx->dst->create<Matrix>(column_type, columns_);
}
} // namespace sem

View File

@ -22,14 +22,16 @@
namespace tint {
namespace sem {
// Forward declaration
class Vector;
/// A matrix type
class Matrix : public Castable<Matrix, Type> {
public:
/// Constructor
/// @param subtype type matrix type
/// @param rows the number of rows in the matrix
/// @param column_type the type of a column of the matrix
/// @param columns the number of columns in the matrix
Matrix(Type* subtype, uint32_t rows, uint32_t columns);
Matrix(Vector* column_type, uint32_t columns);
/// Move constructor
Matrix(Matrix&&);
~Matrix() override;
@ -41,6 +43,9 @@ class Matrix : public Castable<Matrix, Type> {
/// @returns the number of columns in the matrix
uint32_t columns() const { return columns_; }
/// @returns the column-vector type of the matrix
Vector* ColumnType() const { return column_type_; }
/// @returns the name for this type
std::string type_name() const override;
@ -56,6 +61,7 @@ class Matrix : public Castable<Matrix, Type> {
private:
Type* const subtype_;
Vector* const column_type_;
uint32_t const rows_;
uint32_t const columns_;
};

View File

@ -24,7 +24,8 @@ using MatrixTest = TestHelper;
TEST_F(MatrixTest, Creation) {
I32 i32;
Matrix m{&i32, 2, 4};
Vector c{&i32, 2};
Matrix m{&c, 4};
EXPECT_EQ(m.type(), &i32);
EXPECT_EQ(m.rows(), 2u);
EXPECT_EQ(m.columns(), 4u);
@ -32,7 +33,8 @@ TEST_F(MatrixTest, Creation) {
TEST_F(MatrixTest, Is) {
I32 i32;
Matrix m{&i32, 2, 3};
Vector c{&i32, 2};
Matrix m{&c, 4};
Type* ty = &m;
EXPECT_FALSE(ty->Is<AccessControl>());
EXPECT_FALSE(ty->Is<Alias>());
@ -51,12 +53,15 @@ TEST_F(MatrixTest, Is) {
TEST_F(MatrixTest, TypeName) {
I32 i32;
Matrix m{&i32, 2, 3};
Vector c{&i32, 2};
Matrix m{&c, 3};
EXPECT_EQ(m.type_name(), "__mat_2_3__i32");
}
TEST_F(MatrixTest, FriendlyName) {
Matrix m{ty.i32(), 3, 2};
I32 i32;
Vector c{&i32, 3};
Matrix m{&c, 2};
EXPECT_EQ(m.FriendlyName(Symbols()), "mat2x3<i32>");
}

View File

@ -58,7 +58,8 @@ TEST_F(CreateASTTypeForTest, Basic) {
TEST_F(CreateASTTypeForTest, Matrix) {
auto* mat = create([](ProgramBuilder& b) {
return b.create<sem::Matrix>(b.create<sem::F32>(), 2, 3);
auto* column_type = b.create<sem::Vector>(b.create<sem::F32>(), 2u);
return b.create<sem::Matrix>(column_type, 3u);
});
ASSERT_TRUE(mat->Is<ast::Matrix>());
ASSERT_TRUE(mat->As<ast::Matrix>()->type()->Is<ast::F32>());