Fix binary expression resolving and validation with type aliases
* Fixed resolving logical compares with lhs alias * Fixed resolving multiply with lhs or rhs alias * Fixed resolving ops with vecN<alias>and matNxM<alias> * Fixed validation with lhs or rhs alias * Fixed spir-v generation with lhs/rhs alias and added missing error message * Added tests for all valid binary expressions with lhs, rhs, or both as alias Bug: tint:680 Change-Id: I095255a3c63ec20b2e974c6866be9470e7e6ec6a Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46560 Kokoro: Kokoro <noreply+kokoro@google.com> Reviewed-by: Ben Clayton <bclayton@google.com> Reviewed-by: James Price <jrprice@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
c6aa48e4ea
commit
15c6ed048c
|
@ -1030,18 +1030,21 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
|
||||||
using Matrix = type::Matrix;
|
using Matrix = type::Matrix;
|
||||||
using Vector = type::Vector;
|
using Vector = type::Vector;
|
||||||
|
|
||||||
auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
|
auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
|
||||||
auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
|
auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
|
||||||
|
|
||||||
auto* lhs_vec = lhs_type->As<Vector>();
|
auto* lhs_vec = lhs_type->As<Vector>();
|
||||||
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
|
auto* lhs_vec_elem_type =
|
||||||
|
lhs_vec ? lhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
|
||||||
auto* rhs_vec = rhs_type->As<Vector>();
|
auto* rhs_vec = rhs_type->As<Vector>();
|
||||||
auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
|
auto* rhs_vec_elem_type =
|
||||||
|
rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
|
||||||
|
|
||||||
const bool matching_types = lhs_type == rhs_type;
|
|
||||||
const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
|
const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
|
||||||
(lhs_vec_elem_type == rhs_vec_elem_type);
|
(lhs_vec_elem_type == rhs_vec_elem_type);
|
||||||
|
|
||||||
|
const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type);
|
||||||
|
|
||||||
// Binary logical expressions
|
// Binary logical expressions
|
||||||
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
|
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
|
||||||
if (matching_types && lhs_type->Is<Bool>()) {
|
if (matching_types && lhs_type->Is<Bool>()) {
|
||||||
|
@ -1085,9 +1088,11 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto* lhs_mat = lhs_type->As<Matrix>();
|
auto* lhs_mat = lhs_type->As<Matrix>();
|
||||||
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
|
auto* lhs_mat_elem_type =
|
||||||
|
lhs_mat ? lhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
|
||||||
auto* rhs_mat = rhs_type->As<Matrix>();
|
auto* rhs_mat = rhs_type->As<Matrix>();
|
||||||
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
|
auto* rhs_mat_elem_type =
|
||||||
|
rhs_mat ? rhs_mat->type()->UnwrapAliasIfNeeded() : nullptr;
|
||||||
|
|
||||||
// Multiplication of a matrix and a scalar
|
// Multiplication of a matrix and a scalar
|
||||||
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
|
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
|
||||||
|
@ -1195,7 +1200,7 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
|
||||||
expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
|
expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() ||
|
||||||
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
|
expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) {
|
||||||
auto* bool_type = builder_->create<type::Bool>();
|
auto* bool_type = builder_->create<type::Bool>();
|
||||||
auto* param_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
|
auto* param_type = TypeOf(expr->lhs())->UnwrapAll();
|
||||||
type::Type* result_type = bool_type;
|
type::Type* result_type = bool_type;
|
||||||
if (auto* vec = param_type->As<type::Vector>()) {
|
if (auto* vec = param_type->As<type::Vector>()) {
|
||||||
result_type = builder_->create<type::Vector>(bool_type, vec->size());
|
result_type = builder_->create<type::Vector>(bool_type, vec->size());
|
||||||
|
@ -1204,8 +1209,8 @@ bool Resolver::Binary(ast::BinaryExpression* expr) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
if (expr->IsMultiply()) {
|
if (expr->IsMultiply()) {
|
||||||
auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
|
auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
|
||||||
auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
|
auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
|
||||||
|
|
||||||
// Note, the ordering here matters. The later checks depend on the prior
|
// Note, the ordering here matters. The later checks depend on the prior
|
||||||
// checks having been done.
|
// checks having been done.
|
||||||
|
|
|
@ -1143,6 +1143,11 @@ TEST_P(Expr_Binary_Test_Valid, All) {
|
||||||
auto* rhs_type = params.create_rhs_type(ty);
|
auto* rhs_type = params.create_rhs_type(ty);
|
||||||
auto* result_type = params.create_result_type(ty);
|
auto* result_type = params.create_result_type(ty);
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
|
||||||
|
<< rhs_type->FriendlyName(Symbols());
|
||||||
|
SCOPED_TRACE(ss.str());
|
||||||
|
|
||||||
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
||||||
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
||||||
|
|
||||||
|
@ -1158,6 +1163,70 @@ INSTANTIATE_TEST_SUITE_P(ResolverTest,
|
||||||
Expr_Binary_Test_Valid,
|
Expr_Binary_Test_Valid,
|
||||||
testing::ValuesIn(all_valid_cases));
|
testing::ValuesIn(all_valid_cases));
|
||||||
|
|
||||||
|
enum class BinaryExprSide { Left, Right, Both };
|
||||||
|
using Expr_Binary_Test_WithAlias_Valid =
|
||||||
|
ResolverTestWithParam<std::tuple<Params, BinaryExprSide>>;
|
||||||
|
TEST_P(Expr_Binary_Test_WithAlias_Valid, All) {
|
||||||
|
const Params& params = std::get<0>(GetParam());
|
||||||
|
BinaryExprSide side = std::get<1>(GetParam());
|
||||||
|
|
||||||
|
auto* lhs_type = params.create_lhs_type(ty);
|
||||||
|
auto* rhs_type = params.create_rhs_type(ty);
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
|
||||||
|
<< rhs_type->FriendlyName(Symbols());
|
||||||
|
|
||||||
|
// For vectors and matrices, wrap the sub type in an alias
|
||||||
|
auto make_alias = [this](type::Type* type) -> type::Type* {
|
||||||
|
type::Type* result;
|
||||||
|
if (auto* v = type->As<type::Vector>()) {
|
||||||
|
result = create<type::Vector>(
|
||||||
|
create<type::Alias>(Symbols().New(), v->type()), v->size());
|
||||||
|
} else if (auto* m = type->As<type::Matrix>()) {
|
||||||
|
result =
|
||||||
|
create<type::Matrix>(create<type::Alias>(Symbols().New(), m->type()),
|
||||||
|
m->rows(), m->columns());
|
||||||
|
} else {
|
||||||
|
result = create<type::Alias>(Symbols().New(), type);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wrap in alias
|
||||||
|
if (side == BinaryExprSide::Left || side == BinaryExprSide::Both) {
|
||||||
|
lhs_type = make_alias(lhs_type);
|
||||||
|
}
|
||||||
|
if (side == BinaryExprSide::Right || side == BinaryExprSide::Both) {
|
||||||
|
rhs_type = make_alias(rhs_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
ss << ", After aliasing: " << lhs_type->FriendlyName(Symbols()) << " "
|
||||||
|
<< params.op << " " << rhs_type->FriendlyName(Symbols());
|
||||||
|
SCOPED_TRACE(ss.str());
|
||||||
|
|
||||||
|
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
||||||
|
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
||||||
|
|
||||||
|
auto* expr =
|
||||||
|
create<ast::BinaryExpression>(params.op, Expr("lhs"), Expr("rhs"));
|
||||||
|
WrapInFunction(expr);
|
||||||
|
|
||||||
|
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
ASSERT_NE(TypeOf(expr), nullptr);
|
||||||
|
// TODO(amaiorano): Bring this back once we have a way to get the canonical
|
||||||
|
// type
|
||||||
|
// auto* result_type = params.create_result_type(ty);
|
||||||
|
// ASSERT_TRUE(TypeOf(expr) == result_type);
|
||||||
|
}
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
ResolverTest,
|
||||||
|
Expr_Binary_Test_WithAlias_Valid,
|
||||||
|
testing::Combine(testing::ValuesIn(all_valid_cases),
|
||||||
|
testing::Values(BinaryExprSide::Left,
|
||||||
|
BinaryExprSide::Right,
|
||||||
|
BinaryExprSide::Both)));
|
||||||
|
|
||||||
using Expr_Binary_Test_Invalid =
|
using Expr_Binary_Test_Invalid =
|
||||||
ResolverTestWithParam<std::tuple<Params, create_type_func_ptr>>;
|
ResolverTestWithParam<std::tuple<Params, create_type_func_ptr>>;
|
||||||
TEST_P(Expr_Binary_Test_Invalid, All) {
|
TEST_P(Expr_Binary_Test_Invalid, All) {
|
||||||
|
@ -1186,6 +1255,11 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::stringstream ss;
|
||||||
|
ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
|
||||||
|
<< rhs_type->FriendlyName(Symbols());
|
||||||
|
SCOPED_TRACE(ss.str());
|
||||||
|
|
||||||
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
||||||
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
||||||
|
|
||||||
|
|
|
@ -1723,8 +1723,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
||||||
|
|
||||||
// Handle int and float and the vectors of those types. Other types
|
// Handle int and float and the vectors of those types. Other types
|
||||||
// should have been rejected by validation.
|
// should have been rejected by validation.
|
||||||
auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
|
auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll();
|
||||||
auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
|
auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll();
|
||||||
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
|
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
|
||||||
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
|
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
|
||||||
|
|
||||||
|
@ -1820,6 +1820,7 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
|
||||||
// float matrix * matrix
|
// float matrix * matrix
|
||||||
op = spv::Op::OpMatrixTimesMatrix;
|
op = spv::Op::OpMatrixTimesMatrix;
|
||||||
} else {
|
} else {
|
||||||
|
error_ = "invalid multiply expression";
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
} else if (expr->IsNotEqual()) {
|
} else if (expr->IsNotEqual()) {
|
||||||
|
|
Loading…
Reference in New Issue