spirv-reader: support OpDot, OpOuterProduct

Change-Id: I39f2369572a340be1c4c7c6e4a2c8e0e9347d792
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/32200
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: David Neto <dneto@google.com>
This commit is contained in:
David Neto 2020-11-11 13:59:24 +00:00 committed by Commit Bot service account
parent 6526bd4f73
commit a4f49d91dc
3 changed files with 111 additions and 2 deletions

View File

@ -42,6 +42,7 @@
#include "src/ast/fallthrough_statement.h"
#include "src/ast/identifier_expression.h"
#include "src/ast/if_statement.h"
#include "src/ast/intrinsic.h"
#include "src/ast/loop_statement.h"
#include "src/ast/member_accessor_expression.h"
#include "src/ast/return_statement.h"
@ -335,6 +336,20 @@ std::string GetGlslStd450FuncName(uint32_t ext_opcode) {
return "";
}
// Returns the WGSL standard library function instrinsic for the
// given instruction, or ast::Intrinsic::kNone
ast::Intrinsic GetIntrinsic(SpvOp opcode) {
switch (opcode) {
case SpvOpDot:
return ast::Intrinsic::kDot;
case SpvOpOuterProduct:
return ast::Intrinsic::kOuterProduct;
default:
break;
}
return ast::Intrinsic::kNone;
}
// @returns the merge block ID for the given basic block, or 0 if there is none.
uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) {
// Get the OpSelectionMerge or OpLoopMerge instruction, if any.
@ -2715,6 +2730,11 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
std::move(params))};
}
const auto intrinsic = GetIntrinsic(opcode);
if (intrinsic != ast::Intrinsic::kNone) {
return MakeIntrinsicCall(inst);
}
if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) {
return MakeAccessChain(inst);
}
@ -3505,6 +3525,29 @@ bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
{result_type, std::move(call_expr)});
}
TypedExpression FunctionEmitter::MakeIntrinsicCall(
const spvtools::opt::Instruction& inst) {
const auto intrinsic = GetIntrinsic(inst.opcode());
std::ostringstream ss;
ss << intrinsic;
auto ident = std::make_unique<ast::IdentifierExpression>(ss.str());
ident->set_intrinsic(intrinsic);
ast::ExpressionList params;
for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
params.emplace_back(MakeOperand(inst, iarg).expr);
}
auto call_expr = std::make_unique<ast::CallExpression>(std::move(ident),
std::move(params));
auto* result_type = parser_impl_.ConvertType(inst.type_id());
if (!result_type) {
Fail() << "internal error: no mapped type result of call: "
<< inst.PrettyPrint();
return {};
}
return {result_type, std::move(call_expr)};
}
TypedExpression FunctionEmitter::MakeSimpleSelect(
const spvtools::opt::Instruction& inst) {
auto condition = MakeOperand(inst, 0);

View File

@ -689,6 +689,12 @@ class FunctionEmitter {
/// @returns false if emission failed
bool EmitFunctionCall(const spvtools::opt::Instruction& inst);
/// Returns an expression for a SPIR-V instruction that maps to a WGSL
/// intrinsic function call.
/// @param inst the SPIR-V instruction
/// @returns an expression
TypedExpression MakeIntrinsicCall(const spvtools::opt::Instruction& inst);
/// Returns an expression for an OpSelect, if its operands are scalars
/// or vectors. These translate directly to WGSL select. Otherwise, return
/// an expression with a null owned expression

View File

@ -1053,14 +1053,74 @@ TEST_F(SpvBinaryArithTestBasic, MatrixTimesMatrix) {
<< ToString(fe.ast_body());
}
TEST_F(SpvBinaryArithTestBasic, Dot) {
const auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpCopyObject %v2float %v2float_50_60
%2 = OpCopyObject %v2float %v2float_60_50
%3 = OpDot %float %1 %2
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableConst{
x_3
none
__f32
{
Call{
Identifier{dot}
(
Identifier{x_1}
Identifier{x_2}
)
}
}
})"))
<< ToString(fe.ast_body());
}
TEST_F(SpvBinaryArithTestBasic, OuterProduct) {
const auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpCopyObject %v2float %v2float_50_60
%2 = OpCopyObject %v2float %v2float_60_50
%3 = OpOuterProduct %m2v2float %1 %2
OpReturn
OpFunctionEnd
)";
auto* p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
FunctionEmitter fe(p, *spirv_function(100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableConst{
x_3
none
__mat_2_2__f32
{
Call{
Identifier{outerProduct}
(
Identifier{x_1}
Identifier{x_2}
)
}
}
})"))
<< ToString(fe.ast_body());
}
// TODO(dneto): OpSRem. Missing from WGSL
// https://github.com/gpuweb/gpuweb/issues/702
// TODO(dneto): OpFRem. Missing from WGSL
// https://github.com/gpuweb/gpuweb/issues/702
// TODO(dneto): OpOuterProduct
// TODO(dneto): OpDot
// TODO(dneto): OpIAddCarry
// TODO(dneto): OpISubBorrow
// TODO(dneto): OpIMulExtended