Implement addition and subtraction of float matrices
Bug: tint:316 Change-Id: I3a1082c41c47daacf0220d029cb2a5f118684959 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52580 Commit-Queue: Antonio Maiorano <amaiorano@google.com> Commit-Queue: David Neto <dneto@google.com> Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
34c58a932c
commit
c91f8f2822
|
@ -2124,13 +2124,20 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
|
|||
}
|
||||
|
||||
// Matrix arithmetic
|
||||
// TODO(amaiorano): matrix-matrix addition and subtraction
|
||||
auto* lhs_mat = lhs_type->As<Matrix>();
|
||||
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
|
||||
auto* rhs_mat = rhs_type->As<Matrix>();
|
||||
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
|
||||
// Addition and subtraction of float matrices
|
||||
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
|
||||
lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
|
||||
rhs_mat_elem_type->Is<F32>() &&
|
||||
(lhs_mat->columns() == rhs_mat->columns()) &&
|
||||
(lhs_mat->rows() == rhs_mat->rows())) {
|
||||
SetType(expr, rhs_type);
|
||||
return true;
|
||||
}
|
||||
if (expr->IsMultiply()) {
|
||||
auto* lhs_mat = lhs_type->As<Matrix>();
|
||||
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
|
||||
auto* rhs_mat = rhs_type->As<Matrix>();
|
||||
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
|
||||
|
||||
// Multiplication of a matrix and a scalar
|
||||
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
|
||||
rhs_mat_elem_type->Is<F32>()) {
|
||||
|
|
|
@ -1238,9 +1238,12 @@ static constexpr ast::BinaryOp all_ops[] = {
|
|||
};
|
||||
|
||||
static constexpr create_ast_type_func_ptr all_create_type_funcs[] = {
|
||||
ast_bool, ast_u32, ast_i32, ast_f32,
|
||||
ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>,
|
||||
ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>};
|
||||
ast_bool, ast_u32, ast_i32, ast_f32,
|
||||
ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>,
|
||||
ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>, //
|
||||
ast_mat2x3<i32>, ast_mat2x3<u32>, ast_mat2x3<f32>, //
|
||||
ast_mat3x2<i32>, ast_mat3x2<u32>, ast_mat3x2<f32> //
|
||||
};
|
||||
|
||||
// A list of all valid test cases for 'lhs op rhs', except that for vecN and
|
||||
// matNxN, we only test N=3.
|
||||
|
@ -1338,14 +1341,43 @@ static constexpr Params all_valid_cases[] = {
|
|||
// Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
|
||||
// Matrix arithmetic
|
||||
Params{Op::kMultiply, ast_mat2x3<f32>, ast_f32, sem_mat2x3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_mat3x2<f32>, ast_f32, sem_mat3x2<sem_f32>},
|
||||
Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
|
||||
|
||||
Params{Op::kMultiply, ast_f32, ast_mat2x3<f32>, sem_mat2x3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_f32, ast_mat3x2<f32>, sem_mat3x2<sem_f32>},
|
||||
Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
|
||||
|
||||
Params{Op::kMultiply, ast_vec3<f32>, ast_mat2x3<f32>, sem_vec2<sem_f32>},
|
||||
Params{Op::kMultiply, ast_vec2<f32>, ast_mat3x2<f32>, sem_vec3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
|
||||
|
||||
Params{Op::kMultiply, ast_mat3x2<f32>, ast_vec3<f32>, sem_vec2<sem_f32>},
|
||||
Params{Op::kMultiply, ast_mat2x3<f32>, ast_vec2<f32>, sem_vec3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
// TODO(amaiorano): add mat+mat and mat-mat
|
||||
|
||||
Params{Op::kMultiply, ast_mat2x3<f32>, ast_mat3x2<f32>,
|
||||
sem_mat3x3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat2x3<f32>,
|
||||
sem_mat2x2<sem_f32>},
|
||||
Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat3x3<f32>,
|
||||
sem_mat3x2<sem_f32>},
|
||||
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
|
||||
sem_mat3x3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat2x3<f32>,
|
||||
sem_mat2x3<sem_f32>},
|
||||
|
||||
Params{Op::kAdd, ast_mat2x3<f32>, ast_mat2x3<f32>, sem_mat2x3<sem_f32>},
|
||||
Params{Op::kAdd, ast_mat3x2<f32>, ast_mat3x2<f32>, sem_mat3x2<sem_f32>},
|
||||
Params{Op::kAdd, ast_mat3x3<f32>, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
|
||||
|
||||
Params{Op::kSubtract, ast_mat2x3<f32>, ast_mat2x3<f32>,
|
||||
sem_mat2x3<sem_f32>},
|
||||
Params{Op::kSubtract, ast_mat3x2<f32>, ast_mat3x2<f32>,
|
||||
sem_mat3x2<sem_f32>},
|
||||
Params{Op::kSubtract, ast_mat3x3<f32>, ast_mat3x3<f32>,
|
||||
sem_mat3x3<sem_f32>},
|
||||
|
||||
// Comparison expressions
|
||||
// https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
|
||||
|
|
|
@ -176,6 +176,26 @@ ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
|
|||
return ty.mat2x2(create_type(ty));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.mat2x3<T>();
|
||||
}
|
||||
|
||||
template <create_ast_type_func_ptr create_type>
|
||||
ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.mat2x3(create_type(ty));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.mat3x2<T>();
|
||||
}
|
||||
|
||||
template <create_ast_type_func_ptr create_type>
|
||||
ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.mat3x2(create_type(ty));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.mat3x3<T>();
|
||||
|
@ -249,6 +269,18 @@ sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
|
|||
return ty.builder->create<sem::Matrix>(column_type, 2u);
|
||||
}
|
||||
|
||||
template <create_sem_type_func_ptr create_type>
|
||||
sem::Type* sem_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
|
||||
return ty.builder->create<sem::Matrix>(column_type, 2u);
|
||||
}
|
||||
|
||||
template <create_sem_type_func_ptr create_type>
|
||||
sem::Type* sem_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
|
||||
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 2u);
|
||||
return ty.builder->create<sem::Matrix>(column_type, 3u);
|
||||
}
|
||||
|
||||
template <create_sem_type_func_ptr create_type>
|
||||
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
|
||||
|
|
|
@ -1748,6 +1748,67 @@ uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) {
|
|||
return splat_result.to_i();
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateMatrixAddOrSub(uint32_t lhs_id,
|
||||
uint32_t rhs_id,
|
||||
const sem::Matrix* type,
|
||||
spv::Op op) {
|
||||
// Example addition of two matrices:
|
||||
// %31 = OpLoad %mat3v4float %m34
|
||||
// %32 = OpLoad %mat3v4float %m34
|
||||
// %33 = OpCompositeExtract %v4float %31 0
|
||||
// %34 = OpCompositeExtract %v4float %32 0
|
||||
// %35 = OpFAdd %v4float %33 %34
|
||||
// %36 = OpCompositeExtract %v4float %31 1
|
||||
// %37 = OpCompositeExtract %v4float %32 1
|
||||
// %38 = OpFAdd %v4float %36 %37
|
||||
// %39 = OpCompositeExtract %v4float %31 2
|
||||
// %40 = OpCompositeExtract %v4float %32 2
|
||||
// %41 = OpFAdd %v4float %39 %40
|
||||
// %42 = OpCompositeConstruct %mat3v4float %35 %38 %41
|
||||
|
||||
auto* column_type = builder_.create<sem::Vector>(type->type(), type->rows());
|
||||
auto column_type_id = GenerateTypeIfNeeded(column_type);
|
||||
|
||||
OperandList ops;
|
||||
|
||||
for (uint32_t i = 0; i < type->columns(); ++i) {
|
||||
// Extract column `i` from lhs mat
|
||||
auto lhs_column_id = result_op();
|
||||
if (!push_function_inst(spv::Op::OpCompositeExtract,
|
||||
{Operand::Int(column_type_id), lhs_column_id,
|
||||
Operand::Int(lhs_id), Operand::Int(i)})) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Extract column `i` from rhs mat
|
||||
auto rhs_column_id = result_op();
|
||||
if (!push_function_inst(spv::Op::OpCompositeExtract,
|
||||
{Operand::Int(column_type_id), rhs_column_id,
|
||||
Operand::Int(rhs_id), Operand::Int(i)})) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Add or subtract the two columns
|
||||
auto result = result_op();
|
||||
if (!push_function_inst(op, {Operand::Int(column_type_id), result,
|
||||
lhs_column_id, rhs_column_id})) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
ops.push_back(result);
|
||||
}
|
||||
|
||||
// Create the result matrix from the added/subtracted column vectors
|
||||
auto result_mat_id = result_op();
|
||||
ops.insert(ops.begin(), result_mat_id);
|
||||
ops.insert(ops.begin(), Operand::Int(GenerateTypeIfNeeded(type)));
|
||||
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return result_mat_id.to_i();
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
||||
// There is special logic for short circuiting operators.
|
||||
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
|
||||
|
@ -1779,6 +1840,24 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
|||
auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
|
||||
auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
|
||||
|
||||
// Handle matrix-matrix addition and subtraction
|
||||
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_type->is_float_matrix() &&
|
||||
rhs_type->is_float_matrix()) {
|
||||
auto* lhs_mat = lhs_type->As<sem::Matrix>();
|
||||
auto* rhs_mat = rhs_type->As<sem::Matrix>();
|
||||
|
||||
// This should already have been validated by resolver
|
||||
if (lhs_mat->rows() != rhs_mat->rows() ||
|
||||
lhs_mat->columns() != rhs_mat->columns()) {
|
||||
error_ = "matrices must have same dimensionality for add or subtract";
|
||||
return 0;
|
||||
}
|
||||
|
||||
return GenerateMatrixAddOrSub(
|
||||
lhs_id, rhs_id, lhs_mat,
|
||||
expr->IsAdd() ? spv::Op::OpFAdd : spv::Op::OpFSub);
|
||||
}
|
||||
|
||||
// For vector-scalar arithmetic operations, splat scalar into a vector. We
|
||||
// skip this for multiply as we can use OpVectorTimesScalar.
|
||||
const bool is_float_scalar_vector_multiply =
|
||||
|
|
|
@ -482,6 +482,17 @@ class Builder {
|
|||
/// @returns id of the new vector
|
||||
uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type);
|
||||
|
||||
/// Generates instructions to add or subtract two matrices
|
||||
/// @param lhs_id id of multiplicand
|
||||
/// @param rhs_id id of multiplier
|
||||
/// @param type type of both matrices and of result
|
||||
/// @param op one of `spv::Op::OpFAdd` or `spv::Op::OpFSub`
|
||||
/// @returns id of the result matrix
|
||||
uint32_t GenerateMatrixAddOrSub(uint32_t lhs_id,
|
||||
uint32_t rhs_id,
|
||||
const sem::Matrix* type,
|
||||
spv::Op op);
|
||||
|
||||
/// Converts AST image format to SPIR-V and pushes an appropriate capability.
|
||||
/// @param format AST image format type
|
||||
/// @returns SPIR-V image format type
|
||||
|
|
|
@ -863,8 +863,10 @@ OpBranch %9
|
|||
)");
|
||||
}
|
||||
|
||||
namespace BinaryArithVectorScalar {
|
||||
|
||||
enum class Type { f32, i32, u32 };
|
||||
ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
|
||||
static ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
|
||||
|
@ -875,7 +877,7 @@ ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
|
||||
static ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return builder->Expr(1.f);
|
||||
|
@ -886,7 +888,7 @@ ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
|
|||
}
|
||||
return nullptr;
|
||||
}
|
||||
std::string OpTypeDecl(Type type) {
|
||||
static std::string OpTypeDecl(Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return "OpTypeFloat 32";
|
||||
|
@ -904,8 +906,8 @@ struct Param {
|
|||
std::string name;
|
||||
};
|
||||
|
||||
using MixedBinaryArithTest = TestParamHelper<Param>;
|
||||
TEST_P(MixedBinaryArithTest, VectorScalar) {
|
||||
using BinaryArithVectorScalarTest = TestParamHelper<Param>;
|
||||
TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
|
||||
auto& param = GetParam();
|
||||
|
||||
ast::Expression* lhs = MakeVectorExpr(this, param.type);
|
||||
|
@ -943,7 +945,7 @@ OpFunctionEnd
|
|||
|
||||
Validate(b);
|
||||
}
|
||||
TEST_P(MixedBinaryArithTest, ScalarVector) {
|
||||
TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
|
||||
auto& param = GetParam();
|
||||
|
||||
ast::Expression* lhs = MakeScalarExpr(this, param.type);
|
||||
|
@ -983,7 +985,7 @@ OpFunctionEnd
|
|||
}
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
BuilderTest,
|
||||
MixedBinaryArithTest,
|
||||
BinaryArithVectorScalarTest,
|
||||
testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
|
||||
Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
|
||||
// NOTE: Modulo not allowed on mixed float scalar-vector
|
||||
|
@ -1005,8 +1007,8 @@ INSTANTIATE_TEST_SUITE_P(
|
|||
Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
|
||||
Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
|
||||
|
||||
using MixedBinaryArithMultiplyTest = TestParamHelper<Param>;
|
||||
TEST_P(MixedBinaryArithMultiplyTest, VectorScalar) {
|
||||
using BinaryArithVectorScalarMultiplyTest = TestParamHelper<Param>;
|
||||
TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
|
||||
auto& param = GetParam();
|
||||
|
||||
ast::Expression* lhs = MakeVectorExpr(this, param.type);
|
||||
|
@ -1040,7 +1042,7 @@ OpFunctionEnd
|
|||
|
||||
Validate(b);
|
||||
}
|
||||
TEST_P(MixedBinaryArithMultiplyTest, ScalarVector) {
|
||||
TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
|
||||
auto& param = GetParam();
|
||||
|
||||
ast::Expression* lhs = MakeScalarExpr(this, param.type);
|
||||
|
@ -1075,10 +1077,113 @@ OpFunctionEnd
|
|||
Validate(b);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(BuilderTest,
|
||||
MixedBinaryArithMultiplyTest,
|
||||
BinaryArithVectorScalarMultiplyTest,
|
||||
testing::Values(Param{
|
||||
Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
|
||||
|
||||
} // namespace BinaryArithVectorScalar
|
||||
|
||||
namespace BinaryArithMatrixMatrix {
|
||||
|
||||
struct Param {
|
||||
ast::BinaryOp op;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
using BinaryArithMatrixMatrix = TestParamHelper<Param>;
|
||||
TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
|
||||
auto& param = GetParam();
|
||||
|
||||
ast::Expression* lhs = mat3x4<f32>();
|
||||
ast::Expression* rhs = mat3x4<f32>();
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
OpName %3 "test_function"
|
||||
%2 = OpTypeVoid
|
||||
%1 = OpTypeFunction %2
|
||||
%7 = OpTypeFloat 32
|
||||
%6 = OpTypeVector %7 4
|
||||
%5 = OpTypeMatrix %6 3
|
||||
%8 = OpConstantNull %5
|
||||
%3 = OpFunction %2 None %1
|
||||
%4 = OpLabel
|
||||
%10 = OpCompositeExtract %6 %8 0
|
||||
%11 = OpCompositeExtract %6 %8 0
|
||||
%12 = )" + param.name + R"( %6 %10 %11
|
||||
%13 = OpCompositeExtract %6 %8 1
|
||||
%14 = OpCompositeExtract %6 %8 1
|
||||
%15 = )" + param.name + R"( %6 %13 %14
|
||||
%16 = OpCompositeExtract %6 %8 2
|
||||
%17 = OpCompositeExtract %6 %8 2
|
||||
%18 = )" + param.name + R"( %6 %16 %17
|
||||
%19 = OpCompositeConstruct %5 %12 %15 %18
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
|
||||
Validate(b);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
BuilderTest,
|
||||
BinaryArithMatrixMatrix,
|
||||
testing::Values(Param{ast::BinaryOp::kAdd, "OpFAdd"},
|
||||
Param{ast::BinaryOp::kSubtract, "OpFSub"}));
|
||||
|
||||
using BinaryArithMatrixMatrixMultiply = TestParamHelper<Param>;
|
||||
TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
|
||||
auto& param = GetParam();
|
||||
|
||||
ast::Expression* lhs = mat3x4<f32>();
|
||||
ast::Expression* rhs = mat4x3<f32>();
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
OpName %3 "test_function"
|
||||
%2 = OpTypeVoid
|
||||
%1 = OpTypeFunction %2
|
||||
%7 = OpTypeFloat 32
|
||||
%6 = OpTypeVector %7 4
|
||||
%5 = OpTypeMatrix %6 3
|
||||
%8 = OpConstantNull %5
|
||||
%10 = OpTypeVector %7 3
|
||||
%9 = OpTypeMatrix %10 4
|
||||
%11 = OpConstantNull %9
|
||||
%13 = OpTypeMatrix %6 4
|
||||
%3 = OpFunction %2 None %1
|
||||
%4 = OpLabel
|
||||
%12 = OpMatrixTimesMatrix %13 %8 %11
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
|
||||
Validate(b);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P( //
|
||||
BuilderTest,
|
||||
BinaryArithMatrixMatrixMultiply,
|
||||
testing::Values(Param{ast::BinaryOp::kMultiply, "OpFMul"}));
|
||||
|
||||
} // namespace BinaryArithMatrixMatrix
|
||||
|
||||
} // namespace
|
||||
} // namespace spirv
|
||||
} // namespace writer
|
||||
|
|
|
@ -64,6 +64,19 @@ fn scalar_vector_u32() {
|
|||
r = s % v;
|
||||
}
|
||||
|
||||
fn matrix_matrix_f32() {
|
||||
var m34 : mat3x4<f32>;
|
||||
var m43 : mat4x3<f32>;
|
||||
var m33 : mat3x3<f32>;
|
||||
var m44 : mat4x4<f32>;
|
||||
|
||||
m34 = m34 + m34;
|
||||
m34 = m34 - m34;
|
||||
|
||||
m33 = m43 * m34;
|
||||
m44 = m34 * m43;
|
||||
}
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn main() -> [[location(0)]] vec4<f32> {
|
||||
return vec4<f32>(0.0,0.0,0.0,0.0);
|
||||
|
|
|
@ -66,6 +66,17 @@ void scalar_vector_u32() {
|
|||
r = (s % v);
|
||||
}
|
||||
|
||||
void matrix_matrix_f32() {
|
||||
float3x4 m34 = float3x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
float4x3 m43 = float4x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
float3x3 m33 = float3x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
float4x4 m44 = float4x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||
m34 = (m34 + m34);
|
||||
m34 = (m34 - m34);
|
||||
m33 = mul(m34, m43);
|
||||
m44 = mul(m43, m34);
|
||||
}
|
||||
|
||||
tint_symbol main() {
|
||||
const tint_symbol tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f)};
|
||||
return tint_symbol_1;
|
||||
|
|
|
@ -69,6 +69,17 @@ void scalar_vector_u32() {
|
|||
r = (s % v);
|
||||
}
|
||||
|
||||
void matrix_matrix_f32() {
|
||||
float3x4 m34 = float3x4(0.0f);
|
||||
float4x3 m43 = float4x3(0.0f);
|
||||
float3x3 m33 = float3x3(0.0f);
|
||||
float4x4 m44 = float4x4(0.0f);
|
||||
m34 = (m34 + m34);
|
||||
m34 = (m34 - m34);
|
||||
m33 = (m43 * m34);
|
||||
m44 = (m34 * m43);
|
||||
}
|
||||
|
||||
fragment tint_symbol_1 tint_symbol() {
|
||||
return {float4(0.0f, 0.0f, 0.0f, 0.0f)};
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
; SPIR-V
|
||||
; Version: 1.3
|
||||
; Generator: Google Tint Compiler; 0
|
||||
; Bound: 200
|
||||
; Bound: 250
|
||||
; Schema: 0
|
||||
OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
|
@ -32,6 +32,11 @@
|
|||
OpName %v_4 "v"
|
||||
OpName %s_4 "s"
|
||||
OpName %r_4 "r"
|
||||
OpName %matrix_matrix_f32 "matrix_matrix_f32"
|
||||
OpName %m34 "m34"
|
||||
OpName %m43 "m43"
|
||||
OpName %m33 "m33"
|
||||
OpName %m44 "m44"
|
||||
OpName %tint_symbol_2 "tint_symbol_2"
|
||||
OpName %tint_symbol "tint_symbol"
|
||||
OpName %main "main"
|
||||
|
@ -60,9 +65,21 @@
|
|||
%78 = OpConstantNull %v3uint
|
||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||
%81 = OpConstantNull %uint
|
||||
%191 = OpTypeFunction %void %v4float
|
||||
%mat3v4float = OpTypeMatrix %v4float 3
|
||||
%_ptr_Function_mat3v4float = OpTypePointer Function %mat3v4float
|
||||
%196 = OpConstantNull %mat3v4float
|
||||
%mat4v3float = OpTypeMatrix %v3float 4
|
||||
%_ptr_Function_mat4v3float = OpTypePointer Function %mat4v3float
|
||||
%200 = OpConstantNull %mat4v3float
|
||||
%mat3v3float = OpTypeMatrix %v3float 3
|
||||
%_ptr_Function_mat3v3float = OpTypePointer Function %mat3v3float
|
||||
%204 = OpConstantNull %mat3v3float
|
||||
%mat4v4float = OpTypeMatrix %v4float 4
|
||||
%_ptr_Function_mat4v4float = OpTypePointer Function %mat4v4float
|
||||
%208 = OpConstantNull %mat4v4float
|
||||
%241 = OpTypeFunction %void %v4float
|
||||
%float_0 = OpConstant %float 0
|
||||
%199 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
|
||||
%249 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
|
||||
%vector_scalar_f32 = OpFunction %void None %6
|
||||
%9 = OpLabel
|
||||
%v = OpVariable %_ptr_Function_v3float Function %13
|
||||
|
@ -269,14 +286,56 @@
|
|||
OpStore %r_4 %188
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%tint_symbol_2 = OpFunction %void None %191
|
||||
%matrix_matrix_f32 = OpFunction %void None %6
|
||||
%192 = OpLabel
|
||||
%m34 = OpVariable %_ptr_Function_mat3v4float Function %196
|
||||
%m43 = OpVariable %_ptr_Function_mat4v3float Function %200
|
||||
%m33 = OpVariable %_ptr_Function_mat3v3float Function %204
|
||||
%m44 = OpVariable %_ptr_Function_mat4v4float Function %208
|
||||
%209 = OpLoad %mat3v4float %m34
|
||||
%210 = OpLoad %mat3v4float %m34
|
||||
%212 = OpCompositeExtract %v4float %209 0
|
||||
%213 = OpCompositeExtract %v4float %210 0
|
||||
%214 = OpFAdd %v4float %212 %213
|
||||
%215 = OpCompositeExtract %v4float %209 1
|
||||
%216 = OpCompositeExtract %v4float %210 1
|
||||
%217 = OpFAdd %v4float %215 %216
|
||||
%218 = OpCompositeExtract %v4float %209 2
|
||||
%219 = OpCompositeExtract %v4float %210 2
|
||||
%220 = OpFAdd %v4float %218 %219
|
||||
%221 = OpCompositeConstruct %mat3v4float %214 %217 %220
|
||||
OpStore %m34 %221
|
||||
%222 = OpLoad %mat3v4float %m34
|
||||
%223 = OpLoad %mat3v4float %m34
|
||||
%225 = OpCompositeExtract %v4float %222 0
|
||||
%226 = OpCompositeExtract %v4float %223 0
|
||||
%227 = OpFSub %v4float %225 %226
|
||||
%228 = OpCompositeExtract %v4float %222 1
|
||||
%229 = OpCompositeExtract %v4float %223 1
|
||||
%230 = OpFSub %v4float %228 %229
|
||||
%231 = OpCompositeExtract %v4float %222 2
|
||||
%232 = OpCompositeExtract %v4float %223 2
|
||||
%233 = OpFSub %v4float %231 %232
|
||||
%234 = OpCompositeConstruct %mat3v4float %227 %230 %233
|
||||
OpStore %m34 %234
|
||||
%235 = OpLoad %mat4v3float %m43
|
||||
%236 = OpLoad %mat3v4float %m34
|
||||
%237 = OpMatrixTimesMatrix %mat3v3float %235 %236
|
||||
OpStore %m33 %237
|
||||
%238 = OpLoad %mat3v4float %m34
|
||||
%239 = OpLoad %mat4v3float %m43
|
||||
%240 = OpMatrixTimesMatrix %mat4v4float %238 %239
|
||||
OpStore %m44 %240
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%tint_symbol_2 = OpFunction %void None %241
|
||||
%tint_symbol = OpFunctionParameter %v4float
|
||||
%194 = OpLabel
|
||||
%244 = OpLabel
|
||||
OpStore %tint_symbol_1 %tint_symbol
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
%main = OpFunction %void None %6
|
||||
%196 = OpLabel
|
||||
%197 = OpFunctionCall %void %tint_symbol_2 %199
|
||||
%246 = OpLabel
|
||||
%247 = OpFunctionCall %void %tint_symbol_2 %249
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
|
|
|
@ -62,6 +62,17 @@ fn scalar_vector_u32() {
|
|||
r = (s % v);
|
||||
}
|
||||
|
||||
fn matrix_matrix_f32() {
|
||||
var m34 : mat3x4<f32>;
|
||||
var m43 : mat4x3<f32>;
|
||||
var m33 : mat3x3<f32>;
|
||||
var m44 : mat4x4<f32>;
|
||||
m34 = (m34 + m34);
|
||||
m34 = (m34 - m34);
|
||||
m33 = (m43 * m34);
|
||||
m44 = (m34 * m43);
|
||||
}
|
||||
|
||||
[[stage(fragment)]]
|
||||
fn main() -> [[location(0)]] vec4<f32> {
|
||||
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||
|
|
Loading…
Reference in New Issue