From daea034bd19df563ab5367b07c0b0634be770975 Mon Sep 17 00:00:00 2001 From: James Price Date: Thu, 31 Mar 2022 22:30:10 +0000 Subject: [PATCH] resolver: Refactor binary operator type resolution This same logic will be used for resolving and validating compound assignment statements, so pull the core out into a separate function that decouples it from ast::BinaryExpression. Bug: tint:1325 Change-Id: Ibdb5a7fc8153dac0dd7f9ae3d5164e23585068cd Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/74360 Reviewed-by: Antonio Maiorano Reviewed-by: Ben Clayton Kokoro: Kokoro --- src/tint/ast/binary_expression.h | 32 +++++++-- src/tint/resolver/resolver.cc | 117 ++++++++++++++++--------------- src/tint/resolver/resolver.h | 6 ++ 3 files changed, 96 insertions(+), 59 deletions(-) diff --git a/src/tint/ast/binary_expression.h b/src/tint/ast/binary_expression.h index 7f0f5e7978..bcacdd4ea7 100644 --- a/src/tint/ast/binary_expression.h +++ b/src/tint/ast/binary_expression.h @@ -122,7 +122,9 @@ class BinaryExpression final : public Castable { const Expression* const rhs; }; -inline bool BinaryExpression::IsArithmetic() const { +/// @param op the operator +/// @returns true if the op is an arithmetic operation +inline bool IsArithmetic(BinaryOp op) { switch (op) { case ast::BinaryOp::kAdd: case ast::BinaryOp::kSubtract: @@ -135,7 +137,9 @@ inline bool BinaryExpression::IsArithmetic() const { } } -inline bool BinaryExpression::IsComparison() const { +/// @param op the operator +/// @returns true if the op is a comparison operation +inline bool IsComparison(BinaryOp op) { switch (op) { case ast::BinaryOp::kEqual: case ast::BinaryOp::kNotEqual: @@ -149,7 +153,9 @@ inline bool BinaryExpression::IsComparison() const { } } -inline bool BinaryExpression::IsBitwise() const { +/// @param op the operator +/// @returns true if the op is a bitwise operation +inline bool IsBitwise(BinaryOp op) { switch (op) { case ast::BinaryOp::kAnd: case ast::BinaryOp::kOr: @@ -160,7 +166,9 @@ inline bool BinaryExpression::IsBitwise() const { } } -inline bool BinaryExpression::IsBitshift() const { +/// @param op the operator +/// @returns true if the op is a bit shift operation +inline bool IsBitshift(BinaryOp op) { switch (op) { case ast::BinaryOp::kShiftLeft: case ast::BinaryOp::kShiftRight: @@ -180,6 +188,22 @@ inline bool BinaryExpression::IsLogical() const { } } +inline bool BinaryExpression::IsArithmetic() const { + return ast::IsArithmetic(op); +} + +inline bool BinaryExpression::IsComparison() const { + return ast::IsComparison(op); +} + +inline bool BinaryExpression::IsBitwise() const { + return ast::IsBitwise(op); +} + +inline bool BinaryExpression::IsBitshift() const { + return ast::IsBitshift(op); +} + /// @returns the human readable name of the given BinaryOp /// @param op the BinaryOp constexpr const char* FriendlyName(BinaryOp op) { diff --git a/src/tint/resolver/resolver.cc b/src/tint/resolver/resolver.cc index 6068a73ce5..e10425ac13 100644 --- a/src/tint/resolver/resolver.cc +++ b/src/tint/resolver/resolver.cc @@ -1838,6 +1838,33 @@ sem::Expression* Resolver::MemberAccessor( } sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { + auto* lhs = Sem(expr->lhs); + auto* rhs = Sem(expr->rhs); + auto* lhs_ty = lhs->Type()->UnwrapRef(); + auto* rhs_ty = rhs->Type()->UnwrapRef(); + + auto* ty = BinaryOpType(lhs_ty, rhs_ty, expr->op); + if (!ty) { + AddError( + "Binary expression operand types are invalid for this operation: " + + TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " + + TypeNameOf(rhs_ty), + expr->source); + return nullptr; + } + + auto val = EvaluateConstantValue(expr, ty); + bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); + auto* sem = builder_->create(expr, ty, current_statement_, + val, has_side_effects); + sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors(); + + return sem; +} + +const sem::Type* Resolver::BinaryOpType(const sem::Type* lhs_ty, + const sem::Type* rhs_ty, + ast::BinaryOp op) { using Bool = sem::Bool; using F32 = sem::F32; using I32 = sem::I32; @@ -1845,12 +1872,6 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { using Matrix = sem::Matrix; using Vector = sem::Vector; - auto* lhs = Sem(expr->lhs); - auto* rhs = Sem(expr->rhs); - - auto* lhs_ty = lhs->Type()->UnwrapRef(); - auto* rhs_ty = rhs->Type()->UnwrapRef(); - auto* lhs_vec = lhs_ty->As(); auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr; auto* rhs_vec = rhs_ty->As(); @@ -1863,51 +1884,42 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { const bool matching_types = matching_vec_elem_types || (lhs_ty == rhs_ty); - auto build = [&](const sem::Type* ty) { - auto val = EvaluateConstantValue(expr, ty); - bool has_side_effects = lhs->HasSideEffects() || rhs->HasSideEffects(); - auto* sem = builder_->create(expr, ty, current_statement_, - val, has_side_effects); - sem->Behaviors() = lhs->Behaviors() + rhs->Behaviors(); - return sem; - }; - // Binary logical expressions - if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { + if (op == ast::BinaryOp::kLogicalAnd || op == ast::BinaryOp::kLogicalOr) { if (matching_types && lhs_ty->Is()) { - return build(lhs_ty); + return lhs_ty; } } - if (expr->IsOr() || expr->IsAnd()) { + if (op == ast::BinaryOp::kOr || op == ast::BinaryOp::kAnd) { if (matching_types && lhs_ty->Is()) { - return build(lhs_ty); + return lhs_ty; } if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is()) { - return build(lhs_ty); + return lhs_ty; } } // Arithmetic expressions - if (expr->IsArithmetic()) { + if (ast::IsArithmetic(op)) { // Binary arithmetic expressions over scalars if (matching_types && lhs_ty->is_numeric_scalar()) { - return build(lhs_ty); + return lhs_ty; } // Binary arithmetic expressions over vectors if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->is_numeric_scalar()) { - return build(lhs_ty); + return lhs_ty; } // Binary arithmetic expressions with mixed scalar and vector operands if (lhs_vec_elem_type && (lhs_vec_elem_type == rhs_ty) && rhs_ty->is_numeric_scalar()) { - return build(lhs_ty); + return lhs_ty; } if (rhs_vec_elem_type && (rhs_vec_elem_type == lhs_ty) && lhs_ty->is_numeric_scalar()) { - return build(rhs_ty); + return rhs_ty; } } @@ -1917,106 +1929,101 @@ sem::Expression* Resolver::Binary(const ast::BinaryExpression* expr) { auto* rhs_mat = rhs_ty->As(); auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr; // Addition and subtraction of float matrices - if ((expr->IsAdd() || expr->IsSubtract()) && lhs_mat_elem_type && - lhs_mat_elem_type->Is() && rhs_mat_elem_type && + if ((op == ast::BinaryOp::kAdd || op == ast::BinaryOp::kSubtract) && + lhs_mat_elem_type && lhs_mat_elem_type->Is() && rhs_mat_elem_type && rhs_mat_elem_type->Is() && (lhs_mat->columns() == rhs_mat->columns()) && (lhs_mat->rows() == rhs_mat->rows())) { - return build(rhs_ty); + return rhs_ty; } - if (expr->IsMultiply()) { + if (op == ast::BinaryOp::kMultiply) { // Multiplication of a matrix and a scalar if (lhs_ty->Is() && rhs_mat_elem_type && rhs_mat_elem_type->Is()) { - return build(rhs_ty); + return rhs_ty; } if (lhs_mat_elem_type && lhs_mat_elem_type->Is() && rhs_ty->Is()) { - return build(lhs_ty); + return lhs_ty; } // Vector times matrix if (lhs_vec_elem_type && lhs_vec_elem_type->Is() && rhs_mat_elem_type && rhs_mat_elem_type->Is() && (lhs_vec->Width() == rhs_mat->rows())) { - return build( - builder_->create(lhs_vec->type(), rhs_mat->columns())); + return builder_->create(lhs_vec->type(), rhs_mat->columns()); } // Matrix times vector if (lhs_mat_elem_type && lhs_mat_elem_type->Is() && rhs_vec_elem_type && rhs_vec_elem_type->Is() && (lhs_mat->columns() == rhs_vec->Width())) { - return build( - builder_->create(rhs_vec->type(), lhs_mat->rows())); + return builder_->create(rhs_vec->type(), lhs_mat->rows()); } // Matrix times matrix if (lhs_mat_elem_type && lhs_mat_elem_type->Is() && rhs_mat_elem_type && rhs_mat_elem_type->Is() && (lhs_mat->columns() == rhs_mat->rows())) { - return build(builder_->create( + return builder_->create( builder_->create(lhs_mat_elem_type, lhs_mat->rows()), - rhs_mat->columns())); + rhs_mat->columns()); } } // Comparison expressions - if (expr->IsComparison()) { + if (ast::IsComparison(op)) { if (matching_types) { // Special case for bools: only == and != - if (lhs_ty->Is() && (expr->IsEqual() || expr->IsNotEqual())) { - return build(builder_->create()); + if (lhs_ty->Is() && + (op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) { + return builder_->create(); } // For the rest, we can compare i32, u32, and f32 if (lhs_ty->IsAnyOf()) { - return build(builder_->create()); + return builder_->create(); } } // Same for vectors if (matching_vec_elem_types) { if (lhs_vec_elem_type->Is() && - (expr->IsEqual() || expr->IsNotEqual())) { - return build(builder_->create( - builder_->create(), lhs_vec->Width())); + (op == ast::BinaryOp::kEqual || op == ast::BinaryOp::kNotEqual)) { + return builder_->create(builder_->create(), + lhs_vec->Width()); } if (lhs_vec_elem_type->is_numeric_scalar()) { - return build(builder_->create( - builder_->create(), lhs_vec->Width())); + return builder_->create(builder_->create(), + lhs_vec->Width()); } } } // Binary bitwise operations - if (expr->IsBitwise()) { + if (ast::IsBitwise(op)) { if (matching_types && lhs_ty->is_integer_scalar_or_vector()) { - return build(lhs_ty); + return lhs_ty; } } // Bit shift expressions - if (expr->IsBitshift()) { + if (ast::IsBitshift(op)) { // Type validation rules are the same for left or right shift, despite // differences in computation rules (i.e. right shift can be arithmetic or // logical depending on lhs type). if (lhs_ty->IsAnyOf() && rhs_ty->Is()) { - return build(lhs_ty); + return lhs_ty; } if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf() && rhs_vec_elem_type && rhs_vec_elem_type->Is()) { - return build(lhs_ty); + return lhs_ty; } } - AddError("Binary expression operand types are invalid for this operation: " + - TypeNameOf(lhs_ty) + " " + FriendlyName(expr->op) + " " + - TypeNameOf(rhs_ty), - expr->source); return nullptr; } diff --git a/src/tint/resolver/resolver.h b/src/tint/resolver/resolver.h index fe7e865f30..7c3d217385 100644 --- a/src/tint/resolver/resolver.h +++ b/src/tint/resolver/resolver.h @@ -229,6 +229,12 @@ class Resolver { sem::Statement* VariableDeclStatement(const ast::VariableDeclStatement*); bool Statements(const ast::StatementList&); + // Resolve the result type of a binary operator. + // Returns nullptr if the types are not valid for this operator. + const sem::Type* BinaryOpType(const sem::Type* lhs_ty, + const sem::Type* rhs_ty, + ast::BinaryOp op); + // AST and Type validation methods // Each return true on success, false on failure. bool ValidateAlias(const ast::Alias*);