Validate binary operations
This change validates that the operand types and result type of every binary operation is valid. * Added two unit tests which test all valid and invalid param combos. I also removed the old tests, many of which failed once I added this validation, and the rest are obviated by the new tests. * Fixed VertexPulling transform, as well as many tests, that were using invalid operand types for binary operations. Fixed: tint:354 Change-Id: Ia3f48384256993da61b341f17ba5583741011819 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44341 Reviewed-by: Ben Clayton <bclayton@google.com> Commit-Queue: Antonio Maiorano <amaiorano@google.com>
This commit is contained in:
parent
1691401179
commit
be0fc4e929
|
@ -23,11 +23,11 @@ namespace ast {
|
|||
/// The operator type
|
||||
enum class BinaryOp {
|
||||
kNone = 0,
|
||||
kAnd,
|
||||
kOr,
|
||||
kAnd, // &
|
||||
kOr, // |
|
||||
kXor,
|
||||
kLogicalAnd,
|
||||
kLogicalOr,
|
||||
kLogicalAnd, // &&
|
||||
kLogicalOr, // ||
|
||||
kEqual,
|
||||
kNotEqual,
|
||||
kLessThan,
|
||||
|
@ -98,6 +98,14 @@ class BinaryExpression : public Castable<BinaryExpression, Expression> {
|
|||
bool IsDivide() const { return op_ == BinaryOp::kDivide; }
|
||||
/// @returns true if the op is modulo
|
||||
bool IsModulo() const { return op_ == BinaryOp::kModulo; }
|
||||
/// @returns true if the op is an arithmetic operation
|
||||
bool IsArithmetic() const;
|
||||
/// @returns true if the op is a comparison operation
|
||||
bool IsComparison() const;
|
||||
/// @returns true if the op is a bitwise operation
|
||||
bool IsBitwise() const;
|
||||
/// @returns true if the op is a bit shift operation
|
||||
bool IsBitshift() const;
|
||||
|
||||
/// @returns the left side expression
|
||||
Expression* lhs() const { return lhs_; }
|
||||
|
@ -126,6 +134,54 @@ class BinaryExpression : public Castable<BinaryExpression, Expression> {
|
|||
Expression* const rhs_;
|
||||
};
|
||||
|
||||
inline bool BinaryExpression::IsArithmetic() const {
|
||||
switch (op_) {
|
||||
case ast::BinaryOp::kAdd:
|
||||
case ast::BinaryOp::kSubtract:
|
||||
case ast::BinaryOp::kMultiply:
|
||||
case ast::BinaryOp::kDivide:
|
||||
case ast::BinaryOp::kModulo:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool BinaryExpression::IsComparison() const {
|
||||
switch (op_) {
|
||||
case ast::BinaryOp::kEqual:
|
||||
case ast::BinaryOp::kNotEqual:
|
||||
case ast::BinaryOp::kLessThan:
|
||||
case ast::BinaryOp::kLessThanEqual:
|
||||
case ast::BinaryOp::kGreaterThan:
|
||||
case ast::BinaryOp::kGreaterThanEqual:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool BinaryExpression::IsBitwise() const {
|
||||
switch (op_) {
|
||||
case ast::BinaryOp::kAnd:
|
||||
case ast::BinaryOp::kOr:
|
||||
case ast::BinaryOp::kXor:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline bool BinaryExpression::IsBitshift() const {
|
||||
switch (op_) {
|
||||
case ast::BinaryOp::kShiftLeft:
|
||||
case ast::BinaryOp::kShiftRight:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, BinaryOp op) {
|
||||
switch (op) {
|
||||
case BinaryOp::kNone:
|
||||
|
|
|
@ -877,11 +877,167 @@ bool Resolver::MemberAccessor(ast::MemberAccessorExpression* expr) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool Resolver::ValidateBinary(ast::BinaryExpression* expr) {
|
||||
using Bool = type::Bool;
|
||||
using F32 = type::F32;
|
||||
using I32 = type::I32;
|
||||
using U32 = type::U32;
|
||||
using Matrix = type::Matrix;
|
||||
using Vector = type::Vector;
|
||||
|
||||
auto* lhs_type = TypeOf(expr->lhs())->UnwrapPtrIfNeeded();
|
||||
auto* rhs_type = TypeOf(expr->rhs())->UnwrapPtrIfNeeded();
|
||||
|
||||
auto* lhs_vec = lhs_type->As<Vector>();
|
||||
auto* lhs_vec_elem_type = lhs_vec ? lhs_vec->type() : nullptr;
|
||||
auto* rhs_vec = rhs_type->As<Vector>();
|
||||
auto* rhs_vec_elem_type = rhs_vec ? rhs_vec->type() : nullptr;
|
||||
|
||||
const bool matching_types = lhs_type == rhs_type;
|
||||
const bool matching_vec_elem_types = lhs_vec_elem_type && rhs_vec_elem_type &&
|
||||
(lhs_vec_elem_type == rhs_vec_elem_type);
|
||||
|
||||
// Binary logical expressions
|
||||
if (expr->IsLogicalAnd() || expr->IsLogicalOr()) {
|
||||
if (matching_types && lhs_type->Is<Bool>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (expr->IsOr() || expr->IsAnd()) {
|
||||
if (matching_types && lhs_type->Is<Bool>()) {
|
||||
return true;
|
||||
}
|
||||
if (matching_types && lhs_vec_elem_type && lhs_vec_elem_type->Is<Bool>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Arithmetic expressions
|
||||
if (expr->IsArithmetic()) {
|
||||
// Binary arithmetic expressions over scalars
|
||||
if (matching_types && lhs_type->IsAnyOf<I32, F32, U32>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Binary arithmetic expressions over vectors
|
||||
if (matching_types && lhs_vec_elem_type &&
|
||||
lhs_vec_elem_type->IsAnyOf<I32, F32, U32>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Binary arithmetic expressions with mixed scalar, vector, and matrix
|
||||
// operands
|
||||
if (expr->IsMultiply()) {
|
||||
// Multiplication of a vector and a scalar
|
||||
if (lhs_type->Is<F32>() && rhs_vec_elem_type &&
|
||||
rhs_vec_elem_type->Is<F32>()) {
|
||||
return true;
|
||||
}
|
||||
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
|
||||
rhs_type->Is<F32>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
auto* lhs_mat = lhs_type->As<Matrix>();
|
||||
auto* lhs_mat_elem_type = lhs_mat ? lhs_mat->type() : nullptr;
|
||||
auto* rhs_mat = rhs_type->As<Matrix>();
|
||||
auto* rhs_mat_elem_type = rhs_mat ? rhs_mat->type() : nullptr;
|
||||
|
||||
// Multiplication of a matrix and a scalar
|
||||
if (lhs_type->Is<F32>() && rhs_mat_elem_type &&
|
||||
rhs_mat_elem_type->Is<F32>()) {
|
||||
return true;
|
||||
}
|
||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
||||
rhs_type->Is<F32>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Vector times matrix
|
||||
if (lhs_vec_elem_type && lhs_vec_elem_type->Is<F32>() &&
|
||||
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Matrix times vector
|
||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
||||
rhs_vec_elem_type && rhs_vec_elem_type->Is<F32>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Matrix times matrix
|
||||
if (lhs_mat_elem_type && lhs_mat_elem_type->Is<F32>() &&
|
||||
rhs_mat_elem_type && rhs_mat_elem_type->Is<F32>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Comparison expressions
|
||||
if (expr->IsComparison()) {
|
||||
if (matching_types) {
|
||||
// Special case for bools: only == and !=
|
||||
if (lhs_type->Is<Bool>() && (expr->IsEqual() || expr->IsNotEqual())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// For the rest, we can compare i32, u32, and f32
|
||||
if (lhs_type->IsAnyOf<I32, U32, F32>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Same for vectors
|
||||
if (matching_vec_elem_types) {
|
||||
if (lhs_vec_elem_type->Is<Bool>() &&
|
||||
(expr->IsEqual() || expr->IsNotEqual())) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (lhs_vec_elem_type->IsAnyOf<I32, U32, F32>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Binary bitwise operations
|
||||
if (expr->IsBitwise()) {
|
||||
if (matching_types && lhs_type->IsAnyOf<I32, U32>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// Bit shift expressions
|
||||
if (expr->IsBitshift()) {
|
||||
// Type validation rules are the same for left or right shift, despite
|
||||
// differences in computation rules (i.e. right shift can be arithmetic or
|
||||
// logical depending on lhs type).
|
||||
|
||||
if (lhs_type->IsAnyOf<I32, U32>() && rhs_type->Is<U32>()) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if (lhs_vec_elem_type && lhs_vec_elem_type->IsAnyOf<I32, U32>() &&
|
||||
rhs_vec_elem_type && rhs_vec_elem_type->Is<U32>()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
diagnostics_.add_error(
|
||||
"Binary expression operand types are invalid for this operation",
|
||||
expr->source());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool Resolver::Binary(ast::BinaryExpression* expr) {
|
||||
if (!Expression(expr->lhs()) || !Expression(expr->rhs())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!ValidateBinary(expr)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Result type matches first parameter type
|
||||
if (expr->IsAnd() || expr->IsOr() || expr->IsXor() || expr->IsShiftLeft() ||
|
||||
expr->IsShiftRight() || expr->IsAdd() || expr->IsSubtract() ||
|
||||
|
|
|
@ -171,6 +171,7 @@ class Resolver {
|
|||
// AST and Type traversal methods
|
||||
// Each return true on success, false on failure.
|
||||
bool ArrayAccessor(ast::ArrayAccessorExpression*);
|
||||
bool ValidateBinary(ast::BinaryExpression* expr);
|
||||
bool Binary(ast::BinaryExpression*);
|
||||
bool Bitcast(ast::BitcastExpression*);
|
||||
bool BlockStatement(const ast::BlockStatement*);
|
||||
|
|
|
@ -14,6 +14,8 @@
|
|||
|
||||
#include "src/resolver/resolver.h"
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gmock/gmock.h"
|
||||
#include "src/ast/assignment_statement.h"
|
||||
#include "src/ast/bitcast_expression.h"
|
||||
|
@ -971,246 +973,276 @@ TEST_F(ResolverTest, Expr_MemberAccessor_InBinaryOp) {
|
|||
EXPECT_TRUE(TypeOf(expr)->Is<type::F32>());
|
||||
}
|
||||
|
||||
using Expr_Binary_BitwiseTest = ResolverTestWithParam<ast::BinaryOp>;
|
||||
TEST_P(Expr_Binary_BitwiseTest, Scalar) {
|
||||
auto op = GetParam();
|
||||
namespace ExprBinaryTest {
|
||||
|
||||
Global("val", ty.i32(), ast::StorageClass::kNone);
|
||||
using create_type_func_ptr =
|
||||
type::Type* (*)(const ProgramBuilder::TypesBuilder& ty);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
|
||||
WrapInFunction(expr);
|
||||
struct Params {
|
||||
ast::BinaryOp op;
|
||||
create_type_func_ptr create_lhs_type;
|
||||
create_type_func_ptr create_rhs_type;
|
||||
create_type_func_ptr create_result_type;
|
||||
};
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
EXPECT_TRUE(TypeOf(expr)->Is<type::I32>());
|
||||
// Helpers and typedefs to make building the table below more succinct
|
||||
|
||||
using i32 = ProgramBuilder::i32;
|
||||
using u32 = ProgramBuilder::u32;
|
||||
using f32 = ProgramBuilder::f32;
|
||||
using Op = ast::BinaryOp;
|
||||
|
||||
type::Type* ty_bool_(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.bool_();
|
||||
}
|
||||
type::Type* ty_i32(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.i32();
|
||||
}
|
||||
type::Type* ty_u32(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.u32();
|
||||
}
|
||||
type::Type* ty_f32(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.f32();
|
||||
}
|
||||
|
||||
TEST_P(Expr_Binary_BitwiseTest, Vector) {
|
||||
auto op = GetParam();
|
||||
template <typename T>
|
||||
type::Type* ty_vec3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.vec3<T>();
|
||||
}
|
||||
|
||||
Global("val", ty.vec3<i32>(), ast::StorageClass::kNone);
|
||||
template <typename T>
|
||||
type::Type* ty_mat3x3(const ProgramBuilder::TypesBuilder& ty) {
|
||||
return ty.mat3x3<T>();
|
||||
}
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
|
||||
static constexpr create_type_func_ptr all_create_type_funcs[] = {
|
||||
ty_bool_, ty_u32, ty_i32, ty_f32,
|
||||
ty_vec3<bool>, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<f32>,
|
||||
ty_mat3x3<i32>, ty_mat3x3<u32>, ty_mat3x3<f32>};
|
||||
|
||||
// A list of all valid test cases for 'lhs op rhs', except that for vecN and
|
||||
// matNxN, we only test N=3.
|
||||
static constexpr Params all_valid_cases[] = {
|
||||
// Logical expressions
|
||||
// https://gpuweb.github.io/gpuweb/wgsl.html#logical-expr
|
||||
|
||||
// Binary logical expressions
|
||||
Params{Op::kLogicalAnd, ty_bool_, ty_bool_, ty_bool_},
|
||||
Params{Op::kLogicalOr, ty_bool_, ty_bool_, ty_bool_},
|
||||
|
||||
Params{Op::kAnd, ty_bool_, ty_bool_, ty_bool_},
|
||||
Params{Op::kOr, ty_bool_, ty_bool_, ty_bool_},
|
||||
Params{Op::kAnd, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
|
||||
Params{Op::kOr, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
|
||||
|
||||
// Arithmetic expressions
|
||||
// https://gpuweb.github.io/gpuweb/wgsl.html#arithmetic-expr
|
||||
|
||||
// Binary arithmetic expressions over scalars
|
||||
Params{Op::kAdd, ty_i32, ty_i32, ty_i32},
|
||||
Params{Op::kSubtract, ty_i32, ty_i32, ty_i32},
|
||||
Params{Op::kMultiply, ty_i32, ty_i32, ty_i32},
|
||||
Params{Op::kDivide, ty_i32, ty_i32, ty_i32},
|
||||
Params{Op::kModulo, ty_i32, ty_i32, ty_i32},
|
||||
|
||||
Params{Op::kAdd, ty_u32, ty_u32, ty_u32},
|
||||
Params{Op::kSubtract, ty_u32, ty_u32, ty_u32},
|
||||
Params{Op::kMultiply, ty_u32, ty_u32, ty_u32},
|
||||
Params{Op::kDivide, ty_u32, ty_u32, ty_u32},
|
||||
Params{Op::kModulo, ty_u32, ty_u32, ty_u32},
|
||||
|
||||
Params{Op::kAdd, ty_f32, ty_f32, ty_f32},
|
||||
Params{Op::kSubtract, ty_f32, ty_f32, ty_f32},
|
||||
Params{Op::kMultiply, ty_f32, ty_f32, ty_f32},
|
||||
Params{Op::kDivide, ty_f32, ty_f32, ty_f32},
|
||||
Params{Op::kModulo, ty_f32, ty_f32, ty_f32},
|
||||
|
||||
// Binary arithmetic expressions over vectors
|
||||
Params{Op::kAdd, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
|
||||
Params{Op::kSubtract, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
|
||||
Params{Op::kMultiply, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
|
||||
Params{Op::kDivide, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
|
||||
Params{Op::kModulo, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<i32>},
|
||||
|
||||
Params{Op::kAdd, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
|
||||
Params{Op::kSubtract, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
|
||||
Params{Op::kMultiply, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
|
||||
Params{Op::kDivide, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
|
||||
Params{Op::kModulo, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
|
||||
|
||||
Params{Op::kAdd, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
|
||||
Params{Op::kSubtract, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
|
||||
Params{Op::kMultiply, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
|
||||
Params{Op::kDivide, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
|
||||
Params{Op::kModulo, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<f32>},
|
||||
|
||||
// Binary arithmetic expressions with mixed scalar, vector, and matrix
|
||||
// operands
|
||||
Params{Op::kMultiply, ty_vec3<f32>, ty_f32, ty_vec3<f32>},
|
||||
Params{Op::kMultiply, ty_f32, ty_vec3<f32>, ty_vec3<f32>},
|
||||
|
||||
Params{Op::kMultiply, ty_mat3x3<f32>, ty_f32, ty_mat3x3<f32>},
|
||||
Params{Op::kMultiply, ty_f32, ty_mat3x3<f32>, ty_mat3x3<f32>},
|
||||
|
||||
Params{Op::kMultiply, ty_vec3<f32>, ty_mat3x3<f32>, ty_vec3<f32>},
|
||||
Params{Op::kMultiply, ty_mat3x3<f32>, ty_vec3<f32>, ty_vec3<f32>},
|
||||
Params{Op::kMultiply, ty_mat3x3<f32>, ty_mat3x3<f32>, ty_mat3x3<f32>},
|
||||
|
||||
// Comparison expressions
|
||||
// https://gpuweb.github.io/gpuweb/wgsl.html#comparison-expr
|
||||
|
||||
// Comparisons over scalars
|
||||
Params{Op::kEqual, ty_bool_, ty_bool_, ty_bool_},
|
||||
Params{Op::kNotEqual, ty_bool_, ty_bool_, ty_bool_},
|
||||
|
||||
Params{Op::kEqual, ty_i32, ty_i32, ty_bool_},
|
||||
Params{Op::kNotEqual, ty_i32, ty_i32, ty_bool_},
|
||||
Params{Op::kLessThan, ty_i32, ty_i32, ty_bool_},
|
||||
Params{Op::kLessThanEqual, ty_i32, ty_i32, ty_bool_},
|
||||
Params{Op::kGreaterThan, ty_i32, ty_i32, ty_bool_},
|
||||
Params{Op::kGreaterThanEqual, ty_i32, ty_i32, ty_bool_},
|
||||
|
||||
Params{Op::kEqual, ty_u32, ty_u32, ty_bool_},
|
||||
Params{Op::kNotEqual, ty_u32, ty_u32, ty_bool_},
|
||||
Params{Op::kLessThan, ty_u32, ty_u32, ty_bool_},
|
||||
Params{Op::kLessThanEqual, ty_u32, ty_u32, ty_bool_},
|
||||
Params{Op::kGreaterThan, ty_u32, ty_u32, ty_bool_},
|
||||
Params{Op::kGreaterThanEqual, ty_u32, ty_u32, ty_bool_},
|
||||
|
||||
Params{Op::kEqual, ty_f32, ty_f32, ty_bool_},
|
||||
Params{Op::kNotEqual, ty_f32, ty_f32, ty_bool_},
|
||||
Params{Op::kLessThan, ty_f32, ty_f32, ty_bool_},
|
||||
Params{Op::kLessThanEqual, ty_f32, ty_f32, ty_bool_},
|
||||
Params{Op::kGreaterThan, ty_f32, ty_f32, ty_bool_},
|
||||
Params{Op::kGreaterThanEqual, ty_f32, ty_f32, ty_bool_},
|
||||
|
||||
// Comparisons over vectors
|
||||
Params{Op::kEqual, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
|
||||
Params{Op::kNotEqual, ty_vec3<bool>, ty_vec3<bool>, ty_vec3<bool>},
|
||||
|
||||
Params{Op::kEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
|
||||
Params{Op::kNotEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
|
||||
Params{Op::kLessThan, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
|
||||
Params{Op::kLessThanEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
|
||||
Params{Op::kGreaterThan, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
|
||||
Params{Op::kGreaterThanEqual, ty_vec3<i32>, ty_vec3<i32>, ty_vec3<bool>},
|
||||
|
||||
Params{Op::kEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
|
||||
Params{Op::kNotEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
|
||||
Params{Op::kLessThan, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
|
||||
Params{Op::kLessThanEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
|
||||
Params{Op::kGreaterThan, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
|
||||
Params{Op::kGreaterThanEqual, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<bool>},
|
||||
|
||||
Params{Op::kEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
|
||||
Params{Op::kNotEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
|
||||
Params{Op::kLessThan, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
|
||||
Params{Op::kLessThanEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
|
||||
Params{Op::kGreaterThan, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
|
||||
Params{Op::kGreaterThanEqual, ty_vec3<f32>, ty_vec3<f32>, ty_vec3<bool>},
|
||||
|
||||
// Bit expressions
|
||||
// https://gpuweb.github.io/gpuweb/wgsl.html#bit-expr
|
||||
|
||||
// Binary bitwise operations
|
||||
Params{Op::kOr, ty_i32, ty_i32, ty_i32},
|
||||
Params{Op::kAnd, ty_i32, ty_i32, ty_i32},
|
||||
Params{Op::kXor, ty_i32, ty_i32, ty_i32},
|
||||
|
||||
Params{Op::kOr, ty_u32, ty_u32, ty_u32},
|
||||
Params{Op::kAnd, ty_u32, ty_u32, ty_u32},
|
||||
Params{Op::kXor, ty_u32, ty_u32, ty_u32},
|
||||
|
||||
// Bit shift expressions
|
||||
Params{Op::kShiftLeft, ty_i32, ty_u32, ty_i32},
|
||||
Params{Op::kShiftLeft, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<i32>},
|
||||
|
||||
Params{Op::kShiftLeft, ty_u32, ty_u32, ty_u32},
|
||||
Params{Op::kShiftLeft, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>},
|
||||
|
||||
Params{Op::kShiftRight, ty_i32, ty_u32, ty_i32},
|
||||
Params{Op::kShiftRight, ty_vec3<i32>, ty_vec3<u32>, ty_vec3<i32>},
|
||||
|
||||
Params{Op::kShiftRight, ty_u32, ty_u32, ty_u32},
|
||||
Params{Op::kShiftRight, ty_vec3<u32>, ty_vec3<u32>, ty_vec3<u32>}};
|
||||
|
||||
using Expr_Binary_Test_Valid = ResolverTestWithParam<Params>;
|
||||
TEST_P(Expr_Binary_Test_Valid, All) {
|
||||
auto& params = GetParam();
|
||||
|
||||
auto* lhs_type = params.create_lhs_type(ty);
|
||||
auto* rhs_type = params.create_rhs_type(ty);
|
||||
auto* result_type = params.create_result_type(ty);
|
||||
|
||||
SCOPED_TRACE(testing::Message()
|
||||
<< lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
|
||||
<< rhs_type->FriendlyName(Symbols()));
|
||||
|
||||
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
||||
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
||||
|
||||
auto* expr =
|
||||
create<ast::BinaryExpression>(params.op, Expr("lhs"), Expr("rhs"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
|
||||
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::I32>());
|
||||
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
|
||||
ASSERT_TRUE(TypeOf(expr) == result_type);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTest,
|
||||
Expr_Binary_BitwiseTest,
|
||||
testing::Values(ast::BinaryOp::kAnd,
|
||||
ast::BinaryOp::kOr,
|
||||
ast::BinaryOp::kXor,
|
||||
ast::BinaryOp::kShiftLeft,
|
||||
ast::BinaryOp::kShiftRight,
|
||||
ast::BinaryOp::kAdd,
|
||||
ast::BinaryOp::kSubtract,
|
||||
ast::BinaryOp::kDivide,
|
||||
ast::BinaryOp::kModulo));
|
||||
Expr_Binary_Test_Valid,
|
||||
testing::ValuesIn(all_valid_cases));
|
||||
|
||||
using Expr_Binary_LogicalTest = ResolverTestWithParam<ast::BinaryOp>;
|
||||
TEST_P(Expr_Binary_LogicalTest, Scalar) {
|
||||
auto op = GetParam();
|
||||
using Expr_Binary_Test_Invalid =
|
||||
ResolverTestWithParam<std::tuple<Params, create_type_func_ptr>>;
|
||||
TEST_P(Expr_Binary_Test_Invalid, All) {
|
||||
const Params& params = std::get<0>(GetParam());
|
||||
const create_type_func_ptr& create_type_func = std::get<1>(GetParam());
|
||||
|
||||
Global("val", ty.bool_(), ast::StorageClass::kNone);
|
||||
// Currently, for most operations, for a given lhs type, there is exactly one
|
||||
// rhs type allowed. The only exception is for multiplication, which allows
|
||||
// any permutation of f32, vecN<f32>, and matNxN<f32>. We are fed valid inputs
|
||||
// only via `params`, and all possible types via `create_type_func`, so we
|
||||
// test invalid combinations by testing every other rhs type, modulo
|
||||
// exceptions.
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
|
||||
// Skip valid rhs type
|
||||
if (params.create_rhs_type == create_type_func) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto* lhs_type = params.create_lhs_type(ty);
|
||||
auto* rhs_type = create_type_func(ty);
|
||||
|
||||
// Skip exceptions: multiplication of f32, vecN<f32>, and matNxN<f32>
|
||||
if (params.op == Op::kMultiply &&
|
||||
lhs_type->is_float_scalar_or_vector_or_matrix() &&
|
||||
rhs_type->is_float_scalar_or_vector_or_matrix()) {
|
||||
return;
|
||||
}
|
||||
|
||||
SCOPED_TRACE(testing::Message()
|
||||
<< lhs_type->FriendlyName(Symbols()) << " " << params.op << " "
|
||||
<< rhs_type->FriendlyName(Symbols()));
|
||||
|
||||
Global("lhs", lhs_type, ast::StorageClass::kNone);
|
||||
Global("rhs", rhs_type, ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(
|
||||
Source{Source::Location{12, 34}}, params.op, Expr("lhs"), Expr("rhs"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
EXPECT_TRUE(TypeOf(expr)->Is<type::Bool>());
|
||||
}
|
||||
|
||||
TEST_P(Expr_Binary_LogicalTest, Vector) {
|
||||
auto op = GetParam();
|
||||
|
||||
Global("val", ty.vec3<bool>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
|
||||
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::Bool>());
|
||||
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTest,
|
||||
Expr_Binary_LogicalTest,
|
||||
testing::Values(ast::BinaryOp::kLogicalAnd,
|
||||
ast::BinaryOp::kLogicalOr));
|
||||
|
||||
using Expr_Binary_CompareTest = ResolverTestWithParam<ast::BinaryOp>;
|
||||
TEST_P(Expr_Binary_CompareTest, Scalar) {
|
||||
auto op = GetParam();
|
||||
|
||||
Global("val", ty.i32(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
EXPECT_TRUE(TypeOf(expr)->Is<type::Bool>());
|
||||
}
|
||||
|
||||
TEST_P(Expr_Binary_CompareTest, Vector) {
|
||||
auto op = GetParam();
|
||||
|
||||
Global("val", ty.vec3<i32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = create<ast::BinaryExpression>(op, Expr("val"), Expr("val"));
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
|
||||
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::Bool>());
|
||||
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(ResolverTest,
|
||||
Expr_Binary_CompareTest,
|
||||
testing::Values(ast::BinaryOp::kEqual,
|
||||
ast::BinaryOp::kNotEqual,
|
||||
ast::BinaryOp::kLessThan,
|
||||
ast::BinaryOp::kGreaterThan,
|
||||
ast::BinaryOp::kLessThanEqual,
|
||||
ast::BinaryOp::kGreaterThanEqual));
|
||||
|
||||
TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Scalar) {
|
||||
Global("val", ty.i32(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = Mul("val", "val");
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
EXPECT_TRUE(TypeOf(expr)->Is<type::I32>());
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Scalar) {
|
||||
Global("scalar", ty.f32(), ast::StorageClass::kNone);
|
||||
Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = Mul("vector", "scalar");
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
|
||||
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
|
||||
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Vector) {
|
||||
Global("scalar", ty.f32(), ast::StorageClass::kNone);
|
||||
Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = Mul("scalar", "vector");
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
|
||||
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
|
||||
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Vector) {
|
||||
Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = Mul("vector", "vector");
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
|
||||
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
|
||||
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Scalar) {
|
||||
Global("scalar", ty.f32(), ast::StorageClass::kNone);
|
||||
Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = Mul("matrix", "scalar");
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
|
||||
|
||||
auto* mat = TypeOf(expr)->As<type::Matrix>();
|
||||
EXPECT_TRUE(mat->type()->Is<type::F32>());
|
||||
EXPECT_EQ(mat->rows(), 3u);
|
||||
EXPECT_EQ(mat->columns(), 2u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Expr_Binary_Multiply_Scalar_Matrix) {
|
||||
Global("scalar", ty.f32(), ast::StorageClass::kNone);
|
||||
Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = Mul("scalar", "matrix");
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
|
||||
|
||||
auto* mat = TypeOf(expr)->As<type::Matrix>();
|
||||
EXPECT_TRUE(mat->type()->Is<type::F32>());
|
||||
EXPECT_EQ(mat->rows(), 3u);
|
||||
EXPECT_EQ(mat->columns(), 2u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Vector) {
|
||||
Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||
Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = Mul("matrix", "vector");
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
|
||||
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
|
||||
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 3u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Expr_Binary_Multiply_Vector_Matrix) {
|
||||
Global("vector", ty.vec3<f32>(), ast::StorageClass::kNone);
|
||||
Global("matrix", ty.mat2x3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = Mul("vector", "matrix");
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Vector>());
|
||||
EXPECT_TRUE(TypeOf(expr)->As<type::Vector>()->type()->Is<type::F32>());
|
||||
EXPECT_EQ(TypeOf(expr)->As<type::Vector>()->size(), 2u);
|
||||
}
|
||||
|
||||
TEST_F(ResolverTest, Expr_Binary_Multiply_Matrix_Matrix) {
|
||||
Global("mat3x4", ty.mat3x4<f32>(), ast::StorageClass::kNone);
|
||||
Global("mat4x3", ty.mat4x3<f32>(), ast::StorageClass::kNone);
|
||||
|
||||
auto* expr = Mul("mat3x4", "mat4x3");
|
||||
WrapInFunction(expr);
|
||||
|
||||
ASSERT_TRUE(r()->Resolve()) << r()->error();
|
||||
ASSERT_NE(TypeOf(expr), nullptr);
|
||||
ASSERT_TRUE(TypeOf(expr)->Is<type::Matrix>());
|
||||
|
||||
auto* mat = TypeOf(expr)->As<type::Matrix>();
|
||||
EXPECT_TRUE(mat->type()->Is<type::F32>());
|
||||
EXPECT_EQ(mat->rows(), 4u);
|
||||
EXPECT_EQ(mat->columns(), 4u);
|
||||
ASSERT_FALSE(r()->Resolve()) << r()->error();
|
||||
ASSERT_EQ(r()->error(),
|
||||
"12:34 error: Binary expression operand types are invalid for "
|
||||
"this operation");
|
||||
}
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ResolverTest,
|
||||
Expr_Binary_Test_Invalid,
|
||||
testing::Combine(testing::ValuesIn(all_valid_cases),
|
||||
testing::ValuesIn(all_create_type_funcs)));
|
||||
} // namespace ExprBinaryTest
|
||||
|
||||
using UnaryOpExpressionTest = ResolverTestWithParam<ast::UnaryOp>;
|
||||
TEST_P(UnaryOpExpressionTest, Expr_UnaryOp) {
|
||||
|
|
|
@ -104,7 +104,7 @@ TEST_F(BoundArrayAccessorsTest, Array_Idx_Expr) {
|
|||
auto* src = R"(
|
||||
var a : array<f32, 3>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a[c + 2 - 3];
|
||||
|
@ -114,7 +114,7 @@ fn f() -> void {
|
|||
auto* expect = R"(
|
||||
var a : array<f32, 3>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a[min(u32(((c + 2) - 3)), 2u)];
|
||||
|
@ -196,7 +196,7 @@ TEST_F(BoundArrayAccessorsTest, Vector_Idx_Expr) {
|
|||
auto* src = R"(
|
||||
var a : vec3<f32>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a[c + 2 - 3];
|
||||
|
@ -206,7 +206,7 @@ fn f() -> void {
|
|||
auto* expect = R"(
|
||||
var a : vec3<f32>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a[min(u32(((c + 2) - 3)), 2u)];
|
||||
|
@ -244,7 +244,7 @@ TEST_F(BoundArrayAccessorsTest, Vector_Swizzle_Idx_Var) {
|
|||
auto* src = R"(
|
||||
var a : vec3<f32>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a.xy[c];
|
||||
|
@ -254,7 +254,7 @@ fn f() -> void {
|
|||
auto* expect = R"(
|
||||
var a : vec3<f32>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a.xy[min(u32(c), 1u)];
|
||||
|
@ -269,7 +269,7 @@ TEST_F(BoundArrayAccessorsTest, Vector_Swizzle_Idx_Expr) {
|
|||
auto* src = R"(
|
||||
var a : vec3<f32>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a.xy[c + 2 - 3];
|
||||
|
@ -279,7 +279,7 @@ fn f() -> void {
|
|||
auto* expect = R"(
|
||||
var a : vec3<f32>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a.xy[min(u32(((c + 2) - 3)), 1u)];
|
||||
|
@ -361,7 +361,7 @@ TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Column) {
|
|||
auto* src = R"(
|
||||
var a : mat3x2<f32>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a[c + 2 - 3][1];
|
||||
|
@ -371,7 +371,7 @@ fn f() -> void {
|
|||
auto* expect = R"(
|
||||
var a : mat3x2<f32>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a[min(u32(((c + 2) - 3)), 2u)][1];
|
||||
|
@ -387,7 +387,7 @@ TEST_F(BoundArrayAccessorsTest, Matrix_Idx_Expr_Row) {
|
|||
auto* src = R"(
|
||||
var a : mat3x2<f32>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a[1][c + 2 - 3];
|
||||
|
@ -397,7 +397,7 @@ fn f() -> void {
|
|||
auto* expect = R"(
|
||||
var a : mat3x2<f32>;
|
||||
|
||||
var c : u32;
|
||||
var c : i32;
|
||||
|
||||
fn f() -> void {
|
||||
var b : f32 = a[1][min(u32(((c + 2) - 3)), 1u)];
|
||||
|
|
|
@ -132,7 +132,7 @@ void VertexPulling::State::FindOrInsertVertexIndexIfUsed() {
|
|||
Source{}, // source
|
||||
ctx.dst->Symbols().Register(vertex_index_name), // symbol
|
||||
ast::StorageClass::kInput, // storage_class
|
||||
GetI32Type(), // type
|
||||
GetU32Type(), // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
ast::DecorationList{
|
||||
|
@ -179,7 +179,7 @@ void VertexPulling::State::FindOrInsertInstanceIndexIfUsed() {
|
|||
Source{}, // source
|
||||
ctx.dst->Symbols().Register(instance_index_name), // symbol
|
||||
ast::StorageClass::kInput, // storage_class
|
||||
GetI32Type(), // type
|
||||
GetU32Type(), // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
ast::DecorationList{
|
||||
|
@ -273,7 +273,7 @@ ast::BlockStatement* VertexPulling::State::CreateVertexPullingPreamble() const {
|
|||
Source{}, // source
|
||||
ctx.dst->Symbols().Register(kPullingPosVarName), // symbol
|
||||
ast::StorageClass::kFunction, // storage_class
|
||||
GetI32Type(), // type
|
||||
GetU32Type(), // type
|
||||
false, // is_const
|
||||
nullptr, // constructor
|
||||
ast::DecorationList{})); // decorations
|
||||
|
|
|
@ -89,7 +89,7 @@ struct TintVertexData {
|
|||
[[stage(vertex)]]
|
||||
fn main() -> void {
|
||||
{
|
||||
var _tint_pulling_pos : i32;
|
||||
var _tint_pulling_pos : u32;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
@ -113,7 +113,7 @@ fn main() -> void {}
|
|||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
|
||||
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
|
||||
|
||||
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
|
||||
|
||||
|
@ -127,7 +127,7 @@ var<private> var_a : f32;
|
|||
[[stage(vertex)]]
|
||||
fn main() -> void {
|
||||
{
|
||||
var _tint_pulling_pos : i32;
|
||||
var _tint_pulling_pos : u32;
|
||||
_tint_pulling_pos = ((_tint_pulling_vertex_index * 4u) + 0u);
|
||||
var_a = bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]);
|
||||
}
|
||||
|
@ -155,7 +155,7 @@ fn main() -> void {}
|
|||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[builtin(instance_index)]] var<in> _tint_pulling_instance_index : i32;
|
||||
[[builtin(instance_index)]] var<in> _tint_pulling_instance_index : u32;
|
||||
|
||||
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
|
||||
|
||||
|
@ -169,7 +169,7 @@ var<private> var_a : f32;
|
|||
[[stage(vertex)]]
|
||||
fn main() -> void {
|
||||
{
|
||||
var _tint_pulling_pos : i32;
|
||||
var _tint_pulling_pos : u32;
|
||||
_tint_pulling_pos = ((_tint_pulling_instance_index * 4u) + 0u);
|
||||
var_a = bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]);
|
||||
}
|
||||
|
@ -197,7 +197,7 @@ fn main() -> void {}
|
|||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
|
||||
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
|
||||
|
||||
[[binding(0), group(5)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
|
||||
|
||||
|
@ -211,7 +211,7 @@ var<private> var_a : f32;
|
|||
[[stage(vertex)]]
|
||||
fn main() -> void {
|
||||
{
|
||||
var _tint_pulling_pos : i32;
|
||||
var _tint_pulling_pos : u32;
|
||||
_tint_pulling_pos = ((_tint_pulling_vertex_index * 4u) + 0u);
|
||||
var_a = bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]);
|
||||
}
|
||||
|
@ -236,8 +236,8 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
|
|||
auto* src = R"(
|
||||
[[location(0)]] var<in> var_a : f32;
|
||||
[[location(1)]] var<in> var_b : f32;
|
||||
[[builtin(vertex_index)]] var<in> custom_vertex_index : i32;
|
||||
[[builtin(instance_index)]] var<in> custom_instance_index : i32;
|
||||
[[builtin(vertex_index)]] var<in> custom_vertex_index : u32;
|
||||
[[builtin(instance_index)]] var<in> custom_instance_index : u32;
|
||||
|
||||
[[stage(vertex)]]
|
||||
fn main() -> void {}
|
||||
|
@ -257,14 +257,14 @@ var<private> var_a : f32;
|
|||
|
||||
var<private> var_b : f32;
|
||||
|
||||
[[builtin(vertex_index)]] var<in> custom_vertex_index : i32;
|
||||
[[builtin(vertex_index)]] var<in> custom_vertex_index : u32;
|
||||
|
||||
[[builtin(instance_index)]] var<in> custom_instance_index : i32;
|
||||
[[builtin(instance_index)]] var<in> custom_instance_index : u32;
|
||||
|
||||
[[stage(vertex)]]
|
||||
fn main() -> void {
|
||||
{
|
||||
var _tint_pulling_pos : i32;
|
||||
var _tint_pulling_pos : u32;
|
||||
_tint_pulling_pos = ((custom_vertex_index * 4u) + 0u);
|
||||
var_a = bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]);
|
||||
_tint_pulling_pos = ((custom_instance_index * 4u) + 0u);
|
||||
|
@ -305,7 +305,7 @@ fn main() -> void {}
|
|||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
|
||||
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
|
||||
|
||||
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
|
||||
|
||||
|
@ -321,7 +321,7 @@ var<private> var_b : array<f32, 4>;
|
|||
[[stage(vertex)]]
|
||||
fn main() -> void {
|
||||
{
|
||||
var _tint_pulling_pos : i32;
|
||||
var _tint_pulling_pos : u32;
|
||||
_tint_pulling_pos = ((_tint_pulling_vertex_index * 16u) + 0u);
|
||||
var_a = bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[(_tint_pulling_pos / 4u)]);
|
||||
_tint_pulling_pos = ((_tint_pulling_vertex_index * 16u) + 0u);
|
||||
|
@ -355,7 +355,7 @@ fn main() -> void {}
|
|||
)";
|
||||
|
||||
auto* expect = R"(
|
||||
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : i32;
|
||||
[[builtin(vertex_index)]] var<in> _tint_pulling_vertex_index : u32;
|
||||
|
||||
[[binding(0), group(4)]] var<storage> _tint_pulling_vertex_buffer_0 : TintVertexData;
|
||||
|
||||
|
@ -377,7 +377,7 @@ var<private> var_c : array<f32, 4>;
|
|||
[[stage(vertex)]]
|
||||
fn main() -> void {
|
||||
{
|
||||
var _tint_pulling_pos : i32;
|
||||
var _tint_pulling_pos : u32;
|
||||
_tint_pulling_pos = ((_tint_pulling_vertex_index * 8u) + 0u);
|
||||
var_a = vec2<f32>(bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[((_tint_pulling_pos + 0u) / 4u)]), bitcast<f32>(_tint_pulling_vertex_buffer_0._tint_vertex_data[((_tint_pulling_pos + 4u) / 4u)]));
|
||||
_tint_pulling_pos = ((_tint_pulling_vertex_index * 12u) + 0u);
|
||||
|
|
|
@ -92,6 +92,10 @@ bool Type::is_float_scalar_or_vector() const {
|
|||
return is_float_scalar() || is_float_vector();
|
||||
}
|
||||
|
||||
bool Type::is_float_scalar_or_vector_or_matrix() const {
|
||||
return is_float_scalar() || is_float_vector() || is_float_matrix();
|
||||
}
|
||||
|
||||
bool Type::is_integer_scalar() const {
|
||||
return IsAnyOf<U32, I32>();
|
||||
}
|
||||
|
|
|
@ -77,6 +77,8 @@ class Type : public Castable<Type, Cloneable> {
|
|||
bool is_float_vector() const;
|
||||
/// @returns true if this type is a float scalar or vector
|
||||
bool is_float_scalar_or_vector() const;
|
||||
/// @returns true if this type is a float scalar or vector or matrix
|
||||
bool is_float_scalar_or_vector_or_matrix() const;
|
||||
/// @returns true if this type is an integer scalar
|
||||
bool is_integer_scalar() const;
|
||||
/// @returns true if this type is a signed integer vector
|
||||
|
|
|
@ -36,6 +36,14 @@ using HlslBinaryTest = TestParamHelper<BinaryData>;
|
|||
TEST_P(HlslBinaryTest, Emit_f32) {
|
||||
auto params = GetParam();
|
||||
|
||||
// Skip ops that are illegal for this type
|
||||
if (params.op == ast::BinaryOp::kAnd || params.op == ast::BinaryOp::kOr ||
|
||||
params.op == ast::BinaryOp::kXor ||
|
||||
params.op == ast::BinaryOp::kShiftLeft ||
|
||||
params.op == ast::BinaryOp::kShiftRight) {
|
||||
return;
|
||||
}
|
||||
|
||||
Global("left", ty.f32(), ast::StorageClass::kFunction);
|
||||
Global("right", ty.f32(), ast::StorageClass::kFunction);
|
||||
|
||||
|
@ -72,6 +80,12 @@ TEST_P(HlslBinaryTest, Emit_u32) {
|
|||
TEST_P(HlslBinaryTest, Emit_i32) {
|
||||
auto params = GetParam();
|
||||
|
||||
// Skip ops that are illegal for this type
|
||||
if (params.op == ast::BinaryOp::kShiftLeft ||
|
||||
params.op == ast::BinaryOp::kShiftRight) {
|
||||
return;
|
||||
}
|
||||
|
||||
Global("left", ty.i32(), ast::StorageClass::kFunction);
|
||||
Global("right", ty.i32(), ast::StorageClass::kFunction);
|
||||
|
||||
|
|
|
@ -58,6 +58,12 @@ TEST_P(BinaryArithSignedIntegerTest, Scalar) {
|
|||
TEST_P(BinaryArithSignedIntegerTest, Vector) {
|
||||
auto param = GetParam();
|
||||
|
||||
// Skip ops that are illegal for this type
|
||||
if (param.op == ast::BinaryOp::kAnd || param.op == ast::BinaryOp::kOr ||
|
||||
param.op == ast::BinaryOp::kXor) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto* lhs = vec3<i32>(1, 1, 1);
|
||||
auto* rhs = vec3<i32>(1, 1, 1);
|
||||
|
||||
|
@ -111,15 +117,13 @@ TEST_P(BinaryArithSignedIntegerTest, Scalar_Loads) {
|
|||
INSTANTIATE_TEST_SUITE_P(
|
||||
BuilderTest,
|
||||
BinaryArithSignedIntegerTest,
|
||||
// NOTE: No left and right shift as they require u32 for rhs operand
|
||||
testing::Values(BinaryData{ast::BinaryOp::kAdd, "OpIAdd"},
|
||||
BinaryData{ast::BinaryOp::kAnd, "OpBitwiseAnd"},
|
||||
BinaryData{ast::BinaryOp::kDivide, "OpSDiv"},
|
||||
BinaryData{ast::BinaryOp::kModulo, "OpSMod"},
|
||||
BinaryData{ast::BinaryOp::kMultiply, "OpIMul"},
|
||||
BinaryData{ast::BinaryOp::kOr, "OpBitwiseOr"},
|
||||
BinaryData{ast::BinaryOp::kShiftLeft, "OpShiftLeftLogical"},
|
||||
BinaryData{ast::BinaryOp::kShiftRight,
|
||||
"OpShiftRightArithmetic"},
|
||||
BinaryData{ast::BinaryOp::kSubtract, "OpISub"},
|
||||
BinaryData{ast::BinaryOp::kXor, "OpBitwiseXor"}));
|
||||
|
||||
|
@ -149,6 +153,12 @@ TEST_P(BinaryArithUnsignedIntegerTest, Scalar) {
|
|||
TEST_P(BinaryArithUnsignedIntegerTest, Vector) {
|
||||
auto param = GetParam();
|
||||
|
||||
// Skip ops that are illegal for this type
|
||||
if (param.op == ast::BinaryOp::kAnd || param.op == ast::BinaryOp::kOr ||
|
||||
param.op == ast::BinaryOp::kXor) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto* lhs = vec3<u32>(1u, 1u, 1u);
|
||||
auto* rhs = vec3<u32>(1u, 1u, 1u);
|
||||
|
||||
|
|
Loading…
Reference in New Issue