diff --git a/src/reader/spirv/parser_impl.cc b/src/reader/spirv/parser_impl.cc index 7771643c7d..5b5b2f0a16 100644 --- a/src/reader/spirv/parser_impl.cc +++ b/src/reader/spirv/parser_impl.cc @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -608,6 +609,22 @@ bool ParserImpl::RegisterUserAndStructMemberNames() { return true; } +bool ParserImpl::IsValidIdentifier(const std::string& str) { + if (str.empty()) { + return false; + } + std::locale c_locale("C"); + if (!std::isalpha(str[0], c_locale)) { + return false; + } + for (const char& ch : str) { + if ((ch != '_') && !std::isalnum(ch, c_locale)) { + return false; + } + } + return true; +} + bool ParserImpl::EmitEntryPoints() { for (const spvtools::opt::Instruction& entry_point : module_->entry_points()) { @@ -616,6 +633,11 @@ bool ParserImpl::EmitEntryPoints() { const std::string ep_name = entry_point.GetOperand(2).AsString(); const std::string name = namer_.GetName(function_id); + if (!IsValidIdentifier(ep_name)) { + return Fail() << "entry point name is not a valid WGSL identifier: " + << ep_name; + } + ast_module_.AddEntryPoint(std::make_unique( enum_converter_.ToPipelineStage(stage), ep_name, name)); } diff --git a/src/reader/spirv/parser_impl.h b/src/reader/spirv/parser_impl.h index deeed0ad98..649bdc0902 100644 --- a/src/reader/spirv/parser_impl.h +++ b/src/reader/spirv/parser_impl.h @@ -372,6 +372,10 @@ class ParserImpl : Reader { /// @return the Source record, or a default one Source GetSourceForInst(const spvtools::opt::Instruction* inst) const; + /// @param str a candidate identifier + /// @returns true if the given string is a valid WGSL identifier. + static bool IsValidIdentifier(const std::string& str); + private: /// Converts a specific SPIR-V type to a Tint type. Integer case ast::type::Type* ConvertType(const spvtools::opt::analysis::Integer* int_ty); diff --git a/src/reader/spirv/parser_impl_entry_point_test.cc b/src/reader/spirv/parser_impl_entry_point_test.cc index 30dbebc320..507e5f1a72 100644 --- a/src/reader/spirv/parser_impl_entry_point_test.cc +++ b/src/reader/spirv/parser_impl_entry_point_test.cc @@ -24,6 +24,7 @@ namespace reader { namespace spirv { namespace { +using ::testing::Eq; using ::testing::HasSubstr; std::string MakeEntryPoint(const std::string& stage, @@ -81,13 +82,11 @@ TEST_F(SpvParserTest, EntryPoint_MultiNameConflict) { HasSubstr(R"(EntryPoint{fragment as work = work_2})")); } -TEST_F(SpvParserTest, EntryPoint_NameIsSanitized) { +TEST_F(SpvParserTest, EntryPoint_MustBeWgslIdentifier) { auto* p = parser(test::Assemble(MakeEntryPoint("GLCompute", ".1234"))); - EXPECT_TRUE(p->BuildAndParseInternalModule()); - EXPECT_TRUE(p->error().empty()); - const auto module_str = p->module().to_str(); - EXPECT_THAT(module_str, - HasSubstr(R"(EntryPoint{compute as .1234 = x_1234})")); + EXPECT_FALSE(p->BuildAndParseInternalModule()); + EXPECT_THAT(p->error(), + Eq("entry point name is not a valid WGSL identifier: .1234")); } } // namespace diff --git a/src/reader/spirv/parser_impl_test.cc b/src/reader/spirv/parser_impl_test.cc index d1facdc7f2..fa35918aaf 100644 --- a/src/reader/spirv/parser_impl_test.cc +++ b/src/reader/spirv/parser_impl_test.cc @@ -205,6 +205,27 @@ TEST_F(SpvParserTest, Impl_Source_InvalidId) { EXPECT_EQ(0u, s99.column); } +TEST_F(SpvParserTest, Impl_IsValidIdentifier) { + EXPECT_FALSE(ParserImpl::IsValidIdentifier("")); // empty + EXPECT_FALSE( + ParserImpl::IsValidIdentifier("_")); // leading underscore, but ok later + EXPECT_FALSE( + ParserImpl::IsValidIdentifier("9")); // leading digit, but ok later + EXPECT_FALSE(ParserImpl::IsValidIdentifier(" ")); // leading space + EXPECT_FALSE(ParserImpl::IsValidIdentifier("a ")); // trailing space + EXPECT_FALSE(ParserImpl::IsValidIdentifier("a 1")); // space in the middle + EXPECT_FALSE(ParserImpl::IsValidIdentifier(".")); // weird character + + // a simple identifier + EXPECT_TRUE(ParserImpl::IsValidIdentifier("A")); + // each upper case letter + EXPECT_TRUE(ParserImpl::IsValidIdentifier("ABCDEFGHIJKLMNOPQRSTUVWXYZ")); + // each lower case letter + EXPECT_TRUE(ParserImpl::IsValidIdentifier("abcdefghijklmnopqrstuvwxyz")); + EXPECT_TRUE(ParserImpl::IsValidIdentifier("a0123456789")); // each digit + EXPECT_TRUE(ParserImpl::IsValidIdentifier("x_")); // has underscore +} + } // namespace } // namespace spirv } // namespace reader