Refactor GLSL type determination code.
This Cl cleanups and simplifies the type determination for the GLSL imports. Change-Id: I9dd85ac390ef37c91d9493f840f81ceb6736fc06 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/22820 Reviewed-by: David Neto <dneto@google.com>
This commit is contained in:
parent
b0d308c9fe
commit
53352044df
|
@ -617,16 +617,6 @@ bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) {
|
|||
return true;
|
||||
}
|
||||
|
||||
ast::type::Type* TypeDeterminer::GetImportData(
|
||||
const Source& source,
|
||||
const std::string& path,
|
||||
const std::string& name,
|
||||
const ast::ExpressionList& params,
|
||||
uint32_t* id) {
|
||||
if (path != "GLSL.std.450") {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Most of these are floating-point general except the below which are only
|
||||
// FP16 and FP32. We only have FP32 at this point so the below works, if we
|
||||
// get FP64 support or otherwise we'll need to differentiate.
|
||||
|
@ -638,186 +628,128 @@ ast::type::Type* TypeDeterminer::GetImportData(
|
|||
// * asinh, acosh, atanh
|
||||
// * exp, exp2
|
||||
// * log, log2
|
||||
enum class GlslDataType { kFloatScalarOrVector, kIntScalarOrVector };
|
||||
struct GlslData {
|
||||
const char* name;
|
||||
uint8_t param_count;
|
||||
uint32_t op_id;
|
||||
GlslDataType type;
|
||||
};
|
||||
|
||||
if (name == "round" || name == "roundeven" || name == "trunc" ||
|
||||
name == "fabs" || name == "fsign" || name == "floor" || name == "ceil" ||
|
||||
name == "fract" || name == "radians" || name == "degrees" ||
|
||||
name == "sin" || name == "cos" || name == "tan" || name == "asin" ||
|
||||
name == "acos" || name == "atan" || name == "sinh" || name == "cosh" ||
|
||||
name == "tanh" || name == "asinh" || name == "acosh" || name == "atanh" ||
|
||||
name == "exp" || name == "log" || name == "exp2" || name == "log2" ||
|
||||
name == "sqrt" || name == "inversesqrt" || name == "normalize" ||
|
||||
name == "length") {
|
||||
if (params.size() != 1) {
|
||||
constexpr const GlslData kGlslData[] = {
|
||||
{"acos", 1, GLSLstd450Acos, GlslDataType::kFloatScalarOrVector},
|
||||
{"acosh", 1, GLSLstd450Acosh, GlslDataType::kFloatScalarOrVector},
|
||||
{"asin", 1, GLSLstd450Asin, GlslDataType::kFloatScalarOrVector},
|
||||
{"asinh", 1, GLSLstd450Asinh, GlslDataType::kFloatScalarOrVector},
|
||||
{"atan", 1, GLSLstd450Atan, GlslDataType::kFloatScalarOrVector},
|
||||
{"atan2", 2, GLSLstd450Atan2, GlslDataType::kFloatScalarOrVector},
|
||||
{"atanh", 1, GLSLstd450Atanh, GlslDataType::kFloatScalarOrVector},
|
||||
{"ceil", 1, GLSLstd450Ceil, GlslDataType::kFloatScalarOrVector},
|
||||
{"cos", 1, GLSLstd450Cos, GlslDataType::kFloatScalarOrVector},
|
||||
{"cosh", 1, GLSLstd450Cosh, GlslDataType::kFloatScalarOrVector},
|
||||
{"degrees", 1, GLSLstd450Degrees, GlslDataType::kFloatScalarOrVector},
|
||||
{"distance", 2, GLSLstd450Distance, GlslDataType::kFloatScalarOrVector},
|
||||
{"exp", 1, GLSLstd450Exp, GlslDataType::kFloatScalarOrVector},
|
||||
{"exp2", 1, GLSLstd450Exp2, GlslDataType::kFloatScalarOrVector},
|
||||
{"fabs", 1, GLSLstd450FAbs, GlslDataType::kFloatScalarOrVector},
|
||||
{"faceforward", 3, GLSLstd450FaceForward,
|
||||
GlslDataType::kFloatScalarOrVector},
|
||||
{"fclamp", 3, GLSLstd450FClamp, GlslDataType::kFloatScalarOrVector},
|
||||
{"floor", 1, GLSLstd450Floor, GlslDataType::kFloatScalarOrVector},
|
||||
{"fma", 3, GLSLstd450Fma, GlslDataType::kFloatScalarOrVector},
|
||||
{"fmax", 2, GLSLstd450FMax, GlslDataType::kFloatScalarOrVector},
|
||||
{"fmin", 2, GLSLstd450FMin, GlslDataType::kFloatScalarOrVector},
|
||||
{"fmix", 3, GLSLstd450FMix, GlslDataType::kFloatScalarOrVector},
|
||||
{"fract", 1, GLSLstd450Fract, GlslDataType::kFloatScalarOrVector},
|
||||
{"fsign", 1, GLSLstd450FSign, GlslDataType::kFloatScalarOrVector},
|
||||
{"inversesqrt", 1, GLSLstd450InverseSqrt,
|
||||
GlslDataType::kFloatScalarOrVector},
|
||||
{"length", 1, GLSLstd450Length, GlslDataType::kFloatScalarOrVector},
|
||||
{"log", 1, GLSLstd450Log, GlslDataType::kFloatScalarOrVector},
|
||||
{"log2", 1, GLSLstd450Log2, GlslDataType::kFloatScalarOrVector},
|
||||
{"nclamp", 3, GLSLstd450NClamp, GlslDataType::kFloatScalarOrVector},
|
||||
{"nmax", 2, GLSLstd450NMax, GlslDataType::kFloatScalarOrVector},
|
||||
{"nmin", 2, GLSLstd450NMin, GlslDataType::kFloatScalarOrVector},
|
||||
{"normalize", 1, GLSLstd450Normalize, GlslDataType::kFloatScalarOrVector},
|
||||
{"pow", 2, GLSLstd450Pow, GlslDataType::kFloatScalarOrVector},
|
||||
{"radians", 1, GLSLstd450Radians, GlslDataType::kFloatScalarOrVector},
|
||||
{"reflect", 2, GLSLstd450Reflect, GlslDataType::kFloatScalarOrVector},
|
||||
{"round", 1, GLSLstd450Round, GlslDataType::kFloatScalarOrVector},
|
||||
{"roundeven", 1, GLSLstd450RoundEven, GlslDataType::kFloatScalarOrVector},
|
||||
{"sin", 1, GLSLstd450Sin, GlslDataType::kFloatScalarOrVector},
|
||||
{"sinh", 1, GLSLstd450Sinh, GlslDataType::kFloatScalarOrVector},
|
||||
{"smoothstep", 3, GLSLstd450SmoothStep, GlslDataType::kFloatScalarOrVector},
|
||||
{"sqrt", 1, GLSLstd450Sqrt, GlslDataType::kFloatScalarOrVector},
|
||||
{"step", 2, GLSLstd450Step, GlslDataType::kFloatScalarOrVector},
|
||||
{"tan", 1, GLSLstd450Tan, GlslDataType::kFloatScalarOrVector},
|
||||
{"tanh", 1, GLSLstd450Tanh, GlslDataType::kFloatScalarOrVector},
|
||||
{"trunc", 1, GLSLstd450Trunc, GlslDataType::kFloatScalarOrVector},
|
||||
};
|
||||
constexpr const uint32_t kGlslDataCount = sizeof(kGlslData) / sizeof(GlslData);
|
||||
|
||||
ast::type::Type* TypeDeterminer::GetImportData(
|
||||
const Source& source,
|
||||
const std::string& path,
|
||||
const std::string& name,
|
||||
const ast::ExpressionList& params,
|
||||
uint32_t* id) {
|
||||
if (path != "GLSL.std.450") {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const GlslData* data = nullptr;
|
||||
for (uint32_t i = 0; i < kGlslDataCount; ++i) {
|
||||
if (name == kGlslData[i].name) {
|
||||
data = &kGlslData[i];
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (data == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (params.size() != data->param_count) {
|
||||
set_error(source, "incorrect number of parameters for " + name +
|
||||
". Expected 1 got " +
|
||||
std::to_string(params.size()));
|
||||
". Expected " + std::to_string(data->param_count) +
|
||||
" got " + std::to_string(params.size()));
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* result_type = params[0]->result_type()->UnwrapPtrIfNeeded();
|
||||
if (!result_type->is_float_scalar_or_vector()) {
|
||||
set_error(source, "incorrect type for " + name +
|
||||
". Requires a float scalar or a float vector");
|
||||
std::vector<ast::type::Type*> result_types;
|
||||
for (uint32_t i = 0; i < data->param_count; ++i) {
|
||||
result_types.push_back(params[i]->result_type()->UnwrapPtrIfNeeded());
|
||||
|
||||
switch (data->type) {
|
||||
case GlslDataType::kFloatScalarOrVector:
|
||||
if (!result_types.back()->is_float_scalar_or_vector()) {
|
||||
set_error(source, "incorrect type for " + name + ". " +
|
||||
"Requires float scalar or float vector values");
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (name == "round") {
|
||||
*id = GLSLstd450Round;
|
||||
} else if (name == "roundeven") {
|
||||
*id = GLSLstd450RoundEven;
|
||||
} else if (name == "trunc") {
|
||||
*id = GLSLstd450Trunc;
|
||||
} else if (name == "fabs") {
|
||||
*id = GLSLstd450FAbs;
|
||||
} else if (name == "fsign") {
|
||||
*id = GLSLstd450FSign;
|
||||
} else if (name == "floor") {
|
||||
*id = GLSLstd450Floor;
|
||||
} else if (name == "ceil") {
|
||||
*id = GLSLstd450Ceil;
|
||||
} else if (name == "fract") {
|
||||
*id = GLSLstd450Fract;
|
||||
} else if (name == "radians") {
|
||||
*id = GLSLstd450Radians;
|
||||
} else if (name == "degrees") {
|
||||
*id = GLSLstd450Degrees;
|
||||
} else if (name == "sin") {
|
||||
*id = GLSLstd450Sin;
|
||||
} else if (name == "cos") {
|
||||
*id = GLSLstd450Cos;
|
||||
} else if (name == "tan") {
|
||||
*id = GLSLstd450Tan;
|
||||
} else if (name == "asin") {
|
||||
*id = GLSLstd450Asin;
|
||||
} else if (name == "acos") {
|
||||
*id = GLSLstd450Acos;
|
||||
} else if (name == "atan") {
|
||||
*id = GLSLstd450Atan;
|
||||
} else if (name == "sinh") {
|
||||
*id = GLSLstd450Sinh;
|
||||
} else if (name == "cosh") {
|
||||
*id = GLSLstd450Cosh;
|
||||
} else if (name == "tanh") {
|
||||
*id = GLSLstd450Tanh;
|
||||
} else if (name == "asinh") {
|
||||
*id = GLSLstd450Asinh;
|
||||
} else if (name == "acosh") {
|
||||
*id = GLSLstd450Acosh;
|
||||
} else if (name == "atanh") {
|
||||
*id = GLSLstd450Atanh;
|
||||
} else if (name == "exp") {
|
||||
*id = GLSLstd450Exp;
|
||||
} else if (name == "log") {
|
||||
*id = GLSLstd450Log;
|
||||
} else if (name == "exp2") {
|
||||
*id = GLSLstd450Exp2;
|
||||
} else if (name == "log2") {
|
||||
*id = GLSLstd450Log2;
|
||||
} else if (name == "sqrt") {
|
||||
*id = GLSLstd450Sqrt;
|
||||
} else if (name == "inversesqrt") {
|
||||
*id = GLSLstd450InverseSqrt;
|
||||
} else if (name == "normalize") {
|
||||
*id = GLSLstd450Normalize;
|
||||
} else if (name == "length") {
|
||||
*id = GLSLstd450Length;
|
||||
|
||||
// Length returns a scalar of the same type as the parameter.
|
||||
return result_type->is_float_scalar() ? result_type
|
||||
: result_type->AsVector()->type();
|
||||
break;
|
||||
case GlslDataType::kIntScalarOrVector:
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return result_type;
|
||||
} else if (name == "atan2" || name == "pow" || name == "fmin" ||
|
||||
name == "fmax" || name == "step" || name == "reflect" ||
|
||||
name == "nmin" || name == "nmax" || name == "distance") {
|
||||
if (params.size() != 2) {
|
||||
error_ = "incorrect number of parameters for " + name +
|
||||
". Expected 2 got " + std::to_string(params.size());
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto* result_type_0 = params[0]->result_type()->UnwrapPtrIfNeeded();
|
||||
auto* result_type_1 = params[1]->result_type()->UnwrapPtrIfNeeded();
|
||||
if (!result_type_0->is_float_scalar_or_vector() ||
|
||||
!result_type_1->is_float_scalar_or_vector()) {
|
||||
error_ = "incorrect type for " + name +
|
||||
". Requires float scalar or a float vector values";
|
||||
return nullptr;
|
||||
}
|
||||
if (result_type_0 != result_type_1) {
|
||||
// Verify all the parameter types match
|
||||
for (size_t i = 1; i < data->param_count; ++i) {
|
||||
if (result_types[0] != result_types[i]) {
|
||||
error_ = "mismatched parameter types for " + name;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (name == "atan2") {
|
||||
*id = GLSLstd450Atan2;
|
||||
} else if (name == "pow") {
|
||||
*id = GLSLstd450Pow;
|
||||
} else if (name == "fmin") {
|
||||
*id = GLSLstd450FMin;
|
||||
} else if (name == "fmax") {
|
||||
*id = GLSLstd450FMax;
|
||||
} else if (name == "step") {
|
||||
*id = GLSLstd450Step;
|
||||
} else if (name == "reflect") {
|
||||
*id = GLSLstd450Reflect;
|
||||
} else if (name == "nmin") {
|
||||
*id = GLSLstd450NMin;
|
||||
} else if (name == "nmax") {
|
||||
*id = GLSLstd450NMax;
|
||||
} else if (name == "distance") {
|
||||
*id = GLSLstd450Distance;
|
||||
|
||||
// Distance returns a scalar of the same type as the parameter.
|
||||
return result_type_0->is_float_scalar()
|
||||
? result_type_0
|
||||
: result_type_0->AsVector()->type();
|
||||
}
|
||||
|
||||
return result_type_0;
|
||||
} else if (name == "fclamp" || name == "fmix" || name == "smoothstep" ||
|
||||
name == "fma" || name == "nclamp" || name == "faceforward") {
|
||||
if (params.size() != 3) {
|
||||
error_ = "incorrect number of parameters for " + name +
|
||||
". Expected 3 got " + std::to_string(params.size());
|
||||
return nullptr;
|
||||
}
|
||||
*id = data->op_id;
|
||||
|
||||
auto* result_type_0 = params[0]->result_type()->UnwrapPtrIfNeeded();
|
||||
auto* result_type_1 = params[1]->result_type()->UnwrapPtrIfNeeded();
|
||||
auto* result_type_2 = params[2]->result_type()->UnwrapPtrIfNeeded();
|
||||
if (!result_type_0->is_float_scalar_or_vector() ||
|
||||
!result_type_1->is_float_scalar_or_vector() ||
|
||||
!result_type_2->is_float_scalar_or_vector()) {
|
||||
error_ = "incorrect type for " + name +
|
||||
". Requires float scalar or a float vector values";
|
||||
return nullptr;
|
||||
// Handle functions which aways return the type, even if a vector is provided.
|
||||
if (name == "length" || name == "distance") {
|
||||
return result_types[0]->is_float_scalar()
|
||||
? result_types[0]
|
||||
: result_types[0]->AsVector()->type();
|
||||
}
|
||||
if (result_type_0 != result_type_1 || result_type_0 != result_type_2) {
|
||||
error_ = "mismatched parameter types for " + name;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (name == "fclamp") {
|
||||
*id = GLSLstd450FClamp;
|
||||
} else if (name == "fmix") {
|
||||
*id = GLSLstd450FMix;
|
||||
} else if (name == "smoothstep") {
|
||||
*id = GLSLstd450SmoothStep;
|
||||
} else if (name == "fma") {
|
||||
*id = GLSLstd450Fma;
|
||||
} else if (name == "nclamp") {
|
||||
*id = GLSLstd450NClamp;
|
||||
} else if (name == "faceforward") {
|
||||
*id = GLSLstd450FaceForward;
|
||||
}
|
||||
|
||||
return result_type_0;
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
return result_types[0];
|
||||
}
|
||||
|
||||
} // namespace tint
|
||||
|
|
|
@ -1787,8 +1787,9 @@ TEST_P(ImportData_SingleParamTest, Error_Integer) {
|
|||
auto* type =
|
||||
td()->GetImportData({0, 0}, "GLSL.std.450", param.name, params, &id);
|
||||
ASSERT_EQ(type, nullptr);
|
||||
EXPECT_EQ(td()->error(), std::string("incorrect type for ") + param.name +
|
||||
". Requires a float scalar or a float vector");
|
||||
EXPECT_EQ(td()->error(),
|
||||
std::string("incorrect type for ") + param.name +
|
||||
". Requires float scalar or float vector values");
|
||||
}
|
||||
|
||||
TEST_P(ImportData_SingleParamTest, Error_NoParams) {
|
||||
|
@ -1914,9 +1915,9 @@ TEST_F(TypeDeterminerTest, ImportData_Length_Error_Integer) {
|
|||
auto* type =
|
||||
td()->GetImportData({0, 0}, "GLSL.std.450", "length", params, &id);
|
||||
ASSERT_EQ(type, nullptr);
|
||||
EXPECT_EQ(
|
||||
td()->error(),
|
||||
"incorrect type for length. Requires a float scalar or a float vector");
|
||||
EXPECT_EQ(td()->error(),
|
||||
"incorrect type for length. Requires float scalar or float vector "
|
||||
"values");
|
||||
}
|
||||
|
||||
TEST_F(TypeDeterminerTest, ImportData_Length_Error_NoParams) {
|
||||
|
@ -2030,7 +2031,7 @@ TEST_P(ImportData_TwoParamTest, Error_Integer) {
|
|||
ASSERT_EQ(type, nullptr);
|
||||
EXPECT_EQ(td()->error(),
|
||||
std::string("incorrect type for ") + param.name +
|
||||
". Requires float scalar or a float vector values");
|
||||
". Requires float scalar or float vector values");
|
||||
}
|
||||
|
||||
TEST_P(ImportData_TwoParamTest, Error_NoParams) {
|
||||
|
@ -2234,7 +2235,7 @@ TEST_F(TypeDeterminerTest, ImportData_Distance_Error_Integer) {
|
|||
td()->GetImportData({0, 0}, "GLSL.std.450", "distance", params, &id);
|
||||
ASSERT_EQ(type, nullptr);
|
||||
EXPECT_EQ(td()->error(),
|
||||
"incorrect type for distance. Requires float scalar or a float "
|
||||
"incorrect type for distance. Requires float scalar or float "
|
||||
"vector values");
|
||||
}
|
||||
|
||||
|
@ -2440,7 +2441,7 @@ TEST_P(ImportData_ThreeParamTest, Error_Integer) {
|
|||
ASSERT_EQ(type, nullptr);
|
||||
EXPECT_EQ(td()->error(),
|
||||
std::string("incorrect type for ") + param.name +
|
||||
". Requires float scalar or a float vector values");
|
||||
". Requires float scalar or float vector values");
|
||||
}
|
||||
|
||||
TEST_P(ImportData_ThreeParamTest, Error_NoParams) {
|
||||
|
|
Loading…
Reference in New Issue