mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-08 13:14:56 +00:00
Add support for binary arithmetic expressions with mixed scalar and vector operands
Bug: tint:376 Change-Id: I2994ff7394efa903050b470a850b41628d5b775c Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52324 Commit-Queue: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Ben Clayton <bclayton@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
d5f4ea22f0
commit
eaed2b6ce2
@@ -2097,23 +2097,35 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
|
||||
SetType(expr, lhs_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
// Binary arithmetic expressions with mixed scalar and vector operands
|
||||
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_type)) {
|
||||
if (expr->IsModulo()) {
|
||||
if (rhs_type->is_integer_scalar()) {
|
||||
SetType(expr, lhs_type);
|
||||
return true;
|
||||
}
|
||||
} else if (rhs_type->is_numeric_scalar()) {
|
||||
SetType(expr, lhs_type);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_type)) {
|
||||
if (expr->IsModulo()) {
|
||||
if (lhs_type->is_integer_scalar()) {
|
||||
SetType(expr, rhs_type);
|
||||
return true;
|
||||
}
|
||||
} else if (lhs_type->is_numeric_scalar()) {
|
||||
SetType(expr, rhs_type);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Binary arithmetic expressions with mixed scalar, vector, and matrix
|
||||
// operands
|
||||
// Matrix arithmetic
|
||||
// TODO(amaiorano): matrix-matrix addition and subtraction
|
||||
if (expr->IsMultiply()) {
|
||||
// Multiplication of a vector and a scalar
|
||||
if (lhs_type->Is<F32>() && rhs_vec_elem_type &&
|
||||
rhs_vec_elem_type->Is<F32>()) {
|
||||
SetType(expr, rhs_type);
|
||||
return true;
|
||||
}
|
||||
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
|
||||
rhs_type->Is<F32>()) {
|
||||
SetType(expr, lhs_type);
|
||||
return true;
|
||||
}
|
||||
|
||||
auto* lhs_mat = lhs_type->As<Matrix>();
|
||||
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
|
||||
auto* rhs_mat = rhs_type->As<Matrix>();
|
||||
|
||||
@@ -1298,16 +1298,52 @@ static constexpr Params all_valid_cases[] = {
|
||||
Params{Op::kDivide, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
Params{Op::kModulo, ast_vec3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
|
||||
// Binary arithmetic expressions with mixed scalar, vector, and matrix
|
||||
// operands
|
||||
Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
// Binary arithmetic expressions with mixed scalar and vector operands
|
||||
Params{Op::kAdd, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
|
||||
Params{Op::kSubtract, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
|
||||
Params{Op::kMultiply, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
|
||||
Params{Op::kDivide, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
|
||||
Params{Op::kModulo, ast_vec3<i32>, ast_i32, sem_vec3<sem_i32>},
|
||||
|
||||
Params{Op::kAdd, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
|
||||
Params{Op::kSubtract, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
|
||||
Params{Op::kMultiply, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
|
||||
Params{Op::kDivide, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
|
||||
Params{Op::kModulo, ast_i32, ast_vec3<i32>, sem_vec3<sem_i32>},
|
||||
|
||||
Params{Op::kAdd, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
|
||||
Params{Op::kSubtract, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
|
||||
Params{Op::kMultiply, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
|
||||
Params{Op::kDivide, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
|
||||
Params{Op::kModulo, ast_vec3<u32>, ast_u32, sem_vec3<sem_u32>},
|
||||
|
||||
Params{Op::kAdd, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
|
||||
Params{Op::kSubtract, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
|
||||
Params{Op::kMultiply, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
|
||||
Params{Op::kDivide, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
|
||||
Params{Op::kModulo, ast_u32, ast_vec3<u32>, sem_vec3<sem_u32>},
|
||||
|
||||
Params{Op::kAdd, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
|
||||
Params{Op::kSubtract, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
|
||||
Params{Op::kDivide, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
|
||||
// NOTE: no kModulo for ast_vec3<f32>, ast_f32
|
||||
// Params{Op::kModulo, ast_vec3<f32>, ast_f32, sem_vec3<sem_f32>},
|
||||
|
||||
Params{Op::kAdd, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
Params{Op::kSubtract, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
Params{Op::kDivide, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
// NOTE: no kModulo for ast_f32, ast_vec3<f32>
|
||||
// Params{Op::kModulo, ast_f32, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
|
||||
// Matrix arithmetic
|
||||
Params{Op::kMultiply, ast_mat3x3<f32>, ast_f32, sem_mat3x3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_f32, ast_mat3x3<f32>, sem_mat3x3<sem_f32>},
|
||||
|
||||
Params{Op::kMultiply, ast_vec3<f32>, ast_mat3x3<f32>, sem_vec3<sem_f32>},
|
||||
Params{Op::kMultiply, ast_mat3x3<f32>, ast_vec3<f32>, sem_vec3<sem_f32>},
|
||||
// TODO(amaiorano): add mat+mat and mat-mat
|
||||
Params{Op::kMultiply, ast_mat3x3<f32>, ast_mat3x3<f32>,
|
||||
sem_mat3x3<sem_f32>},
|
||||
|
||||
|
||||
@@ -1719,6 +1719,31 @@ uint32_t Builder::GenerateShortCircuitBinaryExpression(
|
||||
return result_id;
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type) {
|
||||
// Create a new vector to splat scalar into
|
||||
auto splat_vector = result_op();
|
||||
auto* splat_vector_type =
|
||||
builder_.create<sem::Pointer>(vec_type, ast::StorageClass::kFunction);
|
||||
push_function_var(
|
||||
{Operand::Int(GenerateTypeIfNeeded(splat_vector_type)), splat_vector,
|
||||
Operand::Int(ConvertStorageClass(ast::StorageClass::kFunction)),
|
||||
Operand::Int(GenerateConstantNullIfNeeded(vec_type))});
|
||||
|
||||
// Splat scalar into vector
|
||||
auto splat_result = result_op();
|
||||
OperandList ops;
|
||||
ops.push_back(Operand::Int(GenerateTypeIfNeeded(vec_type)));
|
||||
ops.push_back(splat_result);
|
||||
for (size_t i = 0; i < vec_type->As<sem::Vector>()->size(); ++i) {
|
||||
ops.push_back(Operand::Int(scalar_id));
|
||||
}
|
||||
if (!push_function_inst(spv::Op::OpCompositeConstruct, ops)) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
return splat_result.to_i();
|
||||
}
|
||||
|
||||
uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
||||
// There is special logic for short circuiting operators.
|
||||
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
|
||||
@@ -1749,6 +1774,33 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
||||
// should have been rejected by validation.
|
||||
auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
|
||||
auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
|
||||
|
||||
// For vector-scalar arithmetic operations, splat scalar into a vector. We
|
||||
// skip this for multiply as we can use OpVectorTimesScalar.
|
||||
const bool is_float_scalar_vector_multiply =
|
||||
expr->IsMultiply() &&
|
||||
((lhs_type->is_float_scalar() && rhs_type->is_float_vector()) ||
|
||||
(lhs_type->is_float_vector() && rhs_type->is_float_scalar()));
|
||||
|
||||
if (expr->IsArithmetic() && !is_float_scalar_vector_multiply) {
|
||||
if (lhs_type->Is<sem::Vector>() && rhs_type->is_numeric_scalar()) {
|
||||
uint32_t splat_vector_id = GenerateSplat(rhs_id, lhs_type);
|
||||
if (splat_vector_id == 0) {
|
||||
return 0;
|
||||
}
|
||||
rhs_id = splat_vector_id;
|
||||
rhs_type = lhs_type;
|
||||
|
||||
} else if (lhs_type->is_numeric_scalar() && rhs_type->Is<sem::Vector>()) {
|
||||
uint32_t splat_vector_id = GenerateSplat(lhs_id, rhs_type);
|
||||
if (splat_vector_id == 0) {
|
||||
return 0;
|
||||
}
|
||||
lhs_id = splat_vector_id;
|
||||
lhs_type = rhs_type;
|
||||
}
|
||||
}
|
||||
|
||||
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
|
||||
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
|
||||
|
||||
|
||||
@@ -473,6 +473,13 @@ class Builder {
|
||||
/// @returns true if the vector was successfully generated
|
||||
bool GenerateVectorType(const sem::Vector* vec, const Operand& result);
|
||||
|
||||
/// Generates instructions to splat `scalar_id` into a vector of type
|
||||
/// `vec_type`
|
||||
/// @param scalar_id scalar to splat
|
||||
/// @param vec_type type of vector
|
||||
/// @returns id of the new vector
|
||||
uint32_t GenerateSplat(uint32_t scalar_id, const sem::Type* vec_type);
|
||||
|
||||
/// Converts AST image format to SPIR-V and pushes an appropriate capability.
|
||||
/// @param format AST image format type
|
||||
/// @returns SPIR-V image format type
|
||||
|
||||
@@ -863,6 +863,222 @@ OpBranch %9
|
||||
)");
|
||||
}
|
||||
|
||||
enum class Type { f32, i32, u32 };
|
||||
ast::Expression* MakeVectorExpr(ProgramBuilder* builder, Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return builder->vec3<ProgramBuilder::f32>(1.f, 1.f, 1.f);
|
||||
case Type::i32:
|
||||
return builder->vec3<ProgramBuilder::i32>(1, 1, 1);
|
||||
case Type::u32:
|
||||
return builder->vec3<ProgramBuilder::u32>(1u, 1u, 1u);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
ast::Expression* MakeScalarExpr(ProgramBuilder* builder, Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return builder->Expr(1.f);
|
||||
case Type::i32:
|
||||
return builder->Expr(1);
|
||||
case Type::u32:
|
||||
return builder->Expr(1u);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
std::string OpTypeDecl(Type type) {
|
||||
switch (type) {
|
||||
case Type::f32:
|
||||
return "OpTypeFloat 32";
|
||||
case Type::i32:
|
||||
return "OpTypeInt 32 1";
|
||||
case Type::u32:
|
||||
return "OpTypeInt 32 0";
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
||||
struct Param {
|
||||
Type type;
|
||||
ast::BinaryOp op;
|
||||
std::string name;
|
||||
};
|
||||
|
||||
using MixedBinaryArithTest = TestParamHelper<Param>;
|
||||
TEST_P(MixedBinaryArithTest, VectorScalar) {
|
||||
auto& param = GetParam();
|
||||
|
||||
ast::Expression* lhs = MakeVectorExpr(this, param.type);
|
||||
ast::Expression* rhs = MakeScalarExpr(this, param.type);
|
||||
std::string op_type_decl = OpTypeDecl(param.type);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
OpName %3 "test_function"
|
||||
%2 = OpTypeVoid
|
||||
%1 = OpTypeFunction %2
|
||||
%6 = )" + op_type_decl + R"(
|
||||
%5 = OpTypeVector %6 3
|
||||
%7 = OpConstant %6 1
|
||||
%8 = OpConstantComposite %5 %7 %7 %7
|
||||
%11 = OpTypePointer Function %5
|
||||
%12 = OpConstantNull %5
|
||||
%3 = OpFunction %2 None %1
|
||||
%4 = OpLabel
|
||||
%10 = OpVariable %11 Function %12
|
||||
%13 = OpCompositeConstruct %5 %7 %7 %7
|
||||
%9 = )" + param.name + R"( %5 %8 %13
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
|
||||
Validate(b);
|
||||
}
|
||||
TEST_P(MixedBinaryArithTest, ScalarVector) {
|
||||
auto& param = GetParam();
|
||||
|
||||
ast::Expression* lhs = MakeScalarExpr(this, param.type);
|
||||
ast::Expression* rhs = MakeVectorExpr(this, param.type);
|
||||
std::string op_type_decl = OpTypeDecl(param.type);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
OpName %3 "test_function"
|
||||
%2 = OpTypeVoid
|
||||
%1 = OpTypeFunction %2
|
||||
%5 = )" + op_type_decl + R"(
|
||||
%6 = OpConstant %5 1
|
||||
%7 = OpTypeVector %5 3
|
||||
%8 = OpConstantComposite %7 %6 %6 %6
|
||||
%11 = OpTypePointer Function %7
|
||||
%12 = OpConstantNull %7
|
||||
%3 = OpFunction %2 None %1
|
||||
%4 = OpLabel
|
||||
%10 = OpVariable %11 Function %12
|
||||
%13 = OpCompositeConstruct %7 %6 %6 %6
|
||||
%9 = )" + param.name + R"( %7 %13 %8
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
|
||||
Validate(b);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
BuilderTest,
|
||||
MixedBinaryArithTest,
|
||||
testing::Values(Param{Type::f32, ast::BinaryOp::kAdd, "OpFAdd"},
|
||||
Param{Type::f32, ast::BinaryOp::kDivide, "OpFDiv"},
|
||||
// NOTE: Modulo not allowed on mixed float scalar-vector
|
||||
// Param{Type::f32, ast::BinaryOp::kModulo, "OpFMod"},
|
||||
// NOTE: We test f32 multiplies separately as we emit
|
||||
// OpVectorTimesScalar for this case
|
||||
// Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
|
||||
Param{Type::f32, ast::BinaryOp::kSubtract, "OpFSub"},
|
||||
|
||||
Param{Type::i32, ast::BinaryOp::kAdd, "OpIAdd"},
|
||||
Param{Type::i32, ast::BinaryOp::kDivide, "OpSDiv"},
|
||||
Param{Type::i32, ast::BinaryOp::kModulo, "OpSMod"},
|
||||
Param{Type::i32, ast::BinaryOp::kMultiply, "OpIMul"},
|
||||
Param{Type::i32, ast::BinaryOp::kSubtract, "OpISub"},
|
||||
|
||||
Param{Type::u32, ast::BinaryOp::kAdd, "OpIAdd"},
|
||||
Param{Type::u32, ast::BinaryOp::kDivide, "OpUDiv"},
|
||||
Param{Type::u32, ast::BinaryOp::kModulo, "OpUMod"},
|
||||
Param{Type::u32, ast::BinaryOp::kMultiply, "OpIMul"},
|
||||
Param{Type::u32, ast::BinaryOp::kSubtract, "OpISub"}));
|
||||
|
||||
using MixedBinaryArithMultiplyTest = TestParamHelper<Param>;
|
||||
TEST_P(MixedBinaryArithMultiplyTest, VectorScalar) {
|
||||
auto& param = GetParam();
|
||||
|
||||
ast::Expression* lhs = MakeVectorExpr(this, param.type);
|
||||
ast::Expression* rhs = MakeScalarExpr(this, param.type);
|
||||
std::string op_type_decl = OpTypeDecl(param.type);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
OpName %3 "test_function"
|
||||
%2 = OpTypeVoid
|
||||
%1 = OpTypeFunction %2
|
||||
%6 = )" + op_type_decl + R"(
|
||||
%5 = OpTypeVector %6 3
|
||||
%7 = OpConstant %6 1
|
||||
%8 = OpConstantComposite %5 %7 %7 %7
|
||||
%3 = OpFunction %2 None %1
|
||||
%4 = OpLabel
|
||||
%9 = OpVectorTimesScalar %5 %8 %7
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
|
||||
Validate(b);
|
||||
}
|
||||
TEST_P(MixedBinaryArithMultiplyTest, ScalarVector) {
|
||||
auto& param = GetParam();
|
||||
|
||||
ast::Expression* lhs = MakeScalarExpr(this, param.type);
|
||||
ast::Expression* rhs = MakeVectorExpr(this, param.type);
|
||||
std::string op_type_decl = OpTypeDecl(param.type);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(param.op, lhs, rhs);
|
||||
|
||||
WrapInFunction(expr);
|
||||
|
||||
spirv::Builder& b = Build();
|
||||
ASSERT_TRUE(b.Build()) << b.error();
|
||||
|
||||
EXPECT_EQ(DumpBuilder(b), R"(OpCapability Shader
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %3 "test_function"
|
||||
OpExecutionMode %3 LocalSize 1 1 1
|
||||
OpName %3 "test_function"
|
||||
%2 = OpTypeVoid
|
||||
%1 = OpTypeFunction %2
|
||||
%5 = )" + op_type_decl + R"(
|
||||
%6 = OpConstant %5 1
|
||||
%7 = OpTypeVector %5 3
|
||||
%8 = OpConstantComposite %7 %6 %6 %6
|
||||
%3 = OpFunction %2 None %1
|
||||
%4 = OpLabel
|
||||
%9 = OpVectorTimesScalar %7 %8 %6
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
)");
|
||||
|
||||
Validate(b);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(BuilderTest,
|
||||
MixedBinaryArithMultiplyTest,
|
||||
testing::Values(Param{
|
||||
Type::f32, ast::BinaryOp::kMultiply, "OpFMul"}));
|
||||
|
||||
} // namespace
|
||||
} // namespace spirv
|
||||
} // namespace writer
|
||||
|
||||
Reference in New Issue
Block a user