Add a symbol to the Function AST node.

This Cl adds a Symbol representing the function name to the function
AST. The symbol is added alongside the name for now. When all usages of
the function name are removed then the string version will be removed
from the constructor.

Change-Id: Ib2450e5fe531e988b25bb7d2937acc6af2187871
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35220
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2020-12-11 18:24:53 +00:00 committed by Commit Bot service account
parent cd9e5f6e91
commit a41132fcd8
48 changed files with 923 additions and 658 deletions

View File

@ -31,12 +31,14 @@ namespace tint {
namespace ast {
Function::Function(const Source& source,
Symbol symbol,
const std::string& name,
VariableList params,
type::Type* return_type,
BlockStatement* body,
FunctionDecorationList decorations)
: Base(source),
symbol_(symbol),
name_(name),
params_(std::move(params)),
return_type_(return_type),
@ -202,7 +204,7 @@ Function::local_referenced_builtin_variables() const {
return ret;
}
void Function::add_ancestor_entry_point(const std::string& ep) {
void Function::add_ancestor_entry_point(Symbol ep) {
for (const auto& point : ancestor_entry_points_) {
if (point == ep) {
return;
@ -211,9 +213,9 @@ void Function::add_ancestor_entry_point(const std::string& ep) {
ancestor_entry_points_.push_back(ep);
}
bool Function::HasAncestorEntryPoint(const std::string& name) const {
bool Function::HasAncestorEntryPoint(Symbol symbol) const {
for (const auto& point : ancestor_entry_points_) {
if (point == name) {
if (point == symbol) {
return true;
}
}
@ -226,7 +228,7 @@ const Statement* Function::get_last_statement() const {
Function* Function::Clone(CloneContext* ctx) const {
return ctx->mod->create<Function>(
ctx->Clone(source()), name_, ctx->Clone(params_),
ctx->Clone(source()), symbol_, name_, ctx->Clone(params_),
ctx->Clone(return_type_), ctx->Clone(body_), ctx->Clone(decorations_));
}
@ -238,7 +240,7 @@ bool Function::IsValid() const {
if (body_ == nullptr || !body_->IsValid()) {
return false;
}
if (name_.length() == 0) {
if (name_.length() == 0 || !symbol_.IsValid()) {
return false;
}
if (return_type_ == nullptr) {
@ -249,7 +251,7 @@ bool Function::IsValid() const {
void Function::to_str(std::ostream& out, size_t indent) const {
make_indent(out, indent);
out << "Function " << name_ << " -> " << return_type_->type_name()
out << "Function " << symbol_.to_str() << " -> " << return_type_->type_name()
<< std::endl;
for (auto* deco : decorations()) {

View File

@ -35,6 +35,7 @@
#include "src/ast/type/sampler_type.h"
#include "src/ast/type/type.h"
#include "src/ast/variable.h"
#include "src/symbol.h"
namespace tint {
namespace ast {
@ -52,12 +53,14 @@ class Function : public Castable<Function, Node> {
/// Create a function
/// @param source the variable source
/// @param symbol the function symbol
/// @param name the function name
/// @param params the function parameters
/// @param return_type the return type
/// @param body the function body
/// @param decorations the function decorations
Function(const Source& source,
Symbol symbol,
const std::string& name,
VariableList params,
type::Type* return_type,
@ -68,6 +71,8 @@ class Function : public Castable<Function, Node> {
~Function() override;
/// @returns the function symbol
Symbol symbol() const { return symbol_; }
/// @returns the function name
const std::string& name() { return name_; }
/// @returns the function params
@ -150,15 +155,15 @@ class Function : public Castable<Function, Node> {
/// Adds an ancestor entry point
/// @param ep the entry point ancestor
void add_ancestor_entry_point(const std::string& ep);
void add_ancestor_entry_point(Symbol ep);
/// @returns the ancestor entry points
const std::vector<std::string>& ancestor_entry_points() const {
const std::vector<Symbol>& ancestor_entry_points() const {
return ancestor_entry_points_;
}
/// Checks if the given entry point is an ancestor
/// @param name the entry point name
/// @returns true if `name` is an ancestor entry point of this function
bool HasAncestorEntryPoint(const std::string& name) const;
/// @param sym the entry point symbol
/// @returns true if `sym` is an ancestor entry point of this function
bool HasAncestorEntryPoint(Symbol sym) const;
/// @returns the function return type.
type::Type* return_type() const { return return_type_; }
@ -197,13 +202,14 @@ class Function : public Castable<Function, Node> {
const std::vector<std::pair<Variable*, Function::BindingInfo>>
ReferencedSampledTextureVariablesImpl(bool multisampled) const;
Symbol symbol_;
std::string name_;
VariableList params_;
type::Type* return_type_ = nullptr;
BlockStatement* body_ = nullptr;
std::vector<Variable*> referenced_module_vars_;
std::vector<Variable*> local_referenced_module_vars_;
std::vector<std::string> ancestor_entry_points_;
std::vector<Symbol> ancestor_entry_points_;
FunctionDecorationList decorations_;
};

View File

@ -35,14 +35,18 @@ TEST_F(FunctionTest, Creation) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
ast::VariableDecorationList{}));
auto* var = params[0];
Function f(Source{}, "func", params, &void_type, create<BlockStatement>(),
FunctionDecorationList{});
Function f(Source{}, func_sym, "func", params, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
EXPECT_EQ(f.symbol(), func_sym);
EXPECT_EQ(f.name(), "func");
ASSERT_EQ(f.params().size(), 1u);
EXPECT_EQ(f.return_type(), &void_type);
@ -53,13 +57,16 @@ TEST_F(FunctionTest, Creation_WithSource) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
ast::VariableDecorationList{}));
Function f(Source{Source::Location{20, 2}}, "func", params, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
Function f(Source{Source::Location{20, 2}}, func_sym, "func", params,
&void_type, create<BlockStatement>(), FunctionDecorationList{});
auto src = f.source();
EXPECT_EQ(src.range.begin.line, 20u);
EXPECT_EQ(src.range.begin.column, 2u);
@ -69,9 +76,12 @@ TEST_F(FunctionTest, AddDuplicateReferencedVariables) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
Variable v(Source{}, "var", StorageClass::kInput, &i32, false, nullptr,
ast::VariableDecorationList{});
Function f(Source{}, "func", VariableList{}, &void_type,
Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
f.add_referenced_module_variable(&v);
@ -92,6 +102,9 @@ TEST_F(FunctionTest, GetReferenceLocations) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
auto* loc1 = create<Variable>(Source{}, "loc1", StorageClass::kInput, &i32,
false, nullptr,
ast::VariableDecorationList{
@ -116,7 +129,7 @@ TEST_F(FunctionTest, GetReferenceLocations) {
create<BuiltinDecoration>(Builtin::kFragDepth, Source{}),
});
Function f(Source{}, "func", VariableList{}, &void_type,
Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
f.add_referenced_module_variable(loc1);
@ -137,6 +150,9 @@ TEST_F(FunctionTest, GetReferenceBuiltins) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
auto* loc1 = create<Variable>(Source{}, "loc1", StorageClass::kInput, &i32,
false, nullptr,
ast::VariableDecorationList{
@ -161,7 +177,7 @@ TEST_F(FunctionTest, GetReferenceBuiltins) {
create<BuiltinDecoration>(Builtin::kFragDepth, Source{}),
});
Function f(Source{}, "func", VariableList{}, &void_type,
Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
f.add_referenced_module_variable(loc1);
@ -180,22 +196,30 @@ TEST_F(FunctionTest, GetReferenceBuiltins) {
TEST_F(FunctionTest, AddDuplicateEntryPoints) {
type::Void void_type;
Function f(Source{}, "func", VariableList{}, &void_type,
Module m;
auto func_sym = m.RegisterSymbol("func");
auto main_sym = m.RegisterSymbol("main");
Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
f.add_ancestor_entry_point("main");
f.add_ancestor_entry_point(main_sym);
ASSERT_EQ(1u, f.ancestor_entry_points().size());
EXPECT_EQ("main", f.ancestor_entry_points()[0]);
EXPECT_EQ(main_sym, f.ancestor_entry_points()[0]);
f.add_ancestor_entry_point("main");
f.add_ancestor_entry_point(main_sym);
ASSERT_EQ(1u, f.ancestor_entry_points().size());
EXPECT_EQ("main", f.ancestor_entry_points()[0]);
EXPECT_EQ(main_sym, f.ancestor_entry_points()[0]);
}
TEST_F(FunctionTest, IsValid) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
@ -204,21 +228,27 @@ TEST_F(FunctionTest, IsValid) {
auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>());
Function f(Source{}, "func", params, &void_type, body,
Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
EXPECT_TRUE(f.IsValid());
}
TEST_F(FunctionTest, IsValid_EmptyName) {
TEST_F(FunctionTest, IsValid_InvalidName) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("");
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
ast::VariableDecorationList{}));
Function f(Source{}, "", params, &void_type, create<BlockStatement>(),
auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>());
Function f(Source{}, func_sym, "", params, &void_type, body,
FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
}
@ -226,13 +256,16 @@ TEST_F(FunctionTest, IsValid_EmptyName) {
TEST_F(FunctionTest, IsValid_MissingReturnType) {
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
ast::VariableDecorationList{}));
Function f(Source{}, "func", params, nullptr, create<BlockStatement>(),
FunctionDecorationList{});
Function f(Source{}, func_sym, "func", params, nullptr,
create<BlockStatement>(), FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
}
@ -240,27 +273,33 @@ TEST_F(FunctionTest, IsValid_NullParam) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
ast::VariableDecorationList{}));
params.push_back(nullptr);
Function f(Source{}, "func", params, &void_type, create<BlockStatement>(),
FunctionDecorationList{});
Function f(Source{}, func_sym, "func", params, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
}
TEST_F(FunctionTest, IsValid_InvalidParam) {
type::Void void_type;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone,
nullptr, false, nullptr,
ast::VariableDecorationList{}));
Function f(Source{}, "func", params, &void_type, create<BlockStatement>(),
FunctionDecorationList{});
Function f(Source{}, func_sym, "func", params, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
}
@ -268,6 +307,9 @@ TEST_F(FunctionTest, IsValid_NullBodyStatement) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
@ -277,7 +319,7 @@ TEST_F(FunctionTest, IsValid_NullBodyStatement) {
body->append(create<DiscardStatement>());
body->append(nullptr);
Function f(Source{}, "func", params, &void_type, body,
Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
@ -287,6 +329,9 @@ TEST_F(FunctionTest, IsValid_InvalidBodyStatement) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
@ -296,7 +341,7 @@ TEST_F(FunctionTest, IsValid_InvalidBodyStatement) {
body->append(create<DiscardStatement>());
body->append(nullptr);
Function f(Source{}, "func", params, &void_type, body,
Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
EXPECT_FALSE(f.IsValid());
}
@ -305,14 +350,18 @@ TEST_F(FunctionTest, ToStr) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>());
Function f(Source{}, "func", {}, &void_type, body, FunctionDecorationList{});
Function f(Source{}, func_sym, "func", {}, &void_type, body,
FunctionDecorationList{});
std::ostringstream out;
f.to_str(out, 2);
EXPECT_EQ(out.str(), R"( Function func -> __void
EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void
()
{
Discard{}
@ -324,16 +373,19 @@ TEST_F(FunctionTest, ToStr_WithDecoration) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>());
Function f(
Source{}, "func", {}, &void_type, body,
Source{}, func_sym, "func", {}, &void_type, body,
FunctionDecorationList{create<WorkgroupDecoration>(2, 4, 6, Source{})});
std::ostringstream out;
f.to_str(out, 2);
EXPECT_EQ(out.str(), R"( Function func -> __void
EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void
WorkgroupDecoration{2 4 6}
()
{
@ -346,6 +398,9 @@ TEST_F(FunctionTest, ToStr_WithParams) {
type::Void void_type;
type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr,
@ -354,12 +409,12 @@ TEST_F(FunctionTest, ToStr_WithParams) {
auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>());
Function f(Source{}, "func", params, &void_type, body,
Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
std::ostringstream out;
f.to_str(out, 2);
EXPECT_EQ(out.str(), R"( Function func -> __void
EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void
(
Variable{
var
@ -376,8 +431,11 @@ TEST_F(FunctionTest, ToStr_WithParams) {
TEST_F(FunctionTest, TypeName) {
type::Void void_type;
Function f(Source{}, "func", {}, &void_type, create<BlockStatement>(),
FunctionDecorationList{});
Module m;
auto func_sym = m.RegisterSymbol("func");
Function f(Source{}, func_sym, "func", {}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
EXPECT_EQ(f.type_name(), "__func__void");
}
@ -386,6 +444,9 @@ TEST_F(FunctionTest, TypeName_WithParams) {
type::I32 i32;
type::F32 f32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
params.push_back(create<Variable>(Source{}, "var1", StorageClass::kNone, &i32,
false, nullptr,
@ -394,19 +455,22 @@ TEST_F(FunctionTest, TypeName_WithParams) {
false, nullptr,
ast::VariableDecorationList{}));
Function f(Source{}, "func", params, &void_type, create<BlockStatement>(),
FunctionDecorationList{});
Function f(Source{}, func_sym, "func", params, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
EXPECT_EQ(f.type_name(), "__func__void__i32__f32");
}
TEST_F(FunctionTest, GetLastStatement) {
type::Void void_type;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
auto* body = create<BlockStatement>();
auto* stmt = create<DiscardStatement>();
body->append(stmt);
Function f(Source{}, "func", params, &void_type, body,
Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
EXPECT_EQ(f.get_last_statement(), stmt);
@ -415,9 +479,12 @@ TEST_F(FunctionTest, GetLastStatement) {
TEST_F(FunctionTest, GetLastStatement_nullptr) {
type::Void void_type;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params;
auto* body = create<BlockStatement>();
Function f(Source{}, "func", params, &void_type, body,
Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{});
EXPECT_EQ(f.get_last_statement(), nullptr);
@ -425,8 +492,12 @@ TEST_F(FunctionTest, GetLastStatement_nullptr) {
TEST_F(FunctionTest, WorkgroupSize_NoneSet) {
type::Void void_type;
Function f(Source{}, "f", {}, &void_type, create<BlockStatement>(),
FunctionDecorationList{});
Module m;
auto func_sym = m.RegisterSymbol("func");
Function f(Source{}, func_sym, "func", {}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
uint32_t x = 0;
uint32_t y = 0;
uint32_t z = 0;
@ -438,7 +509,12 @@ TEST_F(FunctionTest, WorkgroupSize_NoneSet) {
TEST_F(FunctionTest, WorkgroupSize) {
type::Void void_type;
Function f(Source{}, "f", {}, &void_type, create<BlockStatement>(),
Module m;
auto func_sym = m.RegisterSymbol("func");
Function f(Source{}, func_sym, "func", {}, &void_type,
create<BlockStatement>(),
{create<WorkgroupDecoration>(2u, 4u, 6u, Source{})});
uint32_t x = 0;

View File

@ -47,21 +47,23 @@ void Module::Clone(CloneContext* ctx) {
for (auto* func : functions_) {
ctx->mod->functions_.emplace_back(ctx->Clone(func));
}
ctx->mod->symbol_table_ = symbol_table_;
}
Function* Module::FindFunctionByName(const std::string& name) const {
Function* Module::FindFunctionBySymbol(Symbol sym) const {
for (auto* func : functions_) {
if (func->name() == name) {
if (func->symbol() == sym) {
return func;
}
}
return nullptr;
}
Function* Module::FindFunctionByNameAndStage(const std::string& name,
PipelineStage stage) const {
Function* Module::FindFunctionBySymbolAndStage(Symbol sym,
PipelineStage stage) const {
for (auto* func : functions_) {
if (func->name() == name && func->pipeline_stage() == stage) {
if (func->symbol() == sym && func->pipeline_stage() == stage) {
return func;
}
}
@ -81,6 +83,10 @@ Symbol Module::RegisterSymbol(const std::string& name) {
return symbol_table_.Register(name);
}
Symbol Module::GetSymbol(const std::string& name) const {
return symbol_table_.GetSymbol(name);
}
std::string Module::SymbolToName(const Symbol sym) const {
return symbol_table_.NameFor(sym);
}

View File

@ -89,15 +89,14 @@ class Module {
/// @returns the modules functions
const FunctionList& functions() const { return functions_; }
/// Returns the function with the given name
/// @param name the name to search for
/// @param sym the function symbol to search for
/// @returns the associated function or nullptr if none exists
Function* FindFunctionByName(const std::string& name) const;
Function* FindFunctionBySymbol(Symbol sym) const;
/// Returns the function with the given name
/// @param name the name to search for
/// @param sym the function symbol to search for
/// @param stage the pipeline stage
/// @returns the associated function or nullptr if none exists
Function* FindFunctionByNameAndStage(const std::string& name,
PipelineStage stage) const;
Function* FindFunctionBySymbolAndStage(Symbol sym, PipelineStage stage) const;
/// @param stage the pipeline stage
/// @returns true if the module contains an entrypoint function with the given
/// stage
@ -169,6 +168,11 @@ class Module {
/// previously generated symbol will be returned.
Symbol RegisterSymbol(const std::string& name);
/// Returns the symbol for `name`
/// @param name the name to lookup
/// @returns the symbol for name or symbol::kInvalid
Symbol GetSymbol(const std::string& name) const;
/// Returns the `name` for `sym`
/// @param sym the symbol to retrieve the name for
/// @returns the use provided `name` for the symbol or "" if not found

View File

@ -48,16 +48,17 @@ TEST_F(ModuleTest, LookupFunction) {
type::F32 f32;
Module m;
auto func_sym = m.RegisterSymbol("main");
auto* func =
create<Function>(Source{}, "main", VariableList{}, &f32,
create<Function>(Source{}, func_sym, "main", VariableList{}, &f32,
create<BlockStatement>(), ast::FunctionDecorationList{});
m.AddFunction(func);
EXPECT_EQ(func, m.FindFunctionByName("main"));
EXPECT_EQ(func, m.FindFunctionBySymbol(func_sym));
}
TEST_F(ModuleTest, LookupFunctionMissing) {
Module m;
EXPECT_EQ(nullptr, m.FindFunctionByName("Missing"));
EXPECT_EQ(nullptr, m.FindFunctionBySymbol(m.RegisterSymbol("Missing")));
}
TEST_F(ModuleTest, IsValid_Empty) {
@ -127,11 +128,12 @@ TEST_F(ModuleTest, IsValid_Struct_EmptyName) {
TEST_F(ModuleTest, IsValid_Function) {
type::F32 f32;
auto* func =
create<Function>(Source{}, "main", VariableList(), &f32,
create<BlockStatement>(), ast::FunctionDecorationList{});
Module m;
auto* func = create<Function>(Source{}, m.RegisterSymbol("main"), "main",
VariableList(), &f32, create<BlockStatement>(),
ast::FunctionDecorationList{});
m.AddFunction(func);
EXPECT_TRUE(m.IsValid());
}
@ -144,10 +146,13 @@ TEST_F(ModuleTest, IsValid_Null_Function) {
TEST_F(ModuleTest, IsValid_Invalid_Function) {
VariableList p;
auto* func = create<Function>(Source{}, "", p, nullptr, nullptr,
ast::FunctionDecorationList{});
Module m;
auto* func =
create<Function>(Source{}, m.RegisterSymbol("main"), "main", p, nullptr,
nullptr, ast::FunctionDecorationList{});
m.AddFunction(func);
EXPECT_FALSE(m.IsValid());
}

View File

@ -267,7 +267,7 @@ std::vector<ResourceBinding> Inspector::GetMultisampledTextureResourceBindings(
}
ast::Function* Inspector::FindEntryPointByName(const std::string& name) {
auto* func = module_.FindFunctionByName(name);
auto* func = module_.FindFunctionBySymbol(module_.GetSymbol(name));
if (!func) {
error_ += name + " was not found!";
return nullptr;

View File

@ -83,8 +83,9 @@ class InspectorHelper {
ast::FunctionDecorationList decorations = {}) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, name, ast::VariableList(),
void_type(), body, decorations);
return create<ast::Function>(Source{}, mod()->RegisterSymbol(name), name,
ast::VariableList(), void_type(), body,
decorations);
}
/// Generates a function that calls another
@ -102,8 +103,9 @@ class InspectorHelper {
create<ast::CallExpression>(ident_expr, ast::ExpressionList());
body->append(create<ast::CallStatement>(call_expr));
body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, caller, ast::VariableList(),
void_type(), body, decorations);
return create<ast::Function>(Source{}, mod()->RegisterSymbol(caller),
caller, ast::VariableList(), void_type(), body,
decorations);
}
/// Add In/Out variables to the global variables
@ -154,8 +156,9 @@ class InspectorHelper {
create<ast::IdentifierExpression>(in)));
}
body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, name, ast::VariableList(),
void_type(), body, decorations);
return create<ast::Function>(Source{}, mod()->RegisterSymbol(name), name,
ast::VariableList(), void_type(), body,
decorations);
}
/// Generates a function that references in/out variables and calls another
@ -184,8 +187,9 @@ class InspectorHelper {
create<ast::CallExpression>(ident_expr, ast::ExpressionList());
body->append(create<ast::CallStatement>(call_expr));
body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, caller, ast::VariableList(),
void_type(), body, decorations);
return create<ast::Function>(Source{}, mod()->RegisterSymbol(caller),
caller, ast::VariableList(), void_type(), body,
decorations);
}
/// Add a Constant ID to the global variables.
@ -445,9 +449,9 @@ class InspectorHelper {
}
body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, func_name, ast::VariableList(),
void_type(), body,
ast::FunctionDecorationList{});
return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
func_name, ast::VariableList(), void_type(),
body, ast::FunctionDecorationList{});
}
/// Adds a regular sampler variable to the module
@ -587,8 +591,9 @@ class InspectorHelper {
create<ast::IdentifierExpression>("sampler_result"), call_expr));
body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, func_name, ast::VariableList(),
void_type(), body, decorations);
return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
func_name, ast::VariableList(), void_type(),
body, decorations);
}
/// Generates a function that references a specific sampler variable
@ -634,8 +639,9 @@ class InspectorHelper {
create<ast::IdentifierExpression>("sampler_result"), call_expr));
body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, func_name, ast::VariableList(),
void_type(), body, decorations);
return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
func_name, ast::VariableList(), void_type(),
body, decorations);
}
/// Generates a function that references a specific comparison sampler
@ -682,8 +688,9 @@ class InspectorHelper {
create<ast::IdentifierExpression>("sampler_result"), call_expr));
body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, func_name, ast::VariableList(),
void_type(), body, decorations);
return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
func_name, ast::VariableList(), void_type(),
body, decorations);
}
/// Gets an appropriate type for the data in a given texture type.
@ -1513,7 +1520,8 @@ TEST_F(InspectorGetUniformBufferResourceBindingsTest, MultipleUniformBuffers) {
body->append(create<ast::ReturnStatement>(Source{}));
ast::Function* func = create<ast::Function>(
Source{}, "ep_func", ast::VariableList(), void_type(), body,
Source{}, mod()->RegisterSymbol("ep_func"), "ep_func",
ast::VariableList(), void_type(), body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -1659,7 +1667,8 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, MultipleStorageBuffers) {
body->append(create<ast::ReturnStatement>(Source{}));
ast::Function* func = create<ast::Function>(
Source{}, "ep_func", ast::VariableList(), void_type(), body,
Source{}, mod()->RegisterSymbol("ep_func"), "ep_func",
ast::VariableList(), void_type(), body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -1832,7 +1841,8 @@ TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest,
body->append(create<ast::ReturnStatement>(Source{}));
ast::Function* func = create<ast::Function>(
Source{}, "ep_func", ast::VariableList(), void_type(), body,
Source{}, mod()->RegisterSymbol("ep_func"), "ep_func",
ast::VariableList(), void_type(), body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});

View File

@ -761,9 +761,10 @@ bool FunctionEmitter::Emit() {
}
auto* body = statements_stack_[0].statements_;
ast_module_.AddFunction(create<ast::Function>(
decl.source, decl.name, std::move(decl.params), decl.return_type, body,
std::move(decl.decorations)));
ast_module_.AddFunction(
create<ast::Function>(decl.source, ast_module_.RegisterSymbol(decl.name),
decl.name, std::move(decl.params), decl.return_type,
body, std::move(decl.decorations)));
// Maintain the invariant by repopulating the one and only element.
statements_stack_.clear();

View File

@ -46,14 +46,16 @@ TEST_F(SpvParserTest, EmitStatement_VoidCallNoParams) {
OpFunctionEnd
)"));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error();
const auto module_ast_str = p->module().to_str();
const auto module_ast_str = p->get_module().to_str();
EXPECT_THAT(module_ast_str, Eq(R"(Module{
Function x_50 -> __void
Function )" + p->get_module().GetSymbol("x_50").to_str() +
R"( -> __void
()
{
Return{}
}
Function x_100 -> __void
Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __void
()
{
Call[not set]{
@ -214,9 +216,10 @@ TEST_F(SpvParserTest, EmitStatement_CallWithParams) {
)"));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error();
EXPECT_TRUE(p->error().empty());
const auto module_ast_str = p->module().to_str();
const auto module_ast_str = p->get_module().to_str();
EXPECT_THAT(module_ast_str, HasSubstr(R"(Module{
Function x_50 -> __u32
Function )" + p->get_module().GetSymbol("x_50").to_str() +
R"( -> __u32
(
VariableConst{
x_51
@ -240,7 +243,8 @@ TEST_F(SpvParserTest, EmitStatement_CallWithParams) {
}
}
}
Function x_100 -> __void
Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __void
()
{
VariableDeclStatement{

View File

@ -59,9 +59,10 @@ TEST_F(SpvParserTest, Emit_VoidFunctionWithoutParams) {
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit());
auto got = p->module().to_str();
auto* expect = R"(Module{
Function x_100 -> __void
auto got = p->get_module().to_str();
auto expect = R"(Module{
Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __void
()
{
Return{}
@ -83,9 +84,10 @@ TEST_F(SpvParserTest, Emit_NonVoidResultType) {
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit());
auto got = p->module().to_str();
auto* expect = R"(Module{
Function x_100 -> __f32
auto got = p->get_module().to_str();
auto expect = R"(Module{
Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __f32
()
{
Return{
@ -115,9 +117,10 @@ TEST_F(SpvParserTest, Emit_MixedParamTypes) {
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit());
auto got = p->module().to_str();
auto* expect = R"(Module{
Function x_100 -> __void
auto got = p->get_module().to_str();
auto expect = R"(Module{
Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __void
(
VariableConst{
a
@ -159,9 +162,10 @@ TEST_F(SpvParserTest, Emit_GenerateParamNames) {
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit());
auto got = p->module().to_str();
auto* expect = R"(Module{
Function x_100 -> __void
auto got = p->get_module().to_str();
auto expect = R"(Module{
Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __void
(
VariableConst{
x_14

View File

@ -53,7 +53,7 @@ TEST_F(SpvParserTest, EmitFunctions_NoFunctions) {
auto p = parser(test::Assemble(CommonTypes()));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, Not(HasSubstr("Function{")));
}
@ -64,7 +64,7 @@ TEST_F(SpvParserTest, EmitFunctions_FunctionWithoutBody) {
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, Not(HasSubstr("Function{")));
}
@ -79,9 +79,10 @@ OpFunctionEnd)";
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
Function main -> __void
Function )" + p->get_module().GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex}
()
{)"));
@ -98,9 +99,10 @@ OpFunctionEnd)";
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
Function main -> __void
Function )" + p->get_module().GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{fragment}
()
{)"));
@ -117,9 +119,10 @@ OpFunctionEnd)";
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
Function main -> __void
Function )" + p->get_module().GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{compute}
()
{)"));
@ -138,14 +141,16 @@ OpFunctionEnd)";
auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error();
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
Function frag_main -> __void
Function )" + p->get_module().GetSymbol("frag_main").to_str() +
R"( -> __void
StageDecoration{fragment}
()
{)"));
EXPECT_THAT(module_ast, HasSubstr(R"(
Function comp_main -> __void
Function )" + p->get_module().GetSymbol("comp_main").to_str() +
R"( -> __void
StageDecoration{compute}
()
{)"));
@ -160,9 +165,10 @@ TEST_F(SpvParserTest, EmitFunctions_VoidFunctionWithoutParams) {
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
Function main -> __void
Function )" + p->get_module().GetSymbol("main").to_str() +
R"( -> __void
()
{)"));
}
@ -193,9 +199,10 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) {
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
Function leaf -> __u32
Function )" + p->get_module().GetSymbol("leaf").to_str() +
R"( -> __u32
()
{
Return{
@ -204,7 +211,8 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) {
}
}
}
Function branch -> __u32
Function )" + p->get_module().GetSymbol("branch").to_str() +
R"( -> __u32
()
{
VariableDeclStatement{
@ -227,7 +235,8 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) {
}
}
}
Function root -> __void
Function )" + p->get_module().GetSymbol("root").to_str() +
R"( -> __void
()
{
VariableDeclStatement{
@ -260,9 +269,10 @@ TEST_F(SpvParserTest, EmitFunctions_NonVoidResultType) {
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
Function ret_float -> __f32
Function )" + p->get_module().GetSymbol("ret_float").to_str() +
R"( -> __f32
()
{
Return{
@ -289,9 +299,10 @@ TEST_F(SpvParserTest, EmitFunctions_MixedParamTypes) {
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
Function mixed_params -> __void
Function )" + p->get_module().GetSymbol("mixed_params").to_str() +
R"( -> __void
(
VariableConst{
a
@ -328,9 +339,10 @@ TEST_F(SpvParserTest, EmitFunctions_GenerateParamNames) {
)"));
EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str();
const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"(
Function mixed_params -> __void
Function )" + p->get_module().GetSymbol("mixed_params").to_str() +
R"( -> __void
(
VariableConst{
x_14

View File

@ -1280,9 +1280,9 @@ Maybe<ast::Function*> ParserImpl::function_decl(ast::DecorationList& decos) {
if (errored)
return Failure::kErrored;
return create<ast::Function>(header->source, header->name, header->params,
header->return_type, body.value,
func_decos.value);
return create<ast::Function>(
header->source, module_.RegisterSymbol(header->name), header->name,
header->params, header->return_type, body.value, func_decos.value);
}
// function_type_decl

View File

@ -18,10 +18,14 @@ namespace tint {
SymbolTable::SymbolTable() = default;
SymbolTable::SymbolTable(const SymbolTable&) = default;
SymbolTable::SymbolTable(SymbolTable&&) = default;
SymbolTable::~SymbolTable() = default;
SymbolTable& SymbolTable::operator=(const SymbolTable& other) = default;
SymbolTable& SymbolTable::operator=(SymbolTable&&) = default;
Symbol SymbolTable::Register(const std::string& name) {
@ -41,6 +45,11 @@ Symbol SymbolTable::Register(const std::string& name) {
return sym;
}
Symbol SymbolTable::GetSymbol(const std::string& name) const {
auto it = name_to_symbol_.find(name);
return it != name_to_symbol_.end() ? it->second : Symbol();
}
std::string SymbolTable::NameFor(const Symbol symbol) const {
auto it = symbol_to_name_.find(symbol.value());
if (it == symbol_to_name_.end())

View File

@ -27,11 +27,17 @@ class SymbolTable {
public:
/// Constructor
SymbolTable();
/// Copy constructor
SymbolTable(const SymbolTable&);
/// Move Constructor
SymbolTable(SymbolTable&&);
/// Destructor
~SymbolTable();
/// Copy assignment
/// @param other the symbol table to copy
/// @returns the new symbol table
SymbolTable& operator=(const SymbolTable& other);
/// Move assignment
/// @param other the symbol table to move
/// @returns the symbol table
@ -42,6 +48,11 @@ class SymbolTable {
/// @returns the symbol representing the given name
Symbol Register(const std::string& name);
/// Returns the symbol for the given `name`
/// @param name the name to lookup
/// @returns the symbol for the name or symbol::kInvalid if not found.
Symbol GetSymbol(const std::string& name) const;
/// Returns the name for the given symbol
/// @param symbol the symbol to retrieve the name for
/// @returns the symbol name or "" if not found

View File

@ -51,7 +51,7 @@ namespace {
template <typename T = ast::Expression>
T* FindVariable(ast::Module* mod, std::string name) {
if (auto* func = mod->FindFunctionByName("func")) {
if (auto* func = mod->FindFunctionBySymbol(mod->RegisterSymbol("func"))) {
for (auto* stmt : *func->body()) {
if (auto* decl = stmt->As<ast::VariableDeclStatement>()) {
if (auto* var = decl->variable()) {
@ -92,9 +92,9 @@ class BoundArrayAccessorsTest : public testing::Test {
struct ModuleBuilder : public ast::BuilderWithModule {
ModuleBuilder() : body_(create<ast::BlockStatement>()) {
mod->AddFunction(create<ast::Function>(Source{}, "func",
ast::VariableList{}, ty.void_, body_,
ast::FunctionDecorationList{}));
mod->AddFunction(create<ast::Function>(
Source{}, mod->RegisterSymbol("func"), "func", ast::VariableList{},
ty.void_, body_, ast::FunctionDecorationList{}));
}
ast::Module Module() {

View File

@ -58,23 +58,26 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
Var("builtin_assignments_should_happen_before_this",
tint::ast::StorageClass::kFunction, ty.f32)));
mod->AddFunction(
create<ast::Function>(Source{}, "non_entry_a", ast::VariableList{},
ty.void_, create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{}));
auto a_sym = mod->RegisterSymbol("non_entry_a");
mod->AddFunction(create<ast::Function>(
Source{}, a_sym, "non_entry_a", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{}));
auto entry_sym = mod->RegisterSymbol("entry");
auto* entry = create<ast::Function>(
Source{}, "entry", ast::VariableList{}, ty.void_, block,
Source{}, entry_sym, "entry", ast::VariableList{}, ty.void_, block,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex,
Source{}),
});
mod->AddFunction(entry);
mod->AddFunction(
create<ast::Function>(Source{}, "non_entry_b", ast::VariableList{},
ty.void_, create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{}));
auto b_sym = mod->RegisterSymbol("non_entry_b");
mod->AddFunction(create<ast::Function>(
Source{}, b_sym, "non_entry_b", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{}));
}
};
@ -82,7 +85,7 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
auto* expected = R"(Module{
auto expected = R"(Module{
Variable{
Decorations{
BuiltinDecoration{pointsize}
@ -91,11 +94,13 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
out
__f32
}
Function non_entry_a -> __void
Function )" + result.module.RegisterSymbol("non_entry_a").to_str() +
R"( -> __void
()
{
}
Function entry -> __void
Function )" + result.module.RegisterSymbol("entry").to_str() +
R"( -> __void
StageDecoration{vertex}
()
{
@ -111,7 +116,8 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
}
}
}
Function non_entry_b -> __void
Function )" + result.module.RegisterSymbol("non_entry_b").to_str() +
R"( -> __void
()
{
}
@ -123,23 +129,26 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
struct Builder : ModuleBuilder {
void Build() override {
mod->AddFunction(
create<ast::Function>(Source{}, "non_entry_a", ast::VariableList{},
ty.void_, create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{}));
auto a_sym = mod->RegisterSymbol("non_entry_a");
mod->AddFunction(create<ast::Function>(
Source{}, a_sym, "non_entry_a", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{}));
mod->AddFunction(
create<ast::Function>(Source{}, "entry", ast::VariableList{},
ty.void_, create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{
create<ast::StageDecoration>(
ast::PipelineStage::kVertex, Source{}),
}));
auto entry_sym = mod->RegisterSymbol("entry");
mod->AddFunction(create<ast::Function>(
Source{}, entry_sym, "entry", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex,
Source{}),
}));
mod->AddFunction(
create<ast::Function>(Source{}, "non_entry_b", ast::VariableList{},
ty.void_, create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{}));
auto b_sym = mod->RegisterSymbol("non_entry_b");
mod->AddFunction(create<ast::Function>(
Source{}, b_sym, "non_entry_b", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{}));
}
};
@ -147,7 +156,7 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
auto* expected = R"(Module{
auto expected = R"(Module{
Variable{
Decorations{
BuiltinDecoration{pointsize}
@ -156,11 +165,13 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
out
__f32
}
Function non_entry_a -> __void
Function )" + result.module.RegisterSymbol("non_entry_a").to_str() +
R"( -> __void
()
{
}
Function entry -> __void
Function )" + result.module.RegisterSymbol("entry").to_str() +
R"( -> __void
StageDecoration{vertex}
()
{
@ -169,7 +180,8 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
ScalarConstructor[__f32]{1.000000}
}
}
Function non_entry_b -> __void
Function )" + result.module.RegisterSymbol("non_entry_b").to_str() +
R"( -> __void
()
{
}
@ -181,8 +193,9 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
TEST_F(EmitVertexPointSizeTest, NonVertexStage) {
struct Builder : ModuleBuilder {
void Build() override {
auto frag_sym = mod->RegisterSymbol("fragment_entry");
auto* fragment_entry = create<ast::Function>(
Source{}, "fragment_entry", ast::VariableList{}, ty.void_,
Source{}, frag_sym, "fragment_entry", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment,
@ -190,13 +203,14 @@ TEST_F(EmitVertexPointSizeTest, NonVertexStage) {
});
mod->AddFunction(fragment_entry);
auto* compute_entry =
create<ast::Function>(Source{}, "compute_entry", ast::VariableList{},
ty.void_, create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{
create<ast::StageDecoration>(
ast::PipelineStage::kCompute, Source{}),
});
auto comp_sym = mod->RegisterSymbol("compute_entry");
auto* compute_entry = create<ast::Function>(
Source{}, comp_sym, "compute_entry", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute,
Source{}),
});
mod->AddFunction(compute_entry);
}
};
@ -205,13 +219,15 @@ TEST_F(EmitVertexPointSizeTest, NonVertexStage) {
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
auto* expected = R"(Module{
Function fragment_entry -> __void
auto expected = R"(Module{
Function )" + result.module.RegisterSymbol("fragment_entry").to_str() +
R"( -> __void
StageDecoration{fragment}
()
{
}
Function compute_entry -> __void
Function )" + result.module.RegisterSymbol("compute_entry").to_str() +
R"( -> __void
StageDecoration{compute}
()
{

View File

@ -169,9 +169,9 @@ Transform::Output FirstIndexOffset::Run(ast::Module* in) {
body->append(ctx.Clone(s));
}
return ctx.mod->create<ast::Function>(
ctx.Clone(func->source()), func->name(), ctx.Clone(func->params()),
ctx.Clone(func->return_type()), ctx.Clone(body),
ctx.Clone(func->decorations()));
ctx.Clone(func->source()), func->symbol(), func->name(),
ctx.Clone(func->params()), ctx.Clone(func->return_type()),
ctx.Clone(body), ctx.Clone(func->decorations()));
});
in->Clone(&ctx);

View File

@ -58,9 +58,9 @@ struct ModuleBuilder : public ast::BuilderWithModule {
ast::Function* AddFunction(const std::string& name,
ast::VariableList params = {}) {
auto* func = create<ast::Function>(Source{}, name, std::move(params),
ty.u32, create<ast::BlockStatement>(),
ast::FunctionDecorationList());
auto* func = create<ast::Function>(
Source{}, mod->RegisterSymbol(name), name, std::move(params), ty.u32,
create<ast::BlockStatement>(), ast::FunctionDecorationList());
mod->AddFunction(func);
return func;
}
@ -154,7 +154,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) {
uniform
__struct_TintFirstIndexOffsetData
}
Function test -> __u32
Function tint_symbol_1 -> __u32
()
{
VariableDeclStatement{
@ -229,7 +229,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) {
uniform
__struct_TintFirstIndexOffsetData
}
Function test -> __u32
Function tint_symbol_1 -> __u32
()
{
VariableDeclStatement{
@ -317,7 +317,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) {
uniform
__struct_TintFirstIndexOffsetData
}
Function test -> __u32
Function tint_symbol_1 -> __u32
()
{
Return{
@ -389,7 +389,7 @@ TEST_F(FirstIndexOffsetTest, NestedCalls) {
uniform
__struct_TintFirstIndexOffsetData
}
Function func1 -> __u32
Function tint_symbol_1 -> __u32
()
{
VariableDeclStatement{
@ -415,7 +415,7 @@ TEST_F(FirstIndexOffsetTest, NestedCalls) {
}
}
}
Function func2 -> __u32
Function tint_symbol_2 -> __u32
()
{
Return{

View File

@ -84,8 +84,8 @@ Transform::Output VertexPulling::Run(ast::Module* in) {
}
// Find entry point
auto* func = mod->FindFunctionByNameAndStage(cfg.entry_point_name,
ast::PipelineStage::kVertex);
auto* func = mod->FindFunctionBySymbolAndStage(
mod->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex);
if (func == nullptr) {
diag::Diagnostic err;
err.severity = diag::Severity::Error;
@ -94,9 +94,6 @@ Transform::Output VertexPulling::Run(ast::Module* in) {
return out;
}
// Save the vertex function
auto* vertex_func = mod->FindFunctionByName(func->name());
// TODO(idanr): Need to check shader locations in descriptor cover all
// attributes
@ -108,7 +105,7 @@ Transform::Output VertexPulling::Run(ast::Module* in) {
state.FindOrInsertInstanceIndexIfUsed();
state.ConvertVertexInputVariablesToPrivate();
state.AddVertexStorageBuffers();
state.AddVertexPullingPreamble(vertex_func);
state.AddVertexPullingPreamble(func);
return out;
}

View File

@ -47,8 +47,8 @@ class VertexPullingHelper {
// Create basic module with an entry point and vertex function
void InitBasicModule() {
auto* func = create<ast::Function>(
Source{}, "main", ast::VariableList{}, mod_->create<ast::type::Void>(),
create<ast::BlockStatement>(),
Source{}, mod_->RegisterSymbol("main"), "main", ast::VariableList{},
mod_->create<ast::type::Void>(), create<ast::BlockStatement>(),
ast::FunctionDecorationList{create<ast::StageDecoration>(
ast::PipelineStage::kVertex, Source{})});
mod()->AddFunction(func);
@ -134,8 +134,8 @@ TEST_F(VertexPullingTest, Error_InvalidEntryPoint) {
TEST_F(VertexPullingTest, Error_EntryPointWrongStage) {
auto* func = create<ast::Function>(
Source{}, "main", ast::VariableList{}, mod()->create<ast::type::Void>(),
create<ast::BlockStatement>(),
Source{}, mod()->RegisterSymbol("main"), "main", ast::VariableList{},
mod()->create<ast::type::Void>(), create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -152,7 +152,8 @@ TEST_F(VertexPullingTest, BasicModule) {
InitBasicModule();
InitTransform({});
auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
}
TEST_F(VertexPullingTest, OneAttribute) {
@ -164,7 +165,8 @@ TEST_F(VertexPullingTest, OneAttribute) {
InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}});
auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@ -193,7 +195,8 @@ TEST_F(VertexPullingTest, OneAttribute) {
storage_buffer
__struct_TintVertexData
}
Function main -> __void
Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex}
()
{
@ -250,7 +253,8 @@ TEST_F(VertexPullingTest, OneInstancedAttribute) {
{{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}});
auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@ -279,7 +283,8 @@ TEST_F(VertexPullingTest, OneInstancedAttribute) {
storage_buffer
__struct_TintVertexData
}
Function main -> __void
Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex}
()
{
@ -336,7 +341,8 @@ TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
transform()->SetPullingBufferBindingSet(5);
auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@ -365,7 +371,8 @@ TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
storage_buffer
__struct_TintVertexData
}
Function main -> __void
Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex}
()
{
@ -451,7 +458,8 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}});
auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@ -502,7 +510,8 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
storage_buffer
__struct_TintVertexData
}
Function main -> __void
Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex}
()
{
@ -592,7 +601,8 @@ TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
{{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}});
auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@ -626,7 +636,8 @@ TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
storage_buffer
__struct_TintVertexData
}
Function main -> __void
Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex}
()
{
@ -778,7 +789,8 @@ TEST_F(VertexPullingTest, FloatVectorAttributes) {
{16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}});
auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors());
ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{
TintVertexData Struct{
@ -835,7 +847,8 @@ TEST_F(VertexPullingTest, FloatVectorAttributes) {
storage_buffer
__struct_TintVertexData
}
Function main -> __void
Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex}
()
{

View File

@ -122,7 +122,7 @@ bool TypeDeterminer::Determine() {
continue;
}
for (const auto& callee : caller_to_callee_[func->name()]) {
set_entry_points(callee, func->name());
set_entry_points(callee, func->symbol());
}
}
@ -130,11 +130,11 @@ bool TypeDeterminer::Determine() {
}
void TypeDeterminer::set_entry_points(const std::string& fn_name,
const std::string& ep_name) {
name_to_function_[fn_name]->add_ancestor_entry_point(ep_name);
Symbol ep_sym) {
name_to_function_[fn_name]->add_ancestor_entry_point(ep_sym);
for (const auto& callee : caller_to_callee_[fn_name]) {
set_entry_points(callee, ep_name);
set_entry_points(callee, ep_sym);
}
}
@ -389,7 +389,8 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
if (current_function_) {
caller_to_callee_[current_function_->name()].push_back(ident->name());
auto* callee_func = mod_->FindFunctionByName(ident->name());
auto* callee_func =
mod_->FindFunctionBySymbol(mod_->GetSymbol(ident->name()));
if (callee_func == nullptr) {
set_error(expr->source(),
"unable to find called function: " + ident->name());

View File

@ -113,7 +113,7 @@ class TypeDeterminer {
private:
void set_error(const Source& src, const std::string& msg);
void set_referenced_from_function_if_needed(ast::Variable* var, bool local);
void set_entry_points(const std::string& fn_name, const std::string& ep_name);
void set_entry_points(const std::string& fn_name, Symbol ep_sym);
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
bool DetermineBinary(ast::BinaryExpression* expr);

View File

@ -341,9 +341,9 @@ TEST_F(TypeDeterminerTest, Stmt_Call) {
ast::type::F32 f32;
ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func);
// Register the function
@ -372,15 +372,16 @@ TEST_F(TypeDeterminerTest, Stmt_Call_undeclared) {
auto* main_body = create<ast::BlockStatement>();
main_body->append(create<ast::CallStatement>(call_expr));
main_body->append(create<ast::ReturnStatement>(Source{}));
auto* func_main =
create<ast::Function>(Source{}, "main", params0, &f32, main_body,
ast::FunctionDecorationList{});
auto* func_main = create<ast::Function>(Source{}, mod->RegisterSymbol("main"),
"main", params0, &f32, main_body,
ast::FunctionDecorationList{});
mod->AddFunction(func_main);
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "func", params0, &f32, body,
ast::FunctionDecorationList{});
auto* func =
create<ast::Function>(Source{}, mod->RegisterSymbol("func"), "func",
params0, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func);
EXPECT_FALSE(td()->Determine()) << td()->error();
@ -639,9 +640,9 @@ TEST_F(TypeDeterminerTest, Expr_Call) {
ast::type::F32 f32;
ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func);
// Register the function
@ -659,9 +660,9 @@ TEST_F(TypeDeterminerTest, Expr_Call_WithParams) {
ast::type::F32 f32;
ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func);
// Register the function
@ -809,8 +810,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) {
body->append(create<ast::AssignmentStatement>(
my_var, create<ast::IdentifierExpression>("my_var")));
ast::Function f(Source{}, "my_func", {}, &f32, body,
ast::FunctionDecorationList{});
ast::Function f(Source{}, mod->RegisterSymbol("my_func"), "my_func", {}, &f32,
body, ast::FunctionDecorationList{});
EXPECT_TRUE(td()->DetermineFunction(&f));
@ -836,8 +837,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
body->append(create<ast::AssignmentStatement>(
my_var, create<ast::IdentifierExpression>("my_var")));
ast::Function f(Source{}, "my_func", {}, &f32, body,
ast::FunctionDecorationList{});
ast::Function f(Source{}, mod->RegisterSymbol("myfunc"), "my_func", {}, &f32,
body, ast::FunctionDecorationList{});
EXPECT_TRUE(td()->DetermineFunction(&f));
@ -868,8 +869,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function_Ptr) {
body->append(create<ast::AssignmentStatement>(
my_var, create<ast::IdentifierExpression>("my_var")));
ast::Function f(Source{}, "my_func", {}, &f32, body,
ast::FunctionDecorationList{});
ast::Function f(Source{}, mod->RegisterSymbol("my_func"), "my_func", {}, &f32,
body, ast::FunctionDecorationList{});
EXPECT_TRUE(td()->DetermineFunction(&f));
@ -885,9 +886,9 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
ast::type::F32 f32;
ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func);
// Register the function
@ -968,8 +969,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables) {
body->append(create<ast::AssignmentStatement>(
create<ast::IdentifierExpression>("priv_var"),
create<ast::IdentifierExpression>("priv_var")));
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body,
ast::FunctionDecorationList{});
auto* func =
create<ast::Function>(Source{}, mod->RegisterSymbol("my_func"), "my_func",
params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func);
@ -1049,8 +1051,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) {
create<ast::IdentifierExpression>("priv_var"),
create<ast::IdentifierExpression>("priv_var")));
ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body,
ast::FunctionDecorationList{});
auto* func =
create<ast::Function>(Source{}, mod->RegisterSymbol("my_func"), "my_func",
params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func);
@ -1059,8 +1062,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) {
create<ast::IdentifierExpression>("out_var"),
create<ast::CallExpression>(create<ast::IdentifierExpression>("my_func"),
ast::ExpressionList{})));
auto* func2 = create<ast::Function>(Source{}, "func", params, &f32, body,
ast::FunctionDecorationList{});
auto* func2 =
create<ast::Function>(Source{}, mod->RegisterSymbol("func"), "func",
params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func2);
@ -1096,8 +1100,9 @@ TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) {
create<ast::FloatLiteral>(&f32, 1.f))));
ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body,
ast::FunctionDecorationList{});
auto* func =
create<ast::Function>(Source{}, mod->RegisterSymbol("my_func"), "my_func",
params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func);
@ -2636,8 +2641,9 @@ TEST_F(TypeDeterminerTest, StorageClass_SetsIfMissing) {
auto* body = create<ast::BlockStatement>();
body->append(stmt);
auto* func = create<ast::Function>(Source{}, "func", ast::VariableList{},
&i32, body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod->RegisterSymbol("func"),
"func", ast::VariableList{}, &i32, body,
ast::FunctionDecorationList{});
mod->AddFunction(func);
@ -2660,8 +2666,9 @@ TEST_F(TypeDeterminerTest, StorageClass_DoesNotSetOnConst) {
auto* body = create<ast::BlockStatement>();
body->append(stmt);
auto* func = create<ast::Function>(Source{}, "func", ast::VariableList{},
&i32, body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod->RegisterSymbol("func"),
"func", ast::VariableList{}, &i32, body,
ast::FunctionDecorationList{});
mod->AddFunction(func);
@ -2684,8 +2691,9 @@ TEST_F(TypeDeterminerTest, StorageClass_NonFunctionClassError) {
auto* body = create<ast::BlockStatement>();
body->append(stmt);
auto* func = create<ast::Function>(Source{}, "func", ast::VariableList{},
&i32, body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod->RegisterSymbol("func"),
"func", ast::VariableList{}, &i32, body,
ast::FunctionDecorationList{});
mod->AddFunction(func);
@ -4857,24 +4865,27 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
ast::VariableList params;
auto* body = create<ast::BlockStatement>();
auto* func_b = create<ast::Function>(Source{}, "b", params, &f32, body,
ast::FunctionDecorationList{});
auto* func_b =
create<ast::Function>(Source{}, mod->RegisterSymbol("b"), "b", params,
&f32, body, ast::FunctionDecorationList{});
body = create<ast::BlockStatement>();
body->append(create<ast::AssignmentStatement>(
create<ast::IdentifierExpression>("second"),
create<ast::CallExpression>(create<ast::IdentifierExpression>("b"),
ast::ExpressionList{})));
auto* func_c = create<ast::Function>(Source{}, "c", params, &f32, body,
ast::FunctionDecorationList{});
auto* func_c =
create<ast::Function>(Source{}, mod->RegisterSymbol("c"), "c", params,
&f32, body, ast::FunctionDecorationList{});
body = create<ast::BlockStatement>();
body->append(create<ast::AssignmentStatement>(
create<ast::IdentifierExpression>("first"),
create<ast::CallExpression>(create<ast::IdentifierExpression>("c"),
ast::ExpressionList{})));
auto* func_a = create<ast::Function>(Source{}, "a", params, &f32, body,
ast::FunctionDecorationList{});
auto* func_a =
create<ast::Function>(Source{}, mod->RegisterSymbol("a"), "a", params,
&f32, body, ast::FunctionDecorationList{});
body = create<ast::BlockStatement>();
body->append(create<ast::AssignmentStatement>(
@ -4886,7 +4897,7 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
create<ast::CallExpression>(create<ast::IdentifierExpression>("b"),
ast::ExpressionList{})));
auto* ep_1 = create<ast::Function>(
Source{}, "ep_1", params, &f32, body,
Source{}, mod->RegisterSymbol("ep_1"), "ep_1", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -4897,7 +4908,7 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
create<ast::CallExpression>(create<ast::IdentifierExpression>("c"),
ast::ExpressionList{})));
auto* ep_2 = create<ast::Function>(
Source{}, "ep_2", params, &f32, body,
Source{}, mod->RegisterSymbol("ep_2"), "ep_2", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -4954,17 +4965,17 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
const auto& b_eps = func_b->ancestor_entry_points();
ASSERT_EQ(2u, b_eps.size());
EXPECT_EQ("ep_1", b_eps[0]);
EXPECT_EQ("ep_2", b_eps[1]);
EXPECT_EQ(mod->RegisterSymbol("ep_1"), b_eps[0]);
EXPECT_EQ(mod->RegisterSymbol("ep_2"), b_eps[1]);
const auto& a_eps = func_a->ancestor_entry_points();
ASSERT_EQ(1u, a_eps.size());
EXPECT_EQ("ep_1", a_eps[0]);
EXPECT_EQ(mod->RegisterSymbol("ep_1"), a_eps[0]);
const auto& c_eps = func_c->ancestor_entry_points();
ASSERT_EQ(2u, c_eps.size());
EXPECT_EQ("ep_1", c_eps[0]);
EXPECT_EQ("ep_2", c_eps[1]);
EXPECT_EQ(mod->RegisterSymbol("ep_1"), c_eps[0]);
EXPECT_EQ(mod->RegisterSymbol("ep_2"), c_eps[1]);
EXPECT_TRUE(ep_1->ancestor_entry_points().empty());
EXPECT_TRUE(ep_2->ancestor_entry_points().empty());

View File

@ -54,7 +54,8 @@ TEST_F(ValidateFunctionTest, VoidFunctionEndWithoutReturnStatement_Pass) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var));
auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "func", params, &void_type, body,
Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -71,8 +72,8 @@ TEST_F(ValidateFunctionTest,
ast::type::Void void_type;
ast::VariableList params;
auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "func", params, &void_type,
create<ast::BlockStatement>(),
Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
params, &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -100,9 +101,9 @@ TEST_F(ValidateFunctionTest, FunctionEndWithoutReturnStatement_Fail) {
ast::type::Void void_type;
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var));
auto* func =
create<ast::Function>(Source{Source::Location{12, 34}}, "func", params,
&i32, body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
params, &i32, body, ast::FunctionDecorationList{});
mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error();
@ -117,8 +118,9 @@ TEST_F(ValidateFunctionTest, FunctionEndWithoutReturnStatementEmptyBody_Fail) {
ast::type::I32 i32;
ast::VariableList params;
auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "func", params, &i32,
create<ast::BlockStatement>(), ast::FunctionDecorationList{});
Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
params, &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error();
@ -136,7 +138,7 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_Pass) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "func", params, &void_type, body,
Source{}, mod()->RegisterSymbol("func"), "func", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -157,7 +159,8 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_fail) {
body->append(create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
return_expr));
auto* func = create<ast::Function>(Source{}, "func", params, &void_type, body,
auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("func"),
"func", params, &void_type, body,
ast::FunctionDecorationList{});
mod()->AddFunction(func);
@ -180,8 +183,9 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) {
body->append(create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
return_expr));
auto* func = create<ast::Function>(Source{}, "func", params, &f32, body,
ast::FunctionDecorationList{});
auto* func =
create<ast::Function>(Source{}, mod()->RegisterSymbol("func"), "func",
params, &f32, body, ast::FunctionDecorationList{});
mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error();
@ -204,8 +208,9 @@ TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) {
create<ast::SintLiteral>(&i32, 2));
body->append(create<ast::ReturnStatement>(Source{}, return_expr));
auto* func = create<ast::Function>(Source{}, "func", params, &i32, body,
ast::FunctionDecorationList{});
auto* func =
create<ast::Function>(Source{}, mod()->RegisterSymbol("func"), "func",
params, &i32, body, ast::FunctionDecorationList{});
ast::VariableList params_copy;
auto* body_copy = create<ast::BlockStatement>();
@ -213,9 +218,9 @@ TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) {
create<ast::SintLiteral>(&i32, 2));
body_copy->append(create<ast::ReturnStatement>(Source{}, return_expr_copy));
auto* func_copy = create<ast::Function>(Source{Source::Location{12, 34}},
"func", params_copy, &i32, body_copy,
ast::FunctionDecorationList{});
auto* func_copy = create<ast::Function>(
Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
params_copy, &i32, body_copy, ast::FunctionDecorationList{});
mod()->AddFunction(func);
mod()->AddFunction(func_copy);
@ -237,7 +242,8 @@ TEST_F(ValidateFunctionTest, RecursionIsNotAllowed_Fail) {
auto* body0 = create<ast::BlockStatement>();
body0->append(create<ast::CallStatement>(call_expr));
body0->append(create<ast::ReturnStatement>(Source{}));
auto* func0 = create<ast::Function>(Source{}, "func", params0, &f32, body0,
auto* func0 = create<ast::Function>(Source{}, mod()->RegisterSymbol("func"),
"func", params0, &f32, body0,
ast::FunctionDecorationList{});
mod()->AddFunction(func0);
@ -268,7 +274,8 @@ TEST_F(ValidateFunctionTest, RecursionIsNotAllowedExpr_Fail) {
create<ast::SintLiteral>(&i32, 2));
body0->append(create<ast::ReturnStatement>(Source{}, return_expr));
auto* func0 = create<ast::Function>(Source{}, "func", params0, &i32, body0,
auto* func0 = create<ast::Function>(Source{}, mod()->RegisterSymbol("func"),
"func", params0, &i32, body0,
ast::FunctionDecorationList{});
mod()->AddFunction(func0);
@ -288,7 +295,8 @@ TEST_F(ValidateFunctionTest, Function_WithPipelineStage_NotVoid_Fail) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}, return_expr));
auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "vtx_main", params, &i32, body,
Source{Source::Location{12, 34}}, mod()->RegisterSymbol("vtx_main"),
"vtx_main", params, &i32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -317,7 +325,8 @@ TEST_F(ValidateFunctionTest, Function_WithPipelineStage_WithParams_Fail) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "vtx_func", params, &void_type, body,
Source{Source::Location{12, 34}}, mod()->RegisterSymbol("vtx_func"),
"vtx_func", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -339,7 +348,8 @@ TEST_F(ValidateFunctionTest, PipelineStage_MustBeUnique_Fail) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "main", params, &void_type, body,
Source{Source::Location{12, 34}}, mod()->RegisterSymbol("main"), "main",
params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
@ -361,7 +371,8 @@ TEST_F(ValidateFunctionTest, OnePipelineStageFunctionMustBePresent_Pass) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "vtx_func", params, &void_type, body,
Source{}, mod()->RegisterSymbol("vtx_func"), "vtx_func", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -377,8 +388,9 @@ TEST_F(ValidateFunctionTest, OnePipelineStageFunctionMustBePresent_Fail) {
ast::VariableList params;
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "vtx_func", params, &void_type,
body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod()->RegisterSymbol("vtx_func"), "vtx_func", params,
&void_type, body, ast::FunctionDecorationList{});
mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error();

View File

@ -332,7 +332,8 @@ TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariable_Fail) {
body->append(create<ast::AssignmentStatement>(
Source{Source::Location{12, 34}}, lhs, rhs));
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body,
auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("my_func"),
"my_func", params, &f32, body,
ast::FunctionDecorationList{});
mod()->AddFunction(func);
@ -370,7 +371,8 @@ TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariable_Pass) {
Source{Source::Location{12, 34}}, lhs, rhs));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "my_func", params, &void_type, body,
Source{}, mod()->RegisterSymbol("my_func"), "my_func", params, &void_type,
body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -587,8 +589,9 @@ TEST_F(ValidatorTest, GlobalVariableFunctionVariableNotUnique_Fail) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("my_func"),
"my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod()->AddFunction(func);
@ -631,8 +634,9 @@ TEST_F(ValidatorTest, RedeclaredIndentifier_Fail) {
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var_a_float));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("my_func"),
"my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod()->AddFunction(func);
@ -759,8 +763,9 @@ TEST_F(ValidatorTest, RedeclaredIdentifierDifferentFunctions_Pass) {
body0->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var0));
body0->append(create<ast::ReturnStatement>(Source{}));
auto* func0 = create<ast::Function>(Source{}, "func0", params0, &void_type,
body0, ast::FunctionDecorationList{});
auto* func0 = create<ast::Function>(Source{}, mod()->RegisterSymbol("func0"),
"func0", params0, &void_type, body0,
ast::FunctionDecorationList{});
ast::VariableList params1;
auto* body1 = create<ast::BlockStatement>();
@ -768,7 +773,8 @@ TEST_F(ValidatorTest, RedeclaredIdentifierDifferentFunctions_Pass) {
Source{Source::Location{13, 34}}, var1));
body1->append(create<ast::ReturnStatement>(Source{}));
auto* func1 = create<ast::Function>(
Source{}, "func1", params1, &void_type, body1,
Source{}, mod()->RegisterSymbol("func1"), "func1", params1, &void_type,
body1,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});

View File

@ -206,8 +206,9 @@ TEST_F(ValidatorTypeTest, RuntimeArrayInFunction_Fail) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var));
auto* func = create<ast::Function>(
Source{}, "func", params, &void_type, body,
Source{}, mod()->RegisterSymbol("func"), "func", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});

View File

@ -196,15 +196,15 @@ std::string GeneratorImpl::current_ep_var_name(VarType type) {
std::string name = "";
switch (type) {
case VarType::kIn: {
auto in_it = ep_name_to_in_data_.find(current_ep_name_);
if (in_it != ep_name_to_in_data_.end()) {
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
if (in_it != ep_sym_to_in_data_.end()) {
name = in_it->second.var_name;
}
break;
}
case VarType::kOut: {
auto outit = ep_name_to_out_data_.find(current_ep_name_);
if (outit != ep_name_to_out_data_.end()) {
auto outit = ep_sym_to_out_data_.find(current_ep_sym_.value());
if (outit != ep_sym_to_out_data_.end()) {
name = outit->second.var_name;
}
break;
@ -668,12 +668,14 @@ bool GeneratorImpl::EmitCall(std::ostream& pre,
}
auto name = ident->name();
auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name);
auto caller_sym = module_->GetSymbol(name);
auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" +
caller_sym.to_str());
if (it != ep_func_name_remapped_.end()) {
name = it->second;
}
auto* func = module_->FindFunctionByName(ident->name());
auto* func = module_->FindFunctionBySymbol(module_->GetSymbol(ident->name()));
if (func == nullptr) {
error_ = "Unable to find function: " + name;
return false;
@ -1189,15 +1191,15 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) {
has_referenced_var_needing_struct(func);
if (emit_duplicate_functions) {
for (const auto& ep_name : func->ancestor_entry_points()) {
if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_name)) {
for (const auto& ep_sym : func->ancestor_entry_points()) {
if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_sym)) {
return false;
}
out << std::endl;
}
} else {
// Emit as non-duplicated
if (!EmitFunctionInternal(out, func, false, "")) {
if (!EmitFunctionInternal(out, func, false, Symbol())) {
return false;
}
out << std::endl;
@ -1209,8 +1211,8 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) {
bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
ast::Function* func,
bool emit_duplicate_functions,
const std::string& ep_name) {
auto name = func->name();
Symbol ep_sym) {
auto name = func->symbol().to_str();
if (!EmitType(out, func->return_type(), "")) {
return false;
@ -1219,10 +1221,15 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
out << " ";
if (emit_duplicate_functions) {
name = generate_name(name + "_" + ep_name);
ep_func_name_remapped_[ep_name + "_" + func->name()] = name;
auto func_name = name;
auto ep_name = ep_sym.to_str();
// TODO(dsinclair): The SymbolToName should go away and just use
// to_str() here when the conversion is complete.
name = generate_name(func->name() + "_" + module_->SymbolToName(ep_sym));
ep_func_name_remapped_[ep_name + "_" + func_name] = name;
} else {
name = namer_.NameFor(name);
// TODO(dsinclair): this should be updated to a remapped name
name = namer_.NameFor(func->name());
}
out << name << "(";
@ -1234,15 +1241,15 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
//
// We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) {
auto in_it = ep_name_to_in_data_.find(ep_name);
if (in_it != ep_name_to_in_data_.end()) {
auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
if (in_it != ep_sym_to_in_data_.end()) {
out << "in " << in_it->second.struct_name << " "
<< in_it->second.var_name;
first = false;
}
auto outit = ep_name_to_out_data_.find(ep_name);
if (outit != ep_name_to_out_data_.end()) {
auto outit = ep_sym_to_out_data_.find(ep_sym.value());
if (outit != ep_sym_to_out_data_.end()) {
if (!first) {
out << ", ";
}
@ -1269,13 +1276,13 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
out << ") ";
current_ep_name_ = ep_name;
current_ep_sym_ = ep_sym;
if (!EmitBlockAndNewline(out, func->body())) {
return false;
}
current_ep_name_ = "";
current_ep_sym_ = Symbol();
return true;
}
@ -1392,7 +1399,7 @@ bool GeneratorImpl::EmitEntryPointData(
auto in_struct_name =
generate_name(func->name() + "_" + kInStructNameSuffix);
auto in_var_name = generate_name(kTintStructInVarPrefix);
ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name};
ep_sym_to_in_data_[func->symbol().value()] = {in_struct_name, in_var_name};
make_indent(out);
out << "struct " << in_struct_name << " {" << std::endl;
@ -1438,7 +1445,7 @@ bool GeneratorImpl::EmitEntryPointData(
auto outstruct_name =
generate_name(func->name() + "_" + kOutStructNameSuffix);
auto outvar_name = generate_name(kTintStructOutVarPrefix);
ep_name_to_out_data_[func->name()] = {outstruct_name, outvar_name};
ep_sym_to_out_data_[func->symbol().value()] = {outstruct_name, outvar_name};
make_indent(out);
out << "struct " << outstruct_name << " {" << std::endl;
@ -1516,7 +1523,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
ast::Function* func) {
make_indent(out);
current_ep_name_ = func->name();
current_ep_sym_ = func->symbol();
if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
uint32_t x = 0;
@ -1528,17 +1535,18 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
make_indent(out);
}
auto outdata = ep_name_to_out_data_.find(current_ep_name_);
bool has_outdata = outdata != ep_name_to_out_data_.end();
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
bool has_outdata = outdata != ep_sym_to_out_data_.end();
if (has_outdata) {
out << outdata->second.struct_name;
} else {
out << "void";
}
out << " " << namer_.NameFor(current_ep_name_) << "(";
// TODO(dsinclair): This should output the remapped name
out << " " << namer_.NameFor(module_->SymbolToName(current_ep_sym_)) << "(";
auto in_data = ep_name_to_in_data_.find(current_ep_name_);
if (in_data != ep_name_to_in_data_.end()) {
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
if (in_data != ep_sym_to_in_data_.end()) {
out << in_data->second.struct_name << " " << in_data->second.var_name;
}
out << ") {" << std::endl;
@ -1563,7 +1571,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
make_indent(out);
out << "}" << std::endl;
current_ep_name_ = "";
current_ep_sym_ = Symbol();
return true;
}
@ -1966,8 +1974,8 @@ bool GeneratorImpl::EmitReturn(std::ostream& out, ast::ReturnStatement* stmt) {
if (generating_entry_point_) {
out << "return";
auto outdata = ep_name_to_out_data_.find(current_ep_name_);
if (outdata != ep_name_to_out_data_.end()) {
auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
if (outdata != ep_sym_to_out_data_.end()) {
out << " " << outdata->second.var_name;
}
} else if (stmt->has_value()) {

View File

@ -210,12 +210,12 @@ class GeneratorImpl {
/// @param func the function to emit
/// @param emit_duplicate_functions set true if we need to duplicate per entry
/// point
/// @param ep_name the current entry point or blank if none set
/// @param ep_sym the current entry point or symbol::kInvalid if none set
/// @returns true if the function was emitted.
bool EmitFunctionInternal(std::ostream& out,
ast::Function* func,
bool emit_duplicate_functions,
const std::string& ep_name);
Symbol ep_sym);
/// Handles emitting information for an entry point
/// @param out the output stream
/// @param func the entry point
@ -397,12 +397,12 @@ class GeneratorImpl {
Namer namer_;
ast::Module* module_ = nullptr;
std::string current_ep_name_;
Symbol current_ep_sym_;
bool generating_entry_point_ = false;
uint32_t loop_emission_counter_ = 0;
ScopeStack<ast::Variable*> global_variables_;
std::unordered_map<std::string, EntryPointData> ep_name_to_in_data_;
std::unordered_map<std::string, EntryPointData> ep_name_to_out_data_;
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did

View File

@ -613,9 +613,9 @@ TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) {
ast::type::Void void_type;
auto* func = create<ast::Function>(Source{}, "foo", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("foo"), "foo", ast::VariableList{},
&void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ast::ExpressionList params;

View File

@ -35,9 +35,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithoutParams) {
auto* id = create<ast::IdentifierExpression>("my_func");
ast::CallExpression call(id, {});
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error();
@ -53,9 +53,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithParams) {
params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallExpression call(id, params);
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error();
@ -71,9 +71,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitStatement_Call) {
params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallStatement call(create<ast::CallExpression>(id, params));
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(out, &call)) << gen.error();

View File

@ -91,7 +91,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "vtx_main", params, &f32, body,
Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -164,7 +164,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "vtx_main", params, &f32, body,
Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -237,7 +237,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -309,7 +309,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -378,7 +378,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@ -442,7 +442,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@ -512,7 +512,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("x"))));
auto* func = create<ast::Function>(
Source{}, "main", params, &void_type, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});

View File

@ -57,9 +57,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "my_func", ast::VariableList{},
&void_type, body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
"my_func", ast::VariableList{}, &void_type,
body, ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@ -77,9 +77,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Name_Collision) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "GeometryShader", ast::VariableList{},
&void_type, body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("GeometryShader"), "GeometryShader",
ast::VariableList{}, &void_type, body, ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@ -118,8 +118,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithParams) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
"my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@ -174,7 +175,8 @@ TEST_F(HlslGeneratorImplTest_Function,
create<ast::IdentifierExpression>("foo")));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -245,7 +247,8 @@ TEST_F(HlslGeneratorImplTest_Function,
create<ast::IdentifierExpression>("x"))));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -309,7 +312,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -380,7 +384,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -455,7 +460,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -526,7 +532,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -594,7 +601,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(assign);
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -682,8 +690,9 @@ TEST_F(
create<ast::IdentifierExpression>("param")));
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("foo")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -698,7 +707,7 @@ TEST_F(
expr)));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body,
Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -766,8 +775,9 @@ TEST_F(HlslGeneratorImplTest_Function,
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -782,7 +792,7 @@ TEST_F(HlslGeneratorImplTest_Function,
expr)));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body,
Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -863,8 +873,9 @@ TEST_F(
create<ast::IdentifierExpression>("x"))));
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -879,7 +890,7 @@ TEST_F(
expr)));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body,
Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -948,8 +959,9 @@ TEST_F(HlslGeneratorImplTest_Function,
Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("x"))));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -971,7 +983,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -1034,8 +1047,9 @@ TEST_F(HlslGeneratorImplTest_Function,
Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("x"))));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -1057,7 +1071,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -1122,7 +1137,7 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body,
Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -1152,8 +1167,8 @@ TEST_F(HlslGeneratorImplTest_Function,
ast::type::Void void_type;
auto* func = create<ast::Function>(
Source{}, "GeometryShader", ast::VariableList{}, &void_type,
create<ast::BlockStatement>(),
Source{}, mod.RegisterSymbol("GeometryShader"), "GeometryShader",
ast::VariableList{}, &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -1175,7 +1190,7 @@ TEST_F(HlslGeneratorImplTest_Function,
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "main", params, &void_type, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@ -1200,7 +1215,7 @@ TEST_F(HlslGeneratorImplTest_Function,
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "main", params, &void_type, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
@ -1236,8 +1251,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
"my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@ -1317,12 +1333,12 @@ TEST_F(HlslGeneratorImplTest_Function,
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "a", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(
ast::PipelineStage::kCompute, Source{}),
});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute,
Source{}),
});
mod.AddFunction(func);
}
@ -1343,12 +1359,12 @@ TEST_F(HlslGeneratorImplTest_Function,
auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "b", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(
ast::PipelineStage::kCompute, Source{}),
});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute,
Source{}),
});
mod.AddFunction(func);
}

View File

@ -29,9 +29,9 @@ using HlslGeneratorImplTest = TestHelper;
TEST_F(HlslGeneratorImplTest, Generate) {
ast::type::Void void_type;
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ASSERT_TRUE(gen.Generate(out)) << gen.error();

View File

@ -411,15 +411,15 @@ std::string GeneratorImpl::current_ep_var_name(VarType type) {
std::string name = "";
switch (type) {
case VarType::kIn: {
auto in_it = ep_name_to_in_data_.find(current_ep_name_);
if (in_it != ep_name_to_in_data_.end()) {
auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
if (in_it != ep_sym_to_in_data_.end()) {
name = in_it->second.var_name;
}
break;
}
case VarType::kOut: {
auto out_it = ep_name_to_out_data_.find(current_ep_name_);
if (out_it != ep_name_to_out_data_.end()) {
auto out_it = ep_sym_to_out_data_.find(current_ep_sym_.value());
if (out_it != ep_sym_to_out_data_.end()) {
name = out_it->second.var_name;
}
break;
@ -573,12 +573,14 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
}
auto name = ident->name();
auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name);
auto caller_sym = module_->GetSymbol(name);
auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" +
caller_sym.to_str());
if (it != ep_func_name_remapped_.end()) {
name = it->second;
}
auto* func = module_->FindFunctionByName(ident->name());
auto* func = module_->FindFunctionBySymbol(module_->GetSymbol(ident->name()));
if (func == nullptr) {
error_ = "Unable to find function: " + name;
return false;
@ -1026,7 +1028,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
auto in_struct_name =
generate_name(func->name() + "_" + kInStructNameSuffix);
auto in_var_name = generate_name(kTintStructInVarPrefix);
ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name};
ep_sym_to_in_data_[func->symbol().value()] = {in_struct_name, in_var_name};
make_indent();
out_ << "struct " << in_struct_name << " {" << std::endl;
@ -1063,7 +1065,8 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
auto out_struct_name =
generate_name(func->name() + "_" + kOutStructNameSuffix);
auto out_var_name = generate_name(kTintStructOutVarPrefix);
ep_name_to_out_data_[func->name()] = {out_struct_name, out_var_name};
ep_sym_to_out_data_[func->symbol().value()] = {out_struct_name,
out_var_name};
make_indent();
out_ << "struct " << out_struct_name << " {" << std::endl;
@ -1205,15 +1208,15 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
has_referenced_var_needing_struct(func);
if (emit_duplicate_functions) {
for (const auto& ep_name : func->ancestor_entry_points()) {
if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) {
for (const auto& ep_sym : func->ancestor_entry_points()) {
if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_sym)) {
return false;
}
out_ << std::endl;
}
} else {
// Emit as non-duplicated
if (!EmitFunctionInternal(func, false, "")) {
if (!EmitFunctionInternal(func, false, Symbol())) {
return false;
}
out_ << std::endl;
@ -1224,19 +1227,23 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
bool emit_duplicate_functions,
const std::string& ep_name) {
auto name = func->name();
Symbol ep_sym) {
auto name = func->symbol().to_str();
if (!EmitType(func->return_type(), "")) {
return false;
}
out_ << " ";
if (emit_duplicate_functions) {
name = generate_name(name + "_" + ep_name);
ep_func_name_remapped_[ep_name + "_" + func->name()] = name;
auto func_name = name;
auto ep_name = ep_sym.to_str();
// TODO(dsinclair): The SymbolToName should go away and just use
// to_str() here when the conversion is complete.
name = generate_name(func->name() + "_" + module_->SymbolToName(ep_sym));
ep_func_name_remapped_[ep_name + "_" + func_name] = name;
} else {
name = namer_.NameFor(name);
// TODO(dsinclair): this should be updated to a remapped name
name = namer_.NameFor(func->name());
}
out_ << name << "(";
@ -1247,15 +1254,15 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
//
// We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) {
auto in_it = ep_name_to_in_data_.find(ep_name);
if (in_it != ep_name_to_in_data_.end()) {
auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
if (in_it != ep_sym_to_in_data_.end()) {
out_ << "thread " << in_it->second.struct_name << "& "
<< in_it->second.var_name;
first = false;
}
auto out_it = ep_name_to_out_data_.find(ep_name);
if (out_it != ep_name_to_out_data_.end()) {
auto out_it = ep_sym_to_out_data_.find(ep_sym.value());
if (out_it != ep_sym_to_out_data_.end()) {
if (!first) {
out_ << ", ";
}
@ -1337,13 +1344,13 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
out_ << ") ";
current_ep_name_ = ep_name;
current_ep_sym_ = ep_sym;
if (!EmitBlockAndNewline(func->body())) {
return false;
}
current_ep_name_ = "";
current_ep_sym_ = Symbol();
return true;
}
@ -1377,25 +1384,25 @@ std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
make_indent();
current_ep_name_ = func->name();
current_ep_sym_ = func->symbol();
EmitStage(func->pipeline_stage());
out_ << " ";
// This is an entry point, the return type is the entry point output structure
// if one exists, or void otherwise.
auto out_data = ep_name_to_out_data_.find(current_ep_name_);
bool has_out_data = out_data != ep_name_to_out_data_.end();
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
bool has_out_data = out_data != ep_sym_to_out_data_.end();
if (has_out_data) {
out_ << out_data->second.struct_name;
} else {
out_ << "void";
}
out_ << " " << namer_.NameFor(current_ep_name_) << "(";
out_ << " " << namer_.NameFor(func->name()) << "(";
bool first = true;
auto in_data = ep_name_to_in_data_.find(current_ep_name_);
if (in_data != ep_name_to_in_data_.end()) {
auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
if (in_data != ep_sym_to_in_data_.end()) {
out_ << in_data->second.struct_name << " " << in_data->second.var_name
<< " [[stage_in]]";
first = false;
@ -1503,7 +1510,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
make_indent();
out_ << "}" << std::endl;
current_ep_name_ = "";
current_ep_sym_ = Symbol();
return true;
}
@ -1687,8 +1694,8 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) {
out_ << "return";
if (generating_entry_point_) {
auto out_data = ep_name_to_out_data_.find(current_ep_name_);
if (out_data != ep_name_to_out_data_.end()) {
auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
if (out_data != ep_sym_to_out_data_.end()) {
out_ << " " << out_data->second.var_name;
}
} else if (stmt->has_value()) {

View File

@ -156,11 +156,11 @@ class GeneratorImpl : public TextGenerator {
/// @param func the function to emit
/// @param emit_duplicate_functions set true if we need to duplicate per entry
/// point
/// @param ep_name the current entry point or blank if none set
/// @param ep_sym the current entry point or symbol::kInvalid if not set
/// @returns true if the function was emitted.
bool EmitFunctionInternal(ast::Function* func,
bool emit_duplicate_functions,
const std::string& ep_name);
Symbol ep_sym);
/// Handles generating an identifier expression
/// @param expr the identifier expression
/// @returns true if the identifier was emitted
@ -282,13 +282,13 @@ class GeneratorImpl : public TextGenerator {
Namer namer_;
ScopeStack<ast::Variable*> global_variables_;
std::string current_ep_name_;
Symbol current_ep_sym_;
bool generating_entry_point_ = false;
const ast::Module* module_ = nullptr;
uint32_t loop_emission_counter_ = 0;
std::unordered_map<std::string, EntryPointData> ep_name_to_in_data_;
std::unordered_map<std::string, EntryPointData> ep_name_to_out_data_;
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did

View File

@ -37,9 +37,9 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithoutParams) {
auto* id = create<ast::IdentifierExpression>("my_func");
ast::CallExpression call(id, {});
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error();
@ -55,9 +55,9 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithParams) {
params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallExpression call(id, params);
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error();
@ -73,9 +73,9 @@ TEST_F(MslGeneratorImplTest, EmitStatement_Call) {
params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallStatement call(create<ast::CallExpression>(id, params));
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();

View File

@ -90,7 +90,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Input) {
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "vtx_main", params, &f32, body,
Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -160,7 +160,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Output) {
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "vtx_main", params, &f32, body,
Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -229,7 +229,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Input) {
create<ast::IdentifierExpression>("bar"),
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -299,7 +299,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Output) {
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -366,7 +366,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Input) {
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@ -428,7 +428,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Output) {
create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@ -496,7 +496,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Builtins) {
create<ast::IdentifierExpression>("x"))));
auto* func = create<ast::Function>(
Source{}, "main", params, &void_type, body,
Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});

View File

@ -60,9 +60,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "my_func", ast::VariableList{},
&void_type, body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
"my_func", ast::VariableList{}, &void_type,
body, ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@ -82,9 +82,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_Name_Collision) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "main", ast::VariableList{}, &void_type,
body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("main"),
"main", ast::VariableList{}, &void_type,
body, ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@ -125,8 +125,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithParams) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
"my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod.AddFunction(func);
gen.increment_indent();
@ -183,7 +184,8 @@ TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_WithInOutVars) {
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{create<ast::StageDecoration>(
ast::PipelineStage::kFragment, Source{})});
@ -257,7 +259,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -321,7 +324,8 @@ TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_With_Uniform) {
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -397,7 +401,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -478,7 +483,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -572,8 +578,9 @@ TEST_F(
create<ast::IdentifierExpression>("param")));
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("foo")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -588,7 +595,7 @@ TEST_F(
expr)));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body,
Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -659,8 +666,9 @@ TEST_F(MslGeneratorImplTest,
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -676,7 +684,7 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body,
Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -760,8 +768,9 @@ TEST_F(
create<ast::IdentifierExpression>("x"))));
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -776,7 +785,7 @@ TEST_F(
expr)));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body,
Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -843,8 +852,9 @@ TEST_F(MslGeneratorImplTest,
Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("x"))));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -867,7 +877,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -943,8 +954,9 @@ TEST_F(MslGeneratorImplTest,
Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("b"))));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -967,7 +979,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -1049,8 +1062,9 @@ TEST_F(MslGeneratorImplTest,
Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("b"))));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32,
body, ast::FunctionDecorationList{});
auto* sub_func = create<ast::Function>(
Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func);
@ -1073,7 +1087,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body,
Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -1145,7 +1160,7 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body,
Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -1177,8 +1192,8 @@ TEST_F(MslGeneratorImplTest,
ast::type::Void void_type;
auto* func = create<ast::Function>(
Source{}, "main", ast::VariableList{}, &void_type,
create<ast::BlockStatement>(),
Source{}, mod.RegisterSymbol("main"), "main", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@ -1212,8 +1227,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type,
body, ast::FunctionDecorationList{});
auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
"my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod.AddFunction(func);
@ -1298,12 +1314,12 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "a", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(
ast::PipelineStage::kCompute, Source{}),
});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute,
Source{}),
});
mod.AddFunction(func);
}
@ -1325,12 +1341,12 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "b", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(
ast::PipelineStage::kCompute, Source{}),
});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute,
Source{}),
});
mod.AddFunction(func);
}

View File

@ -51,8 +51,8 @@ TEST_F(MslGeneratorImplTest, Generate) {
ast::type::Void void_type;
auto* func = create<ast::Function>(
Source{}, "my_func", ast::VariableList{}, &void_type,
create<ast::BlockStatement>(),
Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});

View File

@ -65,11 +65,11 @@ TEST_F(BuilderTest, Expression_Call) {
Source{}, create<ast::BinaryExpression>(
ast::BinaryOp::kAdd, create<ast::IdentifierExpression>("a"),
create<ast::IdentifierExpression>("b"))));
ast::Function a_func(Source{}, "a_func", func_params, &f32, body,
ast::FunctionDecorationList{});
ast::Function a_func(Source{}, mod->RegisterSymbol("a_func"), "a_func",
func_params, &f32, body, ast::FunctionDecorationList{});
ast::Function func(Source{}, "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ast::ExpressionList call_params;
@ -143,11 +143,12 @@ TEST_F(BuilderTest, Statement_Call) {
ast::BinaryOp::kAdd, create<ast::IdentifierExpression>("a"),
create<ast::IdentifierExpression>("b"))));
ast::Function a_func(Source{}, "a_func", func_params, &void_type, body,
ast::Function a_func(Source{}, mod->RegisterSymbol("a_func"), "a_func",
func_params, &void_type, body,
ast::FunctionDecorationList{});
ast::Function func(Source{}, "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ast::ExpressionList call_params;

View File

@ -42,7 +42,8 @@ TEST_F(BuilderTest, FunctionDecoration_Stage) {
ast::type::Void void_type;
ast::Function func(
Source{}, "main", {}, &void_type, create<ast::BlockStatement>(),
Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -67,8 +68,8 @@ TEST_P(FunctionDecoration_StageTest, Emit) {
ast::type::Void void_type;
ast::Function func(Source{}, "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(params.stage, Source{}),
});
@ -97,7 +98,8 @@ TEST_F(BuilderTest, FunctionDecoration_Stage_WithUnusedInterfaceIds) {
ast::type::Void void_type;
ast::Function func(
Source{}, "main", {}, &void_type, create<ast::BlockStatement>(),
Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -174,7 +176,7 @@ TEST_F(BuilderTest, FunctionDecoration_Stage_WithUsedInterfaceIds) {
create<ast::IdentifierExpression>("my_in")));
ast::Function func(
Source{}, "main", {}, &void_type, body,
Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
});
@ -244,7 +246,8 @@ TEST_F(BuilderTest, FunctionDecoration_ExecutionMode_Fragment_OriginUpperLeft) {
ast::type::Void void_type;
ast::Function func(
Source{}, "main", {}, &void_type, create<ast::BlockStatement>(),
Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -259,7 +262,8 @@ TEST_F(BuilderTest, FunctionDecoration_WorkgroupSize_Default) {
ast::type::Void void_type;
ast::Function func(
Source{}, "main", {}, &void_type, create<ast::BlockStatement>(),
Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
});
@ -274,7 +278,8 @@ TEST_F(BuilderTest, FunctionDecoration_WorkgroupSize) {
ast::type::Void void_type;
ast::Function func(
Source{}, "main", {}, &void_type, create<ast::BlockStatement>(),
Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
@ -290,13 +295,15 @@ TEST_F(BuilderTest, FunctionDecoration_ExecutionMode_MultipleFragment) {
ast::type::Void void_type;
ast::Function func1(
Source{}, "main1", {}, &void_type, create<ast::BlockStatement>(),
Source{}, mod->RegisterSymbol("main1"), "main1", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
ast::Function func2(
Source{}, "main2", {}, &void_type, create<ast::BlockStatement>(),
Source{}, mod->RegisterSymbol("main2"), "main2", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});

View File

@ -47,8 +47,8 @@ using BuilderTest = TestHelper;
TEST_F(BuilderTest, Function_Empty) {
ast::type::Void void_type;
ast::Function func(Source{}, "a_func", {}, &void_type,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func));
@ -68,8 +68,8 @@ TEST_F(BuilderTest, Function_Terminator_Return) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
ast::Function func(Source{}, "a_func", {}, &void_type, body,
ast::FunctionDecorationList{});
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func));
EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func"
@ -101,8 +101,8 @@ TEST_F(BuilderTest, Function_Terminator_ReturnValue) {
Source{}, create<ast::IdentifierExpression>("a")));
ASSERT_TRUE(td.DetermineResultType(body)) << td.error();
ast::Function func(Source{}, "a_func", {}, &void_type, body,
ast::FunctionDecorationList{});
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error();
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -128,8 +128,8 @@ TEST_F(BuilderTest, Function_Terminator_Discard) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::DiscardStatement>());
ast::Function func(Source{}, "a_func", {}, &void_type, body,
ast::FunctionDecorationList{});
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func));
EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func"
@ -168,8 +168,8 @@ TEST_F(BuilderTest, Function_WithParams) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("a")));
ast::Function func(Source{}, "a_func", params, &f32, body,
ast::FunctionDecorationList{});
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", params,
&f32, body, ast::FunctionDecorationList{});
td.RegisterVariableForTesting(func.params()[0]);
td.RegisterVariableForTesting(func.params()[1]);
@ -197,8 +197,8 @@ TEST_F(BuilderTest, Function_WithBody) {
auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}));
ast::Function func(Source{}, "a_func", {}, &void_type, body,
ast::FunctionDecorationList{});
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func));
EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func"
@ -213,8 +213,8 @@ OpFunctionEnd
TEST_F(BuilderTest, FunctionType) {
ast::type::Void void_type;
ast::Function func(Source{}, "a_func", {}, &void_type,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func));
@ -225,11 +225,11 @@ TEST_F(BuilderTest, FunctionType) {
TEST_F(BuilderTest, FunctionType_DeDuplicate) {
ast::type::Void void_type;
ast::Function func1(Source{}, "a_func", {}, &void_type,
create<ast::BlockStatement>(),
ast::Function func1(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ast::Function func2(Source{}, "b_func", {}, &void_type,
create<ast::BlockStatement>(),
ast::Function func2(Source{}, mod->RegisterSymbol("b_func"), "b_func", {},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func1));
@ -307,12 +307,12 @@ TEST_F(BuilderTest, Emit_Multiple_EntryPoint_With_Same_ModuleVar) {
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "a", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(
ast::PipelineStage::kCompute, Source{}),
});
auto* func = create<ast::Function>(
Source{}, mod->RegisterSymbol("a"), "a", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute,
Source{}),
});
mod->AddFunction(func);
}
@ -334,12 +334,12 @@ TEST_F(BuilderTest, Emit_Multiple_EntryPoint_With_Same_ModuleVar) {
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "b", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(
ast::PipelineStage::kCompute, Source{}),
});
auto* func = create<ast::Function>(
Source{}, mod->RegisterSymbol("b"), "b", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute,
Source{}),
});
mod->AddFunction(func);
}

View File

@ -471,8 +471,8 @@ TEST_F(IntrinsicBuilderTest, Call_GLSLMethod_WithLoad) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
@ -505,8 +505,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Float_Test, Call_Scalar) {
auto expr = Call(param.name, 1.0f);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -533,8 +533,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Float_Test, Call_Vector) {
auto expr = Call(param.name, vec2<f32>(1.0f, 1.0f));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -587,8 +587,8 @@ TEST_F(IntrinsicBuilderTest, Call_Length_Scalar) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -612,8 +612,8 @@ TEST_F(IntrinsicBuilderTest, Call_Length_Vector) {
auto expr = Call("length", vec2<f32>(1.0f, 1.0f));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -639,8 +639,8 @@ TEST_F(IntrinsicBuilderTest, Call_Normalize) {
auto expr = Call("normalize", vec2<f32>(1.0f, 1.0f));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -671,8 +671,8 @@ TEST_P(Intrinsic_Builtin_DualParam_Float_Test, Call_Scalar) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -700,8 +700,8 @@ TEST_P(Intrinsic_Builtin_DualParam_Float_Test, Call_Vector) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -737,8 +737,8 @@ TEST_F(IntrinsicBuilderTest, Call_Distance_Scalar) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -763,8 +763,8 @@ TEST_F(IntrinsicBuilderTest, Call_Distance_Vector) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -792,8 +792,8 @@ TEST_F(IntrinsicBuilderTest, Call_Cross) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -823,8 +823,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Float_Test, Call_Scalar) {
auto expr = Call(param.name, 1.0f, 1.0f, 1.0f);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -853,8 +853,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Float_Test, Call_Vector) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -894,8 +894,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Sint_Test, Call_Scalar) {
auto expr = Call(param.name, 1);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -922,8 +922,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Sint_Test, Call_Vector) {
auto expr = Call(param.name, vec2<i32>(1, 1));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -957,8 +957,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Uint_Test, Call_Scalar) {
auto expr = Call(param.name, 1u);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -985,8 +985,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Uint_Test, Call_Vector) {
auto expr = Call(param.name, vec2<u32>(1u, 1u));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1020,8 +1020,8 @@ TEST_P(Intrinsic_Builtin_DualParam_SInt_Test, Call_Scalar) {
auto expr = Call(param.name, 1, 1);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1048,8 +1048,8 @@ TEST_P(Intrinsic_Builtin_DualParam_SInt_Test, Call_Vector) {
auto expr = Call(param.name, vec2<i32>(1, 1), vec2<i32>(1, 1));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1084,8 +1084,8 @@ TEST_P(Intrinsic_Builtin_DualParam_UInt_Test, Call_Scalar) {
auto expr = Call(param.name, 1u, 1u);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1112,8 +1112,8 @@ TEST_P(Intrinsic_Builtin_DualParam_UInt_Test, Call_Vector) {
auto expr = Call(param.name, vec2<u32>(1u, 1u), vec2<u32>(1u, 1u));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1148,8 +1148,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Sint_Test, Call_Scalar) {
auto expr = Call(param.name, 1, 1, 1);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1178,8 +1178,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Sint_Test, Call_Vector) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1213,8 +1213,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Uint_Test, Call_Scalar) {
auto expr = Call(param.name, 1u, 1u, 1u);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1243,8 +1243,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Uint_Test, Call_Vector) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1276,8 +1276,8 @@ TEST_F(IntrinsicBuilderTest, Call_Determinant) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1320,8 +1320,8 @@ TEST_F(IntrinsicBuilderTest, Call_ArrayLength) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1360,8 +1360,8 @@ TEST_F(IntrinsicBuilderTest, Call_ArrayLength_OtherMembersInStruct) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1405,8 +1405,8 @@ TEST_F(IntrinsicBuilderTest, DISABLED_Call_ArrayLength_Ptr) {
auto expr = Call("arrayLength", "ptr_var");
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();

View File

@ -121,8 +121,8 @@ TEST_F(BuilderTest, Switch_WithCase) {
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@ -201,8 +201,8 @@ TEST_F(BuilderTest, Switch_WithDefault) {
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@ -300,8 +300,8 @@ TEST_F(BuilderTest, Switch_WithCaseAndDefault) {
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@ -408,8 +408,8 @@ TEST_F(BuilderTest, Switch_CaseWithFallthrough) {
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@ -495,8 +495,8 @@ TEST_F(BuilderTest, Switch_CaseFallthroughLastStatement) {
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@ -563,8 +563,8 @@ TEST_F(BuilderTest, Switch_WithNestedBreak) {
td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32,
create<ast::BlockStatement>(),
ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
&i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();

View File

@ -113,7 +113,8 @@ bool GeneratorImpl::Generate(const ast::Module& module) {
bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module,
ast::PipelineStage stage,
const std::string& name) {
auto* func = module.FindFunctionByNameAndStage(name, stage);
auto* func =
module.FindFunctionBySymbolAndStage(module.GetSymbol(name), stage);
if (func == nullptr) {
error_ = "Unable to find requested entry point: " + name;
return false;
@ -153,7 +154,7 @@ bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module,
}
for (auto* f : module.functions()) {
if (!f->HasAncestorEntryPoint(name)) {
if (!f->HasAncestorEntryPoint(module.GetSymbol(name))) {
continue;
}

View File

@ -46,8 +46,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function) {
body->append(create<ast::ReturnStatement>(Source{}));
ast::type::Void void_type;
ast::Function func(Source{}, "my_func", {}, &void_type, body,
ast::FunctionDecorationList{});
ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", {},
&void_type, body, ast::FunctionDecorationList{});
gen.increment_indent();
@ -85,8 +85,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithParams) {
ast::VariableDecorationList{})); // decorations
ast::type::Void void_type;
ast::Function func(Source{}, "my_func", params, &void_type, body,
ast::FunctionDecorationList{});
ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", params,
&void_type, body, ast::FunctionDecorationList{});
gen.increment_indent();
@ -104,7 +104,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) {
body->append(create<ast::ReturnStatement>(Source{}));
ast::type::Void void_type;
ast::Function func(Source{}, "my_func", {}, &void_type, body,
ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", {},
&void_type, body,
ast::FunctionDecorationList{
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
});
@ -127,7 +128,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Stage) {
ast::type::Void void_type;
ast::Function func(
Source{}, "my_func", {}, &void_type, body,
Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
});
@ -150,7 +151,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Multiple) {
ast::type::Void void_type;
ast::Function func(
Source{}, "my_func", {}, &void_type, body,
Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
@ -237,12 +238,12 @@ TEST_F(WgslGeneratorImplTest,
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "a", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(
ast::PipelineStage::kCompute, Source{}),
});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute,
Source{}),
});
mod.AddFunction(func);
}
@ -264,12 +265,12 @@ TEST_F(WgslGeneratorImplTest,
body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{}));
auto* func =
create<ast::Function>(Source{}, "b", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(
ast::PipelineStage::kCompute, Source{}),
});
auto* func = create<ast::Function>(
Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body,
ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute,
Source{}),
});
mod.AddFunction(func);
}

View File

@ -33,8 +33,9 @@ TEST_F(WgslGeneratorImplTest, Generate) {
ast::type::Void void_type;
mod.AddFunction(create<ast::Function>(
Source{}, "my_func", ast::VariableList{}, &void_type,
create<ast::BlockStatement>(), ast::FunctionDecorationList{}));
Source{}, mod.RegisterSymbol("a_func"), "my_func", ast::VariableList{},
&void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}));
ASSERT_TRUE(gen.Generate(mod)) << gen.error();
EXPECT_EQ(gen.result(), R"(fn my_func() -> void {