Validate multiply of invalid vector/matrix sizes

Added tests that test all combos of vec*mat, mat*vec, and mat*mat for 2,
3, and 4 dimensions.

Bug: tint:698
Change-Id: I4a407228261cf8ea2a93bc7077544e5a9244d854
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46660
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano 2021-04-01 19:58:37 +00:00 committed by Commit Bot service account
parent f1b643ee70
commit 61dabab673
3 changed files with 121 additions and 12 deletions

View File

@ -928,7 +928,7 @@ class ProgramBuilder {
/// @param rhs the right hand argument to the addition operation /// @param rhs the right hand argument to the addition operation
/// @returns a `ast::BinaryExpression` summing the arguments `lhs` and `rhs` /// @returns a `ast::BinaryExpression` summing the arguments `lhs` and `rhs`
template <typename LHS, typename RHS> template <typename LHS, typename RHS>
ast::Expression* Add(LHS&& lhs, RHS&& rhs) { ast::BinaryExpression* Add(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kAdd, return create<ast::BinaryExpression>(ast::BinaryOp::kAdd,
Expr(std::forward<LHS>(lhs)), Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs))); Expr(std::forward<RHS>(rhs)));
@ -938,7 +938,7 @@ class ProgramBuilder {
/// @param rhs the right hand argument to the subtraction operation /// @param rhs the right hand argument to the subtraction operation
/// @returns a `ast::BinaryExpression` subtracting `rhs` from `lhs` /// @returns a `ast::BinaryExpression` subtracting `rhs` from `lhs`
template <typename LHS, typename RHS> template <typename LHS, typename RHS>
ast::Expression* Sub(LHS&& lhs, RHS&& rhs) { ast::BinaryExpression* Sub(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kSubtract, return create<ast::BinaryExpression>(ast::BinaryOp::kSubtract,
Expr(std::forward<LHS>(lhs)), Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs))); Expr(std::forward<RHS>(rhs)));
@ -948,17 +948,28 @@ class ProgramBuilder {
/// @param rhs the right hand argument to the multiplication operation /// @param rhs the right hand argument to the multiplication operation
/// @returns a `ast::BinaryExpression` multiplying `rhs` from `lhs` /// @returns a `ast::BinaryExpression` multiplying `rhs` from `lhs`
template <typename LHS, typename RHS> template <typename LHS, typename RHS>
ast::Expression* Mul(LHS&& lhs, RHS&& rhs) { ast::BinaryExpression* Mul(LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(ast::BinaryOp::kMultiply, return create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
Expr(std::forward<LHS>(lhs)), Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs))); Expr(std::forward<RHS>(rhs)));
} }
/// @param source the source information
/// @param lhs the left hand argument to the multiplication operation
/// @param rhs the right hand argument to the multiplication operation
/// @returns a `ast::BinaryExpression` multiplying `rhs` from `lhs`
template <typename LHS, typename RHS>
ast::BinaryExpression* Mul(const Source& source, LHS&& lhs, RHS&& rhs) {
return create<ast::BinaryExpression>(source, ast::BinaryOp::kMultiply,
Expr(std::forward<LHS>(lhs)),
Expr(std::forward<RHS>(rhs)));
}
/// @param arr the array argument for the array accessor expression /// @param arr the array argument for the array accessor expression
/// @param idx the index argument for the array accessor expression /// @param idx the index argument for the array accessor expression
/// @returns a `ast::ArrayAccessorExpression` that indexes `arr` with `idx` /// @returns a `ast::ArrayAccessorExpression` that indexes `arr` with `idx`
template <typename ARR, typename IDX> template <typename ARR, typename IDX>
ast::Expression* IndexAccessor(ARR&& arr, IDX&& idx) { ast::ArrayAccessorExpression* IndexAccessor(ARR&& arr, IDX&& idx) {
return create<ast::ArrayAccessorExpression>(Expr(std::forward<ARR>(arr)), return create<ast::ArrayAccessorExpression>(Expr(std::forward<ARR>(arr)),
Expr(std::forward<IDX>(idx))); Expr(std::forward<IDX>(idx)));
} }

View File

@ -1040,8 +1040,10 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
auto* rhs_vec_elem_type = auto* rhs_vec_elem_type =
rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr; rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type && const bool matching_vec_elem_types =
(lhs_vec_elem_type == rhs_vec_elem_type); lhs_vec_elem_type && rhs_vec_elem_type &&
(lhs_vec_elem_type == rhs_vec_elem_type) &&
(lhs_vec->size() == rhs_vec->size());
const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type); const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type);
@ -1106,19 +1108,22 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
// Vector times matrix // Vector times matrix
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() && if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) { rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_vec->size() == rhs_mat->rows())) {
return true; return true;
} }
// Matrix times vector // Matrix times vector
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>()) { rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_vec->size())) {
return true; return true;
} }
// Matrix times matrix // Matrix times matrix
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() && if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) { rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
(lhs_mat->columns() == rhs_mat->rows())) {
return true; return true;
} }
} }

View File

@ -1263,11 +1263,11 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
Global("lhs", lhs_type, ast::StorageClass::kNone); Global("lhs", lhs_type, ast::StorageClass::kNone);
Global("rhs", rhs_type, ast::StorageClass::kNone); Global("rhs", rhs_type, ast::StorageClass::kNone);
auto* expr = create<ast::BinaryExpression>( auto* expr = create<ast::BinaryExpression>(Source{{12, 34}}, params.op,
Source{Source::Location{12, 34}}, params.op, Expr("lhs"), Expr("rhs")); Expr("lhs"), Expr("rhs"));
WrapInFunction(expr); WrapInFunction(expr);
ASSERT_FALSE(r()->Resolve()) << r()->error(); ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(), ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for " "12:34 error: Binary expression operand types are invalid for "
"this operation: " + "this operation: " +
@ -1280,6 +1280,99 @@ INSTANTIATE_TEST_SUITE_P(
Expr_Binary_Test_Invalid, Expr_Binary_Test_Invalid,
testing::Combine(testing::ValuesIn(all_valid_cases), testing::Combine(testing::ValuesIn(all_valid_cases),
testing::ValuesIn(all_create_type_funcs))); testing::ValuesIn(all_create_type_funcs)));
using Expr_Binary_Test_Invalid_VectorMatrixMultiply =
ResolverTestWithParam<std::tuple<bool, uint32_t, uint32_t, uint32_t>>;
TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
bool vec_by_mat = std::get<0>(GetParam());
uint32_t vec_size = std::get<1>(GetParam());
uint32_t mat_rows = std::get<2>(GetParam());
uint32_t mat_cols = std::get<3>(GetParam());
type::Type* lhs_type;
type::Type* rhs_type;
type::Type* result_type;
bool is_valid_expr;
if (vec_by_mat) {
lhs_type = create<type::Vector>(ty.f32(), vec_size);
rhs_type = create<type::Matrix>(ty.f32(), mat_rows, mat_cols);
result_type = create<type::Vector>(ty.f32(), mat_cols);
is_valid_expr = vec_size == mat_rows;
} else {
lhs_type = create<type::Matrix>(ty.f32(), mat_rows, mat_cols);
rhs_type = create<type::Vector>(ty.f32(), vec_size);
result_type = create<type::Vector>(ty.f32(), mat_rows);
is_valid_expr = vec_size == mat_cols;
}
Global("lhs", lhs_type, ast::StorageClass::kNone);
Global("rhs", rhs_type, ast::StorageClass::kNone);
auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
WrapInFunction(expr);
if (is_valid_expr) {
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_TRUE(TypeOf(expr) == result_type);
} else {
ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
lhs_type->FriendlyName(Symbols()) + " " +
FriendlyName(expr->op()) + " " +
rhs_type->FriendlyName(Symbols()));
}
}
auto all_dimension_values = testing::Values(2u, 3u, 4u);
INSTANTIATE_TEST_SUITE_P(ResolverTest,
Expr_Binary_Test_Invalid_VectorMatrixMultiply,
testing::Combine(testing::Values(true, false),
all_dimension_values,
all_dimension_values,
all_dimension_values));
using Expr_Binary_Test_Invalid_MatrixMatrixMultiply =
ResolverTestWithParam<std::tuple<uint32_t, uint32_t, uint32_t, uint32_t>>;
TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) {
uint32_t lhs_mat_rows = std::get<0>(GetParam());
uint32_t lhs_mat_cols = std::get<1>(GetParam());
uint32_t rhs_mat_rows = std::get<2>(GetParam());
uint32_t rhs_mat_cols = std::get<3>(GetParam());
auto* lhs_type = create<type::Matrix>(ty.f32(), lhs_mat_rows, lhs_mat_cols);
auto* rhs_type = create<type::Matrix>(ty.f32(), rhs_mat_rows, rhs_mat_cols);
auto* result_type =
create<type::Matrix>(ty.f32(), lhs_mat_rows, rhs_mat_cols);
Global("lhs", lhs_type, ast::StorageClass::kNone);
Global("rhs", rhs_type, ast::StorageClass::kNone);
auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
WrapInFunction(expr);
bool is_valid_expr = lhs_mat_cols == rhs_mat_rows;
if (is_valid_expr) {
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_TRUE(TypeOf(expr) == result_type);
} else {
ASSERT_FALSE(r()->Resolve());
ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for "
"this operation: " +
lhs_type->FriendlyName(Symbols()) + " " +
FriendlyName(expr->op()) + " " +
rhs_type->FriendlyName(Symbols()));
}
}
INSTANTIATE_TEST_SUITE_P(ResolverTest,
Expr_Binary_Test_Invalid_MatrixMatrixMultiply,
testing::Combine(all_dimension_values,
all_dimension_values,
all_dimension_values,
all_dimension_values));
} // namespace ExprBinaryTest } // namespace ExprBinaryTest
using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>; using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>;