diff --git a/src/ast/function.cc b/src/ast/function.cc index 26522e825e..ae692875d8 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -31,12 +31,14 @@ namespace tint { namespace ast { Function::Function(const Source& source, + Symbol symbol, const std::string& name, VariableList params, type::Type* return_type, BlockStatement* body, FunctionDecorationList decorations) : Base(source), + symbol_(symbol), name_(name), params_(std::move(params)), return_type_(return_type), @@ -202,7 +204,7 @@ Function::local_referenced_builtin_variables() const { return ret; } -void Function::add_ancestor_entry_point(const std::string& ep) { +void Function::add_ancestor_entry_point(Symbol ep) { for (const auto& point : ancestor_entry_points_) { if (point == ep) { return; @@ -211,9 +213,9 @@ void Function::add_ancestor_entry_point(const std::string& ep) { ancestor_entry_points_.push_back(ep); } -bool Function::HasAncestorEntryPoint(const std::string& name) const { +bool Function::HasAncestorEntryPoint(Symbol symbol) const { for (const auto& point : ancestor_entry_points_) { - if (point == name) { + if (point == symbol) { return true; } } @@ -226,7 +228,7 @@ const Statement* Function::get_last_statement() const { Function* Function::Clone(CloneContext* ctx) const { return ctx->mod->create( - ctx->Clone(source()), name_, ctx->Clone(params_), + ctx->Clone(source()), symbol_, name_, ctx->Clone(params_), ctx->Clone(return_type_), ctx->Clone(body_), ctx->Clone(decorations_)); } @@ -238,7 +240,7 @@ bool Function::IsValid() const { if (body_ == nullptr || !body_->IsValid()) { return false; } - if (name_.length() == 0) { + if (name_.length() == 0 || !symbol_.IsValid()) { return false; } if (return_type_ == nullptr) { @@ -249,7 +251,7 @@ bool Function::IsValid() const { void Function::to_str(std::ostream& out, size_t indent) const { make_indent(out, indent); - out << "Function " << name_ << " -> " << return_type_->type_name() + out << "Function " << symbol_.to_str() << " -> " << return_type_->type_name() << std::endl; for (auto* deco : decorations()) { diff --git a/src/ast/function.h b/src/ast/function.h index 61eb529013..14b6789376 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -35,6 +35,7 @@ #include "src/ast/type/sampler_type.h" #include "src/ast/type/type.h" #include "src/ast/variable.h" +#include "src/symbol.h" namespace tint { namespace ast { @@ -52,12 +53,14 @@ class Function : public Castable { /// Create a function /// @param source the variable source + /// @param symbol the function symbol /// @param name the function name /// @param params the function parameters /// @param return_type the return type /// @param body the function body /// @param decorations the function decorations Function(const Source& source, + Symbol symbol, const std::string& name, VariableList params, type::Type* return_type, @@ -68,6 +71,8 @@ class Function : public Castable { ~Function() override; + /// @returns the function symbol + Symbol symbol() const { return symbol_; } /// @returns the function name const std::string& name() { return name_; } /// @returns the function params @@ -150,15 +155,15 @@ class Function : public Castable { /// Adds an ancestor entry point /// @param ep the entry point ancestor - void add_ancestor_entry_point(const std::string& ep); + void add_ancestor_entry_point(Symbol ep); /// @returns the ancestor entry points - const std::vector& ancestor_entry_points() const { + const std::vector& ancestor_entry_points() const { return ancestor_entry_points_; } /// Checks if the given entry point is an ancestor - /// @param name the entry point name - /// @returns true if `name` is an ancestor entry point of this function - bool HasAncestorEntryPoint(const std::string& name) const; + /// @param sym the entry point symbol + /// @returns true if `sym` is an ancestor entry point of this function + bool HasAncestorEntryPoint(Symbol sym) const; /// @returns the function return type. type::Type* return_type() const { return return_type_; } @@ -197,13 +202,14 @@ class Function : public Castable { const std::vector> ReferencedSampledTextureVariablesImpl(bool multisampled) const; + Symbol symbol_; std::string name_; VariableList params_; type::Type* return_type_ = nullptr; BlockStatement* body_ = nullptr; std::vector referenced_module_vars_; std::vector local_referenced_module_vars_; - std::vector ancestor_entry_points_; + std::vector ancestor_entry_points_; FunctionDecorationList decorations_; }; diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index 68f9921943..e341fc83c7 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc @@ -35,14 +35,18 @@ TEST_F(FunctionTest, Creation) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, ast::VariableDecorationList{})); auto* var = params[0]; - Function f(Source{}, "func", params, &void_type, create(), - FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", params, &void_type, + create(), FunctionDecorationList{}); + EXPECT_EQ(f.symbol(), func_sym); EXPECT_EQ(f.name(), "func"); ASSERT_EQ(f.params().size(), 1u); EXPECT_EQ(f.return_type(), &void_type); @@ -53,13 +57,16 @@ TEST_F(FunctionTest, Creation_WithSource) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, ast::VariableDecorationList{})); - Function f(Source{Source::Location{20, 2}}, "func", params, &void_type, - create(), FunctionDecorationList{}); + Function f(Source{Source::Location{20, 2}}, func_sym, "func", params, + &void_type, create(), FunctionDecorationList{}); auto src = f.source(); EXPECT_EQ(src.range.begin.line, 20u); EXPECT_EQ(src.range.begin.column, 2u); @@ -69,9 +76,12 @@ TEST_F(FunctionTest, AddDuplicateReferencedVariables) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + Variable v(Source{}, "var", StorageClass::kInput, &i32, false, nullptr, ast::VariableDecorationList{}); - Function f(Source{}, "func", VariableList{}, &void_type, + Function f(Source{}, func_sym, "func", VariableList{}, &void_type, create(), FunctionDecorationList{}); f.add_referenced_module_variable(&v); @@ -92,6 +102,9 @@ TEST_F(FunctionTest, GetReferenceLocations) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + auto* loc1 = create(Source{}, "loc1", StorageClass::kInput, &i32, false, nullptr, ast::VariableDecorationList{ @@ -116,7 +129,7 @@ TEST_F(FunctionTest, GetReferenceLocations) { create(Builtin::kFragDepth, Source{}), }); - Function f(Source{}, "func", VariableList{}, &void_type, + Function f(Source{}, func_sym, "func", VariableList{}, &void_type, create(), FunctionDecorationList{}); f.add_referenced_module_variable(loc1); @@ -137,6 +150,9 @@ TEST_F(FunctionTest, GetReferenceBuiltins) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + auto* loc1 = create(Source{}, "loc1", StorageClass::kInput, &i32, false, nullptr, ast::VariableDecorationList{ @@ -161,7 +177,7 @@ TEST_F(FunctionTest, GetReferenceBuiltins) { create(Builtin::kFragDepth, Source{}), }); - Function f(Source{}, "func", VariableList{}, &void_type, + Function f(Source{}, func_sym, "func", VariableList{}, &void_type, create(), FunctionDecorationList{}); f.add_referenced_module_variable(loc1); @@ -180,22 +196,30 @@ TEST_F(FunctionTest, GetReferenceBuiltins) { TEST_F(FunctionTest, AddDuplicateEntryPoints) { type::Void void_type; - Function f(Source{}, "func", VariableList{}, &void_type, + + Module m; + auto func_sym = m.RegisterSymbol("func"); + auto main_sym = m.RegisterSymbol("main"); + + Function f(Source{}, func_sym, "func", VariableList{}, &void_type, create(), FunctionDecorationList{}); - f.add_ancestor_entry_point("main"); + f.add_ancestor_entry_point(main_sym); ASSERT_EQ(1u, f.ancestor_entry_points().size()); - EXPECT_EQ("main", f.ancestor_entry_points()[0]); + EXPECT_EQ(main_sym, f.ancestor_entry_points()[0]); - f.add_ancestor_entry_point("main"); + f.add_ancestor_entry_point(main_sym); ASSERT_EQ(1u, f.ancestor_entry_points().size()); - EXPECT_EQ("main", f.ancestor_entry_points()[0]); + EXPECT_EQ(main_sym, f.ancestor_entry_points()[0]); } TEST_F(FunctionTest, IsValid) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, @@ -204,21 +228,27 @@ TEST_F(FunctionTest, IsValid) { auto* body = create(); body->append(create()); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); EXPECT_TRUE(f.IsValid()); } -TEST_F(FunctionTest, IsValid_EmptyName) { +TEST_F(FunctionTest, IsValid_InvalidName) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol(""); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, ast::VariableDecorationList{})); - Function f(Source{}, "", params, &void_type, create(), + auto* body = create(); + body->append(create()); + + Function f(Source{}, func_sym, "", params, &void_type, body, FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); } @@ -226,13 +256,16 @@ TEST_F(FunctionTest, IsValid_EmptyName) { TEST_F(FunctionTest, IsValid_MissingReturnType) { type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, ast::VariableDecorationList{})); - Function f(Source{}, "func", params, nullptr, create(), - FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", params, nullptr, + create(), FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); } @@ -240,27 +273,33 @@ TEST_F(FunctionTest, IsValid_NullParam) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, ast::VariableDecorationList{})); params.push_back(nullptr); - Function f(Source{}, "func", params, &void_type, create(), - FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", params, &void_type, + create(), FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); } TEST_F(FunctionTest, IsValid_InvalidParam) { type::Void void_type; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, nullptr, false, nullptr, ast::VariableDecorationList{})); - Function f(Source{}, "func", params, &void_type, create(), - FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", params, &void_type, + create(), FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); } @@ -268,6 +307,9 @@ TEST_F(FunctionTest, IsValid_NullBodyStatement) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, @@ -277,7 +319,7 @@ TEST_F(FunctionTest, IsValid_NullBodyStatement) { body->append(create()); body->append(nullptr); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); @@ -287,6 +329,9 @@ TEST_F(FunctionTest, IsValid_InvalidBodyStatement) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, @@ -296,7 +341,7 @@ TEST_F(FunctionTest, IsValid_InvalidBodyStatement) { body->append(create()); body->append(nullptr); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); } @@ -305,14 +350,18 @@ TEST_F(FunctionTest, ToStr) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + auto* body = create(); body->append(create()); - Function f(Source{}, "func", {}, &void_type, body, FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", {}, &void_type, body, + FunctionDecorationList{}); std::ostringstream out; f.to_str(out, 2); - EXPECT_EQ(out.str(), R"( Function func -> __void + EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void () { Discard{} @@ -324,16 +373,19 @@ TEST_F(FunctionTest, ToStr_WithDecoration) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + auto* body = create(); body->append(create()); Function f( - Source{}, "func", {}, &void_type, body, + Source{}, func_sym, "func", {}, &void_type, body, FunctionDecorationList{create(2, 4, 6, Source{})}); std::ostringstream out; f.to_str(out, 2); - EXPECT_EQ(out.str(), R"( Function func -> __void + EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void WorkgroupDecoration{2 4 6} () { @@ -346,6 +398,9 @@ TEST_F(FunctionTest, ToStr_WithParams) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, @@ -354,12 +409,12 @@ TEST_F(FunctionTest, ToStr_WithParams) { auto* body = create(); body->append(create()); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); std::ostringstream out; f.to_str(out, 2); - EXPECT_EQ(out.str(), R"( Function func -> __void + EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void ( Variable{ var @@ -376,8 +431,11 @@ TEST_F(FunctionTest, ToStr_WithParams) { TEST_F(FunctionTest, TypeName) { type::Void void_type; - Function f(Source{}, "func", {}, &void_type, create(), - FunctionDecorationList{}); + Module m; + auto func_sym = m.RegisterSymbol("func"); + + Function f(Source{}, func_sym, "func", {}, &void_type, + create(), FunctionDecorationList{}); EXPECT_EQ(f.type_name(), "__func__void"); } @@ -386,6 +444,9 @@ TEST_F(FunctionTest, TypeName_WithParams) { type::I32 i32; type::F32 f32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var1", StorageClass::kNone, &i32, false, nullptr, @@ -394,19 +455,22 @@ TEST_F(FunctionTest, TypeName_WithParams) { false, nullptr, ast::VariableDecorationList{})); - Function f(Source{}, "func", params, &void_type, create(), - FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", params, &void_type, + create(), FunctionDecorationList{}); EXPECT_EQ(f.type_name(), "__func__void__i32__f32"); } TEST_F(FunctionTest, GetLastStatement) { type::Void void_type; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; auto* body = create(); auto* stmt = create(); body->append(stmt); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); EXPECT_EQ(f.get_last_statement(), stmt); @@ -415,9 +479,12 @@ TEST_F(FunctionTest, GetLastStatement) { TEST_F(FunctionTest, GetLastStatement_nullptr) { type::Void void_type; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; auto* body = create(); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); EXPECT_EQ(f.get_last_statement(), nullptr); @@ -425,8 +492,12 @@ TEST_F(FunctionTest, GetLastStatement_nullptr) { TEST_F(FunctionTest, WorkgroupSize_NoneSet) { type::Void void_type; - Function f(Source{}, "f", {}, &void_type, create(), - FunctionDecorationList{}); + + Module m; + auto func_sym = m.RegisterSymbol("func"); + + Function f(Source{}, func_sym, "func", {}, &void_type, + create(), FunctionDecorationList{}); uint32_t x = 0; uint32_t y = 0; uint32_t z = 0; @@ -438,7 +509,12 @@ TEST_F(FunctionTest, WorkgroupSize_NoneSet) { TEST_F(FunctionTest, WorkgroupSize) { type::Void void_type; - Function f(Source{}, "f", {}, &void_type, create(), + + Module m; + auto func_sym = m.RegisterSymbol("func"); + + Function f(Source{}, func_sym, "func", {}, &void_type, + create(), {create(2u, 4u, 6u, Source{})}); uint32_t x = 0; diff --git a/src/ast/module.cc b/src/ast/module.cc index 5dcc12bd51..2132e1f935 100644 --- a/src/ast/module.cc +++ b/src/ast/module.cc @@ -47,21 +47,23 @@ void Module::Clone(CloneContext* ctx) { for (auto* func : functions_) { ctx->mod->functions_.emplace_back(ctx->Clone(func)); } + + ctx->mod->symbol_table_ = symbol_table_; } -Function* Module::FindFunctionByName(const std::string& name) const { +Function* Module::FindFunctionBySymbol(Symbol sym) const { for (auto* func : functions_) { - if (func->name() == name) { + if (func->symbol() == sym) { return func; } } return nullptr; } -Function* Module::FindFunctionByNameAndStage(const std::string& name, - PipelineStage stage) const { +Function* Module::FindFunctionBySymbolAndStage(Symbol sym, + PipelineStage stage) const { for (auto* func : functions_) { - if (func->name() == name && func->pipeline_stage() == stage) { + if (func->symbol() == sym && func->pipeline_stage() == stage) { return func; } } @@ -81,6 +83,10 @@ Symbol Module::RegisterSymbol(const std::string& name) { return symbol_table_.Register(name); } +Symbol Module::GetSymbol(const std::string& name) const { + return symbol_table_.GetSymbol(name); +} + std::string Module::SymbolToName(const Symbol sym) const { return symbol_table_.NameFor(sym); } diff --git a/src/ast/module.h b/src/ast/module.h index b274be110f..1facd79d8b 100644 --- a/src/ast/module.h +++ b/src/ast/module.h @@ -89,15 +89,14 @@ class Module { /// @returns the modules functions const FunctionList& functions() const { return functions_; } /// Returns the function with the given name - /// @param name the name to search for + /// @param sym the function symbol to search for /// @returns the associated function or nullptr if none exists - Function* FindFunctionByName(const std::string& name) const; + Function* FindFunctionBySymbol(Symbol sym) const; /// Returns the function with the given name - /// @param name the name to search for + /// @param sym the function symbol to search for /// @param stage the pipeline stage /// @returns the associated function or nullptr if none exists - Function* FindFunctionByNameAndStage(const std::string& name, - PipelineStage stage) const; + Function* FindFunctionBySymbolAndStage(Symbol sym, PipelineStage stage) const; /// @param stage the pipeline stage /// @returns true if the module contains an entrypoint function with the given /// stage @@ -169,6 +168,11 @@ class Module { /// previously generated symbol will be returned. Symbol RegisterSymbol(const std::string& name); + /// Returns the symbol for `name` + /// @param name the name to lookup + /// @returns the symbol for name or symbol::kInvalid + Symbol GetSymbol(const std::string& name) const; + /// Returns the `name` for `sym` /// @param sym the symbol to retrieve the name for /// @returns the use provided `name` for the symbol or "" if not found diff --git a/src/ast/module_test.cc b/src/ast/module_test.cc index 3303235729..6914eccea9 100644 --- a/src/ast/module_test.cc +++ b/src/ast/module_test.cc @@ -48,16 +48,17 @@ TEST_F(ModuleTest, LookupFunction) { type::F32 f32; Module m; + auto func_sym = m.RegisterSymbol("main"); auto* func = - create(Source{}, "main", VariableList{}, &f32, + create(Source{}, func_sym, "main", VariableList{}, &f32, create(), ast::FunctionDecorationList{}); m.AddFunction(func); - EXPECT_EQ(func, m.FindFunctionByName("main")); + EXPECT_EQ(func, m.FindFunctionBySymbol(func_sym)); } TEST_F(ModuleTest, LookupFunctionMissing) { Module m; - EXPECT_EQ(nullptr, m.FindFunctionByName("Missing")); + EXPECT_EQ(nullptr, m.FindFunctionBySymbol(m.RegisterSymbol("Missing"))); } TEST_F(ModuleTest, IsValid_Empty) { @@ -127,11 +128,12 @@ TEST_F(ModuleTest, IsValid_Struct_EmptyName) { TEST_F(ModuleTest, IsValid_Function) { type::F32 f32; - auto* func = - create(Source{}, "main", VariableList(), &f32, - create(), ast::FunctionDecorationList{}); Module m; + + auto* func = create(Source{}, m.RegisterSymbol("main"), "main", + VariableList(), &f32, create(), + ast::FunctionDecorationList{}); m.AddFunction(func); EXPECT_TRUE(m.IsValid()); } @@ -144,10 +146,13 @@ TEST_F(ModuleTest, IsValid_Null_Function) { TEST_F(ModuleTest, IsValid_Invalid_Function) { VariableList p; - auto* func = create(Source{}, "", p, nullptr, nullptr, - ast::FunctionDecorationList{}); Module m; + + auto* func = + create(Source{}, m.RegisterSymbol("main"), "main", p, nullptr, + nullptr, ast::FunctionDecorationList{}); + m.AddFunction(func); EXPECT_FALSE(m.IsValid()); } diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc index 1bbbe8bc64..7ef231184b 100644 --- a/src/inspector/inspector.cc +++ b/src/inspector/inspector.cc @@ -267,7 +267,7 @@ std::vector Inspector::GetMultisampledTextureResourceBindings( } ast::Function* Inspector::FindEntryPointByName(const std::string& name) { - auto* func = module_.FindFunctionByName(name); + auto* func = module_.FindFunctionBySymbol(module_.GetSymbol(name)); if (!func) { error_ += name + " was not found!"; return nullptr; diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc index 46ac6e456b..fbf24e5dcc 100644 --- a/src/inspector/inspector_test.cc +++ b/src/inspector/inspector_test.cc @@ -83,8 +83,9 @@ class InspectorHelper { ast::FunctionDecorationList decorations = {}) { auto* body = create(); body->append(create(Source{})); - return create(Source{}, name, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(name), name, + ast::VariableList(), void_type(), body, + decorations); } /// Generates a function that calls another @@ -102,8 +103,9 @@ class InspectorHelper { create(ident_expr, ast::ExpressionList()); body->append(create(call_expr)); body->append(create(Source{})); - return create(Source{}, caller, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(caller), + caller, ast::VariableList(), void_type(), body, + decorations); } /// Add In/Out variables to the global variables @@ -154,8 +156,9 @@ class InspectorHelper { create(in))); } body->append(create(Source{})); - return create(Source{}, name, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(name), name, + ast::VariableList(), void_type(), body, + decorations); } /// Generates a function that references in/out variables and calls another @@ -184,8 +187,9 @@ class InspectorHelper { create(ident_expr, ast::ExpressionList()); body->append(create(call_expr)); body->append(create(Source{})); - return create(Source{}, caller, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(caller), + caller, ast::VariableList(), void_type(), body, + decorations); } /// Add a Constant ID to the global variables. @@ -445,9 +449,9 @@ class InspectorHelper { } body->append(create(Source{})); - return create(Source{}, func_name, ast::VariableList(), - void_type(), body, - ast::FunctionDecorationList{}); + return create(Source{}, mod()->RegisterSymbol(func_name), + func_name, ast::VariableList(), void_type(), + body, ast::FunctionDecorationList{}); } /// Adds a regular sampler variable to the module @@ -587,8 +591,9 @@ class InspectorHelper { create("sampler_result"), call_expr)); body->append(create(Source{})); - return create(Source{}, func_name, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(func_name), + func_name, ast::VariableList(), void_type(), + body, decorations); } /// Generates a function that references a specific sampler variable @@ -634,8 +639,9 @@ class InspectorHelper { create("sampler_result"), call_expr)); body->append(create(Source{})); - return create(Source{}, func_name, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(func_name), + func_name, ast::VariableList(), void_type(), + body, decorations); } /// Generates a function that references a specific comparison sampler @@ -682,8 +688,9 @@ class InspectorHelper { create("sampler_result"), call_expr)); body->append(create(Source{})); - return create(Source{}, func_name, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(func_name), + func_name, ast::VariableList(), void_type(), + body, decorations); } /// Gets an appropriate type for the data in a given texture type. @@ -1513,7 +1520,8 @@ TEST_F(InspectorGetUniformBufferResourceBindingsTest, MultipleUniformBuffers) { body->append(create(Source{})); ast::Function* func = create( - Source{}, "ep_func", ast::VariableList(), void_type(), body, + Source{}, mod()->RegisterSymbol("ep_func"), "ep_func", + ast::VariableList(), void_type(), body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -1659,7 +1667,8 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, MultipleStorageBuffers) { body->append(create(Source{})); ast::Function* func = create( - Source{}, "ep_func", ast::VariableList(), void_type(), body, + Source{}, mod()->RegisterSymbol("ep_func"), "ep_func", + ast::VariableList(), void_type(), body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -1832,7 +1841,8 @@ TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest, body->append(create(Source{})); ast::Function* func = create( - Source{}, "ep_func", ast::VariableList(), void_type(), body, + Source{}, mod()->RegisterSymbol("ep_func"), "ep_func", + ast::VariableList(), void_type(), body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index ec024bb9b4..c8affa7f28 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -761,9 +761,10 @@ bool FunctionEmitter::Emit() { } auto* body = statements_stack_[0].statements_; - ast_module_.AddFunction(create( - decl.source, decl.name, std::move(decl.params), decl.return_type, body, - std::move(decl.decorations))); + ast_module_.AddFunction( + create(decl.source, ast_module_.RegisterSymbol(decl.name), + decl.name, std::move(decl.params), decl.return_type, + body, std::move(decl.decorations))); // Maintain the invariant by repopulating the one and only element. statements_stack_.clear(); diff --git a/src/reader/spirv/function_call_test.cc b/src/reader/spirv/function_call_test.cc index a30ec374b8..fabf8191f7 100644 --- a/src/reader/spirv/function_call_test.cc +++ b/src/reader/spirv/function_call_test.cc @@ -46,14 +46,16 @@ TEST_F(SpvParserTest, EmitStatement_VoidCallNoParams) { OpFunctionEnd )")); ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error(); - const auto module_ast_str = p->module().to_str(); + const auto module_ast_str = p->get_module().to_str(); EXPECT_THAT(module_ast_str, Eq(R"(Module{ - Function x_50 -> __void + Function )" + p->get_module().GetSymbol("x_50").to_str() + + R"( -> __void () { Return{} } - Function x_100 -> __void + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __void () { Call[not set]{ @@ -214,9 +216,10 @@ TEST_F(SpvParserTest, EmitStatement_CallWithParams) { )")); ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error(); EXPECT_TRUE(p->error().empty()); - const auto module_ast_str = p->module().to_str(); + const auto module_ast_str = p->get_module().to_str(); EXPECT_THAT(module_ast_str, HasSubstr(R"(Module{ - Function x_50 -> __u32 + Function )" + p->get_module().GetSymbol("x_50").to_str() + + R"( -> __u32 ( VariableConst{ x_51 @@ -240,7 +243,8 @@ TEST_F(SpvParserTest, EmitStatement_CallWithParams) { } } } - Function x_100 -> __void + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __void () { VariableDeclStatement{ diff --git a/src/reader/spirv/function_decl_test.cc b/src/reader/spirv/function_decl_test.cc index 2223893f3f..398f8fa8a1 100644 --- a/src/reader/spirv/function_decl_test.cc +++ b/src/reader/spirv/function_decl_test.cc @@ -59,9 +59,10 @@ TEST_F(SpvParserTest, Emit_VoidFunctionWithoutParams) { ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.Emit()); - auto got = p->module().to_str(); - auto* expect = R"(Module{ - Function x_100 -> __void + auto got = p->get_module().to_str(); + auto expect = R"(Module{ + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __void () { Return{} @@ -83,9 +84,10 @@ TEST_F(SpvParserTest, Emit_NonVoidResultType) { FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.Emit()); - auto got = p->module().to_str(); - auto* expect = R"(Module{ - Function x_100 -> __f32 + auto got = p->get_module().to_str(); + auto expect = R"(Module{ + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __f32 () { Return{ @@ -115,9 +117,10 @@ TEST_F(SpvParserTest, Emit_MixedParamTypes) { FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.Emit()); - auto got = p->module().to_str(); - auto* expect = R"(Module{ - Function x_100 -> __void + auto got = p->get_module().to_str(); + auto expect = R"(Module{ + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __void ( VariableConst{ a @@ -159,9 +162,10 @@ TEST_F(SpvParserTest, Emit_GenerateParamNames) { FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.Emit()); - auto got = p->module().to_str(); - auto* expect = R"(Module{ - Function x_100 -> __void + auto got = p->get_module().to_str(); + auto expect = R"(Module{ + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __void ( VariableConst{ x_14 diff --git a/src/reader/spirv/parser_impl_function_decl_test.cc b/src/reader/spirv/parser_impl_function_decl_test.cc index f322de25a9..dab5f3a0e3 100644 --- a/src/reader/spirv/parser_impl_function_decl_test.cc +++ b/src/reader/spirv/parser_impl_function_decl_test.cc @@ -53,7 +53,7 @@ TEST_F(SpvParserTest, EmitFunctions_NoFunctions) { auto p = parser(test::Assemble(CommonTypes())); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, Not(HasSubstr("Function{"))); } @@ -64,7 +64,7 @@ TEST_F(SpvParserTest, EmitFunctions_FunctionWithoutBody) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, Not(HasSubstr("Function{"))); } @@ -79,9 +79,10 @@ OpFunctionEnd)"; auto p = parser(test::Assemble(input)); ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->error().empty()) << p->error(); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function main -> __void + Function )" + p->get_module().GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () {)")); @@ -98,9 +99,10 @@ OpFunctionEnd)"; auto p = parser(test::Assemble(input)); ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->error().empty()) << p->error(); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function main -> __void + Function )" + p->get_module().GetSymbol("main").to_str() + + R"( -> __void StageDecoration{fragment} () {)")); @@ -117,9 +119,10 @@ OpFunctionEnd)"; auto p = parser(test::Assemble(input)); ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->error().empty()) << p->error(); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function main -> __void + Function )" + p->get_module().GetSymbol("main").to_str() + + R"( -> __void StageDecoration{compute} () {)")); @@ -138,14 +141,16 @@ OpFunctionEnd)"; auto p = parser(test::Assemble(input)); ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->error().empty()) << p->error(); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function frag_main -> __void + Function )" + p->get_module().GetSymbol("frag_main").to_str() + + R"( -> __void StageDecoration{fragment} () {)")); EXPECT_THAT(module_ast, HasSubstr(R"( - Function comp_main -> __void + Function )" + p->get_module().GetSymbol("comp_main").to_str() + + R"( -> __void StageDecoration{compute} () {)")); @@ -160,9 +165,10 @@ TEST_F(SpvParserTest, EmitFunctions_VoidFunctionWithoutParams) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function main -> __void + Function )" + p->get_module().GetSymbol("main").to_str() + + R"( -> __void () {)")); } @@ -193,9 +199,10 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function leaf -> __u32 + Function )" + p->get_module().GetSymbol("leaf").to_str() + + R"( -> __u32 () { Return{ @@ -204,7 +211,8 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) { } } } - Function branch -> __u32 + Function )" + p->get_module().GetSymbol("branch").to_str() + + R"( -> __u32 () { VariableDeclStatement{ @@ -227,7 +235,8 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) { } } } - Function root -> __void + Function )" + p->get_module().GetSymbol("root").to_str() + + R"( -> __void () { VariableDeclStatement{ @@ -260,9 +269,10 @@ TEST_F(SpvParserTest, EmitFunctions_NonVoidResultType) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function ret_float -> __f32 + Function )" + p->get_module().GetSymbol("ret_float").to_str() + + R"( -> __f32 () { Return{ @@ -289,9 +299,10 @@ TEST_F(SpvParserTest, EmitFunctions_MixedParamTypes) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function mixed_params -> __void + Function )" + p->get_module().GetSymbol("mixed_params").to_str() + + R"( -> __void ( VariableConst{ a @@ -328,9 +339,10 @@ TEST_F(SpvParserTest, EmitFunctions_GenerateParamNames) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function mixed_params -> __void + Function )" + p->get_module().GetSymbol("mixed_params").to_str() + + R"( -> __void ( VariableConst{ x_14 diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc index 67a4651261..04ff79bb1c 100644 --- a/src/reader/wgsl/parser_impl.cc +++ b/src/reader/wgsl/parser_impl.cc @@ -1280,9 +1280,9 @@ Maybe ParserImpl::function_decl(ast::DecorationList& decos) { if (errored) return Failure::kErrored; - return create(header->source, header->name, header->params, - header->return_type, body.value, - func_decos.value); + return create( + header->source, module_.RegisterSymbol(header->name), header->name, + header->params, header->return_type, body.value, func_decos.value); } // function_type_decl diff --git a/src/symbol_table.cc b/src/symbol_table.cc index 13fe13fac4..8998ee2cf8 100644 --- a/src/symbol_table.cc +++ b/src/symbol_table.cc @@ -18,10 +18,14 @@ namespace tint { SymbolTable::SymbolTable() = default; +SymbolTable::SymbolTable(const SymbolTable&) = default; + SymbolTable::SymbolTable(SymbolTable&&) = default; SymbolTable::~SymbolTable() = default; +SymbolTable& SymbolTable::operator=(const SymbolTable& other) = default; + SymbolTable& SymbolTable::operator=(SymbolTable&&) = default; Symbol SymbolTable::Register(const std::string& name) { @@ -41,6 +45,11 @@ Symbol SymbolTable::Register(const std::string& name) { return sym; } +Symbol SymbolTable::GetSymbol(const std::string& name) const { + auto it = name_to_symbol_.find(name); + return it != name_to_symbol_.end() ? it->second : Symbol(); +} + std::string SymbolTable::NameFor(const Symbol symbol) const { auto it = symbol_to_name_.find(symbol.value()); if (it == symbol_to_name_.end()) diff --git a/src/symbol_table.h b/src/symbol_table.h index e3351e16a5..1c085988b5 100644 --- a/src/symbol_table.h +++ b/src/symbol_table.h @@ -27,11 +27,17 @@ class SymbolTable { public: /// Constructor SymbolTable(); + /// Copy constructor + SymbolTable(const SymbolTable&); /// Move Constructor SymbolTable(SymbolTable&&); /// Destructor ~SymbolTable(); + /// Copy assignment + /// @param other the symbol table to copy + /// @returns the new symbol table + SymbolTable& operator=(const SymbolTable& other); /// Move assignment /// @param other the symbol table to move /// @returns the symbol table @@ -42,6 +48,11 @@ class SymbolTable { /// @returns the symbol representing the given name Symbol Register(const std::string& name); + /// Returns the symbol for the given `name` + /// @param name the name to lookup + /// @returns the symbol for the name or symbol::kInvalid if not found. + Symbol GetSymbol(const std::string& name) const; + /// Returns the name for the given symbol /// @param symbol the symbol to retrieve the name for /// @returns the symbol name or "" if not found diff --git a/src/transform/bound_array_accessors_test.cc b/src/transform/bound_array_accessors_test.cc index 6e1f1717b0..24a7242c49 100644 --- a/src/transform/bound_array_accessors_test.cc +++ b/src/transform/bound_array_accessors_test.cc @@ -51,7 +51,7 @@ namespace { template T* FindVariable(ast::Module* mod, std::string name) { - if (auto* func = mod->FindFunctionByName("func")) { + if (auto* func = mod->FindFunctionBySymbol(mod->RegisterSymbol("func"))) { for (auto* stmt : *func->body()) { if (auto* decl = stmt->As()) { if (auto* var = decl->variable()) { @@ -92,9 +92,9 @@ class BoundArrayAccessorsTest : public testing::Test { struct ModuleBuilder : public ast::BuilderWithModule { ModuleBuilder() : body_(create()) { - mod->AddFunction(create(Source{}, "func", - ast::VariableList{}, ty.void_, body_, - ast::FunctionDecorationList{})); + mod->AddFunction(create( + Source{}, mod->RegisterSymbol("func"), "func", ast::VariableList{}, + ty.void_, body_, ast::FunctionDecorationList{})); } ast::Module Module() { diff --git a/src/transform/emit_vertex_point_size_test.cc b/src/transform/emit_vertex_point_size_test.cc index 7a08e5fcfe..57a904713b 100644 --- a/src/transform/emit_vertex_point_size_test.cc +++ b/src/transform/emit_vertex_point_size_test.cc @@ -58,23 +58,26 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) { Var("builtin_assignments_should_happen_before_this", tint::ast::StorageClass::kFunction, ty.f32))); - mod->AddFunction( - create(Source{}, "non_entry_a", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{})); + auto a_sym = mod->RegisterSymbol("non_entry_a"); + mod->AddFunction(create( + Source{}, a_sym, "non_entry_a", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{})); + auto entry_sym = mod->RegisterSymbol("entry"); auto* entry = create( - Source{}, "entry", ast::VariableList{}, ty.void_, block, + Source{}, entry_sym, "entry", ast::VariableList{}, ty.void_, block, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); mod->AddFunction(entry); - mod->AddFunction( - create(Source{}, "non_entry_b", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{})); + auto b_sym = mod->RegisterSymbol("non_entry_b"); + mod->AddFunction(create( + Source{}, b_sym, "non_entry_b", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{})); } }; @@ -82,7 +85,7 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - auto* expected = R"(Module{ + auto expected = R"(Module{ Variable{ Decorations{ BuiltinDecoration{pointsize} @@ -91,11 +94,13 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) { out __f32 } - Function non_entry_a -> __void + Function )" + result.module.RegisterSymbol("non_entry_a").to_str() + + R"( -> __void () { } - Function entry -> __void + Function )" + result.module.RegisterSymbol("entry").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -111,7 +116,8 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) { } } } - Function non_entry_b -> __void + Function )" + result.module.RegisterSymbol("non_entry_b").to_str() + + R"( -> __void () { } @@ -123,23 +129,26 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) { TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { struct Builder : ModuleBuilder { void Build() override { - mod->AddFunction( - create(Source{}, "non_entry_a", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{})); + auto a_sym = mod->RegisterSymbol("non_entry_a"); + mod->AddFunction(create( + Source{}, a_sym, "non_entry_a", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{})); - mod->AddFunction( - create(Source{}, "entry", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kVertex, Source{}), - })); + auto entry_sym = mod->RegisterSymbol("entry"); + mod->AddFunction(create( + Source{}, entry_sym, "entry", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{ + create(ast::PipelineStage::kVertex, + Source{}), + })); - mod->AddFunction( - create(Source{}, "non_entry_b", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{})); + auto b_sym = mod->RegisterSymbol("non_entry_b"); + mod->AddFunction(create( + Source{}, b_sym, "non_entry_b", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{})); } }; @@ -147,7 +156,7 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - auto* expected = R"(Module{ + auto expected = R"(Module{ Variable{ Decorations{ BuiltinDecoration{pointsize} @@ -156,11 +165,13 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { out __f32 } - Function non_entry_a -> __void + Function )" + result.module.RegisterSymbol("non_entry_a").to_str() + + R"( -> __void () { } - Function entry -> __void + Function )" + result.module.RegisterSymbol("entry").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -169,7 +180,8 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { ScalarConstructor[__f32]{1.000000} } } - Function non_entry_b -> __void + Function )" + result.module.RegisterSymbol("non_entry_b").to_str() + + R"( -> __void () { } @@ -181,8 +193,9 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { TEST_F(EmitVertexPointSizeTest, NonVertexStage) { struct Builder : ModuleBuilder { void Build() override { + auto frag_sym = mod->RegisterSymbol("fragment_entry"); auto* fragment_entry = create( - Source{}, "fragment_entry", ast::VariableList{}, ty.void_, + Source{}, frag_sym, "fragment_entry", ast::VariableList{}, ty.void_, create(Source{}), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, @@ -190,13 +203,14 @@ TEST_F(EmitVertexPointSizeTest, NonVertexStage) { }); mod->AddFunction(fragment_entry); - auto* compute_entry = - create(Source{}, "compute_entry", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto comp_sym = mod->RegisterSymbol("compute_entry"); + auto* compute_entry = create( + Source{}, comp_sym, "compute_entry", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod->AddFunction(compute_entry); } }; @@ -205,13 +219,15 @@ TEST_F(EmitVertexPointSizeTest, NonVertexStage) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - auto* expected = R"(Module{ - Function fragment_entry -> __void + auto expected = R"(Module{ + Function )" + result.module.RegisterSymbol("fragment_entry").to_str() + + R"( -> __void StageDecoration{fragment} () { } - Function compute_entry -> __void + Function )" + result.module.RegisterSymbol("compute_entry").to_str() + + R"( -> __void StageDecoration{compute} () { diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index 8738b7e790..b398935922 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -169,9 +169,9 @@ Transform::Output FirstIndexOffset::Run(ast::Module* in) { body->append(ctx.Clone(s)); } return ctx.mod->create( - ctx.Clone(func->source()), func->name(), ctx.Clone(func->params()), - ctx.Clone(func->return_type()), ctx.Clone(body), - ctx.Clone(func->decorations())); + ctx.Clone(func->source()), func->symbol(), func->name(), + ctx.Clone(func->params()), ctx.Clone(func->return_type()), + ctx.Clone(body), ctx.Clone(func->decorations())); }); in->Clone(&ctx); diff --git a/src/transform/first_index_offset_test.cc b/src/transform/first_index_offset_test.cc index c3b857c420..7a275577e8 100644 --- a/src/transform/first_index_offset_test.cc +++ b/src/transform/first_index_offset_test.cc @@ -58,9 +58,9 @@ struct ModuleBuilder : public ast::BuilderWithModule { ast::Function* AddFunction(const std::string& name, ast::VariableList params = {}) { - auto* func = create(Source{}, name, std::move(params), - ty.u32, create(), - ast::FunctionDecorationList()); + auto* func = create( + Source{}, mod->RegisterSymbol(name), name, std::move(params), ty.u32, + create(), ast::FunctionDecorationList()); mod->AddFunction(func); return func; } @@ -154,7 +154,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) { uniform __struct_TintFirstIndexOffsetData } - Function test -> __u32 + Function tint_symbol_1 -> __u32 () { VariableDeclStatement{ @@ -229,7 +229,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) { uniform __struct_TintFirstIndexOffsetData } - Function test -> __u32 + Function tint_symbol_1 -> __u32 () { VariableDeclStatement{ @@ -317,7 +317,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) { uniform __struct_TintFirstIndexOffsetData } - Function test -> __u32 + Function tint_symbol_1 -> __u32 () { Return{ @@ -389,7 +389,7 @@ TEST_F(FirstIndexOffsetTest, NestedCalls) { uniform __struct_TintFirstIndexOffsetData } - Function func1 -> __u32 + Function tint_symbol_1 -> __u32 () { VariableDeclStatement{ @@ -415,7 +415,7 @@ TEST_F(FirstIndexOffsetTest, NestedCalls) { } } } - Function func2 -> __u32 + Function tint_symbol_2 -> __u32 () { Return{ diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index d35f3d870d..6143d15601 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -84,8 +84,8 @@ Transform::Output VertexPulling::Run(ast::Module* in) { } // Find entry point - auto* func = mod->FindFunctionByNameAndStage(cfg.entry_point_name, - ast::PipelineStage::kVertex); + auto* func = mod->FindFunctionBySymbolAndStage( + mod->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex); if (func == nullptr) { diag::Diagnostic err; err.severity = diag::Severity::Error; @@ -94,9 +94,6 @@ Transform::Output VertexPulling::Run(ast::Module* in) { return out; } - // Save the vertex function - auto* vertex_func = mod->FindFunctionByName(func->name()); - // TODO(idanr): Need to check shader locations in descriptor cover all // attributes @@ -108,7 +105,7 @@ Transform::Output VertexPulling::Run(ast::Module* in) { state.FindOrInsertInstanceIndexIfUsed(); state.ConvertVertexInputVariablesToPrivate(); state.AddVertexStorageBuffers(); - state.AddVertexPullingPreamble(vertex_func); + state.AddVertexPullingPreamble(func); return out; } diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc index 8f2a177e1c..c0279201dd 100644 --- a/src/transform/vertex_pulling_test.cc +++ b/src/transform/vertex_pulling_test.cc @@ -47,8 +47,8 @@ class VertexPullingHelper { // Create basic module with an entry point and vertex function void InitBasicModule() { auto* func = create( - Source{}, "main", ast::VariableList{}, mod_->create(), - create(), + Source{}, mod_->RegisterSymbol("main"), "main", ast::VariableList{}, + mod_->create(), create(), ast::FunctionDecorationList{create( ast::PipelineStage::kVertex, Source{})}); mod()->AddFunction(func); @@ -134,8 +134,8 @@ TEST_F(VertexPullingTest, Error_InvalidEntryPoint) { TEST_F(VertexPullingTest, Error_EntryPointWrongStage) { auto* func = create( - Source{}, "main", ast::VariableList{}, mod()->create(), - create(), + Source{}, mod()->RegisterSymbol("main"), "main", ast::VariableList{}, + mod()->create(), create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -152,7 +152,8 @@ TEST_F(VertexPullingTest, BasicModule) { InitBasicModule(); InitTransform({}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); } TEST_F(VertexPullingTest, OneAttribute) { @@ -164,7 +165,8 @@ TEST_F(VertexPullingTest, OneAttribute) { InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -193,7 +195,8 @@ TEST_F(VertexPullingTest, OneAttribute) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -250,7 +253,8 @@ TEST_F(VertexPullingTest, OneInstancedAttribute) { {{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -279,7 +283,8 @@ TEST_F(VertexPullingTest, OneInstancedAttribute) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -336,7 +341,8 @@ TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) { transform()->SetPullingBufferBindingSet(5); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -365,7 +371,8 @@ TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -451,7 +458,8 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) { {4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -502,7 +510,8 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -592,7 +601,8 @@ TEST_F(VertexPullingTest, TwoAttributesSameBuffer) { {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -626,7 +636,8 @@ TEST_F(VertexPullingTest, TwoAttributesSameBuffer) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -778,7 +789,8 @@ TEST_F(VertexPullingTest, FloatVectorAttributes) { {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -835,7 +847,8 @@ TEST_F(VertexPullingTest, FloatVectorAttributes) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 91fdba5b78..86e8d2e83c 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -122,7 +122,7 @@ bool TypeDeterminer::Determine() { continue; } for (const auto& callee : caller_to_callee_[func->name()]) { - set_entry_points(callee, func->name()); + set_entry_points(callee, func->symbol()); } } @@ -130,11 +130,11 @@ bool TypeDeterminer::Determine() { } void TypeDeterminer::set_entry_points(const std::string& fn_name, - const std::string& ep_name) { - name_to_function_[fn_name]->add_ancestor_entry_point(ep_name); + Symbol ep_sym) { + name_to_function_[fn_name]->add_ancestor_entry_point(ep_sym); for (const auto& callee : caller_to_callee_[fn_name]) { - set_entry_points(callee, ep_name); + set_entry_points(callee, ep_sym); } } @@ -389,7 +389,8 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) { if (current_function_) { caller_to_callee_[current_function_->name()].push_back(ident->name()); - auto* callee_func = mod_->FindFunctionByName(ident->name()); + auto* callee_func = + mod_->FindFunctionBySymbol(mod_->GetSymbol(ident->name())); if (callee_func == nullptr) { set_error(expr->source(), "unable to find called function: " + ident->name()); diff --git a/src/type_determiner.h b/src/type_determiner.h index f7cb587b6a..0b1157ac49 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -113,7 +113,7 @@ class TypeDeterminer { private: void set_error(const Source& src, const std::string& msg); void set_referenced_from_function_if_needed(ast::Variable* var, bool local); - void set_entry_points(const std::string& fn_name, const std::string& ep_name); + void set_entry_points(const std::string& fn_name, Symbol ep_sym); bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr); bool DetermineBinary(ast::BinaryExpression* expr); diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index c798986304..a45e8961ec 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -341,9 +341,9 @@ TEST_F(TypeDeterminerTest, Stmt_Call) { ast::type::F32 f32; ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, - create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32, + create(), ast::FunctionDecorationList{}); mod->AddFunction(func); // Register the function @@ -372,15 +372,16 @@ TEST_F(TypeDeterminerTest, Stmt_Call_undeclared) { auto* main_body = create(); main_body->append(create(call_expr)); main_body->append(create(Source{})); - auto* func_main = - create(Source{}, "main", params0, &f32, main_body, - ast::FunctionDecorationList{}); + auto* func_main = create(Source{}, mod->RegisterSymbol("main"), + "main", params0, &f32, main_body, + ast::FunctionDecorationList{}); mod->AddFunction(func_main); auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "func", params0, &f32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod->RegisterSymbol("func"), "func", + params0, &f32, body, ast::FunctionDecorationList{}); mod->AddFunction(func); EXPECT_FALSE(td()->Determine()) << td()->error(); @@ -639,9 +640,9 @@ TEST_F(TypeDeterminerTest, Expr_Call) { ast::type::F32 f32; ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, - create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32, + create(), ast::FunctionDecorationList{}); mod->AddFunction(func); // Register the function @@ -659,9 +660,9 @@ TEST_F(TypeDeterminerTest, Expr_Call_WithParams) { ast::type::F32 f32; ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, - create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32, + create(), ast::FunctionDecorationList{}); mod->AddFunction(func); // Register the function @@ -809,8 +810,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) { body->append(create( my_var, create("my_var"))); - ast::Function f(Source{}, "my_func", {}, &f32, body, - ast::FunctionDecorationList{}); + ast::Function f(Source{}, mod->RegisterSymbol("my_func"), "my_func", {}, &f32, + body, ast::FunctionDecorationList{}); EXPECT_TRUE(td()->DetermineFunction(&f)); @@ -836,8 +837,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) { body->append(create( my_var, create("my_var"))); - ast::Function f(Source{}, "my_func", {}, &f32, body, - ast::FunctionDecorationList{}); + ast::Function f(Source{}, mod->RegisterSymbol("myfunc"), "my_func", {}, &f32, + body, ast::FunctionDecorationList{}); EXPECT_TRUE(td()->DetermineFunction(&f)); @@ -868,8 +869,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function_Ptr) { body->append(create( my_var, create("my_var"))); - ast::Function f(Source{}, "my_func", {}, &f32, body, - ast::FunctionDecorationList{}); + ast::Function f(Source{}, mod->RegisterSymbol("my_func"), "my_func", {}, &f32, + body, ast::FunctionDecorationList{}); EXPECT_TRUE(td()->DetermineFunction(&f)); @@ -885,9 +886,9 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) { ast::type::F32 f32; ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, - create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32, + create(), ast::FunctionDecorationList{}); mod->AddFunction(func); // Register the function @@ -968,8 +969,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables) { body->append(create( create("priv_var"), create("priv_var"))); - auto* func = create(Source{}, "my_func", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod->RegisterSymbol("my_func"), "my_func", + params, &f32, body, ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -1049,8 +1051,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) { create("priv_var"), create("priv_var"))); ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod->RegisterSymbol("my_func"), "my_func", + params, &f32, body, ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -1059,8 +1062,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) { create("out_var"), create(create("my_func"), ast::ExpressionList{}))); - auto* func2 = create(Source{}, "func", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func2 = + create(Source{}, mod->RegisterSymbol("func"), "func", + params, &f32, body, ast::FunctionDecorationList{}); mod->AddFunction(func2); @@ -1096,8 +1100,9 @@ TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) { create(&f32, 1.f)))); ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod->RegisterSymbol("my_func"), "my_func", + params, &f32, body, ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -2636,8 +2641,9 @@ TEST_F(TypeDeterminerTest, StorageClass_SetsIfMissing) { auto* body = create(); body->append(stmt); - auto* func = create(Source{}, "func", ast::VariableList{}, - &i32, body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod->RegisterSymbol("func"), + "func", ast::VariableList{}, &i32, body, + ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -2660,8 +2666,9 @@ TEST_F(TypeDeterminerTest, StorageClass_DoesNotSetOnConst) { auto* body = create(); body->append(stmt); - auto* func = create(Source{}, "func", ast::VariableList{}, - &i32, body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod->RegisterSymbol("func"), + "func", ast::VariableList{}, &i32, body, + ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -2684,8 +2691,9 @@ TEST_F(TypeDeterminerTest, StorageClass_NonFunctionClassError) { auto* body = create(); body->append(stmt); - auto* func = create(Source{}, "func", ast::VariableList{}, - &i32, body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod->RegisterSymbol("func"), + "func", ast::VariableList{}, &i32, body, + ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -4857,24 +4865,27 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { ast::VariableList params; auto* body = create(); - auto* func_b = create(Source{}, "b", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func_b = + create(Source{}, mod->RegisterSymbol("b"), "b", params, + &f32, body, ast::FunctionDecorationList{}); body = create(); body->append(create( create("second"), create(create("b"), ast::ExpressionList{}))); - auto* func_c = create(Source{}, "c", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func_c = + create(Source{}, mod->RegisterSymbol("c"), "c", params, + &f32, body, ast::FunctionDecorationList{}); body = create(); body->append(create( create("first"), create(create("c"), ast::ExpressionList{}))); - auto* func_a = create(Source{}, "a", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func_a = + create(Source{}, mod->RegisterSymbol("a"), "a", params, + &f32, body, ast::FunctionDecorationList{}); body = create(); body->append(create( @@ -4886,7 +4897,7 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { create(create("b"), ast::ExpressionList{}))); auto* ep_1 = create( - Source{}, "ep_1", params, &f32, body, + Source{}, mod->RegisterSymbol("ep_1"), "ep_1", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -4897,7 +4908,7 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { create(create("c"), ast::ExpressionList{}))); auto* ep_2 = create( - Source{}, "ep_2", params, &f32, body, + Source{}, mod->RegisterSymbol("ep_2"), "ep_2", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -4954,17 +4965,17 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { const auto& b_eps = func_b->ancestor_entry_points(); ASSERT_EQ(2u, b_eps.size()); - EXPECT_EQ("ep_1", b_eps[0]); - EXPECT_EQ("ep_2", b_eps[1]); + EXPECT_EQ(mod->RegisterSymbol("ep_1"), b_eps[0]); + EXPECT_EQ(mod->RegisterSymbol("ep_2"), b_eps[1]); const auto& a_eps = func_a->ancestor_entry_points(); ASSERT_EQ(1u, a_eps.size()); - EXPECT_EQ("ep_1", a_eps[0]); + EXPECT_EQ(mod->RegisterSymbol("ep_1"), a_eps[0]); const auto& c_eps = func_c->ancestor_entry_points(); ASSERT_EQ(2u, c_eps.size()); - EXPECT_EQ("ep_1", c_eps[0]); - EXPECT_EQ("ep_2", c_eps[1]); + EXPECT_EQ(mod->RegisterSymbol("ep_1"), c_eps[0]); + EXPECT_EQ(mod->RegisterSymbol("ep_2"), c_eps[1]); EXPECT_TRUE(ep_1->ancestor_entry_points().empty()); EXPECT_TRUE(ep_2->ancestor_entry_points().empty()); diff --git a/src/validator/validator_function_test.cc b/src/validator/validator_function_test.cc index c335c226c9..780b5427d2 100644 --- a/src/validator/validator_function_test.cc +++ b/src/validator/validator_function_test.cc @@ -54,7 +54,8 @@ TEST_F(ValidateFunctionTest, VoidFunctionEndWithoutReturnStatement_Pass) { auto* body = create(); body->append(create(var)); auto* func = create( - Source{Source::Location{12, 34}}, "func", params, &void_type, body, + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func", + params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -71,8 +72,8 @@ TEST_F(ValidateFunctionTest, ast::type::Void void_type; ast::VariableList params; auto* func = create( - Source{Source::Location{12, 34}}, "func", params, &void_type, - create(), + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func", + params, &void_type, create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -100,9 +101,9 @@ TEST_F(ValidateFunctionTest, FunctionEndWithoutReturnStatement_Fail) { ast::type::Void void_type; auto* body = create(); body->append(create(var)); - auto* func = - create(Source{Source::Location{12, 34}}, "func", params, - &i32, body, ast::FunctionDecorationList{}); + auto* func = create( + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func", + params, &i32, body, ast::FunctionDecorationList{}); mod()->AddFunction(func); EXPECT_TRUE(td()->Determine()) << td()->error(); @@ -117,8 +118,9 @@ TEST_F(ValidateFunctionTest, FunctionEndWithoutReturnStatementEmptyBody_Fail) { ast::type::I32 i32; ast::VariableList params; auto* func = create( - Source{Source::Location{12, 34}}, "func", params, &i32, - create(), ast::FunctionDecorationList{}); + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func", + params, &i32, create(), + ast::FunctionDecorationList{}); mod()->AddFunction(func); EXPECT_TRUE(td()->Determine()) << td()->error(); @@ -136,7 +138,7 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_Pass) { auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{}, "func", params, &void_type, body, + Source{}, mod()->RegisterSymbol("func"), "func", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -157,7 +159,8 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_fail) { body->append(create(Source{Source::Location{12, 34}}, return_expr)); - auto* func = create(Source{}, "func", params, &void_type, body, + auto* func = create(Source{}, mod()->RegisterSymbol("func"), + "func", params, &void_type, body, ast::FunctionDecorationList{}); mod()->AddFunction(func); @@ -180,8 +183,9 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) { body->append(create(Source{Source::Location{12, 34}}, return_expr)); - auto* func = create(Source{}, "func", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod()->RegisterSymbol("func"), "func", + params, &f32, body, ast::FunctionDecorationList{}); mod()->AddFunction(func); EXPECT_TRUE(td()->Determine()) << td()->error(); @@ -204,8 +208,9 @@ TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) { create(&i32, 2)); body->append(create(Source{}, return_expr)); - auto* func = create(Source{}, "func", params, &i32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod()->RegisterSymbol("func"), "func", + params, &i32, body, ast::FunctionDecorationList{}); ast::VariableList params_copy; auto* body_copy = create(); @@ -213,9 +218,9 @@ TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) { create(&i32, 2)); body_copy->append(create(Source{}, return_expr_copy)); - auto* func_copy = create(Source{Source::Location{12, 34}}, - "func", params_copy, &i32, body_copy, - ast::FunctionDecorationList{}); + auto* func_copy = create( + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func", + params_copy, &i32, body_copy, ast::FunctionDecorationList{}); mod()->AddFunction(func); mod()->AddFunction(func_copy); @@ -237,7 +242,8 @@ TEST_F(ValidateFunctionTest, RecursionIsNotAllowed_Fail) { auto* body0 = create(); body0->append(create(call_expr)); body0->append(create(Source{})); - auto* func0 = create(Source{}, "func", params0, &f32, body0, + auto* func0 = create(Source{}, mod()->RegisterSymbol("func"), + "func", params0, &f32, body0, ast::FunctionDecorationList{}); mod()->AddFunction(func0); @@ -268,7 +274,8 @@ TEST_F(ValidateFunctionTest, RecursionIsNotAllowedExpr_Fail) { create(&i32, 2)); body0->append(create(Source{}, return_expr)); - auto* func0 = create(Source{}, "func", params0, &i32, body0, + auto* func0 = create(Source{}, mod()->RegisterSymbol("func"), + "func", params0, &i32, body0, ast::FunctionDecorationList{}); mod()->AddFunction(func0); @@ -288,7 +295,8 @@ TEST_F(ValidateFunctionTest, Function_WithPipelineStage_NotVoid_Fail) { auto* body = create(); body->append(create(Source{}, return_expr)); auto* func = create( - Source{Source::Location{12, 34}}, "vtx_main", params, &i32, body, + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("vtx_main"), + "vtx_main", params, &i32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -317,7 +325,8 @@ TEST_F(ValidateFunctionTest, Function_WithPipelineStage_WithParams_Fail) { auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{Source::Location{12, 34}}, "vtx_func", params, &void_type, body, + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("vtx_func"), + "vtx_func", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -339,7 +348,8 @@ TEST_F(ValidateFunctionTest, PipelineStage_MustBeUnique_Fail) { auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{Source::Location{12, 34}}, "main", params, &void_type, body, + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("main"), "main", + params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), create(ast::PipelineStage::kFragment, Source{}), @@ -361,7 +371,8 @@ TEST_F(ValidateFunctionTest, OnePipelineStageFunctionMustBePresent_Pass) { auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{}, "vtx_func", params, &void_type, body, + Source{}, mod()->RegisterSymbol("vtx_func"), "vtx_func", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -377,8 +388,9 @@ TEST_F(ValidateFunctionTest, OnePipelineStageFunctionMustBePresent_Fail) { ast::VariableList params; auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "vtx_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod()->RegisterSymbol("vtx_func"), "vtx_func", params, + &void_type, body, ast::FunctionDecorationList{}); mod()->AddFunction(func); EXPECT_TRUE(td()->Determine()) << td()->error(); diff --git a/src/validator/validator_test.cc b/src/validator/validator_test.cc index d993b1ea3f..5931fc0688 100644 --- a/src/validator/validator_test.cc +++ b/src/validator/validator_test.cc @@ -332,7 +332,8 @@ TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariable_Fail) { body->append(create( Source{Source::Location{12, 34}}, lhs, rhs)); - auto* func = create(Source{}, "my_func", params, &f32, body, + auto* func = create(Source{}, mod()->RegisterSymbol("my_func"), + "my_func", params, &f32, body, ast::FunctionDecorationList{}); mod()->AddFunction(func); @@ -370,7 +371,8 @@ TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariable_Pass) { Source{Source::Location{12, 34}}, lhs, rhs)); body->append(create(Source{})); auto* func = create( - Source{}, "my_func", params, &void_type, body, + Source{}, mod()->RegisterSymbol("my_func"), "my_func", params, &void_type, + body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -587,8 +589,9 @@ TEST_F(ValidatorTest, GlobalVariableFunctionVariableNotUnique_Fail) { auto* body = create(); body->append(create( Source{Source::Location{12, 34}}, var)); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod()->RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod()->AddFunction(func); @@ -631,8 +634,9 @@ TEST_F(ValidatorTest, RedeclaredIndentifier_Fail) { body->append(create(var)); body->append(create( Source{Source::Location{12, 34}}, var_a_float)); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod()->RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod()->AddFunction(func); @@ -759,8 +763,9 @@ TEST_F(ValidatorTest, RedeclaredIdentifierDifferentFunctions_Pass) { body0->append(create( Source{Source::Location{12, 34}}, var0)); body0->append(create(Source{})); - auto* func0 = create(Source{}, "func0", params0, &void_type, - body0, ast::FunctionDecorationList{}); + auto* func0 = create(Source{}, mod()->RegisterSymbol("func0"), + "func0", params0, &void_type, body0, + ast::FunctionDecorationList{}); ast::VariableList params1; auto* body1 = create(); @@ -768,7 +773,8 @@ TEST_F(ValidatorTest, RedeclaredIdentifierDifferentFunctions_Pass) { Source{Source::Location{13, 34}}, var1)); body1->append(create(Source{})); auto* func1 = create( - Source{}, "func1", params1, &void_type, body1, + Source{}, mod()->RegisterSymbol("func1"), "func1", params1, &void_type, + body1, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); diff --git a/src/validator/validator_type_test.cc b/src/validator/validator_type_test.cc index 2fbaeffac2..ac8cc05de7 100644 --- a/src/validator/validator_type_test.cc +++ b/src/validator/validator_type_test.cc @@ -206,8 +206,9 @@ TEST_F(ValidatorTypeTest, RuntimeArrayInFunction_Fail) { auto* body = create(); body->append(create( Source{Source::Location{12, 34}}, var)); + auto* func = create( - Source{}, "func", params, &void_type, body, + Source{}, mod()->RegisterSymbol("func"), "func", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index c52c8df572..fa49a4b0e0 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -196,15 +196,15 @@ std::string GeneratorImpl::current_ep_var_name(VarType type) { std::string name = ""; switch (type) { case VarType::kIn: { - auto in_it = ep_name_to_in_data_.find(current_ep_name_); - if (in_it != ep_name_to_in_data_.end()) { + auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value()); + if (in_it != ep_sym_to_in_data_.end()) { name = in_it->second.var_name; } break; } case VarType::kOut: { - auto outit = ep_name_to_out_data_.find(current_ep_name_); - if (outit != ep_name_to_out_data_.end()) { + auto outit = ep_sym_to_out_data_.find(current_ep_sym_.value()); + if (outit != ep_sym_to_out_data_.end()) { name = outit->second.var_name; } break; @@ -668,12 +668,14 @@ bool GeneratorImpl::EmitCall(std::ostream& pre, } auto name = ident->name(); - auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name); + auto caller_sym = module_->GetSymbol(name); + auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" + + caller_sym.to_str()); if (it != ep_func_name_remapped_.end()) { name = it->second; } - auto* func = module_->FindFunctionByName(ident->name()); + auto* func = module_->FindFunctionBySymbol(module_->GetSymbol(ident->name())); if (func == nullptr) { error_ = "Unable to find function: " + name; return false; @@ -1189,15 +1191,15 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) { has_referenced_var_needing_struct(func); if (emit_duplicate_functions) { - for (const auto& ep_name : func->ancestor_entry_points()) { - if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_name)) { + for (const auto& ep_sym : func->ancestor_entry_points()) { + if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_sym)) { return false; } out << std::endl; } } else { // Emit as non-duplicated - if (!EmitFunctionInternal(out, func, false, "")) { + if (!EmitFunctionInternal(out, func, false, Symbol())) { return false; } out << std::endl; @@ -1209,8 +1211,8 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) { bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, ast::Function* func, bool emit_duplicate_functions, - const std::string& ep_name) { - auto name = func->name(); + Symbol ep_sym) { + auto name = func->symbol().to_str(); if (!EmitType(out, func->return_type(), "")) { return false; @@ -1219,10 +1221,15 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, out << " "; if (emit_duplicate_functions) { - name = generate_name(name + "_" + ep_name); - ep_func_name_remapped_[ep_name + "_" + func->name()] = name; + auto func_name = name; + auto ep_name = ep_sym.to_str(); + // TODO(dsinclair): The SymbolToName should go away and just use + // to_str() here when the conversion is complete. + name = generate_name(func->name() + "_" + module_->SymbolToName(ep_sym)); + ep_func_name_remapped_[ep_name + "_" + func_name] = name; } else { - name = namer_.NameFor(name); + // TODO(dsinclair): this should be updated to a remapped name + name = namer_.NameFor(func->name()); } out << name << "("; @@ -1234,15 +1241,15 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, // // We emit both of them if they're there regardless of if they're both used. if (emit_duplicate_functions) { - auto in_it = ep_name_to_in_data_.find(ep_name); - if (in_it != ep_name_to_in_data_.end()) { + auto in_it = ep_sym_to_in_data_.find(ep_sym.value()); + if (in_it != ep_sym_to_in_data_.end()) { out << "in " << in_it->second.struct_name << " " << in_it->second.var_name; first = false; } - auto outit = ep_name_to_out_data_.find(ep_name); - if (outit != ep_name_to_out_data_.end()) { + auto outit = ep_sym_to_out_data_.find(ep_sym.value()); + if (outit != ep_sym_to_out_data_.end()) { if (!first) { out << ", "; } @@ -1269,13 +1276,13 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, out << ") "; - current_ep_name_ = ep_name; + current_ep_sym_ = ep_sym; if (!EmitBlockAndNewline(out, func->body())) { return false; } - current_ep_name_ = ""; + current_ep_sym_ = Symbol(); return true; } @@ -1392,7 +1399,7 @@ bool GeneratorImpl::EmitEntryPointData( auto in_struct_name = generate_name(func->name() + "_" + kInStructNameSuffix); auto in_var_name = generate_name(kTintStructInVarPrefix); - ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name}; + ep_sym_to_in_data_[func->symbol().value()] = {in_struct_name, in_var_name}; make_indent(out); out << "struct " << in_struct_name << " {" << std::endl; @@ -1438,7 +1445,7 @@ bool GeneratorImpl::EmitEntryPointData( auto outstruct_name = generate_name(func->name() + "_" + kOutStructNameSuffix); auto outvar_name = generate_name(kTintStructOutVarPrefix); - ep_name_to_out_data_[func->name()] = {outstruct_name, outvar_name}; + ep_sym_to_out_data_[func->symbol().value()] = {outstruct_name, outvar_name}; make_indent(out); out << "struct " << outstruct_name << " {" << std::endl; @@ -1516,7 +1523,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, ast::Function* func) { make_indent(out); - current_ep_name_ = func->name(); + current_ep_sym_ = func->symbol(); if (func->pipeline_stage() == ast::PipelineStage::kCompute) { uint32_t x = 0; @@ -1528,17 +1535,18 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, make_indent(out); } - auto outdata = ep_name_to_out_data_.find(current_ep_name_); - bool has_outdata = outdata != ep_name_to_out_data_.end(); + auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value()); + bool has_outdata = outdata != ep_sym_to_out_data_.end(); if (has_outdata) { out << outdata->second.struct_name; } else { out << "void"; } - out << " " << namer_.NameFor(current_ep_name_) << "("; + // TODO(dsinclair): This should output the remapped name + out << " " << namer_.NameFor(module_->SymbolToName(current_ep_sym_)) << "("; - auto in_data = ep_name_to_in_data_.find(current_ep_name_); - if (in_data != ep_name_to_in_data_.end()) { + auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value()); + if (in_data != ep_sym_to_in_data_.end()) { out << in_data->second.struct_name << " " << in_data->second.var_name; } out << ") {" << std::endl; @@ -1563,7 +1571,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, make_indent(out); out << "}" << std::endl; - current_ep_name_ = ""; + current_ep_sym_ = Symbol(); return true; } @@ -1966,8 +1974,8 @@ bool GeneratorImpl::EmitReturn(std::ostream& out, ast::ReturnStatement* stmt) { if (generating_entry_point_) { out << "return"; - auto outdata = ep_name_to_out_data_.find(current_ep_name_); - if (outdata != ep_name_to_out_data_.end()) { + auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value()); + if (outdata != ep_sym_to_out_data_.end()) { out << " " << outdata->second.var_name; } } else if (stmt->has_value()) { diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index 6694e26040..ebd98b993d 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -210,12 +210,12 @@ class GeneratorImpl { /// @param func the function to emit /// @param emit_duplicate_functions set true if we need to duplicate per entry /// point - /// @param ep_name the current entry point or blank if none set + /// @param ep_sym the current entry point or symbol::kInvalid if none set /// @returns true if the function was emitted. bool EmitFunctionInternal(std::ostream& out, ast::Function* func, bool emit_duplicate_functions, - const std::string& ep_name); + Symbol ep_sym); /// Handles emitting information for an entry point /// @param out the output stream /// @param func the entry point @@ -397,12 +397,12 @@ class GeneratorImpl { Namer namer_; ast::Module* module_ = nullptr; - std::string current_ep_name_; + Symbol current_ep_sym_; bool generating_entry_point_ = false; uint32_t loop_emission_counter_ = 0; ScopeStack global_variables_; - std::unordered_map ep_name_to_in_data_; - std::unordered_map ep_name_to_out_data_; + std::unordered_map ep_sym_to_in_data_; + std::unordered_map ep_sym_to_out_data_; // This maps an input of "_" to a remapped // function name. If there is no entry for a given key then function did diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc index 20654ebd41..318c68b720 100644 --- a/src/writer/hlsl/generator_impl_binary_test.cc +++ b/src/writer/hlsl/generator_impl_binary_test.cc @@ -613,9 +613,9 @@ TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) { ast::type::Void void_type; - auto* func = create(Source{}, "foo", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("foo"), "foo", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ast::ExpressionList params; diff --git a/src/writer/hlsl/generator_impl_call_test.cc b/src/writer/hlsl/generator_impl_call_test.cc index e185dcc095..3311837bae 100644 --- a/src/writer/hlsl/generator_impl_call_test.cc +++ b/src/writer/hlsl/generator_impl_call_test.cc @@ -35,9 +35,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithoutParams) { auto* id = create("my_func"); ast::CallExpression call(id, {}); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error(); @@ -53,9 +53,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithParams) { params.push_back(create("param2")); ast::CallExpression call(id, params); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error(); @@ -71,9 +71,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitStatement_Call) { params.push_back(create("param2")); ast::CallStatement call(create(id, params)); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); ASSERT_TRUE(gen.EmitStatement(out, &call)) << gen.error(); diff --git a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc index 57c8553103..320b478c32 100644 --- a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc +++ b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc @@ -91,7 +91,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "vtx_main", params, &f32, body, + Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -164,7 +164,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "vtx_main", params, &f32, body, + Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -237,7 +237,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -309,7 +309,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -378,7 +378,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -442,7 +442,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -512,7 +512,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("x")))); auto* func = create( - Source{}, "main", params, &void_type, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index 95543a7279..cab53fc611 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -57,9 +57,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function) { auto* body = create(); body->append(create(Source{})); - auto* func = - create(Source{}, "my_func", ast::VariableList{}, - &void_type, body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", ast::VariableList{}, &void_type, + body, ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -77,9 +77,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Name_Collision) { auto* body = create(); body->append(create(Source{})); - auto* func = - create(Source{}, "GeometryShader", ast::VariableList{}, - &void_type, body, ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("GeometryShader"), "GeometryShader", + ast::VariableList{}, &void_type, body, ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -118,8 +118,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithParams) { auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -174,7 +175,8 @@ TEST_F(HlslGeneratorImplTest_Function, create("foo"))); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -245,7 +247,8 @@ TEST_F(HlslGeneratorImplTest_Function, create("x")))); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -309,7 +312,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -380,7 +384,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -455,7 +460,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -526,7 +532,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -594,7 +601,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(assign); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -682,8 +690,9 @@ TEST_F( create("param"))); body->append(create( Source{}, create("foo"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -698,7 +707,7 @@ TEST_F( expr))); body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -766,8 +775,9 @@ TEST_F(HlslGeneratorImplTest_Function, auto* body = create(); body->append(create( Source{}, create("param"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -782,7 +792,7 @@ TEST_F(HlslGeneratorImplTest_Function, expr))); body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -863,8 +873,9 @@ TEST_F( create("x")))); body->append(create( Source{}, create("param"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -879,7 +890,7 @@ TEST_F( expr))); body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -948,8 +959,9 @@ TEST_F(HlslGeneratorImplTest_Function, Source{}, create( create("coord"), create("x")))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -971,7 +983,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1034,8 +1047,9 @@ TEST_F(HlslGeneratorImplTest_Function, Source{}, create( create("coord"), create("x")))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -1057,7 +1071,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1122,7 +1137,7 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1152,8 +1167,8 @@ TEST_F(HlslGeneratorImplTest_Function, ast::type::Void void_type; auto* func = create( - Source{}, "GeometryShader", ast::VariableList{}, &void_type, - create(), + Source{}, mod.RegisterSymbol("GeometryShader"), "GeometryShader", + ast::VariableList{}, &void_type, create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1175,7 +1190,7 @@ TEST_F(HlslGeneratorImplTest_Function, auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{}, "main", params, &void_type, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -1200,7 +1215,7 @@ TEST_F(HlslGeneratorImplTest_Function, auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{}, "main", params, &void_type, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), create(2u, 4u, 6u, Source{}), @@ -1236,8 +1251,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -1317,12 +1333,12 @@ TEST_F(HlslGeneratorImplTest_Function, auto* body = create(); body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "a", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } @@ -1343,12 +1359,12 @@ TEST_F(HlslGeneratorImplTest_Function, auto* body = create(); body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "b", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } diff --git a/src/writer/hlsl/generator_impl_test.cc b/src/writer/hlsl/generator_impl_test.cc index dfdaf9519a..1494fcf5e9 100644 --- a/src/writer/hlsl/generator_impl_test.cc +++ b/src/writer/hlsl/generator_impl_test.cc @@ -29,9 +29,9 @@ using HlslGeneratorImplTest = TestHelper; TEST_F(HlslGeneratorImplTest, Generate) { ast::type::Void void_type; - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ASSERT_TRUE(gen.Generate(out)) << gen.error(); diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index 000c3d7fed..4af9556c15 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -411,15 +411,15 @@ std::string GeneratorImpl::current_ep_var_name(VarType type) { std::string name = ""; switch (type) { case VarType::kIn: { - auto in_it = ep_name_to_in_data_.find(current_ep_name_); - if (in_it != ep_name_to_in_data_.end()) { + auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value()); + if (in_it != ep_sym_to_in_data_.end()) { name = in_it->second.var_name; } break; } case VarType::kOut: { - auto out_it = ep_name_to_out_data_.find(current_ep_name_); - if (out_it != ep_name_to_out_data_.end()) { + auto out_it = ep_sym_to_out_data_.find(current_ep_sym_.value()); + if (out_it != ep_sym_to_out_data_.end()) { name = out_it->second.var_name; } break; @@ -573,12 +573,14 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { } auto name = ident->name(); - auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name); + auto caller_sym = module_->GetSymbol(name); + auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" + + caller_sym.to_str()); if (it != ep_func_name_remapped_.end()) { name = it->second; } - auto* func = module_->FindFunctionByName(ident->name()); + auto* func = module_->FindFunctionBySymbol(module_->GetSymbol(ident->name())); if (func == nullptr) { error_ = "Unable to find function: " + name; return false; @@ -1026,7 +1028,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { auto in_struct_name = generate_name(func->name() + "_" + kInStructNameSuffix); auto in_var_name = generate_name(kTintStructInVarPrefix); - ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name}; + ep_sym_to_in_data_[func->symbol().value()] = {in_struct_name, in_var_name}; make_indent(); out_ << "struct " << in_struct_name << " {" << std::endl; @@ -1063,7 +1065,8 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { auto out_struct_name = generate_name(func->name() + "_" + kOutStructNameSuffix); auto out_var_name = generate_name(kTintStructOutVarPrefix); - ep_name_to_out_data_[func->name()] = {out_struct_name, out_var_name}; + ep_sym_to_out_data_[func->symbol().value()] = {out_struct_name, + out_var_name}; make_indent(); out_ << "struct " << out_struct_name << " {" << std::endl; @@ -1205,15 +1208,15 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { has_referenced_var_needing_struct(func); if (emit_duplicate_functions) { - for (const auto& ep_name : func->ancestor_entry_points()) { - if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) { + for (const auto& ep_sym : func->ancestor_entry_points()) { + if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_sym)) { return false; } out_ << std::endl; } } else { // Emit as non-duplicated - if (!EmitFunctionInternal(func, false, "")) { + if (!EmitFunctionInternal(func, false, Symbol())) { return false; } out_ << std::endl; @@ -1224,19 +1227,23 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, bool emit_duplicate_functions, - const std::string& ep_name) { - auto name = func->name(); - + Symbol ep_sym) { + auto name = func->symbol().to_str(); if (!EmitType(func->return_type(), "")) { return false; } out_ << " "; if (emit_duplicate_functions) { - name = generate_name(name + "_" + ep_name); - ep_func_name_remapped_[ep_name + "_" + func->name()] = name; + auto func_name = name; + auto ep_name = ep_sym.to_str(); + // TODO(dsinclair): The SymbolToName should go away and just use + // to_str() here when the conversion is complete. + name = generate_name(func->name() + "_" + module_->SymbolToName(ep_sym)); + ep_func_name_remapped_[ep_name + "_" + func_name] = name; } else { - name = namer_.NameFor(name); + // TODO(dsinclair): this should be updated to a remapped name + name = namer_.NameFor(func->name()); } out_ << name << "("; @@ -1247,15 +1254,15 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, // // We emit both of them if they're there regardless of if they're both used. if (emit_duplicate_functions) { - auto in_it = ep_name_to_in_data_.find(ep_name); - if (in_it != ep_name_to_in_data_.end()) { + auto in_it = ep_sym_to_in_data_.find(ep_sym.value()); + if (in_it != ep_sym_to_in_data_.end()) { out_ << "thread " << in_it->second.struct_name << "& " << in_it->second.var_name; first = false; } - auto out_it = ep_name_to_out_data_.find(ep_name); - if (out_it != ep_name_to_out_data_.end()) { + auto out_it = ep_sym_to_out_data_.find(ep_sym.value()); + if (out_it != ep_sym_to_out_data_.end()) { if (!first) { out_ << ", "; } @@ -1337,13 +1344,13 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, out_ << ") "; - current_ep_name_ = ep_name; + current_ep_sym_ = ep_sym; if (!EmitBlockAndNewline(func->body())) { return false; } - current_ep_name_ = ""; + current_ep_sym_ = Symbol(); return true; } @@ -1377,25 +1384,25 @@ std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const { bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { make_indent(); - current_ep_name_ = func->name(); + current_ep_sym_ = func->symbol(); EmitStage(func->pipeline_stage()); out_ << " "; // This is an entry point, the return type is the entry point output structure // if one exists, or void otherwise. - auto out_data = ep_name_to_out_data_.find(current_ep_name_); - bool has_out_data = out_data != ep_name_to_out_data_.end(); + auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value()); + bool has_out_data = out_data != ep_sym_to_out_data_.end(); if (has_out_data) { out_ << out_data->second.struct_name; } else { out_ << "void"; } - out_ << " " << namer_.NameFor(current_ep_name_) << "("; + out_ << " " << namer_.NameFor(func->name()) << "("; bool first = true; - auto in_data = ep_name_to_in_data_.find(current_ep_name_); - if (in_data != ep_name_to_in_data_.end()) { + auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value()); + if (in_data != ep_sym_to_in_data_.end()) { out_ << in_data->second.struct_name << " " << in_data->second.var_name << " [[stage_in]]"; first = false; @@ -1503,7 +1510,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { make_indent(); out_ << "}" << std::endl; - current_ep_name_ = ""; + current_ep_sym_ = Symbol(); return true; } @@ -1687,8 +1694,8 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) { out_ << "return"; if (generating_entry_point_) { - auto out_data = ep_name_to_out_data_.find(current_ep_name_); - if (out_data != ep_name_to_out_data_.end()) { + auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value()); + if (out_data != ep_sym_to_out_data_.end()) { out_ << " " << out_data->second.var_name; } } else if (stmt->has_value()) { diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index 0e4ef7d2eb..d087c7cb4b 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -156,11 +156,11 @@ class GeneratorImpl : public TextGenerator { /// @param func the function to emit /// @param emit_duplicate_functions set true if we need to duplicate per entry /// point - /// @param ep_name the current entry point or blank if none set + /// @param ep_sym the current entry point or symbol::kInvalid if not set /// @returns true if the function was emitted. bool EmitFunctionInternal(ast::Function* func, bool emit_duplicate_functions, - const std::string& ep_name); + Symbol ep_sym); /// Handles generating an identifier expression /// @param expr the identifier expression /// @returns true if the identifier was emitted @@ -282,13 +282,13 @@ class GeneratorImpl : public TextGenerator { Namer namer_; ScopeStack global_variables_; - std::string current_ep_name_; + Symbol current_ep_sym_; bool generating_entry_point_ = false; const ast::Module* module_ = nullptr; uint32_t loop_emission_counter_ = 0; - std::unordered_map ep_name_to_in_data_; - std::unordered_map ep_name_to_out_data_; + std::unordered_map ep_sym_to_in_data_; + std::unordered_map ep_sym_to_out_data_; // This maps an input of "_" to a remapped // function name. If there is no entry for a given key then function did diff --git a/src/writer/msl/generator_impl_call_test.cc b/src/writer/msl/generator_impl_call_test.cc index c27e55689f..71024249ca 100644 --- a/src/writer/msl/generator_impl_call_test.cc +++ b/src/writer/msl/generator_impl_call_test.cc @@ -37,9 +37,9 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithoutParams) { auto* id = create("my_func"); ast::CallExpression call(id, {}); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error(); @@ -55,9 +55,9 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithParams) { params.push_back(create("param2")); ast::CallExpression call(id, params); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error(); @@ -73,9 +73,9 @@ TEST_F(MslGeneratorImplTest, EmitStatement_Call) { params.push_back(create("param2")); ast::CallStatement call(create(id, params)); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); diff --git a/src/writer/msl/generator_impl_function_entry_point_data_test.cc b/src/writer/msl/generator_impl_function_entry_point_data_test.cc index ac6b4c8eba..f43ea3233b 100644 --- a/src/writer/msl/generator_impl_function_entry_point_data_test.cc +++ b/src/writer/msl/generator_impl_function_entry_point_data_test.cc @@ -90,7 +90,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Input) { create("bar"))); auto* func = create( - Source{}, "vtx_main", params, &f32, body, + Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -160,7 +160,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Output) { create("bar"))); auto* func = create( - Source{}, "vtx_main", params, &f32, body, + Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -229,7 +229,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Input) { create("bar"), create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -299,7 +299,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Output) { create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -366,7 +366,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Input) { create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -428,7 +428,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Output) { create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -496,7 +496,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Builtins) { create("x")))); auto* func = create( - Source{}, "main", params, &void_type, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 03934cb2bb..950e8fe8db 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -60,9 +60,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function) { auto* body = create(); body->append(create(Source{})); - auto* func = - create(Source{}, "my_func", ast::VariableList{}, - &void_type, body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", ast::VariableList{}, &void_type, + body, ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -82,9 +82,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_Name_Collision) { auto* body = create(); body->append(create(Source{})); - auto* func = - create(Source{}, "main", ast::VariableList{}, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("main"), + "main", ast::VariableList{}, &void_type, + body, ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -125,8 +125,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithParams) { auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -183,7 +184,8 @@ TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_WithInOutVars) { body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{create( ast::PipelineStage::kFragment, Source{})}); @@ -257,7 +259,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -321,7 +324,8 @@ TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_With_Uniform) { body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -397,7 +401,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -478,7 +483,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -572,8 +578,9 @@ TEST_F( create("param"))); body->append(create( Source{}, create("foo"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -588,7 +595,7 @@ TEST_F( expr))); body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -659,8 +666,9 @@ TEST_F(MslGeneratorImplTest, auto* body = create(); body->append(create( Source{}, create("param"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -676,7 +684,7 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -760,8 +768,9 @@ TEST_F( create("x")))); body->append(create( Source{}, create("param"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -776,7 +785,7 @@ TEST_F( expr))); body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -843,8 +852,9 @@ TEST_F(MslGeneratorImplTest, Source{}, create( create("coord"), create("x")))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -867,7 +877,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -943,8 +954,9 @@ TEST_F(MslGeneratorImplTest, Source{}, create( create("coord"), create("b")))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -967,7 +979,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1049,8 +1062,9 @@ TEST_F(MslGeneratorImplTest, Source{}, create( create("coord"), create("b")))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -1073,7 +1087,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1145,7 +1160,7 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1177,8 +1192,8 @@ TEST_F(MslGeneratorImplTest, ast::type::Void void_type; auto* func = create( - Source{}, "main", ast::VariableList{}, &void_type, - create(), + Source{}, mod.RegisterSymbol("main"), "main", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -1212,8 +1227,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) { auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod.AddFunction(func); @@ -1298,12 +1314,12 @@ TEST_F(MslGeneratorImplTest, body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "a", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } @@ -1325,12 +1341,12 @@ TEST_F(MslGeneratorImplTest, body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "b", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc index 06fdbc7fdf..0ba3247204 100644 --- a/src/writer/msl/generator_impl_test.cc +++ b/src/writer/msl/generator_impl_test.cc @@ -51,8 +51,8 @@ TEST_F(MslGeneratorImplTest, Generate) { ast::type::Void void_type; auto* func = create( - Source{}, "my_func", ast::VariableList{}, &void_type, - create(), + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc index 1f2d50b9fb..a4585a5189 100644 --- a/src/writer/spirv/builder_call_test.cc +++ b/src/writer/spirv/builder_call_test.cc @@ -65,11 +65,11 @@ TEST_F(BuilderTest, Expression_Call) { Source{}, create( ast::BinaryOp::kAdd, create("a"), create("b")))); - ast::Function a_func(Source{}, "a_func", func_params, &f32, body, - ast::FunctionDecorationList{}); + ast::Function a_func(Source{}, mod->RegisterSymbol("a_func"), "a_func", + func_params, &f32, body, ast::FunctionDecorationList{}); - ast::Function func(Source{}, "main", {}, &void_type, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {}, + &void_type, create(), ast::FunctionDecorationList{}); ast::ExpressionList call_params; @@ -143,11 +143,12 @@ TEST_F(BuilderTest, Statement_Call) { ast::BinaryOp::kAdd, create("a"), create("b")))); - ast::Function a_func(Source{}, "a_func", func_params, &void_type, body, + ast::Function a_func(Source{}, mod->RegisterSymbol("a_func"), "a_func", + func_params, &void_type, body, ast::FunctionDecorationList{}); - ast::Function func(Source{}, "main", {}, &void_type, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {}, + &void_type, create(), ast::FunctionDecorationList{}); ast::ExpressionList call_params; diff --git a/src/writer/spirv/builder_function_decoration_test.cc b/src/writer/spirv/builder_function_decoration_test.cc index 6f835e1cc5..dfab021c72 100644 --- a/src/writer/spirv/builder_function_decoration_test.cc +++ b/src/writer/spirv/builder_function_decoration_test.cc @@ -42,7 +42,8 @@ TEST_F(BuilderTest, FunctionDecoration_Stage) { ast::type::Void void_type; ast::Function func( - Source{}, "main", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -67,8 +68,8 @@ TEST_P(FunctionDecoration_StageTest, Emit) { ast::type::Void void_type; - ast::Function func(Source{}, "main", {}, &void_type, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {}, + &void_type, create(), ast::FunctionDecorationList{ create(params.stage, Source{}), }); @@ -97,7 +98,8 @@ TEST_F(BuilderTest, FunctionDecoration_Stage_WithUnusedInterfaceIds) { ast::type::Void void_type; ast::Function func( - Source{}, "main", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -174,7 +176,7 @@ TEST_F(BuilderTest, FunctionDecoration_Stage_WithUsedInterfaceIds) { create("my_in"))); ast::Function func( - Source{}, "main", {}, &void_type, body, + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -244,7 +246,8 @@ TEST_F(BuilderTest, FunctionDecoration_ExecutionMode_Fragment_OriginUpperLeft) { ast::type::Void void_type; ast::Function func( - Source{}, "main", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -259,7 +262,8 @@ TEST_F(BuilderTest, FunctionDecoration_WorkgroupSize_Default) { ast::type::Void void_type; ast::Function func( - Source{}, "main", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -274,7 +278,8 @@ TEST_F(BuilderTest, FunctionDecoration_WorkgroupSize) { ast::type::Void void_type; ast::Function func( - Source{}, "main", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, + create(), ast::FunctionDecorationList{ create(2u, 4u, 6u, Source{}), create(ast::PipelineStage::kCompute, Source{}), @@ -290,13 +295,15 @@ TEST_F(BuilderTest, FunctionDecoration_ExecutionMode_MultipleFragment) { ast::type::Void void_type; ast::Function func1( - Source{}, "main1", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main1"), "main1", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); ast::Function func2( - Source{}, "main2", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main2"), "main2", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc index 51da30623d..983a3b2f0b 100644 --- a/src/writer/spirv/builder_function_test.cc +++ b/src/writer/spirv/builder_function_test.cc @@ -47,8 +47,8 @@ using BuilderTest = TestHelper; TEST_F(BuilderTest, Function_Empty) { ast::type::Void void_type; - ast::Function func(Source{}, "a_func", {}, &void_type, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)); @@ -68,8 +68,8 @@ TEST_F(BuilderTest, Function_Terminator_Return) { auto* body = create(); body->append(create(Source{})); - ast::Function func(Source{}, "a_func", {}, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, body, ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)); EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" @@ -101,8 +101,8 @@ TEST_F(BuilderTest, Function_Terminator_ReturnValue) { Source{}, create("a"))); ASSERT_TRUE(td.DetermineResultType(body)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, body, ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -128,8 +128,8 @@ TEST_F(BuilderTest, Function_Terminator_Discard) { auto* body = create(); body->append(create()); - ast::Function func(Source{}, "a_func", {}, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, body, ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)); EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" @@ -168,8 +168,8 @@ TEST_F(BuilderTest, Function_WithParams) { auto* body = create(); body->append(create( Source{}, create("a"))); - ast::Function func(Source{}, "a_func", params, &f32, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", params, + &f32, body, ast::FunctionDecorationList{}); td.RegisterVariableForTesting(func.params()[0]); td.RegisterVariableForTesting(func.params()[1]); @@ -197,8 +197,8 @@ TEST_F(BuilderTest, Function_WithBody) { auto* body = create(); body->append(create(Source{})); - ast::Function func(Source{}, "a_func", {}, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, body, ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)); EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" @@ -213,8 +213,8 @@ OpFunctionEnd TEST_F(BuilderTest, FunctionType) { ast::type::Void void_type; - ast::Function func(Source{}, "a_func", {}, &void_type, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)); @@ -225,11 +225,11 @@ TEST_F(BuilderTest, FunctionType) { TEST_F(BuilderTest, FunctionType_DeDuplicate) { ast::type::Void void_type; - ast::Function func1(Source{}, "a_func", {}, &void_type, - create(), + ast::Function func1(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, create(), ast::FunctionDecorationList{}); - ast::Function func2(Source{}, "b_func", {}, &void_type, - create(), + ast::Function func2(Source{}, mod->RegisterSymbol("b_func"), "b_func", {}, + &void_type, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func1)); @@ -307,12 +307,12 @@ TEST_F(BuilderTest, Emit_Multiple_EntryPoint_With_Same_ModuleVar) { body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "a", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod->RegisterSymbol("a"), "a", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod->AddFunction(func); } @@ -334,12 +334,12 @@ TEST_F(BuilderTest, Emit_Multiple_EntryPoint_With_Same_ModuleVar) { body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "b", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod->RegisterSymbol("b"), "b", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod->AddFunction(func); } diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc index 98397116d3..d839c9c1e2 100644 --- a/src/writer/spirv/builder_intrinsic_test.cc +++ b/src/writer/spirv/builder_intrinsic_test.cc @@ -471,8 +471,8 @@ TEST_F(IntrinsicBuilderTest, Call_GLSLMethod_WithLoad) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error(); @@ -505,8 +505,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Float_Test, Call_Scalar) { auto expr = Call(param.name, 1.0f); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -533,8 +533,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Float_Test, Call_Vector) { auto expr = Call(param.name, vec2(1.0f, 1.0f)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -587,8 +587,8 @@ TEST_F(IntrinsicBuilderTest, Call_Length_Scalar) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -612,8 +612,8 @@ TEST_F(IntrinsicBuilderTest, Call_Length_Vector) { auto expr = Call("length", vec2(1.0f, 1.0f)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -639,8 +639,8 @@ TEST_F(IntrinsicBuilderTest, Call_Normalize) { auto expr = Call("normalize", vec2(1.0f, 1.0f)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -671,8 +671,8 @@ TEST_P(Intrinsic_Builtin_DualParam_Float_Test, Call_Scalar) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -700,8 +700,8 @@ TEST_P(Intrinsic_Builtin_DualParam_Float_Test, Call_Vector) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -737,8 +737,8 @@ TEST_F(IntrinsicBuilderTest, Call_Distance_Scalar) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -763,8 +763,8 @@ TEST_F(IntrinsicBuilderTest, Call_Distance_Vector) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -792,8 +792,8 @@ TEST_F(IntrinsicBuilderTest, Call_Cross) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -823,8 +823,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Float_Test, Call_Scalar) { auto expr = Call(param.name, 1.0f, 1.0f, 1.0f); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -853,8 +853,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Float_Test, Call_Vector) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -894,8 +894,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Sint_Test, Call_Scalar) { auto expr = Call(param.name, 1); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -922,8 +922,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Sint_Test, Call_Vector) { auto expr = Call(param.name, vec2(1, 1)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -957,8 +957,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Uint_Test, Call_Scalar) { auto expr = Call(param.name, 1u); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -985,8 +985,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Uint_Test, Call_Vector) { auto expr = Call(param.name, vec2(1u, 1u)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1020,8 +1020,8 @@ TEST_P(Intrinsic_Builtin_DualParam_SInt_Test, Call_Scalar) { auto expr = Call(param.name, 1, 1); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1048,8 +1048,8 @@ TEST_P(Intrinsic_Builtin_DualParam_SInt_Test, Call_Vector) { auto expr = Call(param.name, vec2(1, 1), vec2(1, 1)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1084,8 +1084,8 @@ TEST_P(Intrinsic_Builtin_DualParam_UInt_Test, Call_Scalar) { auto expr = Call(param.name, 1u, 1u); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1112,8 +1112,8 @@ TEST_P(Intrinsic_Builtin_DualParam_UInt_Test, Call_Vector) { auto expr = Call(param.name, vec2(1u, 1u), vec2(1u, 1u)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1148,8 +1148,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Sint_Test, Call_Scalar) { auto expr = Call(param.name, 1, 1, 1); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1178,8 +1178,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Sint_Test, Call_Vector) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1213,8 +1213,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Uint_Test, Call_Scalar) { auto expr = Call(param.name, 1u, 1u, 1u); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1243,8 +1243,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Uint_Test, Call_Vector) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1276,8 +1276,8 @@ TEST_F(IntrinsicBuilderTest, Call_Determinant) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1320,8 +1320,8 @@ TEST_F(IntrinsicBuilderTest, Call_ArrayLength) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1360,8 +1360,8 @@ TEST_F(IntrinsicBuilderTest, Call_ArrayLength_OtherMembersInStruct) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1405,8 +1405,8 @@ TEST_F(IntrinsicBuilderTest, DISABLED_Call_ArrayLength_Ptr) { auto expr = Call("arrayLength", "ptr_var"); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); diff --git a/src/writer/spirv/builder_switch_test.cc b/src/writer/spirv/builder_switch_test.cc index 8ae74b0a27..b57cee98a5 100644 --- a/src/writer/spirv/builder_switch_test.cc +++ b/src/writer/spirv/builder_switch_test.cc @@ -121,8 +121,8 @@ TEST_F(BuilderTest, Switch_WithCase) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); @@ -201,8 +201,8 @@ TEST_F(BuilderTest, Switch_WithDefault) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); @@ -300,8 +300,8 @@ TEST_F(BuilderTest, Switch_WithCaseAndDefault) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); @@ -408,8 +408,8 @@ TEST_F(BuilderTest, Switch_CaseWithFallthrough) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); @@ -495,8 +495,8 @@ TEST_F(BuilderTest, Switch_CaseFallthroughLastStatement) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); @@ -563,8 +563,8 @@ TEST_F(BuilderTest, Switch_WithNestedBreak) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 5abd8db8f3..58982ddd2e 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -113,7 +113,8 @@ bool GeneratorImpl::Generate(const ast::Module& module) { bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module, ast::PipelineStage stage, const std::string& name) { - auto* func = module.FindFunctionByNameAndStage(name, stage); + auto* func = + module.FindFunctionBySymbolAndStage(module.GetSymbol(name), stage); if (func == nullptr) { error_ = "Unable to find requested entry point: " + name; return false; @@ -153,7 +154,7 @@ bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module, } for (auto* f : module.functions()) { - if (!f->HasAncestorEntryPoint(name)) { + if (!f->HasAncestorEntryPoint(module.GetSymbol(name))) { continue; } diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index b9d3a0c0dd..573b391ae2 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc @@ -46,8 +46,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function) { body->append(create(Source{})); ast::type::Void void_type; - ast::Function func(Source{}, "my_func", {}, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, + &void_type, body, ast::FunctionDecorationList{}); gen.increment_indent(); @@ -85,8 +85,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithParams) { ast::VariableDecorationList{})); // decorations ast::type::Void void_type; - ast::Function func(Source{}, "my_func", params, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", params, + &void_type, body, ast::FunctionDecorationList{}); gen.increment_indent(); @@ -104,7 +104,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) { body->append(create(Source{})); ast::type::Void void_type; - ast::Function func(Source{}, "my_func", {}, &void_type, body, + ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, + &void_type, body, ast::FunctionDecorationList{ create(2u, 4u, 6u, Source{}), }); @@ -127,7 +128,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Stage) { ast::type::Void void_type; ast::Function func( - Source{}, "my_func", {}, &void_type, body, + Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -150,7 +151,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Multiple) { ast::type::Void void_type; ast::Function func( - Source{}, "my_func", {}, &void_type, body, + Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), create(2u, 4u, 6u, Source{}), @@ -237,12 +238,12 @@ TEST_F(WgslGeneratorImplTest, body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "a", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } @@ -264,12 +265,12 @@ TEST_F(WgslGeneratorImplTest, body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "b", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } diff --git a/src/writer/wgsl/generator_impl_test.cc b/src/writer/wgsl/generator_impl_test.cc index 53431531fe..1b8da48a8a 100644 --- a/src/writer/wgsl/generator_impl_test.cc +++ b/src/writer/wgsl/generator_impl_test.cc @@ -33,8 +33,9 @@ TEST_F(WgslGeneratorImplTest, Generate) { ast::type::Void void_type; mod.AddFunction(create( - Source{}, "my_func", ast::VariableList{}, &void_type, - create(), ast::FunctionDecorationList{})); + Source{}, mod.RegisterSymbol("a_func"), "my_func", ast::VariableList{}, + &void_type, create(), + ast::FunctionDecorationList{})); ASSERT_TRUE(gen.Generate(mod)) << gen.error(); EXPECT_EQ(gen.result(), R"(fn my_func() -> void {