From 15c6ed048c2477a8b9ad2a7ab751f8d6ed9cd9e6 Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Thu, 1 Apr 2021 14:59:27 +0000 Subject: [PATCH] Fix binary expression resolving and validation with type aliases * Fixed resolving logical compares with lhs alias * Fixed resolving multiply with lhs or rhs alias * Fixed resolving ops with vecNand matNxM * Fixed validation with lhs or rhs alias * Fixed spir-v generation with lhs/rhs alias and added missing error message * Added tests for all valid binary expressions with lhs, rhs, or both as alias Bug: tint:680 Change-Id: I095255a3c63ec20b2e974c6866be9470e7e6ec6a Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46560 Kokoro: Kokoro Reviewed-by: Ben Clayton Reviewed-by: James Price Commit-Queue: Antonio Maiorano --- src/resolver/resolver.cc | 25 +++++++----- src/resolver/resolver_test.cc | 74 +++++++++++++++++++++++++++++++++++ src/writer/spirv/builder.cc | 5 ++- 3 files changed, 92 insertions(+), 12 deletions(-) diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index f1a5338f2d..900e16924a 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -1030,18 +1030,21 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) { using Matrix = type::Matrix; using Vector = type::Vector; - auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded(); - auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded(); + auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll(); + auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll(); auto* lhs_vec = lhs_type->As(); - auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr; + auto* lhs_vec_elem_type = + lhs_vec ? lhs_vec->type()->UnwrapAliasIfNeeded() : nullptr; auto* rhs_vec = rhs_type->As(); - auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr; + auto* rhs_vec_elem_type = + rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr; - const bool matching_types = lhs_type == rhs_type; const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type && (lhs_vec_elem_type == rhs_vec_elem_type); + const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type); + // Binary logical expressions if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { if (matching_types && lhs_type->Is()) { @@ -1085,9 +1088,11 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) { } auto* lhs_mat = lhs_type->As(); - auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr; + auto* lhs_mat_elem_type = + lhs_mat ? lhs_mat->type()->UnwrapAliasIfNeeded() : nullptr; auto* rhs_mat = rhs_type->As(); - auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr; + auto* rhs_mat_elem_type = + rhs_mat ? rhs_mat->type()->UnwrapAliasIfNeeded() : nullptr; // Multiplication of a matrix and a scalar if (lhs_type->Is() && rhs_mat_elem_type && @@ -1195,7 +1200,7 @@ bool Resolver::Binary(ast::BinaryExpression* expr) { expr->IsNotEqual() || expr->IsLessThan() || expr->IsGreaterThan() || expr->IsLessThanEqual() || expr->IsGreaterThanEqual()) { auto* bool_type = builder_->create(); - auto* param_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded(); + auto* param_type = TypeOf(expr->lhs())->UnwrapAll(); type::Type* result_type = bool_type; if (auto* vec = param_type->As()) { result_type = builder_->create(bool_type, vec->size()); @@ -1204,8 +1209,8 @@ bool Resolver::Binary(ast::BinaryExpression* expr) { return true; } if (expr->IsMultiply()) { - auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded(); - auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded(); + auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll(); + auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll(); // Note, the ordering here matters. The later checks depend on the prior // checks having been done. diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index eadc0bcc3e..1c785d952e 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -1143,6 +1143,11 @@ TEST_P(Expr_Binary_Test_Valid, All) { auto* rhs_type = params.create_rhs_type(ty); auto* result_type = params.create_result_type(ty); + std::stringstream ss; + ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " " + << rhs_type->FriendlyName(Symbols()); + SCOPED_TRACE(ss.str()); + Global("lhs", lhs_type, ast::StorageClass::kNone); Global("rhs", rhs_type, ast::StorageClass::kNone); @@ -1158,6 +1163,70 @@ INSTANTIATE_TEST_SUITE_P(ResolverTest, Expr_Binary_Test_Valid, testing::ValuesIn(all_valid_cases)); +enum class BinaryExprSide { Left, Right, Both }; +using Expr_Binary_Test_WithAlias_Valid = + ResolverTestWithParam>; +TEST_P(Expr_Binary_Test_WithAlias_Valid, All) { + const Params& params = std::get<0>(GetParam()); + BinaryExprSide side = std::get<1>(GetParam()); + + auto* lhs_type = params.create_lhs_type(ty); + auto* rhs_type = params.create_rhs_type(ty); + + std::stringstream ss; + ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " " + << rhs_type->FriendlyName(Symbols()); + + // For vectors and matrices, wrap the sub type in an alias + auto make_alias = [this](type::Type* type) -> type::Type* { + type::Type* result; + if (auto* v = type->As()) { + result = create( + create(Symbols().New(), v->type()), v->size()); + } else if (auto* m = type->As()) { + result = + create(create(Symbols().New(), m->type()), + m->rows(), m->columns()); + } else { + result = create(Symbols().New(), type); + } + return result; + }; + + // Wrap in alias + if (side == BinaryExprSide::Left || side == BinaryExprSide::Both) { + lhs_type = make_alias(lhs_type); + } + if (side == BinaryExprSide::Right || side == BinaryExprSide::Both) { + rhs_type = make_alias(rhs_type); + } + + ss << ", After aliasing: " << lhs_type->FriendlyName(Symbols()) << " " + << params.op << " " << rhs_type->FriendlyName(Symbols()); + SCOPED_TRACE(ss.str()); + + Global("lhs", lhs_type, ast::StorageClass::kNone); + Global("rhs", rhs_type, ast::StorageClass::kNone); + + auto* expr = + create(params.op, Expr("lhs"), Expr("rhs")); + WrapInFunction(expr); + + ASSERT_TRUE(r()->Resolve()) << r()->error(); + ASSERT_NE(TypeOf(expr), nullptr); + // TODO(amaiorano): Bring this back once we have a way to get the canonical + // type + // auto* result_type = params.create_result_type(ty); + // ASSERT_TRUE(TypeOf(expr) == result_type); +} +INSTANTIATE_TEST_SUITE_P( + ResolverTest, + Expr_Binary_Test_WithAlias_Valid, + testing::Combine(testing::ValuesIn(all_valid_cases), + testing::Values(BinaryExprSide::Left, + BinaryExprSide::Right, + BinaryExprSide::Both))); + using Expr_Binary_Test_Invalid = ResolverTestWithParam>; TEST_P(Expr_Binary_Test_Invalid, All) { @@ -1186,6 +1255,11 @@ TEST_P(Expr_Binary_Test_Invalid, All) { return; } + std::stringstream ss; + ss << lhs_type->FriendlyName(Symbols()) << " " << params.op << " " + << rhs_type->FriendlyName(Symbols()); + SCOPED_TRACE(ss.str()); + Global("lhs", lhs_type, ast::StorageClass::kNone); Global("rhs", rhs_type, ast::StorageClass::kNone); diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index ca555ec698..af2ec2ff43 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -1723,8 +1723,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { // Handle int and float and the vectors of those types. Other types // should have been rejected by validation. - auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded(); - auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded(); + auto* lhs_type = TypeOf(expr->lhs())->UnwrapAll(); + auto* rhs_type = TypeOf(expr->rhs())->UnwrapAll(); bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector(); bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector(); @@ -1820,6 +1820,7 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { // float matrix * matrix op = spv::Op::OpMatrixTimesMatrix; } else { + error_ = "invalid multiply expression"; return 0; } } else if (expr->IsNotEqual()) {