From 943ded79d2b1cdb02d5f2deff613fa5816d0a64f Mon Sep 17 00:00:00 2001 From: David Neto Date: Thu, 22 Apr 2021 12:29:45 +0000 Subject: [PATCH] spirv-reader: support OpTranspose Change-Id: If338b22b703257e863e511579cfd3abcaa0bdfe7 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/48761 Auto-Submit: David Neto Kokoro: Kokoro Commit-Queue: Alan Baker Reviewed-by: Alan Baker --- src/reader/spirv/function.cc | 2 + src/reader/spirv/function_arithmetic_test.cc | 108 ++++++++++++++++++- 2 files changed, 109 insertions(+), 1 deletion(-) diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 888f7cce70..a6270c1bb8 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -175,6 +175,8 @@ const char* GetUnaryBuiltInFunctionName(SpvOp opcode) { return "isNan"; case SpvOpIsInf: return "isInf"; + case SpvOpTranspose: + return "transpose"; default: break; } diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc index 9a05b46c90..869729a605 100644 --- a/src/reader/spirv/function_arithmetic_test.cc +++ b/src/reader/spirv/function_arithmetic_test.cc @@ -57,11 +57,15 @@ std::string CommonTypes() { %v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60 %v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50 %v3float_50_60_70 = OpConstantComposite %v2float %float_50 %float_60 %float_70 + %v3float_60_70_50 = OpConstantComposite %v2float %float_60 %float_70 %float_50 %m2v2float = OpTypeMatrix %v2float 2 + %m2v3float = OpTypeMatrix %v3float 2 + %m3v2float = OpTypeMatrix %v2float 3 %m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50 %m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60 - %m2v3float = OpTypeMatrix %v3float 2 + %m3v2float_a = OpConstantComposite %m3v2float %v2float_50_60 %v2float_60_50 %v2float_50_60 + %m2v3float_a = OpConstantComposite %m2v3float %v3float_50_60_70 %v3float_60_70_50 )"; } @@ -1551,6 +1555,108 @@ INSTANTIATE_TEST_SUITE_P( ArgAndTypeData{"v2float", "v2float_50_60", "__vec_2__f32"}, ArgAndTypeData{"v3float", "v3float_50_60_70", "__vec_3__f32"}))); +TEST_F(SpvUnaryArithTest, Transpose_2x2) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %m2v2float %m2v2float_a + %2 = OpTranspose %m2v2float %1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto* expected = R"( +VariableDeclStatement{ + VariableConst{ + x_2 + none + __mat_2_2__f32 + { + Call[not set]{ + Identifier[not set]{transpose} + ( + Identifier[not set]{x_1} + ) + } + } + } +})"; + const auto got = ToString(p->builder(), fe.ast_body()); + EXPECT_THAT(got, HasSubstr(expected)) << got; +} + +TEST_F(SpvUnaryArithTest, Transpose_2x3) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %m2v3float %m2v3float_a + %2 = OpTranspose %m3v2float %1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + // Note, in the AST dump mat_2_3 means 2 rows and 3 columns. + // So the column vectors have 2 elements. + // That is, %m3v2float is __mat_2_3__f32. + const auto* expected = R"( +VariableDeclStatement{ + VariableConst{ + x_2 + none + __mat_2_3__f32 + { + Call[not set]{ + Identifier[not set]{transpose} + ( + Identifier[not set]{x_1} + ) + } + } + } +})"; + const auto got = ToString(p->builder(), fe.ast_body()); + EXPECT_THAT(got, HasSubstr(expected)) << got; +} + +TEST_F(SpvUnaryArithTest, Transpose_3x2) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %m3v2float %m3v2float_a + %2 = OpTranspose %m2v3float %1 + OpReturn + OpFunctionEnd + )"; + auto p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); + FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + const auto* expected = R"( +VariableDeclStatement{ + VariableConst{ + x_2 + none + __mat_3_2__f32 + { + Call[not set]{ + Identifier[not set]{transpose} + ( + Identifier[not set]{x_1} + ) + } + } + } +})"; + const auto got = ToString(p->builder(), fe.ast_body()); + EXPECT_THAT(got, HasSubstr(expected)) << got; +} + // TODO(dneto): OpSRem. Missing from WGSL // https://github.com/gpuweb/gpuweb/issues/702