Move type methods to type class

This CL moves the checks for different types into the type class so it
can be used in both the type determinater and the SPIR-V builder.

Change-Id: I9142adaf5fc1d6048792645d7892f8d8900fcf59
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19921
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-04-20 14:07:43 +00:00 committed by dan sinclair
parent 6866cb7677
commit c954788b59
3 changed files with 55 additions and 42 deletions

View File

@ -80,6 +80,34 @@ bool Type::IsVoid() const {
return false;
}
bool Type::is_float_scalar() {
return IsF32();
}
bool Type::is_float_matrix() {
return IsMatrix() && AsMatrix()->type()->is_float_scalar();
}
bool Type::is_float_vector() {
return IsVector() && AsVector()->type()->is_float_scalar();
}
bool Type::is_float_scalar_or_vector() {
return is_float_scalar() || is_float_vector();
}
bool Type::is_unsigned_scalar_or_vector() {
return IsU32() || (IsVector() && AsVector()->type()->IsU32());
}
bool Type::is_signed_scalar_or_vector() {
return IsI32() || (IsVector() && AsVector()->type()->IsI32());
}
bool Type::is_integer_scalar_or_vector() {
return is_unsigned_scalar_or_vector() || is_signed_scalar_or_vector();
}
AliasType* Type::AsAlias() {
assert(IsAlias());
return static_cast<AliasType*>(this);

View File

@ -66,6 +66,21 @@ class Type {
/// @returns the name for this type. The |type_name| is unique over all types.
virtual std::string type_name() const = 0;
/// @returns true if this type is a float scalar
bool is_float_scalar();
/// @returns true if this type is a float matrix
bool is_float_matrix();
/// @returns true if this type is a float vector
bool is_float_vector();
/// @returns true if this type is a float scalar or vector
bool is_float_scalar_or_vector();
/// @returns true if this type is an unsigned scalar or vector
bool is_unsigned_scalar_or_vector();
/// @returns true if this type is a signed scalar or vector
bool is_signed_scalar_or_vector();
/// @returns true if this type is an integer scalar or vector
bool is_integer_scalar_or_vector();
/// @returns the type as an alias type
AliasType* AsAlias();
/// @returns the type as an array type

View File

@ -80,36 +80,6 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
return model;
}
bool is_float_scalar(ast::type::Type* type) {
return type->IsF32();
}
bool is_float_matrix(ast::type::Type* type) {
return type->IsMatrix() && type->AsMatrix()->type()->IsF32();
}
bool is_float_vector(ast::type::Type* type) {
return type->IsVector() && type->AsVector()->type()->IsF32();
}
bool is_float_scalar_or_vector(ast::type::Type* type) {
return is_float_scalar(type) || is_float_vector(type);
}
bool is_unsigned_scalar_or_vector(ast::type::Type* type) {
return type->IsU32() ||
(type->IsVector() && type->AsVector()->type()->IsU32());
}
bool is_signed_scalar_or_vector(ast::type::Type* type) {
return type->IsI32() ||
(type->IsVector() && type->AsVector()->type()->IsI32());
}
bool is_integer_scalar_or_vector(ast::type::Type* type) {
return is_unsigned_scalar_or_vector(type) || is_signed_scalar_or_vector(type);
}
} // namespace
Builder::Builder() : scope_stack_({}) {}
@ -601,8 +571,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
// should have been rejected by validation.
auto* lhs_type = expr->lhs()->result_type();
auto* rhs_type = expr->rhs()->result_type();
bool lhs_is_float_or_vec = is_float_scalar_or_vector(lhs_type);
bool lhs_is_unsigned = is_unsigned_scalar_or_vector(lhs_type);
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
spv::Op op = spv::Op::OpNop;
if (expr->IsAnd()) {
@ -660,39 +630,39 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
op = spv::Op::OpSMod;
}
} else if (expr->IsMultiply()) {
if (is_integer_scalar_or_vector(lhs_type)) {
if (lhs_type->is_integer_scalar_or_vector()) {
// If the left hand side is an integer then this _has_ to be OpIMul as
// there there is no other integer multiplication.
op = spv::Op::OpIMul;
} else if (is_float_scalar(lhs_type) && is_float_scalar(rhs_type)) {
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_scalar()) {
// Float scalars multiply with OpFMul
op = spv::Op::OpFMul;
} else if (is_float_vector(lhs_type) && is_float_vector(rhs_type)) {
} else if (lhs_type->is_float_vector() && rhs_type->is_float_vector()) {
// Float vectors must be validated to be the same size and then use OpFMul
op = spv::Op::OpFMul;
} else if (is_float_scalar(lhs_type) && is_float_vector(rhs_type)) {
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_vector()) {
// Scalar * Vector we need to flip lhs and rhs types
// because OpVectorTimesScalar expects <vector>, <scalar>
std::swap(lhs_id, rhs_id);
op = spv::Op::OpVectorTimesScalar;
} else if (is_float_vector(lhs_type) && is_float_scalar(rhs_type)) {
} else if (lhs_type->is_float_vector() && rhs_type->is_float_scalar()) {
// float vector * scalar
op = spv::Op::OpVectorTimesScalar;
} else if (is_float_scalar(lhs_type) && is_float_matrix(rhs_type)) {
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_matrix()) {
// Scalar * Matrix we need to flip lhs and rhs types because
// OpMatrixTimesScalar expects <matrix>, <scalar>
std::swap(lhs_id, rhs_id);
op = spv::Op::OpMatrixTimesScalar;
} else if (is_float_matrix(lhs_type) && is_float_scalar(rhs_type)) {
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_scalar()) {
// float matrix * scalar
op = spv::Op::OpMatrixTimesScalar;
} else if (is_float_vector(lhs_type) && is_float_matrix(rhs_type)) {
} else if (lhs_type->is_float_vector() && rhs_type->is_float_matrix()) {
// float vector * matrix
op = spv::Op::OpVectorTimesMatrix;
} else if (is_float_matrix(lhs_type) && is_float_vector(rhs_type)) {
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_vector()) {
// float matrix * vector
op = spv::Op::OpMatrixTimesVector;
} else if (is_float_matrix(lhs_type) && is_float_matrix(rhs_type)) {
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_matrix()) {
// float matrix * matrix
op = spv::Op::OpMatrixTimesMatrix;
} else {