[spirv-reader] Add mixed scalar/vector/matrix multiply

Bug: tint:3
Change-Id: I5875bf453b05c5d5c96f90122206da04f6799976
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23401
Reviewed-by: dan sinclair <dsinclair@google.com>
This commit is contained in:
David Neto 2020-06-17 23:33:35 +00:00 committed by dan sinclair
parent b961e0069b
commit e12c5ff42e
2 changed files with 154 additions and 5 deletions

View File

@ -200,6 +200,11 @@ ast::BinaryOp ConvertBinaryOp(SpvOp opcode) {
return ast::BinaryOp::kSubtract;
case SpvOpIMul:
case SpvOpFMul:
case SpvOpVectorTimesScalar:
case SpvOpMatrixTimesScalar:
case SpvOpVectorTimesMatrix:
case SpvOpMatrixTimesVector:
case SpvOpMatrixTimesMatrix:
return ast::BinaryOp::kMultiply;
case SpvOpUDiv:
case SpvOpSDiv:

View File

@ -58,6 +58,10 @@ std::string CommonTypes() {
%v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
%m2v2float = OpTypeMatrix %v2float 2
%m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
%m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
)";
}
@ -904,17 +908,157 @@ INSTANTIATE_TEST_SUITE_P(
"__vec_2__f32", AstFor("v2float_50_60"), "modulo",
AstFor("v2float_60_50")}));
TEST_F(SpvBinaryArithTestBasic, VectorTimesScalar) {
const auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpCopyObject %v2float %v2float_50_60
%2 = OpCopyObject %float %float_50
%10 = OpVectorTimesScalar %v2float %1 %2
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
x_10
none
__vec_2__f32
{
Binary{
Identifier{x_1}
multiply
Identifier{x_2}
}
}
})"))
<< ToString(fe.ast_body());
}
TEST_F(SpvBinaryArithTestBasic, MatrixTimesScalar) {
const auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpCopyObject %m2v2float %m2v2float_a
%2 = OpCopyObject %float %float_50
%10 = OpMatrixTimesScalar %m2v2float %1 %2
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
x_10
none
__mat_2_2__f32
{
Binary{
Identifier{x_1}
multiply
Identifier{x_2}
}
}
})"))
<< ToString(fe.ast_body());
}
TEST_F(SpvBinaryArithTestBasic, VectorTimesMatrix) {
const auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpCopyObject %v2float %v2float_50_60
%2 = OpCopyObject %m2v2float %m2v2float_a
%10 = OpMatrixTimesVector %m2v2float %1 %2
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
x_10
none
__mat_2_2__f32
{
Binary{
Identifier{x_1}
multiply
Identifier{x_2}
}
}
})"))
<< ToString(fe.ast_body());
}
TEST_F(SpvBinaryArithTestBasic, MatrixTimesVector) {
const auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpCopyObject %m2v2float %m2v2float_a
%2 = OpCopyObject %v2float %v2float_50_60
%10 = OpMatrixTimesVector %m2v2float %1 %2
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
x_10
none
__mat_2_2__f32
{
Binary{
Identifier{x_1}
multiply
Identifier{x_2}
}
}
})"))
<< ToString(fe.ast_body());
}
TEST_F(SpvBinaryArithTestBasic, MatrixTimesMatrix) {
const auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpCopyObject %m2v2float %m2v2float_a
%2 = OpCopyObject %m2v2float %m2v2float_b
%10 = OpMatrixTimesMatrix %m2v2float %1 %2
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
x_10
none
__mat_2_2__f32
{
Binary{
Identifier{x_1}
multiply
Identifier{x_2}
}
}
})"))
<< ToString(fe.ast_body());
}
// TODO(dneto): OpSRem. Missing from WGSL
// https://github.com/gpuweb/gpuweb/issues/702
// TODO(dneto): OpFRem. Missing from WGSL
// https://github.com/gpuweb/gpuweb/issues/702
// TODO(dneto): OpVectorTimesScalar
// TODO(dneto): OpMatrixTimesScalar
// TODO(dneto): OpVectorTimesMatrix
// TODO(dneto): OpMatrixTimesVector
// TODO(dneto): OpMatrixTimesMatrix
// TODO(dneto): OpOuterProduct
// TODO(dneto): OpDot
// TODO(dneto): OpIAddCarry