mirror of
https://github.com/encounter/dawn-cmake.git
synced 2025-12-20 18:29:23 +00:00
Implement mixed vector-scalar float % operator
W3C consensus on https://github.com/gpuweb/gpuweb/issues/2450 Spec change: https://github.com/gpuweb/gpuweb/pull/2495 Bug: tint:1370 Change-Id: I85bb9c802b0355bc53aa8dbacca8427fb7be1ff6 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/84880 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: James Price <jrprice@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
committed by
Tint LUCI CQ
parent
3b671cb377
commit
9e5484264a
@@ -59,16 +59,10 @@ bool CanReplaceAddSubtractWith(const sem::Type* lhs_type,
|
||||
// type-compatible if the matrices are square.
|
||||
return !lhs_type->is_float_matrix() || lhs_type->is_square_float_matrix();
|
||||
case ast::BinaryOp::kDivide:
|
||||
case ast::BinaryOp::kModulo:
|
||||
// '/' is not defined for matrices.
|
||||
return lhs_type->is_numeric_scalar_or_vector() &&
|
||||
rhs_type->is_numeric_scalar_or_vector();
|
||||
case ast::BinaryOp::kModulo:
|
||||
// TODO(https://crbug.com/tint/1370): once fixed, the rules should be the
|
||||
// same as for divide.
|
||||
if (lhs_type->is_float_vector() || rhs_type->is_float_vector()) {
|
||||
return lhs_type == rhs_type;
|
||||
}
|
||||
return !lhs_type->is_float_matrix() && !rhs_type->is_float_matrix();
|
||||
case ast::BinaryOp::kShiftLeft:
|
||||
case ast::BinaryOp::kShiftRight:
|
||||
return IsSuitableForShift(lhs_type, rhs_type);
|
||||
@@ -102,16 +96,10 @@ bool CanReplaceMultiplyWith(const sem::Type* lhs_type,
|
||||
// These operators require homogeneous integer types.
|
||||
return lhs_type == rhs_type && lhs_type->is_integer_scalar_or_vector();
|
||||
case ast::BinaryOp::kDivide:
|
||||
case ast::BinaryOp::kModulo:
|
||||
// '/' is not defined for matrices.
|
||||
return lhs_type->is_numeric_scalar_or_vector() &&
|
||||
rhs_type->is_numeric_scalar_or_vector();
|
||||
case ast::BinaryOp::kModulo:
|
||||
// TODO(https://crbug.com/tint/1370): once fixed, this should be the same
|
||||
// as for divide
|
||||
if (lhs_type->is_float_vector() || rhs_type->is_float_vector()) {
|
||||
return lhs_type == rhs_type;
|
||||
}
|
||||
return !lhs_type->is_float_matrix() && !rhs_type->is_float_matrix();
|
||||
case ast::BinaryOp::kShiftLeft:
|
||||
case ast::BinaryOp::kShiftRight:
|
||||
return IsSuitableForShift(lhs_type, rhs_type);
|
||||
@@ -120,9 +108,9 @@ bool CanReplaceMultiplyWith(const sem::Type* lhs_type,
|
||||
}
|
||||
}
|
||||
|
||||
bool CanReplaceDivideWith(const sem::Type* lhs_type,
|
||||
const sem::Type* rhs_type,
|
||||
ast::BinaryOp new_operator) {
|
||||
bool CanReplaceDivideOrModuloWith(const sem::Type* lhs_type,
|
||||
const sem::Type* rhs_type,
|
||||
ast::BinaryOp new_operator) {
|
||||
// The program is assumed to be well-typed, so this method determines when
|
||||
// 'new_operator' can be used as a type-preserving replacement in a '/'
|
||||
// expression.
|
||||
@@ -131,12 +119,9 @@ bool CanReplaceDivideWith(const sem::Type* lhs_type,
|
||||
case ast::BinaryOp::kSubtract:
|
||||
case ast::BinaryOp::kMultiply:
|
||||
case ast::BinaryOp::kDivide:
|
||||
case ast::BinaryOp::kModulo:
|
||||
// These operators work in all contexts where '/' works.
|
||||
return true;
|
||||
case ast::BinaryOp::kModulo:
|
||||
// TODO(https://crbug.com/tint/1370): this special case should not be
|
||||
// required; modulo and divide should work in the same contexts.
|
||||
return lhs_type->is_integer_scalar_or_vector() || lhs_type == rhs_type;
|
||||
case ast::BinaryOp::kAnd:
|
||||
case ast::BinaryOp::kOr:
|
||||
case ast::BinaryOp::kXor:
|
||||
@@ -150,30 +135,6 @@ bool CanReplaceDivideWith(const sem::Type* lhs_type,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(https://crbug.com/tint/1370): once fixed, this method will be removed
|
||||
// and the same method will be used to check Divide and Modulo.
|
||||
bool CanReplaceModuloWith(const sem::Type* lhs_type,
|
||||
const sem::Type* rhs_type,
|
||||
ast::BinaryOp new_operator) {
|
||||
switch (new_operator) {
|
||||
case ast::BinaryOp::kAdd:
|
||||
case ast::BinaryOp::kSubtract:
|
||||
case ast::BinaryOp::kMultiply:
|
||||
case ast::BinaryOp::kDivide:
|
||||
case ast::BinaryOp::kModulo:
|
||||
return true;
|
||||
case ast::BinaryOp::kAnd:
|
||||
case ast::BinaryOp::kOr:
|
||||
case ast::BinaryOp::kXor:
|
||||
return lhs_type == rhs_type && lhs_type->is_integer_scalar_or_vector();
|
||||
case ast::BinaryOp::kShiftLeft:
|
||||
case ast::BinaryOp::kShiftRight:
|
||||
return IsSuitableForShift(lhs_type, rhs_type);
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool CanReplaceLogicalAndLogicalOrWith(ast::BinaryOp new_operator) {
|
||||
switch (new_operator) {
|
||||
case ast::BinaryOp::kLogicalAnd:
|
||||
@@ -362,9 +323,9 @@ bool MutationChangeBinaryOperator::CanReplaceBinaryOperator(
|
||||
return CanReplaceMultiplyWith(lhs_basic_type, rhs_basic_type,
|
||||
new_operator);
|
||||
case ast::BinaryOp::kDivide:
|
||||
return CanReplaceDivideWith(lhs_basic_type, rhs_basic_type, new_operator);
|
||||
case ast::BinaryOp::kModulo:
|
||||
return CanReplaceModuloWith(lhs_basic_type, rhs_basic_type, new_operator);
|
||||
return CanReplaceDivideOrModuloWith(lhs_basic_type, rhs_basic_type,
|
||||
new_operator);
|
||||
case ast::BinaryOp::kAnd:
|
||||
case ast::BinaryOp::kOr:
|
||||
return CanReplaceAndOrWith(lhs_basic_type, rhs_basic_type, new_operator);
|
||||
|
||||
@@ -281,18 +281,12 @@ TEST(ChangeBinaryOperatorTest, AddSubtract) {
|
||||
}
|
||||
for (std::string vector_type : {"vec2<f32>", "vec3<f32>", "vec4<f32>"}) {
|
||||
std::string scalar_type = "f32";
|
||||
CheckMutations(
|
||||
vector_type, scalar_type, vector_type, op,
|
||||
{
|
||||
other_op, ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide
|
||||
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
|
||||
});
|
||||
CheckMutations(
|
||||
scalar_type, vector_type, vector_type, op,
|
||||
{
|
||||
other_op, ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide
|
||||
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
|
||||
});
|
||||
CheckMutations(vector_type, scalar_type, vector_type, op,
|
||||
{other_op, ast::BinaryOp::kMultiply,
|
||||
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
|
||||
CheckMutations(scalar_type, vector_type, vector_type, op,
|
||||
{other_op, ast::BinaryOp::kMultiply,
|
||||
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
|
||||
}
|
||||
for (std::string square_matrix_type :
|
||||
{"mat2x2<f32>", "mat3x3<f32>", "mat4x4<f32>"}) {
|
||||
@@ -353,20 +347,14 @@ TEST(ChangeBinaryOperatorTest, Mul) {
|
||||
}
|
||||
for (std::string vector_type : {"vec2<f32>", "vec3<f32>", "vec4<f32>"}) {
|
||||
std::string scalar_type = "f32";
|
||||
CheckMutations(
|
||||
vector_type, scalar_type, vector_type, ast::BinaryOp::kMultiply,
|
||||
{
|
||||
ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
|
||||
ast::BinaryOp::kDivide
|
||||
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
|
||||
});
|
||||
CheckMutations(
|
||||
scalar_type, vector_type, vector_type, ast::BinaryOp::kMultiply,
|
||||
{
|
||||
ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
|
||||
ast::BinaryOp::kDivide
|
||||
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
|
||||
});
|
||||
CheckMutations(vector_type, scalar_type, vector_type,
|
||||
ast::BinaryOp::kMultiply,
|
||||
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
|
||||
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
|
||||
CheckMutations(scalar_type, vector_type, vector_type,
|
||||
ast::BinaryOp::kMultiply,
|
||||
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
|
||||
ast::BinaryOp::kDivide, ast::BinaryOp::kModulo});
|
||||
}
|
||||
for (std::string square_matrix_type :
|
||||
{"mat2x2<f32>", "mat3x3<f32>", "mat4x4<f32>"}) {
|
||||
@@ -472,7 +460,7 @@ TEST(ChangeBinaryOperatorTest, Mul) {
|
||||
ast::BinaryOp::kMultiply, {});
|
||||
}
|
||||
|
||||
TEST(ChangeBinaryOperatorTest, Divide) {
|
||||
TEST(ChangeBinaryOperatorTest, DivideAndModulo) {
|
||||
for (std::string type : {"i32", "vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
|
||||
CheckMutations(
|
||||
type, type, type, ast::BinaryOp::kDivide,
|
||||
@@ -517,26 +505,15 @@ TEST(ChangeBinaryOperatorTest, Divide) {
|
||||
}
|
||||
for (std::string vector_type : {"vec2<f32>", "vec3<f32>", "vec4<f32>"}) {
|
||||
std::string scalar_type = "f32";
|
||||
CheckMutations(
|
||||
vector_type, scalar_type, vector_type, ast::BinaryOp::kDivide,
|
||||
{
|
||||
ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
|
||||
ast::BinaryOp::kMultiply
|
||||
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
|
||||
});
|
||||
CheckMutations(
|
||||
scalar_type, vector_type, vector_type, ast::BinaryOp::kDivide,
|
||||
{
|
||||
ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
|
||||
ast::BinaryOp::kMultiply
|
||||
// TODO(https://crbug.com/tint/1370): once fixed, add kModulo
|
||||
});
|
||||
CheckMutations(vector_type, scalar_type, vector_type,
|
||||
ast::BinaryOp::kDivide,
|
||||
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
|
||||
ast::BinaryOp::kMultiply, ast::BinaryOp::kModulo});
|
||||
CheckMutations(scalar_type, vector_type, vector_type,
|
||||
ast::BinaryOp::kDivide,
|
||||
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
|
||||
ast::BinaryOp::kMultiply, ast::BinaryOp::kModulo});
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(https://crbug.com/tint/1370): once fixed, combine this with the Divide
|
||||
// test
|
||||
TEST(ChangeBinaryOperatorTest, Modulo) {
|
||||
for (std::string type : {"i32", "vec2<i32>", "vec3<i32>", "vec4<i32>"}) {
|
||||
CheckMutations(
|
||||
type, type, type, ast::BinaryOp::kModulo,
|
||||
@@ -579,8 +556,6 @@ TEST(ChangeBinaryOperatorTest, Modulo) {
|
||||
{ast::BinaryOp::kAdd, ast::BinaryOp::kSubtract,
|
||||
ast::BinaryOp::kMultiply, ast::BinaryOp::kDivide});
|
||||
}
|
||||
// TODO(https://crbug.com/tint/1370): mixed float scalars/vectors will be
|
||||
// added when this test is combined with the Divide test
|
||||
}
|
||||
|
||||
TEST(ChangeBinaryOperatorTest, AndOrXor) {
|
||||
|
||||
@@ -1901,23 +1901,13 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
|
||||
}
|
||||
|
||||
// Binary arithmetic expressions with mixed scalar and vector operands
|
||||
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty)) {
|
||||
if (expr->IsModulo()) {
|
||||
if (rhs_ty->is_integer_scalar()) {
|
||||
return build(lhs_ty);
|
||||
}
|
||||
} else if (rhs_ty->is_numeric_scalar()) {
|
||||
return build(lhs_ty);
|
||||
}
|
||||
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) &&
|
||||
rhs_ty->is_numeric_scalar()) {
|
||||
return build(lhs_ty);
|
||||
}
|
||||
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty)) {
|
||||
if (expr->IsModulo()) {
|
||||
if (lhs_ty->is_integer_scalar()) {
|
||||
return build(rhs_ty);
|
||||
}
|
||||
} else if (lhs_ty->is_numeric_scalar()) {
|
||||
return build(rhs_ty);
|
||||
}
|
||||
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) &&
|
||||
lhs_ty->is_numeric_scalar()) {
|
||||
return build(rhs_ty);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1367,15 +1367,13 @@ static constexpr Params all_valid_cases[] = {
|
||||
ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kSubtract),
|
||||
ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kMultiply),
|
||||
ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kDivide),
|
||||
// NOTE: no kModulo for vec3<f32>, f32
|
||||
// ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kModulo),
|
||||
ParamsFor<vec3<f32>, f32, vec3<f32>>(Op::kModulo),
|
||||
|
||||
ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kAdd),
|
||||
ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kSubtract),
|
||||
ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kMultiply),
|
||||
ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kDivide),
|
||||
// NOTE: no kModulo for f32, vec3<f32>
|
||||
// ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kModulo),
|
||||
ParamsFor<f32, vec3<f32>, vec3<f32>>(Op::kModulo),
|
||||
|
||||
// Matrix arithmetic
|
||||
ParamsFor<mat2x3<f32>, f32, mat2x3<f32>>(Op::kMultiply),
|
||||
|
||||
Reference in New Issue
Block a user