spirv-reader: expand OuterProduct to primitive ops

Bug: tint:3
Change-Id: Id6de3554d945bc743a484e80b494690c26552079
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/37660
Reviewed-by: dan sinclair <dsinclair@chromium.org>
Commit-Queue: David Neto <dneto@google.com>
Auto-Submit: David Neto <dneto@google.com>
This commit is contained in:
David Neto 2021-01-14 19:01:27 +00:00 committed by Commit Bot service account
parent 0a68b365eb
commit 0e17caa361
3 changed files with 164 additions and 1 deletions

View File

@ -57,6 +57,7 @@
#include "src/ast/type/depth_texture_type.h" #include "src/ast/type/depth_texture_type.h"
#include "src/ast/type/f32_type.h" #include "src/ast/type/f32_type.h"
#include "src/ast/type/i32_type.h" #include "src/ast/type/i32_type.h"
#include "src/ast/type/matrix_type.h"
#include "src/ast/type/pointer_type.h" #include "src/ast/type/pointer_type.h"
#include "src/ast/type/storage_texture_type.h" #include "src/ast/type/storage_texture_type.h"
#include "src/ast/type/texture_type.h" #include "src/ast/type/texture_type.h"
@ -3013,6 +3014,10 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
return EmitConstDefOrWriteToHoistedVar(inst, expr); return EmitConstDefOrWriteToHoistedVar(inst, expr);
} }
case SpvOpOuterProduct:
// Synthesize an outer product expression in its own statement.
return EmitConstDefOrWriteToHoistedVar(inst, MakeOuterProduct(inst));
case SpvOpFunctionCall: case SpvOpFunctionCall:
return EmitFunctionCall(inst); return EmitFunctionCall(inst);
@ -3707,7 +3712,8 @@ void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() {
// but only if they are defined in this function as well. // but only if they are defined in this function as well.
for (auto& id_def_info_pair : def_info_) { for (auto& id_def_info_pair : def_info_) {
const auto& inst = id_def_info_pair.second->inst; const auto& inst = id_def_info_pair.second->inst;
if (inst.opcode() == SpvOpVectorShuffle) { const auto opcode = inst.opcode();
if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
// We might access the vector operands multiple times. Make sure they // We might access the vector operands multiple times. Make sure they
// are evaluated only once. // are evaluated only once.
for (auto vector_arg : std::array<uint32_t, 2>{0, 1}) { for (auto vector_arg : std::array<uint32_t, 2>{0, 1}) {
@ -4578,6 +4584,52 @@ TypedExpression FunctionEmitter::MakeArrayLength(
return {parser_impl_.ConvertType(inst.type_id()), call_expr}; return {parser_impl_.ConvertType(inst.type_id()), call_expr};
} }
TypedExpression FunctionEmitter::MakeOuterProduct(
const spvtools::opt::Instruction& inst) {
// Synthesize the result.
auto col = MakeOperand(inst, 0);
auto row = MakeOperand(inst, 1);
auto* col_ty = col.type->As<ast::type::Vector>();
auto* row_ty = row.type->As<ast::type::Vector>();
auto* result_ty =
parser_impl_.ConvertType(inst.type_id())->As<ast::type::Matrix>();
if (!col_ty || !col_ty || !result_ty || result_ty->type() != col_ty->type() ||
result_ty->type() != row_ty->type() ||
result_ty->columns() != row_ty->size() ||
result_ty->rows() != col_ty->size()) {
Fail() << "invalid outer product instruction: bad types "
<< inst.PrettyPrint();
return {};
}
// Example:
// c : vec3 column vector
// r : vec2 row vector
// OuterProduct c r : mat2x3 (2 columns, 3 rows)
// Result:
// | c.x * r.x c.x * r.y |
// | c.y * r.x c.y * r.y |
// | c.z * r.x c.z * r.y |
ast::ExpressionList result_columns;
for (uint32_t icol = 0; icol < result_ty->columns(); icol++) {
ast::ExpressionList result_row;
auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr,
Swizzle(icol));
for (uint32_t irow = 0; irow < result_ty->rows(); irow++) {
auto* column_factor = create<ast::MemberAccessorExpression>(
Source{}, col.expr, Swizzle(irow));
auto* elem = create<ast::BinaryExpression>(
Source{}, ast::BinaryOp::kMultiply, row_factor, column_factor);
result_row.push_back(elem);
}
result_columns.push_back(
create<ast::TypeConstructorExpression>(Source{}, col_ty, result_row));
}
return {result_ty, create<ast::TypeConstructorExpression>(Source{}, result_ty,
result_columns)};
}
FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default; FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default;
FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default; FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;

View File

@ -851,6 +851,11 @@ class FunctionEmitter {
/// @returns an expression /// @returns an expression
TypedExpression MakeArrayLength(const spvtools::opt::Instruction& inst); TypedExpression MakeArrayLength(const spvtools::opt::Instruction& inst);
/// Generates an expression for a SPIR-V OpOuterProduct instruction.
/// @param inst the SPIR-V instruction
/// @returns an expression
TypedExpression MakeOuterProduct(const spvtools::opt::Instruction& inst);
/// Emits a texture builtin function call for a SPIR-V instruction that /// Emits a texture builtin function call for a SPIR-V instruction that
/// accesses an image or sampled image. /// accesses an image or sampled image.
/// @param inst the SPIR-V instruction /// @param inst the SPIR-V instruction

View File

@ -43,6 +43,7 @@ std::string CommonTypes() {
%int_40 = OpConstant %int 40 %int_40 = OpConstant %int 40
%float_50 = OpConstant %float 50 %float_50 = OpConstant %float 50
%float_60 = OpConstant %float 60 %float_60 = OpConstant %float 60
%float_70 = OpConstant %float 70
%ptr_uint = OpTypePointer Function %uint %ptr_uint = OpTypePointer Function %uint
%ptr_int = OpTypePointer Function %int %ptr_int = OpTypePointer Function %int
@ -51,6 +52,7 @@ std::string CommonTypes() {
%v2uint = OpTypeVector %uint 2 %v2uint = OpTypeVector %uint 2
%v2int = OpTypeVector %int 2 %v2int = OpTypeVector %int 2
%v2float = OpTypeVector %float 2 %v2float = OpTypeVector %float 2
%v3float = OpTypeVector %float 3
%v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20 %v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20
%v2uint_20_10 = OpConstantComposite %v2uint %uint_20 %uint_10 %v2uint_20_10 = OpConstantComposite %v2uint %uint_20 %uint_10
@ -58,10 +60,12 @@ std::string CommonTypes() {
%v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30 %v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60 %v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50 %v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
%v3float_50_60_70 = OpConstantComposite %v2float %float_50 %float_60 %float_70
%m2v2float = OpTypeMatrix %v2float 2 %m2v2float = OpTypeMatrix %v2float 2
%m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50 %m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
%m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60 %m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
%m2v3float = OpTypeMatrix %v3float 2
)"; )";
} }
@ -1099,6 +1103,108 @@ TEST_F(SpvBinaryArithTestBasic, Dot) {
<< ToString(p->get_module(), fe.ast_body()); << ToString(p->get_module(), fe.ast_body());
} }
TEST_F(SpvBinaryArithTestBasic, OuterProduct) {
// OpOuterProduct is expanded to basic operations.
// The operands, even if used once, are given their own const definitions.
const auto assembly = CommonTypes() + R"(
%100 = OpFunction %void None %voidfn
%entry = OpLabel
%1 = OpFAdd %v3float %v3float_50_60_70 %v3float_50_60_70 ; column vector
%2 = OpFAdd %v2float %v2float_60_50 %v2float_50_60 ; row vector
%3 = OpOuterProduct %m2v3float %1 %2
OpReturn
OpFunctionEnd
)";
auto p = parser(test::Assemble(assembly));
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.EmitBody()) << p->error();
auto got = ToString(p->get_module(), fe.ast_body());
EXPECT_THAT(got, HasSubstr(R"(VariableConst{
x_3
none
__mat_3_2__f32
{
TypeConstructor[not set]{
__mat_3_2__f32
TypeConstructor[not set]{
__vec_3__f32
Binary[not set]{
MemberAccessor[not set]{
Identifier[not set]{x_2}
Identifier[not set]{x}
}
multiply
MemberAccessor[not set]{
Identifier[not set]{x_1}
Identifier[not set]{x}
}
}
Binary[not set]{
MemberAccessor[not set]{
Identifier[not set]{x_2}
Identifier[not set]{x}
}
multiply
MemberAccessor[not set]{
Identifier[not set]{x_1}
Identifier[not set]{y}
}
}
Binary[not set]{
MemberAccessor[not set]{
Identifier[not set]{x_2}
Identifier[not set]{x}
}
multiply
MemberAccessor[not set]{
Identifier[not set]{x_1}
Identifier[not set]{z}
}
}
}
TypeConstructor[not set]{
__vec_3__f32
Binary[not set]{
MemberAccessor[not set]{
Identifier[not set]{x_2}
Identifier[not set]{y}
}
multiply
MemberAccessor[not set]{
Identifier[not set]{x_1}
Identifier[not set]{x}
}
}
Binary[not set]{
MemberAccessor[not set]{
Identifier[not set]{x_2}
Identifier[not set]{y}
}
multiply
MemberAccessor[not set]{
Identifier[not set]{x_1}
Identifier[not set]{y}
}
}
Binary[not set]{
MemberAccessor[not set]{
Identifier[not set]{x_2}
Identifier[not set]{y}
}
multiply
MemberAccessor[not set]{
Identifier[not set]{x_1}
Identifier[not set]{z}
}
}
}
}
}
})"))
<< got;
}
// TODO(dneto): OpSRem. Missing from WGSL // TODO(dneto): OpSRem. Missing from WGSL
// https://github.com/gpuweb/gpuweb/issues/702 // https://github.com/gpuweb/gpuweb/issues/702