resolver: Refactor binary operator type resolution
This same logic will be used for resolving and validating compound assignment statements, so pull the core out into a separate function that decouples it from ast::BinaryExpression. Bug: tint:1325 Change-Id: Ibdb5a7fc8153dac0dd7f9ae3d5164e23585068cd Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/74360 Reviewed-by: Antonio Maiorano <amaiorano@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com>
This commit is contained in:
parent
b349710476
commit
daea034bd1
|
@ -122,7 +122,9 @@ class BinaryExpression final : public Castable<BinaryExpression, Expression> {
|
||||||
const Expression* const rhs;
|
const Expression* const rhs;
|
||||||
};
|
};
|
||||||
|
|
||||||
inline bool BinaryExpression::IsArithmetic() const {
|
/// @param op the operator
|
||||||
|
/// @returns true if the op is an arithmetic operation
|
||||||
|
inline bool IsArithmetic(BinaryOp op) {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case ast::BinaryOp::kAdd:
|
case ast::BinaryOp::kAdd:
|
||||||
case ast::BinaryOp::kSubtract:
|
case ast::BinaryOp::kSubtract:
|
||||||
|
@ -135,7 +137,9 @@ inline bool BinaryExpression::IsArithmetic() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool BinaryExpression::IsComparison() const {
|
/// @param op the operator
|
||||||
|
/// @returns true if the op is a comparison operation
|
||||||
|
inline bool IsComparison(BinaryOp op) {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case ast::BinaryOp::kEqual:
|
case ast::BinaryOp::kEqual:
|
||||||
case ast::BinaryOp::kNotEqual:
|
case ast::BinaryOp::kNotEqual:
|
||||||
|
@ -149,7 +153,9 @@ inline bool BinaryExpression::IsComparison() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool BinaryExpression::IsBitwise() const {
|
/// @param op the operator
|
||||||
|
/// @returns true if the op is a bitwise operation
|
||||||
|
inline bool IsBitwise(BinaryOp op) {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case ast::BinaryOp::kAnd:
|
case ast::BinaryOp::kAnd:
|
||||||
case ast::BinaryOp::kOr:
|
case ast::BinaryOp::kOr:
|
||||||
|
@ -160,7 +166,9 @@ inline bool BinaryExpression::IsBitwise() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool BinaryExpression::IsBitshift() const {
|
/// @param op the operator
|
||||||
|
/// @returns true if the op is a bit shift operation
|
||||||
|
inline bool IsBitshift(BinaryOp op) {
|
||||||
switch (op) {
|
switch (op) {
|
||||||
case ast::BinaryOp::kShiftLeft:
|
case ast::BinaryOp::kShiftLeft:
|
||||||
case ast::BinaryOp::kShiftRight:
|
case ast::BinaryOp::kShiftRight:
|
||||||
|
@ -180,6 +188,22 @@ inline bool BinaryExpression::IsLogical() const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline bool BinaryExpression::IsArithmetic() const {
|
||||||
|
return ast::IsArithmetic(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool BinaryExpression::IsComparison() const {
|
||||||
|
return ast::IsComparison(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool BinaryExpression::IsBitwise() const {
|
||||||
|
return ast::IsBitwise(op);
|
||||||
|
}
|
||||||
|
|
||||||
|
inline bool BinaryExpression::IsBitshift() const {
|
||||||
|
return ast::IsBitshift(op);
|
||||||
|
}
|
||||||
|
|
||||||
/// @returns the human readable name of the given BinaryOp
|
/// @returns the human readable name of the given BinaryOp
|
||||||
/// @param op the BinaryOp
|
/// @param op the BinaryOp
|
||||||
constexpr const char* FriendlyName(BinaryOp op) {
|
constexpr const char* FriendlyName(BinaryOp op) {
|
||||||
|
|
|
@ -1838,6 +1838,33 @@ sem::Expression* Resolver::MemberAccessor(
|
||||||
}
|
}
|
||||||
|
|
||||||
sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
|
sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
|
||||||
|
auto* lhs = Sem(expr->lhs);
|
||||||
|
auto* rhs = Sem(expr->rhs);
|
||||||
|
auto* lhs_ty = lhs->Type()->UnwrapRef();
|
||||||
|
auto* rhs_ty = rhs->Type()->UnwrapRef();
|
||||||
|
|
||||||
|
auto* ty = BinaryOpType(lhs_ty, rhs_ty, expr->op);
|
||||||
|
if (!ty) {
|
||||||
|
AddError(
|
||||||
|
"Binary expression operand types are invalid for this operation: " +
|
||||||
|
TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
|
||||||
|
TypeNameOf(rhs_ty),
|
||||||
|
expr->source);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto val = EvaluateConstantValue(expr, ty);
|
||||||
|
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
|
||||||
|
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
|
||||||
|
val, has_side_effects);
|
||||||
|
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
|
||||||
|
|
||||||
|
return sem;
|
||||||
|
}
|
||||||
|
|
||||||
|
const sem::Type* Resolver::BinaryOpType(const sem::Type* lhs_ty,
|
||||||
|
const sem::Type* rhs_ty,
|
||||||
|
ast::BinaryOp op) {
|
||||||
using Bool = sem::Bool;
|
using Bool = sem::Bool;
|
||||||
using F32 = sem::F32;
|
using F32 = sem::F32;
|
||||||
using I32 = sem::I32;
|
using I32 = sem::I32;
|
||||||
|
@ -1845,12 +1872,6 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
|
||||||
using Matrix = sem::Matrix;
|
using Matrix = sem::Matrix;
|
||||||
using Vector = sem::Vector;
|
using Vector = sem::Vector;
|
||||||
|
|
||||||
auto* lhs = Sem(expr->lhs);
|
|
||||||
auto* rhs = Sem(expr->rhs);
|
|
||||||
|
|
||||||
auto* lhs_ty = lhs->Type()->UnwrapRef();
|
|
||||||
auto* rhs_ty = rhs->Type()->UnwrapRef();
|
|
||||||
|
|
||||||
auto* lhs_vec = lhs_ty->As<Vector>();
|
auto* lhs_vec = lhs_ty->As<Vector>();
|
||||||
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
|
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
|
||||||
auto* rhs_vec = rhs_ty->As<Vector>();
|
auto* rhs_vec = rhs_ty->As<Vector>();
|
||||||
|
@ -1863,51 +1884,42 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
|
||||||
|
|
||||||
const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
|
const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty);
|
||||||
|
|
||||||
auto build = [&](const sem::Type* ty) {
|
|
||||||
auto val = EvaluateConstantValue(expr, ty);
|
|
||||||
bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects();
|
|
||||||
auto* sem = builder_->create<sem::Expression>(expr, ty, current_statement_,
|
|
||||||
val, has_side_effects);
|
|
||||||
sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors();
|
|
||||||
return sem;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Binary logical expressions
|
// Binary logical expressions
|
||||||
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
|
if (op == ast::BinaryOp::kLogicalAnd || op == ast::BinaryOp::kLogicalOr) {
|
||||||
if (matching_types && lhs_ty->Is<Bool>()) {
|
if (matching_types && lhs_ty->Is<Bool>()) {
|
||||||
return build(lhs_ty);
|
return lhs_ty;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (expr->IsOr() || expr->IsAnd()) {
|
if (op == ast::BinaryOp::kOr || op == ast::BinaryOp::kAnd) {
|
||||||
if (matching_types && lhs_ty->Is<Bool>()) {
|
if (matching_types && lhs_ty->Is<Bool>()) {
|
||||||
return build(lhs_ty);
|
return lhs_ty;
|
||||||
}
|
}
|
||||||
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
|
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
|
||||||
return build(lhs_ty);
|
return lhs_ty;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Arithmetic expressions
|
// Arithmetic expressions
|
||||||
if (expr->IsArithmetic()) {
|
if (ast::IsArithmetic(op)) {
|
||||||
// Binary arithmetic expressions over scalars
|
// Binary arithmetic expressions over scalars
|
||||||
if (matching_types && lhs_ty->is_numeric_scalar()) {
|
if (matching_types && lhs_ty->is_numeric_scalar()) {
|
||||||
return build(lhs_ty);
|
return lhs_ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Binary arithmetic expressions over vectors
|
// Binary arithmetic expressions over vectors
|
||||||
if (matching_types && lhs_vec_elem_type &&
|
if (matching_types && lhs_vec_elem_type &&
|
||||||
lhs_vec_elem_type->is_numeric_scalar()) {
|
lhs_vec_elem_type->is_numeric_scalar()) {
|
||||||
return build(lhs_ty);
|
return lhs_ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Binary arithmetic expressions with mixed scalar and vector operands
|
// Binary arithmetic expressions with mixed scalar and vector operands
|
||||||
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) &&
|
if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) &&
|
||||||
rhs_ty->is_numeric_scalar()) {
|
rhs_ty->is_numeric_scalar()) {
|
||||||
return build(lhs_ty);
|
return lhs_ty;
|
||||||
}
|
}
|
||||||
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) &&
|
if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) &&
|
||||||
lhs_ty->is_numeric_scalar()) {
|
lhs_ty->is_numeric_scalar()) {
|
||||||
return build(rhs_ty);
|
return rhs_ty;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1917,106 +1929,101 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) {
|
||||||
auto* rhs_mat = rhs_ty->As<Matrix>();
|
auto* rhs_mat = rhs_ty->As<Matrix>();
|
||||||
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
|
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
|
||||||
// Addition and subtraction of float matrices
|
// Addition and subtraction of float matrices
|
||||||
if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type &&
|
if ((op == ast::BinaryOp::kAdd || op == ast::BinaryOp::kSubtract) &&
|
||||||
lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
|
lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && rhs_mat_elem_type &&
|
||||||
rhs_mat_elem_type->Is<F32>() &&
|
rhs_mat_elem_type->Is<F32>() &&
|
||||||
(lhs_mat->columns() == rhs_mat->columns()) &&
|
(lhs_mat->columns() == rhs_mat->columns()) &&
|
||||||
(lhs_mat->rows() == rhs_mat->rows())) {
|
(lhs_mat->rows() == rhs_mat->rows())) {
|
||||||
return build(rhs_ty);
|
return rhs_ty;
|
||||||
}
|
}
|
||||||
if (expr->IsMultiply()) {
|
if (op == ast::BinaryOp::kMultiply) {
|
||||||
// Multiplication of a matrix and a scalar
|
// Multiplication of a matrix and a scalar
|
||||||
if (lhs_ty->Is<F32>() && rhs_mat_elem_type &&
|
if (lhs_ty->Is<F32>() && rhs_mat_elem_type &&
|
||||||
rhs_mat_elem_type->Is<F32>()) {
|
rhs_mat_elem_type->Is<F32>()) {
|
||||||
return build(rhs_ty);
|
return rhs_ty;
|
||||||
}
|
}
|
||||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
||||||
rhs_ty->Is<F32>()) {
|
rhs_ty->Is<F32>()) {
|
||||||
return build(lhs_ty);
|
return lhs_ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Vector times matrix
|
// Vector times matrix
|
||||||
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
|
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
|
||||||
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
|
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
|
||||||
(lhs_vec->Width() == rhs_mat->rows())) {
|
(lhs_vec->Width() == rhs_mat->rows())) {
|
||||||
return build(
|
return builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns());
|
||||||
builder_->create<sem::Vector>(lhs_vec->type(), rhs_mat->columns()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matrix times vector
|
// Matrix times vector
|
||||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
||||||
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
|
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
|
||||||
(lhs_mat->columns() == rhs_vec->Width())) {
|
(lhs_mat->columns() == rhs_vec->Width())) {
|
||||||
return build(
|
return builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows());
|
||||||
builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matrix times matrix
|
// Matrix times matrix
|
||||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
||||||
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
|
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
|
||||||
(lhs_mat->columns() == rhs_mat->rows())) {
|
(lhs_mat->columns() == rhs_mat->rows())) {
|
||||||
return build(builder_->create<sem::Matrix>(
|
return builder_->create<sem::Matrix>(
|
||||||
builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
|
builder_->create<sem::Vector>(lhs_mat_elem_type, lhs_mat->rows()),
|
||||||
rhs_mat->columns()));
|
rhs_mat->columns());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Comparison expressions
|
// Comparison expressions
|
||||||
if (expr->IsComparison()) {
|
if (ast::IsComparison(op)) {
|
||||||
if (matching_types) {
|
if (matching_types) {
|
||||||
// Special case for bools: only == and !=
|
// Special case for bools: only == and !=
|
||||||
if (lhs_ty->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
|
if (lhs_ty->Is<Bool>() &&
|
||||||
return build(builder_->create<sem::Bool>());
|
(op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
|
||||||
|
return builder_->create<sem::Bool>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// For the rest, we can compare i32, u32, and f32
|
// For the rest, we can compare i32, u32, and f32
|
||||||
if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
|
if (lhs_ty->IsAnyOf<I32, U32, F32>()) {
|
||||||
return build(builder_->create<sem::Bool>());
|
return builder_->create<sem::Bool>();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same for vectors
|
// Same for vectors
|
||||||
if (matching_vec_elem_types) {
|
if (matching_vec_elem_types) {
|
||||||
if (lhs_vec_elem_type->Is<Bool>() &&
|
if (lhs_vec_elem_type->Is<Bool>() &&
|
||||||
(expr->IsEqual() || expr->IsNotEqual())) {
|
(op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) {
|
||||||
return build(builder_->create<sem::Vector>(
|
return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
|
||||||
builder_->create<sem::Bool>(), lhs_vec->Width()));
|
lhs_vec->Width());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lhs_vec_elem_type->is_numeric_scalar()) {
|
if (lhs_vec_elem_type->is_numeric_scalar()) {
|
||||||
return build(builder_->create<sem::Vector>(
|
return builder_->create<sem::Vector>(builder_->create<sem::Bool>(),
|
||||||
builder_->create<sem::Bool>(), lhs_vec->Width()));
|
lhs_vec->Width());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Binary bitwise operations
|
// Binary bitwise operations
|
||||||
if (expr->IsBitwise()) {
|
if (ast::IsBitwise(op)) {
|
||||||
if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
|
if (matching_types && lhs_ty->is_integer_scalar_or_vector()) {
|
||||||
return build(lhs_ty);
|
return lhs_ty;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Bit shift expressions
|
// Bit shift expressions
|
||||||
if (expr->IsBitshift()) {
|
if (ast::IsBitshift(op)) {
|
||||||
// Type validation rules are the same for left or right shift, despite
|
// Type validation rules are the same for left or right shift, despite
|
||||||
// differences in computation rules (i.e. right shift can be arithmetic or
|
// differences in computation rules (i.e. right shift can be arithmetic or
|
||||||
// logical depending on lhs type).
|
// logical depending on lhs type).
|
||||||
|
|
||||||
if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
|
if (lhs_ty->IsAnyOf<I32, U32>() && rhs_ty->Is<U32>()) {
|
||||||
return build(lhs_ty);
|
return lhs_ty;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
|
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
|
||||||
rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
|
rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
|
||||||
return build(lhs_ty);
|
return lhs_ty;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
AddError("Binary expression operand types are invalid for this operation: " +
|
|
||||||
TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " +
|
|
||||||
TypeNameOf(rhs_ty),
|
|
||||||
expr->source);
|
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -229,6 +229,12 @@ class Resolver {
|
||||||
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
|
sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*);
|
||||||
bool Statements(const ast::StatementList&);
|
bool Statements(const ast::StatementList&);
|
||||||
|
|
||||||
|
// Resolve the result type of a binary operator.
|
||||||
|
// Returns nullptr if the types are not valid for this operator.
|
||||||
|
const sem::Type* BinaryOpType(const sem::Type* lhs_ty,
|
||||||
|
const sem::Type* rhs_ty,
|
||||||
|
ast::BinaryOp op);
|
||||||
|
|
||||||
// AST and Type validation methods
|
// AST and Type validation methods
|
||||||
// Each return true on success, false on failure.
|
// Each return true on success, false on failure.
|
||||||
bool ValidateAlias(const ast::Alias*);
|
bool ValidateAlias(const ast::Alias*);
|
||||||
|
|
Loading…
Reference in New Issue