diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 8d455bd2ef..34fc9224cf 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -617,6 +617,76 @@ bool TypeDeterminer::DetermineUnaryOp(ast::UnaryOpExpression* expr) { return true; } +// Most of these are floating-point general except the below which are only +// FP16 and FP32. We only have FP32 at this point so the below works, if we +// get FP64 support or otherwise we'll need to differentiate. +// * radians +// * degrees +// * sin, cos, tan +// * asin, acos, atan +// * sinh, cosh, tanh +// * asinh, acosh, atanh +// * exp, exp2 +// * log, log2 +enum class GlslDataType { kFloatScalarOrVector, kIntScalarOrVector }; +struct GlslData { + const char* name; + uint8_t param_count; + uint32_t op_id; + GlslDataType type; +}; + +constexpr const GlslData kGlslData[] = { + {"acos", 1, GLSLstd450Acos, GlslDataType::kFloatScalarOrVector}, + {"acosh", 1, GLSLstd450Acosh, GlslDataType::kFloatScalarOrVector}, + {"asin", 1, GLSLstd450Asin, GlslDataType::kFloatScalarOrVector}, + {"asinh", 1, GLSLstd450Asinh, GlslDataType::kFloatScalarOrVector}, + {"atan", 1, GLSLstd450Atan, GlslDataType::kFloatScalarOrVector}, + {"atan2", 2, GLSLstd450Atan2, GlslDataType::kFloatScalarOrVector}, + {"atanh", 1, GLSLstd450Atanh, GlslDataType::kFloatScalarOrVector}, + {"ceil", 1, GLSLstd450Ceil, GlslDataType::kFloatScalarOrVector}, + {"cos", 1, GLSLstd450Cos, GlslDataType::kFloatScalarOrVector}, + {"cosh", 1, GLSLstd450Cosh, GlslDataType::kFloatScalarOrVector}, + {"degrees", 1, GLSLstd450Degrees, GlslDataType::kFloatScalarOrVector}, + {"distance", 2, GLSLstd450Distance, GlslDataType::kFloatScalarOrVector}, + {"exp", 1, GLSLstd450Exp, GlslDataType::kFloatScalarOrVector}, + {"exp2", 1, GLSLstd450Exp2, GlslDataType::kFloatScalarOrVector}, + {"fabs", 1, GLSLstd450FAbs, GlslDataType::kFloatScalarOrVector}, + {"faceforward", 3, GLSLstd450FaceForward, + GlslDataType::kFloatScalarOrVector}, + {"fclamp", 3, GLSLstd450FClamp, GlslDataType::kFloatScalarOrVector}, + {"floor", 1, GLSLstd450Floor, GlslDataType::kFloatScalarOrVector}, + {"fma", 3, GLSLstd450Fma, GlslDataType::kFloatScalarOrVector}, + {"fmax", 2, GLSLstd450FMax, GlslDataType::kFloatScalarOrVector}, + {"fmin", 2, GLSLstd450FMin, GlslDataType::kFloatScalarOrVector}, + {"fmix", 3, GLSLstd450FMix, GlslDataType::kFloatScalarOrVector}, + {"fract", 1, GLSLstd450Fract, GlslDataType::kFloatScalarOrVector}, + {"fsign", 1, GLSLstd450FSign, GlslDataType::kFloatScalarOrVector}, + {"inversesqrt", 1, GLSLstd450InverseSqrt, + GlslDataType::kFloatScalarOrVector}, + {"length", 1, GLSLstd450Length, GlslDataType::kFloatScalarOrVector}, + {"log", 1, GLSLstd450Log, GlslDataType::kFloatScalarOrVector}, + {"log2", 1, GLSLstd450Log2, GlslDataType::kFloatScalarOrVector}, + {"nclamp", 3, GLSLstd450NClamp, GlslDataType::kFloatScalarOrVector}, + {"nmax", 2, GLSLstd450NMax, GlslDataType::kFloatScalarOrVector}, + {"nmin", 2, GLSLstd450NMin, GlslDataType::kFloatScalarOrVector}, + {"normalize", 1, GLSLstd450Normalize, GlslDataType::kFloatScalarOrVector}, + {"pow", 2, GLSLstd450Pow, GlslDataType::kFloatScalarOrVector}, + {"radians", 1, GLSLstd450Radians, GlslDataType::kFloatScalarOrVector}, + {"reflect", 2, GLSLstd450Reflect, GlslDataType::kFloatScalarOrVector}, + {"round", 1, GLSLstd450Round, GlslDataType::kFloatScalarOrVector}, + {"roundeven", 1, GLSLstd450RoundEven, GlslDataType::kFloatScalarOrVector}, + {"sin", 1, GLSLstd450Sin, GlslDataType::kFloatScalarOrVector}, + {"sinh", 1, GLSLstd450Sinh, GlslDataType::kFloatScalarOrVector}, + {"smoothstep", 3, GLSLstd450SmoothStep, GlslDataType::kFloatScalarOrVector}, + {"sqrt", 1, GLSLstd450Sqrt, GlslDataType::kFloatScalarOrVector}, + {"step", 2, GLSLstd450Step, GlslDataType::kFloatScalarOrVector}, + {"tan", 1, GLSLstd450Tan, GlslDataType::kFloatScalarOrVector}, + {"tanh", 1, GLSLstd450Tanh, GlslDataType::kFloatScalarOrVector}, + {"trunc", 1, GLSLstd450Trunc, GlslDataType::kFloatScalarOrVector}, +}; +constexpr const uint32_t kGlslDataCount = sizeof(kGlslData) / sizeof(GlslData); + ast::type::Type* TypeDeterminer::GetImportData( const Source& source, const std::string& path, @@ -627,197 +697,59 @@ ast::type::Type* TypeDeterminer::GetImportData( return nullptr; } - // Most of these are floating-point general except the below which are only - // FP16 and FP32. We only have FP32 at this point so the below works, if we - // get FP64 support or otherwise we'll need to differentiate. - // * radians - // * degrees - // * sin, cos, tan - // * asin, acos, atan - // * sinh, cosh, tanh - // * asinh, acosh, atanh - // * exp, exp2 - // * log, log2 - - if (name == "round" || name == "roundeven" || name == "trunc" || - name == "fabs" || name == "fsign" || name == "floor" || name == "ceil" || - name == "fract" || name == "radians" || name == "degrees" || - name == "sin" || name == "cos" || name == "tan" || name == "asin" || - name == "acos" || name == "atan" || name == "sinh" || name == "cosh" || - name == "tanh" || name == "asinh" || name == "acosh" || name == "atanh" || - name == "exp" || name == "log" || name == "exp2" || name == "log2" || - name == "sqrt" || name == "inversesqrt" || name == "normalize" || - name == "length") { - if (params.size() != 1) { - set_error(source, "incorrect number of parameters for " + name + - ". Expected 1 got " + - std::to_string(params.size())); - return nullptr; + const GlslData* data = nullptr; + for (uint32_t i = 0; i < kGlslDataCount; ++i) { + if (name == kGlslData[i].name) { + data = &kGlslData[i]; + break; } - - auto* result_type = params[0]->result_type()->UnwrapPtrIfNeeded(); - if (!result_type->is_float_scalar_or_vector()) { - set_error(source, "incorrect type for " + name + - ". Requires a float scalar or a float vector"); - return nullptr; - } - - if (name == "round") { - *id = GLSLstd450Round; - } else if (name == "roundeven") { - *id = GLSLstd450RoundEven; - } else if (name == "trunc") { - *id = GLSLstd450Trunc; - } else if (name == "fabs") { - *id = GLSLstd450FAbs; - } else if (name == "fsign") { - *id = GLSLstd450FSign; - } else if (name == "floor") { - *id = GLSLstd450Floor; - } else if (name == "ceil") { - *id = GLSLstd450Ceil; - } else if (name == "fract") { - *id = GLSLstd450Fract; - } else if (name == "radians") { - *id = GLSLstd450Radians; - } else if (name == "degrees") { - *id = GLSLstd450Degrees; - } else if (name == "sin") { - *id = GLSLstd450Sin; - } else if (name == "cos") { - *id = GLSLstd450Cos; - } else if (name == "tan") { - *id = GLSLstd450Tan; - } else if (name == "asin") { - *id = GLSLstd450Asin; - } else if (name == "acos") { - *id = GLSLstd450Acos; - } else if (name == "atan") { - *id = GLSLstd450Atan; - } else if (name == "sinh") { - *id = GLSLstd450Sinh; - } else if (name == "cosh") { - *id = GLSLstd450Cosh; - } else if (name == "tanh") { - *id = GLSLstd450Tanh; - } else if (name == "asinh") { - *id = GLSLstd450Asinh; - } else if (name == "acosh") { - *id = GLSLstd450Acosh; - } else if (name == "atanh") { - *id = GLSLstd450Atanh; - } else if (name == "exp") { - *id = GLSLstd450Exp; - } else if (name == "log") { - *id = GLSLstd450Log; - } else if (name == "exp2") { - *id = GLSLstd450Exp2; - } else if (name == "log2") { - *id = GLSLstd450Log2; - } else if (name == "sqrt") { - *id = GLSLstd450Sqrt; - } else if (name == "inversesqrt") { - *id = GLSLstd450InverseSqrt; - } else if (name == "normalize") { - *id = GLSLstd450Normalize; - } else if (name == "length") { - *id = GLSLstd450Length; - - // Length returns a scalar of the same type as the parameter. - return result_type->is_float_scalar() ? result_type - : result_type->AsVector()->type(); - } - - return result_type; - } else if (name == "atan2" || name == "pow" || name == "fmin" || - name == "fmax" || name == "step" || name == "reflect" || - name == "nmin" || name == "nmax" || name == "distance") { - if (params.size() != 2) { - error_ = "incorrect number of parameters for " + name + - ". Expected 2 got " + std::to_string(params.size()); - return nullptr; - } - - auto* result_type_0 = params[0]->result_type()->UnwrapPtrIfNeeded(); - auto* result_type_1 = params[1]->result_type()->UnwrapPtrIfNeeded(); - if (!result_type_0->is_float_scalar_or_vector() || - !result_type_1->is_float_scalar_or_vector()) { - error_ = "incorrect type for " + name + - ". Requires float scalar or a float vector values"; - return nullptr; - } - if (result_type_0 != result_type_1) { - error_ = "mismatched parameter types for " + name; - return nullptr; - } - - if (name == "atan2") { - *id = GLSLstd450Atan2; - } else if (name == "pow") { - *id = GLSLstd450Pow; - } else if (name == "fmin") { - *id = GLSLstd450FMin; - } else if (name == "fmax") { - *id = GLSLstd450FMax; - } else if (name == "step") { - *id = GLSLstd450Step; - } else if (name == "reflect") { - *id = GLSLstd450Reflect; - } else if (name == "nmin") { - *id = GLSLstd450NMin; - } else if (name == "nmax") { - *id = GLSLstd450NMax; - } else if (name == "distance") { - *id = GLSLstd450Distance; - - // Distance returns a scalar of the same type as the parameter. - return result_type_0->is_float_scalar() - ? result_type_0 - : result_type_0->AsVector()->type(); - } - - return result_type_0; - } else if (name == "fclamp" || name == "fmix" || name == "smoothstep" || - name == "fma" || name == "nclamp" || name == "faceforward") { - if (params.size() != 3) { - error_ = "incorrect number of parameters for " + name + - ". Expected 3 got " + std::to_string(params.size()); - return nullptr; - } - - auto* result_type_0 = params[0]->result_type()->UnwrapPtrIfNeeded(); - auto* result_type_1 = params[1]->result_type()->UnwrapPtrIfNeeded(); - auto* result_type_2 = params[2]->result_type()->UnwrapPtrIfNeeded(); - if (!result_type_0->is_float_scalar_or_vector() || - !result_type_1->is_float_scalar_or_vector() || - !result_type_2->is_float_scalar_or_vector()) { - error_ = "incorrect type for " + name + - ". Requires float scalar or a float vector values"; - return nullptr; - } - if (result_type_0 != result_type_1 || result_type_0 != result_type_2) { - error_ = "mismatched parameter types for " + name; - return nullptr; - } - - if (name == "fclamp") { - *id = GLSLstd450FClamp; - } else if (name == "fmix") { - *id = GLSLstd450FMix; - } else if (name == "smoothstep") { - *id = GLSLstd450SmoothStep; - } else if (name == "fma") { - *id = GLSLstd450Fma; - } else if (name == "nclamp") { - *id = GLSLstd450NClamp; - } else if (name == "faceforward") { - *id = GLSLstd450FaceForward; - } - - return result_type_0; + } + if (data == nullptr) { + return nullptr; } - return nullptr; + if (params.size() != data->param_count) { + set_error(source, "incorrect number of parameters for " + name + + ". Expected " + std::to_string(data->param_count) + + " got " + std::to_string(params.size())); + return nullptr; + } + + std::vector result_types; + for (uint32_t i = 0; i < data->param_count; ++i) { + result_types.push_back(params[i]->result_type()->UnwrapPtrIfNeeded()); + + switch (data->type) { + case GlslDataType::kFloatScalarOrVector: + if (!result_types.back()->is_float_scalar_or_vector()) { + set_error(source, "incorrect type for " + name + ". " + + "Requires float scalar or float vector values"); + return nullptr; + } + + break; + case GlslDataType::kIntScalarOrVector: + break; + } + } + + // Verify all the parameter types match + for (size_t i = 1; i < data->param_count; ++i) { + if (result_types[0] != result_types[i]) { + error_ = "mismatched parameter types for " + name; + return nullptr; + } + } + + *id = data->op_id; + + // Handle functions which aways return the type, even if a vector is provided. + if (name == "length" || name == "distance") { + return result_types[0]->is_float_scalar() + ? result_types[0] + : result_types[0]->AsVector()->type(); + } + return result_types[0]; } } // namespace tint diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 1cdca1aba0..4c24c90129 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -1787,8 +1787,9 @@ TEST_P(ImportData_SingleParamTest, Error_Integer) { auto* type = td()->GetImportData({0, 0}, "GLSL.std.450", param.name, params, &id); ASSERT_EQ(type, nullptr); - EXPECT_EQ(td()->error(), std::string("incorrect type for ") + param.name + - ". Requires a float scalar or a float vector"); + EXPECT_EQ(td()->error(), + std::string("incorrect type for ") + param.name + + ". Requires float scalar or float vector values"); } TEST_P(ImportData_SingleParamTest, Error_NoParams) { @@ -1914,9 +1915,9 @@ TEST_F(TypeDeterminerTest, ImportData_Length_Error_Integer) { auto* type = td()->GetImportData({0, 0}, "GLSL.std.450", "length", params, &id); ASSERT_EQ(type, nullptr); - EXPECT_EQ( - td()->error(), - "incorrect type for length. Requires a float scalar or a float vector"); + EXPECT_EQ(td()->error(), + "incorrect type for length. Requires float scalar or float vector " + "values"); } TEST_F(TypeDeterminerTest, ImportData_Length_Error_NoParams) { @@ -2030,7 +2031,7 @@ TEST_P(ImportData_TwoParamTest, Error_Integer) { ASSERT_EQ(type, nullptr); EXPECT_EQ(td()->error(), std::string("incorrect type for ") + param.name + - ". Requires float scalar or a float vector values"); + ". Requires float scalar or float vector values"); } TEST_P(ImportData_TwoParamTest, Error_NoParams) { @@ -2234,7 +2235,7 @@ TEST_F(TypeDeterminerTest, ImportData_Distance_Error_Integer) { td()->GetImportData({0, 0}, "GLSL.std.450", "distance", params, &id); ASSERT_EQ(type, nullptr); EXPECT_EQ(td()->error(), - "incorrect type for distance. Requires float scalar or a float " + "incorrect type for distance. Requires float scalar or float " "vector values"); } @@ -2440,7 +2441,7 @@ TEST_P(ImportData_ThreeParamTest, Error_Integer) { ASSERT_EQ(type, nullptr); EXPECT_EQ(td()->error(), std::string("incorrect type for ") + param.name + - ". Requires float scalar or a float vector values"); + ". Requires float scalar or float vector values"); } TEST_P(ImportData_ThreeParamTest, Error_NoParams) {