From a41132fcd860179352f80bd4750b1096654ad435 Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Fri, 11 Dec 2020 18:24:53 +0000 Subject: [PATCH] Add a symbol to the Function AST node. This Cl adds a Symbol representing the function name to the function AST. The symbol is added alongside the name for now. When all usages of the function name are removed then the string version will be removed from the constructor. Change-Id: Ib2450e5fe531e988b25bb7d2937acc6af2187871 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35220 Commit-Queue: dan sinclair Reviewed-by: Ben Clayton Auto-Submit: dan sinclair --- src/ast/function.cc | 14 +- src/ast/function.h | 18 ++- src/ast/function_test.cc | 152 +++++++++++++----- src/ast/module.cc | 16 +- src/ast/module.h | 14 +- src/ast/module_test.cc | 21 ++- src/inspector/inspector.cc | 2 +- src/inspector/inspector_test.cc | 50 +++--- src/reader/spirv/function.cc | 7 +- src/reader/spirv/function_call_test.cc | 16 +- src/reader/spirv/function_decl_test.cc | 28 ++-- .../spirv/parser_impl_function_decl_test.cc | 58 ++++--- src/reader/wgsl/parser_impl.cc | 6 +- src/symbol_table.cc | 9 ++ src/symbol_table.h | 11 ++ src/transform/bound_array_accessors_test.cc | 8 +- src/transform/emit_vertex_point_size_test.cc | 102 +++++++----- src/transform/first_index_offset.cc | 6 +- src/transform/first_index_offset_test.cc | 16 +- src/transform/vertex_pulling.cc | 9 +- src/transform/vertex_pulling_test.cc | 47 ++++-- src/type_determiner.cc | 11 +- src/type_determiner.h | 2 +- src/type_determiner_test.cc | 111 +++++++------ src/validator/validator_function_test.cc | 62 ++++--- src/validator/validator_test.cc | 24 +-- src/validator/validator_type_test.cc | 3 +- src/writer/hlsl/generator_impl.cc | 70 ++++---- src/writer/hlsl/generator_impl.h | 10 +- src/writer/hlsl/generator_impl_binary_test.cc | 6 +- src/writer/hlsl/generator_impl_call_test.cc | 18 +-- ...tor_impl_function_entry_point_data_test.cc | 14 +- .../hlsl/generator_impl_function_test.cc | 114 +++++++------ src/writer/hlsl/generator_impl_test.cc | 6 +- src/writer/msl/generator_impl.cc | 71 ++++---- src/writer/msl/generator_impl.h | 10 +- src/writer/msl/generator_impl_call_test.cc | 18 +-- ...tor_impl_function_entry_point_data_test.cc | 14 +- .../msl/generator_impl_function_test.cc | 112 +++++++------ src/writer/msl/generator_impl_test.cc | 4 +- src/writer/spirv/builder_call_test.cc | 15 +- .../spirv/builder_function_decoration_test.cc | 27 ++-- src/writer/spirv/builder_function_test.cc | 60 +++---- src/writer/spirv/builder_intrinsic_test.cc | 116 ++++++------- src/writer/spirv/builder_switch_test.cc | 24 +-- src/writer/wgsl/generator_impl.cc | 5 +- .../wgsl/generator_impl_function_test.cc | 39 ++--- src/writer/wgsl/generator_impl_test.cc | 5 +- 48 files changed, 923 insertions(+), 658 deletions(-) diff --git a/src/ast/function.cc b/src/ast/function.cc index 26522e825e..ae692875d8 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -31,12 +31,14 @@ namespace tint { namespace ast { Function::Function(const Source& source, + Symbol symbol, const std::string& name, VariableList params, type::Type* return_type, BlockStatement* body, FunctionDecorationList decorations) : Base(source), + symbol_(symbol), name_(name), params_(std::move(params)), return_type_(return_type), @@ -202,7 +204,7 @@ Function::local_referenced_builtin_variables() const { return ret; } -void Function::add_ancestor_entry_point(const std::string& ep) { +void Function::add_ancestor_entry_point(Symbol ep) { for (const auto& point : ancestor_entry_points_) { if (point == ep) { return; @@ -211,9 +213,9 @@ void Function::add_ancestor_entry_point(const std::string& ep) { ancestor_entry_points_.push_back(ep); } -bool Function::HasAncestorEntryPoint(const std::string& name) const { +bool Function::HasAncestorEntryPoint(Symbol symbol) const { for (const auto& point : ancestor_entry_points_) { - if (point == name) { + if (point == symbol) { return true; } } @@ -226,7 +228,7 @@ const Statement* Function::get_last_statement() const { Function* Function::Clone(CloneContext* ctx) const { return ctx->mod->create( - ctx->Clone(source()), name_, ctx->Clone(params_), + ctx->Clone(source()), symbol_, name_, ctx->Clone(params_), ctx->Clone(return_type_), ctx->Clone(body_), ctx->Clone(decorations_)); } @@ -238,7 +240,7 @@ bool Function::IsValid() const { if (body_ == nullptr || !body_->IsValid()) { return false; } - if (name_.length() == 0) { + if (name_.length() == 0 || !symbol_.IsValid()) { return false; } if (return_type_ == nullptr) { @@ -249,7 +251,7 @@ bool Function::IsValid() const { void Function::to_str(std::ostream& out, size_t indent) const { make_indent(out, indent); - out << "Function " << name_ << " -> " << return_type_->type_name() + out << "Function " << symbol_.to_str() << " -> " << return_type_->type_name() << std::endl; for (auto* deco : decorations()) { diff --git a/src/ast/function.h b/src/ast/function.h index 61eb529013..14b6789376 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -35,6 +35,7 @@ #include "src/ast/type/sampler_type.h" #include "src/ast/type/type.h" #include "src/ast/variable.h" +#include "src/symbol.h" namespace tint { namespace ast { @@ -52,12 +53,14 @@ class Function : public Castable { /// Create a function /// @param source the variable source + /// @param symbol the function symbol /// @param name the function name /// @param params the function parameters /// @param return_type the return type /// @param body the function body /// @param decorations the function decorations Function(const Source& source, + Symbol symbol, const std::string& name, VariableList params, type::Type* return_type, @@ -68,6 +71,8 @@ class Function : public Castable { ~Function() override; + /// @returns the function symbol + Symbol symbol() const { return symbol_; } /// @returns the function name const std::string& name() { return name_; } /// @returns the function params @@ -150,15 +155,15 @@ class Function : public Castable { /// Adds an ancestor entry point /// @param ep the entry point ancestor - void add_ancestor_entry_point(const std::string& ep); + void add_ancestor_entry_point(Symbol ep); /// @returns the ancestor entry points - const std::vector& ancestor_entry_points() const { + const std::vector& ancestor_entry_points() const { return ancestor_entry_points_; } /// Checks if the given entry point is an ancestor - /// @param name the entry point name - /// @returns true if `name` is an ancestor entry point of this function - bool HasAncestorEntryPoint(const std::string& name) const; + /// @param sym the entry point symbol + /// @returns true if `sym` is an ancestor entry point of this function + bool HasAncestorEntryPoint(Symbol sym) const; /// @returns the function return type. type::Type* return_type() const { return return_type_; } @@ -197,13 +202,14 @@ class Function : public Castable { const std::vector> ReferencedSampledTextureVariablesImpl(bool multisampled) const; + Symbol symbol_; std::string name_; VariableList params_; type::Type* return_type_ = nullptr; BlockStatement* body_ = nullptr; std::vector referenced_module_vars_; std::vector local_referenced_module_vars_; - std::vector ancestor_entry_points_; + std::vector ancestor_entry_points_; FunctionDecorationList decorations_; }; diff --git a/src/ast/function_test.cc b/src/ast/function_test.cc index 68f9921943..e341fc83c7 100644 --- a/src/ast/function_test.cc +++ b/src/ast/function_test.cc @@ -35,14 +35,18 @@ TEST_F(FunctionTest, Creation) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, ast::VariableDecorationList{})); auto* var = params[0]; - Function f(Source{}, "func", params, &void_type, create(), - FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", params, &void_type, + create(), FunctionDecorationList{}); + EXPECT_EQ(f.symbol(), func_sym); EXPECT_EQ(f.name(), "func"); ASSERT_EQ(f.params().size(), 1u); EXPECT_EQ(f.return_type(), &void_type); @@ -53,13 +57,16 @@ TEST_F(FunctionTest, Creation_WithSource) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, ast::VariableDecorationList{})); - Function f(Source{Source::Location{20, 2}}, "func", params, &void_type, - create(), FunctionDecorationList{}); + Function f(Source{Source::Location{20, 2}}, func_sym, "func", params, + &void_type, create(), FunctionDecorationList{}); auto src = f.source(); EXPECT_EQ(src.range.begin.line, 20u); EXPECT_EQ(src.range.begin.column, 2u); @@ -69,9 +76,12 @@ TEST_F(FunctionTest, AddDuplicateReferencedVariables) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + Variable v(Source{}, "var", StorageClass::kInput, &i32, false, nullptr, ast::VariableDecorationList{}); - Function f(Source{}, "func", VariableList{}, &void_type, + Function f(Source{}, func_sym, "func", VariableList{}, &void_type, create(), FunctionDecorationList{}); f.add_referenced_module_variable(&v); @@ -92,6 +102,9 @@ TEST_F(FunctionTest, GetReferenceLocations) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + auto* loc1 = create(Source{}, "loc1", StorageClass::kInput, &i32, false, nullptr, ast::VariableDecorationList{ @@ -116,7 +129,7 @@ TEST_F(FunctionTest, GetReferenceLocations) { create(Builtin::kFragDepth, Source{}), }); - Function f(Source{}, "func", VariableList{}, &void_type, + Function f(Source{}, func_sym, "func", VariableList{}, &void_type, create(), FunctionDecorationList{}); f.add_referenced_module_variable(loc1); @@ -137,6 +150,9 @@ TEST_F(FunctionTest, GetReferenceBuiltins) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + auto* loc1 = create(Source{}, "loc1", StorageClass::kInput, &i32, false, nullptr, ast::VariableDecorationList{ @@ -161,7 +177,7 @@ TEST_F(FunctionTest, GetReferenceBuiltins) { create(Builtin::kFragDepth, Source{}), }); - Function f(Source{}, "func", VariableList{}, &void_type, + Function f(Source{}, func_sym, "func", VariableList{}, &void_type, create(), FunctionDecorationList{}); f.add_referenced_module_variable(loc1); @@ -180,22 +196,30 @@ TEST_F(FunctionTest, GetReferenceBuiltins) { TEST_F(FunctionTest, AddDuplicateEntryPoints) { type::Void void_type; - Function f(Source{}, "func", VariableList{}, &void_type, + + Module m; + auto func_sym = m.RegisterSymbol("func"); + auto main_sym = m.RegisterSymbol("main"); + + Function f(Source{}, func_sym, "func", VariableList{}, &void_type, create(), FunctionDecorationList{}); - f.add_ancestor_entry_point("main"); + f.add_ancestor_entry_point(main_sym); ASSERT_EQ(1u, f.ancestor_entry_points().size()); - EXPECT_EQ("main", f.ancestor_entry_points()[0]); + EXPECT_EQ(main_sym, f.ancestor_entry_points()[0]); - f.add_ancestor_entry_point("main"); + f.add_ancestor_entry_point(main_sym); ASSERT_EQ(1u, f.ancestor_entry_points().size()); - EXPECT_EQ("main", f.ancestor_entry_points()[0]); + EXPECT_EQ(main_sym, f.ancestor_entry_points()[0]); } TEST_F(FunctionTest, IsValid) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, @@ -204,21 +228,27 @@ TEST_F(FunctionTest, IsValid) { auto* body = create(); body->append(create()); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); EXPECT_TRUE(f.IsValid()); } -TEST_F(FunctionTest, IsValid_EmptyName) { +TEST_F(FunctionTest, IsValid_InvalidName) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol(""); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, ast::VariableDecorationList{})); - Function f(Source{}, "", params, &void_type, create(), + auto* body = create(); + body->append(create()); + + Function f(Source{}, func_sym, "", params, &void_type, body, FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); } @@ -226,13 +256,16 @@ TEST_F(FunctionTest, IsValid_EmptyName) { TEST_F(FunctionTest, IsValid_MissingReturnType) { type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, ast::VariableDecorationList{})); - Function f(Source{}, "func", params, nullptr, create(), - FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", params, nullptr, + create(), FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); } @@ -240,27 +273,33 @@ TEST_F(FunctionTest, IsValid_NullParam) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, ast::VariableDecorationList{})); params.push_back(nullptr); - Function f(Source{}, "func", params, &void_type, create(), - FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", params, &void_type, + create(), FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); } TEST_F(FunctionTest, IsValid_InvalidParam) { type::Void void_type; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, nullptr, false, nullptr, ast::VariableDecorationList{})); - Function f(Source{}, "func", params, &void_type, create(), - FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", params, &void_type, + create(), FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); } @@ -268,6 +307,9 @@ TEST_F(FunctionTest, IsValid_NullBodyStatement) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, @@ -277,7 +319,7 @@ TEST_F(FunctionTest, IsValid_NullBodyStatement) { body->append(create()); body->append(nullptr); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); @@ -287,6 +329,9 @@ TEST_F(FunctionTest, IsValid_InvalidBodyStatement) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, @@ -296,7 +341,7 @@ TEST_F(FunctionTest, IsValid_InvalidBodyStatement) { body->append(create()); body->append(nullptr); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); EXPECT_FALSE(f.IsValid()); } @@ -305,14 +350,18 @@ TEST_F(FunctionTest, ToStr) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + auto* body = create(); body->append(create()); - Function f(Source{}, "func", {}, &void_type, body, FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", {}, &void_type, body, + FunctionDecorationList{}); std::ostringstream out; f.to_str(out, 2); - EXPECT_EQ(out.str(), R"( Function func -> __void + EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void () { Discard{} @@ -324,16 +373,19 @@ TEST_F(FunctionTest, ToStr_WithDecoration) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + auto* body = create(); body->append(create()); Function f( - Source{}, "func", {}, &void_type, body, + Source{}, func_sym, "func", {}, &void_type, body, FunctionDecorationList{create(2, 4, 6, Source{})}); std::ostringstream out; f.to_str(out, 2); - EXPECT_EQ(out.str(), R"( Function func -> __void + EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void WorkgroupDecoration{2 4 6} () { @@ -346,6 +398,9 @@ TEST_F(FunctionTest, ToStr_WithParams) { type::Void void_type; type::I32 i32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var", StorageClass::kNone, &i32, false, nullptr, @@ -354,12 +409,12 @@ TEST_F(FunctionTest, ToStr_WithParams) { auto* body = create(); body->append(create()); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); std::ostringstream out; f.to_str(out, 2); - EXPECT_EQ(out.str(), R"( Function func -> __void + EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void ( Variable{ var @@ -376,8 +431,11 @@ TEST_F(FunctionTest, ToStr_WithParams) { TEST_F(FunctionTest, TypeName) { type::Void void_type; - Function f(Source{}, "func", {}, &void_type, create(), - FunctionDecorationList{}); + Module m; + auto func_sym = m.RegisterSymbol("func"); + + Function f(Source{}, func_sym, "func", {}, &void_type, + create(), FunctionDecorationList{}); EXPECT_EQ(f.type_name(), "__func__void"); } @@ -386,6 +444,9 @@ TEST_F(FunctionTest, TypeName_WithParams) { type::I32 i32; type::F32 f32; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; params.push_back(create(Source{}, "var1", StorageClass::kNone, &i32, false, nullptr, @@ -394,19 +455,22 @@ TEST_F(FunctionTest, TypeName_WithParams) { false, nullptr, ast::VariableDecorationList{})); - Function f(Source{}, "func", params, &void_type, create(), - FunctionDecorationList{}); + Function f(Source{}, func_sym, "func", params, &void_type, + create(), FunctionDecorationList{}); EXPECT_EQ(f.type_name(), "__func__void__i32__f32"); } TEST_F(FunctionTest, GetLastStatement) { type::Void void_type; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; auto* body = create(); auto* stmt = create(); body->append(stmt); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); EXPECT_EQ(f.get_last_statement(), stmt); @@ -415,9 +479,12 @@ TEST_F(FunctionTest, GetLastStatement) { TEST_F(FunctionTest, GetLastStatement_nullptr) { type::Void void_type; + Module m; + auto func_sym = m.RegisterSymbol("func"); + VariableList params; auto* body = create(); - Function f(Source{}, "func", params, &void_type, body, + Function f(Source{}, func_sym, "func", params, &void_type, body, FunctionDecorationList{}); EXPECT_EQ(f.get_last_statement(), nullptr); @@ -425,8 +492,12 @@ TEST_F(FunctionTest, GetLastStatement_nullptr) { TEST_F(FunctionTest, WorkgroupSize_NoneSet) { type::Void void_type; - Function f(Source{}, "f", {}, &void_type, create(), - FunctionDecorationList{}); + + Module m; + auto func_sym = m.RegisterSymbol("func"); + + Function f(Source{}, func_sym, "func", {}, &void_type, + create(), FunctionDecorationList{}); uint32_t x = 0; uint32_t y = 0; uint32_t z = 0; @@ -438,7 +509,12 @@ TEST_F(FunctionTest, WorkgroupSize_NoneSet) { TEST_F(FunctionTest, WorkgroupSize) { type::Void void_type; - Function f(Source{}, "f", {}, &void_type, create(), + + Module m; + auto func_sym = m.RegisterSymbol("func"); + + Function f(Source{}, func_sym, "func", {}, &void_type, + create(), {create(2u, 4u, 6u, Source{})}); uint32_t x = 0; diff --git a/src/ast/module.cc b/src/ast/module.cc index 5dcc12bd51..2132e1f935 100644 --- a/src/ast/module.cc +++ b/src/ast/module.cc @@ -47,21 +47,23 @@ void Module::Clone(CloneContext* ctx) { for (auto* func : functions_) { ctx->mod->functions_.emplace_back(ctx->Clone(func)); } + + ctx->mod->symbol_table_ = symbol_table_; } -Function* Module::FindFunctionByName(const std::string& name) const { +Function* Module::FindFunctionBySymbol(Symbol sym) const { for (auto* func : functions_) { - if (func->name() == name) { + if (func->symbol() == sym) { return func; } } return nullptr; } -Function* Module::FindFunctionByNameAndStage(const std::string& name, - PipelineStage stage) const { +Function* Module::FindFunctionBySymbolAndStage(Symbol sym, + PipelineStage stage) const { for (auto* func : functions_) { - if (func->name() == name && func->pipeline_stage() == stage) { + if (func->symbol() == sym && func->pipeline_stage() == stage) { return func; } } @@ -81,6 +83,10 @@ Symbol Module::RegisterSymbol(const std::string& name) { return symbol_table_.Register(name); } +Symbol Module::GetSymbol(const std::string& name) const { + return symbol_table_.GetSymbol(name); +} + std::string Module::SymbolToName(const Symbol sym) const { return symbol_table_.NameFor(sym); } diff --git a/src/ast/module.h b/src/ast/module.h index b274be110f..1facd79d8b 100644 --- a/src/ast/module.h +++ b/src/ast/module.h @@ -89,15 +89,14 @@ class Module { /// @returns the modules functions const FunctionList& functions() const { return functions_; } /// Returns the function with the given name - /// @param name the name to search for + /// @param sym the function symbol to search for /// @returns the associated function or nullptr if none exists - Function* FindFunctionByName(const std::string& name) const; + Function* FindFunctionBySymbol(Symbol sym) const; /// Returns the function with the given name - /// @param name the name to search for + /// @param sym the function symbol to search for /// @param stage the pipeline stage /// @returns the associated function or nullptr if none exists - Function* FindFunctionByNameAndStage(const std::string& name, - PipelineStage stage) const; + Function* FindFunctionBySymbolAndStage(Symbol sym, PipelineStage stage) const; /// @param stage the pipeline stage /// @returns true if the module contains an entrypoint function with the given /// stage @@ -169,6 +168,11 @@ class Module { /// previously generated symbol will be returned. Symbol RegisterSymbol(const std::string& name); + /// Returns the symbol for `name` + /// @param name the name to lookup + /// @returns the symbol for name or symbol::kInvalid + Symbol GetSymbol(const std::string& name) const; + /// Returns the `name` for `sym` /// @param sym the symbol to retrieve the name for /// @returns the use provided `name` for the symbol or "" if not found diff --git a/src/ast/module_test.cc b/src/ast/module_test.cc index 3303235729..6914eccea9 100644 --- a/src/ast/module_test.cc +++ b/src/ast/module_test.cc @@ -48,16 +48,17 @@ TEST_F(ModuleTest, LookupFunction) { type::F32 f32; Module m; + auto func_sym = m.RegisterSymbol("main"); auto* func = - create(Source{}, "main", VariableList{}, &f32, + create(Source{}, func_sym, "main", VariableList{}, &f32, create(), ast::FunctionDecorationList{}); m.AddFunction(func); - EXPECT_EQ(func, m.FindFunctionByName("main")); + EXPECT_EQ(func, m.FindFunctionBySymbol(func_sym)); } TEST_F(ModuleTest, LookupFunctionMissing) { Module m; - EXPECT_EQ(nullptr, m.FindFunctionByName("Missing")); + EXPECT_EQ(nullptr, m.FindFunctionBySymbol(m.RegisterSymbol("Missing"))); } TEST_F(ModuleTest, IsValid_Empty) { @@ -127,11 +128,12 @@ TEST_F(ModuleTest, IsValid_Struct_EmptyName) { TEST_F(ModuleTest, IsValid_Function) { type::F32 f32; - auto* func = - create(Source{}, "main", VariableList(), &f32, - create(), ast::FunctionDecorationList{}); Module m; + + auto* func = create(Source{}, m.RegisterSymbol("main"), "main", + VariableList(), &f32, create(), + ast::FunctionDecorationList{}); m.AddFunction(func); EXPECT_TRUE(m.IsValid()); } @@ -144,10 +146,13 @@ TEST_F(ModuleTest, IsValid_Null_Function) { TEST_F(ModuleTest, IsValid_Invalid_Function) { VariableList p; - auto* func = create(Source{}, "", p, nullptr, nullptr, - ast::FunctionDecorationList{}); Module m; + + auto* func = + create(Source{}, m.RegisterSymbol("main"), "main", p, nullptr, + nullptr, ast::FunctionDecorationList{}); + m.AddFunction(func); EXPECT_FALSE(m.IsValid()); } diff --git a/src/inspector/inspector.cc b/src/inspector/inspector.cc index 1bbbe8bc64..7ef231184b 100644 --- a/src/inspector/inspector.cc +++ b/src/inspector/inspector.cc @@ -267,7 +267,7 @@ std::vector Inspector::GetMultisampledTextureResourceBindings( } ast::Function* Inspector::FindEntryPointByName(const std::string& name) { - auto* func = module_.FindFunctionByName(name); + auto* func = module_.FindFunctionBySymbol(module_.GetSymbol(name)); if (!func) { error_ += name + " was not found!"; return nullptr; diff --git a/src/inspector/inspector_test.cc b/src/inspector/inspector_test.cc index 46ac6e456b..fbf24e5dcc 100644 --- a/src/inspector/inspector_test.cc +++ b/src/inspector/inspector_test.cc @@ -83,8 +83,9 @@ class InspectorHelper { ast::FunctionDecorationList decorations = {}) { auto* body = create(); body->append(create(Source{})); - return create(Source{}, name, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(name), name, + ast::VariableList(), void_type(), body, + decorations); } /// Generates a function that calls another @@ -102,8 +103,9 @@ class InspectorHelper { create(ident_expr, ast::ExpressionList()); body->append(create(call_expr)); body->append(create(Source{})); - return create(Source{}, caller, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(caller), + caller, ast::VariableList(), void_type(), body, + decorations); } /// Add In/Out variables to the global variables @@ -154,8 +156,9 @@ class InspectorHelper { create(in))); } body->append(create(Source{})); - return create(Source{}, name, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(name), name, + ast::VariableList(), void_type(), body, + decorations); } /// Generates a function that references in/out variables and calls another @@ -184,8 +187,9 @@ class InspectorHelper { create(ident_expr, ast::ExpressionList()); body->append(create(call_expr)); body->append(create(Source{})); - return create(Source{}, caller, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(caller), + caller, ast::VariableList(), void_type(), body, + decorations); } /// Add a Constant ID to the global variables. @@ -445,9 +449,9 @@ class InspectorHelper { } body->append(create(Source{})); - return create(Source{}, func_name, ast::VariableList(), - void_type(), body, - ast::FunctionDecorationList{}); + return create(Source{}, mod()->RegisterSymbol(func_name), + func_name, ast::VariableList(), void_type(), + body, ast::FunctionDecorationList{}); } /// Adds a regular sampler variable to the module @@ -587,8 +591,9 @@ class InspectorHelper { create("sampler_result"), call_expr)); body->append(create(Source{})); - return create(Source{}, func_name, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(func_name), + func_name, ast::VariableList(), void_type(), + body, decorations); } /// Generates a function that references a specific sampler variable @@ -634,8 +639,9 @@ class InspectorHelper { create("sampler_result"), call_expr)); body->append(create(Source{})); - return create(Source{}, func_name, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(func_name), + func_name, ast::VariableList(), void_type(), + body, decorations); } /// Generates a function that references a specific comparison sampler @@ -682,8 +688,9 @@ class InspectorHelper { create("sampler_result"), call_expr)); body->append(create(Source{})); - return create(Source{}, func_name, ast::VariableList(), - void_type(), body, decorations); + return create(Source{}, mod()->RegisterSymbol(func_name), + func_name, ast::VariableList(), void_type(), + body, decorations); } /// Gets an appropriate type for the data in a given texture type. @@ -1513,7 +1520,8 @@ TEST_F(InspectorGetUniformBufferResourceBindingsTest, MultipleUniformBuffers) { body->append(create(Source{})); ast::Function* func = create( - Source{}, "ep_func", ast::VariableList(), void_type(), body, + Source{}, mod()->RegisterSymbol("ep_func"), "ep_func", + ast::VariableList(), void_type(), body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -1659,7 +1667,8 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, MultipleStorageBuffers) { body->append(create(Source{})); ast::Function* func = create( - Source{}, "ep_func", ast::VariableList(), void_type(), body, + Source{}, mod()->RegisterSymbol("ep_func"), "ep_func", + ast::VariableList(), void_type(), body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -1832,7 +1841,8 @@ TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest, body->append(create(Source{})); ast::Function* func = create( - Source{}, "ep_func", ast::VariableList(), void_type(), body, + Source{}, mod()->RegisterSymbol("ep_func"), "ep_func", + ast::VariableList(), void_type(), body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index ec024bb9b4..c8affa7f28 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -761,9 +761,10 @@ bool FunctionEmitter::Emit() { } auto* body = statements_stack_[0].statements_; - ast_module_.AddFunction(create( - decl.source, decl.name, std::move(decl.params), decl.return_type, body, - std::move(decl.decorations))); + ast_module_.AddFunction( + create(decl.source, ast_module_.RegisterSymbol(decl.name), + decl.name, std::move(decl.params), decl.return_type, + body, std::move(decl.decorations))); // Maintain the invariant by repopulating the one and only element. statements_stack_.clear(); diff --git a/src/reader/spirv/function_call_test.cc b/src/reader/spirv/function_call_test.cc index a30ec374b8..fabf8191f7 100644 --- a/src/reader/spirv/function_call_test.cc +++ b/src/reader/spirv/function_call_test.cc @@ -46,14 +46,16 @@ TEST_F(SpvParserTest, EmitStatement_VoidCallNoParams) { OpFunctionEnd )")); ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error(); - const auto module_ast_str = p->module().to_str(); + const auto module_ast_str = p->get_module().to_str(); EXPECT_THAT(module_ast_str, Eq(R"(Module{ - Function x_50 -> __void + Function )" + p->get_module().GetSymbol("x_50").to_str() + + R"( -> __void () { Return{} } - Function x_100 -> __void + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __void () { Call[not set]{ @@ -214,9 +216,10 @@ TEST_F(SpvParserTest, EmitStatement_CallWithParams) { )")); ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error(); EXPECT_TRUE(p->error().empty()); - const auto module_ast_str = p->module().to_str(); + const auto module_ast_str = p->get_module().to_str(); EXPECT_THAT(module_ast_str, HasSubstr(R"(Module{ - Function x_50 -> __u32 + Function )" + p->get_module().GetSymbol("x_50").to_str() + + R"( -> __u32 ( VariableConst{ x_51 @@ -240,7 +243,8 @@ TEST_F(SpvParserTest, EmitStatement_CallWithParams) { } } } - Function x_100 -> __void + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __void () { VariableDeclStatement{ diff --git a/src/reader/spirv/function_decl_test.cc b/src/reader/spirv/function_decl_test.cc index 2223893f3f..398f8fa8a1 100644 --- a/src/reader/spirv/function_decl_test.cc +++ b/src/reader/spirv/function_decl_test.cc @@ -59,9 +59,10 @@ TEST_F(SpvParserTest, Emit_VoidFunctionWithoutParams) { ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.Emit()); - auto got = p->module().to_str(); - auto* expect = R"(Module{ - Function x_100 -> __void + auto got = p->get_module().to_str(); + auto expect = R"(Module{ + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __void () { Return{} @@ -83,9 +84,10 @@ TEST_F(SpvParserTest, Emit_NonVoidResultType) { FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.Emit()); - auto got = p->module().to_str(); - auto* expect = R"(Module{ - Function x_100 -> __f32 + auto got = p->get_module().to_str(); + auto expect = R"(Module{ + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __f32 () { Return{ @@ -115,9 +117,10 @@ TEST_F(SpvParserTest, Emit_MixedParamTypes) { FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.Emit()); - auto got = p->module().to_str(); - auto* expect = R"(Module{ - Function x_100 -> __void + auto got = p->get_module().to_str(); + auto expect = R"(Module{ + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __void ( VariableConst{ a @@ -159,9 +162,10 @@ TEST_F(SpvParserTest, Emit_GenerateParamNames) { FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); EXPECT_TRUE(fe.Emit()); - auto got = p->module().to_str(); - auto* expect = R"(Module{ - Function x_100 -> __void + auto got = p->get_module().to_str(); + auto expect = R"(Module{ + Function )" + p->get_module().GetSymbol("x_100").to_str() + + R"( -> __void ( VariableConst{ x_14 diff --git a/src/reader/spirv/parser_impl_function_decl_test.cc b/src/reader/spirv/parser_impl_function_decl_test.cc index f322de25a9..dab5f3a0e3 100644 --- a/src/reader/spirv/parser_impl_function_decl_test.cc +++ b/src/reader/spirv/parser_impl_function_decl_test.cc @@ -53,7 +53,7 @@ TEST_F(SpvParserTest, EmitFunctions_NoFunctions) { auto p = parser(test::Assemble(CommonTypes())); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, Not(HasSubstr("Function{"))); } @@ -64,7 +64,7 @@ TEST_F(SpvParserTest, EmitFunctions_FunctionWithoutBody) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, Not(HasSubstr("Function{"))); } @@ -79,9 +79,10 @@ OpFunctionEnd)"; auto p = parser(test::Assemble(input)); ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->error().empty()) << p->error(); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function main -> __void + Function )" + p->get_module().GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () {)")); @@ -98,9 +99,10 @@ OpFunctionEnd)"; auto p = parser(test::Assemble(input)); ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->error().empty()) << p->error(); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function main -> __void + Function )" + p->get_module().GetSymbol("main").to_str() + + R"( -> __void StageDecoration{fragment} () {)")); @@ -117,9 +119,10 @@ OpFunctionEnd)"; auto p = parser(test::Assemble(input)); ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->error().empty()) << p->error(); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function main -> __void + Function )" + p->get_module().GetSymbol("main").to_str() + + R"( -> __void StageDecoration{compute} () {)")); @@ -138,14 +141,16 @@ OpFunctionEnd)"; auto p = parser(test::Assemble(input)); ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->error().empty()) << p->error(); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function frag_main -> __void + Function )" + p->get_module().GetSymbol("frag_main").to_str() + + R"( -> __void StageDecoration{fragment} () {)")); EXPECT_THAT(module_ast, HasSubstr(R"( - Function comp_main -> __void + Function )" + p->get_module().GetSymbol("comp_main").to_str() + + R"( -> __void StageDecoration{compute} () {)")); @@ -160,9 +165,10 @@ TEST_F(SpvParserTest, EmitFunctions_VoidFunctionWithoutParams) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function main -> __void + Function )" + p->get_module().GetSymbol("main").to_str() + + R"( -> __void () {)")); } @@ -193,9 +199,10 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function leaf -> __u32 + Function )" + p->get_module().GetSymbol("leaf").to_str() + + R"( -> __u32 () { Return{ @@ -204,7 +211,8 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) { } } } - Function branch -> __u32 + Function )" + p->get_module().GetSymbol("branch").to_str() + + R"( -> __u32 () { VariableDeclStatement{ @@ -227,7 +235,8 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) { } } } - Function root -> __void + Function )" + p->get_module().GetSymbol("root").to_str() + + R"( -> __void () { VariableDeclStatement{ @@ -260,9 +269,10 @@ TEST_F(SpvParserTest, EmitFunctions_NonVoidResultType) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function ret_float -> __f32 + Function )" + p->get_module().GetSymbol("ret_float").to_str() + + R"( -> __f32 () { Return{ @@ -289,9 +299,10 @@ TEST_F(SpvParserTest, EmitFunctions_MixedParamTypes) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function mixed_params -> __void + Function )" + p->get_module().GetSymbol("mixed_params").to_str() + + R"( -> __void ( VariableConst{ a @@ -328,9 +339,10 @@ TEST_F(SpvParserTest, EmitFunctions_GenerateParamNames) { )")); EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->error().empty()); - const auto module_ast = p->module().to_str(); + const auto module_ast = p->get_module().to_str(); EXPECT_THAT(module_ast, HasSubstr(R"( - Function mixed_params -> __void + Function )" + p->get_module().GetSymbol("mixed_params").to_str() + + R"( -> __void ( VariableConst{ x_14 diff --git a/src/reader/wgsl/parser_impl.cc b/src/reader/wgsl/parser_impl.cc index 67a4651261..04ff79bb1c 100644 --- a/src/reader/wgsl/parser_impl.cc +++ b/src/reader/wgsl/parser_impl.cc @@ -1280,9 +1280,9 @@ Maybe ParserImpl::function_decl(ast::DecorationList& decos) { if (errored) return Failure::kErrored; - return create(header->source, header->name, header->params, - header->return_type, body.value, - func_decos.value); + return create( + header->source, module_.RegisterSymbol(header->name), header->name, + header->params, header->return_type, body.value, func_decos.value); } // function_type_decl diff --git a/src/symbol_table.cc b/src/symbol_table.cc index 13fe13fac4..8998ee2cf8 100644 --- a/src/symbol_table.cc +++ b/src/symbol_table.cc @@ -18,10 +18,14 @@ namespace tint { SymbolTable::SymbolTable() = default; +SymbolTable::SymbolTable(const SymbolTable&) = default; + SymbolTable::SymbolTable(SymbolTable&&) = default; SymbolTable::~SymbolTable() = default; +SymbolTable& SymbolTable::operator=(const SymbolTable& other) = default; + SymbolTable& SymbolTable::operator=(SymbolTable&&) = default; Symbol SymbolTable::Register(const std::string& name) { @@ -41,6 +45,11 @@ Symbol SymbolTable::Register(const std::string& name) { return sym; } +Symbol SymbolTable::GetSymbol(const std::string& name) const { + auto it = name_to_symbol_.find(name); + return it != name_to_symbol_.end() ? it->second : Symbol(); +} + std::string SymbolTable::NameFor(const Symbol symbol) const { auto it = symbol_to_name_.find(symbol.value()); if (it == symbol_to_name_.end()) diff --git a/src/symbol_table.h b/src/symbol_table.h index e3351e16a5..1c085988b5 100644 --- a/src/symbol_table.h +++ b/src/symbol_table.h @@ -27,11 +27,17 @@ class SymbolTable { public: /// Constructor SymbolTable(); + /// Copy constructor + SymbolTable(const SymbolTable&); /// Move Constructor SymbolTable(SymbolTable&&); /// Destructor ~SymbolTable(); + /// Copy assignment + /// @param other the symbol table to copy + /// @returns the new symbol table + SymbolTable& operator=(const SymbolTable& other); /// Move assignment /// @param other the symbol table to move /// @returns the symbol table @@ -42,6 +48,11 @@ class SymbolTable { /// @returns the symbol representing the given name Symbol Register(const std::string& name); + /// Returns the symbol for the given `name` + /// @param name the name to lookup + /// @returns the symbol for the name or symbol::kInvalid if not found. + Symbol GetSymbol(const std::string& name) const; + /// Returns the name for the given symbol /// @param symbol the symbol to retrieve the name for /// @returns the symbol name or "" if not found diff --git a/src/transform/bound_array_accessors_test.cc b/src/transform/bound_array_accessors_test.cc index 6e1f1717b0..24a7242c49 100644 --- a/src/transform/bound_array_accessors_test.cc +++ b/src/transform/bound_array_accessors_test.cc @@ -51,7 +51,7 @@ namespace { template T* FindVariable(ast::Module* mod, std::string name) { - if (auto* func = mod->FindFunctionByName("func")) { + if (auto* func = mod->FindFunctionBySymbol(mod->RegisterSymbol("func"))) { for (auto* stmt : *func->body()) { if (auto* decl = stmt->As()) { if (auto* var = decl->variable()) { @@ -92,9 +92,9 @@ class BoundArrayAccessorsTest : public testing::Test { struct ModuleBuilder : public ast::BuilderWithModule { ModuleBuilder() : body_(create()) { - mod->AddFunction(create(Source{}, "func", - ast::VariableList{}, ty.void_, body_, - ast::FunctionDecorationList{})); + mod->AddFunction(create( + Source{}, mod->RegisterSymbol("func"), "func", ast::VariableList{}, + ty.void_, body_, ast::FunctionDecorationList{})); } ast::Module Module() { diff --git a/src/transform/emit_vertex_point_size_test.cc b/src/transform/emit_vertex_point_size_test.cc index 7a08e5fcfe..57a904713b 100644 --- a/src/transform/emit_vertex_point_size_test.cc +++ b/src/transform/emit_vertex_point_size_test.cc @@ -58,23 +58,26 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) { Var("builtin_assignments_should_happen_before_this", tint::ast::StorageClass::kFunction, ty.f32))); - mod->AddFunction( - create(Source{}, "non_entry_a", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{})); + auto a_sym = mod->RegisterSymbol("non_entry_a"); + mod->AddFunction(create( + Source{}, a_sym, "non_entry_a", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{})); + auto entry_sym = mod->RegisterSymbol("entry"); auto* entry = create( - Source{}, "entry", ast::VariableList{}, ty.void_, block, + Source{}, entry_sym, "entry", ast::VariableList{}, ty.void_, block, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); mod->AddFunction(entry); - mod->AddFunction( - create(Source{}, "non_entry_b", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{})); + auto b_sym = mod->RegisterSymbol("non_entry_b"); + mod->AddFunction(create( + Source{}, b_sym, "non_entry_b", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{})); } }; @@ -82,7 +85,7 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - auto* expected = R"(Module{ + auto expected = R"(Module{ Variable{ Decorations{ BuiltinDecoration{pointsize} @@ -91,11 +94,13 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) { out __f32 } - Function non_entry_a -> __void + Function )" + result.module.RegisterSymbol("non_entry_a").to_str() + + R"( -> __void () { } - Function entry -> __void + Function )" + result.module.RegisterSymbol("entry").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -111,7 +116,8 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) { } } } - Function non_entry_b -> __void + Function )" + result.module.RegisterSymbol("non_entry_b").to_str() + + R"( -> __void () { } @@ -123,23 +129,26 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) { TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { struct Builder : ModuleBuilder { void Build() override { - mod->AddFunction( - create(Source{}, "non_entry_a", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{})); + auto a_sym = mod->RegisterSymbol("non_entry_a"); + mod->AddFunction(create( + Source{}, a_sym, "non_entry_a", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{})); - mod->AddFunction( - create(Source{}, "entry", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kVertex, Source{}), - })); + auto entry_sym = mod->RegisterSymbol("entry"); + mod->AddFunction(create( + Source{}, entry_sym, "entry", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{ + create(ast::PipelineStage::kVertex, + Source{}), + })); - mod->AddFunction( - create(Source{}, "non_entry_b", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{})); + auto b_sym = mod->RegisterSymbol("non_entry_b"); + mod->AddFunction(create( + Source{}, b_sym, "non_entry_b", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{})); } }; @@ -147,7 +156,7 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - auto* expected = R"(Module{ + auto expected = R"(Module{ Variable{ Decorations{ BuiltinDecoration{pointsize} @@ -156,11 +165,13 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { out __f32 } - Function non_entry_a -> __void + Function )" + result.module.RegisterSymbol("non_entry_a").to_str() + + R"( -> __void () { } - Function entry -> __void + Function )" + result.module.RegisterSymbol("entry").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -169,7 +180,8 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { ScalarConstructor[__f32]{1.000000} } } - Function non_entry_b -> __void + Function )" + result.module.RegisterSymbol("non_entry_b").to_str() + + R"( -> __void () { } @@ -181,8 +193,9 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { TEST_F(EmitVertexPointSizeTest, NonVertexStage) { struct Builder : ModuleBuilder { void Build() override { + auto frag_sym = mod->RegisterSymbol("fragment_entry"); auto* fragment_entry = create( - Source{}, "fragment_entry", ast::VariableList{}, ty.void_, + Source{}, frag_sym, "fragment_entry", ast::VariableList{}, ty.void_, create(Source{}), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, @@ -190,13 +203,14 @@ TEST_F(EmitVertexPointSizeTest, NonVertexStage) { }); mod->AddFunction(fragment_entry); - auto* compute_entry = - create(Source{}, "compute_entry", ast::VariableList{}, - ty.void_, create(Source{}), - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto comp_sym = mod->RegisterSymbol("compute_entry"); + auto* compute_entry = create( + Source{}, comp_sym, "compute_entry", ast::VariableList{}, ty.void_, + create(Source{}), + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod->AddFunction(compute_entry); } }; @@ -205,13 +219,15 @@ TEST_F(EmitVertexPointSizeTest, NonVertexStage) { ASSERT_FALSE(result.diagnostics.contains_errors()) << diag::Formatter().format(result.diagnostics); - auto* expected = R"(Module{ - Function fragment_entry -> __void + auto expected = R"(Module{ + Function )" + result.module.RegisterSymbol("fragment_entry").to_str() + + R"( -> __void StageDecoration{fragment} () { } - Function compute_entry -> __void + Function )" + result.module.RegisterSymbol("compute_entry").to_str() + + R"( -> __void StageDecoration{compute} () { diff --git a/src/transform/first_index_offset.cc b/src/transform/first_index_offset.cc index 8738b7e790..b398935922 100644 --- a/src/transform/first_index_offset.cc +++ b/src/transform/first_index_offset.cc @@ -169,9 +169,9 @@ Transform::Output FirstIndexOffset::Run(ast::Module* in) { body->append(ctx.Clone(s)); } return ctx.mod->create( - ctx.Clone(func->source()), func->name(), ctx.Clone(func->params()), - ctx.Clone(func->return_type()), ctx.Clone(body), - ctx.Clone(func->decorations())); + ctx.Clone(func->source()), func->symbol(), func->name(), + ctx.Clone(func->params()), ctx.Clone(func->return_type()), + ctx.Clone(body), ctx.Clone(func->decorations())); }); in->Clone(&ctx); diff --git a/src/transform/first_index_offset_test.cc b/src/transform/first_index_offset_test.cc index c3b857c420..7a275577e8 100644 --- a/src/transform/first_index_offset_test.cc +++ b/src/transform/first_index_offset_test.cc @@ -58,9 +58,9 @@ struct ModuleBuilder : public ast::BuilderWithModule { ast::Function* AddFunction(const std::string& name, ast::VariableList params = {}) { - auto* func = create(Source{}, name, std::move(params), - ty.u32, create(), - ast::FunctionDecorationList()); + auto* func = create( + Source{}, mod->RegisterSymbol(name), name, std::move(params), ty.u32, + create(), ast::FunctionDecorationList()); mod->AddFunction(func); return func; } @@ -154,7 +154,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) { uniform __struct_TintFirstIndexOffsetData } - Function test -> __u32 + Function tint_symbol_1 -> __u32 () { VariableDeclStatement{ @@ -229,7 +229,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) { uniform __struct_TintFirstIndexOffsetData } - Function test -> __u32 + Function tint_symbol_1 -> __u32 () { VariableDeclStatement{ @@ -317,7 +317,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) { uniform __struct_TintFirstIndexOffsetData } - Function test -> __u32 + Function tint_symbol_1 -> __u32 () { Return{ @@ -389,7 +389,7 @@ TEST_F(FirstIndexOffsetTest, NestedCalls) { uniform __struct_TintFirstIndexOffsetData } - Function func1 -> __u32 + Function tint_symbol_1 -> __u32 () { VariableDeclStatement{ @@ -415,7 +415,7 @@ TEST_F(FirstIndexOffsetTest, NestedCalls) { } } } - Function func2 -> __u32 + Function tint_symbol_2 -> __u32 () { Return{ diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index d35f3d870d..6143d15601 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -84,8 +84,8 @@ Transform::Output VertexPulling::Run(ast::Module* in) { } // Find entry point - auto* func = mod->FindFunctionByNameAndStage(cfg.entry_point_name, - ast::PipelineStage::kVertex); + auto* func = mod->FindFunctionBySymbolAndStage( + mod->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex); if (func == nullptr) { diag::Diagnostic err; err.severity = diag::Severity::Error; @@ -94,9 +94,6 @@ Transform::Output VertexPulling::Run(ast::Module* in) { return out; } - // Save the vertex function - auto* vertex_func = mod->FindFunctionByName(func->name()); - // TODO(idanr): Need to check shader locations in descriptor cover all // attributes @@ -108,7 +105,7 @@ Transform::Output VertexPulling::Run(ast::Module* in) { state.FindOrInsertInstanceIndexIfUsed(); state.ConvertVertexInputVariablesToPrivate(); state.AddVertexStorageBuffers(); - state.AddVertexPullingPreamble(vertex_func); + state.AddVertexPullingPreamble(func); return out; } diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc index 8f2a177e1c..c0279201dd 100644 --- a/src/transform/vertex_pulling_test.cc +++ b/src/transform/vertex_pulling_test.cc @@ -47,8 +47,8 @@ class VertexPullingHelper { // Create basic module with an entry point and vertex function void InitBasicModule() { auto* func = create( - Source{}, "main", ast::VariableList{}, mod_->create(), - create(), + Source{}, mod_->RegisterSymbol("main"), "main", ast::VariableList{}, + mod_->create(), create(), ast::FunctionDecorationList{create( ast::PipelineStage::kVertex, Source{})}); mod()->AddFunction(func); @@ -134,8 +134,8 @@ TEST_F(VertexPullingTest, Error_InvalidEntryPoint) { TEST_F(VertexPullingTest, Error_EntryPointWrongStage) { auto* func = create( - Source{}, "main", ast::VariableList{}, mod()->create(), - create(), + Source{}, mod()->RegisterSymbol("main"), "main", ast::VariableList{}, + mod()->create(), create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -152,7 +152,8 @@ TEST_F(VertexPullingTest, BasicModule) { InitBasicModule(); InitTransform({}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); } TEST_F(VertexPullingTest, OneAttribute) { @@ -164,7 +165,8 @@ TEST_F(VertexPullingTest, OneAttribute) { InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -193,7 +195,8 @@ TEST_F(VertexPullingTest, OneAttribute) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -250,7 +253,8 @@ TEST_F(VertexPullingTest, OneInstancedAttribute) { {{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -279,7 +283,8 @@ TEST_F(VertexPullingTest, OneInstancedAttribute) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -336,7 +341,8 @@ TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) { transform()->SetPullingBufferBindingSet(5); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -365,7 +371,8 @@ TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -451,7 +458,8 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) { {4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -502,7 +510,8 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -592,7 +601,8 @@ TEST_F(VertexPullingTest, TwoAttributesSameBuffer) { {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -626,7 +636,8 @@ TEST_F(VertexPullingTest, TwoAttributesSameBuffer) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { @@ -778,7 +789,8 @@ TEST_F(VertexPullingTest, FloatVectorAttributes) { {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}}); auto result = manager()->Run(mod()); - ASSERT_FALSE(result.diagnostics.contains_errors()); + ASSERT_FALSE(result.diagnostics.contains_errors()) + << diag::Formatter().format(result.diagnostics); EXPECT_EQ(R"(Module{ TintVertexData Struct{ @@ -835,7 +847,8 @@ TEST_F(VertexPullingTest, FloatVectorAttributes) { storage_buffer __struct_TintVertexData } - Function main -> __void + Function )" + result.module.GetSymbol("main").to_str() + + R"( -> __void StageDecoration{vertex} () { diff --git a/src/type_determiner.cc b/src/type_determiner.cc index 91fdba5b78..86e8d2e83c 100644 --- a/src/type_determiner.cc +++ b/src/type_determiner.cc @@ -122,7 +122,7 @@ bool TypeDeterminer::Determine() { continue; } for (const auto& callee : caller_to_callee_[func->name()]) { - set_entry_points(callee, func->name()); + set_entry_points(callee, func->symbol()); } } @@ -130,11 +130,11 @@ bool TypeDeterminer::Determine() { } void TypeDeterminer::set_entry_points(const std::string& fn_name, - const std::string& ep_name) { - name_to_function_[fn_name]->add_ancestor_entry_point(ep_name); + Symbol ep_sym) { + name_to_function_[fn_name]->add_ancestor_entry_point(ep_sym); for (const auto& callee : caller_to_callee_[fn_name]) { - set_entry_points(callee, ep_name); + set_entry_points(callee, ep_sym); } } @@ -389,7 +389,8 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) { if (current_function_) { caller_to_callee_[current_function_->name()].push_back(ident->name()); - auto* callee_func = mod_->FindFunctionByName(ident->name()); + auto* callee_func = + mod_->FindFunctionBySymbol(mod_->GetSymbol(ident->name())); if (callee_func == nullptr) { set_error(expr->source(), "unable to find called function: " + ident->name()); diff --git a/src/type_determiner.h b/src/type_determiner.h index f7cb587b6a..0b1157ac49 100644 --- a/src/type_determiner.h +++ b/src/type_determiner.h @@ -113,7 +113,7 @@ class TypeDeterminer { private: void set_error(const Source& src, const std::string& msg); void set_referenced_from_function_if_needed(ast::Variable* var, bool local); - void set_entry_points(const std::string& fn_name, const std::string& ep_name); + void set_entry_points(const std::string& fn_name, Symbol ep_sym); bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr); bool DetermineBinary(ast::BinaryExpression* expr); diff --git a/src/type_determiner_test.cc b/src/type_determiner_test.cc index c798986304..a45e8961ec 100644 --- a/src/type_determiner_test.cc +++ b/src/type_determiner_test.cc @@ -341,9 +341,9 @@ TEST_F(TypeDeterminerTest, Stmt_Call) { ast::type::F32 f32; ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, - create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32, + create(), ast::FunctionDecorationList{}); mod->AddFunction(func); // Register the function @@ -372,15 +372,16 @@ TEST_F(TypeDeterminerTest, Stmt_Call_undeclared) { auto* main_body = create(); main_body->append(create(call_expr)); main_body->append(create(Source{})); - auto* func_main = - create(Source{}, "main", params0, &f32, main_body, - ast::FunctionDecorationList{}); + auto* func_main = create(Source{}, mod->RegisterSymbol("main"), + "main", params0, &f32, main_body, + ast::FunctionDecorationList{}); mod->AddFunction(func_main); auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "func", params0, &f32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod->RegisterSymbol("func"), "func", + params0, &f32, body, ast::FunctionDecorationList{}); mod->AddFunction(func); EXPECT_FALSE(td()->Determine()) << td()->error(); @@ -639,9 +640,9 @@ TEST_F(TypeDeterminerTest, Expr_Call) { ast::type::F32 f32; ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, - create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32, + create(), ast::FunctionDecorationList{}); mod->AddFunction(func); // Register the function @@ -659,9 +660,9 @@ TEST_F(TypeDeterminerTest, Expr_Call_WithParams) { ast::type::F32 f32; ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, - create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32, + create(), ast::FunctionDecorationList{}); mod->AddFunction(func); // Register the function @@ -809,8 +810,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) { body->append(create( my_var, create("my_var"))); - ast::Function f(Source{}, "my_func", {}, &f32, body, - ast::FunctionDecorationList{}); + ast::Function f(Source{}, mod->RegisterSymbol("my_func"), "my_func", {}, &f32, + body, ast::FunctionDecorationList{}); EXPECT_TRUE(td()->DetermineFunction(&f)); @@ -836,8 +837,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) { body->append(create( my_var, create("my_var"))); - ast::Function f(Source{}, "my_func", {}, &f32, body, - ast::FunctionDecorationList{}); + ast::Function f(Source{}, mod->RegisterSymbol("myfunc"), "my_func", {}, &f32, + body, ast::FunctionDecorationList{}); EXPECT_TRUE(td()->DetermineFunction(&f)); @@ -868,8 +869,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function_Ptr) { body->append(create( my_var, create("my_var"))); - ast::Function f(Source{}, "my_func", {}, &f32, body, - ast::FunctionDecorationList{}); + ast::Function f(Source{}, mod->RegisterSymbol("my_func"), "my_func", {}, &f32, + body, ast::FunctionDecorationList{}); EXPECT_TRUE(td()->DetermineFunction(&f)); @@ -885,9 +886,9 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) { ast::type::F32 f32; ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, - create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32, + create(), ast::FunctionDecorationList{}); mod->AddFunction(func); // Register the function @@ -968,8 +969,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables) { body->append(create( create("priv_var"), create("priv_var"))); - auto* func = create(Source{}, "my_func", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod->RegisterSymbol("my_func"), "my_func", + params, &f32, body, ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -1049,8 +1051,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) { create("priv_var"), create("priv_var"))); ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod->RegisterSymbol("my_func"), "my_func", + params, &f32, body, ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -1059,8 +1062,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) { create("out_var"), create(create("my_func"), ast::ExpressionList{}))); - auto* func2 = create(Source{}, "func", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func2 = + create(Source{}, mod->RegisterSymbol("func"), "func", + params, &f32, body, ast::FunctionDecorationList{}); mod->AddFunction(func2); @@ -1096,8 +1100,9 @@ TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) { create(&f32, 1.f)))); ast::VariableList params; - auto* func = create(Source{}, "my_func", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod->RegisterSymbol("my_func"), "my_func", + params, &f32, body, ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -2636,8 +2641,9 @@ TEST_F(TypeDeterminerTest, StorageClass_SetsIfMissing) { auto* body = create(); body->append(stmt); - auto* func = create(Source{}, "func", ast::VariableList{}, - &i32, body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod->RegisterSymbol("func"), + "func", ast::VariableList{}, &i32, body, + ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -2660,8 +2666,9 @@ TEST_F(TypeDeterminerTest, StorageClass_DoesNotSetOnConst) { auto* body = create(); body->append(stmt); - auto* func = create(Source{}, "func", ast::VariableList{}, - &i32, body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod->RegisterSymbol("func"), + "func", ast::VariableList{}, &i32, body, + ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -2684,8 +2691,9 @@ TEST_F(TypeDeterminerTest, StorageClass_NonFunctionClassError) { auto* body = create(); body->append(stmt); - auto* func = create(Source{}, "func", ast::VariableList{}, - &i32, body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod->RegisterSymbol("func"), + "func", ast::VariableList{}, &i32, body, + ast::FunctionDecorationList{}); mod->AddFunction(func); @@ -4857,24 +4865,27 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { ast::VariableList params; auto* body = create(); - auto* func_b = create(Source{}, "b", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func_b = + create(Source{}, mod->RegisterSymbol("b"), "b", params, + &f32, body, ast::FunctionDecorationList{}); body = create(); body->append(create( create("second"), create(create("b"), ast::ExpressionList{}))); - auto* func_c = create(Source{}, "c", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func_c = + create(Source{}, mod->RegisterSymbol("c"), "c", params, + &f32, body, ast::FunctionDecorationList{}); body = create(); body->append(create( create("first"), create(create("c"), ast::ExpressionList{}))); - auto* func_a = create(Source{}, "a", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func_a = + create(Source{}, mod->RegisterSymbol("a"), "a", params, + &f32, body, ast::FunctionDecorationList{}); body = create(); body->append(create( @@ -4886,7 +4897,7 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { create(create("b"), ast::ExpressionList{}))); auto* ep_1 = create( - Source{}, "ep_1", params, &f32, body, + Source{}, mod->RegisterSymbol("ep_1"), "ep_1", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -4897,7 +4908,7 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { create(create("c"), ast::ExpressionList{}))); auto* ep_2 = create( - Source{}, "ep_2", params, &f32, body, + Source{}, mod->RegisterSymbol("ep_2"), "ep_2", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -4954,17 +4965,17 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) { const auto& b_eps = func_b->ancestor_entry_points(); ASSERT_EQ(2u, b_eps.size()); - EXPECT_EQ("ep_1", b_eps[0]); - EXPECT_EQ("ep_2", b_eps[1]); + EXPECT_EQ(mod->RegisterSymbol("ep_1"), b_eps[0]); + EXPECT_EQ(mod->RegisterSymbol("ep_2"), b_eps[1]); const auto& a_eps = func_a->ancestor_entry_points(); ASSERT_EQ(1u, a_eps.size()); - EXPECT_EQ("ep_1", a_eps[0]); + EXPECT_EQ(mod->RegisterSymbol("ep_1"), a_eps[0]); const auto& c_eps = func_c->ancestor_entry_points(); ASSERT_EQ(2u, c_eps.size()); - EXPECT_EQ("ep_1", c_eps[0]); - EXPECT_EQ("ep_2", c_eps[1]); + EXPECT_EQ(mod->RegisterSymbol("ep_1"), c_eps[0]); + EXPECT_EQ(mod->RegisterSymbol("ep_2"), c_eps[1]); EXPECT_TRUE(ep_1->ancestor_entry_points().empty()); EXPECT_TRUE(ep_2->ancestor_entry_points().empty()); diff --git a/src/validator/validator_function_test.cc b/src/validator/validator_function_test.cc index c335c226c9..780b5427d2 100644 --- a/src/validator/validator_function_test.cc +++ b/src/validator/validator_function_test.cc @@ -54,7 +54,8 @@ TEST_F(ValidateFunctionTest, VoidFunctionEndWithoutReturnStatement_Pass) { auto* body = create(); body->append(create(var)); auto* func = create( - Source{Source::Location{12, 34}}, "func", params, &void_type, body, + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func", + params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -71,8 +72,8 @@ TEST_F(ValidateFunctionTest, ast::type::Void void_type; ast::VariableList params; auto* func = create( - Source{Source::Location{12, 34}}, "func", params, &void_type, - create(), + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func", + params, &void_type, create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -100,9 +101,9 @@ TEST_F(ValidateFunctionTest, FunctionEndWithoutReturnStatement_Fail) { ast::type::Void void_type; auto* body = create(); body->append(create(var)); - auto* func = - create(Source{Source::Location{12, 34}}, "func", params, - &i32, body, ast::FunctionDecorationList{}); + auto* func = create( + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func", + params, &i32, body, ast::FunctionDecorationList{}); mod()->AddFunction(func); EXPECT_TRUE(td()->Determine()) << td()->error(); @@ -117,8 +118,9 @@ TEST_F(ValidateFunctionTest, FunctionEndWithoutReturnStatementEmptyBody_Fail) { ast::type::I32 i32; ast::VariableList params; auto* func = create( - Source{Source::Location{12, 34}}, "func", params, &i32, - create(), ast::FunctionDecorationList{}); + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func", + params, &i32, create(), + ast::FunctionDecorationList{}); mod()->AddFunction(func); EXPECT_TRUE(td()->Determine()) << td()->error(); @@ -136,7 +138,7 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_Pass) { auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{}, "func", params, &void_type, body, + Source{}, mod()->RegisterSymbol("func"), "func", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -157,7 +159,8 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_fail) { body->append(create(Source{Source::Location{12, 34}}, return_expr)); - auto* func = create(Source{}, "func", params, &void_type, body, + auto* func = create(Source{}, mod()->RegisterSymbol("func"), + "func", params, &void_type, body, ast::FunctionDecorationList{}); mod()->AddFunction(func); @@ -180,8 +183,9 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) { body->append(create(Source{Source::Location{12, 34}}, return_expr)); - auto* func = create(Source{}, "func", params, &f32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod()->RegisterSymbol("func"), "func", + params, &f32, body, ast::FunctionDecorationList{}); mod()->AddFunction(func); EXPECT_TRUE(td()->Determine()) << td()->error(); @@ -204,8 +208,9 @@ TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) { create(&i32, 2)); body->append(create(Source{}, return_expr)); - auto* func = create(Source{}, "func", params, &i32, body, - ast::FunctionDecorationList{}); + auto* func = + create(Source{}, mod()->RegisterSymbol("func"), "func", + params, &i32, body, ast::FunctionDecorationList{}); ast::VariableList params_copy; auto* body_copy = create(); @@ -213,9 +218,9 @@ TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) { create(&i32, 2)); body_copy->append(create(Source{}, return_expr_copy)); - auto* func_copy = create(Source{Source::Location{12, 34}}, - "func", params_copy, &i32, body_copy, - ast::FunctionDecorationList{}); + auto* func_copy = create( + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func", + params_copy, &i32, body_copy, ast::FunctionDecorationList{}); mod()->AddFunction(func); mod()->AddFunction(func_copy); @@ -237,7 +242,8 @@ TEST_F(ValidateFunctionTest, RecursionIsNotAllowed_Fail) { auto* body0 = create(); body0->append(create(call_expr)); body0->append(create(Source{})); - auto* func0 = create(Source{}, "func", params0, &f32, body0, + auto* func0 = create(Source{}, mod()->RegisterSymbol("func"), + "func", params0, &f32, body0, ast::FunctionDecorationList{}); mod()->AddFunction(func0); @@ -268,7 +274,8 @@ TEST_F(ValidateFunctionTest, RecursionIsNotAllowedExpr_Fail) { create(&i32, 2)); body0->append(create(Source{}, return_expr)); - auto* func0 = create(Source{}, "func", params0, &i32, body0, + auto* func0 = create(Source{}, mod()->RegisterSymbol("func"), + "func", params0, &i32, body0, ast::FunctionDecorationList{}); mod()->AddFunction(func0); @@ -288,7 +295,8 @@ TEST_F(ValidateFunctionTest, Function_WithPipelineStage_NotVoid_Fail) { auto* body = create(); body->append(create(Source{}, return_expr)); auto* func = create( - Source{Source::Location{12, 34}}, "vtx_main", params, &i32, body, + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("vtx_main"), + "vtx_main", params, &i32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -317,7 +325,8 @@ TEST_F(ValidateFunctionTest, Function_WithPipelineStage_WithParams_Fail) { auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{Source::Location{12, 34}}, "vtx_func", params, &void_type, body, + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("vtx_func"), + "vtx_func", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -339,7 +348,8 @@ TEST_F(ValidateFunctionTest, PipelineStage_MustBeUnique_Fail) { auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{Source::Location{12, 34}}, "main", params, &void_type, body, + Source{Source::Location{12, 34}}, mod()->RegisterSymbol("main"), "main", + params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), create(ast::PipelineStage::kFragment, Source{}), @@ -361,7 +371,8 @@ TEST_F(ValidateFunctionTest, OnePipelineStageFunctionMustBePresent_Pass) { auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{}, "vtx_func", params, &void_type, body, + Source{}, mod()->RegisterSymbol("vtx_func"), "vtx_func", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -377,8 +388,9 @@ TEST_F(ValidateFunctionTest, OnePipelineStageFunctionMustBePresent_Fail) { ast::VariableList params; auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "vtx_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod()->RegisterSymbol("vtx_func"), "vtx_func", params, + &void_type, body, ast::FunctionDecorationList{}); mod()->AddFunction(func); EXPECT_TRUE(td()->Determine()) << td()->error(); diff --git a/src/validator/validator_test.cc b/src/validator/validator_test.cc index d993b1ea3f..5931fc0688 100644 --- a/src/validator/validator_test.cc +++ b/src/validator/validator_test.cc @@ -332,7 +332,8 @@ TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariable_Fail) { body->append(create( Source{Source::Location{12, 34}}, lhs, rhs)); - auto* func = create(Source{}, "my_func", params, &f32, body, + auto* func = create(Source{}, mod()->RegisterSymbol("my_func"), + "my_func", params, &f32, body, ast::FunctionDecorationList{}); mod()->AddFunction(func); @@ -370,7 +371,8 @@ TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariable_Pass) { Source{Source::Location{12, 34}}, lhs, rhs)); body->append(create(Source{})); auto* func = create( - Source{}, "my_func", params, &void_type, body, + Source{}, mod()->RegisterSymbol("my_func"), "my_func", params, &void_type, + body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -587,8 +589,9 @@ TEST_F(ValidatorTest, GlobalVariableFunctionVariableNotUnique_Fail) { auto* body = create(); body->append(create( Source{Source::Location{12, 34}}, var)); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod()->RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod()->AddFunction(func); @@ -631,8 +634,9 @@ TEST_F(ValidatorTest, RedeclaredIndentifier_Fail) { body->append(create(var)); body->append(create( Source{Source::Location{12, 34}}, var_a_float)); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod()->RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod()->AddFunction(func); @@ -759,8 +763,9 @@ TEST_F(ValidatorTest, RedeclaredIdentifierDifferentFunctions_Pass) { body0->append(create( Source{Source::Location{12, 34}}, var0)); body0->append(create(Source{})); - auto* func0 = create(Source{}, "func0", params0, &void_type, - body0, ast::FunctionDecorationList{}); + auto* func0 = create(Source{}, mod()->RegisterSymbol("func0"), + "func0", params0, &void_type, body0, + ast::FunctionDecorationList{}); ast::VariableList params1; auto* body1 = create(); @@ -768,7 +773,8 @@ TEST_F(ValidatorTest, RedeclaredIdentifierDifferentFunctions_Pass) { Source{Source::Location{13, 34}}, var1)); body1->append(create(Source{})); auto* func1 = create( - Source{}, "func1", params1, &void_type, body1, + Source{}, mod()->RegisterSymbol("func1"), "func1", params1, &void_type, + body1, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); diff --git a/src/validator/validator_type_test.cc b/src/validator/validator_type_test.cc index 2fbaeffac2..ac8cc05de7 100644 --- a/src/validator/validator_type_test.cc +++ b/src/validator/validator_type_test.cc @@ -206,8 +206,9 @@ TEST_F(ValidatorTypeTest, RuntimeArrayInFunction_Fail) { auto* body = create(); body->append(create( Source{Source::Location{12, 34}}, var)); + auto* func = create( - Source{}, "func", params, &void_type, body, + Source{}, mod()->RegisterSymbol("func"), "func", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index c52c8df572..fa49a4b0e0 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -196,15 +196,15 @@ std::string GeneratorImpl::current_ep_var_name(VarType type) { std::string name = ""; switch (type) { case VarType::kIn: { - auto in_it = ep_name_to_in_data_.find(current_ep_name_); - if (in_it != ep_name_to_in_data_.end()) { + auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value()); + if (in_it != ep_sym_to_in_data_.end()) { name = in_it->second.var_name; } break; } case VarType::kOut: { - auto outit = ep_name_to_out_data_.find(current_ep_name_); - if (outit != ep_name_to_out_data_.end()) { + auto outit = ep_sym_to_out_data_.find(current_ep_sym_.value()); + if (outit != ep_sym_to_out_data_.end()) { name = outit->second.var_name; } break; @@ -668,12 +668,14 @@ bool GeneratorImpl::EmitCall(std::ostream& pre, } auto name = ident->name(); - auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name); + auto caller_sym = module_->GetSymbol(name); + auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" + + caller_sym.to_str()); if (it != ep_func_name_remapped_.end()) { name = it->second; } - auto* func = module_->FindFunctionByName(ident->name()); + auto* func = module_->FindFunctionBySymbol(module_->GetSymbol(ident->name())); if (func == nullptr) { error_ = "Unable to find function: " + name; return false; @@ -1189,15 +1191,15 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) { has_referenced_var_needing_struct(func); if (emit_duplicate_functions) { - for (const auto& ep_name : func->ancestor_entry_points()) { - if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_name)) { + for (const auto& ep_sym : func->ancestor_entry_points()) { + if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_sym)) { return false; } out << std::endl; } } else { // Emit as non-duplicated - if (!EmitFunctionInternal(out, func, false, "")) { + if (!EmitFunctionInternal(out, func, false, Symbol())) { return false; } out << std::endl; @@ -1209,8 +1211,8 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) { bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, ast::Function* func, bool emit_duplicate_functions, - const std::string& ep_name) { - auto name = func->name(); + Symbol ep_sym) { + auto name = func->symbol().to_str(); if (!EmitType(out, func->return_type(), "")) { return false; @@ -1219,10 +1221,15 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, out << " "; if (emit_duplicate_functions) { - name = generate_name(name + "_" + ep_name); - ep_func_name_remapped_[ep_name + "_" + func->name()] = name; + auto func_name = name; + auto ep_name = ep_sym.to_str(); + // TODO(dsinclair): The SymbolToName should go away and just use + // to_str() here when the conversion is complete. + name = generate_name(func->name() + "_" + module_->SymbolToName(ep_sym)); + ep_func_name_remapped_[ep_name + "_" + func_name] = name; } else { - name = namer_.NameFor(name); + // TODO(dsinclair): this should be updated to a remapped name + name = namer_.NameFor(func->name()); } out << name << "("; @@ -1234,15 +1241,15 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, // // We emit both of them if they're there regardless of if they're both used. if (emit_duplicate_functions) { - auto in_it = ep_name_to_in_data_.find(ep_name); - if (in_it != ep_name_to_in_data_.end()) { + auto in_it = ep_sym_to_in_data_.find(ep_sym.value()); + if (in_it != ep_sym_to_in_data_.end()) { out << "in " << in_it->second.struct_name << " " << in_it->second.var_name; first = false; } - auto outit = ep_name_to_out_data_.find(ep_name); - if (outit != ep_name_to_out_data_.end()) { + auto outit = ep_sym_to_out_data_.find(ep_sym.value()); + if (outit != ep_sym_to_out_data_.end()) { if (!first) { out << ", "; } @@ -1269,13 +1276,13 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, out << ") "; - current_ep_name_ = ep_name; + current_ep_sym_ = ep_sym; if (!EmitBlockAndNewline(out, func->body())) { return false; } - current_ep_name_ = ""; + current_ep_sym_ = Symbol(); return true; } @@ -1392,7 +1399,7 @@ bool GeneratorImpl::EmitEntryPointData( auto in_struct_name = generate_name(func->name() + "_" + kInStructNameSuffix); auto in_var_name = generate_name(kTintStructInVarPrefix); - ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name}; + ep_sym_to_in_data_[func->symbol().value()] = {in_struct_name, in_var_name}; make_indent(out); out << "struct " << in_struct_name << " {" << std::endl; @@ -1438,7 +1445,7 @@ bool GeneratorImpl::EmitEntryPointData( auto outstruct_name = generate_name(func->name() + "_" + kOutStructNameSuffix); auto outvar_name = generate_name(kTintStructOutVarPrefix); - ep_name_to_out_data_[func->name()] = {outstruct_name, outvar_name}; + ep_sym_to_out_data_[func->symbol().value()] = {outstruct_name, outvar_name}; make_indent(out); out << "struct " << outstruct_name << " {" << std::endl; @@ -1516,7 +1523,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, ast::Function* func) { make_indent(out); - current_ep_name_ = func->name(); + current_ep_sym_ = func->symbol(); if (func->pipeline_stage() == ast::PipelineStage::kCompute) { uint32_t x = 0; @@ -1528,17 +1535,18 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, make_indent(out); } - auto outdata = ep_name_to_out_data_.find(current_ep_name_); - bool has_outdata = outdata != ep_name_to_out_data_.end(); + auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value()); + bool has_outdata = outdata != ep_sym_to_out_data_.end(); if (has_outdata) { out << outdata->second.struct_name; } else { out << "void"; } - out << " " << namer_.NameFor(current_ep_name_) << "("; + // TODO(dsinclair): This should output the remapped name + out << " " << namer_.NameFor(module_->SymbolToName(current_ep_sym_)) << "("; - auto in_data = ep_name_to_in_data_.find(current_ep_name_); - if (in_data != ep_name_to_in_data_.end()) { + auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value()); + if (in_data != ep_sym_to_in_data_.end()) { out << in_data->second.struct_name << " " << in_data->second.var_name; } out << ") {" << std::endl; @@ -1563,7 +1571,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, make_indent(out); out << "}" << std::endl; - current_ep_name_ = ""; + current_ep_sym_ = Symbol(); return true; } @@ -1966,8 +1974,8 @@ bool GeneratorImpl::EmitReturn(std::ostream& out, ast::ReturnStatement* stmt) { if (generating_entry_point_) { out << "return"; - auto outdata = ep_name_to_out_data_.find(current_ep_name_); - if (outdata != ep_name_to_out_data_.end()) { + auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value()); + if (outdata != ep_sym_to_out_data_.end()) { out << " " << outdata->second.var_name; } } else if (stmt->has_value()) { diff --git a/src/writer/hlsl/generator_impl.h b/src/writer/hlsl/generator_impl.h index 6694e26040..ebd98b993d 100644 --- a/src/writer/hlsl/generator_impl.h +++ b/src/writer/hlsl/generator_impl.h @@ -210,12 +210,12 @@ class GeneratorImpl { /// @param func the function to emit /// @param emit_duplicate_functions set true if we need to duplicate per entry /// point - /// @param ep_name the current entry point or blank if none set + /// @param ep_sym the current entry point or symbol::kInvalid if none set /// @returns true if the function was emitted. bool EmitFunctionInternal(std::ostream& out, ast::Function* func, bool emit_duplicate_functions, - const std::string& ep_name); + Symbol ep_sym); /// Handles emitting information for an entry point /// @param out the output stream /// @param func the entry point @@ -397,12 +397,12 @@ class GeneratorImpl { Namer namer_; ast::Module* module_ = nullptr; - std::string current_ep_name_; + Symbol current_ep_sym_; bool generating_entry_point_ = false; uint32_t loop_emission_counter_ = 0; ScopeStack global_variables_; - std::unordered_map ep_name_to_in_data_; - std::unordered_map ep_name_to_out_data_; + std::unordered_map ep_sym_to_in_data_; + std::unordered_map ep_sym_to_out_data_; // This maps an input of "_" to a remapped // function name. If there is no entry for a given key then function did diff --git a/src/writer/hlsl/generator_impl_binary_test.cc b/src/writer/hlsl/generator_impl_binary_test.cc index 20654ebd41..318c68b720 100644 --- a/src/writer/hlsl/generator_impl_binary_test.cc +++ b/src/writer/hlsl/generator_impl_binary_test.cc @@ -613,9 +613,9 @@ TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) { ast::type::Void void_type; - auto* func = create(Source{}, "foo", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("foo"), "foo", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ast::ExpressionList params; diff --git a/src/writer/hlsl/generator_impl_call_test.cc b/src/writer/hlsl/generator_impl_call_test.cc index e185dcc095..3311837bae 100644 --- a/src/writer/hlsl/generator_impl_call_test.cc +++ b/src/writer/hlsl/generator_impl_call_test.cc @@ -35,9 +35,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithoutParams) { auto* id = create("my_func"); ast::CallExpression call(id, {}); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error(); @@ -53,9 +53,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithParams) { params.push_back(create("param2")); ast::CallExpression call(id, params); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error(); @@ -71,9 +71,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitStatement_Call) { params.push_back(create("param2")); ast::CallStatement call(create(id, params)); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); ASSERT_TRUE(gen.EmitStatement(out, &call)) << gen.error(); diff --git a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc index 57c8553103..320b478c32 100644 --- a/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc +++ b/src/writer/hlsl/generator_impl_function_entry_point_data_test.cc @@ -91,7 +91,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "vtx_main", params, &f32, body, + Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -164,7 +164,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "vtx_main", params, &f32, body, + Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -237,7 +237,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -309,7 +309,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -378,7 +378,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -442,7 +442,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -512,7 +512,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint, create("x")))); auto* func = create( - Source{}, "main", params, &void_type, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index 95543a7279..cab53fc611 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -57,9 +57,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function) { auto* body = create(); body->append(create(Source{})); - auto* func = - create(Source{}, "my_func", ast::VariableList{}, - &void_type, body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", ast::VariableList{}, &void_type, + body, ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -77,9 +77,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Name_Collision) { auto* body = create(); body->append(create(Source{})); - auto* func = - create(Source{}, "GeometryShader", ast::VariableList{}, - &void_type, body, ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("GeometryShader"), "GeometryShader", + ast::VariableList{}, &void_type, body, ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -118,8 +118,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithParams) { auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -174,7 +175,8 @@ TEST_F(HlslGeneratorImplTest_Function, create("foo"))); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -245,7 +247,8 @@ TEST_F(HlslGeneratorImplTest_Function, create("x")))); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -309,7 +312,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -380,7 +384,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -455,7 +460,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -526,7 +532,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -594,7 +601,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(assign); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -682,8 +690,9 @@ TEST_F( create("param"))); body->append(create( Source{}, create("foo"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -698,7 +707,7 @@ TEST_F( expr))); body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -766,8 +775,9 @@ TEST_F(HlslGeneratorImplTest_Function, auto* body = create(); body->append(create( Source{}, create("param"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -782,7 +792,7 @@ TEST_F(HlslGeneratorImplTest_Function, expr))); body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -863,8 +873,9 @@ TEST_F( create("x")))); body->append(create( Source{}, create("param"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -879,7 +890,7 @@ TEST_F( expr))); body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -948,8 +959,9 @@ TEST_F(HlslGeneratorImplTest_Function, Source{}, create( create("coord"), create("x")))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -971,7 +983,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1034,8 +1047,9 @@ TEST_F(HlslGeneratorImplTest_Function, Source{}, create( create("coord"), create("x")))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -1057,7 +1071,8 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(var)); body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1122,7 +1137,7 @@ TEST_F(HlslGeneratorImplTest_Function, body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1152,8 +1167,8 @@ TEST_F(HlslGeneratorImplTest_Function, ast::type::Void void_type; auto* func = create( - Source{}, "GeometryShader", ast::VariableList{}, &void_type, - create(), + Source{}, mod.RegisterSymbol("GeometryShader"), "GeometryShader", + ast::VariableList{}, &void_type, create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1175,7 +1190,7 @@ TEST_F(HlslGeneratorImplTest_Function, auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{}, "main", params, &void_type, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -1200,7 +1215,7 @@ TEST_F(HlslGeneratorImplTest_Function, auto* body = create(); body->append(create(Source{})); auto* func = create( - Source{}, "main", params, &void_type, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), create(2u, 4u, 6u, Source{}), @@ -1236,8 +1251,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) { auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -1317,12 +1333,12 @@ TEST_F(HlslGeneratorImplTest_Function, auto* body = create(); body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "a", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } @@ -1343,12 +1359,12 @@ TEST_F(HlslGeneratorImplTest_Function, auto* body = create(); body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "b", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } diff --git a/src/writer/hlsl/generator_impl_test.cc b/src/writer/hlsl/generator_impl_test.cc index dfdaf9519a..1494fcf5e9 100644 --- a/src/writer/hlsl/generator_impl_test.cc +++ b/src/writer/hlsl/generator_impl_test.cc @@ -29,9 +29,9 @@ using HlslGeneratorImplTest = TestHelper; TEST_F(HlslGeneratorImplTest, Generate) { ast::type::Void void_type; - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ASSERT_TRUE(gen.Generate(out)) << gen.error(); diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index 000c3d7fed..4af9556c15 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -411,15 +411,15 @@ std::string GeneratorImpl::current_ep_var_name(VarType type) { std::string name = ""; switch (type) { case VarType::kIn: { - auto in_it = ep_name_to_in_data_.find(current_ep_name_); - if (in_it != ep_name_to_in_data_.end()) { + auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value()); + if (in_it != ep_sym_to_in_data_.end()) { name = in_it->second.var_name; } break; } case VarType::kOut: { - auto out_it = ep_name_to_out_data_.find(current_ep_name_); - if (out_it != ep_name_to_out_data_.end()) { + auto out_it = ep_sym_to_out_data_.find(current_ep_sym_.value()); + if (out_it != ep_sym_to_out_data_.end()) { name = out_it->second.var_name; } break; @@ -573,12 +573,14 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { } auto name = ident->name(); - auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name); + auto caller_sym = module_->GetSymbol(name); + auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" + + caller_sym.to_str()); if (it != ep_func_name_remapped_.end()) { name = it->second; } - auto* func = module_->FindFunctionByName(ident->name()); + auto* func = module_->FindFunctionBySymbol(module_->GetSymbol(ident->name())); if (func == nullptr) { error_ = "Unable to find function: " + name; return false; @@ -1026,7 +1028,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { auto in_struct_name = generate_name(func->name() + "_" + kInStructNameSuffix); auto in_var_name = generate_name(kTintStructInVarPrefix); - ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name}; + ep_sym_to_in_data_[func->symbol().value()] = {in_struct_name, in_var_name}; make_indent(); out_ << "struct " << in_struct_name << " {" << std::endl; @@ -1063,7 +1065,8 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { auto out_struct_name = generate_name(func->name() + "_" + kOutStructNameSuffix); auto out_var_name = generate_name(kTintStructOutVarPrefix); - ep_name_to_out_data_[func->name()] = {out_struct_name, out_var_name}; + ep_sym_to_out_data_[func->symbol().value()] = {out_struct_name, + out_var_name}; make_indent(); out_ << "struct " << out_struct_name << " {" << std::endl; @@ -1205,15 +1208,15 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { has_referenced_var_needing_struct(func); if (emit_duplicate_functions) { - for (const auto& ep_name : func->ancestor_entry_points()) { - if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) { + for (const auto& ep_sym : func->ancestor_entry_points()) { + if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_sym)) { return false; } out_ << std::endl; } } else { // Emit as non-duplicated - if (!EmitFunctionInternal(func, false, "")) { + if (!EmitFunctionInternal(func, false, Symbol())) { return false; } out_ << std::endl; @@ -1224,19 +1227,23 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, bool emit_duplicate_functions, - const std::string& ep_name) { - auto name = func->name(); - + Symbol ep_sym) { + auto name = func->symbol().to_str(); if (!EmitType(func->return_type(), "")) { return false; } out_ << " "; if (emit_duplicate_functions) { - name = generate_name(name + "_" + ep_name); - ep_func_name_remapped_[ep_name + "_" + func->name()] = name; + auto func_name = name; + auto ep_name = ep_sym.to_str(); + // TODO(dsinclair): The SymbolToName should go away and just use + // to_str() here when the conversion is complete. + name = generate_name(func->name() + "_" + module_->SymbolToName(ep_sym)); + ep_func_name_remapped_[ep_name + "_" + func_name] = name; } else { - name = namer_.NameFor(name); + // TODO(dsinclair): this should be updated to a remapped name + name = namer_.NameFor(func->name()); } out_ << name << "("; @@ -1247,15 +1254,15 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, // // We emit both of them if they're there regardless of if they're both used. if (emit_duplicate_functions) { - auto in_it = ep_name_to_in_data_.find(ep_name); - if (in_it != ep_name_to_in_data_.end()) { + auto in_it = ep_sym_to_in_data_.find(ep_sym.value()); + if (in_it != ep_sym_to_in_data_.end()) { out_ << "thread " << in_it->second.struct_name << "& " << in_it->second.var_name; first = false; } - auto out_it = ep_name_to_out_data_.find(ep_name); - if (out_it != ep_name_to_out_data_.end()) { + auto out_it = ep_sym_to_out_data_.find(ep_sym.value()); + if (out_it != ep_sym_to_out_data_.end()) { if (!first) { out_ << ", "; } @@ -1337,13 +1344,13 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, out_ << ") "; - current_ep_name_ = ep_name; + current_ep_sym_ = ep_sym; if (!EmitBlockAndNewline(func->body())) { return false; } - current_ep_name_ = ""; + current_ep_sym_ = Symbol(); return true; } @@ -1377,25 +1384,25 @@ std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const { bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { make_indent(); - current_ep_name_ = func->name(); + current_ep_sym_ = func->symbol(); EmitStage(func->pipeline_stage()); out_ << " "; // This is an entry point, the return type is the entry point output structure // if one exists, or void otherwise. - auto out_data = ep_name_to_out_data_.find(current_ep_name_); - bool has_out_data = out_data != ep_name_to_out_data_.end(); + auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value()); + bool has_out_data = out_data != ep_sym_to_out_data_.end(); if (has_out_data) { out_ << out_data->second.struct_name; } else { out_ << "void"; } - out_ << " " << namer_.NameFor(current_ep_name_) << "("; + out_ << " " << namer_.NameFor(func->name()) << "("; bool first = true; - auto in_data = ep_name_to_in_data_.find(current_ep_name_); - if (in_data != ep_name_to_in_data_.end()) { + auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value()); + if (in_data != ep_sym_to_in_data_.end()) { out_ << in_data->second.struct_name << " " << in_data->second.var_name << " [[stage_in]]"; first = false; @@ -1503,7 +1510,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { make_indent(); out_ << "}" << std::endl; - current_ep_name_ = ""; + current_ep_sym_ = Symbol(); return true; } @@ -1687,8 +1694,8 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) { out_ << "return"; if (generating_entry_point_) { - auto out_data = ep_name_to_out_data_.find(current_ep_name_); - if (out_data != ep_name_to_out_data_.end()) { + auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value()); + if (out_data != ep_sym_to_out_data_.end()) { out_ << " " << out_data->second.var_name; } } else if (stmt->has_value()) { diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index 0e4ef7d2eb..d087c7cb4b 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -156,11 +156,11 @@ class GeneratorImpl : public TextGenerator { /// @param func the function to emit /// @param emit_duplicate_functions set true if we need to duplicate per entry /// point - /// @param ep_name the current entry point or blank if none set + /// @param ep_sym the current entry point or symbol::kInvalid if not set /// @returns true if the function was emitted. bool EmitFunctionInternal(ast::Function* func, bool emit_duplicate_functions, - const std::string& ep_name); + Symbol ep_sym); /// Handles generating an identifier expression /// @param expr the identifier expression /// @returns true if the identifier was emitted @@ -282,13 +282,13 @@ class GeneratorImpl : public TextGenerator { Namer namer_; ScopeStack global_variables_; - std::string current_ep_name_; + Symbol current_ep_sym_; bool generating_entry_point_ = false; const ast::Module* module_ = nullptr; uint32_t loop_emission_counter_ = 0; - std::unordered_map ep_name_to_in_data_; - std::unordered_map ep_name_to_out_data_; + std::unordered_map ep_sym_to_in_data_; + std::unordered_map ep_sym_to_out_data_; // This maps an input of "_" to a remapped // function name. If there is no entry for a given key then function did diff --git a/src/writer/msl/generator_impl_call_test.cc b/src/writer/msl/generator_impl_call_test.cc index c27e55689f..71024249ca 100644 --- a/src/writer/msl/generator_impl_call_test.cc +++ b/src/writer/msl/generator_impl_call_test.cc @@ -37,9 +37,9 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithoutParams) { auto* id = create("my_func"); ast::CallExpression call(id, {}); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error(); @@ -55,9 +55,9 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithParams) { params.push_back(create("param2")); ast::CallExpression call(id, params); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error(); @@ -73,9 +73,9 @@ TEST_F(MslGeneratorImplTest, EmitStatement_Call) { params.push_back(create("param2")); ast::CallStatement call(create(id, params)); - auto* func = create(Source{}, "my_func", ast::VariableList{}, - &void_type, create(), - ast::FunctionDecorationList{}); + auto* func = create( + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); diff --git a/src/writer/msl/generator_impl_function_entry_point_data_test.cc b/src/writer/msl/generator_impl_function_entry_point_data_test.cc index ac6b4c8eba..f43ea3233b 100644 --- a/src/writer/msl/generator_impl_function_entry_point_data_test.cc +++ b/src/writer/msl/generator_impl_function_entry_point_data_test.cc @@ -90,7 +90,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Input) { create("bar"))); auto* func = create( - Source{}, "vtx_main", params, &f32, body, + Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -160,7 +160,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Output) { create("bar"))); auto* func = create( - Source{}, "vtx_main", params, &f32, body, + Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -229,7 +229,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Input) { create("bar"), create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -299,7 +299,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Output) { create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -366,7 +366,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Input) { create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -428,7 +428,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Output) { create("bar"))); auto* func = create( - Source{}, "main", params, &f32, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -496,7 +496,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Builtins) { create("x")))); auto* func = create( - Source{}, "main", params, &void_type, body, + Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 03934cb2bb..950e8fe8db 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -60,9 +60,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function) { auto* body = create(); body->append(create(Source{})); - auto* func = - create(Source{}, "my_func", ast::VariableList{}, - &void_type, body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", ast::VariableList{}, &void_type, + body, ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -82,9 +82,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_Name_Collision) { auto* body = create(); body->append(create(Source{})); - auto* func = - create(Source{}, "main", ast::VariableList{}, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("main"), + "main", ast::VariableList{}, &void_type, + body, ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -125,8 +125,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithParams) { auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod.AddFunction(func); gen.increment_indent(); @@ -183,7 +184,8 @@ TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_WithInOutVars) { body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{create( ast::PipelineStage::kFragment, Source{})}); @@ -257,7 +259,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -321,7 +324,8 @@ TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_With_Uniform) { body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -397,7 +401,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -478,7 +483,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -572,8 +578,9 @@ TEST_F( create("param"))); body->append(create( Source{}, create("foo"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -588,7 +595,7 @@ TEST_F( expr))); body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -659,8 +666,9 @@ TEST_F(MslGeneratorImplTest, auto* body = create(); body->append(create( Source{}, create("param"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -676,7 +684,7 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -760,8 +768,9 @@ TEST_F( create("x")))); body->append(create( Source{}, create("param"))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -776,7 +785,7 @@ TEST_F( expr))); body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -843,8 +852,9 @@ TEST_F(MslGeneratorImplTest, Source{}, create( create("coord"), create("x")))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -867,7 +877,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -943,8 +954,9 @@ TEST_F(MslGeneratorImplTest, Source{}, create( create("coord"), create("b")))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -967,7 +979,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1049,8 +1062,9 @@ TEST_F(MslGeneratorImplTest, Source{}, create( create("coord"), create("b")))); - auto* sub_func = create(Source{}, "sub_func", params, &f32, - body, ast::FunctionDecorationList{}); + auto* sub_func = create( + Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body, + ast::FunctionDecorationList{}); mod.AddFunction(sub_func); @@ -1073,7 +1087,8 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func = create( - Source{}, "frag_main", params, &void_type, body, + Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params, + &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1145,7 +1160,7 @@ TEST_F(MslGeneratorImplTest, body->append(create(Source{})); auto* func_1 = create( - Source{}, "ep_1", params, &void_type, body, + Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -1177,8 +1192,8 @@ TEST_F(MslGeneratorImplTest, ast::type::Void void_type; auto* func = create( - Source{}, "main", ast::VariableList{}, &void_type, - create(), + Source{}, mod.RegisterSymbol("main"), "main", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -1212,8 +1227,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) { auto* body = create(); body->append(create(Source{})); - auto* func = create(Source{}, "my_func", params, &void_type, - body, ast::FunctionDecorationList{}); + auto* func = create(Source{}, mod.RegisterSymbol("my_func"), + "my_func", params, &void_type, body, + ast::FunctionDecorationList{}); mod.AddFunction(func); @@ -1298,12 +1314,12 @@ TEST_F(MslGeneratorImplTest, body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "a", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } @@ -1325,12 +1341,12 @@ TEST_F(MslGeneratorImplTest, body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "b", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc index 06fdbc7fdf..0ba3247204 100644 --- a/src/writer/msl/generator_impl_test.cc +++ b/src/writer/msl/generator_impl_test.cc @@ -51,8 +51,8 @@ TEST_F(MslGeneratorImplTest, Generate) { ast::type::Void void_type; auto* func = create( - Source{}, "my_func", ast::VariableList{}, &void_type, - create(), + Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{}, + &void_type, create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); diff --git a/src/writer/spirv/builder_call_test.cc b/src/writer/spirv/builder_call_test.cc index 1f2d50b9fb..a4585a5189 100644 --- a/src/writer/spirv/builder_call_test.cc +++ b/src/writer/spirv/builder_call_test.cc @@ -65,11 +65,11 @@ TEST_F(BuilderTest, Expression_Call) { Source{}, create( ast::BinaryOp::kAdd, create("a"), create("b")))); - ast::Function a_func(Source{}, "a_func", func_params, &f32, body, - ast::FunctionDecorationList{}); + ast::Function a_func(Source{}, mod->RegisterSymbol("a_func"), "a_func", + func_params, &f32, body, ast::FunctionDecorationList{}); - ast::Function func(Source{}, "main", {}, &void_type, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {}, + &void_type, create(), ast::FunctionDecorationList{}); ast::ExpressionList call_params; @@ -143,11 +143,12 @@ TEST_F(BuilderTest, Statement_Call) { ast::BinaryOp::kAdd, create("a"), create("b")))); - ast::Function a_func(Source{}, "a_func", func_params, &void_type, body, + ast::Function a_func(Source{}, mod->RegisterSymbol("a_func"), "a_func", + func_params, &void_type, body, ast::FunctionDecorationList{}); - ast::Function func(Source{}, "main", {}, &void_type, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {}, + &void_type, create(), ast::FunctionDecorationList{}); ast::ExpressionList call_params; diff --git a/src/writer/spirv/builder_function_decoration_test.cc b/src/writer/spirv/builder_function_decoration_test.cc index 6f835e1cc5..dfab021c72 100644 --- a/src/writer/spirv/builder_function_decoration_test.cc +++ b/src/writer/spirv/builder_function_decoration_test.cc @@ -42,7 +42,8 @@ TEST_F(BuilderTest, FunctionDecoration_Stage) { ast::type::Void void_type; ast::Function func( - Source{}, "main", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -67,8 +68,8 @@ TEST_P(FunctionDecoration_StageTest, Emit) { ast::type::Void void_type; - ast::Function func(Source{}, "main", {}, &void_type, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {}, + &void_type, create(), ast::FunctionDecorationList{ create(params.stage, Source{}), }); @@ -97,7 +98,8 @@ TEST_F(BuilderTest, FunctionDecoration_Stage_WithUnusedInterfaceIds) { ast::type::Void void_type; ast::Function func( - Source{}, "main", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -174,7 +176,7 @@ TEST_F(BuilderTest, FunctionDecoration_Stage_WithUsedInterfaceIds) { create("my_in"))); ast::Function func( - Source{}, "main", {}, &void_type, body, + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kVertex, Source{}), }); @@ -244,7 +246,8 @@ TEST_F(BuilderTest, FunctionDecoration_ExecutionMode_Fragment_OriginUpperLeft) { ast::type::Void void_type; ast::Function func( - Source{}, "main", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -259,7 +262,8 @@ TEST_F(BuilderTest, FunctionDecoration_WorkgroupSize_Default) { ast::type::Void void_type; ast::Function func( - Source{}, "main", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kCompute, Source{}), }); @@ -274,7 +278,8 @@ TEST_F(BuilderTest, FunctionDecoration_WorkgroupSize) { ast::type::Void void_type; ast::Function func( - Source{}, "main", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, + create(), ast::FunctionDecorationList{ create(2u, 4u, 6u, Source{}), create(ast::PipelineStage::kCompute, Source{}), @@ -290,13 +295,15 @@ TEST_F(BuilderTest, FunctionDecoration_ExecutionMode_MultipleFragment) { ast::type::Void void_type; ast::Function func1( - Source{}, "main1", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main1"), "main1", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); ast::Function func2( - Source{}, "main2", {}, &void_type, create(), + Source{}, mod->RegisterSymbol("main2"), "main2", {}, &void_type, + create(), ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); diff --git a/src/writer/spirv/builder_function_test.cc b/src/writer/spirv/builder_function_test.cc index 51da30623d..983a3b2f0b 100644 --- a/src/writer/spirv/builder_function_test.cc +++ b/src/writer/spirv/builder_function_test.cc @@ -47,8 +47,8 @@ using BuilderTest = TestHelper; TEST_F(BuilderTest, Function_Empty) { ast::type::Void void_type; - ast::Function func(Source{}, "a_func", {}, &void_type, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)); @@ -68,8 +68,8 @@ TEST_F(BuilderTest, Function_Terminator_Return) { auto* body = create(); body->append(create(Source{})); - ast::Function func(Source{}, "a_func", {}, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, body, ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)); EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" @@ -101,8 +101,8 @@ TEST_F(BuilderTest, Function_Terminator_ReturnValue) { Source{}, create("a"))); ASSERT_TRUE(td.DetermineResultType(body)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, body, ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -128,8 +128,8 @@ TEST_F(BuilderTest, Function_Terminator_Discard) { auto* body = create(); body->append(create()); - ast::Function func(Source{}, "a_func", {}, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, body, ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)); EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" @@ -168,8 +168,8 @@ TEST_F(BuilderTest, Function_WithParams) { auto* body = create(); body->append(create( Source{}, create("a"))); - ast::Function func(Source{}, "a_func", params, &f32, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", params, + &f32, body, ast::FunctionDecorationList{}); td.RegisterVariableForTesting(func.params()[0]); td.RegisterVariableForTesting(func.params()[1]); @@ -197,8 +197,8 @@ TEST_F(BuilderTest, Function_WithBody) { auto* body = create(); body->append(create(Source{})); - ast::Function func(Source{}, "a_func", {}, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, body, ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)); EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" @@ -213,8 +213,8 @@ OpFunctionEnd TEST_F(BuilderTest, FunctionType) { ast::type::Void void_type; - ast::Function func(Source{}, "a_func", {}, &void_type, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)); @@ -225,11 +225,11 @@ TEST_F(BuilderTest, FunctionType) { TEST_F(BuilderTest, FunctionType_DeDuplicate) { ast::type::Void void_type; - ast::Function func1(Source{}, "a_func", {}, &void_type, - create(), + ast::Function func1(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &void_type, create(), ast::FunctionDecorationList{}); - ast::Function func2(Source{}, "b_func", {}, &void_type, - create(), + ast::Function func2(Source{}, mod->RegisterSymbol("b_func"), "b_func", {}, + &void_type, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func1)); @@ -307,12 +307,12 @@ TEST_F(BuilderTest, Emit_Multiple_EntryPoint_With_Same_ModuleVar) { body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "a", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod->RegisterSymbol("a"), "a", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod->AddFunction(func); } @@ -334,12 +334,12 @@ TEST_F(BuilderTest, Emit_Multiple_EntryPoint_With_Same_ModuleVar) { body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "b", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod->RegisterSymbol("b"), "b", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod->AddFunction(func); } diff --git a/src/writer/spirv/builder_intrinsic_test.cc b/src/writer/spirv/builder_intrinsic_test.cc index 98397116d3..d839c9c1e2 100644 --- a/src/writer/spirv/builder_intrinsic_test.cc +++ b/src/writer/spirv/builder_intrinsic_test.cc @@ -471,8 +471,8 @@ TEST_F(IntrinsicBuilderTest, Call_GLSLMethod_WithLoad) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error(); @@ -505,8 +505,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Float_Test, Call_Scalar) { auto expr = Call(param.name, 1.0f); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -533,8 +533,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Float_Test, Call_Vector) { auto expr = Call(param.name, vec2(1.0f, 1.0f)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -587,8 +587,8 @@ TEST_F(IntrinsicBuilderTest, Call_Length_Scalar) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -612,8 +612,8 @@ TEST_F(IntrinsicBuilderTest, Call_Length_Vector) { auto expr = Call("length", vec2(1.0f, 1.0f)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -639,8 +639,8 @@ TEST_F(IntrinsicBuilderTest, Call_Normalize) { auto expr = Call("normalize", vec2(1.0f, 1.0f)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -671,8 +671,8 @@ TEST_P(Intrinsic_Builtin_DualParam_Float_Test, Call_Scalar) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -700,8 +700,8 @@ TEST_P(Intrinsic_Builtin_DualParam_Float_Test, Call_Vector) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -737,8 +737,8 @@ TEST_F(IntrinsicBuilderTest, Call_Distance_Scalar) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -763,8 +763,8 @@ TEST_F(IntrinsicBuilderTest, Call_Distance_Vector) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -792,8 +792,8 @@ TEST_F(IntrinsicBuilderTest, Call_Cross) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -823,8 +823,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Float_Test, Call_Scalar) { auto expr = Call(param.name, 1.0f, 1.0f, 1.0f); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -853,8 +853,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Float_Test, Call_Vector) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -894,8 +894,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Sint_Test, Call_Scalar) { auto expr = Call(param.name, 1); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -922,8 +922,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Sint_Test, Call_Vector) { auto expr = Call(param.name, vec2(1, 1)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -957,8 +957,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Uint_Test, Call_Scalar) { auto expr = Call(param.name, 1u); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -985,8 +985,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Uint_Test, Call_Vector) { auto expr = Call(param.name, vec2(1u, 1u)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1020,8 +1020,8 @@ TEST_P(Intrinsic_Builtin_DualParam_SInt_Test, Call_Scalar) { auto expr = Call(param.name, 1, 1); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1048,8 +1048,8 @@ TEST_P(Intrinsic_Builtin_DualParam_SInt_Test, Call_Vector) { auto expr = Call(param.name, vec2(1, 1), vec2(1, 1)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1084,8 +1084,8 @@ TEST_P(Intrinsic_Builtin_DualParam_UInt_Test, Call_Scalar) { auto expr = Call(param.name, 1u, 1u); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1112,8 +1112,8 @@ TEST_P(Intrinsic_Builtin_DualParam_UInt_Test, Call_Vector) { auto expr = Call(param.name, vec2(1u, 1u), vec2(1u, 1u)); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1148,8 +1148,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Sint_Test, Call_Scalar) { auto expr = Call(param.name, 1, 1, 1); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1178,8 +1178,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Sint_Test, Call_Vector) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1213,8 +1213,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Uint_Test, Call_Scalar) { auto expr = Call(param.name, 1u, 1u, 1u); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1243,8 +1243,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Uint_Test, Call_Vector) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1276,8 +1276,8 @@ TEST_F(IntrinsicBuilderTest, Call_Determinant) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1320,8 +1320,8 @@ TEST_F(IntrinsicBuilderTest, Call_ArrayLength) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1360,8 +1360,8 @@ TEST_F(IntrinsicBuilderTest, Call_ArrayLength_OtherMembersInStruct) { ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); @@ -1405,8 +1405,8 @@ TEST_F(IntrinsicBuilderTest, DISABLED_Call_ArrayLength_Ptr) { auto expr = Call("arrayLength", "ptr_var"); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, ty.void_, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + ty.void_, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); diff --git a/src/writer/spirv/builder_switch_test.cc b/src/writer/spirv/builder_switch_test.cc index 8ae74b0a27..b57cee98a5 100644 --- a/src/writer/spirv/builder_switch_test.cc +++ b/src/writer/spirv/builder_switch_test.cc @@ -121,8 +121,8 @@ TEST_F(BuilderTest, Switch_WithCase) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); @@ -201,8 +201,8 @@ TEST_F(BuilderTest, Switch_WithDefault) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); @@ -300,8 +300,8 @@ TEST_F(BuilderTest, Switch_WithCaseAndDefault) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); @@ -408,8 +408,8 @@ TEST_F(BuilderTest, Switch_CaseWithFallthrough) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); @@ -495,8 +495,8 @@ TEST_F(BuilderTest, Switch_CaseFallthroughLastStatement) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); @@ -563,8 +563,8 @@ TEST_F(BuilderTest, Switch_WithNestedBreak) { td.RegisterVariableForTesting(a); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); - ast::Function func(Source{}, "a_func", {}, &i32, - create(), + ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {}, + &i32, create(), ast::FunctionDecorationList{}); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); diff --git a/src/writer/wgsl/generator_impl.cc b/src/writer/wgsl/generator_impl.cc index 5abd8db8f3..58982ddd2e 100644 --- a/src/writer/wgsl/generator_impl.cc +++ b/src/writer/wgsl/generator_impl.cc @@ -113,7 +113,8 @@ bool GeneratorImpl::Generate(const ast::Module& module) { bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module, ast::PipelineStage stage, const std::string& name) { - auto* func = module.FindFunctionByNameAndStage(name, stage); + auto* func = + module.FindFunctionBySymbolAndStage(module.GetSymbol(name), stage); if (func == nullptr) { error_ = "Unable to find requested entry point: " + name; return false; @@ -153,7 +154,7 @@ bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module, } for (auto* f : module.functions()) { - if (!f->HasAncestorEntryPoint(name)) { + if (!f->HasAncestorEntryPoint(module.GetSymbol(name))) { continue; } diff --git a/src/writer/wgsl/generator_impl_function_test.cc b/src/writer/wgsl/generator_impl_function_test.cc index b9d3a0c0dd..573b391ae2 100644 --- a/src/writer/wgsl/generator_impl_function_test.cc +++ b/src/writer/wgsl/generator_impl_function_test.cc @@ -46,8 +46,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function) { body->append(create(Source{})); ast::type::Void void_type; - ast::Function func(Source{}, "my_func", {}, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, + &void_type, body, ast::FunctionDecorationList{}); gen.increment_indent(); @@ -85,8 +85,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithParams) { ast::VariableDecorationList{})); // decorations ast::type::Void void_type; - ast::Function func(Source{}, "my_func", params, &void_type, body, - ast::FunctionDecorationList{}); + ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", params, + &void_type, body, ast::FunctionDecorationList{}); gen.increment_indent(); @@ -104,7 +104,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) { body->append(create(Source{})); ast::type::Void void_type; - ast::Function func(Source{}, "my_func", {}, &void_type, body, + ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, + &void_type, body, ast::FunctionDecorationList{ create(2u, 4u, 6u, Source{}), }); @@ -127,7 +128,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Stage) { ast::type::Void void_type; ast::Function func( - Source{}, "my_func", {}, &void_type, body, + Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), }); @@ -150,7 +151,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Multiple) { ast::type::Void void_type; ast::Function func( - Source{}, "my_func", {}, &void_type, body, + Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, &void_type, body, ast::FunctionDecorationList{ create(ast::PipelineStage::kFragment, Source{}), create(2u, 4u, 6u, Source{}), @@ -237,12 +238,12 @@ TEST_F(WgslGeneratorImplTest, body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "a", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } @@ -264,12 +265,12 @@ TEST_F(WgslGeneratorImplTest, body->append(create(var)); body->append(create(Source{})); - auto* func = - create(Source{}, "b", params, &void_type, body, - ast::FunctionDecorationList{ - create( - ast::PipelineStage::kCompute, Source{}), - }); + auto* func = create( + Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body, + ast::FunctionDecorationList{ + create(ast::PipelineStage::kCompute, + Source{}), + }); mod.AddFunction(func); } diff --git a/src/writer/wgsl/generator_impl_test.cc b/src/writer/wgsl/generator_impl_test.cc index 53431531fe..1b8da48a8a 100644 --- a/src/writer/wgsl/generator_impl_test.cc +++ b/src/writer/wgsl/generator_impl_test.cc @@ -33,8 +33,9 @@ TEST_F(WgslGeneratorImplTest, Generate) { ast::type::Void void_type; mod.AddFunction(create( - Source{}, "my_func", ast::VariableList{}, &void_type, - create(), ast::FunctionDecorationList{})); + Source{}, mod.RegisterSymbol("a_func"), "my_func", ast::VariableList{}, + &void_type, create(), + ast::FunctionDecorationList{})); ASSERT_TRUE(gen.Generate(mod)) << gen.error(); EXPECT_EQ(gen.result(), R"(fn my_func() -> void {