Implement addition and subtraction of float matrices
Bug: tint:316 Change-Id: I3a1082c41c47daacf0220d029cb2a5f118684959 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52580 Commit-Queue: Antonio Maiorano <amaiorano@google.com> Commit-Queue: David Neto <dneto@google.com> Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
34c58a932c
commit
c91f8f2822
|
@ -2124,13 +2124,20 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matrix arithmetic
|
// Matrix arithmetic
|
||||||
// TODO(amaiorano): matrix-matrix addition and subtraction
|
auto* lhs_mat = lhs_type->As<Matrix>();
|
||||||
|
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
|
||||||
|
auto* rhs_mat = rhs_type->As<Matrix>();
|
||||||
|
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
|
||||||
|
// Addition and subtraction of float matrices
|
||||||
|
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
|
||||||
|
lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
|
||||||
|
rhs_mat_elem_type->Is<F32>() &&
|
||||||
|
(lhs_mat->columns() == rhs_mat->columns()) &&
|
||||||
|
(lhs_mat->rows() == rhs_mat->rows())) {
|
||||||
|
SetType(expr, rhs_type);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (expr->IsMultiply()) {
|
if (expr->IsMultiply()) {
|
||||||
auto* lhs_mat = lhs_type->As<Matrix>();
|
|
||||||
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
|
|
||||||
auto* rhs_mat = rhs_type->As<Matrix>();
|
|
||||||
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
|
|
||||||
|
|
||||||
// Multiplication of a matrix and a scalar
|
// Multiplication of a matrix and a scalar
|
||||||
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
|
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
|
||||||
rhs_mat_elem_type->Is<F32>()) {
|
rhs_mat_elem_type->Is<F32>()) {
|
||||||
|
|
|
@ -1238,9 +1238,12 @@ static constexpr ast::BinaryOp all_ops[] = {
|
||||||
};
|
};
|
||||||
|
|
||||||
static constexpr create_ast_type_func_ptr all_create_type_funcs[] = {
|
static constexpr create_ast_type_func_ptr all_create_type_funcs[] = {
|
||||||
ast_bool, ast_u32, ast_i32, ast_f32,
|
ast_bool, ast_u32, ast_i32, ast_f32,
|
||||||
ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>,
|
ast_vec3<bool>, ast_vec3<i32>, ast_vec3<u32>, ast_vec3<f32>,
|
||||||
ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>};
|
ast_mat3x3<i32>, ast_mat3x3<u32>, ast_mat3x3<f32>, //
|
||||||
|
ast_mat2x3<i32>, ast_mat2x3<u32>, ast_mat2x3<f32>, //
|
||||||
|
ast_mat3x2<i32>, ast_mat3x2<u32>, ast_mat3x2<f32> //
|
||||||
|
};
|
||||||
|
|
||||||
// A list of all valid test cases for 'lhs op rhs', except that for vecN and
|
// A list of all valid test cases for 'lhs op rhs', except that for vecN and
|
||||||
// matNxN, we only test N=3.
|
// matNxN, we only test N=3.
|
||||||
|
@ -1338,14 +1341,43 @@ static constexpr Params all_valid_cases[] = {
|
||||||
// Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
|
// Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||||
|
|
||||||
// Matrix arithmetic
|
// Matrix arithmetic
|
||||||
|
Params{Op::kMultiply, ast_mat2x3<f32>, ast_f32, sem_mat2x3<sem_f32>},
|
||||||
|
Params{Op::kMultiply, ast_mat3x2<f32>, ast_f32, sem_mat3x2<sem_f32>},
|
||||||
Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
|
Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
|
||||||
|
|
||||||
|
Params{Op::kMultiply, ast_f32, ast_mat2x3<f32>, sem_mat2x3<sem_f32>},
|
||||||
|
Params{Op::kMultiply, ast_f32, ast_mat3x2<f32>, sem_mat3x2<sem_f32>},
|
||||||
Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
|
Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
|
||||||
|
|
||||||
|
Params{Op::kMultiply, ast_vec3<f32>, ast_mat2x3<f32>, sem_vec2<sem_f32>},
|
||||||
|
Params{Op::kMultiply, ast_vec2<f32>, ast_mat3x2<f32>, sem_vec3<sem_f32>},
|
||||||
Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
|
Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
|
||||||
|
|
||||||
|
Params{Op::kMultiply, ast_mat3x2<f32>, ast_vec3<f32>, sem_vec2<sem_f32>},
|
||||||
|
Params{Op::kMultiply, ast_mat2x3<f32>, ast_vec2<f32>, sem_vec3<sem_f32>},
|
||||||
Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
|
Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||||
// TODO(amaiorano): add mat+mat and mat-mat
|
|
||||||
|
Params{Op::kMultiply, ast_mat2x3<f32>, ast_mat3x2<f32>,
|
||||||
|
sem_mat3x3<sem_f32>},
|
||||||
|
Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat2x3<f32>,
|
||||||
|
sem_mat2x2<sem_f32>},
|
||||||
|
Params{Op::kMultiply, ast_mat3x2<f32>, ast_mat3x3<f32>,
|
||||||
|
sem_mat3x2<sem_f32>},
|
||||||
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
|
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
|
||||||
sem_mat3x3<sem_f32>},
|
sem_mat3x3<sem_f32>},
|
||||||
|
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat2x3<f32>,
|
||||||
|
sem_mat2x3<sem_f32>},
|
||||||
|
|
||||||
|
Params{Op::kAdd, ast_mat2x3<f32>, ast_mat2x3<f32>, sem_mat2x3<sem_f32>},
|
||||||
|
Params{Op::kAdd, ast_mat3x2<f32>, ast_mat3x2<f32>, sem_mat3x2<sem_f32>},
|
||||||
|
Params{Op::kAdd, ast_mat3x3<f32>, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
|
||||||
|
|
||||||
|
Params{Op::kSubtract, ast_mat2x3<f32>, ast_mat2x3<f32>,
|
||||||
|
sem_mat2x3<sem_f32>},
|
||||||
|
Params{Op::kSubtract, ast_mat3x2<f32>, ast_mat3x2<f32>,
|
||||||
|
sem_mat3x2<sem_f32>},
|
||||||
|
Params{Op::kSubtract, ast_mat3x3<f32>, ast_mat3x3<f32>,
|
||||||
|
sem_mat3x3<sem_f32>},
|
||||||
|
|
||||||
// Comparison expressions
|
// Comparison expressions
|
||||||
// https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
|
// https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
|
||||||
|
|
|
@ -176,6 +176,26 @@ ast::Type* ast_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
return ty.mat2x2(create_type(ty));
|
return ty.mat2x2(create_type(ty));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
|
return ty.mat2x3<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <create_ast_type_func_ptr create_type>
|
||||||
|
ast::Type* ast_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
|
return ty.mat2x3(create_type(ty));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
|
return ty.mat3x2<T>();
|
||||||
|
}
|
||||||
|
|
||||||
|
template <create_ast_type_func_ptr create_type>
|
||||||
|
ast::Type* ast_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
|
return ty.mat3x2(create_type(ty));
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
ast::Type* ast_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
return ty.mat3x3<T>();
|
return ty.mat3x3<T>();
|
||||||
|
@ -249,6 +269,18 @@ sem::Type* sem_mat2x2(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
return ty.builder->create<sem::Matrix>(column_type, 2u);
|
return ty.builder->create<sem::Matrix>(column_type, 2u);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <create_sem_type_func_ptr create_type>
|
||||||
|
sem::Type* sem_mat2x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
|
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
|
||||||
|
return ty.builder->create<sem::Matrix>(column_type, 2u);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <create_sem_type_func_ptr create_type>
|
||||||
|
sem::Type* sem_mat3x2(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
|
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 2u);
|
||||||
|
return ty.builder->create<sem::Matrix>(column_type, 3u);
|
||||||
|
}
|
||||||
|
|
||||||
template <create_sem_type_func_ptr create_type>
|
template <create_sem_type_func_ptr create_type>
|
||||||
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
sem::Type* sem_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||||
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
|
auto* column_type = ty.builder->create<sem::Vector>(create_type(ty), 3u);
|
||||||
|
|
|
@ -1748,6 +1748,67 @@ uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) {
|
||||||
return splat_result.to_i();
|
return splat_result.to_i();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint32_t Builder::GenerateMatrixAddOrSub(uint32_t lhs_id,
|
||||||
|
uint32_t rhs_id,
|
||||||
|
const sem::Matrix* type,
|
||||||
|
spv::Op op) {
|
||||||
|
// Example addition of two matrices:
|
||||||
|
// %31 = OpLoad %mat3v4float %m34
|
||||||
|
// %32 = OpLoad %mat3v4float %m34
|
||||||
|
// %33 = OpCompositeExtract %v4float %31 0
|
||||||
|
// %34 = OpCompositeExtract %v4float %32 0
|
||||||
|
// %35 = OpFAdd %v4float %33 %34
|
||||||
|
// %36 = OpCompositeExtract %v4float %31 1
|
||||||
|
// %37 = OpCompositeExtract %v4float %32 1
|
||||||
|
// %38 = OpFAdd %v4float %36 %37
|
||||||
|
// %39 = OpCompositeExtract %v4float %31 2
|
||||||
|
// %40 = OpCompositeExtract %v4float %32 2
|
||||||
|
// %41 = OpFAdd %v4float %39 %40
|
||||||
|
// %42 = OpCompositeConstruct %mat3v4float %35 %38 %41
|
||||||
|
|
||||||
|
auto* column_type = builder_.create<sem::Vector>(type->type(), type->rows());
|
||||||
|
auto column_type_id = GenerateTypeIfNeeded(column_type);
|
||||||
|
|
||||||
|
OperandList ops;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < type->columns(); ++i) {
|
||||||
|
// Extract column `i` from lhs mat
|
||||||
|
auto lhs_column_id = result_op();
|
||||||
|
if (!push_function_inst(spv::Op::OpCompositeExtract,
|
||||||
|
{Operand::Int(column_type_id), lhs_column_id,
|
||||||
|
Operand::Int(lhs_id), Operand::Int(i)})) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract column `i` from rhs mat
|
||||||
|
auto rhs_column_id = result_op();
|
||||||
|
if (!push_function_inst(spv::Op::OpCompositeExtract,
|
||||||
|
{Operand::Int(column_type_id), rhs_column_id,
|
||||||
|
Operand::Int(rhs_id), Operand::Int(i)})) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add or subtract the two columns
|
||||||
|
auto result = result_op();
|
||||||
|
if (!push_function_inst(op, {Operand::Int(column_type_id), result,
|
||||||
|
lhs_column_id, rhs_column_id})) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
ops.push_back(result);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the result matrix from the added/subtracted column vectors
|
||||||
|
auto result_mat_id = result_op();
|
||||||
|
ops.insert(ops.begin(), result_mat_id);
|
||||||
|
ops.insert(ops.begin(), Operand::Int(GenerateTypeIfNeeded(type)));
|
||||||
|
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return result_mat_id.to_i();
|
||||||
|
}
|
||||||
|
|
||||||
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
||||||
// There is special logic for short circuiting operators.
|
// There is special logic for short circuiting operators.
|
||||||
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
|
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
|
||||||
|
@ -1779,6 +1840,24 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
||||||
auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
|
auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
|
||||||
auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
|
auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
|
||||||
|
|
||||||
|
// Handle matrix-matrix addition and subtraction
|
||||||
|
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_type->is_float_matrix() &&
|
||||||
|
rhs_type->is_float_matrix()) {
|
||||||
|
auto* lhs_mat = lhs_type->As<sem::Matrix>();
|
||||||
|
auto* rhs_mat = rhs_type->As<sem::Matrix>();
|
||||||
|
|
||||||
|
// This should already have been validated by resolver
|
||||||
|
if (lhs_mat->rows() != rhs_mat->rows() ||
|
||||||
|
lhs_mat->columns() != rhs_mat->columns()) {
|
||||||
|
error_ = "matrices must have same dimensionality for add or subtract";
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
return GenerateMatrixAddOrSub(
|
||||||
|
lhs_id, rhs_id, lhs_mat,
|
||||||
|
expr->IsAdd() ? spv::Op::OpFAdd : spv::Op::OpFSub);
|
||||||
|
}
|
||||||
|
|
||||||
// For vector-scalar arithmetic operations, splat scalar into a vector. We
|
// For vector-scalar arithmetic operations, splat scalar into a vector. We
|
||||||
// skip this for multiply as we can use OpVectorTimesScalar.
|
// skip this for multiply as we can use OpVectorTimesScalar.
|
||||||
const bool is_float_scalar_vector_multiply =
|
const bool is_float_scalar_vector_multiply =
|
||||||
|
|
|
@ -482,6 +482,17 @@ class Builder {
|
||||||
/// @returns id of the new vector
|
/// @returns id of the new vector
|
||||||
uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type);
|
uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type);
|
||||||
|
|
||||||
|
/// Generates instructions to add or subtract two matrices
|
||||||
|
/// @param lhs_id id of multiplicand
|
||||||
|
/// @param rhs_id id of multiplier
|
||||||
|
/// @param type type of both matrices and of result
|
||||||
|
/// @param op one of `spv::Op::OpFAdd` or `spv::Op::OpFSub`
|
||||||
|
/// @returns id of the result matrix
|
||||||
|
uint32_t GenerateMatrixAddOrSub(uint32_t lhs_id,
|
||||||
|
uint32_t rhs_id,
|
||||||
|
const sem::Matrix* type,
|
||||||
|
spv::Op op);
|
||||||
|
|
||||||
/// Converts AST image format to SPIR-V and pushes an appropriate capability.
|
/// Converts AST image format to SPIR-V and pushes an appropriate capability.
|
||||||
/// @param format AST image format type
|
/// @param format AST image format type
|
||||||
/// @returns SPIR-V image format type
|
/// @returns SPIR-V image format type
|
||||||
|
|
|
@ -863,8 +863,10 @@ OpBranch %9
|
||||||
)");
|
)");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace BinaryArithVectorScalar {
|
||||||
|
|
||||||
enum class Type { f32, i32, u32 };
|
enum class Type { f32, i32, u32 };
|
||||||
ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
|
static ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case Type::f32:
|
case Type::f32:
|
||||||
return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
|
return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
|
||||||
|
@ -875,7 +877,7 @@ ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
|
static ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case Type::f32:
|
case Type::f32:
|
||||||
return builder->Expr(1.f);
|
return builder->Expr(1.f);
|
||||||
|
@ -886,7 +888,7 @@ ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
std::string OpTypeDecl(Type type) {
|
static std::string OpTypeDecl(Type type) {
|
||||||
switch (type) {
|
switch (type) {
|
||||||
case Type::f32:
|
case Type::f32:
|
||||||
return "OpTypeFloat 32";
|
return "OpTypeFloat 32";
|
||||||
|
@ -904,8 +906,8 @@ struct Param {
|
||||||
std::string name;
|
std::string name;
|
||||||
};
|
};
|
||||||
|
|
||||||
using MixedBinaryArithTest = TestParamHelper<Param>;
|
using BinaryArithVectorScalarTest = TestParamHelper<Param>;
|
||||||
TEST_P(MixedBinaryArithTest, VectorScalar) {
|
TEST_P(BinaryArithVectorScalarTest, VectorScalar) {
|
||||||
auto& param = GetParam();
|
auto& param = GetParam();
|
||||||
|
|
||||||
ast::Expression* lhs = MakeVectorExpr(this, param.type);
|
ast::Expression* lhs = MakeVectorExpr(this, param.type);
|
||||||
|
@ -943,7 +945,7 @@ OpFunctionEnd
|
||||||
|
|
||||||
Validate(b);
|
Validate(b);
|
||||||
}
|
}
|
||||||
TEST_P(MixedBinaryArithTest, ScalarVector) {
|
TEST_P(BinaryArithVectorScalarTest, ScalarVector) {
|
||||||
auto& param = GetParam();
|
auto& param = GetParam();
|
||||||
|
|
||||||
ast::Expression* lhs = MakeScalarExpr(this, param.type);
|
ast::Expression* lhs = MakeScalarExpr(this, param.type);
|
||||||
|
@ -983,7 +985,7 @@ OpFunctionEnd
|
||||||
}
|
}
|
||||||
INSTANTIATE_TEST_SUITE_P(
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
BuilderTest,
|
BuilderTest,
|
||||||
MixedBinaryArithTest,
|
BinaryArithVectorScalarTest,
|
||||||
testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
|
testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
|
||||||
Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
|
Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
|
||||||
// NOTE: Modulo not allowed on mixed float scalar-vector
|
// NOTE: Modulo not allowed on mixed float scalar-vector
|
||||||
|
@ -1005,8 +1007,8 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
|
Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
|
||||||
Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
|
Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
|
||||||
|
|
||||||
using MixedBinaryArithMultiplyTest = TestParamHelper<Param>;
|
using BinaryArithVectorScalarMultiplyTest = TestParamHelper<Param>;
|
||||||
TEST_P(MixedBinaryArithMultiplyTest, VectorScalar) {
|
TEST_P(BinaryArithVectorScalarMultiplyTest, VectorScalar) {
|
||||||
auto& param = GetParam();
|
auto& param = GetParam();
|
||||||
|
|
||||||
ast::Expression* lhs = MakeVectorExpr(this, param.type);
|
ast::Expression* lhs = MakeVectorExpr(this, param.type);
|
||||||
|
@ -1040,7 +1042,7 @@ OpFunctionEnd
|
||||||
|
|
||||||
Validate(b);
|
Validate(b);
|
||||||
}
|
}
|
||||||
TEST_P(MixedBinaryArithMultiplyTest, ScalarVector) {
|
TEST_P(BinaryArithVectorScalarMultiplyTest, ScalarVector) {
|
||||||
auto& param = GetParam();
|
auto& param = GetParam();
|
||||||
|
|
||||||
ast::Expression* lhs = MakeScalarExpr(this, param.type);
|
ast::Expression* lhs = MakeScalarExpr(this, param.type);
|
||||||
|
@ -1075,10 +1077,113 @@ OpFunctionEnd
|
||||||
Validate(b);
|
Validate(b);
|
||||||
}
|
}
|
||||||
INSTANTIATE_TEST_SUITE_P(BuilderTest,
|
INSTANTIATE_TEST_SUITE_P(BuilderTest,
|
||||||
MixedBinaryArithMultiplyTest,
|
BinaryArithVectorScalarMultiplyTest,
|
||||||
testing::Values(Param{
|
testing::Values(Param{
|
||||||
Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
|
Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
|
||||||
|
|
||||||
|
} // namespace BinaryArithVectorScalar
|
||||||
|
|
||||||
|
namespace BinaryArithMatrixMatrix {
|
||||||
|
|
||||||
|
struct Param {
|
||||||
|
ast::BinaryOp op;
|
||||||
|
std::string name;
|
||||||
|
};
|
||||||
|
|
||||||
|
using BinaryArithMatrixMatrix = TestParamHelper<Param>;
|
||||||
|
TEST_P(BinaryArithMatrixMatrix, AddOrSubtract) {
|
||||||
|
auto& param = GetParam();
|
||||||
|
|
||||||
|
ast::Expression* lhs = mat3x4<f32>();
|
||||||
|
ast::Expression* rhs = mat3x4<f32>();
|
||||||
|
|
||||||
|
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||||
|
|
||||||
|
WrapInFunction(expr);
|
||||||
|
|
||||||
|
spirv::Builder& b = Build();
|
||||||
|
ASSERT_TRUE(b.Build()) << b.error();
|
||||||
|
|
||||||
|
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %3 "test_function"
|
||||||
|
OpExecutionMode %3 LocalSize 1 1 1
|
||||||
|
OpName %3 "test_function"
|
||||||
|
%2 = OpTypeVoid
|
||||||
|
%1 = OpTypeFunction %2
|
||||||
|
%7 = OpTypeFloat 32
|
||||||
|
%6 = OpTypeVector %7 4
|
||||||
|
%5 = OpTypeMatrix %6 3
|
||||||
|
%8 = OpConstantNull %5
|
||||||
|
%3 = OpFunction %2 None %1
|
||||||
|
%4 = OpLabel
|
||||||
|
%10 = OpCompositeExtract %6 %8 0
|
||||||
|
%11 = OpCompositeExtract %6 %8 0
|
||||||
|
%12 = )" + param.name + R"( %6 %10 %11
|
||||||
|
%13 = OpCompositeExtract %6 %8 1
|
||||||
|
%14 = OpCompositeExtract %6 %8 1
|
||||||
|
%15 = )" + param.name + R"( %6 %13 %14
|
||||||
|
%16 = OpCompositeExtract %6 %8 2
|
||||||
|
%17 = OpCompositeExtract %6 %8 2
|
||||||
|
%18 = )" + param.name + R"( %6 %16 %17
|
||||||
|
%19 = OpCompositeConstruct %5 %12 %15 %18
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)");
|
||||||
|
|
||||||
|
Validate(b);
|
||||||
|
}
|
||||||
|
INSTANTIATE_TEST_SUITE_P( //
|
||||||
|
BuilderTest,
|
||||||
|
BinaryArithMatrixMatrix,
|
||||||
|
testing::Values(Param{ast::BinaryOp::kAdd, "OpFAdd"},
|
||||||
|
Param{ast::BinaryOp::kSubtract, "OpFSub"}));
|
||||||
|
|
||||||
|
using BinaryArithMatrixMatrixMultiply = TestParamHelper<Param>;
|
||||||
|
TEST_P(BinaryArithMatrixMatrixMultiply, Multiply) {
|
||||||
|
auto& param = GetParam();
|
||||||
|
|
||||||
|
ast::Expression* lhs = mat3x4<f32>();
|
||||||
|
ast::Expression* rhs = mat4x3<f32>();
|
||||||
|
|
||||||
|
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||||
|
|
||||||
|
WrapInFunction(expr);
|
||||||
|
|
||||||
|
spirv::Builder& b = Build();
|
||||||
|
ASSERT_TRUE(b.Build()) << b.error();
|
||||||
|
|
||||||
|
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||||
|
OpMemoryModel Logical GLSL450
|
||||||
|
OpEntryPoint GLCompute %3 "test_function"
|
||||||
|
OpExecutionMode %3 LocalSize 1 1 1
|
||||||
|
OpName %3 "test_function"
|
||||||
|
%2 = OpTypeVoid
|
||||||
|
%1 = OpTypeFunction %2
|
||||||
|
%7 = OpTypeFloat 32
|
||||||
|
%6 = OpTypeVector %7 4
|
||||||
|
%5 = OpTypeMatrix %6 3
|
||||||
|
%8 = OpConstantNull %5
|
||||||
|
%10 = OpTypeVector %7 3
|
||||||
|
%9 = OpTypeMatrix %10 4
|
||||||
|
%11 = OpConstantNull %9
|
||||||
|
%13 = OpTypeMatrix %6 4
|
||||||
|
%3 = OpFunction %2 None %1
|
||||||
|
%4 = OpLabel
|
||||||
|
%12 = OpMatrixTimesMatrix %13 %8 %11
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
)");
|
||||||
|
|
||||||
|
Validate(b);
|
||||||
|
}
|
||||||
|
INSTANTIATE_TEST_SUITE_P( //
|
||||||
|
BuilderTest,
|
||||||
|
BinaryArithMatrixMatrixMultiply,
|
||||||
|
testing::Values(Param{ast::BinaryOp::kMultiply, "OpFMul"}));
|
||||||
|
|
||||||
|
} // namespace BinaryArithMatrixMatrix
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace spirv
|
} // namespace spirv
|
||||||
} // namespace writer
|
} // namespace writer
|
||||||
|
|
|
@ -64,6 +64,19 @@ fn scalar_vector_u32() {
|
||||||
r = s % v;
|
r = s % v;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn matrix_matrix_f32() {
|
||||||
|
var m34 : mat3x4<f32>;
|
||||||
|
var m43 : mat4x3<f32>;
|
||||||
|
var m33 : mat3x3<f32>;
|
||||||
|
var m44 : mat4x4<f32>;
|
||||||
|
|
||||||
|
m34 = m34 + m34;
|
||||||
|
m34 = m34 - m34;
|
||||||
|
|
||||||
|
m33 = m43 * m34;
|
||||||
|
m44 = m34 * m43;
|
||||||
|
}
|
||||||
|
|
||||||
[[stage(fragment)]]
|
[[stage(fragment)]]
|
||||||
fn main() -> [[location(0)]] vec4<f32> {
|
fn main() -> [[location(0)]] vec4<f32> {
|
||||||
return vec4<f32>(0.0,0.0,0.0,0.0);
|
return vec4<f32>(0.0,0.0,0.0,0.0);
|
||||||
|
|
|
@ -66,6 +66,17 @@ void scalar_vector_u32() {
|
||||||
r = (s % v);
|
r = (s % v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void matrix_matrix_f32() {
|
||||||
|
float3x4 m34 = float3x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
|
float4x3 m43 = float4x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
|
float3x3 m33 = float3x3(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
|
float4x4 m44 = float4x4(0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
|
m34 = (m34 + m34);
|
||||||
|
m34 = (m34 - m34);
|
||||||
|
m33 = mul(m34, m43);
|
||||||
|
m44 = mul(m43, m34);
|
||||||
|
}
|
||||||
|
|
||||||
tint_symbol main() {
|
tint_symbol main() {
|
||||||
const tint_symbol tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f)};
|
const tint_symbol tint_symbol_1 = {float4(0.0f, 0.0f, 0.0f, 0.0f)};
|
||||||
return tint_symbol_1;
|
return tint_symbol_1;
|
||||||
|
|
|
@ -69,6 +69,17 @@ void scalar_vector_u32() {
|
||||||
r = (s % v);
|
r = (s % v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void matrix_matrix_f32() {
|
||||||
|
float3x4 m34 = float3x4(0.0f);
|
||||||
|
float4x3 m43 = float4x3(0.0f);
|
||||||
|
float3x3 m33 = float3x3(0.0f);
|
||||||
|
float4x4 m44 = float4x4(0.0f);
|
||||||
|
m34 = (m34 + m34);
|
||||||
|
m34 = (m34 - m34);
|
||||||
|
m33 = (m43 * m34);
|
||||||
|
m44 = (m34 * m43);
|
||||||
|
}
|
||||||
|
|
||||||
fragment tint_symbol_1 tint_symbol() {
|
fragment tint_symbol_1 tint_symbol() {
|
||||||
return {float4(0.0f, 0.0f, 0.0f, 0.0f)};
|
return {float4(0.0f, 0.0f, 0.0f, 0.0f)};
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
; SPIR-V
|
; SPIR-V
|
||||||
; Version: 1.3
|
; Version: 1.3
|
||||||
; Generator: Google Tint Compiler; 0
|
; Generator: Google Tint Compiler; 0
|
||||||
; Bound: 200
|
; Bound: 250
|
||||||
; Schema: 0
|
; Schema: 0
|
||||||
OpCapability Shader
|
OpCapability Shader
|
||||||
OpMemoryModel Logical GLSL450
|
OpMemoryModel Logical GLSL450
|
||||||
|
@ -32,6 +32,11 @@
|
||||||
OpName %v_4 "v"
|
OpName %v_4 "v"
|
||||||
OpName %s_4 "s"
|
OpName %s_4 "s"
|
||||||
OpName %r_4 "r"
|
OpName %r_4 "r"
|
||||||
|
OpName %matrix_matrix_f32 "matrix_matrix_f32"
|
||||||
|
OpName %m34 "m34"
|
||||||
|
OpName %m43 "m43"
|
||||||
|
OpName %m33 "m33"
|
||||||
|
OpName %m44 "m44"
|
||||||
OpName %tint_symbol_2 "tint_symbol_2"
|
OpName %tint_symbol_2 "tint_symbol_2"
|
||||||
OpName %tint_symbol "tint_symbol"
|
OpName %tint_symbol "tint_symbol"
|
||||||
OpName %main "main"
|
OpName %main "main"
|
||||||
|
@ -60,9 +65,21 @@
|
||||||
%78 = OpConstantNull %v3uint
|
%78 = OpConstantNull %v3uint
|
||||||
%_ptr_Function_uint = OpTypePointer Function %uint
|
%_ptr_Function_uint = OpTypePointer Function %uint
|
||||||
%81 = OpConstantNull %uint
|
%81 = OpConstantNull %uint
|
||||||
%191 = OpTypeFunction %void %v4float
|
%mat3v4float = OpTypeMatrix %v4float 3
|
||||||
|
%_ptr_Function_mat3v4float = OpTypePointer Function %mat3v4float
|
||||||
|
%196 = OpConstantNull %mat3v4float
|
||||||
|
%mat4v3float = OpTypeMatrix %v3float 4
|
||||||
|
%_ptr_Function_mat4v3float = OpTypePointer Function %mat4v3float
|
||||||
|
%200 = OpConstantNull %mat4v3float
|
||||||
|
%mat3v3float = OpTypeMatrix %v3float 3
|
||||||
|
%_ptr_Function_mat3v3float = OpTypePointer Function %mat3v3float
|
||||||
|
%204 = OpConstantNull %mat3v3float
|
||||||
|
%mat4v4float = OpTypeMatrix %v4float 4
|
||||||
|
%_ptr_Function_mat4v4float = OpTypePointer Function %mat4v4float
|
||||||
|
%208 = OpConstantNull %mat4v4float
|
||||||
|
%241 = OpTypeFunction %void %v4float
|
||||||
%float_0 = OpConstant %float 0
|
%float_0 = OpConstant %float 0
|
||||||
%199 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
|
%249 = OpConstantComposite %v4float %float_0 %float_0 %float_0 %float_0
|
||||||
%vector_scalar_f32 = OpFunction %void None %6
|
%vector_scalar_f32 = OpFunction %void None %6
|
||||||
%9 = OpLabel
|
%9 = OpLabel
|
||||||
%v = OpVariable %_ptr_Function_v3float Function %13
|
%v = OpVariable %_ptr_Function_v3float Function %13
|
||||||
|
@ -269,14 +286,56 @@
|
||||||
OpStore %r_4 %188
|
OpStore %r_4 %188
|
||||||
OpReturn
|
OpReturn
|
||||||
OpFunctionEnd
|
OpFunctionEnd
|
||||||
%tint_symbol_2 = OpFunction %void None %191
|
%matrix_matrix_f32 = OpFunction %void None %6
|
||||||
|
%192 = OpLabel
|
||||||
|
%m34 = OpVariable %_ptr_Function_mat3v4float Function %196
|
||||||
|
%m43 = OpVariable %_ptr_Function_mat4v3float Function %200
|
||||||
|
%m33 = OpVariable %_ptr_Function_mat3v3float Function %204
|
||||||
|
%m44 = OpVariable %_ptr_Function_mat4v4float Function %208
|
||||||
|
%209 = OpLoad %mat3v4float %m34
|
||||||
|
%210 = OpLoad %mat3v4float %m34
|
||||||
|
%212 = OpCompositeExtract %v4float %209 0
|
||||||
|
%213 = OpCompositeExtract %v4float %210 0
|
||||||
|
%214 = OpFAdd %v4float %212 %213
|
||||||
|
%215 = OpCompositeExtract %v4float %209 1
|
||||||
|
%216 = OpCompositeExtract %v4float %210 1
|
||||||
|
%217 = OpFAdd %v4float %215 %216
|
||||||
|
%218 = OpCompositeExtract %v4float %209 2
|
||||||
|
%219 = OpCompositeExtract %v4float %210 2
|
||||||
|
%220 = OpFAdd %v4float %218 %219
|
||||||
|
%221 = OpCompositeConstruct %mat3v4float %214 %217 %220
|
||||||
|
OpStore %m34 %221
|
||||||
|
%222 = OpLoad %mat3v4float %m34
|
||||||
|
%223 = OpLoad %mat3v4float %m34
|
||||||
|
%225 = OpCompositeExtract %v4float %222 0
|
||||||
|
%226 = OpCompositeExtract %v4float %223 0
|
||||||
|
%227 = OpFSub %v4float %225 %226
|
||||||
|
%228 = OpCompositeExtract %v4float %222 1
|
||||||
|
%229 = OpCompositeExtract %v4float %223 1
|
||||||
|
%230 = OpFSub %v4float %228 %229
|
||||||
|
%231 = OpCompositeExtract %v4float %222 2
|
||||||
|
%232 = OpCompositeExtract %v4float %223 2
|
||||||
|
%233 = OpFSub %v4float %231 %232
|
||||||
|
%234 = OpCompositeConstruct %mat3v4float %227 %230 %233
|
||||||
|
OpStore %m34 %234
|
||||||
|
%235 = OpLoad %mat4v3float %m43
|
||||||
|
%236 = OpLoad %mat3v4float %m34
|
||||||
|
%237 = OpMatrixTimesMatrix %mat3v3float %235 %236
|
||||||
|
OpStore %m33 %237
|
||||||
|
%238 = OpLoad %mat3v4float %m34
|
||||||
|
%239 = OpLoad %mat4v3float %m43
|
||||||
|
%240 = OpMatrixTimesMatrix %mat4v4float %238 %239
|
||||||
|
OpStore %m44 %240
|
||||||
|
OpReturn
|
||||||
|
OpFunctionEnd
|
||||||
|
%tint_symbol_2 = OpFunction %void None %241
|
||||||
%tint_symbol = OpFunctionParameter %v4float
|
%tint_symbol = OpFunctionParameter %v4float
|
||||||
%194 = OpLabel
|
%244 = OpLabel
|
||||||
OpStore %tint_symbol_1 %tint_symbol
|
OpStore %tint_symbol_1 %tint_symbol
|
||||||
OpReturn
|
OpReturn
|
||||||
OpFunctionEnd
|
OpFunctionEnd
|
||||||
%main = OpFunction %void None %6
|
%main = OpFunction %void None %6
|
||||||
%196 = OpLabel
|
%246 = OpLabel
|
||||||
%197 = OpFunctionCall %void %tint_symbol_2 %199
|
%247 = OpFunctionCall %void %tint_symbol_2 %249
|
||||||
OpReturn
|
OpReturn
|
||||||
OpFunctionEnd
|
OpFunctionEnd
|
||||||
|
|
|
@ -62,6 +62,17 @@ fn scalar_vector_u32() {
|
||||||
r = (s % v);
|
r = (s % v);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn matrix_matrix_f32() {
|
||||||
|
var m34 : mat3x4<f32>;
|
||||||
|
var m43 : mat4x3<f32>;
|
||||||
|
var m33 : mat3x3<f32>;
|
||||||
|
var m44 : mat4x4<f32>;
|
||||||
|
m34 = (m34 + m34);
|
||||||
|
m34 = (m34 - m34);
|
||||||
|
m33 = (m43 * m34);
|
||||||
|
m44 = (m34 * m43);
|
||||||
|
}
|
||||||
|
|
||||||
[[stage(fragment)]]
|
[[stage(fragment)]]
|
||||||
fn main() -> [[location(0)]] vec4<f32> {
|
fn main() -> [[location(0)]] vec4<f32> {
|
||||||
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
return vec4<f32>(0.0, 0.0, 0.0, 0.0);
|
||||||
|
|
Loading…
Reference in New Issue