spirv-reader: expand OuterProduct to primitive ops
Bug: tint:3 Change-Id: Id6de3554d945bc743a484e80b494690c26552079 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/37660 Reviewed-by: dan sinclair <dsinclair@chromium.org> Commit-Queue: David Neto <dneto@google.com> Auto-Submit: David Neto <dneto@google.com>
This commit is contained in:
parent
0a68b365eb
commit
0e17caa361
|
@ -57,6 +57,7 @@
|
|||
#include "src/ast/type/depth_texture_type.h"
|
||||
#include "src/ast/type/f32_type.h"
|
||||
#include "src/ast/type/i32_type.h"
|
||||
#include "src/ast/type/matrix_type.h"
|
||||
#include "src/ast/type/pointer_type.h"
|
||||
#include "src/ast/type/storage_texture_type.h"
|
||||
#include "src/ast/type/texture_type.h"
|
||||
|
@ -3013,6 +3014,10 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
|
|||
return EmitConstDefOrWriteToHoistedVar(inst, expr);
|
||||
}
|
||||
|
||||
case SpvOpOuterProduct:
|
||||
// Synthesize an outer product expression in its own statement.
|
||||
return EmitConstDefOrWriteToHoistedVar(inst, MakeOuterProduct(inst));
|
||||
|
||||
case SpvOpFunctionCall:
|
||||
return EmitFunctionCall(inst);
|
||||
|
||||
|
@ -3707,7 +3712,8 @@ void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() {
|
|||
// but only if they are defined in this function as well.
|
||||
for (auto& id_def_info_pair : def_info_) {
|
||||
const auto& inst = id_def_info_pair.second->inst;
|
||||
if (inst.opcode() == SpvOpVectorShuffle) {
|
||||
const auto opcode = inst.opcode();
|
||||
if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
|
||||
// We might access the vector operands multiple times. Make sure they
|
||||
// are evaluated only once.
|
||||
for (auto vector_arg : std::array<uint32_t, 2>{0, 1}) {
|
||||
|
@ -4578,6 +4584,52 @@ TypedExpression FunctionEmitter::MakeArrayLength(
|
|||
return {parser_impl_.ConvertType(inst.type_id()), call_expr};
|
||||
}
|
||||
|
||||
TypedExpression FunctionEmitter::MakeOuterProduct(
|
||||
const spvtools::opt::Instruction& inst) {
|
||||
// Synthesize the result.
|
||||
auto col = MakeOperand(inst, 0);
|
||||
auto row = MakeOperand(inst, 1);
|
||||
auto* col_ty = col.type->As<ast::type::Vector>();
|
||||
auto* row_ty = row.type->As<ast::type::Vector>();
|
||||
auto* result_ty =
|
||||
parser_impl_.ConvertType(inst.type_id())->As<ast::type::Matrix>();
|
||||
if (!col_ty || !col_ty || !result_ty || result_ty->type() != col_ty->type() ||
|
||||
result_ty->type() != row_ty->type() ||
|
||||
result_ty->columns() != row_ty->size() ||
|
||||
result_ty->rows() != col_ty->size()) {
|
||||
Fail() << "invalid outer product instruction: bad types "
|
||||
<< inst.PrettyPrint();
|
||||
return {};
|
||||
}
|
||||
|
||||
// Example:
|
||||
// c : vec3 column vector
|
||||
// r : vec2 row vector
|
||||
// OuterProduct c r : mat2x3 (2 columns, 3 rows)
|
||||
// Result:
|
||||
// | c.x * r.x c.x * r.y |
|
||||
// | c.y * r.x c.y * r.y |
|
||||
// | c.z * r.x c.z * r.y |
|
||||
|
||||
ast::ExpressionList result_columns;
|
||||
for (uint32_t icol = 0; icol < result_ty->columns(); icol++) {
|
||||
ast::ExpressionList result_row;
|
||||
auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr,
|
||||
Swizzle(icol));
|
||||
for (uint32_t irow = 0; irow < result_ty->rows(); irow++) {
|
||||
auto* column_factor = create<ast::MemberAccessorExpression>(
|
||||
Source{}, col.expr, Swizzle(irow));
|
||||
auto* elem = create<ast::BinaryExpression>(
|
||||
Source{}, ast::BinaryOp::kMultiply, row_factor, column_factor);
|
||||
result_row.push_back(elem);
|
||||
}
|
||||
result_columns.push_back(
|
||||
create<ast::TypeConstructorExpression>(Source{}, col_ty, result_row));
|
||||
}
|
||||
return {result_ty, create<ast::TypeConstructorExpression>(Source{}, result_ty,
|
||||
result_columns)};
|
||||
}
|
||||
|
||||
FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default;
|
||||
FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;
|
||||
|
||||
|
|
|
@ -851,6 +851,11 @@ class FunctionEmitter {
|
|||
/// @returns an expression
|
||||
TypedExpression MakeArrayLength(const spvtools::opt::Instruction& inst);
|
||||
|
||||
/// Generates an expression for a SPIR-V OpOuterProduct instruction.
|
||||
/// @param inst the SPIR-V instruction
|
||||
/// @returns an expression
|
||||
TypedExpression MakeOuterProduct(const spvtools::opt::Instruction& inst);
|
||||
|
||||
/// Emits a texture builtin function call for a SPIR-V instruction that
|
||||
/// accesses an image or sampled image.
|
||||
/// @param inst the SPIR-V instruction
|
||||
|
|
|
@ -43,6 +43,7 @@ std::string CommonTypes() {
|
|||
%int_40 = OpConstant %int 40
|
||||
%float_50 = OpConstant %float 50
|
||||
%float_60 = OpConstant %float 60
|
||||
%float_70 = OpConstant %float 70
|
||||
|
||||
%ptr_uint = OpTypePointer Function %uint
|
||||
%ptr_int = OpTypePointer Function %int
|
||||
|
@ -51,6 +52,7 @@ std::string CommonTypes() {
|
|||
%v2uint = OpTypeVector %uint 2
|
||||
%v2int = OpTypeVector %int 2
|
||||
%v2float = OpTypeVector %float 2
|
||||
%v3float = OpTypeVector %float 3
|
||||
|
||||
%v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20
|
||||
%v2uint_20_10 = OpConstantComposite %v2uint %uint_20 %uint_10
|
||||
|
@ -58,10 +60,12 @@ std::string CommonTypes() {
|
|||
%v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
|
||||
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
|
||||
%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
|
||||
%v3float_50_60_70 = OpConstantComposite %v2float %float_50 %float_60 %float_70
|
||||
|
||||
%m2v2float = OpTypeMatrix %v2float 2
|
||||
%m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
|
||||
%m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
|
||||
%m2v3float = OpTypeMatrix %v3float 2
|
||||
)";
|
||||
}
|
||||
|
||||
|
@ -1099,6 +1103,108 @@ TEST_F(SpvBinaryArithTestBasic, Dot) {
|
|||
<< ToString(p->get_module(), fe.ast_body());
|
||||
}
|
||||
|
||||
TEST_F(SpvBinaryArithTestBasic, OuterProduct) {
|
||||
// OpOuterProduct is expanded to basic operations.
|
||||
// The operands, even if used once, are given their own const definitions.
|
||||
const auto assembly = CommonTypes() + R"(
|
||||
%100 = OpFunction %void None %voidfn
|
||||
%entry = OpLabel
|
||||
%1 = OpFAdd %v3float %v3float_50_60_70 %v3float_50_60_70 ; column vector
|
||||
%2 = OpFAdd %v2float %v2float_60_50 %v2float_50_60 ; row vector
|
||||
%3 = OpOuterProduct %m2v3float %1 %2
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)";
|
||||
auto p = parser(test::Assemble(assembly));
|
||||
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
|
||||
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||
auto got = ToString(p->get_module(), fe.ast_body());
|
||||
EXPECT_THAT(got, HasSubstr(R"(VariableConst{
|
||||
x_3
|
||||
none
|
||||
__mat_3_2__f32
|
||||
{
|
||||
TypeConstructor[not set]{
|
||||
__mat_3_2__f32
|
||||
TypeConstructor[not set]{
|
||||
__vec_3__f32
|
||||
Binary[not set]{
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_2}
|
||||
Identifier[not set]{x}
|
||||
}
|
||||
multiply
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_1}
|
||||
Identifier[not set]{x}
|
||||
}
|
||||
}
|
||||
Binary[not set]{
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_2}
|
||||
Identifier[not set]{x}
|
||||
}
|
||||
multiply
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_1}
|
||||
Identifier[not set]{y}
|
||||
}
|
||||
}
|
||||
Binary[not set]{
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_2}
|
||||
Identifier[not set]{x}
|
||||
}
|
||||
multiply
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_1}
|
||||
Identifier[not set]{z}
|
||||
}
|
||||
}
|
||||
}
|
||||
TypeConstructor[not set]{
|
||||
__vec_3__f32
|
||||
Binary[not set]{
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_2}
|
||||
Identifier[not set]{y}
|
||||
}
|
||||
multiply
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_1}
|
||||
Identifier[not set]{x}
|
||||
}
|
||||
}
|
||||
Binary[not set]{
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_2}
|
||||
Identifier[not set]{y}
|
||||
}
|
||||
multiply
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_1}
|
||||
Identifier[not set]{y}
|
||||
}
|
||||
}
|
||||
Binary[not set]{
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_2}
|
||||
Identifier[not set]{y}
|
||||
}
|
||||
multiply
|
||||
MemberAccessor[not set]{
|
||||
Identifier[not set]{x_1}
|
||||
Identifier[not set]{z}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})"))
|
||||
<< got;
|
||||
}
|
||||
|
||||
// TODO(dneto): OpSRem. Missing from WGSL
|
||||
// https://github.com/gpuweb/gpuweb/issues/702
|
||||
|
||||
|
|
Loading…
Reference in New Issue