diff --git a/src/ast/type/type.cc b/src/ast/type/type.cc index fc5ce2dc0d..17ee534094 100644 --- a/src/ast/type/type.cc +++ b/src/ast/type/type.cc @@ -80,6 +80,34 @@ bool Type::IsVoid() const { return false; } +bool Type::is_float_scalar() { + return IsF32(); +} + +bool Type::is_float_matrix() { + return IsMatrix() && AsMatrix()->type()->is_float_scalar(); +} + +bool Type::is_float_vector() { + return IsVector() && AsVector()->type()->is_float_scalar(); +} + +bool Type::is_float_scalar_or_vector() { + return is_float_scalar() || is_float_vector(); +} + +bool Type::is_unsigned_scalar_or_vector() { + return IsU32() || (IsVector() && AsVector()->type()->IsU32()); +} + +bool Type::is_signed_scalar_or_vector() { + return IsI32() || (IsVector() && AsVector()->type()->IsI32()); +} + +bool Type::is_integer_scalar_or_vector() { + return is_unsigned_scalar_or_vector() || is_signed_scalar_or_vector(); +} + AliasType* Type::AsAlias() { assert(IsAlias()); return static_cast(this); diff --git a/src/ast/type/type.h b/src/ast/type/type.h index 05e29ddf64..4fa44ab1d9 100644 --- a/src/ast/type/type.h +++ b/src/ast/type/type.h @@ -66,6 +66,21 @@ class Type { /// @returns the name for this type. The |type_name| is unique over all types. virtual std::string type_name() const = 0; + /// @returns true if this type is a float scalar + bool is_float_scalar(); + /// @returns true if this type is a float matrix + bool is_float_matrix(); + /// @returns true if this type is a float vector + bool is_float_vector(); + /// @returns true if this type is a float scalar or vector + bool is_float_scalar_or_vector(); + /// @returns true if this type is an unsigned scalar or vector + bool is_unsigned_scalar_or_vector(); + /// @returns true if this type is a signed scalar or vector + bool is_signed_scalar_or_vector(); + /// @returns true if this type is an integer scalar or vector + bool is_integer_scalar_or_vector(); + /// @returns the type as an alias type AliasType* AsAlias(); /// @returns the type as an array type diff --git a/src/writer/spirv/builder.cc b/src/writer/spirv/builder.cc index 0e1a76f92a..ed435dc3ff 100644 --- a/src/writer/spirv/builder.cc +++ b/src/writer/spirv/builder.cc @@ -80,36 +80,6 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) { return model; } -bool is_float_scalar(ast::type::Type* type) { - return type->IsF32(); -} - -bool is_float_matrix(ast::type::Type* type) { - return type->IsMatrix() && type->AsMatrix()->type()->IsF32(); -} - -bool is_float_vector(ast::type::Type* type) { - return type->IsVector() && type->AsVector()->type()->IsF32(); -} - -bool is_float_scalar_or_vector(ast::type::Type* type) { - return is_float_scalar(type) || is_float_vector(type); -} - -bool is_unsigned_scalar_or_vector(ast::type::Type* type) { - return type->IsU32() || - (type->IsVector() && type->AsVector()->type()->IsU32()); -} - -bool is_signed_scalar_or_vector(ast::type::Type* type) { - return type->IsI32() || - (type->IsVector() && type->AsVector()->type()->IsI32()); -} - -bool is_integer_scalar_or_vector(ast::type::Type* type) { - return is_unsigned_scalar_or_vector(type) || is_signed_scalar_or_vector(type); -} - } // namespace Builder::Builder() : scope_stack_({}) {} @@ -601,8 +571,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { // should have been rejected by validation. auto* lhs_type = expr->lhs()->result_type(); auto* rhs_type = expr->rhs()->result_type(); - bool lhs_is_float_or_vec = is_float_scalar_or_vector(lhs_type); - bool lhs_is_unsigned = is_unsigned_scalar_or_vector(lhs_type); + bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector(); + bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector(); spv::Op op = spv::Op::OpNop; if (expr->IsAnd()) { @@ -660,39 +630,39 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) { op = spv::Op::OpSMod; } } else if (expr->IsMultiply()) { - if (is_integer_scalar_or_vector(lhs_type)) { + if (lhs_type->is_integer_scalar_or_vector()) { // If the left hand side is an integer then this _has_ to be OpIMul as // there there is no other integer multiplication. op = spv::Op::OpIMul; - } else if (is_float_scalar(lhs_type) && is_float_scalar(rhs_type)) { + } else if (lhs_type->is_float_scalar() && rhs_type->is_float_scalar()) { // Float scalars multiply with OpFMul op = spv::Op::OpFMul; - } else if (is_float_vector(lhs_type) && is_float_vector(rhs_type)) { + } else if (lhs_type->is_float_vector() && rhs_type->is_float_vector()) { // Float vectors must be validated to be the same size and then use OpFMul op = spv::Op::OpFMul; - } else if (is_float_scalar(lhs_type) && is_float_vector(rhs_type)) { + } else if (lhs_type->is_float_scalar() && rhs_type->is_float_vector()) { // Scalar * Vector we need to flip lhs and rhs types // because OpVectorTimesScalar expects , std::swap(lhs_id, rhs_id); op = spv::Op::OpVectorTimesScalar; - } else if (is_float_vector(lhs_type) && is_float_scalar(rhs_type)) { + } else if (lhs_type->is_float_vector() && rhs_type->is_float_scalar()) { // float vector * scalar op = spv::Op::OpVectorTimesScalar; - } else if (is_float_scalar(lhs_type) && is_float_matrix(rhs_type)) { + } else if (lhs_type->is_float_scalar() && rhs_type->is_float_matrix()) { // Scalar * Matrix we need to flip lhs and rhs types because // OpMatrixTimesScalar expects , std::swap(lhs_id, rhs_id); op = spv::Op::OpMatrixTimesScalar; - } else if (is_float_matrix(lhs_type) && is_float_scalar(rhs_type)) { + } else if (lhs_type->is_float_matrix() && rhs_type->is_float_scalar()) { // float matrix * scalar op = spv::Op::OpMatrixTimesScalar; - } else if (is_float_vector(lhs_type) && is_float_matrix(rhs_type)) { + } else if (lhs_type->is_float_vector() && rhs_type->is_float_matrix()) { // float vector * matrix op = spv::Op::OpVectorTimesMatrix; - } else if (is_float_matrix(lhs_type) && is_float_vector(rhs_type)) { + } else if (lhs_type->is_float_matrix() && rhs_type->is_float_vector()) { // float matrix * vector op = spv::Op::OpMatrixTimesVector; - } else if (is_float_matrix(lhs_type) && is_float_matrix(rhs_type)) { + } else if (lhs_type->is_float_matrix() && rhs_type->is_float_matrix()) { // float matrix * matrix op = spv::Op::OpMatrixTimesMatrix; } else {