Validate multiply of invalid vector/matrix sizes
Added tests that test all combos of vec*mat, mat*vec, and mat*mat for 2, 3, and 4 dimensions. Bug: tint:698 Change-Id: I4a407228261cf8ea2a93bc7077544e5a9244d854 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/46660 Reviewed-by: Ben Clayton <bclayton@google.com> Kokoro: Kokoro <noreply+kokoro@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
f1b643ee70
commit
61dabab673
|
@ -928,7 +928,7 @@ class ProgramBuilder {
|
||||||
/// @param rhs the right hand argument to the addition operation
|
/// @param rhs the right hand argument to the addition operation
|
||||||
/// @returns a `ast::BinaryExpression` summing the arguments `lhs` and `rhs`
|
/// @returns a `ast::BinaryExpression` summing the arguments `lhs` and `rhs`
|
||||||
template <typename LHS, typename RHS>
|
template <typename LHS, typename RHS>
|
||||||
ast::Expression* Add(LHS&& lhs, RHS&& rhs) {
|
ast::BinaryExpression* Add(LHS&& lhs, RHS&& rhs) {
|
||||||
return create<ast::BinaryExpression>(ast::BinaryOp::kAdd,
|
return create<ast::BinaryExpression>(ast::BinaryOp::kAdd,
|
||||||
Expr(std::forward<LHS>(lhs)),
|
Expr(std::forward<LHS>(lhs)),
|
||||||
Expr(std::forward<RHS>(rhs)));
|
Expr(std::forward<RHS>(rhs)));
|
||||||
|
@ -938,7 +938,7 @@ class ProgramBuilder {
|
||||||
/// @param rhs the right hand argument to the subtraction operation
|
/// @param rhs the right hand argument to the subtraction operation
|
||||||
/// @returns a `ast::BinaryExpression` subtracting `rhs` from `lhs`
|
/// @returns a `ast::BinaryExpression` subtracting `rhs` from `lhs`
|
||||||
template <typename LHS, typename RHS>
|
template <typename LHS, typename RHS>
|
||||||
ast::Expression* Sub(LHS&& lhs, RHS&& rhs) {
|
ast::BinaryExpression* Sub(LHS&& lhs, RHS&& rhs) {
|
||||||
return create<ast::BinaryExpression>(ast::BinaryOp::kSubtract,
|
return create<ast::BinaryExpression>(ast::BinaryOp::kSubtract,
|
||||||
Expr(std::forward<LHS>(lhs)),
|
Expr(std::forward<LHS>(lhs)),
|
||||||
Expr(std::forward<RHS>(rhs)));
|
Expr(std::forward<RHS>(rhs)));
|
||||||
|
@ -948,17 +948,28 @@ class ProgramBuilder {
|
||||||
/// @param rhs the right hand argument to the multiplication operation
|
/// @param rhs the right hand argument to the multiplication operation
|
||||||
/// @returns a `ast::BinaryExpression` multiplying `rhs` from `lhs`
|
/// @returns a `ast::BinaryExpression` multiplying `rhs` from `lhs`
|
||||||
template <typename LHS, typename RHS>
|
template <typename LHS, typename RHS>
|
||||||
ast::Expression* Mul(LHS&& lhs, RHS&& rhs) {
|
ast::BinaryExpression* Mul(LHS&& lhs, RHS&& rhs) {
|
||||||
return create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
|
return create<ast::BinaryExpression>(ast::BinaryOp::kMultiply,
|
||||||
Expr(std::forward<LHS>(lhs)),
|
Expr(std::forward<LHS>(lhs)),
|
||||||
Expr(std::forward<RHS>(rhs)));
|
Expr(std::forward<RHS>(rhs)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// @param source the source information
|
||||||
|
/// @param lhs the left hand argument to the multiplication operation
|
||||||
|
/// @param rhs the right hand argument to the multiplication operation
|
||||||
|
/// @returns a `ast::BinaryExpression` multiplying `rhs` from `lhs`
|
||||||
|
template <typename LHS, typename RHS>
|
||||||
|
ast::BinaryExpression* Mul(const Source& source, LHS&& lhs, RHS&& rhs) {
|
||||||
|
return create<ast::BinaryExpression>(source, ast::BinaryOp::kMultiply,
|
||||||
|
Expr(std::forward<LHS>(lhs)),
|
||||||
|
Expr(std::forward<RHS>(rhs)));
|
||||||
|
}
|
||||||
|
|
||||||
/// @param arr the array argument for the array accessor expression
|
/// @param arr the array argument for the array accessor expression
|
||||||
/// @param idx the index argument for the array accessor expression
|
/// @param idx the index argument for the array accessor expression
|
||||||
/// @returns a `ast::ArrayAccessorExpression` that indexes `arr` with `idx`
|
/// @returns a `ast::ArrayAccessorExpression` that indexes `arr` with `idx`
|
||||||
template <typename ARR, typename IDX>
|
template <typename ARR, typename IDX>
|
||||||
ast::Expression* IndexAccessor(ARR&& arr, IDX&& idx) {
|
ast::ArrayAccessorExpression* IndexAccessor(ARR&& arr, IDX&& idx) {
|
||||||
return create<ast::ArrayAccessorExpression>(Expr(std::forward<ARR>(arr)),
|
return create<ast::ArrayAccessorExpression>(Expr(std::forward<ARR>(arr)),
|
||||||
Expr(std::forward<IDX>(idx)));
|
Expr(std::forward<IDX>(idx)));
|
||||||
}
|
}
|
||||||
|
|
|
@ -1040,8 +1040,10 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
|
||||||
auto* rhs_vec_elem_type =
|
auto* rhs_vec_elem_type =
|
||||||
rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
|
rhs_vec ? rhs_vec->type()->UnwrapAliasIfNeeded() : nullptr;
|
||||||
|
|
||||||
const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
|
const bool matching_vec_elem_types =
|
||||||
(lhs_vec_elem_type == rhs_vec_elem_type);
|
lhs_vec_elem_type && rhs_vec_elem_type &&
|
||||||
|
(lhs_vec_elem_type == rhs_vec_elem_type) &&
|
||||||
|
(lhs_vec->size() == rhs_vec->size());
|
||||||
|
|
||||||
const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type);
|
const bool matching_types = matching_vec_elem_types || (lhs_type == rhs_type);
|
||||||
|
|
||||||
|
@ -1106,19 +1108,22 @@ bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
|
||||||
|
|
||||||
// Vector times matrix
|
// Vector times matrix
|
||||||
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
|
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
|
||||||
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
|
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
|
||||||
|
(lhs_vec->size() == rhs_mat->rows())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matrix times vector
|
// Matrix times vector
|
||||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
||||||
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>()) {
|
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>() &&
|
||||||
|
(lhs_mat->columns() == rhs_vec->size())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Matrix times matrix
|
// Matrix times matrix
|
||||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
||||||
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
|
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>() &&
|
||||||
|
(lhs_mat->columns() == rhs_mat->rows())) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1263,11 +1263,11 @@ TEST_P(Expr_Binary_Test_Invalid, All) {
|
||||||
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
||||||
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
||||||
|
|
||||||
auto* expr = create<ast::BinaryExpression>(
|
auto* expr = create<ast::BinaryExpression>(Source{{12, 34}}, params.op,
|
||||||
Source{Source::Location{12, 34}}, params.op, Expr("lhs"), Expr("rhs"));
|
Expr("lhs"), Expr("rhs"));
|
||||||
WrapInFunction(expr);
|
WrapInFunction(expr);
|
||||||
|
|
||||||
ASSERT_FALSE(r()->Resolve()) << r()->error();
|
ASSERT_FALSE(r()->Resolve());
|
||||||
ASSERT_EQ(r()->error(),
|
ASSERT_EQ(r()->error(),
|
||||||
"12:34 error: Binary expression operand types are invalid for "
|
"12:34 error: Binary expression operand types are invalid for "
|
||||||
"this operation: " +
|
"this operation: " +
|
||||||
|
@ -1280,6 +1280,99 @@ INSTANTIATE_TEST_SUITE_P(
|
||||||
Expr_Binary_Test_Invalid,
|
Expr_Binary_Test_Invalid,
|
||||||
testing::Combine(testing::ValuesIn(all_valid_cases),
|
testing::Combine(testing::ValuesIn(all_valid_cases),
|
||||||
testing::ValuesIn(all_create_type_funcs)));
|
testing::ValuesIn(all_create_type_funcs)));
|
||||||
|
|
||||||
|
using Expr_Binary_Test_Invalid_VectorMatrixMultiply =
|
||||||
|
ResolverTestWithParam<std::tuple<bool, uint32_t, uint32_t, uint32_t>>;
|
||||||
|
TEST_P(Expr_Binary_Test_Invalid_VectorMatrixMultiply, All) {
|
||||||
|
bool vec_by_mat = std::get<0>(GetParam());
|
||||||
|
uint32_t vec_size = std::get<1>(GetParam());
|
||||||
|
uint32_t mat_rows = std::get<2>(GetParam());
|
||||||
|
uint32_t mat_cols = std::get<3>(GetParam());
|
||||||
|
|
||||||
|
type::Type* lhs_type;
|
||||||
|
type::Type* rhs_type;
|
||||||
|
type::Type* result_type;
|
||||||
|
bool is_valid_expr;
|
||||||
|
|
||||||
|
if (vec_by_mat) {
|
||||||
|
lhs_type = create<type::Vector>(ty.f32(), vec_size);
|
||||||
|
rhs_type = create<type::Matrix>(ty.f32(), mat_rows, mat_cols);
|
||||||
|
result_type = create<type::Vector>(ty.f32(), mat_cols);
|
||||||
|
is_valid_expr = vec_size == mat_rows;
|
||||||
|
} else {
|
||||||
|
lhs_type = create<type::Matrix>(ty.f32(), mat_rows, mat_cols);
|
||||||
|
rhs_type = create<type::Vector>(ty.f32(), vec_size);
|
||||||
|
result_type = create<type::Vector>(ty.f32(), mat_rows);
|
||||||
|
is_valid_expr = vec_size == mat_cols;
|
||||||
|
}
|
||||||
|
|
||||||
|
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
||||||
|
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
||||||
|
|
||||||
|
auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
|
||||||
|
WrapInFunction(expr);
|
||||||
|
|
||||||
|
if (is_valid_expr) {
|
||||||
|
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
ASSERT_TRUE(TypeOf(expr) == result_type);
|
||||||
|
} else {
|
||||||
|
ASSERT_FALSE(r()->Resolve());
|
||||||
|
ASSERT_EQ(r()->error(),
|
||||||
|
"12:34 error: Binary expression operand types are invalid for "
|
||||||
|
"this operation: " +
|
||||||
|
lhs_type->FriendlyName(Symbols()) + " " +
|
||||||
|
FriendlyName(expr->op()) + " " +
|
||||||
|
rhs_type->FriendlyName(Symbols()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto all_dimension_values = testing::Values(2u, 3u, 4u);
|
||||||
|
INSTANTIATE_TEST_SUITE_P(ResolverTest,
|
||||||
|
Expr_Binary_Test_Invalid_VectorMatrixMultiply,
|
||||||
|
testing::Combine(testing::Values(true, false),
|
||||||
|
all_dimension_values,
|
||||||
|
all_dimension_values,
|
||||||
|
all_dimension_values));
|
||||||
|
|
||||||
|
using Expr_Binary_Test_Invalid_MatrixMatrixMultiply =
|
||||||
|
ResolverTestWithParam<std::tuple<uint32_t, uint32_t, uint32_t, uint32_t>>;
|
||||||
|
TEST_P(Expr_Binary_Test_Invalid_MatrixMatrixMultiply, All) {
|
||||||
|
uint32_t lhs_mat_rows = std::get<0>(GetParam());
|
||||||
|
uint32_t lhs_mat_cols = std::get<1>(GetParam());
|
||||||
|
uint32_t rhs_mat_rows = std::get<2>(GetParam());
|
||||||
|
uint32_t rhs_mat_cols = std::get<3>(GetParam());
|
||||||
|
|
||||||
|
auto* lhs_type = create<type::Matrix>(ty.f32(), lhs_mat_rows, lhs_mat_cols);
|
||||||
|
auto* rhs_type = create<type::Matrix>(ty.f32(), rhs_mat_rows, rhs_mat_cols);
|
||||||
|
auto* result_type =
|
||||||
|
create<type::Matrix>(ty.f32(), lhs_mat_rows, rhs_mat_cols);
|
||||||
|
|
||||||
|
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
||||||
|
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
||||||
|
|
||||||
|
auto* expr = Mul(Source{{12, 34}}, Expr("lhs"), Expr("rhs"));
|
||||||
|
WrapInFunction(expr);
|
||||||
|
|
||||||
|
bool is_valid_expr = lhs_mat_cols == rhs_mat_rows;
|
||||||
|
if (is_valid_expr) {
|
||||||
|
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||||
|
ASSERT_TRUE(TypeOf(expr) == result_type);
|
||||||
|
} else {
|
||||||
|
ASSERT_FALSE(r()->Resolve());
|
||||||
|
ASSERT_EQ(r()->error(),
|
||||||
|
"12:34 error: Binary expression operand types are invalid for "
|
||||||
|
"this operation: " +
|
||||||
|
lhs_type->FriendlyName(Symbols()) + " " +
|
||||||
|
FriendlyName(expr->op()) + " " +
|
||||||
|
rhs_type->FriendlyName(Symbols()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
INSTANTIATE_TEST_SUITE_P(ResolverTest,
|
||||||
|
Expr_Binary_Test_Invalid_MatrixMatrixMultiply,
|
||||||
|
testing::Combine(all_dimension_values,
|
||||||
|
all_dimension_values,
|
||||||
|
all_dimension_values,
|
||||||
|
all_dimension_values));
|
||||||
|
|
||||||
} // namespace ExprBinaryTest
|
} // namespace ExprBinaryTest
|
||||||
|
|
||||||
using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>;
|
using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>;
|
||||||
|
|
Loading…
Reference in New Issue