diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index baceffa7df..99dc700b8b 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -2124,13 +2124,20 @@ bool Resolver::Binary(ast::BinaryExpression* expr) { } // Matrix arithmetic - // TODO(amaiorano): matrix-matrix addition and subtraction + auto* lhs_mat = lhs_type->As(); + auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr; + auto* rhs_mat = rhs_type->As(); + auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr; + // Addition and subtraction of float matrices + if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type && + lhs_mat_elem_type->Is() && rhs_mat_elem_type && + rhs_mat_elem_type->Is() && + (lhs_mat->columns() == rhs_mat->columns()) && + (lhs_mat->rows() == rhs_mat->rows())) { + SetType(expr, rhs_type); + return true; + } if (expr->IsMultiply()) { - auto* lhs_mat = lhs_type->As(); - auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr; - auto* rhs_mat = rhs_type->As(); - auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr; - // Multiplication of a matrix and a scalar if (lhs_type->Is() && rhs_mat_elem_type && rhs_mat_elem_type->Is()) { diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index 06b6ce1936..7406204f8e 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -1238,9 +1238,12 @@ static constexpr ast::BinaryOp all_ops[] = { }; static constexpr create_ast_type_func_ptr all_create_type_funcs[] = { - ast_bool, ast_u32, ast_i32, ast_f32, - ast_vec3, ast_vec3, ast_vec3, ast_vec3, - ast_mat3x3, ast_mat3x3, ast_mat3x3}; + ast_bool, ast_u32, ast_i32, ast_f32, + ast_vec3, ast_vec3, ast_vec3, ast_vec3, + ast_mat3x3, ast_mat3x3, ast_mat3x3, // + ast_mat2x3, ast_mat2x3, ast_mat2x3, // + ast_mat3x2, ast_mat3x2, ast_mat3x2 // +}; // A list of all valid test cases for 'lhs op rhs', except that for vecN and // matNxN, we only test N=3. @@ -1338,14 +1341,43 @@ static constexpr Params all_valid_cases[] = { // Params{Op::kModulo, ast_f32, ast_vec3, sem_vec3}, // Matrix arithmetic + Params{Op::kMultiply, ast_mat2x3, ast_f32, sem_mat2x3}, + Params{Op::kMultiply, ast_mat3x2, ast_f32, sem_mat3x2}, Params{Op::kMultiply, ast_mat3x3, ast_f32, sem_mat3x3}, + + Params{Op::kMultiply, ast_f32, ast_mat2x3, sem_mat2x3}, + Params{Op::kMultiply, ast_f32, ast_mat3x2, sem_mat3x2}, Params{Op::kMultiply, ast_f32, ast_mat3x3, sem_mat3x3}, + Params{Op::kMultiply, ast_vec3, ast_mat2x3, sem_vec2}, + Params{Op::kMultiply, ast_vec2, ast_mat3x2, sem_vec3}, Params{Op::kMultiply, ast_vec3, ast_mat3x3, sem_vec3}, + + Params{Op::kMultiply, ast_mat3x2, ast_vec3, sem_vec2}, + Params{Op::kMultiply, ast_mat2x3, ast_vec2, sem_vec3}, Params{Op::kMultiply, ast_mat3x3, ast_vec3, sem_vec3}, - // TODO(amaiorano): add mat+mat and mat-mat + + Params{Op::kMultiply, ast_mat2x3, ast_mat3x2, + sem_mat3x3}, + Params{Op::kMultiply, ast_mat3x2, ast_mat2x3, + sem_mat2x2}, + Params{Op::kMultiply, ast_mat3x2, ast_mat3x3, + sem_mat3x2}, Params{Op::kMultiply, ast_mat3x3, ast_mat3x3, sem_mat3x3}, + Params{Op::kMultiply, ast_mat3x3, ast_mat2x3, + sem_mat2x3}, + + Params{Op::kAdd, ast_mat2x3, ast_mat2x3, sem_mat2x3}, + Params{Op::kAdd, ast_mat3x2, ast_mat3x2, sem_mat3x2}, + Params{Op::kAdd, ast_mat3x3, ast_mat3x3, sem_mat3x3}, + + Params{Op::kSubtract, ast_mat2x3, ast_mat2x3, + sem_mat2x3}, + Params{Op::kSubtract, ast_mat3x2, ast_mat3x2, + sem_mat3x2}, + Params{Op::kSubtract, ast_mat3x3, ast_mat3x3, + sem_mat3x3}, // Comparison expressions // https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr diff --git a/src/resolver/resolver_test_helper.h b/src/resolver/resolver_test_helper.h index b82aec0172..4a10110e9c 100644 --- a/src/resolver/resolver_test_helper.h +++ b/src/resolver/resolver_test_helper.h @@ -176,6 +176,26 @@ ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) { return ty.mat2x2(create_type(ty)); } +template +ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat2x3(); +} + +template +ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat2x3(create_type(ty)); +} + +template +ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat3x2(); +} + +template +ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat3x2(create_type(ty)); +} + template ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) { return ty.mat3x3(); @@ -249,6 +269,18 @@ sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) { return ty.builder->create(column_type, 2u); } +template +sem::Type* sem_mat2x3(const ProgramBuilder::TypesBuilder& ty) { + auto* column_type = ty.builder->create(create_type(ty), 3u); + return ty.builder->create(column_type, 2u); +} + +template +sem::Type* sem_mat3x2(const ProgramBuilder::TypesBuilder& ty) { + auto* column_type = ty.builder->create(create_type(ty), 2u); + return ty.builder->create(column_type, 3u); +} + template sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) { auto* column_type = ty.builder->create(create_type(ty), 3u); diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index dd8de52134..79e551ccd3 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -1748,6 +1748,67 @@ uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) { return splat_result.to_i(); } +uint32_t Builder::GenerateMatrixAddOrSub(uint32_t lhs_id, + uint32_t rhs_id, + const sem::Matrix* type, + spv::Op op) { + // Example addition of two matrices: + // %31 = OpLoad %mat3v4float %m34 + // %32 = OpLoad %mat3v4float %m34 + // %33 = OpCompositeExtract %v4float %31 0 + // %34 = OpCompositeExtract %v4float %32 0 + // %35 = OpFAdd %v4float %33 %34 + // %36 = OpCompositeExtract %v4float %31 1 + // %37 = OpCompositeExtract %v4float %32 1 + // %38 = OpFAdd %v4float %36 %37 + // %39 = OpCompositeExtract %v4float %31 2 + // %40 = OpCompositeExtract %v4float %32 2 + // %41 = OpFAdd %v4float %39 %40 + // %42 = OpCompositeConstruct %mat3v4float %35 %38 %41 + + auto* column_type = builder_.create(type->type(), type->rows()); + auto column_type_id = GenerateTypeIfNeeded(column_type); + + OperandList ops; + + for (uint32_t i = 0; i < type->columns(); ++i) { + // Extract column `i` from lhs mat + auto lhs_column_id = result_op(); + if (!push_function_inst(spv::Op::OpCompositeExtract, + {Operand::Int(column_type_id), lhs_column_id, + Operand::Int(lhs_id), Operand::Int(i)})) { + return 0; + } + + // Extract column `i` from rhs mat + auto rhs_column_id = result_op(); + if (!push_function_inst(spv::Op::OpCompositeExtract, + {Operand::Int(column_type_id), rhs_column_id, + Operand::Int(rhs_id), Operand::Int(i)})) { + return 0; + } + + // Add or subtract the two columns + auto result = result_op(); + if (!push_function_inst(op, {Operand::Int(column_type_id), result, + lhs_column_id, rhs_column_id})) { + return 0; + } + + ops.push_back(result); + } + + // Create the result matrix from the added/subtracted column vectors + auto result_mat_id = result_op(); + ops.insert(ops.begin(), result_mat_id); + ops.insert(ops.begin(), Operand::Int(GenerateTypeIfNeeded(type))); + if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) { + return 0; + } + + return result_mat_id.to_i(); +} + uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { // There is special logic for short circuiting operators. if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { @@ -1779,6 +1840,24 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef(); auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef(); + // Handle matrix-matrix addition and subtraction + if ((expr->IsAdd() || expr->IsSubtract()) && lhs_type->is_float_matrix() && + rhs_type->is_float_matrix()) { + auto* lhs_mat = lhs_type->As(); + auto* rhs_mat = rhs_type->As(); + + // This should already have been validated by resolver + if (lhs_mat->rows() != rhs_mat->rows() || + lhs_mat->columns() != rhs_mat->columns()) { + error_ = "matrices must have same dimensionality for add or subtract"; + return 0; + } + + return GenerateMatrixAddOrSub( + lhs_id, rhs_id, lhs_mat, + expr->IsAdd() ? spv::Op::OpFAdd : spv::Op::OpFSub); + } + // For vector-scalar arithmetic operations, splat scalar into a vector. We // skip this for multiply as we can use OpVectorTimesScalar. const bool is_float_scalar_vector_multiply = diff --git a/src/writer/spirv/builder.h b/src/writer/spirv/builder.h index 3dac6d29b8..1a7271d288 100644 --- a/src/writer/spirv/builder.h +++ b/src/writer/spirv/builder.h @@ -482,6 +482,17 @@ class Builder { /// @returns id of the new vector uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type); + /// Generates instructions to add or subtract two matrices + /// @param lhs_id id of multiplicand + /// @param rhs_id id of multiplier + /// @param type type of both matrices and of result + /// @param op one of `spv::Op::OpFAdd` or `spv::Op::OpFSub` + /// @returns id of the result matrix + uint32_t GenerateMatrixAddOrSub(uint32_t lhs_id, + uint32_t rhs_id, + const sem::Matrix* type, + spv::Op op); + /// Converts AST image format to SPIR-V and pushes an appropriate capability. /// @param format AST image format type /// @returns SPIR-V image format type diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc index e15d4ec93e..3b8391b491 100644 --- a/src/writer/spirv/builder_binary_expression_test.cc +++ b/src/writer/spirv/builder_binary_expression_test.cc @@ -863,8 +863,10 @@ OpBranch %9 )"); } +namespace BinaryArithVectorScalar { + enum class Type { f32, i32, u32 }; -ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) { +static ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) { switch (type) { case Type::f32: return builder->vec3(1.f, 1.f, 1.f); @@ -875,7 +877,7 @@ ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) { } return nullptr; } -ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) { +static ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) { switch (type) { case Type::f32: return builder->Expr(1.f); @@ -886,7 +888,7 @@ ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) { } return nullptr; } -std::string OpTypeDecl(Type type) { +static std::string OpTypeDecl(Type type) { switch (type) { case Type::f32: return "OpTypeFloat 32"; @@ -904,8 +906,8 @@ struct Param { std::string name; }; -using MixedBinaryArithTest = TestParamHelper; -TEST_P(MixedBinaryArithTest, VectorScalar) { +using BinaryArithVectorScalarTest = TestParamHelper; +TEST_P(BinaryArithVectorScalarTest, VectorScalar) { auto& param = GetParam(); ast::Expression* lhs = MakeVectorExpr(this, param.type); @@ -943,7 +945,7 @@ OpFunctionEnd Validate(b); } -TEST_P(MixedBinaryArithTest, ScalarVector) { +TEST_P(BinaryArithVectorScalarTest, ScalarVector) { auto& param = GetParam(); ast::Expression* lhs = MakeScalarExpr(this, param.type); @@ -983,7 +985,7 @@ OpFunctionEnd } INSTANTIATE_TEST_SUITE_P( BuilderTest, - MixedBinaryArithTest, + BinaryArithVectorScalarTest, testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"}, Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"}, // NOTE: Modulo not allowed on mixed float scalar-vector @@ -1005,8 +1007,8 @@ INSTANTIATE_TEST_SUITE_P( Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"}, Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"})); -using MixedBinaryArithMultiplyTest = TestParamHelper; -TEST_P(MixedBinaryArithMultiplyTest, VectorScalar) { +using BinaryArithVectorScalarMultiplyTest = TestParamHelper; +TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) { auto& param = GetParam(); ast::Expression* lhs = MakeVectorExpr(this, param.type); @@ -1040,7 +1042,7 @@ OpFunctionEnd Validate(b); } -TEST_P(MixedBinaryArithMultiplyTest, ScalarVector) { +TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) { auto& param = GetParam(); ast::Expression* lhs = MakeScalarExpr(this, param.type); @@ -1075,10 +1077,113 @@ OpFunctionEnd Validate(b); } INSTANTIATE_TEST_SUITE_P(BuilderTest, - MixedBinaryArithMultiplyTest, + BinaryArithVectorScalarMultiplyTest, testing::Values(Param{ Type::f32, ast::BinaryOp::kMultiply, "OpFMul"})); +} // namespace BinaryArithVectorScalar + +namespace BinaryArithMatrixMatrix { + +struct Param { + ast::BinaryOp op; + std::string name; +}; + +using BinaryArithMatrixMatrix = TestParamHelper; +TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) { + auto& param = GetParam(); + + ast::Expression* lhs = mat3x4(); + ast::Expression* rhs = mat3x4(); + + auto* expr = create(param.op, lhs, rhs); + + WrapInFunction(expr); + + spirv::Builder& b = Build(); + ASSERT_TRUE(b.Build()) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %3 "test_function" +OpExecutionMode %3 LocalSize 1 1 1 +OpName %3 "test_function" +%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%7 = OpTypeFloat 32 +%6 = OpTypeVector %7 4 +%5 = OpTypeMatrix %6 3 +%8 = OpConstantNull %5 +%3 = OpFunction %2 None %1 +%4 = OpLabel +%10 = OpCompositeExtract %6 %8 0 +%11 = OpCompositeExtract %6 %8 0 +%12 = )" + param.name + R"( %6 %10 %11 +%13 = OpCompositeExtract %6 %8 1 +%14 = OpCompositeExtract %6 %8 1 +%15 = )" + param.name + R"( %6 %13 %14 +%16 = OpCompositeExtract %6 %8 2 +%17 = OpCompositeExtract %6 %8 2 +%18 = )" + param.name + R"( %6 %16 %17 +%19 = OpCompositeConstruct %5 %12 %15 %18 +OpReturn +OpFunctionEnd +)"); + + Validate(b); +} +INSTANTIATE_TEST_SUITE_P( // + BuilderTest, + BinaryArithMatrixMatrix, + testing::Values(Param{ast::BinaryOp::kAdd, "OpFAdd"}, + Param{ast::BinaryOp::kSubtract, "OpFSub"})); + +using BinaryArithMatrixMatrixMultiply = TestParamHelper; +TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) { + auto& param = GetParam(); + + ast::Expression* lhs = mat3x4(); + ast::Expression* rhs = mat4x3(); + + auto* expr = create(param.op, lhs, rhs); + + WrapInFunction(expr); + + spirv::Builder& b = Build(); + ASSERT_TRUE(b.Build()) << b.error(); + + EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %3 "test_function" +OpExecutionMode %3 LocalSize 1 1 1 +OpName %3 "test_function" +%2 = OpTypeVoid +%1 = OpTypeFunction %2 +%7 = OpTypeFloat 32 +%6 = OpTypeVector %7 4 +%5 = OpTypeMatrix %6 3 +%8 = OpConstantNull %5 +%10 = OpTypeVector %7 3 +%9 = OpTypeMatrix %10 4 +%11 = OpConstantNull %9 +%13 = OpTypeMatrix %6 4 +%3 = OpFunction %2 None %1 +%4 = OpLabel +%12 = OpMatrixTimesMatrix %13 %8 %11 +OpReturn +OpFunctionEnd +)"); + + Validate(b); +} +INSTANTIATE_TEST_SUITE_P( // + BuilderTest, + BinaryArithMatrixMatrixMultiply, + testing::Values(Param{ast::BinaryOp::kMultiply, "OpFMul"})); + +} // namespace BinaryArithMatrixMatrix + } // namespace } // namespace spirv } // namespace writer diff --git a/test/expressions/binary_expressions.wgsl b/test/expressions/binary_expressions.wgsl index a58eecef70..b7444a0587 100644 --- a/test/expressions/binary_expressions.wgsl +++ b/test/expressions/binary_expressions.wgsl @@ -64,6 +64,19 @@ fn scalar_vector_u32() { r = s % v; } +fn matrix_matrix_f32() { + var m34 : mat3x4; + var m43 : mat4x3; + var m33 : mat3x3; + var m44 : mat4x4; + + m34 = m34 + m34; + m34 = m34 - m34; + + m33 = m43 * m34; + m44 = m34 * m43; +} + [[stage(fragment)]] fn main() -> [[location(0)]] vec4 { return vec4(0.0,0.0,0.0,0.0); diff --git a/test/expressions/binary_expressions.wgsl.expected.hlsl b/test/expressions/binary_expressions.wgsl.expected.hlsl index 01f2d2bcac..c234983590 100644 --- a/test/expressions/binary_expressions.wgsl.expected.hlsl +++ b/test/expressions/binary_expressions.wgsl.expected.hlsl @@ -66,6 +66,17 @@ void scalar_vector_u32() { r = (s % v); } +void matrix_matrix_f32() { + float3x4 m34 = float3x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + float4x3 m43 = float4x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + float3x3 m33 = float3x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + float4x4 m44 = float4x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f); + m34 = (m34 + m34); + m34 = (m34 - m34); + m33 = mul(m34, m43); + m44 = mul(m43, m34); +} + tint_symbol main() { const tint_symbol tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f)}; return tint_symbol_1; diff --git a/test/expressions/binary_expressions.wgsl.expected.msl b/test/expressions/binary_expressions.wgsl.expected.msl index 56309406cb..e46aa45564 100644 --- a/test/expressions/binary_expressions.wgsl.expected.msl +++ b/test/expressions/binary_expressions.wgsl.expected.msl @@ -69,6 +69,17 @@ void scalar_vector_u32() { r = (s % v); } +void matrix_matrix_f32() { + float3x4 m34 = float3x4(0.0f); + float4x3 m43 = float4x3(0.0f); + float3x3 m33 = float3x3(0.0f); + float4x4 m44 = float4x4(0.0f); + m34 = (m34 + m34); + m34 = (m34 - m34); + m33 = (m43 * m34); + m44 = (m34 * m43); +} + fragment tint_symbol_1 tint_symbol() { return {float4(0.0f, 0.0f, 0.0f, 0.0f)}; } diff --git a/test/expressions/binary_expressions.wgsl.expected.spvasm b/test/expressions/binary_expressions.wgsl.expected.spvasm index 445bb49672..8052168bfa 100644 --- a/test/expressions/binary_expressions.wgsl.expected.spvasm +++ b/test/expressions/binary_expressions.wgsl.expected.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.3 ; Generator: Google Tint Compiler; 0 -; Bound: 200 +; Bound: 250 ; Schema: 0 OpCapability Shader OpMemoryModel Logical GLSL450 @@ -32,6 +32,11 @@ OpName %v_4 "v" OpName %s_4 "s" OpName %r_4 "r" + OpName %matrix_matrix_f32 "matrix_matrix_f32" + OpName %m34 "m34" + OpName %m43 "m43" + OpName %m33 "m33" + OpName %m44 "m44" OpName %tint_symbol_2 "tint_symbol_2" OpName %tint_symbol "tint_symbol" OpName %main "main" @@ -60,9 +65,21 @@ %78 = OpConstantNull %v3uint %_ptr_Function_uint = OpTypePointer Function %uint %81 = OpConstantNull %uint - %191 = OpTypeFunction %void %v4float +%mat3v4float = OpTypeMatrix %v4float 3 +%_ptr_Function_mat3v4float = OpTypePointer Function %mat3v4float + %196 = OpConstantNull %mat3v4float +%mat4v3float = OpTypeMatrix %v3float 4 +%_ptr_Function_mat4v3float = OpTypePointer Function %mat4v3float + %200 = OpConstantNull %mat4v3float +%mat3v3float = OpTypeMatrix %v3float 3 +%_ptr_Function_mat3v3float = OpTypePointer Function %mat3v3float + %204 = OpConstantNull %mat3v3float +%mat4v4float = OpTypeMatrix %v4float 4 +%_ptr_Function_mat4v4float = OpTypePointer Function %mat4v4float + %208 = OpConstantNull %mat4v4float + %241 = OpTypeFunction %void %v4float %float_0 = OpConstant %float 0 - %199 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 + %249 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0 %vector_scalar_f32 = OpFunction %void None %6 %9 = OpLabel %v = OpVariable %_ptr_Function_v3float Function %13 @@ -269,14 +286,56 @@ OpStore %r_4 %188 OpReturn OpFunctionEnd -%tint_symbol_2 = OpFunction %void None %191 +%matrix_matrix_f32 = OpFunction %void None %6 + %192 = OpLabel + %m34 = OpVariable %_ptr_Function_mat3v4float Function %196 + %m43 = OpVariable %_ptr_Function_mat4v3float Function %200 + %m33 = OpVariable %_ptr_Function_mat3v3float Function %204 + %m44 = OpVariable %_ptr_Function_mat4v4float Function %208 + %209 = OpLoad %mat3v4float %m34 + %210 = OpLoad %mat3v4float %m34 + %212 = OpCompositeExtract %v4float %209 0 + %213 = OpCompositeExtract %v4float %210 0 + %214 = OpFAdd %v4float %212 %213 + %215 = OpCompositeExtract %v4float %209 1 + %216 = OpCompositeExtract %v4float %210 1 + %217 = OpFAdd %v4float %215 %216 + %218 = OpCompositeExtract %v4float %209 2 + %219 = OpCompositeExtract %v4float %210 2 + %220 = OpFAdd %v4float %218 %219 + %221 = OpCompositeConstruct %mat3v4float %214 %217 %220 + OpStore %m34 %221 + %222 = OpLoad %mat3v4float %m34 + %223 = OpLoad %mat3v4float %m34 + %225 = OpCompositeExtract %v4float %222 0 + %226 = OpCompositeExtract %v4float %223 0 + %227 = OpFSub %v4float %225 %226 + %228 = OpCompositeExtract %v4float %222 1 + %229 = OpCompositeExtract %v4float %223 1 + %230 = OpFSub %v4float %228 %229 + %231 = OpCompositeExtract %v4float %222 2 + %232 = OpCompositeExtract %v4float %223 2 + %233 = OpFSub %v4float %231 %232 + %234 = OpCompositeConstruct %mat3v4float %227 %230 %233 + OpStore %m34 %234 + %235 = OpLoad %mat4v3float %m43 + %236 = OpLoad %mat3v4float %m34 + %237 = OpMatrixTimesMatrix %mat3v3float %235 %236 + OpStore %m33 %237 + %238 = OpLoad %mat3v4float %m34 + %239 = OpLoad %mat4v3float %m43 + %240 = OpMatrixTimesMatrix %mat4v4float %238 %239 + OpStore %m44 %240 + OpReturn + OpFunctionEnd +%tint_symbol_2 = OpFunction %void None %241 %tint_symbol = OpFunctionParameter %v4float - %194 = OpLabel + %244 = OpLabel OpStore %tint_symbol_1 %tint_symbol OpReturn OpFunctionEnd %main = OpFunction %void None %6 - %196 = OpLabel - %197 = OpFunctionCall %void %tint_symbol_2 %199 + %246 = OpLabel + %247 = OpFunctionCall %void %tint_symbol_2 %249 OpReturn OpFunctionEnd diff --git a/test/expressions/binary_expressions.wgsl.expected.wgsl b/test/expressions/binary_expressions.wgsl.expected.wgsl index ec7e7d505e..cd0bf554b1 100644 --- a/test/expressions/binary_expressions.wgsl.expected.wgsl +++ b/test/expressions/binary_expressions.wgsl.expected.wgsl @@ -62,6 +62,17 @@ fn scalar_vector_u32() { r = (s % v); } +fn matrix_matrix_f32() { + var m34 : mat3x4; + var m43 : mat4x3; + var m33 : mat3x3; + var m44 : mat4x4; + m34 = (m34 + m34); + m34 = (m34 - m34); + m33 = (m43 * m34); + m44 = (m34 * m43); +} + [[stage(fragment)]] fn main() -> [[location(0)]] vec4 { return vec4(0.0, 0.0, 0.0, 0.0);