diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index 1645aaba1e..7a088c775d 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -1216,6 +1216,27 @@ struct Params { create_sem_type_func_ptr create_result_type; }; +static constexpr ast::BinaryOp all_ops[] = { + ast::BinaryOp::kAnd, + ast::BinaryOp::kOr, + ast::BinaryOp::kXor, + ast::BinaryOp::kLogicalAnd, + ast::BinaryOp::kLogicalOr, + ast::BinaryOp::kEqual, + ast::BinaryOp::kNotEqual, + ast::BinaryOp::kLessThan, + ast::BinaryOp::kGreaterThan, + ast::BinaryOp::kLessThanEqual, + ast::BinaryOp::kGreaterThanEqual, + ast::BinaryOp::kShiftLeft, + ast::BinaryOp::kShiftRight, + ast::BinaryOp::kAdd, + ast::BinaryOp::kSubtract, + ast::BinaryOp::kMultiply, + ast::BinaryOp::kDivide, + ast::BinaryOp::kModulo, +}; + static constexpr create_ast_type_func_ptr all_create_type_funcs[] = { ast_bool, ast_u32, ast_i32, ast_f32, ast_vec3, ast_vec3, ast_vec3, ast_vec3, @@ -1467,44 +1488,41 @@ INSTANTIATE_TEST_SUITE_P( BinaryExprSide::Right, BinaryExprSide::Both))); +// This test works by taking the cartesian product of all possible +// (type * type * op), and processing only the triplets that are not found in +// the `all_valid_cases` table. using Expr_Binary_Test_Invalid = - ResolverTestWithParam>; + ResolverTestWithParam>; TEST_P(Expr_Binary_Test_Invalid, All) { - const Params& params = std::get<0>(GetParam()); - auto& create_type_func = std::get<1>(GetParam()); + const create_ast_type_func_ptr& lhs_create_type_func = + std::get<0>(GetParam()); + const create_ast_type_func_ptr& rhs_create_type_func = + std::get<1>(GetParam()); + const ast::BinaryOp op = std::get<2>(GetParam()); - // Currently, for most operations, for a given lhs type, there is exactly one - // rhs type allowed. The only exception is for multiplication, which allows - // any permutation of f32, vecN, and matNxN. We are fed valid inputs - // only via `params`, and all possible types via `create_type_func`, so we - // test invalid combinations by testing every other rhs type, modulo - // exceptions. - - // Skip valid rhs type - if (params.create_rhs_type == create_type_func) { - return; + // Skip if valid case + // TODO(amaiorano): replace linear lookup with O(1) if too slow + for (auto& c : all_valid_cases) { + if (c.create_lhs_type == lhs_create_type_func && + c.create_rhs_type == rhs_create_type_func && c.op == op) { + return; + } } - auto* lhs_type = params.create_lhs_type(ty); - auto* rhs_type = create_type_func(ty); - - // Skip exceptions: multiplication of f32, vecN, and matNxN - if (params.op == Op::kMultiply && - lhs_type->is_float_scalar_or_vector_or_matrix() && - rhs_type->is_float_scalar_or_vector_or_matrix()) { - return; - } + auto* lhs_type = lhs_create_type_func(ty); + auto* rhs_type = rhs_create_type_func(ty); std::stringstream ss; - ss << FriendlyName(lhs_type) << " " << params.op << " " - << FriendlyName(rhs_type); + ss << FriendlyName(lhs_type) << " " << op << " " << FriendlyName(rhs_type); SCOPED_TRACE(ss.str()); Global("lhs", lhs_type, ast::StorageClass::kInput); Global("rhs", rhs_type, ast::StorageClass::kInput); - auto* expr = create(Source{{12, 34}}, params.op, - Expr("lhs"), Expr("rhs")); + auto* expr = create(Source{{12, 34}}, op, Expr("lhs"), + Expr("rhs")); WrapInFunction(expr); ASSERT_FALSE(r()->Resolve()); @@ -1517,8 +1535,9 @@ TEST_P(Expr_Binary_Test_Invalid, All) { INSTANTIATE_TEST_SUITE_P( ResolverTest, Expr_Binary_Test_Invalid, - testing::Combine(testing::ValuesIn(all_valid_cases), - testing::ValuesIn(all_create_type_funcs))); + testing::Combine(testing::ValuesIn(all_create_type_funcs), + testing::ValuesIn(all_create_type_funcs), + testing::ValuesIn(all_ops))); using Expr_Binary_Test_Invalid_VectorMatrixMultiply = ResolverTestWithParam>;