spirv-reader: expand OuterProduct to primitive ops
Bug: tint:3 Change-Id: Id6de3554d945bc743a484e80b494690c26552079 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/37660 Reviewed-by: dan sinclair <dsinclair@chromium.org> Commit-Queue: David Neto <dneto@google.com> Auto-Submit: David Neto <dneto@google.com>
This commit is contained in:
parent
0a68b365eb
commit
0e17caa361
|
@ -57,6 +57,7 @@
|
||||||
#include "src/ast/type/depth_texture_type.h"
|
#include "src/ast/type/depth_texture_type.h"
|
||||||
#include "src/ast/type/f32_type.h"
|
#include "src/ast/type/f32_type.h"
|
||||||
#include "src/ast/type/i32_type.h"
|
#include "src/ast/type/i32_type.h"
|
||||||
|
#include "src/ast/type/matrix_type.h"
|
||||||
#include "src/ast/type/pointer_type.h"
|
#include "src/ast/type/pointer_type.h"
|
||||||
#include "src/ast/type/storage_texture_type.h"
|
#include "src/ast/type/storage_texture_type.h"
|
||||||
#include "src/ast/type/texture_type.h"
|
#include "src/ast/type/texture_type.h"
|
||||||
|
@ -3013,6 +3014,10 @@ bool FunctionEmitter::EmitStatement(const spvtools::opt::Instruction& inst) {
|
||||||
return EmitConstDefOrWriteToHoistedVar(inst, expr);
|
return EmitConstDefOrWriteToHoistedVar(inst, expr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case SpvOpOuterProduct:
|
||||||
|
// Synthesize an outer product expression in its own statement.
|
||||||
|
return EmitConstDefOrWriteToHoistedVar(inst, MakeOuterProduct(inst));
|
||||||
|
|
||||||
case SpvOpFunctionCall:
|
case SpvOpFunctionCall:
|
||||||
return EmitFunctionCall(inst);
|
return EmitFunctionCall(inst);
|
||||||
|
|
||||||
|
@ -3707,7 +3712,8 @@ void FunctionEmitter::FindValuesNeedingNamedOrHoistedDefinition() {
|
||||||
// but only if they are defined in this function as well.
|
// but only if they are defined in this function as well.
|
||||||
for (auto& id_def_info_pair : def_info_) {
|
for (auto& id_def_info_pair : def_info_) {
|
||||||
const auto& inst = id_def_info_pair.second->inst;
|
const auto& inst = id_def_info_pair.second->inst;
|
||||||
if (inst.opcode() == SpvOpVectorShuffle) {
|
const auto opcode = inst.opcode();
|
||||||
|
if ((opcode == SpvOpVectorShuffle) || (opcode == SpvOpOuterProduct)) {
|
||||||
// We might access the vector operands multiple times. Make sure they
|
// We might access the vector operands multiple times. Make sure they
|
||||||
// are evaluated only once.
|
// are evaluated only once.
|
||||||
for (auto vector_arg : std::array<uint32_t, 2>{0, 1}) {
|
for (auto vector_arg : std::array<uint32_t, 2>{0, 1}) {
|
||||||
|
@ -4578,6 +4584,52 @@ TypedExpression FunctionEmitter::MakeArrayLength(
|
||||||
return {parser_impl_.ConvertType(inst.type_id()), call_expr};
|
return {parser_impl_.ConvertType(inst.type_id()), call_expr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TypedExpression FunctionEmitter::MakeOuterProduct(
|
||||||
|
const spvtools::opt::Instruction& inst) {
|
||||||
|
// Synthesize the result.
|
||||||
|
auto col = MakeOperand(inst, 0);
|
||||||
|
auto row = MakeOperand(inst, 1);
|
||||||
|
auto* col_ty = col.type->As<ast::type::Vector>();
|
||||||
|
auto* row_ty = row.type->As<ast::type::Vector>();
|
||||||
|
auto* result_ty =
|
||||||
|
parser_impl_.ConvertType(inst.type_id())->As<ast::type::Matrix>();
|
||||||
|
if (!col_ty || !col_ty || !result_ty || result_ty->type() != col_ty->type() ||
|
||||||
|
result_ty->type() != row_ty->type() ||
|
||||||
|
result_ty->columns() != row_ty->size() ||
|
||||||
|
result_ty->rows() != col_ty->size()) {
|
||||||
|
Fail() << "invalid outer product instruction: bad types "
|
||||||
|
<< inst.PrettyPrint();
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Example:
|
||||||
|
// c : vec3 column vector
|
||||||
|
// r : vec2 row vector
|
||||||
|
// OuterProduct c r : mat2x3 (2 columns, 3 rows)
|
||||||
|
// Result:
|
||||||
|
// | c.x * r.x c.x * r.y |
|
||||||
|
// | c.y * r.x c.y * r.y |
|
||||||
|
// | c.z * r.x c.z * r.y |
|
||||||
|
|
||||||
|
ast::ExpressionList result_columns;
|
||||||
|
for (uint32_t icol = 0; icol < result_ty->columns(); icol++) {
|
||||||
|
ast::ExpressionList result_row;
|
||||||
|
auto* row_factor = create<ast::MemberAccessorExpression>(Source{}, row.expr,
|
||||||
|
Swizzle(icol));
|
||||||
|
for (uint32_t irow = 0; irow < result_ty->rows(); irow++) {
|
||||||
|
auto* column_factor = create<ast::MemberAccessorExpression>(
|
||||||
|
Source{}, col.expr, Swizzle(irow));
|
||||||
|
auto* elem = create<ast::BinaryExpression>(
|
||||||
|
Source{}, ast::BinaryOp::kMultiply, row_factor, column_factor);
|
||||||
|
result_row.push_back(elem);
|
||||||
|
}
|
||||||
|
result_columns.push_back(
|
||||||
|
create<ast::TypeConstructorExpression>(Source{}, col_ty, result_row));
|
||||||
|
}
|
||||||
|
return {result_ty, create<ast::TypeConstructorExpression>(Source{}, result_ty,
|
||||||
|
result_columns)};
|
||||||
|
}
|
||||||
|
|
||||||
FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default;
|
FunctionEmitter::FunctionDeclaration::FunctionDeclaration() = default;
|
||||||
FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;
|
FunctionEmitter::FunctionDeclaration::~FunctionDeclaration() = default;
|
||||||
|
|
||||||
|
|
|
@ -851,6 +851,11 @@ class FunctionEmitter {
|
||||||
/// @returns an expression
|
/// @returns an expression
|
||||||
TypedExpression MakeArrayLength(const spvtools::opt::Instruction& inst);
|
TypedExpression MakeArrayLength(const spvtools::opt::Instruction& inst);
|
||||||
|
|
||||||
|
/// Generates an expression for a SPIR-V OpOuterProduct instruction.
|
||||||
|
/// @param inst the SPIR-V instruction
|
||||||
|
/// @returns an expression
|
||||||
|
TypedExpression MakeOuterProduct(const spvtools::opt::Instruction& inst);
|
||||||
|
|
||||||
/// Emits a texture builtin function call for a SPIR-V instruction that
|
/// Emits a texture builtin function call for a SPIR-V instruction that
|
||||||
/// accesses an image or sampled image.
|
/// accesses an image or sampled image.
|
||||||
/// @param inst the SPIR-V instruction
|
/// @param inst the SPIR-V instruction
|
||||||
|
|
|
@ -43,6 +43,7 @@ std::string CommonTypes() {
|
||||||
%int_40 = OpConstant %int 40
|
%int_40 = OpConstant %int 40
|
||||||
%float_50 = OpConstant %float 50
|
%float_50 = OpConstant %float 50
|
||||||
%float_60 = OpConstant %float 60
|
%float_60 = OpConstant %float 60
|
||||||
|
%float_70 = OpConstant %float 70
|
||||||
|
|
||||||
%ptr_uint = OpTypePointer Function %uint
|
%ptr_uint = OpTypePointer Function %uint
|
||||||
%ptr_int = OpTypePointer Function %int
|
%ptr_int = OpTypePointer Function %int
|
||||||
|
@ -51,6 +52,7 @@ std::string CommonTypes() {
|
||||||
%v2uint = OpTypeVector %uint 2
|
%v2uint = OpTypeVector %uint 2
|
||||||
%v2int = OpTypeVector %int 2
|
%v2int = OpTypeVector %int 2
|
||||||
%v2float = OpTypeVector %float 2
|
%v2float = OpTypeVector %float 2
|
||||||
|
%v3float = OpTypeVector %float 3
|
||||||
|
|
||||||
%v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20
|
%v2uint_10_20 = OpConstantComposite %v2uint %uint_10 %uint_20
|
||||||
%v2uint_20_10 = OpConstantComposite %v2uint %uint_20 %uint_10
|
%v2uint_20_10 = OpConstantComposite %v2uint %uint_20 %uint_10
|
||||||
|
@ -58,10 +60,12 @@ std::string CommonTypes() {
|
||||||
%v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
|
%v2int_40_30 = OpConstantComposite %v2int %int_40 %int_30
|
||||||
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
|
%v2float_50_60 = OpConstantComposite %v2float %float_50 %float_60
|
||||||
%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
|
%v2float_60_50 = OpConstantComposite %v2float %float_60 %float_50
|
||||||
|
%v3float_50_60_70 = OpConstantComposite %v2float %float_50 %float_60 %float_70
|
||||||
|
|
||||||
%m2v2float = OpTypeMatrix %v2float 2
|
%m2v2float = OpTypeMatrix %v2float 2
|
||||||
%m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
|
%m2v2float_a = OpConstantComposite %m2v2float %v2float_50_60 %v2float_60_50
|
||||||
%m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
|
%m2v2float_b = OpConstantComposite %m2v2float %v2float_60_50 %v2float_50_60
|
||||||
|
%m2v3float = OpTypeMatrix %v3float 2
|
||||||
)";
|
)";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1099,6 +1103,108 @@ TEST_F(SpvBinaryArithTestBasic, Dot) {
|
||||||
<< ToString(p->get_module(), fe.ast_body());
|
<< ToString(p->get_module(), fe.ast_body());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(SpvBinaryArithTestBasic, OuterProduct) {
|
||||||
|
// OpOuterProduct is expanded to basic operations.
|
||||||
|
// The operands, even if used once, are given their own const definitions.
|
||||||
|
const auto assembly = CommonTypes() + R"(
|
||||||
|
%100 = OpFunction %void None %voidfn
|
||||||
|
%entry = OpLabel
|
||||||
|
%1 = OpFAdd %v3float %v3float_50_60_70 %v3float_50_60_70 ; column vector
|
||||||
|
%2 = OpFAdd %v2float %v2float_60_50 %v2float_50_60 ; row vector
|
||||||
|
%3 = OpOuterProduct %m2v3float %1 %2
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)";
|
||||||
|
auto p = parser(test::Assemble(assembly));
|
||||||
|
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()) << assembly;
|
||||||
|
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
|
||||||
|
EXPECT_TRUE(fe.EmitBody()) << p->error();
|
||||||
|
auto got = ToString(p->get_module(), fe.ast_body());
|
||||||
|
EXPECT_THAT(got, HasSubstr(R"(VariableConst{
|
||||||
|
x_3
|
||||||
|
none
|
||||||
|
__mat_3_2__f32
|
||||||
|
{
|
||||||
|
TypeConstructor[not set]{
|
||||||
|
__mat_3_2__f32
|
||||||
|
TypeConstructor[not set]{
|
||||||
|
__vec_3__f32
|
||||||
|
Binary[not set]{
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_2}
|
||||||
|
Identifier[not set]{x}
|
||||||
|
}
|
||||||
|
multiply
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_1}
|
||||||
|
Identifier[not set]{x}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Binary[not set]{
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_2}
|
||||||
|
Identifier[not set]{x}
|
||||||
|
}
|
||||||
|
multiply
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_1}
|
||||||
|
Identifier[not set]{y}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Binary[not set]{
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_2}
|
||||||
|
Identifier[not set]{x}
|
||||||
|
}
|
||||||
|
multiply
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_1}
|
||||||
|
Identifier[not set]{z}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TypeConstructor[not set]{
|
||||||
|
__vec_3__f32
|
||||||
|
Binary[not set]{
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_2}
|
||||||
|
Identifier[not set]{y}
|
||||||
|
}
|
||||||
|
multiply
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_1}
|
||||||
|
Identifier[not set]{x}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Binary[not set]{
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_2}
|
||||||
|
Identifier[not set]{y}
|
||||||
|
}
|
||||||
|
multiply
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_1}
|
||||||
|
Identifier[not set]{y}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Binary[not set]{
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_2}
|
||||||
|
Identifier[not set]{y}
|
||||||
|
}
|
||||||
|
multiply
|
||||||
|
MemberAccessor[not set]{
|
||||||
|
Identifier[not set]{x_1}
|
||||||
|
Identifier[not set]{z}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})"))
|
||||||
|
<< got;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO(dneto): OpSRem. Missing from WGSL
|
// TODO(dneto): OpSRem. Missing from WGSL
|
||||||
// https://github.com/gpuweb/gpuweb/issues/702
|
// https://github.com/gpuweb/gpuweb/issues/702
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue