[spirv-reader] Add mixed scalar/vector/matrix multiply
Bug: tint:3 Change-Id: I5875bf453b05c5d5c96f90122206da04f6799976 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/23401 Reviewed-by: dan sinclair <dsinclair@google.com>
This commit is contained in:
parent
b961e0069b
commit
e12c5ff42e
|
@ -200,6 +200,11 @@ ast::BinaryOp ConvertBinaryOp(SpvOp opcode) {
|
||||||
return ast::BinaryOp::kSubtract;
|
return ast::BinaryOp::kSubtract;
|
||||||
case SpvOpIMul:
|
case SpvOpIMul:
|
||||||
case SpvOpFMul:
|
case SpvOpFMul:
|
||||||
|
case SpvOpVectorTimesScalar:
|
||||||
|
case SpvOpMatrixTimesScalar:
|
||||||
|
case SpvOpVectorTimesMatrix:
|
||||||
|
case SpvOpMatrixTimesVector:
|
||||||
|
case SpvOpMatrixTimesMatrix:
|
||||||
return ast::BinaryOp::kMultiply;
|
return ast::BinaryOp::kMultiply;
|
||||||
case SpvOpUDiv:
|
case SpvOpUDiv:
|
||||||
case SpvOpSDiv:
|
case SpvOpSDiv:
|
||||||
|
|
|
@ -58,6 +58,10 @@ std::string CommonTypes() {
|
||||||
%v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
|
%v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
|
||||||
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
|
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
|
||||||
%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
|
%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
|
||||||
|
|
||||||
|
%m2v2float = OpTypeMatrix %v2float 2
|
||||||
|
%m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
|
||||||
|
%m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
|
||||||
)";
|
)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -904,17 +908,157 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
"__vec_2__f32", AstFor("v2float_50_60"), "modulo",
|
"__vec_2__f32", AstFor("v2float_50_60"), "modulo",
|
||||||
AstFor("v2float_60_50")}));
|
AstFor("v2float_60_50")}));
|
||||||
|
|
||||||
|
TEST_F(SpvBinaryArithTestBasic, VectorTimesScalar) {
|
||||||
|
const auto assembly = CommonTypes() + R"(
|
||||||
|
%100 = OpFunction %void None %voidfn
|
||||||
|
%entry = OpLabel
|
||||||
|
%1 = OpCopyObject %v2float %v2float_50_60
|
||||||
|
%2 = OpCopyObject %float %float_50
|
||||||
|
%10 = OpVectorTimesScalar %v2float %1 %2
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
auto* p = parser(test::Assemble(assembly));
|
||||||
|
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||||
|
FunctionEmitter fe(p, *spirv_function(100));
|
||||||
|
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||||
|
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
|
||||||
|
x_10
|
||||||
|
none
|
||||||
|
__vec_2__f32
|
||||||
|
{
|
||||||
|
Binary{
|
||||||
|
Identifier{x_1}
|
||||||
|
multiply
|
||||||
|
Identifier{x_2}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})"))
|
||||||
|
<< ToString(fe.ast_body());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SpvBinaryArithTestBasic, MatrixTimesScalar) {
|
||||||
|
const auto assembly = CommonTypes() + R"(
|
||||||
|
%100 = OpFunction %void None %voidfn
|
||||||
|
%entry = OpLabel
|
||||||
|
%1 = OpCopyObject %m2v2float %m2v2float_a
|
||||||
|
%2 = OpCopyObject %float %float_50
|
||||||
|
%10 = OpMatrixTimesScalar %m2v2float %1 %2
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
auto* p = parser(test::Assemble(assembly));
|
||||||
|
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||||
|
FunctionEmitter fe(p, *spirv_function(100));
|
||||||
|
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||||
|
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
|
||||||
|
x_10
|
||||||
|
none
|
||||||
|
__mat_2_2__f32
|
||||||
|
{
|
||||||
|
Binary{
|
||||||
|
Identifier{x_1}
|
||||||
|
multiply
|
||||||
|
Identifier{x_2}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})"))
|
||||||
|
<< ToString(fe.ast_body());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SpvBinaryArithTestBasic, VectorTimesMatrix) {
|
||||||
|
const auto assembly = CommonTypes() + R"(
|
||||||
|
%100 = OpFunction %void None %voidfn
|
||||||
|
%entry = OpLabel
|
||||||
|
%1 = OpCopyObject %v2float %v2float_50_60
|
||||||
|
%2 = OpCopyObject %m2v2float %m2v2float_a
|
||||||
|
%10 = OpMatrixTimesVector %m2v2float %1 %2
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
auto* p = parser(test::Assemble(assembly));
|
||||||
|
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||||
|
FunctionEmitter fe(p, *spirv_function(100));
|
||||||
|
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||||
|
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
|
||||||
|
x_10
|
||||||
|
none
|
||||||
|
__mat_2_2__f32
|
||||||
|
{
|
||||||
|
Binary{
|
||||||
|
Identifier{x_1}
|
||||||
|
multiply
|
||||||
|
Identifier{x_2}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})"))
|
||||||
|
<< ToString(fe.ast_body());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SpvBinaryArithTestBasic, MatrixTimesVector) {
|
||||||
|
const auto assembly = CommonTypes() + R"(
|
||||||
|
%100 = OpFunction %void None %voidfn
|
||||||
|
%entry = OpLabel
|
||||||
|
%1 = OpCopyObject %m2v2float %m2v2float_a
|
||||||
|
%2 = OpCopyObject %v2float %v2float_50_60
|
||||||
|
%10 = OpMatrixTimesVector %m2v2float %1 %2
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
auto* p = parser(test::Assemble(assembly));
|
||||||
|
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||||
|
FunctionEmitter fe(p, *spirv_function(100));
|
||||||
|
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||||
|
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
|
||||||
|
x_10
|
||||||
|
none
|
||||||
|
__mat_2_2__f32
|
||||||
|
{
|
||||||
|
Binary{
|
||||||
|
Identifier{x_1}
|
||||||
|
multiply
|
||||||
|
Identifier{x_2}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})"))
|
||||||
|
<< ToString(fe.ast_body());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SpvBinaryArithTestBasic, MatrixTimesMatrix) {
|
||||||
|
const auto assembly = CommonTypes() + R"(
|
||||||
|
%100 = OpFunction %void None %voidfn
|
||||||
|
%entry = OpLabel
|
||||||
|
%1 = OpCopyObject %m2v2float %m2v2float_a
|
||||||
|
%2 = OpCopyObject %m2v2float %m2v2float_b
|
||||||
|
%10 = OpMatrixTimesMatrix %m2v2float %1 %2
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
auto* p = parser(test::Assemble(assembly));
|
||||||
|
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||||
|
FunctionEmitter fe(p, *spirv_function(100));
|
||||||
|
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||||
|
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(Variable{
|
||||||
|
x_10
|
||||||
|
none
|
||||||
|
__mat_2_2__f32
|
||||||
|
{
|
||||||
|
Binary{
|
||||||
|
Identifier{x_1}
|
||||||
|
multiply
|
||||||
|
Identifier{x_2}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})"))
|
||||||
|
<< ToString(fe.ast_body());
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(dneto): OpSRem. Missing from WGSL
|
// TODO(dneto): OpSRem. Missing from WGSL
|
||||||
// https://github.com/gpuweb/gpuweb/issues/702
|
// https://github.com/gpuweb/gpuweb/issues/702
|
||||||
|
|
||||||
// TODO(dneto): OpFRem. Missing from WGSL
|
// TODO(dneto): OpFRem. Missing from WGSL
|
||||||
// https://github.com/gpuweb/gpuweb/issues/702
|
// https://github.com/gpuweb/gpuweb/issues/702
|
||||||
|
|
||||||
// TODO(dneto): OpVectorTimesScalar
|
|
||||||
// TODO(dneto): OpMatrixTimesScalar
|
|
||||||
// TODO(dneto): OpVectorTimesMatrix
|
|
||||||
// TODO(dneto): OpMatrixTimesVector
|
|
||||||
// TODO(dneto): OpMatrixTimesMatrix
|
|
||||||
// TODO(dneto): OpOuterProduct
|
// TODO(dneto): OpOuterProduct
|
||||||
// TODO(dneto): OpDot
|
// TODO(dneto): OpDot
|
||||||
// TODO(dneto): OpIAddCarry
|
// TODO(dneto): OpIAddCarry
|
||||||
|
|
Loading…
Reference in New Issue