Refactor GLSL type determination code.

This Cl cleanups and simplifies the type determination for the GLSL
imports.

Change-Id: I9dd85ac390ef37c91d9493f840f81ceb6736fc06
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22820
Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
dan sinclair 2020-06-08 18:49:31 +00:00 committed by dan sinclair
parent b0d308c9fe
commit 53352044df
2 changed files with 129 additions and 196 deletions

View File

@ -617,16 +617,6 @@ bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) {
return true;
}
ast::type::Type* TypeDeterminer::GetImportData(
const Source& source,
const std::string& path,
const std::string& name,
const ast::ExpressionList& params,
uint32_t* id) {
if (path != "GLSL.std.450") {
return nullptr;
}
// Most of these are floating-point general except the below which are only
// FP16 and FP32. We only have FP32 at this point so the below works, if we
// get FP64 support or otherwise we'll need to differentiate.
@ -638,186 +628,128 @@ ast::type::Type* TypeDeterminer::GetImportData(
// * asinh, acosh, atanh
// * exp, exp2
// * log, log2
enum class GlslDataType { kFloatScalarOrVector, kIntScalarOrVector };
struct GlslData {
const char* name;
uint8_t param_count;
uint32_t op_id;
GlslDataType type;
};
if (name == "round" || name == "roundeven" || name == "trunc" ||
name == "fabs" || name == "fsign" || name == "floor" || name == "ceil" ||
name == "fract" || name == "radians" || name == "degrees" ||
name == "sin" || name == "cos" || name == "tan" || name == "asin" ||
name == "acos" || name == "atan" || name == "sinh" || name == "cosh" ||
name == "tanh" || name == "asinh" || name == "acosh" || name == "atanh" ||
name == "exp" || name == "log" || name == "exp2" || name == "log2" ||
name == "sqrt" || name == "inversesqrt" || name == "normalize" ||
name == "length") {
if (params.size() != 1) {
constexpr const GlslData kGlslData[] = {
{"acos", 1, GLSLstd450Acos, GlslDataType::kFloatScalarOrVector},
{"acosh", 1, GLSLstd450Acosh, GlslDataType::kFloatScalarOrVector},
{"asin", 1, GLSLstd450Asin, GlslDataType::kFloatScalarOrVector},
{"asinh", 1, GLSLstd450Asinh, GlslDataType::kFloatScalarOrVector},
{"atan", 1, GLSLstd450Atan, GlslDataType::kFloatScalarOrVector},
{"atan2", 2, GLSLstd450Atan2, GlslDataType::kFloatScalarOrVector},
{"atanh", 1, GLSLstd450Atanh, GlslDataType::kFloatScalarOrVector},
{"ceil", 1, GLSLstd450Ceil, GlslDataType::kFloatScalarOrVector},
{"cos", 1, GLSLstd450Cos, GlslDataType::kFloatScalarOrVector},
{"cosh", 1, GLSLstd450Cosh, GlslDataType::kFloatScalarOrVector},
{"degrees", 1, GLSLstd450Degrees, GlslDataType::kFloatScalarOrVector},
{"distance", 2, GLSLstd450Distance, GlslDataType::kFloatScalarOrVector},
{"exp", 1, GLSLstd450Exp, GlslDataType::kFloatScalarOrVector},
{"exp2", 1, GLSLstd450Exp2, GlslDataType::kFloatScalarOrVector},
{"fabs", 1, GLSLstd450FAbs, GlslDataType::kFloatScalarOrVector},
{"faceforward", 3, GLSLstd450FaceForward,
GlslDataType::kFloatScalarOrVector},
{"fclamp", 3, GLSLstd450FClamp, GlslDataType::kFloatScalarOrVector},
{"floor", 1, GLSLstd450Floor, GlslDataType::kFloatScalarOrVector},
{"fma", 3, GLSLstd450Fma, GlslDataType::kFloatScalarOrVector},
{"fmax", 2, GLSLstd450FMax, GlslDataType::kFloatScalarOrVector},
{"fmin", 2, GLSLstd450FMin, GlslDataType::kFloatScalarOrVector},
{"fmix", 3, GLSLstd450FMix, GlslDataType::kFloatScalarOrVector},
{"fract", 1, GLSLstd450Fract, GlslDataType::kFloatScalarOrVector},
{"fsign", 1, GLSLstd450FSign, GlslDataType::kFloatScalarOrVector},
{"inversesqrt", 1, GLSLstd450InverseSqrt,
GlslDataType::kFloatScalarOrVector},
{"length", 1, GLSLstd450Length, GlslDataType::kFloatScalarOrVector},
{"log", 1, GLSLstd450Log, GlslDataType::kFloatScalarOrVector},
{"log2", 1, GLSLstd450Log2, GlslDataType::kFloatScalarOrVector},
{"nclamp", 3, GLSLstd450NClamp, GlslDataType::kFloatScalarOrVector},
{"nmax", 2, GLSLstd450NMax, GlslDataType::kFloatScalarOrVector},
{"nmin", 2, GLSLstd450NMin, GlslDataType::kFloatScalarOrVector},
{"normalize", 1, GLSLstd450Normalize, GlslDataType::kFloatScalarOrVector},
{"pow", 2, GLSLstd450Pow, GlslDataType::kFloatScalarOrVector},
{"radians", 1, GLSLstd450Radians, GlslDataType::kFloatScalarOrVector},
{"reflect", 2, GLSLstd450Reflect, GlslDataType::kFloatScalarOrVector},
{"round", 1, GLSLstd450Round, GlslDataType::kFloatScalarOrVector},
{"roundeven", 1, GLSLstd450RoundEven, GlslDataType::kFloatScalarOrVector},
{"sin", 1, GLSLstd450Sin, GlslDataType::kFloatScalarOrVector},
{"sinh", 1, GLSLstd450Sinh, GlslDataType::kFloatScalarOrVector},
{"smoothstep", 3, GLSLstd450SmoothStep, GlslDataType::kFloatScalarOrVector},
{"sqrt", 1, GLSLstd450Sqrt, GlslDataType::kFloatScalarOrVector},
{"step", 2, GLSLstd450Step, GlslDataType::kFloatScalarOrVector},
{"tan", 1, GLSLstd450Tan, GlslDataType::kFloatScalarOrVector},
{"tanh", 1, GLSLstd450Tanh, GlslDataType::kFloatScalarOrVector},
{"trunc", 1, GLSLstd450Trunc, GlslDataType::kFloatScalarOrVector},
};
constexpr const uint32_t kGlslDataCount = sizeof(kGlslData) / sizeof(GlslData);
ast::type::Type* TypeDeterminer::GetImportData(
const Source& source,
const std::string& path,
const std::string& name,
const ast::ExpressionList& params,
uint32_t* id) {
if (path != "GLSL.std.450") {
return nullptr;
}
const GlslData* data = nullptr;
for (uint32_t i = 0; i < kGlslDataCount; ++i) {
if (name == kGlslData[i].name) {
data = &kGlslData[i];
break;
}
}
if (data == nullptr) {
return nullptr;
}
if (params.size() != data->param_count) {
set_error(source, "incorrect number of parameters for " + name +
". Expected 1 got " +
std::to_string(params.size()));
". Expected " + std::to_string(data->param_count) +
" got " + std::to_string(params.size()));
return nullptr;
}
auto* result_type = params[0]->result_type()->UnwrapPtrIfNeeded();
if (!result_type->is_float_scalar_or_vector()) {
set_error(source, "incorrect type for " + name +
". Requires a float scalar or a float vector");
std::vector<ast::type::Type*> result_types;
for (uint32_t i = 0; i < data->param_count; ++i) {
result_types.push_back(params[i]->result_type()->UnwrapPtrIfNeeded());
switch (data->type) {
case GlslDataType::kFloatScalarOrVector:
if (!result_types.back()->is_float_scalar_or_vector()) {
set_error(source, "incorrect type for " + name + ". " +
"Requires float scalar or float vector values");
return nullptr;
}
if (name == "round") {
*id = GLSLstd450Round;
} else if (name == "roundeven") {
*id = GLSLstd450RoundEven;
} else if (name == "trunc") {
*id = GLSLstd450Trunc;
} else if (name == "fabs") {
*id = GLSLstd450FAbs;
} else if (name == "fsign") {
*id = GLSLstd450FSign;
} else if (name == "floor") {
*id = GLSLstd450Floor;
} else if (name == "ceil") {
*id = GLSLstd450Ceil;
} else if (name == "fract") {
*id = GLSLstd450Fract;
} else if (name == "radians") {
*id = GLSLstd450Radians;
} else if (name == "degrees") {
*id = GLSLstd450Degrees;
} else if (name == "sin") {
*id = GLSLstd450Sin;
} else if (name == "cos") {
*id = GLSLstd450Cos;
} else if (name == "tan") {
*id = GLSLstd450Tan;
} else if (name == "asin") {
*id = GLSLstd450Asin;
} else if (name == "acos") {
*id = GLSLstd450Acos;
} else if (name == "atan") {
*id = GLSLstd450Atan;
} else if (name == "sinh") {
*id = GLSLstd450Sinh;
} else if (name == "cosh") {
*id = GLSLstd450Cosh;
} else if (name == "tanh") {
*id = GLSLstd450Tanh;
} else if (name == "asinh") {
*id = GLSLstd450Asinh;
} else if (name == "acosh") {
*id = GLSLstd450Acosh;
} else if (name == "atanh") {
*id = GLSLstd450Atanh;
} else if (name == "exp") {
*id = GLSLstd450Exp;
} else if (name == "log") {
*id = GLSLstd450Log;
} else if (name == "exp2") {
*id = GLSLstd450Exp2;
} else if (name == "log2") {
*id = GLSLstd450Log2;
} else if (name == "sqrt") {
*id = GLSLstd450Sqrt;
} else if (name == "inversesqrt") {
*id = GLSLstd450InverseSqrt;
} else if (name == "normalize") {
*id = GLSLstd450Normalize;
} else if (name == "length") {
*id = GLSLstd450Length;
// Length returns a scalar of the same type as the parameter.
return result_type->is_float_scalar() ? result_type
: result_type->AsVector()->type();
break;
case GlslDataType::kIntScalarOrVector:
break;
}
}
return result_type;
} else if (name == "atan2" || name == "pow" || name == "fmin" ||
name == "fmax" || name == "step" || name == "reflect" ||
name == "nmin" || name == "nmax" || name == "distance") {
if (params.size() != 2) {
error_ = "incorrect number of parameters for " + name +
". Expected 2 got " + std::to_string(params.size());
return nullptr;
}
auto* result_type_0 = params[0]->result_type()->UnwrapPtrIfNeeded();
auto* result_type_1 = params[1]->result_type()->UnwrapPtrIfNeeded();
if (!result_type_0->is_float_scalar_or_vector() ||
!result_type_1->is_float_scalar_or_vector()) {
error_ = "incorrect type for " + name +
". Requires float scalar or a float vector values";
return nullptr;
}
if (result_type_0 != result_type_1) {
// Verify all the parameter types match
for (size_t i = 1; i < data->param_count; ++i) {
if (result_types[0] != result_types[i]) {
error_ = "mismatched parameter types for " + name;
return nullptr;
}
if (name == "atan2") {
*id = GLSLstd450Atan2;
} else if (name == "pow") {
*id = GLSLstd450Pow;
} else if (name == "fmin") {
*id = GLSLstd450FMin;
} else if (name == "fmax") {
*id = GLSLstd450FMax;
} else if (name == "step") {
*id = GLSLstd450Step;
} else if (name == "reflect") {
*id = GLSLstd450Reflect;
} else if (name == "nmin") {
*id = GLSLstd450NMin;
} else if (name == "nmax") {
*id = GLSLstd450NMax;
} else if (name == "distance") {
*id = GLSLstd450Distance;
// Distance returns a scalar of the same type as the parameter.
return result_type_0->is_float_scalar()
? result_type_0
: result_type_0->AsVector()->type();
}
return result_type_0;
} else if (name == "fclamp" || name == "fmix" || name == "smoothstep" ||
name == "fma" || name == "nclamp" || name == "faceforward") {
if (params.size() != 3) {
error_ = "incorrect number of parameters for " + name +
". Expected 3 got " + std::to_string(params.size());
return nullptr;
}
*id = data->op_id;
auto* result_type_0 = params[0]->result_type()->UnwrapPtrIfNeeded();
auto* result_type_1 = params[1]->result_type()->UnwrapPtrIfNeeded();
auto* result_type_2 = params[2]->result_type()->UnwrapPtrIfNeeded();
if (!result_type_0->is_float_scalar_or_vector() ||
!result_type_1->is_float_scalar_or_vector() ||
!result_type_2->is_float_scalar_or_vector()) {
error_ = "incorrect type for " + name +
". Requires float scalar or a float vector values";
return nullptr;
// Handle functions which aways return the type, even if a vector is provided.
if (name == "length" || name == "distance") {
return result_types[0]->is_float_scalar()
? result_types[0]
: result_types[0]->AsVector()->type();
}
if (result_type_0 != result_type_1 || result_type_0 != result_type_2) {
error_ = "mismatched parameter types for " + name;
return nullptr;
}
if (name == "fclamp") {
*id = GLSLstd450FClamp;
} else if (name == "fmix") {
*id = GLSLstd450FMix;
} else if (name == "smoothstep") {
*id = GLSLstd450SmoothStep;
} else if (name == "fma") {
*id = GLSLstd450Fma;
} else if (name == "nclamp") {
*id = GLSLstd450NClamp;
} else if (name == "faceforward") {
*id = GLSLstd450FaceForward;
}
return result_type_0;
}
return nullptr;
return result_types[0];
}
} // namespace tint

View File

@ -1787,8 +1787,9 @@ TEST_P(ImportData_SingleParamTest, Error_Integer) {
auto* type =
td()->GetImportData({0, 0}, "GLSL.std.450", param.name, params, &id);
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(), std::string("incorrect type for ") + param.name +
". Requires a float scalar or a float vector");
EXPECT_EQ(td()->error(),
std::string("incorrect type for ") + param.name +
". Requires float scalar or float vector values");
}
TEST_P(ImportData_SingleParamTest, Error_NoParams) {
@ -1914,9 +1915,9 @@ TEST_F(TypeDeterminerTest, ImportData_Length_Error_Integer) {
auto* type =
td()->GetImportData({0, 0}, "GLSL.std.450", "length", params, &id);
ASSERT_EQ(type, nullptr);
EXPECT_EQ(
td()->error(),
"incorrect type for length. Requires a float scalar or a float vector");
EXPECT_EQ(td()->error(),
"incorrect type for length. Requires float scalar or float vector "
"values");
}
TEST_F(TypeDeterminerTest, ImportData_Length_Error_NoParams) {
@ -2030,7 +2031,7 @@ TEST_P(ImportData_TwoParamTest, Error_Integer) {
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
std::string("incorrect type for ") + param.name +
". Requires float scalar or a float vector values");
". Requires float scalar or float vector values");
}
TEST_P(ImportData_TwoParamTest, Error_NoParams) {
@ -2234,7 +2235,7 @@ TEST_F(TypeDeterminerTest, ImportData_Distance_Error_Integer) {
td()->GetImportData({0, 0}, "GLSL.std.450", "distance", params, &id);
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
"incorrect type for distance. Requires float scalar or a float "
"incorrect type for distance. Requires float scalar or float "
"vector values");
}
@ -2440,7 +2441,7 @@ TEST_P(ImportData_ThreeParamTest, Error_Integer) {
ASSERT_EQ(type, nullptr);
EXPECT_EQ(td()->error(),
std::string("incorrect type for ") + param.name +
". Requires float scalar or a float vector values");
". Requires float scalar or float vector values");
}
TEST_P(ImportData_ThreeParamTest, Error_NoParams) {