Add support for binary arithmetic expressions with mixed scalar and vector operands

Bug: tint:376
Change-Id: I2994ff7394efa903050b470a850b41628d5b775c
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52324
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
Antonio Maiorano
2021-05-27 21:07:56 +00:00
committed by Tint LUCI CQ
parent d5f4ea22f0
commit eaed2b6ce2
10 changed files with 909 additions and 18 deletions

View File

@@ -2097,23 +2097,35 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
SetType(expr, lhs_type);
return true;
}
// Binary arithmetic expressions with mixed scalar and vector operands
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_type)) {
if (expr->IsModulo()) {
if (rhs_type->is_integer_scalar()) {
SetType(expr, lhs_type);
return true;
}
} else if (rhs_type->is_numeric_scalar()) {
SetType(expr, lhs_type);
return true;
}
}
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_type)) {
if (expr->IsModulo()) {
if (lhs_type->is_integer_scalar()) {
SetType(expr, rhs_type);
return true;
}
} else if (lhs_type->is_numeric_scalar()) {
SetType(expr, rhs_type);
return true;
}
}
}
// Binary arithmetic expressions with mixed scalar, vector, and matrix
// operands
// Matrix arithmetic
// TODO(amaiorano): matrix-matrix addition and subtraction
if (expr->IsMultiply()) {
// Multiplication of a vector and a scalar
if (lhs_type->Is<F32>() && rhs_vec_elem_type &&
rhs_vec_elem_type->Is<F32>()) {
SetType(expr, rhs_type);
return true;
}
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_type->Is<F32>()) {
SetType(expr, lhs_type);
return true;
}
auto* lhs_mat = lhs_type->As<Matrix>();
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
auto* rhs_mat = rhs_type->As<Matrix>();

View File

@@ -1298,16 +1298,52 @@ static constexpr Params all_valid_cases[] = {
Params{Op::kDivide, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kModulo, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
// Binary arithmetic expressions with mixed scalar, vector, and matrix
// operands
Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
// Binary arithmetic expressions with mixed scalar and vector operands
Params{Op::kAdd, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
Params{Op::kSubtract, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
Params{Op::kMultiply, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
Params{Op::kDivide, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
Params{Op::kModulo, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
Params{Op::kAdd, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kSubtract, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kMultiply, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kDivide, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kModulo, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
Params{Op::kAdd, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
Params{Op::kSubtract, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
Params{Op::kMultiply, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
Params{Op::kDivide, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
Params{Op::kModulo, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
Params{Op::kAdd, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kSubtract, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kMultiply, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kDivide, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kModulo, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
Params{Op::kAdd, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
Params{Op::kSubtract, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
Params{Op::kDivide, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
// NOTE: no kModulo for ast_vec3<f32>, ast_f32
// Params{Op::kModulo, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
Params{Op::kAdd, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kSubtract, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
Params{Op::kDivide, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
// NOTE: no kModulo for ast_f32, ast_vec3<f32>
// Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
// Matrix arithmetic
Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
// TODO(amaiorano): add mat+mat and mat-mat
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
sem_mat3x3<sem_f32>},

View File

@@ -1719,6 +1719,31 @@ uint32_t Builder::GenerateShortCircuitBinaryExpression(
return result_id;
}
uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) {
// Create a new vector to splat scalar into
auto splat_vector = result_op();
auto* splat_vector_type =
builder_.create<sem::Pointer>(vec_type, ast::StorageClass::kFunction);
push_function_var(
{Operand::Int(GenerateTypeIfNeeded(splat_vector_type)), splat_vector,
Operand::Int(ConvertStorageClass(ast::StorageClass::kFunction)),
Operand::Int(GenerateConstantNullIfNeeded(vec_type))});
// Splat scalar into vector
auto splat_result = result_op();
OperandList ops;
ops.push_back(Operand::Int(GenerateTypeIfNeeded(vec_type)));
ops.push_back(splat_result);
for (size_t i = 0; i < vec_type->As<sem::Vector>()->size(); ++i) {
ops.push_back(Operand::Int(scalar_id));
}
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
return 0;
}
return splat_result.to_i();
}
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
// There is special logic for short circuiting operators.
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
@@ -1749,6 +1774,33 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
// should have been rejected by validation.
auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
// For vector-scalar arithmetic operations, splat scalar into a vector. We
// skip this for multiply as we can use OpVectorTimesScalar.
const bool is_float_scalar_vector_multiply =
expr->IsMultiply() &&
((lhs_type->is_float_scalar() && rhs_type->is_float_vector()) ||
(lhs_type->is_float_vector() && rhs_type->is_float_scalar()));
if (expr->IsArithmetic() && !is_float_scalar_vector_multiply) {
if (lhs_type->Is<sem::Vector>() && rhs_type->is_numeric_scalar()) {
uint32_t splat_vector_id = GenerateSplat(rhs_id, lhs_type);
if (splat_vector_id == 0) {
return 0;
}
rhs_id = splat_vector_id;
rhs_type = lhs_type;
} else if (lhs_type->is_numeric_scalar() && rhs_type->Is<sem::Vector>()) {
uint32_t splat_vector_id = GenerateSplat(lhs_id, rhs_type);
if (splat_vector_id == 0) {
return 0;
}
lhs_id = splat_vector_id;
lhs_type = rhs_type;
}
}
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();

View File

@@ -473,6 +473,13 @@ class Builder {
/// @returns true if the vector was successfully generated
bool GenerateVectorType(const sem::Vector* vec, const Operand& result);
/// Generates instructions to splat `scalar_id` into a vector of type
/// `vec_type`
/// @param scalar_id scalar to splat
/// @param vec_type type of vector
/// @returns id of the new vector
uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type);
/// Converts AST image format to SPIR-V and pushes an appropriate capability.
/// @param format AST image format type
/// @returns SPIR-V image format type

View File

@@ -863,6 +863,222 @@ OpBranch %9
)");
}
enum class Type { f32, i32, u32 };
ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
switch (type) {
case Type::f32:
return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
case Type::i32:
return builder->vec3<ProgramBuilder::i32>(1, 1, 1);
case Type::u32:
return builder->vec3<ProgramBuilder::u32>(1u, 1u, 1u);
}
return nullptr;
}
ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
switch (type) {
case Type::f32:
return builder->Expr(1.f);
case Type::i32:
return builder->Expr(1);
case Type::u32:
return builder->Expr(1u);
}
return nullptr;
}
std::string OpTypeDecl(Type type) {
switch (type) {
case Type::f32:
return "OpTypeFloat 32";
case Type::i32:
return "OpTypeInt 32 1";
case Type::u32:
return "OpTypeInt 32 0";
}
return {};
}
struct Param {
Type type;
ast::BinaryOp op;
std::string name;
};
using MixedBinaryArithTest = TestParamHelper<Param>;
TEST_P(MixedBinaryArithTest, VectorScalar) {
auto& param = GetParam();
ast::Expression* lhs = MakeVectorExpr(this, param.type);
ast::Expression* rhs = MakeScalarExpr(this, param.type);
std::string op_type_decl = OpTypeDecl(param.type);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
OpName %3 "test_function"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = )" + op_type_decl + R"(
%5 = OpTypeVector %6 3
%7 = OpConstant %6 1
%8 = OpConstantComposite %5 %7 %7 %7
%11 = OpTypePointer Function %5
%12 = OpConstantNull %5
%3 = OpFunction %2 None %1
%4 = OpLabel
%10 = OpVariable %11 Function %12
%13 = OpCompositeConstruct %5 %7 %7 %7
%9 = )" + param.name + R"( %5 %8 %13
OpReturn
OpFunctionEnd
)");
Validate(b);
}
TEST_P(MixedBinaryArithTest, ScalarVector) {
auto& param = GetParam();
ast::Expression* lhs = MakeScalarExpr(this, param.type);
ast::Expression* rhs = MakeVectorExpr(this, param.type);
std::string op_type_decl = OpTypeDecl(param.type);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
OpName %3 "test_function"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%5 = )" + op_type_decl + R"(
%6 = OpConstant %5 1
%7 = OpTypeVector %5 3
%8 = OpConstantComposite %7 %6 %6 %6
%11 = OpTypePointer Function %7
%12 = OpConstantNull %7
%3 = OpFunction %2 None %1
%4 = OpLabel
%10 = OpVariable %11 Function %12
%13 = OpCompositeConstruct %7 %6 %6 %6
%9 = )" + param.name + R"( %7 %13 %8
OpReturn
OpFunctionEnd
)");
Validate(b);
}
INSTANTIATE_TEST_SUITE_P(
BuilderTest,
MixedBinaryArithTest,
testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
// NOTE: Modulo not allowed on mixed float scalar-vector
// Param{Type::f32, ast::BinaryOp::kModulo, "OpFMod"},
// NOTE: We test f32 multiplies separately as we emit
// OpVectorTimesScalar for this case
// Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
Param{Type::f32, ast::BinaryOp::kSubtract, "OpFSub"},
Param{Type::i32, ast::BinaryOp::kAdd, "OpIAdd"},
Param{Type::i32, ast::BinaryOp::kDivide, "OpSDiv"},
Param{Type::i32, ast::BinaryOp::kModulo, "OpSMod"},
Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
Param{Type::i32, ast::BinaryOp::kSubtract, "OpISub"},
Param{Type::u32, ast::BinaryOp::kAdd, "OpIAdd"},
Param{Type::u32, ast::BinaryOp::kDivide, "OpUDiv"},
Param{Type::u32, ast::BinaryOp::kModulo, "OpUMod"},
Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
using MixedBinaryArithMultiplyTest = TestParamHelper<Param>;
TEST_P(MixedBinaryArithMultiplyTest, VectorScalar) {
auto& param = GetParam();
ast::Expression* lhs = MakeVectorExpr(this, param.type);
ast::Expression* rhs = MakeScalarExpr(this, param.type);
std::string op_type_decl = OpTypeDecl(param.type);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
OpName %3 "test_function"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%6 = )" + op_type_decl + R"(
%5 = OpTypeVector %6 3
%7 = OpConstant %6 1
%8 = OpConstantComposite %5 %7 %7 %7
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVectorTimesScalar %5 %8 %7
OpReturn
OpFunctionEnd
)");
Validate(b);
}
TEST_P(MixedBinaryArithMultiplyTest, ScalarVector) {
auto& param = GetParam();
ast::Expression* lhs = MakeScalarExpr(this, param.type);
ast::Expression* rhs = MakeVectorExpr(this, param.type);
std::string op_type_decl = OpTypeDecl(param.type);
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
WrapInFunction(expr);
spirv::Builder& b = Build();
ASSERT_TRUE(b.Build()) << b.error();
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %3 "test_function"
OpExecutionMode %3 LocalSize 1 1 1
OpName %3 "test_function"
%2 = OpTypeVoid
%1 = OpTypeFunction %2
%5 = )" + op_type_decl + R"(
%6 = OpConstant %5 1
%7 = OpTypeVector %5 3
%8 = OpConstantComposite %7 %6 %6 %6
%3 = OpFunction %2 None %1
%4 = OpLabel
%9 = OpVectorTimesScalar %7 %8 %6
OpReturn
OpFunctionEnd
)");
Validate(b);
}
INSTANTIATE_TEST_SUITE_P(BuilderTest,
MixedBinaryArithMultiplyTest,
testing::Values(Param{
Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
} // namespace
} // namespace spirv
} // namespace writer