From be0fc4e929093faa52e65e50acd1ee133d6fed36 Mon Sep 17 00:00:00 2001 From: Antonio Maiorano Date: Tue, 16 Mar 2021 13:26:03 +0000 Subject: [PATCH] Validate binary operations This change validates that the operand types and result type of every binary operation is valid. * Added two unit tests which test all valid and invalid param combos. I also removed the old tests, many of which failed once I added this validation, and the rest are obviated by the new tests. * Fixed VertexPulling transform, as well as many tests, that were using invalid operand types for binary operations. Fixed: tint:354 Change-Id: Ia3f48384256993da61b341f17ba5583741011819 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44341 Reviewed-by: Ben Clayton Commit-Queue: Antonio Maiorano --- src/ast/binary_expression.h | 64 ++- src/resolver/resolver.cc | 156 ++++++ src/resolver/resolver.h | 1 + src/resolver/resolver_test.cc | 474 ++++++++++-------- src/transform/bound_array_accessors_test.cc | 24 +- src/transform/vertex_pulling.cc | 6 +- src/transform/vertex_pulling_test.cc | 32 +- src/type/type.cc | 4 + src/type/type.h | 2 + src/writer/hlsl/generator_impl_binary_test.cc | 14 + .../spirv/builder_binary_expression_test.cc | 16 +- 11 files changed, 534 insertions(+), 259 deletions(-) diff --git a/src/ast/binary_expression.h b/src/ast/binary_expression.h index 84834acaf1..b9b8c1c4c8 100644 --- a/src/ast/binary_expression.h +++ b/src/ast/binary_expression.h @@ -23,11 +23,11 @@ namespace ast { /// The operator type enum class BinaryOp { kNone = 0, - kAnd, - kOr, + kAnd, // & + kOr, // | kXor, - kLogicalAnd, - kLogicalOr, + kLogicalAnd, // && + kLogicalOr, // || kEqual, kNotEqual, kLessThan, @@ -98,6 +98,14 @@ class BinaryExpression : public Castable { bool IsDivide() const { return op_ == BinaryOp::kDivide; } /// @returns true if the op is modulo bool IsModulo() const { return op_ == BinaryOp::kModulo; } + /// @returns true if the op is an arithmetic operation + bool IsArithmetic() const; + /// @returns true if the op is a comparison operation + bool IsComparison() const; + /// @returns true if the op is a bitwise operation + bool IsBitwise() const; + /// @returns true if the op is a bit shift operation + bool IsBitshift() const; /// @returns the left side expression Expression* lhs() const { return lhs_; } @@ -126,6 +134,54 @@ class BinaryExpression : public Castable { Expression* const rhs_; }; +inline bool BinaryExpression::IsArithmetic() const { + switch (op_) { + case ast::BinaryOp::kAdd: + case ast::BinaryOp::kSubtract: + case ast::BinaryOp::kMultiply: + case ast::BinaryOp::kDivide: + case ast::BinaryOp::kModulo: + return true; + default: + return false; + } +} + +inline bool BinaryExpression::IsComparison() const { + switch (op_) { + case ast::BinaryOp::kEqual: + case ast::BinaryOp::kNotEqual: + case ast::BinaryOp::kLessThan: + case ast::BinaryOp::kLessThanEqual: + case ast::BinaryOp::kGreaterThan: + case ast::BinaryOp::kGreaterThanEqual: + return true; + default: + return false; + } +} + +inline bool BinaryExpression::IsBitwise() const { + switch (op_) { + case ast::BinaryOp::kAnd: + case ast::BinaryOp::kOr: + case ast::BinaryOp::kXor: + return true; + default: + return false; + } +} + +inline bool BinaryExpression::IsBitshift() const { + switch (op_) { + case ast::BinaryOp::kShiftLeft: + case ast::BinaryOp::kShiftRight: + return true; + default: + return false; + } +} + inline std::ostream& operator<<(std::ostream& out, BinaryOp op) { switch (op) { case BinaryOp::kNone: diff --git a/src/resolver/resolver.cc b/src/resolver/resolver.cc index 8b27d4703e..9198c396d0 100644 --- a/src/resolver/resolver.cc +++ b/src/resolver/resolver.cc @@ -877,11 +877,167 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) { return true; } +bool Resolver::ValidateBinary(ast::BinaryExpression* expr) { + using Bool = type::Bool; + using F32 = type::F32; + using I32 = type::I32; + using U32 = type::U32; + using Matrix = type::Matrix; + using Vector = type::Vector; + + auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded(); + auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded(); + + auto* lhs_vec = lhs_type->As(); + auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr; + auto* rhs_vec = rhs_type->As(); + auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr; + + const bool matching_types = lhs_type == rhs_type; + const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type && + (lhs_vec_elem_type == rhs_vec_elem_type); + + // Binary logical expressions + if (expr->IsLogicalAnd() || expr->IsLogicalOr()) { + if (matching_types && lhs_type->Is()) { + return true; + } + } + if (expr->IsOr() || expr->IsAnd()) { + if (matching_types && lhs_type->Is()) { + return true; + } + if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is()) { + return true; + } + } + + // Arithmetic expressions + if (expr->IsArithmetic()) { + // Binary arithmetic expressions over scalars + if (matching_types && lhs_type->IsAnyOf()) { + return true; + } + + // Binary arithmetic expressions over vectors + if (matching_types && lhs_vec_elem_type && + lhs_vec_elem_type->IsAnyOf()) { + return true; + } + } + + // Binary arithmetic expressions with mixed scalar, vector, and matrix + // operands + if (expr->IsMultiply()) { + // Multiplication of a vector and a scalar + if (lhs_type->Is() && rhs_vec_elem_type && + rhs_vec_elem_type->Is()) { + return true; + } + if (lhs_vec_elem_type && lhs_vec_elem_type->Is() && + rhs_type->Is()) { + return true; + } + + auto* lhs_mat = lhs_type->As(); + auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr; + auto* rhs_mat = rhs_type->As(); + auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr; + + // Multiplication of a matrix and a scalar + if (lhs_type->Is() && rhs_mat_elem_type && + rhs_mat_elem_type->Is()) { + return true; + } + if (lhs_mat_elem_type && lhs_mat_elem_type->Is() && + rhs_type->Is()) { + return true; + } + + // Vector times matrix + if (lhs_vec_elem_type && lhs_vec_elem_type->Is() && + rhs_mat_elem_type && rhs_mat_elem_type->Is()) { + return true; + } + + // Matrix times vector + if (lhs_mat_elem_type && lhs_mat_elem_type->Is() && + rhs_vec_elem_type && rhs_vec_elem_type->Is()) { + return true; + } + + // Matrix times matrix + if (lhs_mat_elem_type && lhs_mat_elem_type->Is() && + rhs_mat_elem_type && rhs_mat_elem_type->Is()) { + return true; + } + } + + // Comparison expressions + if (expr->IsComparison()) { + if (matching_types) { + // Special case for bools: only == and != + if (lhs_type->Is() && (expr->IsEqual() || expr->IsNotEqual())) { + return true; + } + + // For the rest, we can compare i32, u32, and f32 + if (lhs_type->IsAnyOf()) { + return true; + } + } + + // Same for vectors + if (matching_vec_elem_types) { + if (lhs_vec_elem_type->Is() && + (expr->IsEqual() || expr->IsNotEqual())) { + return true; + } + + if (lhs_vec_elem_type->IsAnyOf()) { + return true; + } + } + } + + // Binary bitwise operations + if (expr->IsBitwise()) { + if (matching_types && lhs_type->IsAnyOf()) { + return true; + } + } + + // Bit shift expressions + if (expr->IsBitshift()) { + // Type validation rules are the same for left or right shift, despite + // differences in computation rules (i.e. right shift can be arithmetic or + // logical depending on lhs type). + + if (lhs_type->IsAnyOf() && rhs_type->Is()) { + return true; + } + + if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf() && + rhs_vec_elem_type && rhs_vec_elem_type->Is()) { + return true; + } + } + + diagnostics_.add_error( + "Binary expression operand types are invalid for this operation", + expr->source()); + return false; +} + bool Resolver::Binary(ast::BinaryExpression* expr) { if (!Expression(expr->lhs()) || !Expression(expr->rhs())) { return false; } + if (!ValidateBinary(expr)) { + return false; + } + // Result type matches first parameter type if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() || expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() || diff --git a/src/resolver/resolver.h b/src/resolver/resolver.h index ddf7c11ef7..d8a3d81f11 100644 --- a/src/resolver/resolver.h +++ b/src/resolver/resolver.h @@ -171,6 +171,7 @@ class Resolver { // AST and Type traversal methods // Each return true on success, false on failure. bool ArrayAccessor(ast::ArrayAccessorExpression*); + bool ValidateBinary(ast::BinaryExpression* expr); bool Binary(ast::BinaryExpression*); bool Bitcast(ast::BitcastExpression*); bool BlockStatement(const ast::BlockStatement*); diff --git a/src/resolver/resolver_test.cc b/src/resolver/resolver_test.cc index caf608e964..7ce593f268 100644 --- a/src/resolver/resolver_test.cc +++ b/src/resolver/resolver_test.cc @@ -14,6 +14,8 @@ #include "src/resolver/resolver.h" +#include + #include "gmock/gmock.h" #include "src/ast/assignment_statement.h" #include "src/ast/bitcast_expression.h" @@ -971,246 +973,276 @@ TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) { EXPECT_TRUE(TypeOf(expr)->Is()); } -using Expr_Binary_BitwiseTest = ResolverTestWithParam; -TEST_P(Expr_Binary_BitwiseTest, Scalar) { - auto op = GetParam(); +namespace ExprBinaryTest { - Global("val", ty.i32(), ast::StorageClass::kNone); +using create_type_func_ptr = + type::Type* (*)(const ProgramBuilder::TypesBuilder& ty); - auto* expr = create(op, Expr("val"), Expr("val")); - WrapInFunction(expr); +struct Params { + ast::BinaryOp op; + create_type_func_ptr create_lhs_type; + create_type_func_ptr create_rhs_type; + create_type_func_ptr create_result_type; +}; - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - EXPECT_TRUE(TypeOf(expr)->Is()); +// Helpers and typedefs to make building the table below more succinct + +using i32 = ProgramBuilder::i32; +using u32 = ProgramBuilder::u32; +using f32 = ProgramBuilder::f32; +using Op = ast::BinaryOp; + +type::Type* ty_bool_(const ProgramBuilder::TypesBuilder& ty) { + return ty.bool_(); +} +type::Type* ty_i32(const ProgramBuilder::TypesBuilder& ty) { + return ty.i32(); +} +type::Type* ty_u32(const ProgramBuilder::TypesBuilder& ty) { + return ty.u32(); +} +type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) { + return ty.f32(); } -TEST_P(Expr_Binary_BitwiseTest, Vector) { - auto op = GetParam(); +template +type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) { + return ty.vec3(); +} - Global("val", ty.vec3(), ast::StorageClass::kNone); +template +type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) { + return ty.mat3x3(); +} - auto* expr = create(op, Expr("val"), Expr("val")); +static constexpr create_type_func_ptr all_create_type_funcs[] = { + ty_bool_, ty_u32, ty_i32, ty_f32, + ty_vec3, ty_vec3, ty_vec3, ty_vec3, + ty_mat3x3, ty_mat3x3, ty_mat3x3}; + +// A list of all valid test cases for 'lhs op rhs', except that for vecN and +// matNxN, we only test N=3. +static constexpr Params all_valid_cases[] = { + // Logical expressions + // https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr + + // Binary logical expressions + Params{Op::kLogicalAnd, ty_bool_, ty_bool_, ty_bool_}, + Params{Op::kLogicalOr, ty_bool_, ty_bool_, ty_bool_}, + + Params{Op::kAnd, ty_bool_, ty_bool_, ty_bool_}, + Params{Op::kOr, ty_bool_, ty_bool_, ty_bool_}, + Params{Op::kAnd, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kOr, ty_vec3, ty_vec3, ty_vec3}, + + // Arithmetic expressions + // https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr + + // Binary arithmetic expressions over scalars + Params{Op::kAdd, ty_i32, ty_i32, ty_i32}, + Params{Op::kSubtract, ty_i32, ty_i32, ty_i32}, + Params{Op::kMultiply, ty_i32, ty_i32, ty_i32}, + Params{Op::kDivide, ty_i32, ty_i32, ty_i32}, + Params{Op::kModulo, ty_i32, ty_i32, ty_i32}, + + Params{Op::kAdd, ty_u32, ty_u32, ty_u32}, + Params{Op::kSubtract, ty_u32, ty_u32, ty_u32}, + Params{Op::kMultiply, ty_u32, ty_u32, ty_u32}, + Params{Op::kDivide, ty_u32, ty_u32, ty_u32}, + Params{Op::kModulo, ty_u32, ty_u32, ty_u32}, + + Params{Op::kAdd, ty_f32, ty_f32, ty_f32}, + Params{Op::kSubtract, ty_f32, ty_f32, ty_f32}, + Params{Op::kMultiply, ty_f32, ty_f32, ty_f32}, + Params{Op::kDivide, ty_f32, ty_f32, ty_f32}, + Params{Op::kModulo, ty_f32, ty_f32, ty_f32}, + + // Binary arithmetic expressions over vectors + Params{Op::kAdd, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kSubtract, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kMultiply, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kDivide, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kModulo, ty_vec3, ty_vec3, ty_vec3}, + + Params{Op::kAdd, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kSubtract, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kMultiply, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kDivide, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kModulo, ty_vec3, ty_vec3, ty_vec3}, + + Params{Op::kAdd, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kSubtract, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kMultiply, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kDivide, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kModulo, ty_vec3, ty_vec3, ty_vec3}, + + // Binary arithmetic expressions with mixed scalar, vector, and matrix + // operands + Params{Op::kMultiply, ty_vec3, ty_f32, ty_vec3}, + Params{Op::kMultiply, ty_f32, ty_vec3, ty_vec3}, + + Params{Op::kMultiply, ty_mat3x3, ty_f32, ty_mat3x3}, + Params{Op::kMultiply, ty_f32, ty_mat3x3, ty_mat3x3}, + + Params{Op::kMultiply, ty_vec3, ty_mat3x3, ty_vec3}, + Params{Op::kMultiply, ty_mat3x3, ty_vec3, ty_vec3}, + Params{Op::kMultiply, ty_mat3x3, ty_mat3x3, ty_mat3x3}, + + // Comparison expressions + // https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr + + // Comparisons over scalars + Params{Op::kEqual, ty_bool_, ty_bool_, ty_bool_}, + Params{Op::kNotEqual, ty_bool_, ty_bool_, ty_bool_}, + + Params{Op::kEqual, ty_i32, ty_i32, ty_bool_}, + Params{Op::kNotEqual, ty_i32, ty_i32, ty_bool_}, + Params{Op::kLessThan, ty_i32, ty_i32, ty_bool_}, + Params{Op::kLessThanEqual, ty_i32, ty_i32, ty_bool_}, + Params{Op::kGreaterThan, ty_i32, ty_i32, ty_bool_}, + Params{Op::kGreaterThanEqual, ty_i32, ty_i32, ty_bool_}, + + Params{Op::kEqual, ty_u32, ty_u32, ty_bool_}, + Params{Op::kNotEqual, ty_u32, ty_u32, ty_bool_}, + Params{Op::kLessThan, ty_u32, ty_u32, ty_bool_}, + Params{Op::kLessThanEqual, ty_u32, ty_u32, ty_bool_}, + Params{Op::kGreaterThan, ty_u32, ty_u32, ty_bool_}, + Params{Op::kGreaterThanEqual, ty_u32, ty_u32, ty_bool_}, + + Params{Op::kEqual, ty_f32, ty_f32, ty_bool_}, + Params{Op::kNotEqual, ty_f32, ty_f32, ty_bool_}, + Params{Op::kLessThan, ty_f32, ty_f32, ty_bool_}, + Params{Op::kLessThanEqual, ty_f32, ty_f32, ty_bool_}, + Params{Op::kGreaterThan, ty_f32, ty_f32, ty_bool_}, + Params{Op::kGreaterThanEqual, ty_f32, ty_f32, ty_bool_}, + + // Comparisons over vectors + Params{Op::kEqual, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kNotEqual, ty_vec3, ty_vec3, ty_vec3}, + + Params{Op::kEqual, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kNotEqual, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kLessThan, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kLessThanEqual, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kGreaterThan, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kGreaterThanEqual, ty_vec3, ty_vec3, ty_vec3}, + + Params{Op::kEqual, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kNotEqual, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kLessThan, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kLessThanEqual, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kGreaterThan, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kGreaterThanEqual, ty_vec3, ty_vec3, ty_vec3}, + + Params{Op::kEqual, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kNotEqual, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kLessThan, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kLessThanEqual, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kGreaterThan, ty_vec3, ty_vec3, ty_vec3}, + Params{Op::kGreaterThanEqual, ty_vec3, ty_vec3, ty_vec3}, + + // Bit expressions + // https://gpuweb.github.io/gpuweb/wgsl.html#bit-expr + + // Binary bitwise operations + Params{Op::kOr, ty_i32, ty_i32, ty_i32}, + Params{Op::kAnd, ty_i32, ty_i32, ty_i32}, + Params{Op::kXor, ty_i32, ty_i32, ty_i32}, + + Params{Op::kOr, ty_u32, ty_u32, ty_u32}, + Params{Op::kAnd, ty_u32, ty_u32, ty_u32}, + Params{Op::kXor, ty_u32, ty_u32, ty_u32}, + + // Bit shift expressions + Params{Op::kShiftLeft, ty_i32, ty_u32, ty_i32}, + Params{Op::kShiftLeft, ty_vec3, ty_vec3, ty_vec3}, + + Params{Op::kShiftLeft, ty_u32, ty_u32, ty_u32}, + Params{Op::kShiftLeft, ty_vec3, ty_vec3, ty_vec3}, + + Params{Op::kShiftRight, ty_i32, ty_u32, ty_i32}, + Params{Op::kShiftRight, ty_vec3, ty_vec3, ty_vec3}, + + Params{Op::kShiftRight, ty_u32, ty_u32, ty_u32}, + Params{Op::kShiftRight, ty_vec3, ty_vec3, ty_vec3}}; + +using Expr_Binary_Test_Valid = ResolverTestWithParam; +TEST_P(Expr_Binary_Test_Valid, All) { + auto& params = GetParam(); + + auto* lhs_type = params.create_lhs_type(ty); + auto* rhs_type = params.create_rhs_type(ty); + auto* result_type = params.create_result_type(ty); + + SCOPED_TRACE(testing::Message() + << lhs_type->FriendlyName(Symbols()) << " " << params.op << " " + << rhs_type->FriendlyName(Symbols())); + + Global("lhs", lhs_type, ast::StorageClass::kNone); + Global("rhs", rhs_type, ast::StorageClass::kNone); + + auto* expr = + create(params.op, Expr("lhs"), Expr("rhs")); WrapInFunction(expr); ASSERT_TRUE(r()->Resolve()) << r()->error(); ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - EXPECT_TRUE(TypeOf(expr)->As()->type()->Is()); - EXPECT_EQ(TypeOf(expr)->As()->size(), 3u); + ASSERT_TRUE(TypeOf(expr) == result_type); } INSTANTIATE_TEST_SUITE_P(ResolverTest, - Expr_Binary_BitwiseTest, - testing::Values(ast::BinaryOp::kAnd, - ast::BinaryOp::kOr, - ast::BinaryOp::kXor, - ast::BinaryOp::kShiftLeft, - ast::BinaryOp::kShiftRight, - ast::BinaryOp::kAdd, - ast::BinaryOp::kSubtract, - ast::BinaryOp::kDivide, - ast::BinaryOp::kModulo)); + Expr_Binary_Test_Valid, + testing::ValuesIn(all_valid_cases)); -using Expr_Binary_LogicalTest = ResolverTestWithParam; -TEST_P(Expr_Binary_LogicalTest, Scalar) { - auto op = GetParam(); +using Expr_Binary_Test_Invalid = + ResolverTestWithParam>; +TEST_P(Expr_Binary_Test_Invalid, All) { + const Params& params = std::get<0>(GetParam()); + const create_type_func_ptr& create_type_func = std::get<1>(GetParam()); - Global("val", ty.bool_(), ast::StorageClass::kNone); + // Currently, for most operations, for a given lhs type, there is exactly one + // rhs type allowed. The only exception is for multiplication, which allows + // any permutation of f32, vecN, and matNxN. We are fed valid inputs + // only via `params`, and all possible types via `create_type_func`, so we + // test invalid combinations by testing every other rhs type, modulo + // exceptions. - auto* expr = create(op, Expr("val"), Expr("val")); + // Skip valid rhs type + if (params.create_rhs_type == create_type_func) { + return; + } + + auto* lhs_type = params.create_lhs_type(ty); + auto* rhs_type = create_type_func(ty); + + // Skip exceptions: multiplication of f32, vecN, and matNxN + if (params.op == Op::kMultiply && + lhs_type->is_float_scalar_or_vector_or_matrix() && + rhs_type->is_float_scalar_or_vector_or_matrix()) { + return; + } + + SCOPED_TRACE(testing::Message() + << lhs_type->FriendlyName(Symbols()) << " " << params.op << " " + << rhs_type->FriendlyName(Symbols())); + + Global("lhs", lhs_type, ast::StorageClass::kNone); + Global("rhs", rhs_type, ast::StorageClass::kNone); + + auto* expr = create( + Source{Source::Location{12, 34}}, params.op, Expr("lhs"), Expr("rhs")); WrapInFunction(expr); - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - EXPECT_TRUE(TypeOf(expr)->Is()); -} - -TEST_P(Expr_Binary_LogicalTest, Vector) { - auto op = GetParam(); - - Global("val", ty.vec3(), ast::StorageClass::kNone); - - auto* expr = create(op, Expr("val"), Expr("val")); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - EXPECT_TRUE(TypeOf(expr)->As()->type()->Is()); - EXPECT_EQ(TypeOf(expr)->As()->size(), 3u); -} -INSTANTIATE_TEST_SUITE_P(ResolverTest, - Expr_Binary_LogicalTest, - testing::Values(ast::BinaryOp::kLogicalAnd, - ast::BinaryOp::kLogicalOr)); - -using Expr_Binary_CompareTest = ResolverTestWithParam; -TEST_P(Expr_Binary_CompareTest, Scalar) { - auto op = GetParam(); - - Global("val", ty.i32(), ast::StorageClass::kNone); - - auto* expr = create(op, Expr("val"), Expr("val")); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - EXPECT_TRUE(TypeOf(expr)->Is()); -} - -TEST_P(Expr_Binary_CompareTest, Vector) { - auto op = GetParam(); - - Global("val", ty.vec3(), ast::StorageClass::kNone); - - auto* expr = create(op, Expr("val"), Expr("val")); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - EXPECT_TRUE(TypeOf(expr)->As()->type()->Is()); - EXPECT_EQ(TypeOf(expr)->As()->size(), 3u); -} -INSTANTIATE_TEST_SUITE_P(ResolverTest, - Expr_Binary_CompareTest, - testing::Values(ast::BinaryOp::kEqual, - ast::BinaryOp::kNotEqual, - ast::BinaryOp::kLessThan, - ast::BinaryOp::kGreaterThan, - ast::BinaryOp::kLessThanEqual, - ast::BinaryOp::kGreaterThanEqual)); - -TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Scalar) { - Global("val", ty.i32(), ast::StorageClass::kNone); - - auto* expr = Mul("val", "val"); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - EXPECT_TRUE(TypeOf(expr)->Is()); -} - -TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Scalar) { - Global("scalar", ty.f32(), ast::StorageClass::kNone); - Global("vector", ty.vec3(), ast::StorageClass::kNone); - - auto* expr = Mul("vector", "scalar"); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - EXPECT_TRUE(TypeOf(expr)->As()->type()->Is()); - EXPECT_EQ(TypeOf(expr)->As()->size(), 3u); -} - -TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Vector) { - Global("scalar", ty.f32(), ast::StorageClass::kNone); - Global("vector", ty.vec3(), ast::StorageClass::kNone); - - auto* expr = Mul("scalar", "vector"); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - EXPECT_TRUE(TypeOf(expr)->As()->type()->Is()); - EXPECT_EQ(TypeOf(expr)->As()->size(), 3u); -} - -TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Vector) { - Global("vector", ty.vec3(), ast::StorageClass::kNone); - - auto* expr = Mul("vector", "vector"); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - EXPECT_TRUE(TypeOf(expr)->As()->type()->Is()); - EXPECT_EQ(TypeOf(expr)->As()->size(), 3u); -} - -TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Scalar) { - Global("scalar", ty.f32(), ast::StorageClass::kNone); - Global("matrix", ty.mat2x3(), ast::StorageClass::kNone); - - auto* expr = Mul("matrix", "scalar"); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - - auto* mat = TypeOf(expr)->As(); - EXPECT_TRUE(mat->type()->Is()); - EXPECT_EQ(mat->rows(), 3u); - EXPECT_EQ(mat->columns(), 2u); -} - -TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Matrix) { - Global("scalar", ty.f32(), ast::StorageClass::kNone); - Global("matrix", ty.mat2x3(), ast::StorageClass::kNone); - - auto* expr = Mul("scalar", "matrix"); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - - auto* mat = TypeOf(expr)->As(); - EXPECT_TRUE(mat->type()->Is()); - EXPECT_EQ(mat->rows(), 3u); - EXPECT_EQ(mat->columns(), 2u); -} - -TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Vector) { - Global("vector", ty.vec3(), ast::StorageClass::kNone); - Global("matrix", ty.mat2x3(), ast::StorageClass::kNone); - - auto* expr = Mul("matrix", "vector"); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - EXPECT_TRUE(TypeOf(expr)->As()->type()->Is()); - EXPECT_EQ(TypeOf(expr)->As()->size(), 3u); -} - -TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Matrix) { - Global("vector", ty.vec3(), ast::StorageClass::kNone); - Global("matrix", ty.mat2x3(), ast::StorageClass::kNone); - - auto* expr = Mul("vector", "matrix"); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - EXPECT_TRUE(TypeOf(expr)->As()->type()->Is()); - EXPECT_EQ(TypeOf(expr)->As()->size(), 2u); -} - -TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Matrix) { - Global("mat3x4", ty.mat3x4(), ast::StorageClass::kNone); - Global("mat4x3", ty.mat4x3(), ast::StorageClass::kNone); - - auto* expr = Mul("mat3x4", "mat4x3"); - WrapInFunction(expr); - - ASSERT_TRUE(r()->Resolve()) << r()->error(); - ASSERT_NE(TypeOf(expr), nullptr); - ASSERT_TRUE(TypeOf(expr)->Is()); - - auto* mat = TypeOf(expr)->As(); - EXPECT_TRUE(mat->type()->Is()); - EXPECT_EQ(mat->rows(), 4u); - EXPECT_EQ(mat->columns(), 4u); + ASSERT_FALSE(r()->Resolve()) << r()->error(); + ASSERT_EQ(r()->error(), + "12:34 error: Binary expression operand types are invalid for " + "this operation"); } +INSTANTIATE_TEST_SUITE_P( + ResolverTest, + Expr_Binary_Test_Invalid, + testing::Combine(testing::ValuesIn(all_valid_cases), + testing::ValuesIn(all_create_type_funcs))); +} // namespace ExprBinaryTest using UnaryOpExpressionTest = ResolverTestWithParam; TEST_P(UnaryOpExpressionTest, Expr_UnaryOp) { diff --git a/src/transform/bound_array_accessors_test.cc b/src/transform/bound_array_accessors_test.cc index 67a831cf3f..7eedc187f7 100644 --- a/src/transform/bound_array_accessors_test.cc +++ b/src/transform/bound_array_accessors_test.cc @@ -104,7 +104,7 @@ TEST_F(BoundArrayAccessorsTest, Array_Idx_Expr) { auto* src = R"( var a : array; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a[c + 2 - 3]; @@ -114,7 +114,7 @@ fn f() -> void { auto* expect = R"( var a : array; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a[min(u32(((c + 2) - 3)), 2u)]; @@ -196,7 +196,7 @@ TEST_F(BoundArrayAccessorsTest, Vector_Idx_Expr) { auto* src = R"( var a : vec3; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a[c + 2 - 3]; @@ -206,7 +206,7 @@ fn f() -> void { auto* expect = R"( var a : vec3; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a[min(u32(((c + 2) - 3)), 2u)]; @@ -244,7 +244,7 @@ TEST_F(BoundArrayAccessorsTest, Vector_Swizzle_Idx_Var) { auto* src = R"( var a : vec3; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a.xy[c]; @@ -254,7 +254,7 @@ fn f() -> void { auto* expect = R"( var a : vec3; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a.xy[min(u32(c), 1u)]; @@ -269,7 +269,7 @@ TEST_F(BoundArrayAccessorsTest, Vector_Swizzle_Idx_Expr) { auto* src = R"( var a : vec3; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a.xy[c + 2 - 3]; @@ -279,7 +279,7 @@ fn f() -> void { auto* expect = R"( var a : vec3; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a.xy[min(u32(((c + 2) - 3)), 1u)]; @@ -361,7 +361,7 @@ TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Column) { auto* src = R"( var a : mat3x2; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a[c + 2 - 3][1]; @@ -371,7 +371,7 @@ fn f() -> void { auto* expect = R"( var a : mat3x2; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a[min(u32(((c + 2) - 3)), 2u)][1]; @@ -387,7 +387,7 @@ TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Row) { auto* src = R"( var a : mat3x2; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a[1][c + 2 - 3]; @@ -397,7 +397,7 @@ fn f() -> void { auto* expect = R"( var a : mat3x2; -var c : u32; +var c : i32; fn f() -> void { var b : f32 = a[1][min(u32(((c + 2) - 3)), 1u)]; diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index b82a4d56c1..78e96c7482 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -132,7 +132,7 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() { Source{}, // source ctx.dst->Symbols().Register(vertex_index_name), // symbol ast::StorageClass::kInput, // storage_class - GetI32Type(), // type + GetU32Type(), // type false, // is_const nullptr, // constructor ast::DecorationList{ @@ -179,7 +179,7 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() { Source{}, // source ctx.dst->Symbols().Register(instance_index_name), // symbol ast::StorageClass::kInput, // storage_class - GetI32Type(), // type + GetU32Type(), // type false, // is_const nullptr, // constructor ast::DecorationList{ @@ -273,7 +273,7 @@ ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const { Source{}, // source ctx.dst->Symbols().Register(kPullingPosVarName), // symbol ast::StorageClass::kFunction, // storage_class - GetI32Type(), // type + GetU32Type(), // type false, // is_const nullptr, // constructor ast::DecorationList{})); // decorations diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc index d36dcaaf9f..c0b6e19f16 100644 --- a/src/transform/vertex_pulling_test.cc +++ b/src/transform/vertex_pulling_test.cc @@ -89,7 +89,7 @@ struct TintVertexData { [[stage(vertex)]] fn main() -> void { { - var _tint_pulling_pos : i32; + var _tint_pulling_pos : u32; } } )"; @@ -113,7 +113,7 @@ fn main() -> void {} )"; auto* expect = R"( -[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; +[[builtin(vertex_index)]] var _tint_pulling_vertex_index : u32; [[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; @@ -127,7 +127,7 @@ var var_a : f32; [[stage(vertex)]] fn main() -> void { { - var _tint_pulling_pos : i32; + var _tint_pulling_pos : u32; _tint_pulling_pos = ((_tint_pulling_vertex_index * 4u) + 0u); var_a = bitcast(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]); } @@ -155,7 +155,7 @@ fn main() -> void {} )"; auto* expect = R"( -[[builtin(instance_index)]] var _tint_pulling_instance_index : i32; +[[builtin(instance_index)]] var _tint_pulling_instance_index : u32; [[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; @@ -169,7 +169,7 @@ var var_a : f32; [[stage(vertex)]] fn main() -> void { { - var _tint_pulling_pos : i32; + var _tint_pulling_pos : u32; _tint_pulling_pos = ((_tint_pulling_instance_index * 4u) + 0u); var_a = bitcast(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]); } @@ -197,7 +197,7 @@ fn main() -> void {} )"; auto* expect = R"( -[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; +[[builtin(vertex_index)]] var _tint_pulling_vertex_index : u32; [[binding(0), group(5)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; @@ -211,7 +211,7 @@ var var_a : f32; [[stage(vertex)]] fn main() -> void { { - var _tint_pulling_pos : i32; + var _tint_pulling_pos : u32; _tint_pulling_pos = ((_tint_pulling_vertex_index * 4u) + 0u); var_a = bitcast(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]); } @@ -236,8 +236,8 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) { auto* src = R"( [[location(0)]] var var_a : f32; [[location(1)]] var var_b : f32; -[[builtin(vertex_index)]] var custom_vertex_index : i32; -[[builtin(instance_index)]] var custom_instance_index : i32; +[[builtin(vertex_index)]] var custom_vertex_index : u32; +[[builtin(instance_index)]] var custom_instance_index : u32; [[stage(vertex)]] fn main() -> void {} @@ -257,14 +257,14 @@ var var_a : f32; var var_b : f32; -[[builtin(vertex_index)]] var custom_vertex_index : i32; +[[builtin(vertex_index)]] var custom_vertex_index : u32; -[[builtin(instance_index)]] var custom_instance_index : i32; +[[builtin(instance_index)]] var custom_instance_index : u32; [[stage(vertex)]] fn main() -> void { { - var _tint_pulling_pos : i32; + var _tint_pulling_pos : u32; _tint_pulling_pos = ((custom_vertex_index * 4u) + 0u); var_a = bitcast(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]); _tint_pulling_pos = ((custom_instance_index * 4u) + 0u); @@ -305,7 +305,7 @@ fn main() -> void {} )"; auto* expect = R"( -[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; +[[builtin(vertex_index)]] var _tint_pulling_vertex_index : u32; [[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; @@ -321,7 +321,7 @@ var var_b : array; [[stage(vertex)]] fn main() -> void { { - var _tint_pulling_pos : i32; + var _tint_pulling_pos : u32; _tint_pulling_pos = ((_tint_pulling_vertex_index * 16u) + 0u); var_a = bitcast(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]); _tint_pulling_pos = ((_tint_pulling_vertex_index * 16u) + 0u); @@ -355,7 +355,7 @@ fn main() -> void {} )"; auto* expect = R"( -[[builtin(vertex_index)]] var _tint_pulling_vertex_index : i32; +[[builtin(vertex_index)]] var _tint_pulling_vertex_index : u32; [[binding(0), group(4)]] var _tint_pulling_vertex_buffer_0 : TintVertexData; @@ -377,7 +377,7 @@ var var_c : array; [[stage(vertex)]] fn main() -> void { { - var _tint_pulling_pos : i32; + var _tint_pulling_pos : u32; _tint_pulling_pos = ((_tint_pulling_vertex_index * 8u) + 0u); var_a = vec2(bitcast(_tint_pulling_vertex_buffer_0._tint_vertex_data[((_tint_pulling_pos + 0u) / 4u)]), bitcast(_tint_pulling_vertex_buffer_0._tint_vertex_data[((_tint_pulling_pos + 4u) / 4u)])); _tint_pulling_pos = ((_tint_pulling_vertex_index * 12u) + 0u); diff --git a/src/type/type.cc b/src/type/type.cc index 2a7d6ba991..f74c20c046 100644 --- a/src/type/type.cc +++ b/src/type/type.cc @@ -92,6 +92,10 @@ bool Type::is_float_scalar_or_vector() const { return is_float_scalar() || is_float_vector(); } +bool Type::is_float_scalar_or_vector_or_matrix() const { + return is_float_scalar() || is_float_vector() || is_float_matrix(); +} + bool Type::is_integer_scalar() const { return IsAnyOf(); } diff --git a/src/type/type.h b/src/type/type.h index c4e7a36574..0dd9eafefb 100644 --- a/src/type/type.h +++ b/src/type/type.h @@ -77,6 +77,8 @@ class Type : public Castable { bool is_float_vector() const; /// @returns true if this type is a float scalar or vector bool is_float_scalar_or_vector() const; + /// @returns true if this type is a float scalar or vector or matrix + bool is_float_scalar_or_vector_or_matrix() const; /// @returns true if this type is an integer scalar bool is_integer_scalar() const; /// @returns true if this type is a signed integer vector diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc index 4b748b934a..0f65366cc9 100644 --- a/src/writer/hlsl/generator_impl_binary_test.cc +++ b/src/writer/hlsl/generator_impl_binary_test.cc @@ -36,6 +36,14 @@ using HlslBinaryTest = TestParamHelper; TEST_P(HlslBinaryTest, Emit_f32) { auto params = GetParam(); + // Skip ops that are illegal for this type + if (params.op == ast::BinaryOp::kAnd || params.op == ast::BinaryOp::kOr || + params.op == ast::BinaryOp::kXor || + params.op == ast::BinaryOp::kShiftLeft || + params.op == ast::BinaryOp::kShiftRight) { + return; + } + Global("left", ty.f32(), ast::StorageClass::kFunction); Global("right", ty.f32(), ast::StorageClass::kFunction); @@ -72,6 +80,12 @@ TEST_P(HlslBinaryTest, Emit_u32) { TEST_P(HlslBinaryTest, Emit_i32) { auto params = GetParam(); + // Skip ops that are illegal for this type + if (params.op == ast::BinaryOp::kShiftLeft || + params.op == ast::BinaryOp::kShiftRight) { + return; + } + Global("left", ty.i32(), ast::StorageClass::kFunction); Global("right", ty.i32(), ast::StorageClass::kFunction); diff --git a/src/writer/spirv/builder_binary_expression_test.cc b/src/writer/spirv/builder_binary_expression_test.cc index 6a44c9a0c1..55000ce6e3 100644 --- a/src/writer/spirv/builder_binary_expression_test.cc +++ b/src/writer/spirv/builder_binary_expression_test.cc @@ -58,6 +58,12 @@ TEST_P(BinaryArithSignedIntegerTest, Scalar) { TEST_P(BinaryArithSignedIntegerTest, Vector) { auto param = GetParam(); + // Skip ops that are illegal for this type + if (param.op == ast::BinaryOp::kAnd || param.op == ast::BinaryOp::kOr || + param.op == ast::BinaryOp::kXor) { + return; + } + auto* lhs = vec3(1, 1, 1); auto* rhs = vec3(1, 1, 1); @@ -111,15 +117,13 @@ TEST_P(BinaryArithSignedIntegerTest, Scalar_Loads) { INSTANTIATE_TEST_SUITE_P( BuilderTest, BinaryArithSignedIntegerTest, + // NOTE: No left and right shift as they require u32 for rhs operand testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpIAdd"}, BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"}, BinaryData{ast::BinaryOp::kDivide, "OpSDiv"}, BinaryData{ast::BinaryOp::kModulo, "OpSMod"}, BinaryData{ast::BinaryOp::kMultiply, "OpIMul"}, BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"}, - BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"}, - BinaryData{ast::BinaryOp::kShiftRight, - "OpShiftRightArithmetic"}, BinaryData{ast::BinaryOp::kSubtract, "OpISub"}, BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"})); @@ -149,6 +153,12 @@ TEST_P(BinaryArithUnsignedIntegerTest, Scalar) { TEST_P(BinaryArithUnsignedIntegerTest, Vector) { auto param = GetParam(); + // Skip ops that are illegal for this type + if (param.op == ast::BinaryOp::kAnd || param.op == ast::BinaryOp::kOr || + param.op == ast::BinaryOp::kXor) { + return; + } + auto* lhs = vec3(1u, 1u, 1u); auto* rhs = vec3(1u, 1u, 1u);