resolver: Refactor binary operator type resolution

This same logic will be used for resolving and validating compound
assignment statements, so pull the core out into a separate function
that decouples it from ast::BinaryExpression.

Bug: tint:1325
Change-Id: Ibdb5a7fc8153dac0dd7f9ae3d5164e23585068cd
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/74360
Reviewed-by: Antonio Maiorano <amaiorano@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
James Price 2022-03-31 22:30:10 +00:00
parent b349710476
commit daea034bd1
3 changed files with 96 additions and 59 deletions

View File

@ -122,7 +122,9 @@ class BinaryExpression final : public Castable<BinaryExpression, Expression> {
const Expression* const rhs; const Expression* const rhs;
}; };
inline bool BinaryExpression::IsArithmetic() const { /// @param op the operator
/// @returns true if the op is an arithmetic operation
inline bool IsArithmetic(BinaryOp op) {
switch (op) { switch (op) {
case ast::BinaryOp::kAdd: case ast::BinaryOp::kAdd:
case ast::BinaryOp::kSubtract: case ast::BinaryOp::kSubtract:
@ -135,7 +137,9 @@ inline bool BinaryExpression::IsArithmetic() const {
} }
} }
inline bool BinaryExpression::IsComparison() const { /// @param op the operator
/// @returns true if the op is a comparison operation
inline bool IsComparison(BinaryOp op) {
switch (op) { switch (op) {
case ast::BinaryOp::kEqual: case ast::BinaryOp::kEqual:
case ast::BinaryOp::kNotEqual: case ast::BinaryOp::kNotEqual:
@ -149,7 +153,9 @@ inline bool BinaryExpression::IsComparison() const {
} }
} }
inline bool BinaryExpression::IsBitwise() const { /// @param op the operator
/// @returns true if the op is a bitwise operation
inline bool IsBitwise(BinaryOp op) {
switch (op) { switch (op) {
case ast::BinaryOp::kAnd: case ast::BinaryOp::kAnd:
case ast::BinaryOp::kOr: case ast::BinaryOp::kOr:
@ -160,7 +166,9 @@ inline bool BinaryExpression::IsBitwise() const {
} }
} }
inline bool BinaryExpression::IsBitshift() const { /// @param op the operator
/// @returns true if the op is a bit shift operation
inline bool IsBitshift(BinaryOp op) {
switch (op) { switch (op) {
case ast::BinaryOp::kShiftLeft: case ast::BinaryOp::kShiftLeft:
case ast::BinaryOp::kShiftRight: case ast::BinaryOp::kShiftRight:
@ -180,6 +188,22 @@ inline bool BinaryExpression::IsLogical() const {
} }
} }
inline bool BinaryExpression::IsArithmetic() const {
return ast::IsArithmetic(op);
}
inline bool BinaryExpression::IsComparison() const {
return ast::IsComparison(op);
}
inline bool BinaryExpression::IsBitwise() const {
return ast::IsBitwise(op);
}
inline bool BinaryExpression::IsBitshift() const {
return ast::IsBitshift(op);
}
/// @returns the human readable name of the given BinaryOp /// @returns the human readable name of the given BinaryOp
/// @param op the BinaryOp /// @param op the BinaryOp
constexpr const char* FriendlyName(BinaryOp op) { constexpr const char* FriendlyName(BinaryOp op) {

View File

@ -1838,6 +1838,33 @@ sem::Expression* Resolver::MemberAccessor(
} }
sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
auto* lhs = Sem(expr->lhs);
auto* rhs = Sem(expr->rhs);
auto* lhs_ty = lhs->Type()->UnwrapRef();
auto* rhs_ty = rhs->Type()->UnwrapRef();
auto* ty = BinaryOpType(lhs_ty, rhs_ty, expr->op);
if (!ty) {
AddError(
"Binary expression operand types are invalid for this operation: " +
TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
TypeNameOf(rhs_ty),
expr->source);
return nullptr;
}
auto val = EvaluateConstantValue(expr, ty);
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
val, has_side_effects);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
return sem;
}
const sem::Type* Resolver::BinaryOpType(const sem::Type* lhs_ty,
const sem::Type* rhs_ty,
ast::BinaryOp op) {
using Bool = sem::Bool; using Bool = sem::Bool;
using F32 = sem::F32; using F32 = sem::F32;
using I32 = sem::I32; using I32 = sem::I32;
@ -1845,12 +1872,6 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
using Matrix = sem::Matrix; using Matrix = sem::Matrix;
using Vector = sem::Vector; using Vector = sem::Vector;
auto* lhs = Sem(expr->lhs);
auto* rhs = Sem(expr->rhs);
auto* lhs_ty = lhs->Type()->UnwrapRef();
auto* rhs_ty = rhs->Type()->UnwrapRef();
auto* lhs_vec = lhs_ty->As<Vector>(); auto* lhs_vec = lhs_ty->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr; auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
auto* rhs_vec = rhs_ty->As<Vector>(); auto* rhs_vec = rhs_ty->As<Vector>();
@ -1863,51 +1884,42 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty); const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
auto build = [&](const sem::Type* ty) {
auto val = EvaluateConstantValue(expr, ty);
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
val, has_side_effects);
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
return sem;
};
// Binary logical expressions // Binary logical expressions
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { if (op == ast::BinaryOp::kLogicalAnd || op == ast::BinaryOp::kLogicalOr) {
if (matching_types && lhs_ty->Is<Bool>()) { if (matching_types && lhs_ty->Is<Bool>()) {
return build(lhs_ty); return lhs_ty;
} }
} }
if (expr->IsOr() || expr->IsAnd()) { if (op == ast::BinaryOp::kOr || op == ast::BinaryOp::kAnd) {
if (matching_types && lhs_ty->Is<Bool>()) { if (matching_types && lhs_ty->Is<Bool>()) {
return build(lhs_ty); return lhs_ty;
} }
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) { if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
return build(lhs_ty); return lhs_ty;
} }
} }
// Arithmetic expressions // Arithmetic expressions
if (expr->IsArithmetic()) { if (ast::IsArithmetic(op)) {
// Binary arithmetic expressions over scalars // Binary arithmetic expressions over scalars
if (matching_types && lhs_ty->is_numeric_scalar()) { if (matching_types && lhs_ty->is_numeric_scalar()) {
return build(lhs_ty); return lhs_ty;
} }
// Binary arithmetic expressions over vectors // Binary arithmetic expressions over vectors
if (matching_types && lhs_vec_elem_type && if (matching_types && lhs_vec_elem_type &&
lhs_vec_elem_type->is_numeric_scalar()) { lhs_vec_elem_type->is_numeric_scalar()) {
return build(lhs_ty); return lhs_ty;
} }
// Binary arithmetic expressions with mixed scalar and vector operands // Binary arithmetic expressions with mixed scalar and vector operands
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) && if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) &&
rhs_ty->is_numeric_scalar()) { rhs_ty->is_numeric_scalar()) {
return build(lhs_ty); return lhs_ty;
} }
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) && if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) &&
lhs_ty->is_numeric_scalar()) { lhs_ty->is_numeric_scalar()) {
return build(rhs_ty); return rhs_ty;
} }
} }
@ -1917,106 +1929,101 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
auto* rhs_mat = rhs_ty->As<Matrix>(); auto* rhs_mat = rhs_ty->As<Matrix>();
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr; auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Addition and subtraction of float matrices // Addition and subtraction of float matrices
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type && if ((op == ast::BinaryOp::kAdd || op == ast::BinaryOp::kSubtract) &&
lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type && lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->columns()) && (lhs_mat->columns() == rhs_mat->columns()) &&
(lhs_mat->rows() == rhs_mat->rows())) { (lhs_mat->rows() == rhs_mat->rows())) {
return build(rhs_ty); return rhs_ty;
} }
if (expr->IsMultiply()) { if (op == ast::BinaryOp::kMultiply) {
// Multiplication of a matrix and a scalar // Multiplication of a matrix and a scalar
if (lhs_ty->Is<F32>() && rhs_mat_elem_type && if (lhs_ty->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>()) { rhs_mat_elem_type->Is<F32>()) {
return build(rhs_ty); return rhs_ty;
} }
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_ty->Is<F32>()) { rhs_ty->Is<F32>()) {
return build(lhs_ty); return lhs_ty;
} }
// Vector times matrix // Vector times matrix
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() && if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_vec->Width() == rhs_mat->rows())) { (lhs_vec->Width() == rhs_mat->rows())) {
return build( return builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns());
builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns()));
} }
// Matrix times vector // Matrix times vector
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() && rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_vec->Width())) { (lhs_mat->columns() == rhs_vec->Width())) {
return build( return builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows());
builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
} }
// Matrix times matrix // Matrix times matrix
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->rows())) { (lhs_mat->columns() == rhs_mat->rows())) {
return build(builder_->create<sem::Matrix>( return builder_->create<sem::Matrix>(
builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()), builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
rhs_mat->columns())); rhs_mat->columns());
} }
} }
// Comparison expressions // Comparison expressions
if (expr->IsComparison()) { if (ast::IsComparison(op)) {
if (matching_types) { if (matching_types) {
// Special case for bools: only == and != // Special case for bools: only == and !=
if (lhs_ty->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) { if (lhs_ty->Is<Bool>() &&
return build(builder_->create<sem::Bool>()); (op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
return builder_->create<sem::Bool>();
} }
// For the rest, we can compare i32, u32, and f32 // For the rest, we can compare i32, u32, and f32
if (lhs_ty->IsAnyOf<I32, U32, F32>()) { if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
return build(builder_->create<sem::Bool>()); return builder_->create<sem::Bool>();
} }
} }
// Same for vectors // Same for vectors
if (matching_vec_elem_types) { if (matching_vec_elem_types) {
if (lhs_vec_elem_type->Is<Bool>() && if (lhs_vec_elem_type->Is<Bool>() &&
(expr->IsEqual() || expr->IsNotEqual())) { (op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
return build(builder_->create<sem::Vector>( return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
builder_->create<sem::Bool>(), lhs_vec->Width())); lhs_vec->Width());
} }
if (lhs_vec_elem_type->is_numeric_scalar()) { if (lhs_vec_elem_type->is_numeric_scalar()) {
return build(builder_->create<sem::Vector>( return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
builder_->create<sem::Bool>(), lhs_vec->Width())); lhs_vec->Width());
} }
} }
} }
// Binary bitwise operations // Binary bitwise operations
if (expr->IsBitwise()) { if (ast::IsBitwise(op)) {
if (matching_types && lhs_ty->is_integer_scalar_or_vector()) { if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
return build(lhs_ty); return lhs_ty;
} }
} }
// Bit shift expressions // Bit shift expressions
if (expr->IsBitshift()) { if (ast::IsBitshift(op)) {
// Type validation rules are the same for left or right shift, despite // Type validation rules are the same for left or right shift, despite
// differences in computation rules (i.e. right shift can be arithmetic or // differences in computation rules (i.e. right shift can be arithmetic or
// logical depending on lhs type). // logical depending on lhs type).
if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) { if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
return build(lhs_ty); return lhs_ty;
} }
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() && if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) { rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
return build(lhs_ty); return lhs_ty;
} }
} }
AddError("Binary expression operand types are invalid for this operation: " +
TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
TypeNameOf(rhs_ty),
expr->source);
return nullptr; return nullptr;
} }

View File

@ -229,6 +229,12 @@ class Resolver {
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*); sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
bool Statements(const ast::StatementList&); bool Statements(const ast::StatementList&);
// Resolve the result type of a binary operator.
// Returns nullptr if the types are not valid for this operator.
const sem::Type* BinaryOpType(const sem::Type* lhs_ty,
const sem::Type* rhs_ty,
ast::BinaryOp op);
// AST and Type validation methods // AST and Type validation methods
// Each return true on success, false on failure. // Each return true on success, false on failure.
bool ValidateAlias(const ast::Alias*); bool ValidateAlias(const ast::Alias*);