Resolver: merge binary expression validation and resolving logic together

This avoid duplicating the logic in two places, and makes it easier to
implement according to the spec.

Bug: tint:376
Change-Id: If62f508e2c76b5b661e66aae9ff20b8e874a65d8
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/52323
Kokoro: Kokoro <noreply+kokoro@google.com>
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2021-05-27 18:25:06 +00:00 committed by Tint LUCI CQ
parent 25b7d4064f
commit 0131ce2205
2 changed files with 37 additions and 85 deletions

View File

@ -2035,7 +2035,14 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
return true;
}
bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
bool Resolver::Binary(ast::BinaryExpression* expr) {
Mark(expr->lhs());
Mark(expr->rhs());
if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
return false;
}
using Bool = sem::Bool;
using F32 = sem::F32;
using I32 = sem::I32;
@ -2061,14 +2068,17 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
// Binary logical expressions
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
if (matching_types && lhs_type->Is<Bool>()) {
SetType(expr, lhs_type);
return true;
}
}
if (expr->IsOr() || expr->IsAnd()) {
if (matching_types && lhs_type->Is<Bool>()) {
SetType(expr, lhs_type);
return true;
}
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
SetType(expr, lhs_type);
return true;
}
}
@ -2076,13 +2086,15 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
// Arithmetic expressions
if (expr->IsArithmetic()) {
// Binary arithmetic expressions over scalars
if (matching_types && lhs_type->IsAnyOf<I32, F32, U32>()) {
if (matching_types && lhs_type->is_numeric_scalar()) {
SetType(expr, lhs_type);
return true;
}
// Binary arithmetic expressions over vectors
if (matching_types && lhs_vec_elem_type &&
lhs_vec_elem_type->IsAnyOf<I32, F32, U32>()) {
lhs_vec_elem_type->is_numeric_scalar()) {
SetType(expr, lhs_type);
return true;
}
}
@ -2093,10 +2105,12 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
// Multiplication of a vector and a scalar
if (lhs_type->Is<F32>() && rhs_vec_elem_type &&
rhs_vec_elem_type->Is<F32>()) {
SetType(expr, rhs_type);
return true;
}
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_type->Is<F32>()) {
SetType(expr, lhs_type);
return true;
}
@ -2108,10 +2122,12 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
// Multiplication of a matrix and a scalar
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>()) {
SetType(expr, rhs_type);
return true;
}
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_type->Is<F32>()) {
SetType(expr, lhs_type);
return true;
}
@ -2119,6 +2135,8 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_vec->size() == rhs_mat->rows())) {
SetType(expr, builder_->create<sem::Vector>(lhs_vec->type(),
rhs_mat->columns()));
return true;
}
@ -2126,6 +2144,8 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_vec->size())) {
SetType(expr,
builder_->create<sem::Vector>(rhs_vec->type(), lhs_mat->rows()));
return true;
}
@ -2133,6 +2153,10 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->rows())) {
SetType(expr, builder_->create<sem::Matrix>(
builder_->create<sem::Vector>(lhs_mat_elem_type,
lhs_mat->rows()),
rhs_mat->columns()));
return true;
}
}
@ -2142,11 +2166,13 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
if (matching_types) {
// Special case for bools: only == and !=
if (lhs_type->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
SetType(expr, builder_->create<sem::Bool>());
return true;
}
// For the rest, we can compare i32, u32, and f32
if (lhs_type->IsAnyOf<I32, U32, F32>()) {
SetType(expr, builder_->create<sem::Bool>());
return true;
}
}
@ -2155,10 +2181,14 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
if (matching_vec_elem_types) {
if (lhs_vec_elem_type->Is<Bool>() &&
(expr->IsEqual() || expr->IsNotEqual())) {
SetType(expr, builder_->create<sem::Vector>(
builder_->create<sem::Bool>(), lhs_vec->size()));
return true;
}
if (lhs_vec_elem_type->IsAnyOf<I32, U32, F32>()) {
if (lhs_vec_elem_type->is_numeric_scalar()) {
SetType(expr, builder_->create<sem::Vector>(
builder_->create<sem::Bool>(), lhs_vec->size()));
return true;
}
}
@ -2167,6 +2197,7 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
// Binary bitwise operations
if (expr->IsBitwise()) {
if (matching_types && lhs_type->IsAnyOf<I32, U32>()) {
SetType(expr, lhs_type);
return true;
}
}
@ -2178,11 +2209,13 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
// logical depending on lhs type).
if (lhs_type->IsAnyOf<I32, U32>() && rhs_type->Is<U32>()) {
SetType(expr, lhs_type);
return true;
}
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
SetType(expr, lhs_type);
return true;
}
}
@ -2196,86 +2229,6 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
return false;
}
bool Resolver::Binary(ast::BinaryExpression* expr) {
Mark(expr->lhs());
Mark(expr->rhs());
if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
return false;
}
if (!ValidateBinary(expr)) {
return false;
}
// Result type matches first parameter type
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() ||
expr->IsDivide() || expr->IsModulo()) {
SetType(expr, TypeOf(expr->lhs())->UnwrapRef());
return true;
}
// Result type is a scalar or vector of boolean type
if (expr->IsLogicalAnd() || expr->IsLogicalOr() || expr->IsEqual() ||
expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
auto* bool_type = builder_->create<sem::Bool>();
auto* param_type = TypeOf(expr->lhs())->UnwrapRef();
sem::Type* result_type = bool_type;
if (auto* vec = param_type->As<sem::Vector>()) {
result_type = builder_->create<sem::Vector>(bool_type, vec->size());
}
SetType(expr, result_type);
return true;
}
if (expr->IsMultiply()) {
auto* lhs_type = TypeOf(expr->lhs())->UnwrapRef();
auto* rhs_type = TypeOf(expr->rhs())->UnwrapRef();
// Note, the ordering here matters. The later checks depend on the prior
// checks having been done.
auto* lhs_mat = lhs_type->As<sem::Matrix>();
auto* rhs_mat = rhs_type->As<sem::Matrix>();
auto* lhs_vec = lhs_type->As<sem::Vector>();
auto* rhs_vec = rhs_type->As<sem::Vector>();
const sem::Type* result_type = nullptr;
if (lhs_mat && rhs_mat) {
auto* column_type =
builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
result_type =
builder_->create<sem::Matrix>(column_type, rhs_mat->columns());
} else if (lhs_mat && rhs_vec) {
result_type =
builder_->create<sem::Vector>(lhs_mat->type(), lhs_mat->rows());
} else if (lhs_vec && rhs_mat) {
result_type =
builder_->create<sem::Vector>(rhs_mat->type(), rhs_mat->columns());
} else if (lhs_mat) {
// matrix * scalar
result_type = lhs_type;
} else if (rhs_mat) {
// scalar * matrix
result_type = rhs_type;
} else if (lhs_vec && rhs_vec) {
result_type = lhs_type;
} else if (lhs_vec) {
// Vector * scalar
result_type = lhs_type;
} else if (rhs_vec) {
// Scalar * vector
result_type = rhs_type;
} else {
// Scalar * Scalar
result_type = lhs_type;
}
SetType(expr, result_type);
return true;
}
diagnostics_.add_error("Unknown binary expression", expr->source());
return false;
}
bool Resolver::UnaryOp(ast::UnaryOpExpression* unary) {
Mark(unary->expr());

View File

@ -236,7 +236,6 @@ class Resolver {
uint32_t el_align,
const Source& source);
bool ValidateAssignment(const ast::AssignmentStatement* a);
bool ValidateBinary(ast::BinaryExpression* expr);
bool ValidateEntryPoint(const ast::Function* func, const FunctionInfo* info);
bool ValidateFunction(const ast::Function* func, const FunctionInfo* info);
bool ValidateGlobalVariable(const VariableInfo* var);