Move type methods to type class

This CL moves the checks for different types into the type class so it
can be used in both the type determinater and the SPIR-V builder.

Change-Id: I9142adaf5fc1d6048792645d7892f8d8900fcf59
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/19921
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair
2020-04-20 14:07:43 +00:00
committed by dan sinclair
parent 6866cb7677
commit c954788b59
3 changed files with 55 additions and 42 deletions

View File

@@ -80,36 +80,6 @@ uint32_t pipeline_stage_to_execution_model(ast::PipelineStage stage) {
return model;
}
bool is_float_scalar(ast::type::Type* type) {
return type->IsF32();
}
bool is_float_matrix(ast::type::Type* type) {
return type->IsMatrix() && type->AsMatrix()->type()->IsF32();
}
bool is_float_vector(ast::type::Type* type) {
return type->IsVector() && type->AsVector()->type()->IsF32();
}
bool is_float_scalar_or_vector(ast::type::Type* type) {
return is_float_scalar(type) || is_float_vector(type);
}
bool is_unsigned_scalar_or_vector(ast::type::Type* type) {
return type->IsU32() ||
(type->IsVector() && type->AsVector()->type()->IsU32());
}
bool is_signed_scalar_or_vector(ast::type::Type* type) {
return type->IsI32() ||
(type->IsVector() && type->AsVector()->type()->IsI32());
}
bool is_integer_scalar_or_vector(ast::type::Type* type) {
return is_unsigned_scalar_or_vector(type) || is_signed_scalar_or_vector(type);
}
} // namespace
Builder::Builder() : scope_stack_({}) {}
@@ -601,8 +571,8 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
// should have been rejected by validation.
auto* lhs_type = expr->lhs()->result_type();
auto* rhs_type = expr->rhs()->result_type();
bool lhs_is_float_or_vec = is_float_scalar_or_vector(lhs_type);
bool lhs_is_unsigned = is_unsigned_scalar_or_vector(lhs_type);
bool lhs_is_float_or_vec = lhs_type->is_float_scalar_or_vector();
bool lhs_is_unsigned = lhs_type->is_unsigned_scalar_or_vector();
spv::Op op = spv::Op::OpNop;
if (expr->IsAnd()) {
@@ -660,39 +630,39 @@ uint32_t Builder::GenerateBinaryExpression(ast::BinaryExpression* expr) {
op = spv::Op::OpSMod;
}
} else if (expr->IsMultiply()) {
if (is_integer_scalar_or_vector(lhs_type)) {
if (lhs_type->is_integer_scalar_or_vector()) {
// If the left hand side is an integer then this _has_ to be OpIMul as
// there there is no other integer multiplication.
op = spv::Op::OpIMul;
} else if (is_float_scalar(lhs_type) && is_float_scalar(rhs_type)) {
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_scalar()) {
// Float scalars multiply with OpFMul
op = spv::Op::OpFMul;
} else if (is_float_vector(lhs_type) && is_float_vector(rhs_type)) {
} else if (lhs_type->is_float_vector() && rhs_type->is_float_vector()) {
// Float vectors must be validated to be the same size and then use OpFMul
op = spv::Op::OpFMul;
} else if (is_float_scalar(lhs_type) && is_float_vector(rhs_type)) {
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_vector()) {
// Scalar * Vector we need to flip lhs and rhs types
// because OpVectorTimesScalar expects <vector>, <scalar>
std::swap(lhs_id, rhs_id);
op = spv::Op::OpVectorTimesScalar;
} else if (is_float_vector(lhs_type) && is_float_scalar(rhs_type)) {
} else if (lhs_type->is_float_vector() && rhs_type->is_float_scalar()) {
// float vector * scalar
op = spv::Op::OpVectorTimesScalar;
} else if (is_float_scalar(lhs_type) && is_float_matrix(rhs_type)) {
} else if (lhs_type->is_float_scalar() && rhs_type->is_float_matrix()) {
// Scalar * Matrix we need to flip lhs and rhs types because
// OpMatrixTimesScalar expects <matrix>, <scalar>
std::swap(lhs_id, rhs_id);
op = spv::Op::OpMatrixTimesScalar;
} else if (is_float_matrix(lhs_type) && is_float_scalar(rhs_type)) {
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_scalar()) {
// float matrix * scalar
op = spv::Op::OpMatrixTimesScalar;
} else if (is_float_vector(lhs_type) && is_float_matrix(rhs_type)) {
} else if (lhs_type->is_float_vector() && rhs_type->is_float_matrix()) {
// float vector * matrix
op = spv::Op::OpVectorTimesMatrix;
} else if (is_float_matrix(lhs_type) && is_float_vector(rhs_type)) {
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_vector()) {
// float matrix * vector
op = spv::Op::OpMatrixTimesVector;
} else if (is_float_matrix(lhs_type) && is_float_matrix(rhs_type)) {
} else if (lhs_type->is_float_matrix() && rhs_type->is_float_matrix()) {
// float matrix * matrix
op = spv::Op::OpMatrixTimesMatrix;
} else {