sem::Matrix: Pass the column type to the constructor
It's common to want this when indexing matrices. Change-Id: Ic60a3a8d05873119d78a3cb0860d129e33ac3525 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/49880 Reviewed-by: James Price <jrprice@google.com> Commit-Queue: Ben Clayton <bclayton@chromium.org>
This commit is contained in:
parent
b432f232b5
commit
6c1cf6569e
|
@ -356,8 +356,9 @@ class OpenSizeMatBuilder : public Builder {
|
||||||
auto* el = element_builder_->Build(state);
|
auto* el = element_builder_->Build(state);
|
||||||
auto columns = state.open_numbers.at(columns_);
|
auto columns = state.open_numbers.at(columns_);
|
||||||
auto rows = state.open_numbers.at(rows_);
|
auto rows = state.open_numbers.at(rows_);
|
||||||
return state.ty_mgr.Get<sem::Matrix>(const_cast<sem::Type*>(el), rows,
|
auto* column_type =
|
||||||
columns);
|
state.ty_mgr.Get<sem::Vector>(const_cast<sem::Type*>(el), rows);
|
||||||
|
return state.ty_mgr.Get<sem::Matrix>(column_type, columns);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string str() const override {
|
std::string str() const override {
|
||||||
|
|
|
@ -470,7 +470,7 @@ class ProgramBuilder {
|
||||||
type = MaybeCreateTypename(type);
|
type = MaybeCreateTypename(type);
|
||||||
return {type.ast ? builder->create<ast::Matrix>(type, rows, columns)
|
return {type.ast ? builder->create<ast::Matrix>(type, rows, columns)
|
||||||
: nullptr,
|
: nullptr,
|
||||||
type.sem ? builder->create<sem::Matrix>(type, rows, columns)
|
type.sem ? builder->create<sem::Matrix>(vec(type, rows), columns)
|
||||||
: nullptr};
|
: nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -486,7 +486,7 @@ class ProgramBuilder {
|
||||||
return {type.ast
|
return {type.ast
|
||||||
? builder->create<ast::Matrix>(source, type, rows, columns)
|
? builder->create<ast::Matrix>(source, type, rows, columns)
|
||||||
: nullptr,
|
: nullptr,
|
||||||
type.sem ? builder->create<sem::Matrix>(type, rows, columns)
|
type.sem ? builder->create<sem::Matrix>(vec(type, rows), columns)
|
||||||
: nullptr};
|
: nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -347,8 +347,9 @@ const sem::Type* Resolver::Type(const ast::Type* ty) {
|
||||||
}
|
}
|
||||||
if (auto* t = ty->As<ast::Matrix>()) {
|
if (auto* t = ty->As<ast::Matrix>()) {
|
||||||
if (auto* el = Type(t->type())) {
|
if (auto* el = Type(t->type())) {
|
||||||
return builder_->create<sem::Matrix>(const_cast<sem::Type*>(el),
|
auto* column_type = builder_->create<sem::Vector>(
|
||||||
t->rows(), t->columns());
|
const_cast<sem::Type*>(el), t->rows());
|
||||||
|
return builder_->create<sem::Matrix>(column_type, t->columns());
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
@ -1958,8 +1959,10 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
|
||||||
auto* rhs_vec = rhs_type->As<sem::Vector>();
|
auto* rhs_vec = rhs_type->As<sem::Vector>();
|
||||||
const sem::Type* result_type = nullptr;
|
const sem::Type* result_type = nullptr;
|
||||||
if (lhs_mat && rhs_mat) {
|
if (lhs_mat && rhs_mat) {
|
||||||
result_type = builder_->create<sem::Matrix>(
|
auto* column_type =
|
||||||
lhs_mat->type(), lhs_mat->rows(), rhs_mat->columns());
|
builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
|
||||||
|
result_type =
|
||||||
|
builder_->create<sem::Matrix>(column_type, rhs_mat->columns());
|
||||||
} else if (lhs_mat && rhs_vec) {
|
} else if (lhs_mat && rhs_vec) {
|
||||||
result_type =
|
result_type =
|
||||||
builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
|
builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
|
||||||
|
@ -2886,9 +2889,9 @@ const sem::Type* Resolver::Canonical(const sem::Type* type) {
|
||||||
const_cast<sem::Type*>(make_canonical(v->type())), v->size());
|
const_cast<sem::Type*>(make_canonical(v->type())), v->size());
|
||||||
}
|
}
|
||||||
if (auto* m = ct->As<Matrix>()) {
|
if (auto* m = ct->As<Matrix>()) {
|
||||||
return builder_->create<Matrix>(
|
auto* column_type =
|
||||||
const_cast<sem::Type*>(make_canonical(m->type())), m->rows(),
|
builder_->create<sem::Vector>(make_canonical(m->type()), m->rows());
|
||||||
m->columns());
|
return builder_->create<Matrix>(column_type, m->columns());
|
||||||
}
|
}
|
||||||
if (auto* ac = ct->As<AccessControl>()) {
|
if (auto* ac = ct->As<AccessControl>()) {
|
||||||
return builder_->create<AccessControl>(ac->access_control(),
|
return builder_->create<AccessControl>(ac->access_control(),
|
||||||
|
|
|
@ -239,17 +239,20 @@ sem::Type* sem_vec4(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
|
|
||||||
template <create_sem_type_func_ptr create_type>
|
template <create_sem_type_func_ptr create_type>
|
||||||
sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
|
sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
return ty.builder->create<sem::Matrix>(create_type(ty), 2, 2);
|
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 2u);
|
||||||
|
return ty.builder->create<sem::Matrix>(column_type, 2u);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <create_sem_type_func_ptr create_type>
|
template <create_sem_type_func_ptr create_type>
|
||||||
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
return ty.builder->create<sem::Matrix>(create_type(ty), 3, 3);
|
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
|
||||||
|
return ty.builder->create<sem::Matrix>(column_type, 3u);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <create_sem_type_func_ptr create_type>
|
template <create_sem_type_func_ptr create_type>
|
||||||
sem::Type* sem_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
|
sem::Type* sem_mat4x4(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
return ty.builder->create<sem::Matrix>(create_type(ty), 4, 4);
|
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 4u);
|
||||||
|
return ty.builder->create<sem::Matrix>(column_type, 4u);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <create_sem_type_func_ptr create_type>
|
template <create_sem_type_func_ptr create_type>
|
||||||
|
|
|
@ -15,18 +15,22 @@
|
||||||
#include "src/sem/matrix_type.h"
|
#include "src/sem/matrix_type.h"
|
||||||
|
|
||||||
#include "src/program_builder.h"
|
#include "src/program_builder.h"
|
||||||
|
#include "src/sem/vector_type.h"
|
||||||
|
|
||||||
TINT_INSTANTIATE_TYPEINFO(tint::sem::Matrix);
|
TINT_INSTANTIATE_TYPEINFO(tint::sem::Matrix);
|
||||||
|
|
||||||
namespace tint {
|
namespace tint {
|
||||||
namespace sem {
|
namespace sem {
|
||||||
|
|
||||||
Matrix::Matrix(Type* subtype, uint32_t rows, uint32_t columns)
|
Matrix::Matrix(Vector* column_type, uint32_t columns)
|
||||||
: subtype_(subtype), rows_(rows), columns_(columns) {
|
: subtype_(column_type->type()),
|
||||||
TINT_ASSERT(rows > 1);
|
column_type_(column_type),
|
||||||
TINT_ASSERT(rows < 5);
|
rows_(column_type->size()),
|
||||||
TINT_ASSERT(columns > 1);
|
columns_(columns) {
|
||||||
TINT_ASSERT(columns < 5);
|
TINT_ASSERT(rows_ > 1);
|
||||||
|
TINT_ASSERT(rows_ < 5);
|
||||||
|
TINT_ASSERT(columns_ > 1);
|
||||||
|
TINT_ASSERT(columns_ < 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
Matrix::Matrix(Matrix&&) = default;
|
Matrix::Matrix(Matrix&&) = default;
|
||||||
|
@ -47,8 +51,8 @@ std::string Matrix::FriendlyName(const SymbolTable& symbols) const {
|
||||||
|
|
||||||
Matrix* Matrix::Clone(CloneContext* ctx) const {
|
Matrix* Matrix::Clone(CloneContext* ctx) const {
|
||||||
// Clone arguments outside of create() call to have deterministic ordering
|
// Clone arguments outside of create() call to have deterministic ordering
|
||||||
auto* ty = ctx->Clone(type());
|
auto* column_type = ctx->Clone(ColumnType());
|
||||||
return ctx->dst->create<Matrix>(ty, rows_, columns_);
|
return ctx->dst->create<Matrix>(column_type, columns_);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace sem
|
} // namespace sem
|
||||||
|
|
|
@ -22,14 +22,16 @@
|
||||||
namespace tint {
|
namespace tint {
|
||||||
namespace sem {
|
namespace sem {
|
||||||
|
|
||||||
|
// Forward declaration
|
||||||
|
class Vector;
|
||||||
|
|
||||||
/// A matrix type
|
/// A matrix type
|
||||||
class Matrix : public Castable<Matrix, Type> {
|
class Matrix : public Castable<Matrix, Type> {
|
||||||
public:
|
public:
|
||||||
/// Constructor
|
/// Constructor
|
||||||
/// @param subtype type matrix type
|
/// @param column_type the type of a column of the matrix
|
||||||
/// @param rows the number of rows in the matrix
|
|
||||||
/// @param columns the number of columns in the matrix
|
/// @param columns the number of columns in the matrix
|
||||||
Matrix(Type* subtype, uint32_t rows, uint32_t columns);
|
Matrix(Vector* column_type, uint32_t columns);
|
||||||
/// Move constructor
|
/// Move constructor
|
||||||
Matrix(Matrix&&);
|
Matrix(Matrix&&);
|
||||||
~Matrix() override;
|
~Matrix() override;
|
||||||
|
@ -41,6 +43,9 @@ class Matrix : public Castable<Matrix, Type> {
|
||||||
/// @returns the number of columns in the matrix
|
/// @returns the number of columns in the matrix
|
||||||
uint32_t columns() const { return columns_; }
|
uint32_t columns() const { return columns_; }
|
||||||
|
|
||||||
|
/// @returns the column-vector type of the matrix
|
||||||
|
Vector* ColumnType() const { return column_type_; }
|
||||||
|
|
||||||
/// @returns the name for this type
|
/// @returns the name for this type
|
||||||
std::string type_name() const override;
|
std::string type_name() const override;
|
||||||
|
|
||||||
|
@ -56,6 +61,7 @@ class Matrix : public Castable<Matrix, Type> {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Type* const subtype_;
|
Type* const subtype_;
|
||||||
|
Vector* const column_type_;
|
||||||
uint32_t const rows_;
|
uint32_t const rows_;
|
||||||
uint32_t const columns_;
|
uint32_t const columns_;
|
||||||
};
|
};
|
||||||
|
|
|
@ -24,7 +24,8 @@ using MatrixTest = TestHelper;
|
||||||
|
|
||||||
TEST_F(MatrixTest, Creation) {
|
TEST_F(MatrixTest, Creation) {
|
||||||
I32 i32;
|
I32 i32;
|
||||||
Matrix m{&i32, 2, 4};
|
Vector c{&i32, 2};
|
||||||
|
Matrix m{&c, 4};
|
||||||
EXPECT_EQ(m.type(), &i32);
|
EXPECT_EQ(m.type(), &i32);
|
||||||
EXPECT_EQ(m.rows(), 2u);
|
EXPECT_EQ(m.rows(), 2u);
|
||||||
EXPECT_EQ(m.columns(), 4u);
|
EXPECT_EQ(m.columns(), 4u);
|
||||||
|
@ -32,7 +33,8 @@ TEST_F(MatrixTest, Creation) {
|
||||||
|
|
||||||
TEST_F(MatrixTest, Is) {
|
TEST_F(MatrixTest, Is) {
|
||||||
I32 i32;
|
I32 i32;
|
||||||
Matrix m{&i32, 2, 3};
|
Vector c{&i32, 2};
|
||||||
|
Matrix m{&c, 4};
|
||||||
Type* ty = &m;
|
Type* ty = &m;
|
||||||
EXPECT_FALSE(ty->Is<AccessControl>());
|
EXPECT_FALSE(ty->Is<AccessControl>());
|
||||||
EXPECT_FALSE(ty->Is<Alias>());
|
EXPECT_FALSE(ty->Is<Alias>());
|
||||||
|
@ -51,12 +53,15 @@ TEST_F(MatrixTest, Is) {
|
||||||
|
|
||||||
TEST_F(MatrixTest, TypeName) {
|
TEST_F(MatrixTest, TypeName) {
|
||||||
I32 i32;
|
I32 i32;
|
||||||
Matrix m{&i32, 2, 3};
|
Vector c{&i32, 2};
|
||||||
|
Matrix m{&c, 3};
|
||||||
EXPECT_EQ(m.type_name(), "__mat_2_3__i32");
|
EXPECT_EQ(m.type_name(), "__mat_2_3__i32");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(MatrixTest, FriendlyName) {
|
TEST_F(MatrixTest, FriendlyName) {
|
||||||
Matrix m{ty.i32(), 3, 2};
|
I32 i32;
|
||||||
|
Vector c{&i32, 3};
|
||||||
|
Matrix m{&c, 2};
|
||||||
EXPECT_EQ(m.FriendlyName(Symbols()), "mat2x3<i32>");
|
EXPECT_EQ(m.FriendlyName(Symbols()), "mat2x3<i32>");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -58,7 +58,8 @@ TEST_F(CreateASTTypeForTest, Basic) {
|
||||||
|
|
||||||
TEST_F(CreateASTTypeForTest, Matrix) {
|
TEST_F(CreateASTTypeForTest, Matrix) {
|
||||||
auto* mat = create([](ProgramBuilder& b) {
|
auto* mat = create([](ProgramBuilder& b) {
|
||||||
return b.create<sem::Matrix>(b.create<sem::F32>(), 2, 3);
|
auto* column_type = b.create<sem::Vector>(b.create<sem::F32>(), 2u);
|
||||||
|
return b.create<sem::Matrix>(column_type, 3u);
|
||||||
});
|
});
|
||||||
ASSERT_TRUE(mat->Is<ast::Matrix>());
|
ASSERT_TRUE(mat->Is<ast::Matrix>());
|
||||||
ASSERT_TRUE(mat->As<ast::Matrix>()->type()->Is<ast::F32>());
|
ASSERT_TRUE(mat->As<ast::Matrix>()->type()->Is<ast::F32>());
|
||||||
|
|
Loading…
Reference in New Issue