From a4f49d91dc8d9ff0aac711485767a316aabea0c1 Mon Sep 17 00:00:00 2001 From: David Neto Date: Wed, 11 Nov 2020 13:59:24 +0000 Subject: [PATCH] spirv-reader: support OpDot, OpOuterProduct Change-Id: I39f2369572a340be1c4c7c6e4a2c8e0e9347d792 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/32200 Reviewed-by: dan sinclair Reviewed-by: Ben Clayton Commit-Queue: David Neto --- src/reader/spirv/function.cc | 43 +++++++++++++ src/reader/spirv/function.h | 6 ++ src/reader/spirv/function_arithmetic_test.cc | 64 +++++++++++++++++++- 3 files changed, 111 insertions(+), 2 deletions(-) diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 18f715b99c..b42713fbbc 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -42,6 +42,7 @@ #include "src/ast/fallthrough_statement.h" #include "src/ast/identifier_expression.h" #include "src/ast/if_statement.h" +#include "src/ast/intrinsic.h" #include "src/ast/loop_statement.h" #include "src/ast/member_accessor_expression.h" #include "src/ast/return_statement.h" @@ -335,6 +336,20 @@ std::string GetGlslStd450FuncName(uint32_t ext_opcode) { return ""; } +// Returns the WGSL standard library function instrinsic for the +// given instruction, or ast::Intrinsic::kNone +ast::Intrinsic GetIntrinsic(SpvOp opcode) { + switch (opcode) { + case SpvOpDot: + return ast::Intrinsic::kDot; + case SpvOpOuterProduct: + return ast::Intrinsic::kOuterProduct; + default: + break; + } + return ast::Intrinsic::kNone; +} + // @returns the merge block ID for the given basic block, or 0 if there is none. uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) { // Get the OpSelectionMerge or OpLoopMerge instruction, if any. @@ -2715,6 +2730,11 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue( std::move(params))}; } + const auto intrinsic = GetIntrinsic(opcode); + if (intrinsic != ast::Intrinsic::kNone) { + return MakeIntrinsicCall(inst); + } + if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) { return MakeAccessChain(inst); } @@ -3505,6 +3525,29 @@ bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) { {result_type, std::move(call_expr)}); } +TypedExpression FunctionEmitter::MakeIntrinsicCall( + const spvtools::opt::Instruction& inst) { + const auto intrinsic = GetIntrinsic(inst.opcode()); + std::ostringstream ss; + ss << intrinsic; + auto ident = std::make_unique(ss.str()); + ident->set_intrinsic(intrinsic); + + ast::ExpressionList params; + for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) { + params.emplace_back(MakeOperand(inst, iarg).expr); + } + auto call_expr = std::make_unique(std::move(ident), + std::move(params)); + auto* result_type = parser_impl_.ConvertType(inst.type_id()); + if (!result_type) { + Fail() << "internal error: no mapped type result of call: " + << inst.PrettyPrint(); + return {}; + } + return {result_type, std::move(call_expr)}; +} + TypedExpression FunctionEmitter::MakeSimpleSelect( const spvtools::opt::Instruction& inst) { auto condition = MakeOperand(inst, 0); diff --git a/src/reader/spirv/function.h b/src/reader/spirv/function.h index c6862c0625..c9df72aaa1 100644 --- a/src/reader/spirv/function.h +++ b/src/reader/spirv/function.h @@ -689,6 +689,12 @@ class FunctionEmitter { /// @returns false if emission failed bool EmitFunctionCall(const spvtools::opt::Instruction& inst); + /// Returns an expression for a SPIR-V instruction that maps to a WGSL + /// intrinsic function call. + /// @param inst the SPIR-V instruction + /// @returns an expression + TypedExpression MakeIntrinsicCall(const spvtools::opt::Instruction& inst); + /// Returns an expression for an OpSelect, if its operands are scalars /// or vectors. These translate directly to WGSL select. Otherwise, return /// an expression with a null owned expression diff --git a/src/reader/spirv/function_arithmetic_test.cc b/src/reader/spirv/function_arithmetic_test.cc index bd86258cef..da958c1a8a 100644 --- a/src/reader/spirv/function_arithmetic_test.cc +++ b/src/reader/spirv/function_arithmetic_test.cc @@ -1053,14 +1053,74 @@ TEST_F(SpvBinaryArithTestBasic, MatrixTimesMatrix) { << ToString(fe.ast_body()); } +TEST_F(SpvBinaryArithTestBasic, Dot) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %v2float %v2float_50_60 + %2 = OpCopyObject %v2float %v2float_60_50 + %3 = OpDot %float %1 %2 + OpReturn + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableConst{ + x_3 + none + __f32 + { + Call{ + Identifier{dot} + ( + Identifier{x_1} + Identifier{x_2} + ) + } + } + })")) + << ToString(fe.ast_body()); +} + +TEST_F(SpvBinaryArithTestBasic, OuterProduct) { + const auto assembly = CommonTypes() + R"( + %100 = OpFunction %void None %voidfn + %entry = OpLabel + %1 = OpCopyObject %v2float %v2float_50_60 + %2 = OpCopyObject %v2float %v2float_60_50 + %3 = OpOuterProduct %m2v2float %1 %2 + OpReturn + OpFunctionEnd +)"; + auto* p = parser(test::Assemble(assembly)); + ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly; + FunctionEmitter fe(p, *spirv_function(100)); + EXPECT_TRUE(fe.EmitBody()) << p->error(); + EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableConst{ + x_3 + none + __mat_2_2__f32 + { + Call{ + Identifier{outerProduct} + ( + Identifier{x_1} + Identifier{x_2} + ) + } + } + })")) + << ToString(fe.ast_body()); +} + // TODO(dneto): OpSRem. Missing from WGSL // https://github.com/gpuweb/gpuweb/issues/702 // TODO(dneto): OpFRem. Missing from WGSL // https://github.com/gpuweb/gpuweb/issues/702 -// TODO(dneto): OpOuterProduct -// TODO(dneto): OpDot // TODO(dneto): OpIAddCarry // TODO(dneto): OpISubBorrow // TODO(dneto): OpIMulExtended