Validate binary operations

This change validates that the operand types and result type of every
binary operation is valid.

* Added two unit tests which test all valid and invalid param combos. I
also removed the old tests, many of which failed once I added this
validation, and the rest are obviated by the new tests.

* Fixed VertexPulling transform, as well as many tests, that were using
invalid operand types for binary operations.

Fixed: tint:354
Change-Id: Ia3f48384256993da61b341f17ba5583741011819
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44341
Reviewed-by: Ben Clayton <bclayton@google.com>
Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
Antonio Maiorano
2021-03-16 13:26:03 +00:00
committed by Commit Bot service account
parent 1691401179
commit be0fc4e929
11 changed files with 534 additions and 259 deletions

View File

@@ -877,11 +877,167 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
return true;
}
bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
using Bool = type::Bool;
using F32 = type::F32;
using I32 = type::I32;
using U32 = type::U32;
using Matrix = type::Matrix;
using Vector = type::Vector;
auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
auto* lhs_vec = lhs_type->As<Vector>();
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
auto* rhs_vec = rhs_type->As<Vector>();
auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
const bool matching_types = lhs_type == rhs_type;
const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
(lhs_vec_elem_type == rhs_vec_elem_type);
// Binary logical expressions
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
if (matching_types && lhs_type->Is<Bool>()) {
return true;
}
}
if (expr->IsOr() || expr->IsAnd()) {
if (matching_types && lhs_type->Is<Bool>()) {
return true;
}
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
return true;
}
}
// Arithmetic expressions
if (expr->IsArithmetic()) {
// Binary arithmetic expressions over scalars
if (matching_types && lhs_type->IsAnyOf<I32, F32, U32>()) {
return true;
}
// Binary arithmetic expressions over vectors
if (matching_types && lhs_vec_elem_type &&
lhs_vec_elem_type->IsAnyOf<I32, F32, U32>()) {
return true;
}
}
// Binary arithmetic expressions with mixed scalar, vector, and matrix
// operands
if (expr->IsMultiply()) {
// Multiplication of a vector and a scalar
if (lhs_type->Is<F32>() && rhs_vec_elem_type &&
rhs_vec_elem_type->Is<F32>()) {
return true;
}
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_type->Is<F32>()) {
return true;
}
auto* lhs_mat = lhs_type->As<Matrix>();
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
auto* rhs_mat = rhs_type->As<Matrix>();
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
// Multiplication of a matrix and a scalar
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
rhs_mat_elem_type->Is<F32>()) {
return true;
}
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_type->Is<F32>()) {
return true;
}
// Vector times matrix
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
return true;
}
// Matrix times vector
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>()) {
return true;
}
// Matrix times matrix
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
return true;
}
}
// Comparison expressions
if (expr->IsComparison()) {
if (matching_types) {
// Special case for bools: only == and !=
if (lhs_type->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
return true;
}
// For the rest, we can compare i32, u32, and f32
if (lhs_type->IsAnyOf<I32, U32, F32>()) {
return true;
}
}
// Same for vectors
if (matching_vec_elem_types) {
if (lhs_vec_elem_type->Is<Bool>() &&
(expr->IsEqual() || expr->IsNotEqual())) {
return true;
}
if (lhs_vec_elem_type->IsAnyOf<I32, U32, F32>()) {
return true;
}
}
}
// Binary bitwise operations
if (expr->IsBitwise()) {
if (matching_types && lhs_type->IsAnyOf<I32, U32>()) {
return true;
}
}
// Bit shift expressions
if (expr->IsBitshift()) {
// Type validation rules are the same for left or right shift, despite
// differences in computation rules (i.e. right shift can be arithmetic or
// logical depending on lhs type).
if (lhs_type->IsAnyOf<I32, U32>() && rhs_type->Is<U32>()) {
return true;
}
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
return true;
}
}
diagnostics_.add_error(
"Binary expression operand types are invalid for this operation",
expr->source());
return false;
}
bool Resolver::Binary(ast::BinaryExpression* expr) {
if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
return false;
}
if (!ValidateBinary(expr)) {
return false;
}
// Result type matches first parameter type
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() ||

View File

@@ -171,6 +171,7 @@ class Resolver {
// AST and Type traversal methods
// Each return true on success, false on failure.
bool ArrayAccessor(ast::ArrayAccessorExpression*);
bool ValidateBinary(ast::BinaryExpression* expr);
bool Binary(ast::BinaryExpression*);
bool Bitcast(ast::BitcastExpression*);
bool BlockStatement(const ast::BlockStatement*);

View File

@@ -14,6 +14,8 @@
#include "src/resolver/resolver.h"
#include <tuple>
#include "gmock/gmock.h"
#include "src/ast/assignment_statement.h"
#include "src/ast/bitcast_expression.h"
@@ -971,246 +973,276 @@ TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) {
EXPECT_TRUE(TypeOf(expr)->Is<type::F32>());
}
using Expr_Binary_BitwiseTest = ResolverTestWithParam<ast::BinaryOp>;
TEST_P(Expr_Binary_BitwiseTest, Scalar) {
auto op = GetParam();
namespace ExprBinaryTest {
Global("val", ty.i32(), ast::StorageClass::kNone);
using create_type_func_ptr =
type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
WrapInFunction(expr);
struct Params {
ast::BinaryOp op;
create_type_func_ptr create_lhs_type;
create_type_func_ptr create_rhs_type;
create_type_func_ptr create_result_type;
};
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
EXPECT_TRUE(TypeOf(expr)->Is<type::I32>());
// Helpers and typedefs to make building the table below more succinct
using i32 = ProgramBuilder::i32;
using u32 = ProgramBuilder::u32;
using f32 = ProgramBuilder::f32;
using Op = ast::BinaryOp;
type::Type* ty_bool_(const ProgramBuilder::TypesBuilder& ty) {
return ty.bool_();
}
type::Type* ty_i32(const ProgramBuilder::TypesBuilder& ty) {
return ty.i32();
}
type::Type* ty_u32(const ProgramBuilder::TypesBuilder& ty) {
return ty.u32();
}
type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) {
return ty.f32();
}
TEST_P(Expr_Binary_BitwiseTest, Vector) {
auto op = GetParam();
template <typename T>
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
return ty.vec3<T>();
}
Global("val", ty.vec3<i32>(), ast::StorageClass::kNone);
template <typename T>
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
return ty.mat3x3<T>();
}
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
static constexpr create_type_func_ptr all_create_type_funcs[] = {
ty_bool_, ty_u32, ty_i32, ty_f32,
ty_vec3<bool>, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<f32>,
ty_mat3x3<i32>, ty_mat3x3<u32>, ty_mat3x3<f32>};
// A list of all valid test cases for 'lhs op rhs', except that for vecN and
// matNxN, we only test N=3.
static constexpr Params all_valid_cases[] = {
// Logical expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr
// Binary logical expressions
Params{Op::kLogicalAnd, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kLogicalOr, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kAnd, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kOr, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kAnd, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kOr, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
// Arithmetic expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr
// Binary arithmetic expressions over scalars
Params{Op::kAdd, ty_i32, ty_i32, ty_i32},
Params{Op::kSubtract, ty_i32, ty_i32, ty_i32},
Params{Op::kMultiply, ty_i32, ty_i32, ty_i32},
Params{Op::kDivide, ty_i32, ty_i32, ty_i32},
Params{Op::kModulo, ty_i32, ty_i32, ty_i32},
Params{Op::kAdd, ty_u32, ty_u32, ty_u32},
Params{Op::kSubtract, ty_u32, ty_u32, ty_u32},
Params{Op::kMultiply, ty_u32, ty_u32, ty_u32},
Params{Op::kDivide, ty_u32, ty_u32, ty_u32},
Params{Op::kModulo, ty_u32, ty_u32, ty_u32},
Params{Op::kAdd, ty_f32, ty_f32, ty_f32},
Params{Op::kSubtract, ty_f32, ty_f32, ty_f32},
Params{Op::kMultiply, ty_f32, ty_f32, ty_f32},
Params{Op::kDivide, ty_f32, ty_f32, ty_f32},
Params{Op::kModulo, ty_f32, ty_f32, ty_f32},
// Binary arithmetic expressions over vectors
Params{Op::kAdd, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kSubtract, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kMultiply, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kDivide, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kModulo, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
Params{Op::kAdd, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kSubtract, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kMultiply, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kDivide, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kModulo, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kAdd, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kSubtract, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kDivide, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kModulo, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
// Binary arithmetic expressions with mixed scalar, vector, and matrix
// operands
Params{Op::kMultiply, ty_vec3<f32>, ty_f32, ty_vec3<f32>},
Params{Op::kMultiply, ty_f32, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ty_mat3x3<f32>, ty_f32, ty_mat3x3<f32>},
Params{Op::kMultiply, ty_f32, ty_mat3x3<f32>, ty_mat3x3<f32>},
Params{Op::kMultiply, ty_vec3<f32>, ty_mat3x3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ty_mat3x3<f32>, ty_vec3<f32>, ty_vec3<f32>},
Params{Op::kMultiply, ty_mat3x3<f32>, ty_mat3x3<f32>, ty_mat3x3<f32>},
// Comparison expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
// Comparisons over scalars
Params{Op::kEqual, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kNotEqual, ty_bool_, ty_bool_, ty_bool_},
Params{Op::kEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kNotEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kLessThan, ty_i32, ty_i32, ty_bool_},
Params{Op::kLessThanEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kGreaterThan, ty_i32, ty_i32, ty_bool_},
Params{Op::kGreaterThanEqual, ty_i32, ty_i32, ty_bool_},
Params{Op::kEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kNotEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kLessThan, ty_u32, ty_u32, ty_bool_},
Params{Op::kLessThanEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kGreaterThan, ty_u32, ty_u32, ty_bool_},
Params{Op::kGreaterThanEqual, ty_u32, ty_u32, ty_bool_},
Params{Op::kEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kNotEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kLessThan, ty_f32, ty_f32, ty_bool_},
Params{Op::kLessThanEqual, ty_f32, ty_f32, ty_bool_},
Params{Op::kGreaterThan, ty_f32, ty_f32, ty_bool_},
Params{Op::kGreaterThanEqual, ty_f32, ty_f32, ty_bool_},
// Comparisons over vectors
Params{Op::kEqual, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
Params{Op::kEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kLessThan, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kLessThanEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kGreaterThan, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kGreaterThanEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
Params{Op::kEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kLessThan, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kLessThanEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kGreaterThan, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kGreaterThanEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
Params{Op::kEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kNotEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kLessThan, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kLessThanEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kGreaterThan, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
Params{Op::kGreaterThanEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
// Bit expressions
// https://gpuweb.github.io/gpuweb/wgsl.html#bit-expr
// Binary bitwise operations
Params{Op::kOr, ty_i32, ty_i32, ty_i32},
Params{Op::kAnd, ty_i32, ty_i32, ty_i32},
Params{Op::kXor, ty_i32, ty_i32, ty_i32},
Params{Op::kOr, ty_u32, ty_u32, ty_u32},
Params{Op::kAnd, ty_u32, ty_u32, ty_u32},
Params{Op::kXor, ty_u32, ty_u32, ty_u32},
// Bit shift expressions
Params{Op::kShiftLeft, ty_i32, ty_u32, ty_i32},
Params{Op::kShiftLeft, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<i32>},
Params{Op::kShiftLeft, ty_u32, ty_u32, ty_u32},
Params{Op::kShiftLeft, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
Params{Op::kShiftRight, ty_i32, ty_u32, ty_i32},
Params{Op::kShiftRight, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<i32>},
Params{Op::kShiftRight, ty_u32, ty_u32, ty_u32},
Params{Op::kShiftRight, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>}};
using Expr_Binary_Test_Valid = ResolverTestWithParam<Params>;
TEST_P(Expr_Binary_Test_Valid, All) {
auto& params = GetParam();
auto* lhs_type = params.create_lhs_type(ty);
auto* rhs_type = params.create_rhs_type(ty);
auto* result_type = params.create_result_type(ty);
SCOPED_TRACE(testing::Message()
<< lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
<< rhs_type->FriendlyName(Symbols()));
Global("lhs", lhs_type, ast::StorageClass::kNone);
Global("rhs", rhs_type, ast::StorageClass::kNone);
auto* expr =
create<ast::BinaryExpression>(params.op, Expr("lhs"), Expr("rhs"));
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::I32>());
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
ASSERT_TRUE(TypeOf(expr) == result_type);
}
INSTANTIATE_TEST_SUITE_P(ResolverTest,
Expr_Binary_BitwiseTest,
testing::Values(ast::BinaryOp::kAnd,
ast::BinaryOp::kOr,
ast::BinaryOp::kXor,
ast::BinaryOp::kShiftLeft,
ast::BinaryOp::kShiftRight,
ast::BinaryOp::kAdd,
ast::BinaryOp::kSubtract,
ast::BinaryOp::kDivide,
ast::BinaryOp::kModulo));
Expr_Binary_Test_Valid,
testing::ValuesIn(all_valid_cases));
using Expr_Binary_LogicalTest = ResolverTestWithParam<ast::BinaryOp>;
TEST_P(Expr_Binary_LogicalTest, Scalar) {
auto op = GetParam();
using Expr_Binary_Test_Invalid =
ResolverTestWithParam<std::tuple<Params, create_type_func_ptr>>;
TEST_P(Expr_Binary_Test_Invalid, All) {
const Params& params = std::get<0>(GetParam());
const create_type_func_ptr& create_type_func = std::get<1>(GetParam());
Global("val", ty.bool_(), ast::StorageClass::kNone);
// Currently, for most operations, for a given lhs type, there is exactly one
// rhs type allowed. The only exception is for multiplication, which allows
// any permutation of f32, vecN<f32>, and matNxN<f32>. We are fed valid inputs
// only via `params`, and all possible types via `create_type_func`, so we
// test invalid combinations by testing every other rhs type, modulo
// exceptions.
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
// Skip valid rhs type
if (params.create_rhs_type == create_type_func) {
return;
}
auto* lhs_type = params.create_lhs_type(ty);
auto* rhs_type = create_type_func(ty);
// Skip exceptions: multiplication of f32, vecN<f32>, and matNxN<f32>
if (params.op == Op::kMultiply &&
lhs_type->is_float_scalar_or_vector_or_matrix() &&
rhs_type->is_float_scalar_or_vector_or_matrix()) {
return;
}
SCOPED_TRACE(testing::Message()
<< lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
<< rhs_type->FriendlyName(Symbols()));
Global("lhs", lhs_type, ast::StorageClass::kNone);
Global("rhs", rhs_type, ast::StorageClass::kNone);
auto* expr = create<ast::BinaryExpression>(
Source{Source::Location{12, 34}}, params.op, Expr("lhs"), Expr("rhs"));
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
EXPECT_TRUE(TypeOf(expr)->Is<type::Bool>());
}
TEST_P(Expr_Binary_LogicalTest, Vector) {
auto op = GetParam();
Global("val", ty.vec3<bool>(), ast::StorageClass::kNone);
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::Bool>());
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
INSTANTIATE_TEST_SUITE_P(ResolverTest,
Expr_Binary_LogicalTest,
testing::Values(ast::BinaryOp::kLogicalAnd,
ast::BinaryOp::kLogicalOr));
using Expr_Binary_CompareTest = ResolverTestWithParam<ast::BinaryOp>;
TEST_P(Expr_Binary_CompareTest, Scalar) {
auto op = GetParam();
Global("val", ty.i32(), ast::StorageClass::kNone);
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
EXPECT_TRUE(TypeOf(expr)->Is<type::Bool>());
}
TEST_P(Expr_Binary_CompareTest, Vector) {
auto op = GetParam();
Global("val", ty.vec3<i32>(), ast::StorageClass::kNone);
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::Bool>());
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
INSTANTIATE_TEST_SUITE_P(ResolverTest,
Expr_Binary_CompareTest,
testing::Values(ast::BinaryOp::kEqual,
ast::BinaryOp::kNotEqual,
ast::BinaryOp::kLessThan,
ast::BinaryOp::kGreaterThan,
ast::BinaryOp::kLessThanEqual,
ast::BinaryOp::kGreaterThanEqual));
TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Scalar) {
Global("val", ty.i32(), ast::StorageClass::kNone);
auto* expr = Mul("val", "val");
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
EXPECT_TRUE(TypeOf(expr)->Is<type::I32>());
}
TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Scalar) {
Global("scalar", ty.f32(), ast::StorageClass::kNone);
Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* expr = Mul("vector", "scalar");
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Vector) {
Global("scalar", ty.f32(), ast::StorageClass::kNone);
Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* expr = Mul("scalar", "vector");
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Vector) {
Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
auto* expr = Mul("vector", "vector");
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Scalar) {
Global("scalar", ty.f32(), ast::StorageClass::kNone);
Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
auto* expr = Mul("matrix", "scalar");
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
auto* mat = TypeOf(expr)->As<type::Matrix>();
EXPECT_TRUE(mat->type()->Is<type::F32>());
EXPECT_EQ(mat->rows(), 3u);
EXPECT_EQ(mat->columns(), 2u);
}
TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Matrix) {
Global("scalar", ty.f32(), ast::StorageClass::kNone);
Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
auto* expr = Mul("scalar", "matrix");
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
auto* mat = TypeOf(expr)->As<type::Matrix>();
EXPECT_TRUE(mat->type()->Is<type::F32>());
EXPECT_EQ(mat->rows(), 3u);
EXPECT_EQ(mat->columns(), 2u);
}
TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Vector) {
Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
auto* expr = Mul("matrix", "vector");
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
}
TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Matrix) {
Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
auto* expr = Mul("vector", "matrix");
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 2u);
}
TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Matrix) {
Global("mat3x4", ty.mat3x4<f32>(), ast::StorageClass::kNone);
Global("mat4x3", ty.mat4x3<f32>(), ast::StorageClass::kNone);
auto* expr = Mul("mat3x4", "mat4x3");
WrapInFunction(expr);
ASSERT_TRUE(r()->Resolve()) << r()->error();
ASSERT_NE(TypeOf(expr), nullptr);
ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
auto* mat = TypeOf(expr)->As<type::Matrix>();
EXPECT_TRUE(mat->type()->Is<type::F32>());
EXPECT_EQ(mat->rows(), 4u);
EXPECT_EQ(mat->columns(), 4u);
ASSERT_FALSE(r()->Resolve()) << r()->error();
ASSERT_EQ(r()->error(),
"12:34 error: Binary expression operand types are invalid for "
"this operation");
}
INSTANTIATE_TEST_SUITE_P(
ResolverTest,
Expr_Binary_Test_Invalid,
testing::Combine(testing::ValuesIn(all_valid_cases),
testing::ValuesIn(all_create_type_funcs)));
} // namespace ExprBinaryTest
using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>;
TEST_P(UnaryOpExpressionTest, Expr_UnaryOp) {