Implement addition and subtraction of float matrices

Bug: tint:316
Change-Id: I3a1082c41c47daacf0220d029cb2a5f118684959
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52580
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Commit-Queue: David Neto <dneto@google.com>
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
Antonio Maiorano 2021-05-31 15:53:20 +00:00 committed by Tint LUCI CQ
parent 34c58a932c
commit c91f8f2822
11 changed files with 399 additions and 28 deletions

View File

@ -2124,13 +2124,20 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
} }
// Matrix arithmetic // Matrix arithmetic
// TODO(amaiorano): matrix-matrix addition and subtraction auto* lhs_mat = lhs_type->As<Matrix>();
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
auto* rhs_mat = rhs_type->As<Matrix>();
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Addition and subtraction of float matrices
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->columns()) &&
(lhs_mat->rows() == rhs_mat->rows())) {
SetType(expr, rhs_type);
return true;
}
if (expr->IsMultiply()) { if (expr->IsMultiply()) {
auto* lhs_mat = lhs_type->As<Matrix>();
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
auto* rhs_mat = rhs_type->As<Matrix>();
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Multiplication of a matrix and a scalar // Multiplication of a matrix and a scalar
if (lhs_type->Is<F32>() && rhs_mat_elem_type && if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>()) { rhs_mat_elem_type->Is<F32>()) {

View File

@ -1238,9 +1238,12 @@ static constexpr ast::BinaryOp all_ops[] = {
}; };
static constexpr create_ast_type_func_ptr all_create_type_funcs[] = { static constexpr create_ast_type_func_ptr all_create_type_funcs[] = {
ast_bool, ast_u32, ast_i32, ast_f32, ast_bool, ast_u32, ast_i32, ast_f32,
ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>, ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>,
ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>}; ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>, //
ast_mat2x3<i32>, ast_mat2x3<u32>, ast_mat2x3<f32>, //
ast_mat3x2<i32>, ast_mat3x2<u32>, ast_mat3x2<f32> //
};
// A list of all valid test cases for 'lhs op rhs', except that for vecN and // A list of all valid test cases for 'lhs op rhs', except that for vecN and
// matNxN, we only test N=3. // matNxN, we only test N=3.
@ -1338,14 +1341,43 @@ static constexpr Params all_valid_cases[] = {
// Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>}, // Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
// Matrix arithmetic // Matrix arithmetic
Params{Op::kMultiply, ast_mat2x3<f32>, ast_f32, sem_mat2x3<sem_f32>},
Params{Op::kMultiply, ast_mat3x2<f32>, ast_f32, sem_mat3x2<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>}, Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_mat2x3<f32>, sem_mat2x3<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_mat3x2<f32>, sem_mat3x2<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>}, Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_mat2x3<f32>, sem_vec2<sem_f32>},
Params{Op::kMultiply, ast_vec2<f32>, ast_mat3x2<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>}, Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_mat3x2<f32>, ast_vec3<f32>, sem_vec2<sem_f32>},
Params{Op::kMultiply, ast_mat2x3<f32>, ast_vec2<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>}, Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
// TODO(amaiorano): add mat+mat and mat-mat
Params{Op::kMultiply, ast_mat2x3<f32>, ast_mat3x2<f32>,
sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat2x3<f32>,
sem_mat2x2<sem_f32>},
Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat3x3<f32>,
sem_mat3x2<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>, Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
sem_mat3x3<sem_f32>}, sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat2x3<f32>,
sem_mat2x3<sem_f32>},
Params{Op::kAdd, ast_mat2x3<f32>, ast_mat2x3<f32>, sem_mat2x3<sem_f32>},
Params{Op::kAdd, ast_mat3x2<f32>, ast_mat3x2<f32>, sem_mat3x2<sem_f32>},
Params{Op::kAdd, ast_mat3x3<f32>, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
Params{Op::kSubtract, ast_mat2x3<f32>, ast_mat2x3<f32>,
sem_mat2x3<sem_f32>},
Params{Op::kSubtract, ast_mat3x2<f32>, ast_mat3x2<f32>,
sem_mat3x2<sem_f32>},
Params{Op::kSubtract, ast_mat3x3<f32>, ast_mat3x3<f32>,
sem_mat3x3<sem_f32>},
// Comparison expressions // Comparison expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr // https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr

View File

@ -176,6 +176,26 @@ ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat2x2(create_type(ty)); return ty.mat2x2(create_type(ty));
} }
template <typename T>
ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat2x3<T>();
}
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat2x3(create_type(ty));
}
template <typename T>
ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x2<T>();
}
template <create_ast_type_func_ptr create_type>
ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x2(create_type(ty));
}
template <typename T> template <typename T>
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) { ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3<T>(); return ty.mat3x3<T>();
@ -249,6 +269,18 @@ sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
return ty.builder->create<sem::Matrix>(column_type, 2u); return ty.builder->create<sem::Matrix>(column_type, 2u);
} }
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
return ty.builder->create<sem::Matrix>(column_type, 2u);
}
template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 2u);
return ty.builder->create<sem::Matrix>(column_type, 3u);
}
template <create_sem_type_func_ptr create_type> template <create_sem_type_func_ptr create_type>
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) { sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u); auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);

View File

@ -1748,6 +1748,67 @@ uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) {
return splat_result.to_i(); return splat_result.to_i();
} }
uint32_t Builder::GenerateMatrixAddOrSub(uint32_t lhs_id,
uint32_t rhs_id,
const sem::Matrix* type,
spv::Op op) {
// Example addition of two matrices:
// %31 = OpLoad %mat3v4float %m34
// %32 = OpLoad %mat3v4float %m34
// %33 = OpCompositeExtract %v4float %31 0
// %34 = OpCompositeExtract %v4float %32 0
// %35 = OpFAdd %v4float %33 %34
// %36 = OpCompositeExtract %v4float %31 1
// %37 = OpCompositeExtract %v4float %32 1
// %38 = OpFAdd %v4float %36 %37
// %39 = OpCompositeExtract %v4float %31 2
// %40 = OpCompositeExtract %v4float %32 2
// %41 = OpFAdd %v4float %39 %40
// %42 = OpCompositeConstruct %mat3v4float %35 %38 %41
auto* column_type = builder_.create<sem::Vector>(type->type(), type->rows());
auto column_type_id = GenerateTypeIfNeeded(column_type);
OperandList ops;
for (uint32_t i = 0; i < type->columns(); ++i) {
// Extract column `i` from lhs mat
auto lhs_column_id = result_op();
if (!push_function_inst(spv::Op::OpCompositeExtract,
{Operand::Int(column_type_id), lhs_column_id,
Operand::Int(lhs_id), Operand::Int(i)})) {
return 0;
}
// Extract column `i` from rhs mat
auto rhs_column_id = result_op();
if (!push_function_inst(spv::Op::OpCompositeExtract,
{Operand::Int(column_type_id), rhs_column_id,
Operand::Int(rhs_id), Operand::Int(i)})) {
return 0;
}
// Add or subtract the two columns
auto result = result_op();
if (!push_function_inst(op, {Operand::Int(column_type_id), result,
lhs_column_id, rhs_column_id})) {
return 0;
}
ops.push_back(result);
}
// Create the result matrix from the added/subtracted column vectors
auto result_mat_id = result_op();
ops.insert(ops.begin(), result_mat_id);
ops.insert(ops.begin(), Operand::Int(GenerateTypeIfNeeded(type)));
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
return 0;
}
return result_mat_id.to_i();
}
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
// There is special logic for short circuiting operators. // There is special logic for short circuiting operators.
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
@ -1779,6 +1840,24 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef(); auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef(); auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
// Handle matrix-matrix addition and subtraction
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_type->is_float_matrix() &&
rhs_type->is_float_matrix()) {
auto* lhs_mat = lhs_type->As<sem::Matrix>();
auto* rhs_mat = rhs_type->As<sem::Matrix>();
// This should already have been validated by resolver
if (lhs_mat->rows() != rhs_mat->rows() ||
lhs_mat->columns() != rhs_mat->columns()) {
error_ = "matrices must have same dimensionality for add or subtract";
return 0;
}
return GenerateMatrixAddOrSub(
lhs_id, rhs_id, lhs_mat,
expr->IsAdd() ? spv::Op::OpFAdd : spv::Op::OpFSub);
}
// For vector-scalar arithmetic operations, splat scalar into a vector. We // For vector-scalar arithmetic operations, splat scalar into a vector. We
// skip this for multiply as we can use OpVectorTimesScalar. // skip this for multiply as we can use OpVectorTimesScalar.
const bool is_float_scalar_vector_multiply = const bool is_float_scalar_vector_multiply =

View File

@ -482,6 +482,17 @@ class Builder {
/// @returns id of the new vector /// @returns id of the new vector
uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type); uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type);
/// Generates instructions to add or subtract two matrices
/// @param lhs_id id of multiplicand
/// @param rhs_id id of multiplier
/// @param type type of both matrices and of result
/// @param op one of `spv::Op::OpFAdd` or `spv::Op::OpFSub`
/// @returns id of the result matrix
uint32_t GenerateMatrixAddOrSub(uint32_t lhs_id,
uint32_t rhs_id,
const sem::Matrix* type,
spv::Op op);
/// Converts AST image format to SPIR-V and pushes an appropriate capability. /// Converts AST image format to SPIR-V and pushes an appropriate capability.
/// @param format AST image format type /// @param format AST image format type
/// @returns SPIR-V image format type /// @returns SPIR-V image format type

View File

@ -863,8 +863,10 @@ OpBranch %9
)"); )");
} }
namespace BinaryArithVectorScalar {
enum class Type { f32, i32, u32 }; enum class Type { f32, i32, u32 };
ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) { static ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
switch (type) { switch (type) {
case Type::f32: case Type::f32:
return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f); return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
@ -875,7 +877,7 @@ ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
} }
return nullptr; return nullptr;
} }
ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) { static ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
switch (type) { switch (type) {
case Type::f32: case Type::f32:
return builder->Expr(1.f); return builder->Expr(1.f);
@ -886,7 +888,7 @@ ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
} }
return nullptr; return nullptr;
} }
std::string OpTypeDecl(Type type) { static std::string OpTypeDecl(Type type) {
switch (type) { switch (type) {
case Type::f32: case Type::f32:
return "OpTypeFloat 32"; return "OpTypeFloat 32";
@ -904,8 +906,8 @@ struct Param {
std::string name; std::string name;
}; };
using MixedBinaryArithTest = TestParamHelper<Param>; using BinaryArithVectorScalarTest = TestParamHelper<Param>;
TEST_P(MixedBinaryArithTest, VectorScalar) { TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
auto& param = GetParam(); auto& param = GetParam();
ast::Expression* lhs = MakeVectorExpr(this, param.type); ast::Expression* lhs = MakeVectorExpr(this, param.type);
@ -943,7 +945,7 @@ OpFunctionEnd
Validate(b); Validate(b);
} }
TEST_P(MixedBinaryArithTest, ScalarVector) { TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
auto& param = GetParam(); auto& param = GetParam();
ast::Expression* lhs = MakeScalarExpr(this, param.type); ast::Expression* lhs = MakeScalarExpr(this, param.type);
@ -983,7 +985,7 @@ OpFunctionEnd
} }
INSTANTIATE_TEST_SUITE_P( INSTANTIATE_TEST_SUITE_P(
BuilderTest, BuilderTest,
MixedBinaryArithTest, BinaryArithVectorScalarTest,
testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"}, testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"}, Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
// NOTE: Modulo not allowed on mixed float scalar-vector // NOTE: Modulo not allowed on mixed float scalar-vector
@ -1005,8 +1007,8 @@ INSTANTIATE_TEST_SUITE_P(
Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"}, Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"})); Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
using MixedBinaryArithMultiplyTest = TestParamHelper<Param>; using BinaryArithVectorScalarMultiplyTest = TestParamHelper<Param>;
TEST_P(MixedBinaryArithMultiplyTest, VectorScalar) { TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
auto& param = GetParam(); auto& param = GetParam();
ast::Expression* lhs = MakeVectorExpr(this, param.type); ast::Expression* lhs = MakeVectorExpr(this, param.type);
@ -1040,7 +1042,7 @@ OpFunctionEnd
Validate(b); Validate(b);
} }
TEST_P(MixedBinaryArithMultiplyTest, ScalarVector) { TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
auto& param = GetParam(); auto& param = GetParam();
ast::Expression* lhs = MakeScalarExpr(this, param.type); ast::Expression* lhs = MakeScalarExpr(this, param.type);
@ -1075,10 +1077,113 @@ OpFunctionEnd
Validate(b); Validate(b);
} }
INSTANTIATE_TEST_SUITE_P(BuilderTest, INSTANTIATE_TEST_SUITE_P(BuilderTest,
MixedBinaryArithMultiplyTest, BinaryArithVectorScalarMultiplyTest,
testing::Values(Param{ testing::Values(Param{
Type::f32, ast::BinaryOp::kMultiply, "OpFMul"})); Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
} // namespace BinaryArithVectorScalar
namespace BinaryArithMatrixMatrix {
struct Param {
ast::BinaryOp op;
std::string name;
};
using BinaryArithMatrixMatrix = TestParamHelper<Param>;
TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
auto& param = GetParam();
ast::Expression* lhs = mat3x4<f32>();
ast::Expression* rhs = mat3x4<f32>();
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
OpName %3 "test_function"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%7 = OpTypeFloat 32
%6 = OpTypeVector %7 4
%5 = OpTypeMatrix %6 3
%8 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%10 = OpCompositeExtract %6 %8 0
%11 = OpCompositeExtract %6 %8 0
%12 = )" + param.name + R"( %6 %10 %11
%13 = OpCompositeExtract %6 %8 1
%14 = OpCompositeExtract %6 %8 1
%15 = )" + param.name + R"( %6 %13 %14
%16 = OpCompositeExtract %6 %8 2
%17 = OpCompositeExtract %6 %8 2
%18 = )" + param.name + R"( %6 %16 %17
%19 = OpCompositeConstruct %5 %12 %15 %18
OpReturn
OpFunctionEnd
)");
Validate(b);
}
INSTANTIATE_TEST_SUITE_P( //
BuilderTest,
BinaryArithMatrixMatrix,
testing::Values(Param{ast::BinaryOp::kAdd, "OpFAdd"},
Param{ast::BinaryOp::kSubtract, "OpFSub"}));
using BinaryArithMatrixMatrixMultiply = TestParamHelper<Param>;
TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
auto& param = GetParam();
ast::Expression* lhs = mat3x4<f32>();
ast::Expression* rhs = mat4x3<f32>();
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
OpName %3 "test_function"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%7 = OpTypeFloat 32
%6 = OpTypeVector %7 4
%5 = OpTypeMatrix %6 3
%8 = OpConstantNull %5
%10 = OpTypeVector %7 3
%9 = OpTypeMatrix %10 4
%11 = OpConstantNull %9
%13 = OpTypeMatrix %6 4
%3 = OpFunction %2 None %1
%4 = OpLabel
%12 = OpMatrixTimesMatrix %13 %8 %11
OpReturn
OpFunctionEnd
)");
Validate(b);
}
INSTANTIATE_TEST_SUITE_P( //
BuilderTest,
BinaryArithMatrixMatrixMultiply,
testing::Values(Param{ast::BinaryOp::kMultiply, "OpFMul"}));
} // namespace BinaryArithMatrixMatrix
} // namespace } // namespace
} // namespace spirv } // namespace spirv
} // namespace writer } // namespace writer

View File

@ -64,6 +64,19 @@ fn scalar_vector_u32() {
r = s % v; r = s % v;
} }
fn matrix_matrix_f32() {
var m34 : mat3x4<f32>;
var m43 : mat4x3<f32>;
var m33 : mat3x3<f32>;
var m44 : mat4x4<f32>;
m34 = m34 + m34;
m34 = m34 - m34;
m33 = m43 * m34;
m44 = m34 * m43;
}
[[stage(fragment)]] [[stage(fragment)]]
fn main() -> [[location(0)]] vec4<f32> { fn main() -> [[location(0)]] vec4<f32> {
return vec4<f32>(0.0,0.0,0.0,0.0); return vec4<f32>(0.0,0.0,0.0,0.0);

View File

@ -66,6 +66,17 @@ void scalar_vector_u32() {
r = (s % v); r = (s % v);
} }
void matrix_matrix_f32() {
float3x4 m34 = float3x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
float4x3 m43 = float4x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
float3x3 m33 = float3x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
float4x4 m44 = float4x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
m34 = (m34 + m34);
m34 = (m34 - m34);
m33 = mul(m34, m43);
m44 = mul(m43, m34);
}
tint_symbol main() { tint_symbol main() {
const tint_symbol tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f)}; const tint_symbol tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f)};
return tint_symbol_1; return tint_symbol_1;

View File

@ -69,6 +69,17 @@ void scalar_vector_u32() {
r = (s % v); r = (s % v);
} }
void matrix_matrix_f32() {
float3x4 m34 = float3x4(0.0f);
float4x3 m43 = float4x3(0.0f);
float3x3 m33 = float3x3(0.0f);
float4x4 m44 = float4x4(0.0f);
m34 = (m34 + m34);
m34 = (m34 - m34);
m33 = (m43 * m34);
m44 = (m34 * m43);
}
fragment tint_symbol_1 tint_symbol() { fragment tint_symbol_1 tint_symbol() {
return {float4(0.0f, 0.0f, 0.0f, 0.0f)}; return {float4(0.0f, 0.0f, 0.0f, 0.0f)};
} }

View File

@ -1,7 +1,7 @@
; SPIR-V ; SPIR-V
; Version: 1.3 ; Version: 1.3
; Generator: Google Tint Compiler; 0 ; Generator: Google Tint Compiler; 0
; Bound: 200 ; Bound: 250
; Schema: 0 ; Schema: 0
OpCapability Shader OpCapability Shader
OpMemoryModel Logical GLSL450 OpMemoryModel Logical GLSL450
@ -32,6 +32,11 @@
OpName %v_4 "v" OpName %v_4 "v"
OpName %s_4 "s" OpName %s_4 "s"
OpName %r_4 "r" OpName %r_4 "r"
OpName %matrix_matrix_f32 "matrix_matrix_f32"
OpName %m34 "m34"
OpName %m43 "m43"
OpName %m33 "m33"
OpName %m44 "m44"
OpName %tint_symbol_2 "tint_symbol_2" OpName %tint_symbol_2 "tint_symbol_2"
OpName %tint_symbol "tint_symbol" OpName %tint_symbol "tint_symbol"
OpName %main "main" OpName %main "main"
@ -60,9 +65,21 @@
%78 = OpConstantNull %v3uint %78 = OpConstantNull %v3uint
%_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Function_uint = OpTypePointer Function %uint
%81 = OpConstantNull %uint %81 = OpConstantNull %uint
%191 = OpTypeFunction %void %v4float %mat3v4float = OpTypeMatrix %v4float 3
%_ptr_Function_mat3v4float = OpTypePointer Function %mat3v4float
%196 = OpConstantNull %mat3v4float
%mat4v3float = OpTypeMatrix %v3float 4
%_ptr_Function_mat4v3float = OpTypePointer Function %mat4v3float
%200 = OpConstantNull %mat4v3float
%mat3v3float = OpTypeMatrix %v3float 3
%_ptr_Function_mat3v3float = OpTypePointer Function %mat3v3float
%204 = OpConstantNull %mat3v3float
%mat4v4float = OpTypeMatrix %v4float 4
%_ptr_Function_mat4v4float = OpTypePointer Function %mat4v4float
%208 = OpConstantNull %mat4v4float
%241 = OpTypeFunction %void %v4float
%float_0 = OpConstant %float 0 %float_0 = OpConstant %float 0
%199 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 %249 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
%vector_scalar_f32 = OpFunction %void None %6 %vector_scalar_f32 = OpFunction %void None %6
%9 = OpLabel %9 = OpLabel
%v = OpVariable %_ptr_Function_v3float Function %13 %v = OpVariable %_ptr_Function_v3float Function %13
@ -269,14 +286,56 @@
OpStore %r_4 %188 OpStore %r_4 %188
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
%tint_symbol_2 = OpFunction %void None %191 %matrix_matrix_f32 = OpFunction %void None %6
%192 = OpLabel
%m34 = OpVariable %_ptr_Function_mat3v4float Function %196
%m43 = OpVariable %_ptr_Function_mat4v3float Function %200
%m33 = OpVariable %_ptr_Function_mat3v3float Function %204
%m44 = OpVariable %_ptr_Function_mat4v4float Function %208
%209 = OpLoad %mat3v4float %m34
%210 = OpLoad %mat3v4float %m34
%212 = OpCompositeExtract %v4float %209 0
%213 = OpCompositeExtract %v4float %210 0
%214 = OpFAdd %v4float %212 %213
%215 = OpCompositeExtract %v4float %209 1
%216 = OpCompositeExtract %v4float %210 1
%217 = OpFAdd %v4float %215 %216
%218 = OpCompositeExtract %v4float %209 2
%219 = OpCompositeExtract %v4float %210 2
%220 = OpFAdd %v4float %218 %219
%221 = OpCompositeConstruct %mat3v4float %214 %217 %220
OpStore %m34 %221
%222 = OpLoad %mat3v4float %m34
%223 = OpLoad %mat3v4float %m34
%225 = OpCompositeExtract %v4float %222 0
%226 = OpCompositeExtract %v4float %223 0
%227 = OpFSub %v4float %225 %226
%228 = OpCompositeExtract %v4float %222 1
%229 = OpCompositeExtract %v4float %223 1
%230 = OpFSub %v4float %228 %229
%231 = OpCompositeExtract %v4float %222 2
%232 = OpCompositeExtract %v4float %223 2
%233 = OpFSub %v4float %231 %232
%234 = OpCompositeConstruct %mat3v4float %227 %230 %233
OpStore %m34 %234
%235 = OpLoad %mat4v3float %m43
%236 = OpLoad %mat3v4float %m34
%237 = OpMatrixTimesMatrix %mat3v3float %235 %236
OpStore %m33 %237
%238 = OpLoad %mat3v4float %m34
%239 = OpLoad %mat4v3float %m43
%240 = OpMatrixTimesMatrix %mat4v4float %238 %239
OpStore %m44 %240
OpReturn
OpFunctionEnd
%tint_symbol_2 = OpFunction %void None %241
%tint_symbol = OpFunctionParameter %v4float %tint_symbol = OpFunctionParameter %v4float
%194 = OpLabel %244 = OpLabel
OpStore %tint_symbol_1 %tint_symbol OpStore %tint_symbol_1 %tint_symbol
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd
%main = OpFunction %void None %6 %main = OpFunction %void None %6
%196 = OpLabel %246 = OpLabel
%197 = OpFunctionCall %void %tint_symbol_2 %199 %247 = OpFunctionCall %void %tint_symbol_2 %249
OpReturn OpReturn
OpFunctionEnd OpFunctionEnd

View File

@ -62,6 +62,17 @@ fn scalar_vector_u32() {
r = (s % v); r = (s % v);
} }
fn matrix_matrix_f32() {
var m34 : mat3x4<f32>;
var m43 : mat4x3<f32>;
var m33 : mat3x3<f32>;
var m44 : mat4x4<f32>;
m34 = (m34 + m34);
m34 = (m34 - m34);
m33 = (m43 * m34);
m44 = (m34 * m43);
}
[[stage(fragment)]] [[stage(fragment)]]
fn main() -> [[location(0)]] vec4<f32> { fn main() -> [[location(0)]] vec4<f32> {
return vec4<f32>(0.0, 0.0, 0.0, 0.0); return vec4<f32>(0.0, 0.0, 0.0, 0.0);