diff --git a/src/ast/type/type.cc b/src/ast/type/type.cc index 55d045a779..f9e96949d9 100644 --- a/src/ast/type/type.cc +++ b/src/ast/type/type.cc @@ -103,6 +103,17 @@ bool Type::is_float_scalar_or_vector() { return is_float_scalar() || is_float_vector(); } +bool Type::is_integer_scalar() { + return IsU32() || IsI32(); +} + +bool Type::is_integer_vector() { + if (!IsVector()) { + return false; + } + return AsVector()->type()->IsU32() || AsVector()->type()->IsI32(); +} + bool Type::is_unsigned_scalar_or_vector() { return IsU32() || (IsVector() && AsVector()->type()->IsU32()); } diff --git a/src/ast/type/type.h b/src/ast/type/type.h index 52e8c81952..12c9b118eb 100644 --- a/src/ast/type/type.h +++ b/src/ast/type/type.h @@ -77,6 +77,10 @@ class Type { bool is_float_vector(); /// @returns true if this type is a float scalar or vector bool is_float_scalar_or_vector(); + /// @returns ture if this type is an integer scalar + bool is_integer_scalar(); + /// @returns ture if this type is an integer vector + bool is_integer_vector(); /// @returns true if this type is an unsigned scalar or vector bool is_unsigned_scalar_or_vector(); /// @returns true if this type is a signed scalar or vector diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 34fc9224cf..3978bfef39 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -676,10 +676,12 @@ constexpr const GlslData kGlslData[] = { {"reflect", 2, GLSLstd450Reflect, GlslDataType::kFloatScalarOrVector}, {"round", 1, GLSLstd450Round, GlslDataType::kFloatScalarOrVector}, {"roundeven", 1, GLSLstd450RoundEven, GlslDataType::kFloatScalarOrVector}, + {"sabs", 1, GLSLstd450SAbs, GlslDataType::kIntScalarOrVector}, {"sin", 1, GLSLstd450Sin, GlslDataType::kFloatScalarOrVector}, {"sinh", 1, GLSLstd450Sinh, GlslDataType::kFloatScalarOrVector}, {"smoothstep", 3, GLSLstd450SmoothStep, GlslDataType::kFloatScalarOrVector}, {"sqrt", 1, GLSLstd450Sqrt, GlslDataType::kFloatScalarOrVector}, + {"ssign", 1, GLSLstd450SSign, GlslDataType::kIntScalarOrVector}, {"step", 2, GLSLstd450Step, GlslDataType::kFloatScalarOrVector}, {"tan", 1, GLSLstd450Tan, GlslDataType::kFloatScalarOrVector}, {"tanh", 1, GLSLstd450Tanh, GlslDataType::kFloatScalarOrVector}, @@ -729,6 +731,12 @@ ast::type::Type* TypeDeterminer::GetImportData( break; case GlslDataType::kIntScalarOrVector: + if (!result_types.back()->is_integer_scalar_or_vector()) { + set_error(source, + "incorrect type for " + name + ". " + + "Requires integer scalar or integer vector values"); + return nullptr; + } break; } } diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index 4c24c90129..c229837a7b 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -2606,5 +2606,114 @@ INSTANTIATE_TEST_SUITE_P( GLSLData{"fma", GLSLstd450Fma}, GLSLData{"faceforward", GLSLstd450FaceForward}, GLSLData{"nclamp", GLSLstd450NClamp})); + +using ImportData_Int_SingleParamTest = TypeDeterminerTestWithParam; +TEST_P(ImportData_Int_SingleParamTest, Scalar) { + auto param = GetParam(); + + ast::type::I32Type i32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&i32, 1))); + + ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error(); + + uint32_t id = 0; + auto* type = + td()->GetImportData({0, 0}, "GLSL.std.450", param.name, params, &id); + ASSERT_NE(type, nullptr); + EXPECT_TRUE(type->is_integer_scalar()); + EXPECT_EQ(id, param.value); +} + +TEST_P(ImportData_Int_SingleParamTest, Vector) { + auto param = GetParam(); + + ast::type::I32Type i32; + ast::type::VectorType vec(&i32, 3); + + ast::ExpressionList vals; + vals.push_back(std::make_unique( + std::make_unique(&i32, 1))); + vals.push_back(std::make_unique( + std::make_unique(&i32, 1))); + vals.push_back(std::make_unique( + std::make_unique(&i32, 3))); + + ast::ExpressionList params; + params.push_back( + std::make_unique(&vec, std::move(vals))); + + ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error(); + + uint32_t id = 0; + auto* type = + td()->GetImportData({0, 0}, "GLSL.std.450", param.name, params, &id); + ASSERT_NE(type, nullptr); + EXPECT_TRUE(type->is_integer_vector()); + EXPECT_EQ(type->AsVector()->size(), 3u); + EXPECT_EQ(id, param.value); +} + +TEST_P(ImportData_Int_SingleParamTest, Error_Float) { + auto param = GetParam(); + + ast::type::F32Type f32; + + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&f32, 1.f))); + + ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error(); + + uint32_t id = 0; + auto* type = + td()->GetImportData({0, 0}, "GLSL.std.450", param.name, params, &id); + ASSERT_EQ(type, nullptr); + EXPECT_EQ(td()->error(), + std::string("incorrect type for ") + param.name + + ". Requires integer scalar or integer vector values"); +} + +TEST_P(ImportData_Int_SingleParamTest, Error_NoParams) { + auto param = GetParam(); + + ast::ExpressionList params; + uint32_t id = 0; + auto* type = + td()->GetImportData({0, 0}, "GLSL.std.450", param.name, params, &id); + ASSERT_EQ(type, nullptr); + EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + + param.name + ". Expected 1 got 0"); +} + +TEST_P(ImportData_Int_SingleParamTest, Error_MultipleParams) { + auto param = GetParam(); + + ast::type::I32Type i32; + ast::ExpressionList params; + params.push_back(std::make_unique( + std::make_unique(&i32, 1))); + params.push_back(std::make_unique( + std::make_unique(&i32, 1))); + params.push_back(std::make_unique( + std::make_unique(&i32, 1))); + + ASSERT_TRUE(td()->DetermineResultType(params)) << td()->error(); + + uint32_t id = 0; + auto* type = + td()->GetImportData({0, 0}, "GLSL.std.450", param.name, params, &id); + ASSERT_EQ(type, nullptr); + EXPECT_EQ(td()->error(), std::string("incorrect number of parameters for ") + + param.name + ". Expected 1 got 3"); +} + +INSTANTIATE_TEST_SUITE_P(TypeDeterminerTest, + ImportData_Int_SingleParamTest, + testing::Values(GLSLData{"sabs", GLSLstd450SAbs}, + GLSLData{"ssign", GLSLstd450SSign})); + } // namespace } // namespace tint