spirv-reader: support OpDot, OpOuterProduct
Change-Id: I39f2369572a340be1c4c7c6e4a2c8e0e9347d792 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/32200 Reviewed-by: dan sinclair <dsinclair@chromium.org> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: David Neto <dneto@google.com>
This commit is contained in:
parent
6526bd4f73
commit
a4f49d91dc
|
@ -42,6 +42,7 @@
|
||||||
#include "src/ast/fallthrough_statement.h"
|
#include "src/ast/fallthrough_statement.h"
|
||||||
#include "src/ast/identifier_expression.h"
|
#include "src/ast/identifier_expression.h"
|
||||||
#include "src/ast/if_statement.h"
|
#include "src/ast/if_statement.h"
|
||||||
|
#include "src/ast/intrinsic.h"
|
||||||
#include "src/ast/loop_statement.h"
|
#include "src/ast/loop_statement.h"
|
||||||
#include "src/ast/member_accessor_expression.h"
|
#include "src/ast/member_accessor_expression.h"
|
||||||
#include "src/ast/return_statement.h"
|
#include "src/ast/return_statement.h"
|
||||||
|
@ -335,6 +336,20 @@ std::string GetGlslStd450FuncName(uint32_t ext_opcode) {
|
||||||
return "";
|
return "";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the WGSL standard library function instrinsic for the
|
||||||
|
// given instruction, or ast::Intrinsic::kNone
|
||||||
|
ast::Intrinsic GetIntrinsic(SpvOp opcode) {
|
||||||
|
switch (opcode) {
|
||||||
|
case SpvOpDot:
|
||||||
|
return ast::Intrinsic::kDot;
|
||||||
|
case SpvOpOuterProduct:
|
||||||
|
return ast::Intrinsic::kOuterProduct;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return ast::Intrinsic::kNone;
|
||||||
|
}
|
||||||
|
|
||||||
// @returns the merge block ID for the given basic block, or 0 if there is none.
|
// @returns the merge block ID for the given basic block, or 0 if there is none.
|
||||||
uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) {
|
uint32_t MergeFor(const spvtools::opt::BasicBlock& bb) {
|
||||||
// Get the OpSelectionMerge or OpLoopMerge instruction, if any.
|
// Get the OpSelectionMerge or OpLoopMerge instruction, if any.
|
||||||
|
@ -2715,6 +2730,11 @@ TypedExpression FunctionEmitter::MaybeEmitCombinatorialValue(
|
||||||
std::move(params))};
|
std::move(params))};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const auto intrinsic = GetIntrinsic(opcode);
|
||||||
|
if (intrinsic != ast::Intrinsic::kNone) {
|
||||||
|
return MakeIntrinsicCall(inst);
|
||||||
|
}
|
||||||
|
|
||||||
if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) {
|
if (opcode == SpvOpAccessChain || opcode == SpvOpInBoundsAccessChain) {
|
||||||
return MakeAccessChain(inst);
|
return MakeAccessChain(inst);
|
||||||
}
|
}
|
||||||
|
@ -3505,6 +3525,29 @@ bool FunctionEmitter::EmitFunctionCall(const spvtools::opt::Instruction& inst) {
|
||||||
{result_type, std::move(call_expr)});
|
{result_type, std::move(call_expr)});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypedExpression FunctionEmitter::MakeIntrinsicCall(
|
||||||
|
const spvtools::opt::Instruction& inst) {
|
||||||
|
const auto intrinsic = GetIntrinsic(inst.opcode());
|
||||||
|
std::ostringstream ss;
|
||||||
|
ss << intrinsic;
|
||||||
|
auto ident = std::make_unique<ast::IdentifierExpression>(ss.str());
|
||||||
|
ident->set_intrinsic(intrinsic);
|
||||||
|
|
||||||
|
ast::ExpressionList params;
|
||||||
|
for (uint32_t iarg = 0; iarg < inst.NumInOperands(); ++iarg) {
|
||||||
|
params.emplace_back(MakeOperand(inst, iarg).expr);
|
||||||
|
}
|
||||||
|
auto call_expr = std::make_unique<ast::CallExpression>(std::move(ident),
|
||||||
|
std::move(params));
|
||||||
|
auto* result_type = parser_impl_.ConvertType(inst.type_id());
|
||||||
|
if (!result_type) {
|
||||||
|
Fail() << "internal error: no mapped type result of call: "
|
||||||
|
<< inst.PrettyPrint();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
return {result_type, std::move(call_expr)};
|
||||||
|
}
|
||||||
|
|
||||||
TypedExpression FunctionEmitter::MakeSimpleSelect(
|
TypedExpression FunctionEmitter::MakeSimpleSelect(
|
||||||
const spvtools::opt::Instruction& inst) {
|
const spvtools::opt::Instruction& inst) {
|
||||||
auto condition = MakeOperand(inst, 0);
|
auto condition = MakeOperand(inst, 0);
|
||||||
|
|
|
@ -689,6 +689,12 @@ class FunctionEmitter {
|
||||||
/// @returns false if emission failed
|
/// @returns false if emission failed
|
||||||
bool EmitFunctionCall(const spvtools::opt::Instruction& inst);
|
bool EmitFunctionCall(const spvtools::opt::Instruction& inst);
|
||||||
|
|
||||||
|
/// Returns an expression for a SPIR-V instruction that maps to a WGSL
|
||||||
|
/// intrinsic function call.
|
||||||
|
/// @param inst the SPIR-V instruction
|
||||||
|
/// @returns an expression
|
||||||
|
TypedExpression MakeIntrinsicCall(const spvtools::opt::Instruction& inst);
|
||||||
|
|
||||||
/// Returns an expression for an OpSelect, if its operands are scalars
|
/// Returns an expression for an OpSelect, if its operands are scalars
|
||||||
/// or vectors. These translate directly to WGSL select. Otherwise, return
|
/// or vectors. These translate directly to WGSL select. Otherwise, return
|
||||||
/// an expression with a null owned expression
|
/// an expression with a null owned expression
|
||||||
|
|
|
@ -1053,14 +1053,74 @@ TEST_F(SpvBinaryArithTestBasic, MatrixTimesMatrix) {
|
||||||
<< ToString(fe.ast_body());
|
<< ToString(fe.ast_body());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SpvBinaryArithTestBasic, Dot) {
|
||||||
|
const auto assembly = CommonTypes() + R"(
|
||||||
|
%100 = OpFunction %void None %voidfn
|
||||||
|
%entry = OpLabel
|
||||||
|
%1 = OpCopyObject %v2float %v2float_50_60
|
||||||
|
%2 = OpCopyObject %v2float %v2float_60_50
|
||||||
|
%3 = OpDot %float %1 %2
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
auto* p = parser(test::Assemble(assembly));
|
||||||
|
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||||
|
FunctionEmitter fe(p, *spirv_function(100));
|
||||||
|
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||||
|
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableConst{
|
||||||
|
x_3
|
||||||
|
none
|
||||||
|
__f32
|
||||||
|
{
|
||||||
|
Call{
|
||||||
|
Identifier{dot}
|
||||||
|
(
|
||||||
|
Identifier{x_1}
|
||||||
|
Identifier{x_2}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})"))
|
||||||
|
<< ToString(fe.ast_body());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(SpvBinaryArithTestBasic, OuterProduct) {
|
||||||
|
const auto assembly = CommonTypes() + R"(
|
||||||
|
%100 = OpFunction %void None %voidfn
|
||||||
|
%entry = OpLabel
|
||||||
|
%1 = OpCopyObject %v2float %v2float_50_60
|
||||||
|
%2 = OpCopyObject %v2float %v2float_60_50
|
||||||
|
%3 = OpOuterProduct %m2v2float %1 %2
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
auto* p = parser(test::Assemble(assembly));
|
||||||
|
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||||
|
FunctionEmitter fe(p, *spirv_function(100));
|
||||||
|
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||||
|
EXPECT_THAT(ToString(fe.ast_body()), HasSubstr(R"(VariableConst{
|
||||||
|
x_3
|
||||||
|
none
|
||||||
|
__mat_2_2__f32
|
||||||
|
{
|
||||||
|
Call{
|
||||||
|
Identifier{outerProduct}
|
||||||
|
(
|
||||||
|
Identifier{x_1}
|
||||||
|
Identifier{x_2}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})"))
|
||||||
|
<< ToString(fe.ast_body());
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(dneto): OpSRem. Missing from WGSL
|
// TODO(dneto): OpSRem. Missing from WGSL
|
||||||
// https://github.com/gpuweb/gpuweb/issues/702
|
// https://github.com/gpuweb/gpuweb/issues/702
|
||||||
|
|
||||||
// TODO(dneto): OpFRem. Missing from WGSL
|
// TODO(dneto): OpFRem. Missing from WGSL
|
||||||
// https://github.com/gpuweb/gpuweb/issues/702
|
// https://github.com/gpuweb/gpuweb/issues/702
|
||||||
|
|
||||||
// TODO(dneto): OpOuterProduct
|
|
||||||
// TODO(dneto): OpDot
|
|
||||||
// TODO(dneto): OpIAddCarry
|
// TODO(dneto): OpIAddCarry
|
||||||
// TODO(dneto): OpISubBorrow
|
// TODO(dneto): OpISubBorrow
|
||||||
// TODO(dneto): OpIMulExtended
|
// TODO(dneto): OpIMulExtended
|
||||||
|
|
Loading…
Reference in New Issue