resolver: Refactor binary operator type resolution

This same logic will be used for resolving and validating compound
assignment statements, so pull the core out into a separate function
that decouples it from ast::BinaryExpression.

Bug: tint:1325
Change-Id: Ibdb5a7fc8153dac0dd7f9ae3d5164e23585068cd
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/74360
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
James Price 2022-03-31 22:30:10 +00:00
parent b349710476
commit daea034bd1
3 changed files with 96 additions and 59 deletions

View File

@ -122,7 +122,9 @@ class BinaryExpression final : public Castable<BinaryExpression, Expression> {
const Expression* const rhs;
};
inline bool BinaryExpression::IsArithmetic() const {
/// @param op the operator
/// @returns true if the op is an arithmetic operation
inline bool IsArithmetic(BinaryOp op) {
switch (op) {
case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract:
@ -135,7 +137,9 @@ inline bool BinaryExpression::IsArithmetic() const {
}
}
inline bool BinaryExpression::IsComparison() const {
/// @param op the operator
/// @returns true if the op is a comparison operation
inline bool IsComparison(BinaryOp op) {
switch (op) {
case ast::BinaryOp::kEqual:
case ast::BinaryOp::kNotEqual:
@ -149,7 +153,9 @@ inline bool BinaryExpression::IsComparison() const {
}
}
inline bool BinaryExpression::IsBitwise() const {
/// @param op the operator
/// @returns true if the op is a bitwise operation
inline bool IsBitwise(BinaryOp op) {
switch (op) {
case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr:
@ -160,7 +166,9 @@ inline bool BinaryExpression::IsBitwise() const {
}
}
inline bool BinaryExpression::IsBitshift() const {
/// @param op the operator
/// @returns true if the op is a bit shift operation
inline bool IsBitshift(BinaryOp op) {
switch (op) {
case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight:
@ -180,6 +188,22 @@ inline bool BinaryExpression::IsLogical() const {
}
}
inline bool BinaryExpression::IsArithmetic() const {
return ast::IsArithmetic(op);
}
inline bool BinaryExpression::IsComparison() const {
return ast::IsComparison(op);
}
inline bool BinaryExpression::IsBitwise() const {
return ast::IsBitwise(op);
}
inline bool BinaryExpression::IsBitshift() const {
return ast::IsBitshift(op);
}
/// @returns the human readable name of the given BinaryOp
/// @param op the BinaryOp
constexpr const char* FriendlyName(BinaryOp op) {

View File

@ -1838,6 +1838,33 @@ sem::Expression* Resolver::MemberAccessor(
}
sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
auto* lhs = Sem(expr->lhs);
auto* rhs = Sem(expr->rhs);
auto* lhs_ty = lhs->Type()->UnwrapRef();
auto* rhs_ty = rhs->Type()->UnwrapRef();
auto* ty = BinaryOpType(lhs_ty, rhs_ty, expr->op);
if (!ty) {
AddError(
"Binary expression operand types are invalid for this operation: " +
TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
TypeNameOf(rhs_ty),
expr->source);
return nullptr;
}
auto val = EvaluateConstantValue(expr, ty);
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
val, has_side_effects);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
return sem;
}
const sem::Type* Resolver::BinaryOpType(const sem::Type* lhs_ty,
const sem::Type* rhs_ty,
ast::BinaryOp op) {
using Bool = sem::Bool;
using F32 = sem::F32;
using I32 = sem::I32;
@ -1845,12 +1872,6 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
using Matrix = sem::Matrix;
using Vector = sem::Vector;
auto* lhs = Sem(expr->lhs);
auto* rhs = Sem(expr->rhs);
auto* lhs_ty = lhs->Type()->UnwrapRef();
auto* rhs_ty = rhs->Type()->UnwrapRef();
auto* lhs_vec = lhs_ty->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
auto* rhs_vec = rhs_ty->As<Vector>();
@ -1863,51 +1884,42 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
auto build = [&](const sem::Type* ty) {
auto val = EvaluateConstantValue(expr, ty);
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
val, has_side_effects);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
return sem;
};
// Binary logical expressions
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
if (op == ast::BinaryOp::kLogicalAnd || op == ast::BinaryOp::kLogicalOr) {
if (matching_types && lhs_ty->Is<Bool>()) {
return build(lhs_ty);
return lhs_ty;
}
}
if (expr->IsOr() || expr->IsAnd()) {
if (op == ast::BinaryOp::kOr || op == ast::BinaryOp::kAnd) {
if (matching_types && lhs_ty->Is<Bool>()) {
return build(lhs_ty);
return lhs_ty;
}
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
return build(lhs_ty);
return lhs_ty;
}
}
// Arithmetic expressions
if (expr->IsArithmetic()) {
if (ast::IsArithmetic(op)) {
// Binary arithmetic expressions over scalars
if (matching_types && lhs_ty->is_numeric_scalar()) {
return build(lhs_ty);
return lhs_ty;
}
// Binary arithmetic expressions over vectors
if (matching_types && lhs_vec_elem_type &&
lhs_vec_elem_type->is_numeric_scalar()) {
return build(lhs_ty);
return lhs_ty;
}
// Binary arithmetic expressions with mixed scalar and vector operands
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) &&
rhs_ty->is_numeric_scalar()) {
return build(lhs_ty);
return lhs_ty;
}
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) &&
lhs_ty->is_numeric_scalar()) {
return build(rhs_ty);
return rhs_ty;
}
}
@ -1917,106 +1929,101 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
auto* rhs_mat = rhs_ty->As<Matrix>();
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Addition and subtraction of float matrices
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
if ((op == ast::BinaryOp::kAdd || op == ast::BinaryOp::kSubtract) &&
lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->columns()) &&
(lhs_mat->rows() == rhs_mat->rows())) {
return build(rhs_ty);
return rhs_ty;
}
if (expr->IsMultiply()) {
if (op == ast::BinaryOp::kMultiply) {
// Multiplication of a matrix and a scalar
if (lhs_ty->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>()) {
return build(rhs_ty);
return rhs_ty;
}
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_ty->Is<F32>()) {
return build(lhs_ty);
return lhs_ty;
}
// Vector times matrix
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_vec->Width() == rhs_mat->rows())) {
return build(
builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns()));
return builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns());
}
// Matrix times vector
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_vec->Width())) {
return build(
builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
return builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows());
}
// Matrix times matrix
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->rows())) {
return build(builder_->create<sem::Matrix>(
return builder_->create<sem::Matrix>(
builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
rhs_mat->columns()));
rhs_mat->columns());
}
}
// Comparison expressions
if (expr->IsComparison()) {
if (ast::IsComparison(op)) {
if (matching_types) {
// Special case for bools: only == and !=
if (lhs_ty->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
return build(builder_->create<sem::Bool>());
if (lhs_ty->Is<Bool>() &&
(op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
return builder_->create<sem::Bool>();
}
// For the rest, we can compare i32, u32, and f32
if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
return build(builder_->create<sem::Bool>());
return builder_->create<sem::Bool>();
}
}
// Same for vectors
if (matching_vec_elem_types) {
if (lhs_vec_elem_type->Is<Bool>() &&
(expr->IsEqual() || expr->IsNotEqual())) {
return build(builder_->create<sem::Vector>(
builder_->create<sem::Bool>(), lhs_vec->Width()));
(op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
lhs_vec->Width());
}
if (lhs_vec_elem_type->is_numeric_scalar()) {
return build(builder_->create<sem::Vector>(
builder_->create<sem::Bool>(), lhs_vec->Width()));
return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
lhs_vec->Width());
}
}
}
// Binary bitwise operations
if (expr->IsBitwise()) {
if (ast::IsBitwise(op)) {
if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
return build(lhs_ty);
return lhs_ty;
}
}
// Bit shift expressions
if (expr->IsBitshift()) {
if (ast::IsBitshift(op)) {
// Type validation rules are the same for left or right shift, despite
// differences in computation rules (i.e. right shift can be arithmetic or
// logical depending on lhs type).
if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
return build(lhs_ty);
return lhs_ty;
}
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
return build(lhs_ty);
return lhs_ty;
}
}
AddError("Binary expression operand types are invalid for this operation: " +
TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
TypeNameOf(rhs_ty),
expr->source);
return nullptr;
}

View File

@ -229,6 +229,12 @@ class Resolver {
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
bool Statements(const ast::StatementList&);
// Resolve the result type of a binary operator.
// Returns nullptr if the types are not valid for this operator.
const sem::Type* BinaryOpType(const sem::Type* lhs_ty,
const sem::Type* rhs_ty,
ast::BinaryOp op);
// AST and Type validation methods
// Each return true on success, false on failure.
bool ValidateAlias(const ast::Alias*);