diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 932c5ca8ea..05c3b3017c 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -200,6 +200,11 @@ ast::BinaryOp ConvertBinaryOp(SpvOp opcode) { return ast::BinaryOp::kSubtract; case SpvOpIMul: case SpvOpFMul: + case SpvOpVectorTimesScalar: + case SpvOpMatrixTimesScalar: + case SpvOpVectorTimesMatrix: + case SpvOpMatrixTimesVector: + case SpvOpMatrixTimesMatrix: return ast::BinaryOp::kMultiply; case SpvOpUDiv: case SpvOpSDiv: diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc index c92a84cb1d..ba9656d2de 100644 --- a/src/reader/spirv/function_arithmetic_test.cc +++ b/src/reader/spirv/function_arithmetic_test.cc @@ -58,6 +58,10 @@ std::string CommonTypes() { %v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30 %v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60 %v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50 + + %m2v2float = OpTypeMatrix %v2float 2 + %m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50 + %m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60 )"; } @@ -904,17 +908,157 @@ INSTANTIATE_TEST_SUITE_P( "__vec_2__f32", AstFor("v2float_50_60"), "modulo", AstFor("v2float_60_50")})); +TEST_F(SpvBinaryArithTestBasic, VectorTimesScalar) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %v2float %v2float_50_60 + %2 = OpCopyObject %float %float_50 + %10 = OpVectorTimesScalar %v2float %1 %2 + OpReturn + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_10 + none + __vec_2__f32 + { + Binary{ + Identifier{x_1} + multiply + Identifier{x_2} + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvBinaryArithTestBasic, MatrixTimesScalar) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %m2v2float %m2v2float_a + %2 = OpCopyObject %float %float_50 + %10 = OpMatrixTimesScalar %m2v2float %1 %2 + OpReturn + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_10 + none + __mat_2_2__f32 + { + Binary{ + Identifier{x_1} + multiply + Identifier{x_2} + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvBinaryArithTestBasic, VectorTimesMatrix) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %v2float %v2float_50_60 + %2 = OpCopyObject %m2v2float %m2v2float_a + %10 = OpMatrixTimesVector %m2v2float %1 %2 + OpReturn + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_10 + none + __mat_2_2__f32 + { + Binary{ + Identifier{x_1} + multiply + Identifier{x_2} + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvBinaryArithTestBasic, MatrixTimesVector) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %m2v2float %m2v2float_a + %2 = OpCopyObject %v2float %v2float_50_60 + %10 = OpMatrixTimesVector %m2v2float %1 %2 + OpReturn + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_10 + none + __mat_2_2__f32 + { + Binary{ + Identifier{x_1} + multiply + Identifier{x_2} + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvBinaryArithTestBasic, MatrixTimesMatrix) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %m2v2float %m2v2float_a + %2 = OpCopyObject %m2v2float %m2v2float_b + %10 = OpMatrixTimesMatrix %m2v2float %1 %2 + OpReturn + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{ + x_10 + none + __mat_2_2__f32 + { + Binary{ + Identifier{x_1} + multiply + Identifier{x_2} + } + } + })")) + << ToString(fe.ast_body()); +} + // TODO(dneto): OpSRem. Missing from WGSL // https://github.com/gpuweb/gpuweb/issues/702 // TODO(dneto): OpFRem. Missing from WGSL // https://github.com/gpuweb/gpuweb/issues/702 -// TODO(dneto): OpVectorTimesScalar -// TODO(dneto): OpMatrixTimesScalar -// TODO(dneto): OpVectorTimesMatrix -// TODO(dneto): OpMatrixTimesVector -// TODO(dneto): OpMatrixTimesMatrix // TODO(dneto): OpOuterProduct // TODO(dneto): OpDot // TODO(dneto): OpIAddCarry