[ir][spirv-writer] Emit matrix types

Bug: tint:1906
Change-Id: Ief75da483f89595d56094611ecb037da73a396e5
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/134204
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: James Price <jrprice@google.com>
This commit is contained in:
James Price 2023-05-24 17:43:56 +00:00 committed by Dawn LUCI CQ
parent b54b58d57d
commit a8c528052d
3 changed files with 74 additions and 4 deletions

View File

@ -21,6 +21,7 @@
#include "src/tint/type/f16.h" #include "src/tint/type/f16.h"
#include "src/tint/type/f32.h" #include "src/tint/type/f32.h"
#include "src/tint/type/i32.h" #include "src/tint/type/i32.h"
#include "src/tint/type/matrix.h"
#include "src/tint/type/type.h" #include "src/tint/type/type.h"
#include "src/tint/type/u32.h" #include "src/tint/type/u32.h"
#include "src/tint/type/vector.h" #include "src/tint/type/vector.h"
@ -112,21 +113,65 @@ class Manager final {
/// @param inner the inner type /// @param inner the inner type
/// @param size the vector size /// @param size the vector size
/// @returns the vector type /// @returns the vector type
const type::Type* vec(const type::Type* inner, uint32_t size) { const type::Vector* vec(const type::Type* inner, uint32_t size) {
return Get<type::Vector>(inner, size); return Get<type::Vector>(inner, size);
} }
/// @param inner the inner type /// @param inner the inner type
/// @returns the vector type /// @returns the vector type
const type::Type* vec2(const type::Type* inner) { return vec(inner, 2); } const type::Vector* vec2(const type::Type* inner) { return vec(inner, 2); }
/// @param inner the inner type /// @param inner the inner type
/// @returns the vector type /// @returns the vector type
const type::Type* vec3(const type::Type* inner) { return vec(inner, 3); } const type::Vector* vec3(const type::Type* inner) { return vec(inner, 3); }
/// @param inner the inner type /// @param inner the inner type
/// @returns the vector type /// @returns the vector type
const type::Type* vec4(const type::Type* inner) { return vec(inner, 4); } const type::Vector* vec4(const type::Type* inner) { return vec(inner, 4); }
/// @param inner the inner type
/// @param cols the number of columns
/// @param rows the number of rows
/// @returns the matrix type
const type::Matrix* mat(const type::Type* inner, uint32_t cols, uint32_t rows) {
return Get<type::Matrix>(vec(inner, rows), cols);
}
/// @param inner the inner type
/// @returns the matrix type
const type::Matrix* mat2x2(const type::Type* inner) { return mat(inner, 2, 2); }
/// @param inner the inner type
/// @returns the matrix type
const type::Matrix* mat2x3(const type::Type* inner) { return mat(inner, 2, 3); }
/// @param inner the inner type
/// @returns the matrix type
const type::Matrix* mat2x4(const type::Type* inner) { return mat(inner, 2, 4); }
/// @param inner the inner type
/// @returns the matrix type
const type::Matrix* mat3x2(const type::Type* inner) { return mat(inner, 3, 2); }
/// @param inner the inner type
/// @returns the matrix type
const type::Matrix* mat3x3(const type::Type* inner) { return mat(inner, 3, 3); }
/// @param inner the inner type
/// @returns the matrix type
const type::Matrix* mat3x4(const type::Type* inner) { return mat(inner, 3, 4); }
/// @param inner the inner type
/// @returns the matrix type
const type::Matrix* mat4x2(const type::Type* inner) { return mat(inner, 4, 2); }
/// @param inner the inner type
/// @returns the matrix type
const type::Matrix* mat4x3(const type::Type* inner) { return mat(inner, 4, 3); }
/// @param inner the inner type
/// @returns the matrix type
const type::Matrix* mat4x4(const type::Type* inner) { return mat(inner, 4, 4); }
/// @returns an iterator to the beginning of the types /// @returns an iterator to the beginning of the types
TypeIterator begin() const { return types_.begin(); } TypeIterator begin() const { return types_.begin(); }

View File

@ -33,6 +33,7 @@
#include "src/tint/type/f16.h" #include "src/tint/type/f16.h"
#include "src/tint/type/f32.h" #include "src/tint/type/f32.h"
#include "src/tint/type/i32.h" #include "src/tint/type/i32.h"
#include "src/tint/type/matrix.h"
#include "src/tint/type/pointer.h" #include "src/tint/type/pointer.h"
#include "src/tint/type/type.h" #include "src/tint/type/type.h"
#include "src/tint/type/u32.h" #include "src/tint/type/u32.h"
@ -178,6 +179,10 @@ uint32_t GeneratorImplIr::Type(const type::Type* ty) {
[&](const type::Vector* vec) { [&](const type::Vector* vec) {
module_.PushType(spv::Op::OpTypeVector, {id, Type(vec->type()), vec->Width()}); module_.PushType(spv::Op::OpTypeVector, {id, Type(vec->type()), vec->Width()});
}, },
[&](const type::Matrix* mat) {
module_.PushType(spv::Op::OpTypeMatrix,
{id, Type(mat->ColumnType()), mat->columns()});
},
[&](const type::Pointer* ptr) { [&](const type::Pointer* ptr) {
module_.PushType( module_.PushType(
spv::Op::OpTypePointer, spv::Op::OpTypePointer,

View File

@ -105,6 +105,26 @@ TEST_F(SpvGeneratorImplTest, Type_Vec4Bool) {
"%1 = OpTypeVector %2 4\n"); "%1 = OpTypeVector %2 4\n");
} }
TEST_F(SpvGeneratorImplTest, Type_Mat2x3f) {
auto* vec = b.ir.types.mat2x3(b.ir.types.f32());
auto id = generator_.Type(vec);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
"%3 = OpTypeFloat 32\n"
"%2 = OpTypeVector %3 3\n"
"%1 = OpTypeMatrix %2 2\n");
}
TEST_F(SpvGeneratorImplTest, Type_Mat4x2h) {
auto* vec = b.ir.types.mat4x2(b.ir.types.f16());
auto id = generator_.Type(vec);
EXPECT_EQ(id, 1u);
EXPECT_EQ(DumpTypes(),
"%3 = OpTypeFloat 16\n"
"%2 = OpTypeVector %3 2\n"
"%1 = OpTypeMatrix %2 4\n");
}
// Test that we can emit multiple types. // Test that we can emit multiple types.
// Includes types with the same opcode but different parameters. // Includes types with the same opcode but different parameters.
TEST_F(SpvGeneratorImplTest, Type_Multiple) { TEST_F(SpvGeneratorImplTest, Type_Multiple) {