[spirv-reader] Add mixed scalar/vector/matrix multiply
Bug: tint:3 Change-Id: I5875bf453b05c5d5c96f90122206da04f6799976 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23401 Reviewed-by: dan sinclair <dsinclair@google.com>
This commit is contained in:
parent
b961e0069b
commit
e12c5ff42e
|
@ -200,6 +200,11 @@ ast::BinaryOp ConvertBinaryOp(SpvOp opcode) {
|
|||
return ast::BinaryOp::kSubtract;
|
||||
case SpvOpIMul:
|
||||
case SpvOpFMul:
|
||||
case SpvOpVectorTimesScalar:
|
||||
case SpvOpMatrixTimesScalar:
|
||||
case SpvOpVectorTimesMatrix:
|
||||
case SpvOpMatrixTimesVector:
|
||||
case SpvOpMatrixTimesMatrix:
|
||||
return ast::BinaryOp::kMultiply;
|
||||
case SpvOpUDiv:
|
||||
case SpvOpSDiv:
|
||||
|
|
|
@ -58,6 +58,10 @@ std::string CommonTypes() {
|
|||
%v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
|
||||
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
|
||||
%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
|
||||
|
||||
%m2v2float = OpTypeMatrix %v2float 2
|
||||
%m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
|
||||
%m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
|
||||
)";
|
||||
}
|
||||
|
||||
|
@ -904,17 +908,157 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
"__vec_2__f32", AstFor("v2float_50_60"), "modulo",
|
||||
AstFor("v2float_60_50")}));
|
||||
|
||||
TEST_F(SpvBinaryArithTestBasic, VectorTimesScalar) {
|
||||
const auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
%entry = OpLabel
|
||||
%1 = OpCopyObject %v2float %v2float_50_60
|
||||
%2 = OpCopyObject %float %float_50
|
||||
%10 = OpVectorTimesScalar %v2float %1 %2
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
auto* p = parser(test::Assemble(assembly));
|
||||
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||
FunctionEmitter fe(p, *spirv_function(100));
|
||||
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
|
||||
x_10
|
||||
none
|
||||
__vec_2__f32
|
||||
{
|
||||
Binary{
|
||||
Identifier{x_1}
|
||||
multiply
|
||||
Identifier{x_2}
|
||||
}
|
||||
}
|
||||
})"))
|
||||
<< ToString(fe.ast_body());
|
||||
}
|
||||
|
||||
TEST_F(SpvBinaryArithTestBasic, MatrixTimesScalar) {
|
||||
const auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
%entry = OpLabel
|
||||
%1 = OpCopyObject %m2v2float %m2v2float_a
|
||||
%2 = OpCopyObject %float %float_50
|
||||
%10 = OpMatrixTimesScalar %m2v2float %1 %2
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
auto* p = parser(test::Assemble(assembly));
|
||||
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||
FunctionEmitter fe(p, *spirv_function(100));
|
||||
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
|
||||
x_10
|
||||
none
|
||||
__mat_2_2__f32
|
||||
{
|
||||
Binary{
|
||||
Identifier{x_1}
|
||||
multiply
|
||||
Identifier{x_2}
|
||||
}
|
||||
}
|
||||
})"))
|
||||
<< ToString(fe.ast_body());
|
||||
}
|
||||
|
||||
TEST_F(SpvBinaryArithTestBasic, VectorTimesMatrix) {
|
||||
const auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
%entry = OpLabel
|
||||
%1 = OpCopyObject %v2float %v2float_50_60
|
||||
%2 = OpCopyObject %m2v2float %m2v2float_a
|
||||
%10 = OpMatrixTimesVector %m2v2float %1 %2
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
auto* p = parser(test::Assemble(assembly));
|
||||
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||
FunctionEmitter fe(p, *spirv_function(100));
|
||||
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
|
||||
x_10
|
||||
none
|
||||
__mat_2_2__f32
|
||||
{
|
||||
Binary{
|
||||
Identifier{x_1}
|
||||
multiply
|
||||
Identifier{x_2}
|
||||
}
|
||||
}
|
||||
})"))
|
||||
<< ToString(fe.ast_body());
|
||||
}
|
||||
|
||||
TEST_F(SpvBinaryArithTestBasic, MatrixTimesVector) {
|
||||
const auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
%entry = OpLabel
|
||||
%1 = OpCopyObject %m2v2float %m2v2float_a
|
||||
%2 = OpCopyObject %v2float %v2float_50_60
|
||||
%10 = OpMatrixTimesVector %m2v2float %1 %2
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
auto* p = parser(test::Assemble(assembly));
|
||||
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||
FunctionEmitter fe(p, *spirv_function(100));
|
||||
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
|
||||
x_10
|
||||
none
|
||||
__mat_2_2__f32
|
||||
{
|
||||
Binary{
|
||||
Identifier{x_1}
|
||||
multiply
|
||||
Identifier{x_2}
|
||||
}
|
||||
}
|
||||
})"))
|
||||
<< ToString(fe.ast_body());
|
||||
}
|
||||
|
||||
TEST_F(SpvBinaryArithTestBasic, MatrixTimesMatrix) {
|
||||
const auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
%entry = OpLabel
|
||||
%1 = OpCopyObject %m2v2float %m2v2float_a
|
||||
%2 = OpCopyObject %m2v2float %m2v2float_b
|
||||
%10 = OpMatrixTimesMatrix %m2v2float %1 %2
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
auto* p = parser(test::Assemble(assembly));
|
||||
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||
FunctionEmitter fe(p, *spirv_function(100));
|
||||
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
|
||||
x_10
|
||||
none
|
||||
__mat_2_2__f32
|
||||
{
|
||||
Binary{
|
||||
Identifier{x_1}
|
||||
multiply
|
||||
Identifier{x_2}
|
||||
}
|
||||
}
|
||||
})"))
|
||||
<< ToString(fe.ast_body());
|
||||
}
|
||||
|
||||
// TODO(dneto): OpSRem. Missing from WGSL
|
||||
// https://github.com/gpuweb/gpuweb/issues/702
|
||||
|
||||
// TODO(dneto): OpFRem. Missing from WGSL
|
||||
// https://github.com/gpuweb/gpuweb/issues/702
|
||||
|
||||
// TODO(dneto): OpVectorTimesScalar
|
||||
// TODO(dneto): OpMatrixTimesScalar
|
||||
// TODO(dneto): OpVectorTimesMatrix
|
||||
// TODO(dneto): OpMatrixTimesVector
|
||||
// TODO(dneto): OpMatrixTimesMatrix
|
||||
// TODO(dneto): OpOuterProduct
|
||||
// TODO(dneto): OpDot
|
||||
// TODO(dneto): OpIAddCarry
|
||||
|
|
Loading…
Reference in New Issue