Add a symbol to the Function AST node.

This Cl adds a Symbol representing the function name to the function
AST. The symbol is added alongside the name for now. When all usages of
the function name are removed then the string version will be removed
from the constructor.

Change-Id: Ib2450e5fe531e988b25bb7d2937acc6af2187871
Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/35220
Commit-Queue: dan sinclair <dsinclair@chromium.org>
Reviewed-by: Ben Clayton <bclayton@google.com>
Auto-Submit: dan sinclair <dsinclair@chromium.org>
This commit is contained in:
dan sinclair 2020-12-11 18:24:53 +00:00 committed by Commit Bot service account
parent cd9e5f6e91
commit a41132fcd8
48 changed files with 923 additions and 658 deletions

View File

@ -31,12 +31,14 @@ namespace tint {
namespace ast { namespace ast {
Function::Function(const Source& source, Function::Function(const Source& source,
Symbol symbol,
const std::string& name, const std::string& name,
VariableList params, VariableList params,
type::Type* return_type, type::Type* return_type,
BlockStatement* body, BlockStatement* body,
FunctionDecorationList decorations) FunctionDecorationList decorations)
: Base(source), : Base(source),
symbol_(symbol),
name_(name), name_(name),
params_(std::move(params)), params_(std::move(params)),
return_type_(return_type), return_type_(return_type),
@ -202,7 +204,7 @@ Function::local_referenced_builtin_variables() const {
return ret; return ret;
} }
void Function::add_ancestor_entry_point(const std::string& ep) { void Function::add_ancestor_entry_point(Symbol ep) {
for (const auto& point : ancestor_entry_points_) { for (const auto& point : ancestor_entry_points_) {
if (point == ep) { if (point == ep) {
return; return;
@ -211,9 +213,9 @@ void Function::add_ancestor_entry_point(const std::string& ep) {
ancestor_entry_points_.push_back(ep); ancestor_entry_points_.push_back(ep);
} }
bool Function::HasAncestorEntryPoint(const std::string& name) const { bool Function::HasAncestorEntryPoint(Symbol symbol) const {
for (const auto& point : ancestor_entry_points_) { for (const auto& point : ancestor_entry_points_) {
if (point == name) { if (point == symbol) {
return true; return true;
} }
} }
@ -226,7 +228,7 @@ const Statement* Function::get_last_statement() const {
Function* Function::Clone(CloneContext* ctx) const { Function* Function::Clone(CloneContext* ctx) const {
return ctx->mod->create<Function>( return ctx->mod->create<Function>(
ctx->Clone(source()), name_, ctx->Clone(params_), ctx->Clone(source()), symbol_, name_, ctx->Clone(params_),
ctx->Clone(return_type_), ctx->Clone(body_), ctx->Clone(decorations_)); ctx->Clone(return_type_), ctx->Clone(body_), ctx->Clone(decorations_));
} }
@ -238,7 +240,7 @@ bool Function::IsValid() const {
if (body_ == nullptr || !body_->IsValid()) { if (body_ == nullptr || !body_->IsValid()) {
return false; return false;
} }
if (name_.length() == 0) { if (name_.length() == 0 || !symbol_.IsValid()) {
return false; return false;
} }
if (return_type_ == nullptr) { if (return_type_ == nullptr) {
@ -249,7 +251,7 @@ bool Function::IsValid() const {
void Function::to_str(std::ostream& out, size_t indent) const { void Function::to_str(std::ostream& out, size_t indent) const {
make_indent(out, indent); make_indent(out, indent);
out << "Function " << name_ << " -> " << return_type_->type_name() out << "Function " << symbol_.to_str() << " -> " << return_type_->type_name()
<< std::endl; << std::endl;
for (auto* deco : decorations()) { for (auto* deco : decorations()) {

View File

@ -35,6 +35,7 @@
#include "src/ast/type/sampler_type.h" #include "src/ast/type/sampler_type.h"
#include "src/ast/type/type.h" #include "src/ast/type/type.h"
#include "src/ast/variable.h" #include "src/ast/variable.h"
#include "src/symbol.h"
namespace tint { namespace tint {
namespace ast { namespace ast {
@ -52,12 +53,14 @@ class Function : public Castable<Function, Node> {
/// Create a function /// Create a function
/// @param source the variable source /// @param source the variable source
/// @param symbol the function symbol
/// @param name the function name /// @param name the function name
/// @param params the function parameters /// @param params the function parameters
/// @param return_type the return type /// @param return_type the return type
/// @param body the function body /// @param body the function body
/// @param decorations the function decorations /// @param decorations the function decorations
Function(const Source& source, Function(const Source& source,
Symbol symbol,
const std::string& name, const std::string& name,
VariableList params, VariableList params,
type::Type* return_type, type::Type* return_type,
@ -68,6 +71,8 @@ class Function : public Castable<Function, Node> {
~Function() override; ~Function() override;
/// @returns the function symbol
Symbol symbol() const { return symbol_; }
/// @returns the function name /// @returns the function name
const std::string& name() { return name_; } const std::string& name() { return name_; }
/// @returns the function params /// @returns the function params
@ -150,15 +155,15 @@ class Function : public Castable<Function, Node> {
/// Adds an ancestor entry point /// Adds an ancestor entry point
/// @param ep the entry point ancestor /// @param ep the entry point ancestor
void add_ancestor_entry_point(const std::string& ep); void add_ancestor_entry_point(Symbol ep);
/// @returns the ancestor entry points /// @returns the ancestor entry points
const std::vector<std::string>& ancestor_entry_points() const { const std::vector<Symbol>& ancestor_entry_points() const {
return ancestor_entry_points_; return ancestor_entry_points_;
} }
/// Checks if the given entry point is an ancestor /// Checks if the given entry point is an ancestor
/// @param name the entry point name /// @param sym the entry point symbol
/// @returns true if `name` is an ancestor entry point of this function /// @returns true if `sym` is an ancestor entry point of this function
bool HasAncestorEntryPoint(const std::string& name) const; bool HasAncestorEntryPoint(Symbol sym) const;
/// @returns the function return type. /// @returns the function return type.
type::Type* return_type() const { return return_type_; } type::Type* return_type() const { return return_type_; }
@ -197,13 +202,14 @@ class Function : public Castable<Function, Node> {
const std::vector<std::pair<Variable*, Function::BindingInfo>> const std::vector<std::pair<Variable*, Function::BindingInfo>>
ReferencedSampledTextureVariablesImpl(bool multisampled) const; ReferencedSampledTextureVariablesImpl(bool multisampled) const;
Symbol symbol_;
std::string name_; std::string name_;
VariableList params_; VariableList params_;
type::Type* return_type_ = nullptr; type::Type* return_type_ = nullptr;
BlockStatement* body_ = nullptr; BlockStatement* body_ = nullptr;
std::vector<Variable*> referenced_module_vars_; std::vector<Variable*> referenced_module_vars_;
std::vector<Variable*> local_referenced_module_vars_; std::vector<Variable*> local_referenced_module_vars_;
std::vector<std::string> ancestor_entry_points_; std::vector<Symbol> ancestor_entry_points_;
FunctionDecorationList decorations_; FunctionDecorationList decorations_;
}; };

View File

@ -35,14 +35,18 @@ TEST_F(FunctionTest, Creation) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32, params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr, false, nullptr,
ast::VariableDecorationList{})); ast::VariableDecorationList{}));
auto* var = params[0]; auto* var = params[0];
Function f(Source{}, "func", params, &void_type, create<BlockStatement>(), Function f(Source{}, func_sym, "func", params, &void_type,
FunctionDecorationList{}); create<BlockStatement>(), FunctionDecorationList{});
EXPECT_EQ(f.symbol(), func_sym);
EXPECT_EQ(f.name(), "func"); EXPECT_EQ(f.name(), "func");
ASSERT_EQ(f.params().size(), 1u); ASSERT_EQ(f.params().size(), 1u);
EXPECT_EQ(f.return_type(), &void_type); EXPECT_EQ(f.return_type(), &void_type);
@ -53,13 +57,16 @@ TEST_F(FunctionTest, Creation_WithSource) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32, params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr, false, nullptr,
ast::VariableDecorationList{})); ast::VariableDecorationList{}));
Function f(Source{Source::Location{20, 2}}, "func", params, &void_type, Function f(Source{Source::Location{20, 2}}, func_sym, "func", params,
create<BlockStatement>(), FunctionDecorationList{}); &void_type, create<BlockStatement>(), FunctionDecorationList{});
auto src = f.source(); auto src = f.source();
EXPECT_EQ(src.range.begin.line, 20u); EXPECT_EQ(src.range.begin.line, 20u);
EXPECT_EQ(src.range.begin.column, 2u); EXPECT_EQ(src.range.begin.column, 2u);
@ -69,9 +76,12 @@ TEST_F(FunctionTest, AddDuplicateReferencedVariables) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
Variable v(Source{}, "var", StorageClass::kInput, &i32, false, nullptr, Variable v(Source{}, "var", StorageClass::kInput, &i32, false, nullptr,
ast::VariableDecorationList{}); ast::VariableDecorationList{});
Function f(Source{}, "func", VariableList{}, &void_type, Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{}); create<BlockStatement>(), FunctionDecorationList{});
f.add_referenced_module_variable(&v); f.add_referenced_module_variable(&v);
@ -92,6 +102,9 @@ TEST_F(FunctionTest, GetReferenceLocations) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
auto* loc1 = create<Variable>(Source{}, "loc1", StorageClass::kInput, &i32, auto* loc1 = create<Variable>(Source{}, "loc1", StorageClass::kInput, &i32,
false, nullptr, false, nullptr,
ast::VariableDecorationList{ ast::VariableDecorationList{
@ -116,7 +129,7 @@ TEST_F(FunctionTest, GetReferenceLocations) {
create<BuiltinDecoration>(Builtin::kFragDepth, Source{}), create<BuiltinDecoration>(Builtin::kFragDepth, Source{}),
}); });
Function f(Source{}, "func", VariableList{}, &void_type, Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{}); create<BlockStatement>(), FunctionDecorationList{});
f.add_referenced_module_variable(loc1); f.add_referenced_module_variable(loc1);
@ -137,6 +150,9 @@ TEST_F(FunctionTest, GetReferenceBuiltins) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
auto* loc1 = create<Variable>(Source{}, "loc1", StorageClass::kInput, &i32, auto* loc1 = create<Variable>(Source{}, "loc1", StorageClass::kInput, &i32,
false, nullptr, false, nullptr,
ast::VariableDecorationList{ ast::VariableDecorationList{
@ -161,7 +177,7 @@ TEST_F(FunctionTest, GetReferenceBuiltins) {
create<BuiltinDecoration>(Builtin::kFragDepth, Source{}), create<BuiltinDecoration>(Builtin::kFragDepth, Source{}),
}); });
Function f(Source{}, "func", VariableList{}, &void_type, Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{}); create<BlockStatement>(), FunctionDecorationList{});
f.add_referenced_module_variable(loc1); f.add_referenced_module_variable(loc1);
@ -180,22 +196,30 @@ TEST_F(FunctionTest, GetReferenceBuiltins) {
TEST_F(FunctionTest, AddDuplicateEntryPoints) { TEST_F(FunctionTest, AddDuplicateEntryPoints) {
type::Void void_type; type::Void void_type;
Function f(Source{}, "func", VariableList{}, &void_type,
Module m;
auto func_sym = m.RegisterSymbol("func");
auto main_sym = m.RegisterSymbol("main");
Function f(Source{}, func_sym, "func", VariableList{}, &void_type,
create<BlockStatement>(), FunctionDecorationList{}); create<BlockStatement>(), FunctionDecorationList{});
f.add_ancestor_entry_point("main"); f.add_ancestor_entry_point(main_sym);
ASSERT_EQ(1u, f.ancestor_entry_points().size()); ASSERT_EQ(1u, f.ancestor_entry_points().size());
EXPECT_EQ("main", f.ancestor_entry_points()[0]); EXPECT_EQ(main_sym, f.ancestor_entry_points()[0]);
f.add_ancestor_entry_point("main"); f.add_ancestor_entry_point(main_sym);
ASSERT_EQ(1u, f.ancestor_entry_points().size()); ASSERT_EQ(1u, f.ancestor_entry_points().size());
EXPECT_EQ("main", f.ancestor_entry_points()[0]); EXPECT_EQ(main_sym, f.ancestor_entry_points()[0]);
} }
TEST_F(FunctionTest, IsValid) { TEST_F(FunctionTest, IsValid) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32, params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr, false, nullptr,
@ -204,21 +228,27 @@ TEST_F(FunctionTest, IsValid) {
auto* body = create<BlockStatement>(); auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>()); body->append(create<DiscardStatement>());
Function f(Source{}, "func", params, &void_type, body, Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{}); FunctionDecorationList{});
EXPECT_TRUE(f.IsValid()); EXPECT_TRUE(f.IsValid());
} }
TEST_F(FunctionTest, IsValid_EmptyName) { TEST_F(FunctionTest, IsValid_InvalidName) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32, params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr, false, nullptr,
ast::VariableDecorationList{})); ast::VariableDecorationList{}));
Function f(Source{}, "", params, &void_type, create<BlockStatement>(), auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>());
Function f(Source{}, func_sym, "", params, &void_type, body,
FunctionDecorationList{}); FunctionDecorationList{});
EXPECT_FALSE(f.IsValid()); EXPECT_FALSE(f.IsValid());
} }
@ -226,13 +256,16 @@ TEST_F(FunctionTest, IsValid_EmptyName) {
TEST_F(FunctionTest, IsValid_MissingReturnType) { TEST_F(FunctionTest, IsValid_MissingReturnType) {
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32, params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr, false, nullptr,
ast::VariableDecorationList{})); ast::VariableDecorationList{}));
Function f(Source{}, "func", params, nullptr, create<BlockStatement>(), Function f(Source{}, func_sym, "func", params, nullptr,
FunctionDecorationList{}); create<BlockStatement>(), FunctionDecorationList{});
EXPECT_FALSE(f.IsValid()); EXPECT_FALSE(f.IsValid());
} }
@ -240,27 +273,33 @@ TEST_F(FunctionTest, IsValid_NullParam) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32, params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr, false, nullptr,
ast::VariableDecorationList{})); ast::VariableDecorationList{}));
params.push_back(nullptr); params.push_back(nullptr);
Function f(Source{}, "func", params, &void_type, create<BlockStatement>(), Function f(Source{}, func_sym, "func", params, &void_type,
FunctionDecorationList{}); create<BlockStatement>(), FunctionDecorationList{});
EXPECT_FALSE(f.IsValid()); EXPECT_FALSE(f.IsValid());
} }
TEST_F(FunctionTest, IsValid_InvalidParam) { TEST_F(FunctionTest, IsValid_InvalidParam) {
type::Void void_type; type::Void void_type;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone,
nullptr, false, nullptr, nullptr, false, nullptr,
ast::VariableDecorationList{})); ast::VariableDecorationList{}));
Function f(Source{}, "func", params, &void_type, create<BlockStatement>(), Function f(Source{}, func_sym, "func", params, &void_type,
FunctionDecorationList{}); create<BlockStatement>(), FunctionDecorationList{});
EXPECT_FALSE(f.IsValid()); EXPECT_FALSE(f.IsValid());
} }
@ -268,6 +307,9 @@ TEST_F(FunctionTest, IsValid_NullBodyStatement) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32, params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr, false, nullptr,
@ -277,7 +319,7 @@ TEST_F(FunctionTest, IsValid_NullBodyStatement) {
body->append(create<DiscardStatement>()); body->append(create<DiscardStatement>());
body->append(nullptr); body->append(nullptr);
Function f(Source{}, "func", params, &void_type, body, Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{}); FunctionDecorationList{});
EXPECT_FALSE(f.IsValid()); EXPECT_FALSE(f.IsValid());
@ -287,6 +329,9 @@ TEST_F(FunctionTest, IsValid_InvalidBodyStatement) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32, params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr, false, nullptr,
@ -296,7 +341,7 @@ TEST_F(FunctionTest, IsValid_InvalidBodyStatement) {
body->append(create<DiscardStatement>()); body->append(create<DiscardStatement>());
body->append(nullptr); body->append(nullptr);
Function f(Source{}, "func", params, &void_type, body, Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{}); FunctionDecorationList{});
EXPECT_FALSE(f.IsValid()); EXPECT_FALSE(f.IsValid());
} }
@ -305,14 +350,18 @@ TEST_F(FunctionTest, ToStr) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
auto* body = create<BlockStatement>(); auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>()); body->append(create<DiscardStatement>());
Function f(Source{}, "func", {}, &void_type, body, FunctionDecorationList{}); Function f(Source{}, func_sym, "func", {}, &void_type, body,
FunctionDecorationList{});
std::ostringstream out; std::ostringstream out;
f.to_str(out, 2); f.to_str(out, 2);
EXPECT_EQ(out.str(), R"( Function func -> __void EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void
() ()
{ {
Discard{} Discard{}
@ -324,16 +373,19 @@ TEST_F(FunctionTest, ToStr_WithDecoration) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
auto* body = create<BlockStatement>(); auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>()); body->append(create<DiscardStatement>());
Function f( Function f(
Source{}, "func", {}, &void_type, body, Source{}, func_sym, "func", {}, &void_type, body,
FunctionDecorationList{create<WorkgroupDecoration>(2, 4, 6, Source{})}); FunctionDecorationList{create<WorkgroupDecoration>(2, 4, 6, Source{})});
std::ostringstream out; std::ostringstream out;
f.to_str(out, 2); f.to_str(out, 2);
EXPECT_EQ(out.str(), R"( Function func -> __void EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void
WorkgroupDecoration{2 4 6} WorkgroupDecoration{2 4 6}
() ()
{ {
@ -346,6 +398,9 @@ TEST_F(FunctionTest, ToStr_WithParams) {
type::Void void_type; type::Void void_type;
type::I32 i32; type::I32 i32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32, params.push_back(create<Variable>(Source{}, "var", StorageClass::kNone, &i32,
false, nullptr, false, nullptr,
@ -354,12 +409,12 @@ TEST_F(FunctionTest, ToStr_WithParams) {
auto* body = create<BlockStatement>(); auto* body = create<BlockStatement>();
body->append(create<DiscardStatement>()); body->append(create<DiscardStatement>());
Function f(Source{}, "func", params, &void_type, body, Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{}); FunctionDecorationList{});
std::ostringstream out; std::ostringstream out;
f.to_str(out, 2); f.to_str(out, 2);
EXPECT_EQ(out.str(), R"( Function func -> __void EXPECT_EQ(out.str(), R"( Function tint_symbol_1 -> __void
( (
Variable{ Variable{
var var
@ -376,8 +431,11 @@ TEST_F(FunctionTest, ToStr_WithParams) {
TEST_F(FunctionTest, TypeName) { TEST_F(FunctionTest, TypeName) {
type::Void void_type; type::Void void_type;
Function f(Source{}, "func", {}, &void_type, create<BlockStatement>(), Module m;
FunctionDecorationList{}); auto func_sym = m.RegisterSymbol("func");
Function f(Source{}, func_sym, "func", {}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
EXPECT_EQ(f.type_name(), "__func__void"); EXPECT_EQ(f.type_name(), "__func__void");
} }
@ -386,6 +444,9 @@ TEST_F(FunctionTest, TypeName_WithParams) {
type::I32 i32; type::I32 i32;
type::F32 f32; type::F32 f32;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
params.push_back(create<Variable>(Source{}, "var1", StorageClass::kNone, &i32, params.push_back(create<Variable>(Source{}, "var1", StorageClass::kNone, &i32,
false, nullptr, false, nullptr,
@ -394,19 +455,22 @@ TEST_F(FunctionTest, TypeName_WithParams) {
false, nullptr, false, nullptr,
ast::VariableDecorationList{})); ast::VariableDecorationList{}));
Function f(Source{}, "func", params, &void_type, create<BlockStatement>(), Function f(Source{}, func_sym, "func", params, &void_type,
FunctionDecorationList{}); create<BlockStatement>(), FunctionDecorationList{});
EXPECT_EQ(f.type_name(), "__func__void__i32__f32"); EXPECT_EQ(f.type_name(), "__func__void__i32__f32");
} }
TEST_F(FunctionTest, GetLastStatement) { TEST_F(FunctionTest, GetLastStatement) {
type::Void void_type; type::Void void_type;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
auto* body = create<BlockStatement>(); auto* body = create<BlockStatement>();
auto* stmt = create<DiscardStatement>(); auto* stmt = create<DiscardStatement>();
body->append(stmt); body->append(stmt);
Function f(Source{}, "func", params, &void_type, body, Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{}); FunctionDecorationList{});
EXPECT_EQ(f.get_last_statement(), stmt); EXPECT_EQ(f.get_last_statement(), stmt);
@ -415,9 +479,12 @@ TEST_F(FunctionTest, GetLastStatement) {
TEST_F(FunctionTest, GetLastStatement_nullptr) { TEST_F(FunctionTest, GetLastStatement_nullptr) {
type::Void void_type; type::Void void_type;
Module m;
auto func_sym = m.RegisterSymbol("func");
VariableList params; VariableList params;
auto* body = create<BlockStatement>(); auto* body = create<BlockStatement>();
Function f(Source{}, "func", params, &void_type, body, Function f(Source{}, func_sym, "func", params, &void_type, body,
FunctionDecorationList{}); FunctionDecorationList{});
EXPECT_EQ(f.get_last_statement(), nullptr); EXPECT_EQ(f.get_last_statement(), nullptr);
@ -425,8 +492,12 @@ TEST_F(FunctionTest, GetLastStatement_nullptr) {
TEST_F(FunctionTest, WorkgroupSize_NoneSet) { TEST_F(FunctionTest, WorkgroupSize_NoneSet) {
type::Void void_type; type::Void void_type;
Function f(Source{}, "f", {}, &void_type, create<BlockStatement>(),
FunctionDecorationList{}); Module m;
auto func_sym = m.RegisterSymbol("func");
Function f(Source{}, func_sym, "func", {}, &void_type,
create<BlockStatement>(), FunctionDecorationList{});
uint32_t x = 0; uint32_t x = 0;
uint32_t y = 0; uint32_t y = 0;
uint32_t z = 0; uint32_t z = 0;
@ -438,7 +509,12 @@ TEST_F(FunctionTest, WorkgroupSize_NoneSet) {
TEST_F(FunctionTest, WorkgroupSize) { TEST_F(FunctionTest, WorkgroupSize) {
type::Void void_type; type::Void void_type;
Function f(Source{}, "f", {}, &void_type, create<BlockStatement>(),
Module m;
auto func_sym = m.RegisterSymbol("func");
Function f(Source{}, func_sym, "func", {}, &void_type,
create<BlockStatement>(),
{create<WorkgroupDecoration>(2u, 4u, 6u, Source{})}); {create<WorkgroupDecoration>(2u, 4u, 6u, Source{})});
uint32_t x = 0; uint32_t x = 0;

View File

@ -47,21 +47,23 @@ void Module::Clone(CloneContext* ctx) {
for (auto* func : functions_) { for (auto* func : functions_) {
ctx->mod->functions_.emplace_back(ctx->Clone(func)); ctx->mod->functions_.emplace_back(ctx->Clone(func));
} }
ctx->mod->symbol_table_ = symbol_table_;
} }
Function* Module::FindFunctionByName(const std::string& name) const { Function* Module::FindFunctionBySymbol(Symbol sym) const {
for (auto* func : functions_) { for (auto* func : functions_) {
if (func->name() == name) { if (func->symbol() == sym) {
return func; return func;
} }
} }
return nullptr; return nullptr;
} }
Function* Module::FindFunctionByNameAndStage(const std::string& name, Function* Module::FindFunctionBySymbolAndStage(Symbol sym,
PipelineStage stage) const { PipelineStage stage) const {
for (auto* func : functions_) { for (auto* func : functions_) {
if (func->name() == name && func->pipeline_stage() == stage) { if (func->symbol() == sym && func->pipeline_stage() == stage) {
return func; return func;
} }
} }
@ -81,6 +83,10 @@ Symbol Module::RegisterSymbol(const std::string& name) {
return symbol_table_.Register(name); return symbol_table_.Register(name);
} }
Symbol Module::GetSymbol(const std::string& name) const {
return symbol_table_.GetSymbol(name);
}
std::string Module::SymbolToName(const Symbol sym) const { std::string Module::SymbolToName(const Symbol sym) const {
return symbol_table_.NameFor(sym); return symbol_table_.NameFor(sym);
} }

View File

@ -89,15 +89,14 @@ class Module {
/// @returns the modules functions /// @returns the modules functions
const FunctionList& functions() const { return functions_; } const FunctionList& functions() const { return functions_; }
/// Returns the function with the given name /// Returns the function with the given name
/// @param name the name to search for /// @param sym the function symbol to search for
/// @returns the associated function or nullptr if none exists /// @returns the associated function or nullptr if none exists
Function* FindFunctionByName(const std::string& name) const; Function* FindFunctionBySymbol(Symbol sym) const;
/// Returns the function with the given name /// Returns the function with the given name
/// @param name the name to search for /// @param sym the function symbol to search for
/// @param stage the pipeline stage /// @param stage the pipeline stage
/// @returns the associated function or nullptr if none exists /// @returns the associated function or nullptr if none exists
Function* FindFunctionByNameAndStage(const std::string& name, Function* FindFunctionBySymbolAndStage(Symbol sym, PipelineStage stage) const;
PipelineStage stage) const;
/// @param stage the pipeline stage /// @param stage the pipeline stage
/// @returns true if the module contains an entrypoint function with the given /// @returns true if the module contains an entrypoint function with the given
/// stage /// stage
@ -169,6 +168,11 @@ class Module {
/// previously generated symbol will be returned. /// previously generated symbol will be returned.
Symbol RegisterSymbol(const std::string& name); Symbol RegisterSymbol(const std::string& name);
/// Returns the symbol for `name`
/// @param name the name to lookup
/// @returns the symbol for name or symbol::kInvalid
Symbol GetSymbol(const std::string& name) const;
/// Returns the `name` for `sym` /// Returns the `name` for `sym`
/// @param sym the symbol to retrieve the name for /// @param sym the symbol to retrieve the name for
/// @returns the use provided `name` for the symbol or "" if not found /// @returns the use provided `name` for the symbol or "" if not found

View File

@ -48,16 +48,17 @@ TEST_F(ModuleTest, LookupFunction) {
type::F32 f32; type::F32 f32;
Module m; Module m;
auto func_sym = m.RegisterSymbol("main");
auto* func = auto* func =
create<Function>(Source{}, "main", VariableList{}, &f32, create<Function>(Source{}, func_sym, "main", VariableList{}, &f32,
create<BlockStatement>(), ast::FunctionDecorationList{}); create<BlockStatement>(), ast::FunctionDecorationList{});
m.AddFunction(func); m.AddFunction(func);
EXPECT_EQ(func, m.FindFunctionByName("main")); EXPECT_EQ(func, m.FindFunctionBySymbol(func_sym));
} }
TEST_F(ModuleTest, LookupFunctionMissing) { TEST_F(ModuleTest, LookupFunctionMissing) {
Module m; Module m;
EXPECT_EQ(nullptr, m.FindFunctionByName("Missing")); EXPECT_EQ(nullptr, m.FindFunctionBySymbol(m.RegisterSymbol("Missing")));
} }
TEST_F(ModuleTest, IsValid_Empty) { TEST_F(ModuleTest, IsValid_Empty) {
@ -127,11 +128,12 @@ TEST_F(ModuleTest, IsValid_Struct_EmptyName) {
TEST_F(ModuleTest, IsValid_Function) { TEST_F(ModuleTest, IsValid_Function) {
type::F32 f32; type::F32 f32;
auto* func =
create<Function>(Source{}, "main", VariableList(), &f32,
create<BlockStatement>(), ast::FunctionDecorationList{});
Module m; Module m;
auto* func = create<Function>(Source{}, m.RegisterSymbol("main"), "main",
VariableList(), &f32, create<BlockStatement>(),
ast::FunctionDecorationList{});
m.AddFunction(func); m.AddFunction(func);
EXPECT_TRUE(m.IsValid()); EXPECT_TRUE(m.IsValid());
} }
@ -144,10 +146,13 @@ TEST_F(ModuleTest, IsValid_Null_Function) {
TEST_F(ModuleTest, IsValid_Invalid_Function) { TEST_F(ModuleTest, IsValid_Invalid_Function) {
VariableList p; VariableList p;
auto* func = create<Function>(Source{}, "", p, nullptr, nullptr,
ast::FunctionDecorationList{});
Module m; Module m;
auto* func =
create<Function>(Source{}, m.RegisterSymbol("main"), "main", p, nullptr,
nullptr, ast::FunctionDecorationList{});
m.AddFunction(func); m.AddFunction(func);
EXPECT_FALSE(m.IsValid()); EXPECT_FALSE(m.IsValid());
} }

View File

@ -267,7 +267,7 @@ std::vector<ResourceBinding> Inspector::GetMultisampledTextureResourceBindings(
} }
ast::Function* Inspector::FindEntryPointByName(const std::string& name) { ast::Function* Inspector::FindEntryPointByName(const std::string& name) {
auto* func = module_.FindFunctionByName(name); auto* func = module_.FindFunctionBySymbol(module_.GetSymbol(name));
if (!func) { if (!func) {
error_ += name + " was not found!"; error_ += name + " was not found!";
return nullptr; return nullptr;

View File

@ -83,8 +83,9 @@ class InspectorHelper {
ast::FunctionDecorationList decorations = {}) { ast::FunctionDecorationList decorations = {}) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, name, ast::VariableList(), return create<ast::Function>(Source{}, mod()->RegisterSymbol(name), name,
void_type(), body, decorations); ast::VariableList(), void_type(), body,
decorations);
} }
/// Generates a function that calls another /// Generates a function that calls another
@ -102,8 +103,9 @@ class InspectorHelper {
create<ast::CallExpression>(ident_expr, ast::ExpressionList()); create<ast::CallExpression>(ident_expr, ast::ExpressionList());
body->append(create<ast::CallStatement>(call_expr)); body->append(create<ast::CallStatement>(call_expr));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, caller, ast::VariableList(), return create<ast::Function>(Source{}, mod()->RegisterSymbol(caller),
void_type(), body, decorations); caller, ast::VariableList(), void_type(), body,
decorations);
} }
/// Add In/Out variables to the global variables /// Add In/Out variables to the global variables
@ -154,8 +156,9 @@ class InspectorHelper {
create<ast::IdentifierExpression>(in))); create<ast::IdentifierExpression>(in)));
} }
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, name, ast::VariableList(), return create<ast::Function>(Source{}, mod()->RegisterSymbol(name), name,
void_type(), body, decorations); ast::VariableList(), void_type(), body,
decorations);
} }
/// Generates a function that references in/out variables and calls another /// Generates a function that references in/out variables and calls another
@ -184,8 +187,9 @@ class InspectorHelper {
create<ast::CallExpression>(ident_expr, ast::ExpressionList()); create<ast::CallExpression>(ident_expr, ast::ExpressionList());
body->append(create<ast::CallStatement>(call_expr)); body->append(create<ast::CallStatement>(call_expr));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, caller, ast::VariableList(), return create<ast::Function>(Source{}, mod()->RegisterSymbol(caller),
void_type(), body, decorations); caller, ast::VariableList(), void_type(), body,
decorations);
} }
/// Add a Constant ID to the global variables. /// Add a Constant ID to the global variables.
@ -445,9 +449,9 @@ class InspectorHelper {
} }
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, func_name, ast::VariableList(), return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
void_type(), body, func_name, ast::VariableList(), void_type(),
ast::FunctionDecorationList{}); body, ast::FunctionDecorationList{});
} }
/// Adds a regular sampler variable to the module /// Adds a regular sampler variable to the module
@ -587,8 +591,9 @@ class InspectorHelper {
create<ast::IdentifierExpression>("sampler_result"), call_expr)); create<ast::IdentifierExpression>("sampler_result"), call_expr));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, func_name, ast::VariableList(), return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
void_type(), body, decorations); func_name, ast::VariableList(), void_type(),
body, decorations);
} }
/// Generates a function that references a specific sampler variable /// Generates a function that references a specific sampler variable
@ -634,8 +639,9 @@ class InspectorHelper {
create<ast::IdentifierExpression>("sampler_result"), call_expr)); create<ast::IdentifierExpression>("sampler_result"), call_expr));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, func_name, ast::VariableList(), return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
void_type(), body, decorations); func_name, ast::VariableList(), void_type(),
body, decorations);
} }
/// Generates a function that references a specific comparison sampler /// Generates a function that references a specific comparison sampler
@ -682,8 +688,9 @@ class InspectorHelper {
create<ast::IdentifierExpression>("sampler_result"), call_expr)); create<ast::IdentifierExpression>("sampler_result"), call_expr));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
return create<ast::Function>(Source{}, func_name, ast::VariableList(), return create<ast::Function>(Source{}, mod()->RegisterSymbol(func_name),
void_type(), body, decorations); func_name, ast::VariableList(), void_type(),
body, decorations);
} }
/// Gets an appropriate type for the data in a given texture type. /// Gets an appropriate type for the data in a given texture type.
@ -1513,7 +1520,8 @@ TEST_F(InspectorGetUniformBufferResourceBindingsTest, MultipleUniformBuffers) {
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
ast::Function* func = create<ast::Function>( ast::Function* func = create<ast::Function>(
Source{}, "ep_func", ast::VariableList(), void_type(), body, Source{}, mod()->RegisterSymbol("ep_func"), "ep_func",
ast::VariableList(), void_type(), body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -1659,7 +1667,8 @@ TEST_F(InspectorGetStorageBufferResourceBindingsTest, MultipleStorageBuffers) {
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
ast::Function* func = create<ast::Function>( ast::Function* func = create<ast::Function>(
Source{}, "ep_func", ast::VariableList(), void_type(), body, Source{}, mod()->RegisterSymbol("ep_func"), "ep_func",
ast::VariableList(), void_type(), body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -1832,7 +1841,8 @@ TEST_F(InspectorGetReadOnlyStorageBufferResourceBindingsTest,
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
ast::Function* func = create<ast::Function>( ast::Function* func = create<ast::Function>(
Source{}, "ep_func", ast::VariableList(), void_type(), body, Source{}, mod()->RegisterSymbol("ep_func"), "ep_func",
ast::VariableList(), void_type(), body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });

View File

@ -761,9 +761,10 @@ bool FunctionEmitter::Emit() {
} }
auto* body = statements_stack_[0].statements_; auto* body = statements_stack_[0].statements_;
ast_module_.AddFunction(create<ast::Function>( ast_module_.AddFunction(
decl.source, decl.name, std::move(decl.params), decl.return_type, body, create<ast::Function>(decl.source, ast_module_.RegisterSymbol(decl.name),
std::move(decl.decorations))); decl.name, std::move(decl.params), decl.return_type,
body, std::move(decl.decorations)));
// Maintain the invariant by repopulating the one and only element. // Maintain the invariant by repopulating the one and only element.
statements_stack_.clear(); statements_stack_.clear();

View File

@ -46,14 +46,16 @@ TEST_F(SpvParserTest, EmitStatement_VoidCallNoParams) {
OpFunctionEnd OpFunctionEnd
)")); )"));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error(); ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error();
const auto module_ast_str = p->module().to_str(); const auto module_ast_str = p->get_module().to_str();
EXPECT_THAT(module_ast_str, Eq(R"(Module{ EXPECT_THAT(module_ast_str, Eq(R"(Module{
Function x_50 -> __void Function )" + p->get_module().GetSymbol("x_50").to_str() +
R"( -> __void
() ()
{ {
Return{} Return{}
} }
Function x_100 -> __void Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __void
() ()
{ {
Call[not set]{ Call[not set]{
@ -214,9 +216,10 @@ TEST_F(SpvParserTest, EmitStatement_CallWithParams) {
)")); )"));
ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error(); ASSERT_TRUE(p->BuildAndParseInternalModule()) << p->error();
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
const auto module_ast_str = p->module().to_str(); const auto module_ast_str = p->get_module().to_str();
EXPECT_THAT(module_ast_str, HasSubstr(R"(Module{ EXPECT_THAT(module_ast_str, HasSubstr(R"(Module{
Function x_50 -> __u32 Function )" + p->get_module().GetSymbol("x_50").to_str() +
R"( -> __u32
( (
VariableConst{ VariableConst{
x_51 x_51
@ -240,7 +243,8 @@ TEST_F(SpvParserTest, EmitStatement_CallWithParams) {
} }
} }
} }
Function x_100 -> __void Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __void
() ()
{ {
VariableDeclStatement{ VariableDeclStatement{

View File

@ -59,9 +59,10 @@ TEST_F(SpvParserTest, Emit_VoidFunctionWithoutParams) {
ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions()); ASSERT_TRUE(p->BuildAndParseInternalModuleExceptFunctions());
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit()); EXPECT_TRUE(fe.Emit());
auto got = p->module().to_str(); auto got = p->get_module().to_str();
auto* expect = R"(Module{ auto expect = R"(Module{
Function x_100 -> __void Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __void
() ()
{ {
Return{} Return{}
@ -83,9 +84,10 @@ TEST_F(SpvParserTest, Emit_NonVoidResultType) {
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit()); EXPECT_TRUE(fe.Emit());
auto got = p->module().to_str(); auto got = p->get_module().to_str();
auto* expect = R"(Module{ auto expect = R"(Module{
Function x_100 -> __f32 Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __f32
() ()
{ {
Return{ Return{
@ -115,9 +117,10 @@ TEST_F(SpvParserTest, Emit_MixedParamTypes) {
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit()); EXPECT_TRUE(fe.Emit());
auto got = p->module().to_str(); auto got = p->get_module().to_str();
auto* expect = R"(Module{ auto expect = R"(Module{
Function x_100 -> __void Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __void
( (
VariableConst{ VariableConst{
a a
@ -159,9 +162,10 @@ TEST_F(SpvParserTest, Emit_GenerateParamNames) {
FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100)); FunctionEmitter fe(p.get(), *spirv_function(p.get(), 100));
EXPECT_TRUE(fe.Emit()); EXPECT_TRUE(fe.Emit());
auto got = p->module().to_str(); auto got = p->get_module().to_str();
auto* expect = R"(Module{ auto expect = R"(Module{
Function x_100 -> __void Function )" + p->get_module().GetSymbol("x_100").to_str() +
R"( -> __void
( (
VariableConst{ VariableConst{
x_14 x_14

View File

@ -53,7 +53,7 @@ TEST_F(SpvParserTest, EmitFunctions_NoFunctions) {
auto p = parser(test::Assemble(CommonTypes())); auto p = parser(test::Assemble(CommonTypes()));
EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, Not(HasSubstr("Function{"))); EXPECT_THAT(module_ast, Not(HasSubstr("Function{")));
} }
@ -64,7 +64,7 @@ TEST_F(SpvParserTest, EmitFunctions_FunctionWithoutBody) {
)")); )"));
EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, Not(HasSubstr("Function{"))); EXPECT_THAT(module_ast, Not(HasSubstr("Function{")));
} }
@ -79,9 +79,10 @@ OpFunctionEnd)";
auto p = parser(test::Assemble(input)); auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error(); ASSERT_TRUE(p->error().empty()) << p->error();
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"( EXPECT_THAT(module_ast, HasSubstr(R"(
Function main -> __void Function )" + p->get_module().GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex} StageDecoration{vertex}
() ()
{)")); {)"));
@ -98,9 +99,10 @@ OpFunctionEnd)";
auto p = parser(test::Assemble(input)); auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error(); ASSERT_TRUE(p->error().empty()) << p->error();
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"( EXPECT_THAT(module_ast, HasSubstr(R"(
Function main -> __void Function )" + p->get_module().GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{fragment} StageDecoration{fragment}
() ()
{)")); {)"));
@ -117,9 +119,10 @@ OpFunctionEnd)";
auto p = parser(test::Assemble(input)); auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error(); ASSERT_TRUE(p->error().empty()) << p->error();
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"( EXPECT_THAT(module_ast, HasSubstr(R"(
Function main -> __void Function )" + p->get_module().GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{compute} StageDecoration{compute}
() ()
{)")); {)"));
@ -138,14 +141,16 @@ OpFunctionEnd)";
auto p = parser(test::Assemble(input)); auto p = parser(test::Assemble(input));
ASSERT_TRUE(p->BuildAndParseInternalModule()); ASSERT_TRUE(p->BuildAndParseInternalModule());
ASSERT_TRUE(p->error().empty()) << p->error(); ASSERT_TRUE(p->error().empty()) << p->error();
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"( EXPECT_THAT(module_ast, HasSubstr(R"(
Function frag_main -> __void Function )" + p->get_module().GetSymbol("frag_main").to_str() +
R"( -> __void
StageDecoration{fragment} StageDecoration{fragment}
() ()
{)")); {)"));
EXPECT_THAT(module_ast, HasSubstr(R"( EXPECT_THAT(module_ast, HasSubstr(R"(
Function comp_main -> __void Function )" + p->get_module().GetSymbol("comp_main").to_str() +
R"( -> __void
StageDecoration{compute} StageDecoration{compute}
() ()
{)")); {)"));
@ -160,9 +165,10 @@ TEST_F(SpvParserTest, EmitFunctions_VoidFunctionWithoutParams) {
)")); )"));
EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"( EXPECT_THAT(module_ast, HasSubstr(R"(
Function main -> __void Function )" + p->get_module().GetSymbol("main").to_str() +
R"( -> __void
() ()
{)")); {)"));
} }
@ -193,9 +199,10 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) {
)")); )"));
EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"( EXPECT_THAT(module_ast, HasSubstr(R"(
Function leaf -> __u32 Function )" + p->get_module().GetSymbol("leaf").to_str() +
R"( -> __u32
() ()
{ {
Return{ Return{
@ -204,7 +211,8 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) {
} }
} }
} }
Function branch -> __u32 Function )" + p->get_module().GetSymbol("branch").to_str() +
R"( -> __u32
() ()
{ {
VariableDeclStatement{ VariableDeclStatement{
@ -227,7 +235,8 @@ TEST_F(SpvParserTest, EmitFunctions_CalleePrecedesCaller) {
} }
} }
} }
Function root -> __void Function )" + p->get_module().GetSymbol("root").to_str() +
R"( -> __void
() ()
{ {
VariableDeclStatement{ VariableDeclStatement{
@ -260,9 +269,10 @@ TEST_F(SpvParserTest, EmitFunctions_NonVoidResultType) {
)")); )"));
EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"( EXPECT_THAT(module_ast, HasSubstr(R"(
Function ret_float -> __f32 Function )" + p->get_module().GetSymbol("ret_float").to_str() +
R"( -> __f32
() ()
{ {
Return{ Return{
@ -289,9 +299,10 @@ TEST_F(SpvParserTest, EmitFunctions_MixedParamTypes) {
)")); )"));
EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"( EXPECT_THAT(module_ast, HasSubstr(R"(
Function mixed_params -> __void Function )" + p->get_module().GetSymbol("mixed_params").to_str() +
R"( -> __void
( (
VariableConst{ VariableConst{
a a
@ -328,9 +339,10 @@ TEST_F(SpvParserTest, EmitFunctions_GenerateParamNames) {
)")); )"));
EXPECT_TRUE(p->BuildAndParseInternalModule()); EXPECT_TRUE(p->BuildAndParseInternalModule());
EXPECT_TRUE(p->error().empty()); EXPECT_TRUE(p->error().empty());
const auto module_ast = p->module().to_str(); const auto module_ast = p->get_module().to_str();
EXPECT_THAT(module_ast, HasSubstr(R"( EXPECT_THAT(module_ast, HasSubstr(R"(
Function mixed_params -> __void Function )" + p->get_module().GetSymbol("mixed_params").to_str() +
R"( -> __void
( (
VariableConst{ VariableConst{
x_14 x_14

View File

@ -1280,9 +1280,9 @@ Maybe<ast::Function*> ParserImpl::function_decl(ast::DecorationList& decos) {
if (errored) if (errored)
return Failure::kErrored; return Failure::kErrored;
return create<ast::Function>(header->source, header->name, header->params, return create<ast::Function>(
header->return_type, body.value, header->source, module_.RegisterSymbol(header->name), header->name,
func_decos.value); header->params, header->return_type, body.value, func_decos.value);
} }
// function_type_decl // function_type_decl

View File

@ -18,10 +18,14 @@ namespace tint {
SymbolTable::SymbolTable() = default; SymbolTable::SymbolTable() = default;
SymbolTable::SymbolTable(const SymbolTable&) = default;
SymbolTable::SymbolTable(SymbolTable&&) = default; SymbolTable::SymbolTable(SymbolTable&&) = default;
SymbolTable::~SymbolTable() = default; SymbolTable::~SymbolTable() = default;
SymbolTable& SymbolTable::operator=(const SymbolTable& other) = default;
SymbolTable& SymbolTable::operator=(SymbolTable&&) = default; SymbolTable& SymbolTable::operator=(SymbolTable&&) = default;
Symbol SymbolTable::Register(const std::string& name) { Symbol SymbolTable::Register(const std::string& name) {
@ -41,6 +45,11 @@ Symbol SymbolTable::Register(const std::string& name) {
return sym; return sym;
} }
Symbol SymbolTable::GetSymbol(const std::string& name) const {
auto it = name_to_symbol_.find(name);
return it != name_to_symbol_.end() ? it->second : Symbol();
}
std::string SymbolTable::NameFor(const Symbol symbol) const { std::string SymbolTable::NameFor(const Symbol symbol) const {
auto it = symbol_to_name_.find(symbol.value()); auto it = symbol_to_name_.find(symbol.value());
if (it == symbol_to_name_.end()) if (it == symbol_to_name_.end())

View File

@ -27,11 +27,17 @@ class SymbolTable {
public: public:
/// Constructor /// Constructor
SymbolTable(); SymbolTable();
/// Copy constructor
SymbolTable(const SymbolTable&);
/// Move Constructor /// Move Constructor
SymbolTable(SymbolTable&&); SymbolTable(SymbolTable&&);
/// Destructor /// Destructor
~SymbolTable(); ~SymbolTable();
/// Copy assignment
/// @param other the symbol table to copy
/// @returns the new symbol table
SymbolTable& operator=(const SymbolTable& other);
/// Move assignment /// Move assignment
/// @param other the symbol table to move /// @param other the symbol table to move
/// @returns the symbol table /// @returns the symbol table
@ -42,6 +48,11 @@ class SymbolTable {
/// @returns the symbol representing the given name /// @returns the symbol representing the given name
Symbol Register(const std::string& name); Symbol Register(const std::string& name);
/// Returns the symbol for the given `name`
/// @param name the name to lookup
/// @returns the symbol for the name or symbol::kInvalid if not found.
Symbol GetSymbol(const std::string& name) const;
/// Returns the name for the given symbol /// Returns the name for the given symbol
/// @param symbol the symbol to retrieve the name for /// @param symbol the symbol to retrieve the name for
/// @returns the symbol name or "" if not found /// @returns the symbol name or "" if not found

View File

@ -51,7 +51,7 @@ namespace {
template <typename T = ast::Expression> template <typename T = ast::Expression>
T* FindVariable(ast::Module* mod, std::string name) { T* FindVariable(ast::Module* mod, std::string name) {
if (auto* func = mod->FindFunctionByName("func")) { if (auto* func = mod->FindFunctionBySymbol(mod->RegisterSymbol("func"))) {
for (auto* stmt : *func->body()) { for (auto* stmt : *func->body()) {
if (auto* decl = stmt->As<ast::VariableDeclStatement>()) { if (auto* decl = stmt->As<ast::VariableDeclStatement>()) {
if (auto* var = decl->variable()) { if (auto* var = decl->variable()) {
@ -92,9 +92,9 @@ class BoundArrayAccessorsTest : public testing::Test {
struct ModuleBuilder : public ast::BuilderWithModule { struct ModuleBuilder : public ast::BuilderWithModule {
ModuleBuilder() : body_(create<ast::BlockStatement>()) { ModuleBuilder() : body_(create<ast::BlockStatement>()) {
mod->AddFunction(create<ast::Function>(Source{}, "func", mod->AddFunction(create<ast::Function>(
ast::VariableList{}, ty.void_, body_, Source{}, mod->RegisterSymbol("func"), "func", ast::VariableList{},
ast::FunctionDecorationList{})); ty.void_, body_, ast::FunctionDecorationList{}));
} }
ast::Module Module() { ast::Module Module() {

View File

@ -58,22 +58,25 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
Var("builtin_assignments_should_happen_before_this", Var("builtin_assignments_should_happen_before_this",
tint::ast::StorageClass::kFunction, ty.f32))); tint::ast::StorageClass::kFunction, ty.f32)));
mod->AddFunction( auto a_sym = mod->RegisterSymbol("non_entry_a");
create<ast::Function>(Source{}, "non_entry_a", ast::VariableList{}, mod->AddFunction(create<ast::Function>(
ty.void_, create<ast::BlockStatement>(Source{}), Source{}, a_sym, "non_entry_a", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{})); ast::FunctionDecorationList{}));
auto entry_sym = mod->RegisterSymbol("entry");
auto* entry = create<ast::Function>( auto* entry = create<ast::Function>(
Source{}, "entry", ast::VariableList{}, ty.void_, block, Source{}, entry_sym, "entry", ast::VariableList{}, ty.void_, block,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, create<ast::StageDecoration>(ast::PipelineStage::kVertex,
Source{}), Source{}),
}); });
mod->AddFunction(entry); mod->AddFunction(entry);
mod->AddFunction( auto b_sym = mod->RegisterSymbol("non_entry_b");
create<ast::Function>(Source{}, "non_entry_b", ast::VariableList{}, mod->AddFunction(create<ast::Function>(
ty.void_, create<ast::BlockStatement>(Source{}), Source{}, b_sym, "non_entry_b", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{})); ast::FunctionDecorationList{}));
} }
}; };
@ -82,7 +85,7 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
ASSERT_FALSE(result.diagnostics.contains_errors()) ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics); << diag::Formatter().format(result.diagnostics);
auto* expected = R"(Module{ auto expected = R"(Module{
Variable{ Variable{
Decorations{ Decorations{
BuiltinDecoration{pointsize} BuiltinDecoration{pointsize}
@ -91,11 +94,13 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
out out
__f32 __f32
} }
Function non_entry_a -> __void Function )" + result.module.RegisterSymbol("non_entry_a").to_str() +
R"( -> __void
() ()
{ {
} }
Function entry -> __void Function )" + result.module.RegisterSymbol("entry").to_str() +
R"( -> __void
StageDecoration{vertex} StageDecoration{vertex}
() ()
{ {
@ -111,7 +116,8 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
} }
} }
} }
Function non_entry_b -> __void Function )" + result.module.RegisterSymbol("non_entry_b").to_str() +
R"( -> __void
() ()
{ {
} }
@ -123,22 +129,25 @@ TEST_F(EmitVertexPointSizeTest, VertexStageBasic) {
TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) { TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
struct Builder : ModuleBuilder { struct Builder : ModuleBuilder {
void Build() override { void Build() override {
mod->AddFunction( auto a_sym = mod->RegisterSymbol("non_entry_a");
create<ast::Function>(Source{}, "non_entry_a", ast::VariableList{}, mod->AddFunction(create<ast::Function>(
ty.void_, create<ast::BlockStatement>(Source{}), Source{}, a_sym, "non_entry_a", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{})); ast::FunctionDecorationList{}));
mod->AddFunction( auto entry_sym = mod->RegisterSymbol("entry");
create<ast::Function>(Source{}, "entry", ast::VariableList{}, mod->AddFunction(create<ast::Function>(
ty.void_, create<ast::BlockStatement>(Source{}), Source{}, entry_sym, "entry", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>( create<ast::StageDecoration>(ast::PipelineStage::kVertex,
ast::PipelineStage::kVertex, Source{}), Source{}),
})); }));
mod->AddFunction( auto b_sym = mod->RegisterSymbol("non_entry_b");
create<ast::Function>(Source{}, "non_entry_b", ast::VariableList{}, mod->AddFunction(create<ast::Function>(
ty.void_, create<ast::BlockStatement>(Source{}), Source{}, b_sym, "non_entry_b", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{})); ast::FunctionDecorationList{}));
} }
}; };
@ -147,7 +156,7 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
ASSERT_FALSE(result.diagnostics.contains_errors()) ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics); << diag::Formatter().format(result.diagnostics);
auto* expected = R"(Module{ auto expected = R"(Module{
Variable{ Variable{
Decorations{ Decorations{
BuiltinDecoration{pointsize} BuiltinDecoration{pointsize}
@ -156,11 +165,13 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
out out
__f32 __f32
} }
Function non_entry_a -> __void Function )" + result.module.RegisterSymbol("non_entry_a").to_str() +
R"( -> __void
() ()
{ {
} }
Function entry -> __void Function )" + result.module.RegisterSymbol("entry").to_str() +
R"( -> __void
StageDecoration{vertex} StageDecoration{vertex}
() ()
{ {
@ -169,7 +180,8 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
ScalarConstructor[__f32]{1.000000} ScalarConstructor[__f32]{1.000000}
} }
} }
Function non_entry_b -> __void Function )" + result.module.RegisterSymbol("non_entry_b").to_str() +
R"( -> __void
() ()
{ {
} }
@ -181,8 +193,9 @@ TEST_F(EmitVertexPointSizeTest, VertexStageEmpty) {
TEST_F(EmitVertexPointSizeTest, NonVertexStage) { TEST_F(EmitVertexPointSizeTest, NonVertexStage) {
struct Builder : ModuleBuilder { struct Builder : ModuleBuilder {
void Build() override { void Build() override {
auto frag_sym = mod->RegisterSymbol("fragment_entry");
auto* fragment_entry = create<ast::Function>( auto* fragment_entry = create<ast::Function>(
Source{}, "fragment_entry", ast::VariableList{}, ty.void_, Source{}, frag_sym, "fragment_entry", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}), create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, create<ast::StageDecoration>(ast::PipelineStage::kFragment,
@ -190,12 +203,13 @@ TEST_F(EmitVertexPointSizeTest, NonVertexStage) {
}); });
mod->AddFunction(fragment_entry); mod->AddFunction(fragment_entry);
auto* compute_entry = auto comp_sym = mod->RegisterSymbol("compute_entry");
create<ast::Function>(Source{}, "compute_entry", ast::VariableList{}, auto* compute_entry = create<ast::Function>(
ty.void_, create<ast::BlockStatement>(Source{}), Source{}, comp_sym, "compute_entry", ast::VariableList{}, ty.void_,
create<ast::BlockStatement>(Source{}),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>( create<ast::StageDecoration>(ast::PipelineStage::kCompute,
ast::PipelineStage::kCompute, Source{}), Source{}),
}); });
mod->AddFunction(compute_entry); mod->AddFunction(compute_entry);
} }
@ -205,13 +219,15 @@ TEST_F(EmitVertexPointSizeTest, NonVertexStage) {
ASSERT_FALSE(result.diagnostics.contains_errors()) ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics); << diag::Formatter().format(result.diagnostics);
auto* expected = R"(Module{ auto expected = R"(Module{
Function fragment_entry -> __void Function )" + result.module.RegisterSymbol("fragment_entry").to_str() +
R"( -> __void
StageDecoration{fragment} StageDecoration{fragment}
() ()
{ {
} }
Function compute_entry -> __void Function )" + result.module.RegisterSymbol("compute_entry").to_str() +
R"( -> __void
StageDecoration{compute} StageDecoration{compute}
() ()
{ {

View File

@ -169,9 +169,9 @@ Transform::Output FirstIndexOffset::Run(ast::Module* in) {
body->append(ctx.Clone(s)); body->append(ctx.Clone(s));
} }
return ctx.mod->create<ast::Function>( return ctx.mod->create<ast::Function>(
ctx.Clone(func->source()), func->name(), ctx.Clone(func->params()), ctx.Clone(func->source()), func->symbol(), func->name(),
ctx.Clone(func->return_type()), ctx.Clone(body), ctx.Clone(func->params()), ctx.Clone(func->return_type()),
ctx.Clone(func->decorations())); ctx.Clone(body), ctx.Clone(func->decorations()));
}); });
in->Clone(&ctx); in->Clone(&ctx);

View File

@ -58,9 +58,9 @@ struct ModuleBuilder : public ast::BuilderWithModule {
ast::Function* AddFunction(const std::string& name, ast::Function* AddFunction(const std::string& name,
ast::VariableList params = {}) { ast::VariableList params = {}) {
auto* func = create<ast::Function>(Source{}, name, std::move(params), auto* func = create<ast::Function>(
ty.u32, create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol(name), name, std::move(params), ty.u32,
ast::FunctionDecorationList()); create<ast::BlockStatement>(), ast::FunctionDecorationList());
mod->AddFunction(func); mod->AddFunction(func);
return func; return func;
} }
@ -154,7 +154,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleVertexIndex) {
uniform uniform
__struct_TintFirstIndexOffsetData __struct_TintFirstIndexOffsetData
} }
Function test -> __u32 Function tint_symbol_1 -> __u32
() ()
{ {
VariableDeclStatement{ VariableDeclStatement{
@ -229,7 +229,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleInstanceIndex) {
uniform uniform
__struct_TintFirstIndexOffsetData __struct_TintFirstIndexOffsetData
} }
Function test -> __u32 Function tint_symbol_1 -> __u32
() ()
{ {
VariableDeclStatement{ VariableDeclStatement{
@ -317,7 +317,7 @@ TEST_F(FirstIndexOffsetTest, BasicModuleBothIndex) {
uniform uniform
__struct_TintFirstIndexOffsetData __struct_TintFirstIndexOffsetData
} }
Function test -> __u32 Function tint_symbol_1 -> __u32
() ()
{ {
Return{ Return{
@ -389,7 +389,7 @@ TEST_F(FirstIndexOffsetTest, NestedCalls) {
uniform uniform
__struct_TintFirstIndexOffsetData __struct_TintFirstIndexOffsetData
} }
Function func1 -> __u32 Function tint_symbol_1 -> __u32
() ()
{ {
VariableDeclStatement{ VariableDeclStatement{
@ -415,7 +415,7 @@ TEST_F(FirstIndexOffsetTest, NestedCalls) {
} }
} }
} }
Function func2 -> __u32 Function tint_symbol_2 -> __u32
() ()
{ {
Return{ Return{

View File

@ -84,8 +84,8 @@ Transform::Output VertexPulling::Run(ast::Module* in) {
} }
// Find entry point // Find entry point
auto* func = mod->FindFunctionByNameAndStage(cfg.entry_point_name, auto* func = mod->FindFunctionBySymbolAndStage(
ast::PipelineStage::kVertex); mod->GetSymbol(cfg.entry_point_name), ast::PipelineStage::kVertex);
if (func == nullptr) { if (func == nullptr) {
diag::Diagnostic err; diag::Diagnostic err;
err.severity = diag::Severity::Error; err.severity = diag::Severity::Error;
@ -94,9 +94,6 @@ Transform::Output VertexPulling::Run(ast::Module* in) {
return out; return out;
} }
// Save the vertex function
auto* vertex_func = mod->FindFunctionByName(func->name());
// TODO(idanr): Need to check shader locations in descriptor cover all // TODO(idanr): Need to check shader locations in descriptor cover all
// attributes // attributes
@ -108,7 +105,7 @@ Transform::Output VertexPulling::Run(ast::Module* in) {
state.FindOrInsertInstanceIndexIfUsed(); state.FindOrInsertInstanceIndexIfUsed();
state.ConvertVertexInputVariablesToPrivate(); state.ConvertVertexInputVariablesToPrivate();
state.AddVertexStorageBuffers(); state.AddVertexStorageBuffers();
state.AddVertexPullingPreamble(vertex_func); state.AddVertexPullingPreamble(func);
return out; return out;
} }

View File

@ -47,8 +47,8 @@ class VertexPullingHelper {
// Create basic module with an entry point and vertex function // Create basic module with an entry point and vertex function
void InitBasicModule() { void InitBasicModule() {
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", ast::VariableList{}, mod_->create<ast::type::Void>(), Source{}, mod_->RegisterSymbol("main"), "main", ast::VariableList{},
create<ast::BlockStatement>(), mod_->create<ast::type::Void>(), create<ast::BlockStatement>(),
ast::FunctionDecorationList{create<ast::StageDecoration>( ast::FunctionDecorationList{create<ast::StageDecoration>(
ast::PipelineStage::kVertex, Source{})}); ast::PipelineStage::kVertex, Source{})});
mod()->AddFunction(func); mod()->AddFunction(func);
@ -134,8 +134,8 @@ TEST_F(VertexPullingTest, Error_InvalidEntryPoint) {
TEST_F(VertexPullingTest, Error_EntryPointWrongStage) { TEST_F(VertexPullingTest, Error_EntryPointWrongStage) {
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", ast::VariableList{}, mod()->create<ast::type::Void>(), Source{}, mod()->RegisterSymbol("main"), "main", ast::VariableList{},
create<ast::BlockStatement>(), mod()->create<ast::type::Void>(), create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -152,7 +152,8 @@ TEST_F(VertexPullingTest, BasicModule) {
InitBasicModule(); InitBasicModule();
InitTransform({}); InitTransform({});
auto result = manager()->Run(mod()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors()); ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
} }
TEST_F(VertexPullingTest, OneAttribute) { TEST_F(VertexPullingTest, OneAttribute) {
@ -164,7 +165,8 @@ TEST_F(VertexPullingTest, OneAttribute) {
InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}); InitTransform({{{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}});
auto result = manager()->Run(mod()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors()); ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -193,7 +195,8 @@ TEST_F(VertexPullingTest, OneAttribute) {
storage_buffer storage_buffer
__struct_TintVertexData __struct_TintVertexData
} }
Function main -> __void Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex} StageDecoration{vertex}
() ()
{ {
@ -250,7 +253,8 @@ TEST_F(VertexPullingTest, OneInstancedAttribute) {
{{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}}); {{{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 0}}}}});
auto result = manager()->Run(mod()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors()); ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -279,7 +283,8 @@ TEST_F(VertexPullingTest, OneInstancedAttribute) {
storage_buffer storage_buffer
__struct_TintVertexData __struct_TintVertexData
} }
Function main -> __void Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex} StageDecoration{vertex}
() ()
{ {
@ -336,7 +341,8 @@ TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
transform()->SetPullingBufferBindingSet(5); transform()->SetPullingBufferBindingSet(5);
auto result = manager()->Run(mod()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors()); ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -365,7 +371,8 @@ TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) {
storage_buffer storage_buffer
__struct_TintVertexData __struct_TintVertexData
} }
Function main -> __void Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex} StageDecoration{vertex}
() ()
{ {
@ -451,7 +458,8 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
{4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}}); {4, InputStepMode::kInstance, {{VertexFormat::kF32, 0, 1}}}}});
auto result = manager()->Run(mod()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors()); ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -502,7 +510,8 @@ TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) {
storage_buffer storage_buffer
__struct_TintVertexData __struct_TintVertexData
} }
Function main -> __void Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex} StageDecoration{vertex}
() ()
{ {
@ -592,7 +601,8 @@ TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
{{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}}); {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}});
auto result = manager()->Run(mod()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors()); ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -626,7 +636,8 @@ TEST_F(VertexPullingTest, TwoAttributesSameBuffer) {
storage_buffer storage_buffer
__struct_TintVertexData __struct_TintVertexData
} }
Function main -> __void Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex} StageDecoration{vertex}
() ()
{ {
@ -778,7 +789,8 @@ TEST_F(VertexPullingTest, FloatVectorAttributes) {
{16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}}); {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}}});
auto result = manager()->Run(mod()); auto result = manager()->Run(mod());
ASSERT_FALSE(result.diagnostics.contains_errors()); ASSERT_FALSE(result.diagnostics.contains_errors())
<< diag::Formatter().format(result.diagnostics);
EXPECT_EQ(R"(Module{ EXPECT_EQ(R"(Module{
TintVertexData Struct{ TintVertexData Struct{
@ -835,7 +847,8 @@ TEST_F(VertexPullingTest, FloatVectorAttributes) {
storage_buffer storage_buffer
__struct_TintVertexData __struct_TintVertexData
} }
Function main -> __void Function )" + result.module.GetSymbol("main").to_str() +
R"( -> __void
StageDecoration{vertex} StageDecoration{vertex}
() ()
{ {

View File

@ -122,7 +122,7 @@ bool TypeDeterminer::Determine() {
continue; continue;
} }
for (const auto& callee : caller_to_callee_[func->name()]) { for (const auto& callee : caller_to_callee_[func->name()]) {
set_entry_points(callee, func->name()); set_entry_points(callee, func->symbol());
} }
} }
@ -130,11 +130,11 @@ bool TypeDeterminer::Determine() {
} }
void TypeDeterminer::set_entry_points(const std::string& fn_name, void TypeDeterminer::set_entry_points(const std::string& fn_name,
const std::string& ep_name) { Symbol ep_sym) {
name_to_function_[fn_name]->add_ancestor_entry_point(ep_name); name_to_function_[fn_name]->add_ancestor_entry_point(ep_sym);
for (const auto& callee : caller_to_callee_[fn_name]) { for (const auto& callee : caller_to_callee_[fn_name]) {
set_entry_points(callee, ep_name); set_entry_points(callee, ep_sym);
} }
} }
@ -389,7 +389,8 @@ bool TypeDeterminer::DetermineCall(ast::CallExpression* expr) {
if (current_function_) { if (current_function_) {
caller_to_callee_[current_function_->name()].push_back(ident->name()); caller_to_callee_[current_function_->name()].push_back(ident->name());
auto* callee_func = mod_->FindFunctionByName(ident->name()); auto* callee_func =
mod_->FindFunctionBySymbol(mod_->GetSymbol(ident->name()));
if (callee_func == nullptr) { if (callee_func == nullptr) {
set_error(expr->source(), set_error(expr->source(),
"unable to find called function: " + ident->name()); "unable to find called function: " + ident->name());

View File

@ -113,7 +113,7 @@ class TypeDeterminer {
private: private:
void set_error(const Source& src, const std::string& msg); void set_error(const Source& src, const std::string& msg);
void set_referenced_from_function_if_needed(ast::Variable* var, bool local); void set_referenced_from_function_if_needed(ast::Variable* var, bool local);
void set_entry_points(const std::string& fn_name, const std::string& ep_name); void set_entry_points(const std::string& fn_name, Symbol ep_sym);
bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr); bool DetermineArrayAccessor(ast::ArrayAccessorExpression* expr);
bool DetermineBinary(ast::BinaryExpression* expr); bool DetermineBinary(ast::BinaryExpression* expr);

View File

@ -341,9 +341,9 @@ TEST_F(TypeDeterminerTest, Stmt_Call) {
ast::type::F32 f32; ast::type::F32 f32;
ast::VariableList params; ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, auto* func = create<ast::Function>(
create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
ast::FunctionDecorationList{}); create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
// Register the function // Register the function
@ -372,15 +372,16 @@ TEST_F(TypeDeterminerTest, Stmt_Call_undeclared) {
auto* main_body = create<ast::BlockStatement>(); auto* main_body = create<ast::BlockStatement>();
main_body->append(create<ast::CallStatement>(call_expr)); main_body->append(create<ast::CallStatement>(call_expr));
main_body->append(create<ast::ReturnStatement>(Source{})); main_body->append(create<ast::ReturnStatement>(Source{}));
auto* func_main = auto* func_main = create<ast::Function>(Source{}, mod->RegisterSymbol("main"),
create<ast::Function>(Source{}, "main", params0, &f32, main_body, "main", params0, &f32, main_body,
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
mod->AddFunction(func_main); mod->AddFunction(func_main);
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "func", params0, &f32, body, auto* func =
ast::FunctionDecorationList{}); create<ast::Function>(Source{}, mod->RegisterSymbol("func"), "func",
params0, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
EXPECT_FALSE(td()->Determine()) << td()->error(); EXPECT_FALSE(td()->Determine()) << td()->error();
@ -639,9 +640,9 @@ TEST_F(TypeDeterminerTest, Expr_Call) {
ast::type::F32 f32; ast::type::F32 f32;
ast::VariableList params; ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, auto* func = create<ast::Function>(
create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
ast::FunctionDecorationList{}); create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
// Register the function // Register the function
@ -659,9 +660,9 @@ TEST_F(TypeDeterminerTest, Expr_Call_WithParams) {
ast::type::F32 f32; ast::type::F32 f32;
ast::VariableList params; ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, auto* func = create<ast::Function>(
create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
ast::FunctionDecorationList{}); create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
// Register the function // Register the function
@ -809,8 +810,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable_Const) {
body->append(create<ast::AssignmentStatement>( body->append(create<ast::AssignmentStatement>(
my_var, create<ast::IdentifierExpression>("my_var"))); my_var, create<ast::IdentifierExpression>("my_var")));
ast::Function f(Source{}, "my_func", {}, &f32, body, ast::Function f(Source{}, mod->RegisterSymbol("my_func"), "my_func", {}, &f32,
ast::FunctionDecorationList{}); body, ast::FunctionDecorationList{});
EXPECT_TRUE(td()->DetermineFunction(&f)); EXPECT_TRUE(td()->DetermineFunction(&f));
@ -836,8 +837,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_FunctionVariable) {
body->append(create<ast::AssignmentStatement>( body->append(create<ast::AssignmentStatement>(
my_var, create<ast::IdentifierExpression>("my_var"))); my_var, create<ast::IdentifierExpression>("my_var")));
ast::Function f(Source{}, "my_func", {}, &f32, body, ast::Function f(Source{}, mod->RegisterSymbol("myfunc"), "my_func", {}, &f32,
ast::FunctionDecorationList{}); body, ast::FunctionDecorationList{});
EXPECT_TRUE(td()->DetermineFunction(&f)); EXPECT_TRUE(td()->DetermineFunction(&f));
@ -868,8 +869,8 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function_Ptr) {
body->append(create<ast::AssignmentStatement>( body->append(create<ast::AssignmentStatement>(
my_var, create<ast::IdentifierExpression>("my_var"))); my_var, create<ast::IdentifierExpression>("my_var")));
ast::Function f(Source{}, "my_func", {}, &f32, body, ast::Function f(Source{}, mod->RegisterSymbol("my_func"), "my_func", {}, &f32,
ast::FunctionDecorationList{}); body, ast::FunctionDecorationList{});
EXPECT_TRUE(td()->DetermineFunction(&f)); EXPECT_TRUE(td()->DetermineFunction(&f));
@ -885,9 +886,9 @@ TEST_F(TypeDeterminerTest, Expr_Identifier_Function) {
ast::type::F32 f32; ast::type::F32 f32;
ast::VariableList params; ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, auto* func = create<ast::Function>(
create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("my_func"), "my_func", params, &f32,
ast::FunctionDecorationList{}); create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
// Register the function // Register the function
@ -968,8 +969,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables) {
body->append(create<ast::AssignmentStatement>( body->append(create<ast::AssignmentStatement>(
create<ast::IdentifierExpression>("priv_var"), create<ast::IdentifierExpression>("priv_var"),
create<ast::IdentifierExpression>("priv_var"))); create<ast::IdentifierExpression>("priv_var")));
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body, auto* func =
ast::FunctionDecorationList{}); create<ast::Function>(Source{}, mod->RegisterSymbol("my_func"), "my_func",
params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
@ -1049,8 +1051,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) {
create<ast::IdentifierExpression>("priv_var"), create<ast::IdentifierExpression>("priv_var"),
create<ast::IdentifierExpression>("priv_var"))); create<ast::IdentifierExpression>("priv_var")));
ast::VariableList params; ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body, auto* func =
ast::FunctionDecorationList{}); create<ast::Function>(Source{}, mod->RegisterSymbol("my_func"), "my_func",
params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
@ -1059,8 +1062,9 @@ TEST_F(TypeDeterminerTest, Function_RegisterInputOutputVariables_SubFunction) {
create<ast::IdentifierExpression>("out_var"), create<ast::IdentifierExpression>("out_var"),
create<ast::CallExpression>(create<ast::IdentifierExpression>("my_func"), create<ast::CallExpression>(create<ast::IdentifierExpression>("my_func"),
ast::ExpressionList{}))); ast::ExpressionList{})));
auto* func2 = create<ast::Function>(Source{}, "func", params, &f32, body, auto* func2 =
ast::FunctionDecorationList{}); create<ast::Function>(Source{}, mod->RegisterSymbol("func"), "func",
params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func2); mod->AddFunction(func2);
@ -1096,8 +1100,9 @@ TEST_F(TypeDeterminerTest, Function_NotRegisterFunctionVariable) {
create<ast::FloatLiteral>(&f32, 1.f)))); create<ast::FloatLiteral>(&f32, 1.f))));
ast::VariableList params; ast::VariableList params;
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body, auto* func =
ast::FunctionDecorationList{}); create<ast::Function>(Source{}, mod->RegisterSymbol("my_func"), "my_func",
params, &f32, body, ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
@ -2636,8 +2641,9 @@ TEST_F(TypeDeterminerTest, StorageClass_SetsIfMissing) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(stmt); body->append(stmt);
auto* func = create<ast::Function>(Source{}, "func", ast::VariableList{}, auto* func = create<ast::Function>(Source{}, mod->RegisterSymbol("func"),
&i32, body, ast::FunctionDecorationList{}); "func", ast::VariableList{}, &i32, body,
ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
@ -2660,8 +2666,9 @@ TEST_F(TypeDeterminerTest, StorageClass_DoesNotSetOnConst) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(stmt); body->append(stmt);
auto* func = create<ast::Function>(Source{}, "func", ast::VariableList{}, auto* func = create<ast::Function>(Source{}, mod->RegisterSymbol("func"),
&i32, body, ast::FunctionDecorationList{}); "func", ast::VariableList{}, &i32, body,
ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
@ -2684,8 +2691,9 @@ TEST_F(TypeDeterminerTest, StorageClass_NonFunctionClassError) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(stmt); body->append(stmt);
auto* func = create<ast::Function>(Source{}, "func", ast::VariableList{}, auto* func = create<ast::Function>(Source{}, mod->RegisterSymbol("func"),
&i32, body, ast::FunctionDecorationList{}); "func", ast::VariableList{}, &i32, body,
ast::FunctionDecorationList{});
mod->AddFunction(func); mod->AddFunction(func);
@ -4857,24 +4865,27 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
ast::VariableList params; ast::VariableList params;
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
auto* func_b = create<ast::Function>(Source{}, "b", params, &f32, body, auto* func_b =
ast::FunctionDecorationList{}); create<ast::Function>(Source{}, mod->RegisterSymbol("b"), "b", params,
&f32, body, ast::FunctionDecorationList{});
body = create<ast::BlockStatement>(); body = create<ast::BlockStatement>();
body->append(create<ast::AssignmentStatement>( body->append(create<ast::AssignmentStatement>(
create<ast::IdentifierExpression>("second"), create<ast::IdentifierExpression>("second"),
create<ast::CallExpression>(create<ast::IdentifierExpression>("b"), create<ast::CallExpression>(create<ast::IdentifierExpression>("b"),
ast::ExpressionList{}))); ast::ExpressionList{})));
auto* func_c = create<ast::Function>(Source{}, "c", params, &f32, body, auto* func_c =
ast::FunctionDecorationList{}); create<ast::Function>(Source{}, mod->RegisterSymbol("c"), "c", params,
&f32, body, ast::FunctionDecorationList{});
body = create<ast::BlockStatement>(); body = create<ast::BlockStatement>();
body->append(create<ast::AssignmentStatement>( body->append(create<ast::AssignmentStatement>(
create<ast::IdentifierExpression>("first"), create<ast::IdentifierExpression>("first"),
create<ast::CallExpression>(create<ast::IdentifierExpression>("c"), create<ast::CallExpression>(create<ast::IdentifierExpression>("c"),
ast::ExpressionList{}))); ast::ExpressionList{})));
auto* func_a = create<ast::Function>(Source{}, "a", params, &f32, body, auto* func_a =
ast::FunctionDecorationList{}); create<ast::Function>(Source{}, mod->RegisterSymbol("a"), "a", params,
&f32, body, ast::FunctionDecorationList{});
body = create<ast::BlockStatement>(); body = create<ast::BlockStatement>();
body->append(create<ast::AssignmentStatement>( body->append(create<ast::AssignmentStatement>(
@ -4886,7 +4897,7 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
create<ast::CallExpression>(create<ast::IdentifierExpression>("b"), create<ast::CallExpression>(create<ast::IdentifierExpression>("b"),
ast::ExpressionList{}))); ast::ExpressionList{})));
auto* ep_1 = create<ast::Function>( auto* ep_1 = create<ast::Function>(
Source{}, "ep_1", params, &f32, body, Source{}, mod->RegisterSymbol("ep_1"), "ep_1", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -4897,7 +4908,7 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
create<ast::CallExpression>(create<ast::IdentifierExpression>("c"), create<ast::CallExpression>(create<ast::IdentifierExpression>("c"),
ast::ExpressionList{}))); ast::ExpressionList{})));
auto* ep_2 = create<ast::Function>( auto* ep_2 = create<ast::Function>(
Source{}, "ep_2", params, &f32, body, Source{}, mod->RegisterSymbol("ep_2"), "ep_2", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -4954,17 +4965,17 @@ TEST_F(TypeDeterminerTest, Function_EntryPoints_StageDecoration) {
const auto& b_eps = func_b->ancestor_entry_points(); const auto& b_eps = func_b->ancestor_entry_points();
ASSERT_EQ(2u, b_eps.size()); ASSERT_EQ(2u, b_eps.size());
EXPECT_EQ("ep_1", b_eps[0]); EXPECT_EQ(mod->RegisterSymbol("ep_1"), b_eps[0]);
EXPECT_EQ("ep_2", b_eps[1]); EXPECT_EQ(mod->RegisterSymbol("ep_2"), b_eps[1]);
const auto& a_eps = func_a->ancestor_entry_points(); const auto& a_eps = func_a->ancestor_entry_points();
ASSERT_EQ(1u, a_eps.size()); ASSERT_EQ(1u, a_eps.size());
EXPECT_EQ("ep_1", a_eps[0]); EXPECT_EQ(mod->RegisterSymbol("ep_1"), a_eps[0]);
const auto& c_eps = func_c->ancestor_entry_points(); const auto& c_eps = func_c->ancestor_entry_points();
ASSERT_EQ(2u, c_eps.size()); ASSERT_EQ(2u, c_eps.size());
EXPECT_EQ("ep_1", c_eps[0]); EXPECT_EQ(mod->RegisterSymbol("ep_1"), c_eps[0]);
EXPECT_EQ("ep_2", c_eps[1]); EXPECT_EQ(mod->RegisterSymbol("ep_2"), c_eps[1]);
EXPECT_TRUE(ep_1->ancestor_entry_points().empty()); EXPECT_TRUE(ep_1->ancestor_entry_points().empty());
EXPECT_TRUE(ep_2->ancestor_entry_points().empty()); EXPECT_TRUE(ep_2->ancestor_entry_points().empty());

View File

@ -54,7 +54,8 @@ TEST_F(ValidateFunctionTest, VoidFunctionEndWithoutReturnStatement_Pass) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "func", params, &void_type, body, Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -71,8 +72,8 @@ TEST_F(ValidateFunctionTest,
ast::type::Void void_type; ast::type::Void void_type;
ast::VariableList params; ast::VariableList params;
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "func", params, &void_type, Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
create<ast::BlockStatement>(), params, &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -100,9 +101,9 @@ TEST_F(ValidateFunctionTest, FunctionEndWithoutReturnStatement_Fail) {
ast::type::Void void_type; ast::type::Void void_type;
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
auto* func = auto* func = create<ast::Function>(
create<ast::Function>(Source{Source::Location{12, 34}}, "func", params, Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
&i32, body, ast::FunctionDecorationList{}); params, &i32, body, ast::FunctionDecorationList{});
mod()->AddFunction(func); mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error(); EXPECT_TRUE(td()->Determine()) << td()->error();
@ -117,8 +118,9 @@ TEST_F(ValidateFunctionTest, FunctionEndWithoutReturnStatementEmptyBody_Fail) {
ast::type::I32 i32; ast::type::I32 i32;
ast::VariableList params; ast::VariableList params;
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "func", params, &i32, Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
create<ast::BlockStatement>(), ast::FunctionDecorationList{}); params, &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{});
mod()->AddFunction(func); mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error(); EXPECT_TRUE(td()->Determine()) << td()->error();
@ -136,7 +138,7 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_Pass) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "func", params, &void_type, body, Source{}, mod()->RegisterSymbol("func"), "func", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -157,7 +159,8 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementType_fail) {
body->append(create<ast::ReturnStatement>(Source{Source::Location{12, 34}}, body->append(create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
return_expr)); return_expr));
auto* func = create<ast::Function>(Source{}, "func", params, &void_type, body, auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("func"),
"func", params, &void_type, body,
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
mod()->AddFunction(func); mod()->AddFunction(func);
@ -180,8 +183,9 @@ TEST_F(ValidateFunctionTest, FunctionTypeMustMatchReturnStatementTypeF32_fail) {
body->append(create<ast::ReturnStatement>(Source{Source::Location{12, 34}}, body->append(create<ast::ReturnStatement>(Source{Source::Location{12, 34}},
return_expr)); return_expr));
auto* func = create<ast::Function>(Source{}, "func", params, &f32, body, auto* func =
ast::FunctionDecorationList{}); create<ast::Function>(Source{}, mod()->RegisterSymbol("func"), "func",
params, &f32, body, ast::FunctionDecorationList{});
mod()->AddFunction(func); mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error(); EXPECT_TRUE(td()->Determine()) << td()->error();
@ -204,8 +208,9 @@ TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) {
create<ast::SintLiteral>(&i32, 2)); create<ast::SintLiteral>(&i32, 2));
body->append(create<ast::ReturnStatement>(Source{}, return_expr)); body->append(create<ast::ReturnStatement>(Source{}, return_expr));
auto* func = create<ast::Function>(Source{}, "func", params, &i32, body, auto* func =
ast::FunctionDecorationList{}); create<ast::Function>(Source{}, mod()->RegisterSymbol("func"), "func",
params, &i32, body, ast::FunctionDecorationList{});
ast::VariableList params_copy; ast::VariableList params_copy;
auto* body_copy = create<ast::BlockStatement>(); auto* body_copy = create<ast::BlockStatement>();
@ -213,9 +218,9 @@ TEST_F(ValidateFunctionTest, FunctionNamesMustBeUnique_fail) {
create<ast::SintLiteral>(&i32, 2)); create<ast::SintLiteral>(&i32, 2));
body_copy->append(create<ast::ReturnStatement>(Source{}, return_expr_copy)); body_copy->append(create<ast::ReturnStatement>(Source{}, return_expr_copy));
auto* func_copy = create<ast::Function>(Source{Source::Location{12, 34}}, auto* func_copy = create<ast::Function>(
"func", params_copy, &i32, body_copy, Source{Source::Location{12, 34}}, mod()->RegisterSymbol("func"), "func",
ast::FunctionDecorationList{}); params_copy, &i32, body_copy, ast::FunctionDecorationList{});
mod()->AddFunction(func); mod()->AddFunction(func);
mod()->AddFunction(func_copy); mod()->AddFunction(func_copy);
@ -237,7 +242,8 @@ TEST_F(ValidateFunctionTest, RecursionIsNotAllowed_Fail) {
auto* body0 = create<ast::BlockStatement>(); auto* body0 = create<ast::BlockStatement>();
body0->append(create<ast::CallStatement>(call_expr)); body0->append(create<ast::CallStatement>(call_expr));
body0->append(create<ast::ReturnStatement>(Source{})); body0->append(create<ast::ReturnStatement>(Source{}));
auto* func0 = create<ast::Function>(Source{}, "func", params0, &f32, body0, auto* func0 = create<ast::Function>(Source{}, mod()->RegisterSymbol("func"),
"func", params0, &f32, body0,
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
mod()->AddFunction(func0); mod()->AddFunction(func0);
@ -268,7 +274,8 @@ TEST_F(ValidateFunctionTest, RecursionIsNotAllowedExpr_Fail) {
create<ast::SintLiteral>(&i32, 2)); create<ast::SintLiteral>(&i32, 2));
body0->append(create<ast::ReturnStatement>(Source{}, return_expr)); body0->append(create<ast::ReturnStatement>(Source{}, return_expr));
auto* func0 = create<ast::Function>(Source{}, "func", params0, &i32, body0, auto* func0 = create<ast::Function>(Source{}, mod()->RegisterSymbol("func"),
"func", params0, &i32, body0,
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
mod()->AddFunction(func0); mod()->AddFunction(func0);
@ -288,7 +295,8 @@ TEST_F(ValidateFunctionTest, Function_WithPipelineStage_NotVoid_Fail) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{}, return_expr)); body->append(create<ast::ReturnStatement>(Source{}, return_expr));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "vtx_main", params, &i32, body, Source{Source::Location{12, 34}}, mod()->RegisterSymbol("vtx_main"),
"vtx_main", params, &i32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -317,7 +325,8 @@ TEST_F(ValidateFunctionTest, Function_WithPipelineStage_WithParams_Fail) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "vtx_func", params, &void_type, body, Source{Source::Location{12, 34}}, mod()->RegisterSymbol("vtx_func"),
"vtx_func", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -339,7 +348,8 @@ TEST_F(ValidateFunctionTest, PipelineStage_MustBeUnique_Fail) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{Source::Location{12, 34}}, "main", params, &void_type, body, Source{Source::Location{12, 34}}, mod()->RegisterSymbol("main"), "main",
params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
@ -361,7 +371,8 @@ TEST_F(ValidateFunctionTest, OnePipelineStageFunctionMustBePresent_Pass) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "vtx_func", params, &void_type, body, Source{}, mod()->RegisterSymbol("vtx_func"), "vtx_func", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -377,8 +388,9 @@ TEST_F(ValidateFunctionTest, OnePipelineStageFunctionMustBePresent_Fail) {
ast::VariableList params; ast::VariableList params;
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "vtx_func", params, &void_type, auto* func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod()->RegisterSymbol("vtx_func"), "vtx_func", params,
&void_type, body, ast::FunctionDecorationList{});
mod()->AddFunction(func); mod()->AddFunction(func);
EXPECT_TRUE(td()->Determine()) << td()->error(); EXPECT_TRUE(td()->Determine()) << td()->error();

View File

@ -332,7 +332,8 @@ TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariable_Fail) {
body->append(create<ast::AssignmentStatement>( body->append(create<ast::AssignmentStatement>(
Source{Source::Location{12, 34}}, lhs, rhs)); Source{Source::Location{12, 34}}, lhs, rhs));
auto* func = create<ast::Function>(Source{}, "my_func", params, &f32, body, auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("my_func"),
"my_func", params, &f32, body,
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
mod()->AddFunction(func); mod()->AddFunction(func);
@ -370,7 +371,8 @@ TEST_F(ValidatorTest, UsingUndefinedVariableGlobalVariable_Pass) {
Source{Source::Location{12, 34}}, lhs, rhs)); Source{Source::Location{12, 34}}, lhs, rhs));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "my_func", params, &void_type, body, Source{}, mod()->RegisterSymbol("my_func"), "my_func", params, &void_type,
body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -587,8 +589,9 @@ TEST_F(ValidatorTest, GlobalVariableFunctionVariableNotUnique_Fail) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>( body->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var)); Source{Source::Location{12, 34}}, var));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type, auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("my_func"),
body, ast::FunctionDecorationList{}); "my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod()->AddFunction(func); mod()->AddFunction(func);
@ -631,8 +634,9 @@ TEST_F(ValidatorTest, RedeclaredIndentifier_Fail) {
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::VariableDeclStatement>( body->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var_a_float)); Source{Source::Location{12, 34}}, var_a_float));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type, auto* func = create<ast::Function>(Source{}, mod()->RegisterSymbol("my_func"),
body, ast::FunctionDecorationList{}); "my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod()->AddFunction(func); mod()->AddFunction(func);
@ -759,8 +763,9 @@ TEST_F(ValidatorTest, RedeclaredIdentifierDifferentFunctions_Pass) {
body0->append(create<ast::VariableDeclStatement>( body0->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var0)); Source{Source::Location{12, 34}}, var0));
body0->append(create<ast::ReturnStatement>(Source{})); body0->append(create<ast::ReturnStatement>(Source{}));
auto* func0 = create<ast::Function>(Source{}, "func0", params0, &void_type, auto* func0 = create<ast::Function>(Source{}, mod()->RegisterSymbol("func0"),
body0, ast::FunctionDecorationList{}); "func0", params0, &void_type, body0,
ast::FunctionDecorationList{});
ast::VariableList params1; ast::VariableList params1;
auto* body1 = create<ast::BlockStatement>(); auto* body1 = create<ast::BlockStatement>();
@ -768,7 +773,8 @@ TEST_F(ValidatorTest, RedeclaredIdentifierDifferentFunctions_Pass) {
Source{Source::Location{13, 34}}, var1)); Source{Source::Location{13, 34}}, var1));
body1->append(create<ast::ReturnStatement>(Source{})); body1->append(create<ast::ReturnStatement>(Source{}));
auto* func1 = create<ast::Function>( auto* func1 = create<ast::Function>(
Source{}, "func1", params1, &void_type, body1, Source{}, mod()->RegisterSymbol("func1"), "func1", params1, &void_type,
body1,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });

View File

@ -206,8 +206,9 @@ TEST_F(ValidatorTypeTest, RuntimeArrayInFunction_Fail) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>( body->append(create<ast::VariableDeclStatement>(
Source{Source::Location{12, 34}}, var)); Source{Source::Location{12, 34}}, var));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "func", params, &void_type, body, Source{}, mod()->RegisterSymbol("func"), "func", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });

View File

@ -196,15 +196,15 @@ std::string GeneratorImpl::current_ep_var_name(VarType type) {
std::string name = ""; std::string name = "";
switch (type) { switch (type) {
case VarType::kIn: { case VarType::kIn: {
auto in_it = ep_name_to_in_data_.find(current_ep_name_); auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
if (in_it != ep_name_to_in_data_.end()) { if (in_it != ep_sym_to_in_data_.end()) {
name = in_it->second.var_name; name = in_it->second.var_name;
} }
break; break;
} }
case VarType::kOut: { case VarType::kOut: {
auto outit = ep_name_to_out_data_.find(current_ep_name_); auto outit = ep_sym_to_out_data_.find(current_ep_sym_.value());
if (outit != ep_name_to_out_data_.end()) { if (outit != ep_sym_to_out_data_.end()) {
name = outit->second.var_name; name = outit->second.var_name;
} }
break; break;
@ -668,12 +668,14 @@ bool GeneratorImpl::EmitCall(std::ostream& pre,
} }
auto name = ident->name(); auto name = ident->name();
auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name); auto caller_sym = module_->GetSymbol(name);
auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" +
caller_sym.to_str());
if (it != ep_func_name_remapped_.end()) { if (it != ep_func_name_remapped_.end()) {
name = it->second; name = it->second;
} }
auto* func = module_->FindFunctionByName(ident->name()); auto* func = module_->FindFunctionBySymbol(module_->GetSymbol(ident->name()));
if (func == nullptr) { if (func == nullptr) {
error_ = "Unable to find function: " + name; error_ = "Unable to find function: " + name;
return false; return false;
@ -1189,15 +1191,15 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) {
has_referenced_var_needing_struct(func); has_referenced_var_needing_struct(func);
if (emit_duplicate_functions) { if (emit_duplicate_functions) {
for (const auto& ep_name : func->ancestor_entry_points()) { for (const auto& ep_sym : func->ancestor_entry_points()) {
if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_name)) { if (!EmitFunctionInternal(out, func, emit_duplicate_functions, ep_sym)) {
return false; return false;
} }
out << std::endl; out << std::endl;
} }
} else { } else {
// Emit as non-duplicated // Emit as non-duplicated
if (!EmitFunctionInternal(out, func, false, "")) { if (!EmitFunctionInternal(out, func, false, Symbol())) {
return false; return false;
} }
out << std::endl; out << std::endl;
@ -1209,8 +1211,8 @@ bool GeneratorImpl::EmitFunction(std::ostream& out, ast::Function* func) {
bool GeneratorImpl::EmitFunctionInternal(std::ostream& out, bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
ast::Function* func, ast::Function* func,
bool emit_duplicate_functions, bool emit_duplicate_functions,
const std::string& ep_name) { Symbol ep_sym) {
auto name = func->name(); auto name = func->symbol().to_str();
if (!EmitType(out, func->return_type(), "")) { if (!EmitType(out, func->return_type(), "")) {
return false; return false;
@ -1219,10 +1221,15 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
out << " "; out << " ";
if (emit_duplicate_functions) { if (emit_duplicate_functions) {
name = generate_name(name + "_" + ep_name); auto func_name = name;
ep_func_name_remapped_[ep_name + "_" + func->name()] = name; auto ep_name = ep_sym.to_str();
// TODO(dsinclair): The SymbolToName should go away and just use
// to_str() here when the conversion is complete.
name = generate_name(func->name() + "_" + module_->SymbolToName(ep_sym));
ep_func_name_remapped_[ep_name + "_" + func_name] = name;
} else { } else {
name = namer_.NameFor(name); // TODO(dsinclair): this should be updated to a remapped name
name = namer_.NameFor(func->name());
} }
out << name << "("; out << name << "(";
@ -1234,15 +1241,15 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
// //
// We emit both of them if they're there regardless of if they're both used. // We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) { if (emit_duplicate_functions) {
auto in_it = ep_name_to_in_data_.find(ep_name); auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
if (in_it != ep_name_to_in_data_.end()) { if (in_it != ep_sym_to_in_data_.end()) {
out << "in " << in_it->second.struct_name << " " out << "in " << in_it->second.struct_name << " "
<< in_it->second.var_name; << in_it->second.var_name;
first = false; first = false;
} }
auto outit = ep_name_to_out_data_.find(ep_name); auto outit = ep_sym_to_out_data_.find(ep_sym.value());
if (outit != ep_name_to_out_data_.end()) { if (outit != ep_sym_to_out_data_.end()) {
if (!first) { if (!first) {
out << ", "; out << ", ";
} }
@ -1269,13 +1276,13 @@ bool GeneratorImpl::EmitFunctionInternal(std::ostream& out,
out << ") "; out << ") ";
current_ep_name_ = ep_name; current_ep_sym_ = ep_sym;
if (!EmitBlockAndNewline(out, func->body())) { if (!EmitBlockAndNewline(out, func->body())) {
return false; return false;
} }
current_ep_name_ = ""; current_ep_sym_ = Symbol();
return true; return true;
} }
@ -1392,7 +1399,7 @@ bool GeneratorImpl::EmitEntryPointData(
auto in_struct_name = auto in_struct_name =
generate_name(func->name() + "_" + kInStructNameSuffix); generate_name(func->name() + "_" + kInStructNameSuffix);
auto in_var_name = generate_name(kTintStructInVarPrefix); auto in_var_name = generate_name(kTintStructInVarPrefix);
ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name}; ep_sym_to_in_data_[func->symbol().value()] = {in_struct_name, in_var_name};
make_indent(out); make_indent(out);
out << "struct " << in_struct_name << " {" << std::endl; out << "struct " << in_struct_name << " {" << std::endl;
@ -1438,7 +1445,7 @@ bool GeneratorImpl::EmitEntryPointData(
auto outstruct_name = auto outstruct_name =
generate_name(func->name() + "_" + kOutStructNameSuffix); generate_name(func->name() + "_" + kOutStructNameSuffix);
auto outvar_name = generate_name(kTintStructOutVarPrefix); auto outvar_name = generate_name(kTintStructOutVarPrefix);
ep_name_to_out_data_[func->name()] = {outstruct_name, outvar_name}; ep_sym_to_out_data_[func->symbol().value()] = {outstruct_name, outvar_name};
make_indent(out); make_indent(out);
out << "struct " << outstruct_name << " {" << std::endl; out << "struct " << outstruct_name << " {" << std::endl;
@ -1516,7 +1523,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
ast::Function* func) { ast::Function* func) {
make_indent(out); make_indent(out);
current_ep_name_ = func->name(); current_ep_sym_ = func->symbol();
if (func->pipeline_stage() == ast::PipelineStage::kCompute) { if (func->pipeline_stage() == ast::PipelineStage::kCompute) {
uint32_t x = 0; uint32_t x = 0;
@ -1528,17 +1535,18 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
make_indent(out); make_indent(out);
} }
auto outdata = ep_name_to_out_data_.find(current_ep_name_); auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
bool has_outdata = outdata != ep_name_to_out_data_.end(); bool has_outdata = outdata != ep_sym_to_out_data_.end();
if (has_outdata) { if (has_outdata) {
out << outdata->second.struct_name; out << outdata->second.struct_name;
} else { } else {
out << "void"; out << "void";
} }
out << " " << namer_.NameFor(current_ep_name_) << "("; // TODO(dsinclair): This should output the remapped name
out << " " << namer_.NameFor(module_->SymbolToName(current_ep_sym_)) << "(";
auto in_data = ep_name_to_in_data_.find(current_ep_name_); auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
if (in_data != ep_name_to_in_data_.end()) { if (in_data != ep_sym_to_in_data_.end()) {
out << in_data->second.struct_name << " " << in_data->second.var_name; out << in_data->second.struct_name << " " << in_data->second.var_name;
} }
out << ") {" << std::endl; out << ") {" << std::endl;
@ -1563,7 +1571,7 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out,
make_indent(out); make_indent(out);
out << "}" << std::endl; out << "}" << std::endl;
current_ep_name_ = ""; current_ep_sym_ = Symbol();
return true; return true;
} }
@ -1966,8 +1974,8 @@ bool GeneratorImpl::EmitReturn(std::ostream& out, ast::ReturnStatement* stmt) {
if (generating_entry_point_) { if (generating_entry_point_) {
out << "return"; out << "return";
auto outdata = ep_name_to_out_data_.find(current_ep_name_); auto outdata = ep_sym_to_out_data_.find(current_ep_sym_.value());
if (outdata != ep_name_to_out_data_.end()) { if (outdata != ep_sym_to_out_data_.end()) {
out << " " << outdata->second.var_name; out << " " << outdata->second.var_name;
} }
} else if (stmt->has_value()) { } else if (stmt->has_value()) {

View File

@ -210,12 +210,12 @@ class GeneratorImpl {
/// @param func the function to emit /// @param func the function to emit
/// @param emit_duplicate_functions set true if we need to duplicate per entry /// @param emit_duplicate_functions set true if we need to duplicate per entry
/// point /// point
/// @param ep_name the current entry point or blank if none set /// @param ep_sym the current entry point or symbol::kInvalid if none set
/// @returns true if the function was emitted. /// @returns true if the function was emitted.
bool EmitFunctionInternal(std::ostream& out, bool EmitFunctionInternal(std::ostream& out,
ast::Function* func, ast::Function* func,
bool emit_duplicate_functions, bool emit_duplicate_functions,
const std::string& ep_name); Symbol ep_sym);
/// Handles emitting information for an entry point /// Handles emitting information for an entry point
/// @param out the output stream /// @param out the output stream
/// @param func the entry point /// @param func the entry point
@ -397,12 +397,12 @@ class GeneratorImpl {
Namer namer_; Namer namer_;
ast::Module* module_ = nullptr; ast::Module* module_ = nullptr;
std::string current_ep_name_; Symbol current_ep_sym_;
bool generating_entry_point_ = false; bool generating_entry_point_ = false;
uint32_t loop_emission_counter_ = 0; uint32_t loop_emission_counter_ = 0;
ScopeStack<ast::Variable*> global_variables_; ScopeStack<ast::Variable*> global_variables_;
std::unordered_map<std::string, EntryPointData> ep_name_to_in_data_; std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
std::unordered_map<std::string, EntryPointData> ep_name_to_out_data_; std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped // This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did // function name. If there is no entry for a given key then function did

View File

@ -613,9 +613,9 @@ TEST_F(HlslGeneratorImplTest_Binary, Call_WithLogical) {
ast::type::Void void_type; ast::type::Void void_type;
auto* func = create<ast::Function>(Source{}, "foo", ast::VariableList{}, auto* func = create<ast::Function>(
&void_type, create<ast::BlockStatement>(), Source{}, mod.RegisterSymbol("foo"), "foo", ast::VariableList{},
ast::FunctionDecorationList{}); &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
ast::ExpressionList params; ast::ExpressionList params;

View File

@ -35,9 +35,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithoutParams) {
auto* id = create<ast::IdentifierExpression>("my_func"); auto* id = create<ast::IdentifierExpression>("my_func");
ast::CallExpression call(id, {}); ast::CallExpression call(id, {});
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{}, auto* func = create<ast::Function>(
&void_type, create<ast::BlockStatement>(), Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
ast::FunctionDecorationList{}); &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error(); ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error();
@ -53,9 +53,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitExpression_Call_WithParams) {
params.push_back(create<ast::IdentifierExpression>("param2")); params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallExpression call(id, params); ast::CallExpression call(id, params);
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{}, auto* func = create<ast::Function>(
&void_type, create<ast::BlockStatement>(), Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
ast::FunctionDecorationList{}); &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error(); ASSERT_TRUE(gen.EmitExpression(pre, out, &call)) << gen.error();
@ -71,9 +71,9 @@ TEST_F(HlslGeneratorImplTest_Call, EmitStatement_Call) {
params.push_back(create<ast::IdentifierExpression>("param2")); params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallStatement call(create<ast::CallExpression>(id, params)); ast::CallStatement call(create<ast::CallExpression>(id, params));
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{}, auto* func = create<ast::Function>(
&void_type, create<ast::BlockStatement>(), Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
ast::FunctionDecorationList{}); &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
gen.increment_indent(); gen.increment_indent();
ASSERT_TRUE(gen.EmitStatement(out, &call)) << gen.error(); ASSERT_TRUE(gen.EmitStatement(out, &call)) << gen.error();

View File

@ -91,7 +91,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "vtx_main", params, &f32, body, Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -164,7 +164,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "vtx_main", params, &f32, body, Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -237,7 +237,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body, Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -309,7 +309,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body, Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -378,7 +378,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body, Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
}); });
@ -442,7 +442,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body, Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
}); });
@ -512,7 +512,7 @@ TEST_F(HlslGeneratorImplTest_EntryPoint,
create<ast::IdentifierExpression>("x")))); create<ast::IdentifierExpression>("x"))));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &void_type, body, Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });

View File

@ -57,9 +57,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
create<ast::Function>(Source{}, "my_func", ast::VariableList{}, "my_func", ast::VariableList{}, &void_type,
&void_type, body, ast::FunctionDecorationList{}); body, ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
gen.increment_indent(); gen.increment_indent();
@ -77,9 +77,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_Name_Collision) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(
create<ast::Function>(Source{}, "GeometryShader", ast::VariableList{}, Source{}, mod.RegisterSymbol("GeometryShader"), "GeometryShader",
&void_type, body, ast::FunctionDecorationList{}); ast::VariableList{}, &void_type, body, ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
gen.increment_indent(); gen.increment_indent();
@ -118,8 +118,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithParams) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type, auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
body, ast::FunctionDecorationList{}); "my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
gen.increment_indent(); gen.increment_indent();
@ -174,7 +175,8 @@ TEST_F(HlslGeneratorImplTest_Function,
create<ast::IdentifierExpression>("foo"))); create<ast::IdentifierExpression>("foo")));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -245,7 +247,8 @@ TEST_F(HlslGeneratorImplTest_Function,
create<ast::IdentifierExpression>("x")))); create<ast::IdentifierExpression>("x"))));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -309,7 +312,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -380,7 +384,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -455,7 +460,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -526,7 +532,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -594,7 +601,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(assign); body->append(assign);
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -682,8 +690,9 @@ TEST_F(
create<ast::IdentifierExpression>("param"))); create<ast::IdentifierExpression>("param")));
body->append(create<ast::ReturnStatement>( body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("foo"))); Source{}, create<ast::IdentifierExpression>("foo")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -698,7 +707,7 @@ TEST_F(
expr))); expr)));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>( auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body, Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -766,8 +775,9 @@ TEST_F(HlslGeneratorImplTest_Function,
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>( body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param"))); Source{}, create<ast::IdentifierExpression>("param")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -782,7 +792,7 @@ TEST_F(HlslGeneratorImplTest_Function,
expr))); expr)));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>( auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body, Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -863,8 +873,9 @@ TEST_F(
create<ast::IdentifierExpression>("x")))); create<ast::IdentifierExpression>("x"))));
body->append(create<ast::ReturnStatement>( body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param"))); Source{}, create<ast::IdentifierExpression>("param")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -879,7 +890,7 @@ TEST_F(
expr))); expr)));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>( auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body, Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -948,8 +959,9 @@ TEST_F(HlslGeneratorImplTest_Function,
Source{}, create<ast::MemberAccessorExpression>( Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"), create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("x")))); create<ast::IdentifierExpression>("x"))));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -971,7 +983,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -1034,8 +1047,9 @@ TEST_F(HlslGeneratorImplTest_Function,
Source{}, create<ast::MemberAccessorExpression>( Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"), create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("x")))); create<ast::IdentifierExpression>("x"))));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -1057,7 +1071,8 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -1122,7 +1137,7 @@ TEST_F(HlslGeneratorImplTest_Function,
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>( auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body, Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -1152,8 +1167,8 @@ TEST_F(HlslGeneratorImplTest_Function,
ast::type::Void void_type; ast::type::Void void_type;
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "GeometryShader", ast::VariableList{}, &void_type, Source{}, mod.RegisterSymbol("GeometryShader"), "GeometryShader",
create<ast::BlockStatement>(), ast::VariableList{}, &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -1175,7 +1190,7 @@ TEST_F(HlslGeneratorImplTest_Function,
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &void_type, body, Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
}); });
@ -1200,7 +1215,7 @@ TEST_F(HlslGeneratorImplTest_Function,
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &void_type, body, Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}), create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
@ -1236,8 +1251,9 @@ TEST_F(HlslGeneratorImplTest_Function, Emit_Function_WithArrayParams) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type, auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
body, ast::FunctionDecorationList{}); "my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
gen.increment_indent(); gen.increment_indent();
@ -1317,11 +1333,11 @@ TEST_F(HlslGeneratorImplTest_Function,
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(
create<ast::Function>(Source{}, "a", params, &void_type, body, Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>( create<ast::StageDecoration>(ast::PipelineStage::kCompute,
ast::PipelineStage::kCompute, Source{}), Source{}),
}); });
mod.AddFunction(func); mod.AddFunction(func);
@ -1343,11 +1359,11 @@ TEST_F(HlslGeneratorImplTest_Function,
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(
create<ast::Function>(Source{}, "b", params, &void_type, body, Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>( create<ast::StageDecoration>(ast::PipelineStage::kCompute,
ast::PipelineStage::kCompute, Source{}), Source{}),
}); });
mod.AddFunction(func); mod.AddFunction(func);

View File

@ -29,9 +29,9 @@ using HlslGeneratorImplTest = TestHelper;
TEST_F(HlslGeneratorImplTest, Generate) { TEST_F(HlslGeneratorImplTest, Generate) {
ast::type::Void void_type; ast::type::Void void_type;
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{}, auto* func = create<ast::Function>(
&void_type, create<ast::BlockStatement>(), Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
ast::FunctionDecorationList{}); &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
ASSERT_TRUE(gen.Generate(out)) << gen.error(); ASSERT_TRUE(gen.Generate(out)) << gen.error();

View File

@ -411,15 +411,15 @@ std::string GeneratorImpl::current_ep_var_name(VarType type) {
std::string name = ""; std::string name = "";
switch (type) { switch (type) {
case VarType::kIn: { case VarType::kIn: {
auto in_it = ep_name_to_in_data_.find(current_ep_name_); auto in_it = ep_sym_to_in_data_.find(current_ep_sym_.value());
if (in_it != ep_name_to_in_data_.end()) { if (in_it != ep_sym_to_in_data_.end()) {
name = in_it->second.var_name; name = in_it->second.var_name;
} }
break; break;
} }
case VarType::kOut: { case VarType::kOut: {
auto out_it = ep_name_to_out_data_.find(current_ep_name_); auto out_it = ep_sym_to_out_data_.find(current_ep_sym_.value());
if (out_it != ep_name_to_out_data_.end()) { if (out_it != ep_sym_to_out_data_.end()) {
name = out_it->second.var_name; name = out_it->second.var_name;
} }
break; break;
@ -573,12 +573,14 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) {
} }
auto name = ident->name(); auto name = ident->name();
auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name); auto caller_sym = module_->GetSymbol(name);
auto it = ep_func_name_remapped_.find(current_ep_sym_.to_str() + "_" +
caller_sym.to_str());
if (it != ep_func_name_remapped_.end()) { if (it != ep_func_name_remapped_.end()) {
name = it->second; name = it->second;
} }
auto* func = module_->FindFunctionByName(ident->name()); auto* func = module_->FindFunctionBySymbol(module_->GetSymbol(ident->name()));
if (func == nullptr) { if (func == nullptr) {
error_ = "Unable to find function: " + name; error_ = "Unable to find function: " + name;
return false; return false;
@ -1026,7 +1028,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
auto in_struct_name = auto in_struct_name =
generate_name(func->name() + "_" + kInStructNameSuffix); generate_name(func->name() + "_" + kInStructNameSuffix);
auto in_var_name = generate_name(kTintStructInVarPrefix); auto in_var_name = generate_name(kTintStructInVarPrefix);
ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name}; ep_sym_to_in_data_[func->symbol().value()] = {in_struct_name, in_var_name};
make_indent(); make_indent();
out_ << "struct " << in_struct_name << " {" << std::endl; out_ << "struct " << in_struct_name << " {" << std::endl;
@ -1063,7 +1065,8 @@ bool GeneratorImpl::EmitEntryPointData(ast::Function* func) {
auto out_struct_name = auto out_struct_name =
generate_name(func->name() + "_" + kOutStructNameSuffix); generate_name(func->name() + "_" + kOutStructNameSuffix);
auto out_var_name = generate_name(kTintStructOutVarPrefix); auto out_var_name = generate_name(kTintStructOutVarPrefix);
ep_name_to_out_data_[func->name()] = {out_struct_name, out_var_name}; ep_sym_to_out_data_[func->symbol().value()] = {out_struct_name,
out_var_name};
make_indent(); make_indent();
out_ << "struct " << out_struct_name << " {" << std::endl; out_ << "struct " << out_struct_name << " {" << std::endl;
@ -1205,15 +1208,15 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
has_referenced_var_needing_struct(func); has_referenced_var_needing_struct(func);
if (emit_duplicate_functions) { if (emit_duplicate_functions) {
for (const auto& ep_name : func->ancestor_entry_points()) { for (const auto& ep_sym : func->ancestor_entry_points()) {
if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) { if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_sym)) {
return false; return false;
} }
out_ << std::endl; out_ << std::endl;
} }
} else { } else {
// Emit as non-duplicated // Emit as non-duplicated
if (!EmitFunctionInternal(func, false, "")) { if (!EmitFunctionInternal(func, false, Symbol())) {
return false; return false;
} }
out_ << std::endl; out_ << std::endl;
@ -1224,19 +1227,23 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) {
bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
bool emit_duplicate_functions, bool emit_duplicate_functions,
const std::string& ep_name) { Symbol ep_sym) {
auto name = func->name(); auto name = func->symbol().to_str();
if (!EmitType(func->return_type(), "")) { if (!EmitType(func->return_type(), "")) {
return false; return false;
} }
out_ << " "; out_ << " ";
if (emit_duplicate_functions) { if (emit_duplicate_functions) {
name = generate_name(name + "_" + ep_name); auto func_name = name;
ep_func_name_remapped_[ep_name + "_" + func->name()] = name; auto ep_name = ep_sym.to_str();
// TODO(dsinclair): The SymbolToName should go away and just use
// to_str() here when the conversion is complete.
name = generate_name(func->name() + "_" + module_->SymbolToName(ep_sym));
ep_func_name_remapped_[ep_name + "_" + func_name] = name;
} else { } else {
name = namer_.NameFor(name); // TODO(dsinclair): this should be updated to a remapped name
name = namer_.NameFor(func->name());
} }
out_ << name << "("; out_ << name << "(";
@ -1247,15 +1254,15 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
// //
// We emit both of them if they're there regardless of if they're both used. // We emit both of them if they're there regardless of if they're both used.
if (emit_duplicate_functions) { if (emit_duplicate_functions) {
auto in_it = ep_name_to_in_data_.find(ep_name); auto in_it = ep_sym_to_in_data_.find(ep_sym.value());
if (in_it != ep_name_to_in_data_.end()) { if (in_it != ep_sym_to_in_data_.end()) {
out_ << "thread " << in_it->second.struct_name << "& " out_ << "thread " << in_it->second.struct_name << "& "
<< in_it->second.var_name; << in_it->second.var_name;
first = false; first = false;
} }
auto out_it = ep_name_to_out_data_.find(ep_name); auto out_it = ep_sym_to_out_data_.find(ep_sym.value());
if (out_it != ep_name_to_out_data_.end()) { if (out_it != ep_sym_to_out_data_.end()) {
if (!first) { if (!first) {
out_ << ", "; out_ << ", ";
} }
@ -1337,13 +1344,13 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func,
out_ << ") "; out_ << ") ";
current_ep_name_ = ep_name; current_ep_sym_ = ep_sym;
if (!EmitBlockAndNewline(func->body())) { if (!EmitBlockAndNewline(func->body())) {
return false; return false;
} }
current_ep_name_ = ""; current_ep_sym_ = Symbol();
return true; return true;
} }
@ -1377,25 +1384,25 @@ std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const {
bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
make_indent(); make_indent();
current_ep_name_ = func->name(); current_ep_sym_ = func->symbol();
EmitStage(func->pipeline_stage()); EmitStage(func->pipeline_stage());
out_ << " "; out_ << " ";
// This is an entry point, the return type is the entry point output structure // This is an entry point, the return type is the entry point output structure
// if one exists, or void otherwise. // if one exists, or void otherwise.
auto out_data = ep_name_to_out_data_.find(current_ep_name_); auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
bool has_out_data = out_data != ep_name_to_out_data_.end(); bool has_out_data = out_data != ep_sym_to_out_data_.end();
if (has_out_data) { if (has_out_data) {
out_ << out_data->second.struct_name; out_ << out_data->second.struct_name;
} else { } else {
out_ << "void"; out_ << "void";
} }
out_ << " " << namer_.NameFor(current_ep_name_) << "("; out_ << " " << namer_.NameFor(func->name()) << "(";
bool first = true; bool first = true;
auto in_data = ep_name_to_in_data_.find(current_ep_name_); auto in_data = ep_sym_to_in_data_.find(current_ep_sym_.value());
if (in_data != ep_name_to_in_data_.end()) { if (in_data != ep_sym_to_in_data_.end()) {
out_ << in_data->second.struct_name << " " << in_data->second.var_name out_ << in_data->second.struct_name << " " << in_data->second.var_name
<< " [[stage_in]]"; << " [[stage_in]]";
first = false; first = false;
@ -1503,7 +1510,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) {
make_indent(); make_indent();
out_ << "}" << std::endl; out_ << "}" << std::endl;
current_ep_name_ = ""; current_ep_sym_ = Symbol();
return true; return true;
} }
@ -1687,8 +1694,8 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) {
out_ << "return"; out_ << "return";
if (generating_entry_point_) { if (generating_entry_point_) {
auto out_data = ep_name_to_out_data_.find(current_ep_name_); auto out_data = ep_sym_to_out_data_.find(current_ep_sym_.value());
if (out_data != ep_name_to_out_data_.end()) { if (out_data != ep_sym_to_out_data_.end()) {
out_ << " " << out_data->second.var_name; out_ << " " << out_data->second.var_name;
} }
} else if (stmt->has_value()) { } else if (stmt->has_value()) {

View File

@ -156,11 +156,11 @@ class GeneratorImpl : public TextGenerator {
/// @param func the function to emit /// @param func the function to emit
/// @param emit_duplicate_functions set true if we need to duplicate per entry /// @param emit_duplicate_functions set true if we need to duplicate per entry
/// point /// point
/// @param ep_name the current entry point or blank if none set /// @param ep_sym the current entry point or symbol::kInvalid if not set
/// @returns true if the function was emitted. /// @returns true if the function was emitted.
bool EmitFunctionInternal(ast::Function* func, bool EmitFunctionInternal(ast::Function* func,
bool emit_duplicate_functions, bool emit_duplicate_functions,
const std::string& ep_name); Symbol ep_sym);
/// Handles generating an identifier expression /// Handles generating an identifier expression
/// @param expr the identifier expression /// @param expr the identifier expression
/// @returns true if the identifier was emitted /// @returns true if the identifier was emitted
@ -282,13 +282,13 @@ class GeneratorImpl : public TextGenerator {
Namer namer_; Namer namer_;
ScopeStack<ast::Variable*> global_variables_; ScopeStack<ast::Variable*> global_variables_;
std::string current_ep_name_; Symbol current_ep_sym_;
bool generating_entry_point_ = false; bool generating_entry_point_ = false;
const ast::Module* module_ = nullptr; const ast::Module* module_ = nullptr;
uint32_t loop_emission_counter_ = 0; uint32_t loop_emission_counter_ = 0;
std::unordered_map<std::string, EntryPointData> ep_name_to_in_data_; std::unordered_map<uint32_t, EntryPointData> ep_sym_to_in_data_;
std::unordered_map<std::string, EntryPointData> ep_name_to_out_data_; std::unordered_map<uint32_t, EntryPointData> ep_sym_to_out_data_;
// This maps an input of "<entry_point_name>_<function_name>" to a remapped // This maps an input of "<entry_point_name>_<function_name>" to a remapped
// function name. If there is no entry for a given key then function did // function name. If there is no entry for a given key then function did

View File

@ -37,9 +37,9 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithoutParams) {
auto* id = create<ast::IdentifierExpression>("my_func"); auto* id = create<ast::IdentifierExpression>("my_func");
ast::CallExpression call(id, {}); ast::CallExpression call(id, {});
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{}, auto* func = create<ast::Function>(
&void_type, create<ast::BlockStatement>(), Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
ast::FunctionDecorationList{}); &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error(); ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error();
@ -55,9 +55,9 @@ TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithParams) {
params.push_back(create<ast::IdentifierExpression>("param2")); params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallExpression call(id, params); ast::CallExpression call(id, params);
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{}, auto* func = create<ast::Function>(
&void_type, create<ast::BlockStatement>(), Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
ast::FunctionDecorationList{}); &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error(); ASSERT_TRUE(gen.EmitExpression(&call)) << gen.error();
@ -73,9 +73,9 @@ TEST_F(MslGeneratorImplTest, EmitStatement_Call) {
params.push_back(create<ast::IdentifierExpression>("param2")); params.push_back(create<ast::IdentifierExpression>("param2"));
ast::CallStatement call(create<ast::CallExpression>(id, params)); ast::CallStatement call(create<ast::CallExpression>(id, params));
auto* func = create<ast::Function>(Source{}, "my_func", ast::VariableList{}, auto* func = create<ast::Function>(
&void_type, create<ast::BlockStatement>(), Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
ast::FunctionDecorationList{}); &void_type, create<ast::BlockStatement>(), ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
gen.increment_indent(); gen.increment_indent();

View File

@ -90,7 +90,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Input) {
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "vtx_main", params, &f32, body, Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -160,7 +160,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Output) {
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "vtx_main", params, &f32, body, Source{}, mod.RegisterSymbol("vtx_main"), "vtx_main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -229,7 +229,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Input) {
create<ast::IdentifierExpression>("bar"), create<ast::IdentifierExpression>("bar"),
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body, Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -299,7 +299,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Output) {
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body, Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -366,7 +366,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Input) {
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body, Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
}); });
@ -428,7 +428,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Output) {
create<ast::IdentifierExpression>("bar"))); create<ast::IdentifierExpression>("bar")));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &f32, body, Source{}, mod.RegisterSymbol("main"), "main", params, &f32, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
}); });
@ -496,7 +496,7 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Builtins) {
create<ast::IdentifierExpression>("x")))); create<ast::IdentifierExpression>("x"))));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", params, &void_type, body, Source{}, mod.RegisterSymbol("main"), "main", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });

View File

@ -60,9 +60,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
create<ast::Function>(Source{}, "my_func", ast::VariableList{}, "my_func", ast::VariableList{}, &void_type,
&void_type, body, ast::FunctionDecorationList{}); body, ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
gen.increment_indent(); gen.increment_indent();
@ -82,8 +82,8 @@ TEST_F(MslGeneratorImplTest, Emit_Function_Name_Collision) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("main"),
create<ast::Function>(Source{}, "main", ast::VariableList{}, &void_type, "main", ast::VariableList{}, &void_type,
body, ast::FunctionDecorationList{}); body, ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
@ -125,8 +125,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithParams) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type, auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
body, ast::FunctionDecorationList{}); "my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
gen.increment_indent(); gen.increment_indent();
@ -183,7 +184,8 @@ TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_WithInOutVars) {
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{create<ast::StageDecoration>( ast::FunctionDecorationList{create<ast::StageDecoration>(
ast::PipelineStage::kFragment, Source{})}); ast::PipelineStage::kFragment, Source{})});
@ -257,7 +259,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -321,7 +324,8 @@ TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_With_Uniform) {
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -397,7 +401,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -478,7 +483,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -572,8 +578,9 @@ TEST_F(
create<ast::IdentifierExpression>("param"))); create<ast::IdentifierExpression>("param")));
body->append(create<ast::ReturnStatement>( body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("foo"))); Source{}, create<ast::IdentifierExpression>("foo")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -588,7 +595,7 @@ TEST_F(
expr))); expr)));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>( auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body, Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -659,8 +666,9 @@ TEST_F(MslGeneratorImplTest,
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>( body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param"))); Source{}, create<ast::IdentifierExpression>("param")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -676,7 +684,7 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>( auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body, Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -760,8 +768,9 @@ TEST_F(
create<ast::IdentifierExpression>("x")))); create<ast::IdentifierExpression>("x"))));
body->append(create<ast::ReturnStatement>( body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("param"))); Source{}, create<ast::IdentifierExpression>("param")));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -776,7 +785,7 @@ TEST_F(
expr))); expr)));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>( auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body, Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -843,8 +852,9 @@ TEST_F(MslGeneratorImplTest,
Source{}, create<ast::MemberAccessorExpression>( Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"), create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("x")))); create<ast::IdentifierExpression>("x"))));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -867,7 +877,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -943,8 +954,9 @@ TEST_F(MslGeneratorImplTest,
Source{}, create<ast::MemberAccessorExpression>( Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"), create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("b")))); create<ast::IdentifierExpression>("b"))));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -967,7 +979,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -1049,8 +1062,9 @@ TEST_F(MslGeneratorImplTest,
Source{}, create<ast::MemberAccessorExpression>( Source{}, create<ast::MemberAccessorExpression>(
create<ast::IdentifierExpression>("coord"), create<ast::IdentifierExpression>("coord"),
create<ast::IdentifierExpression>("b")))); create<ast::IdentifierExpression>("b"))));
auto* sub_func = create<ast::Function>(Source{}, "sub_func", params, &f32, auto* sub_func = create<ast::Function>(
body, ast::FunctionDecorationList{}); Source{}, mod.RegisterSymbol("sub_func"), "sub_func", params, &f32, body,
ast::FunctionDecorationList{});
mod.AddFunction(sub_func); mod.AddFunction(sub_func);
@ -1073,7 +1087,8 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "frag_main", params, &void_type, body, Source{}, mod.RegisterSymbol("frag_main"), "frag_main", params,
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -1145,7 +1160,7 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func_1 = create<ast::Function>( auto* func_1 = create<ast::Function>(
Source{}, "ep_1", params, &void_type, body, Source{}, mod.RegisterSymbol("ep_1"), "ep_1", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -1177,8 +1192,8 @@ TEST_F(MslGeneratorImplTest,
ast::type::Void void_type; ast::type::Void void_type;
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "main", ast::VariableList{}, &void_type, Source{}, mod.RegisterSymbol("main"), "main", ast::VariableList{},
create<ast::BlockStatement>(), &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
}); });
@ -1212,8 +1227,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = create<ast::Function>(Source{}, "my_func", params, &void_type, auto* func = create<ast::Function>(Source{}, mod.RegisterSymbol("my_func"),
body, ast::FunctionDecorationList{}); "my_func", params, &void_type, body,
ast::FunctionDecorationList{});
mod.AddFunction(func); mod.AddFunction(func);
@ -1298,11 +1314,11 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(
create<ast::Function>(Source{}, "a", params, &void_type, body, Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>( create<ast::StageDecoration>(ast::PipelineStage::kCompute,
ast::PipelineStage::kCompute, Source{}), Source{}),
}); });
mod.AddFunction(func); mod.AddFunction(func);
@ -1325,11 +1341,11 @@ TEST_F(MslGeneratorImplTest,
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(
create<ast::Function>(Source{}, "b", params, &void_type, body, Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>( create<ast::StageDecoration>(ast::PipelineStage::kCompute,
ast::PipelineStage::kCompute, Source{}), Source{}),
}); });
mod.AddFunction(func); mod.AddFunction(func);

View File

@ -51,8 +51,8 @@ TEST_F(MslGeneratorImplTest, Generate) {
ast::type::Void void_type; ast::type::Void void_type;
auto* func = create<ast::Function>( auto* func = create<ast::Function>(
Source{}, "my_func", ast::VariableList{}, &void_type, Source{}, mod.RegisterSymbol("my_func"), "my_func", ast::VariableList{},
create<ast::BlockStatement>(), &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
}); });

View File

@ -65,11 +65,11 @@ TEST_F(BuilderTest, Expression_Call) {
Source{}, create<ast::BinaryExpression>( Source{}, create<ast::BinaryExpression>(
ast::BinaryOp::kAdd, create<ast::IdentifierExpression>("a"), ast::BinaryOp::kAdd, create<ast::IdentifierExpression>("a"),
create<ast::IdentifierExpression>("b")))); create<ast::IdentifierExpression>("b"))));
ast::Function a_func(Source{}, "a_func", func_params, &f32, body, ast::Function a_func(Source{}, mod->RegisterSymbol("a_func"), "a_func",
ast::FunctionDecorationList{}); func_params, &f32, body, ast::FunctionDecorationList{});
ast::Function func(Source{}, "main", {}, &void_type, ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {},
create<ast::BlockStatement>(), &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ast::ExpressionList call_params; ast::ExpressionList call_params;
@ -143,11 +143,12 @@ TEST_F(BuilderTest, Statement_Call) {
ast::BinaryOp::kAdd, create<ast::IdentifierExpression>("a"), ast::BinaryOp::kAdd, create<ast::IdentifierExpression>("a"),
create<ast::IdentifierExpression>("b")))); create<ast::IdentifierExpression>("b"))));
ast::Function a_func(Source{}, "a_func", func_params, &void_type, body, ast::Function a_func(Source{}, mod->RegisterSymbol("a_func"), "a_func",
func_params, &void_type, body,
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ast::Function func(Source{}, "main", {}, &void_type, ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {},
create<ast::BlockStatement>(), &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ast::ExpressionList call_params; ast::ExpressionList call_params;

View File

@ -42,7 +42,8 @@ TEST_F(BuilderTest, FunctionDecoration_Stage) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func( ast::Function func(
Source{}, "main", {}, &void_type, create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -67,8 +68,8 @@ TEST_P(FunctionDecoration_StageTest, Emit) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func(Source{}, "main", {}, &void_type, ast::Function func(Source{}, mod->RegisterSymbol("main"), "main", {},
create<ast::BlockStatement>(), &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(params.stage, Source{}), create<ast::StageDecoration>(params.stage, Source{}),
}); });
@ -97,7 +98,8 @@ TEST_F(BuilderTest, FunctionDecoration_Stage_WithUnusedInterfaceIds) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func( ast::Function func(
Source{}, "main", {}, &void_type, create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -174,7 +176,7 @@ TEST_F(BuilderTest, FunctionDecoration_Stage_WithUsedInterfaceIds) {
create<ast::IdentifierExpression>("my_in"))); create<ast::IdentifierExpression>("my_in")));
ast::Function func( ast::Function func(
Source{}, "main", {}, &void_type, body, Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kVertex, Source{}),
}); });
@ -244,7 +246,8 @@ TEST_F(BuilderTest, FunctionDecoration_ExecutionMode_Fragment_OriginUpperLeft) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func( ast::Function func(
Source{}, "main", {}, &void_type, create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -259,7 +262,8 @@ TEST_F(BuilderTest, FunctionDecoration_WorkgroupSize_Default) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func( ast::Function func(
Source{}, "main", {}, &void_type, create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
}); });
@ -274,7 +278,8 @@ TEST_F(BuilderTest, FunctionDecoration_WorkgroupSize) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func( ast::Function func(
Source{}, "main", {}, &void_type, create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("main"), "main", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}), create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kCompute, Source{}),
@ -290,13 +295,15 @@ TEST_F(BuilderTest, FunctionDecoration_ExecutionMode_MultipleFragment) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func1( ast::Function func1(
Source{}, "main1", {}, &void_type, create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("main1"), "main1", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
ast::Function func2( ast::Function func2(
Source{}, "main2", {}, &void_type, create<ast::BlockStatement>(), Source{}, mod->RegisterSymbol("main2"), "main2", {}, &void_type,
create<ast::BlockStatement>(),
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });

View File

@ -47,8 +47,8 @@ using BuilderTest = TestHelper;
TEST_F(BuilderTest, Function_Empty) { TEST_F(BuilderTest, Function_Empty) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func(Source{}, "a_func", {}, &void_type, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)); ASSERT_TRUE(b.GenerateFunction(&func));
@ -68,8 +68,8 @@ TEST_F(BuilderTest, Function_Terminator_Return) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
ast::Function func(Source{}, "a_func", {}, &void_type, body, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ast::FunctionDecorationList{}); &void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)); ASSERT_TRUE(b.GenerateFunction(&func));
EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func"
@ -101,8 +101,8 @@ TEST_F(BuilderTest, Function_Terminator_ReturnValue) {
Source{}, create<ast::IdentifierExpression>("a"))); Source{}, create<ast::IdentifierExpression>("a")));
ASSERT_TRUE(td.DetermineResultType(body)) << td.error(); ASSERT_TRUE(td.DetermineResultType(body)) << td.error();
ast::Function func(Source{}, "a_func", {}, &void_type, body, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ast::FunctionDecorationList{}); &void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error(); ASSERT_TRUE(b.GenerateGlobalVariable(var_a)) << b.error();
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -128,8 +128,8 @@ TEST_F(BuilderTest, Function_Terminator_Discard) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::DiscardStatement>()); body->append(create<ast::DiscardStatement>());
ast::Function func(Source{}, "a_func", {}, &void_type, body, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ast::FunctionDecorationList{}); &void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)); ASSERT_TRUE(b.GenerateFunction(&func));
EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func"
@ -168,8 +168,8 @@ TEST_F(BuilderTest, Function_WithParams) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>( body->append(create<ast::ReturnStatement>(
Source{}, create<ast::IdentifierExpression>("a"))); Source{}, create<ast::IdentifierExpression>("a")));
ast::Function func(Source{}, "a_func", params, &f32, body, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", params,
ast::FunctionDecorationList{}); &f32, body, ast::FunctionDecorationList{});
td.RegisterVariableForTesting(func.params()[0]); td.RegisterVariableForTesting(func.params()[0]);
td.RegisterVariableForTesting(func.params()[1]); td.RegisterVariableForTesting(func.params()[1]);
@ -197,8 +197,8 @@ TEST_F(BuilderTest, Function_WithBody) {
auto* body = create<ast::BlockStatement>(); auto* body = create<ast::BlockStatement>();
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
ast::Function func(Source{}, "a_func", {}, &void_type, body, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
ast::FunctionDecorationList{}); &void_type, body, ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)); ASSERT_TRUE(b.GenerateFunction(&func));
EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func" EXPECT_EQ(DumpBuilder(b), R"(OpName %3 "a_func"
@ -213,8 +213,8 @@ OpFunctionEnd
TEST_F(BuilderTest, FunctionType) { TEST_F(BuilderTest, FunctionType) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func(Source{}, "a_func", {}, &void_type, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)); ASSERT_TRUE(b.GenerateFunction(&func));
@ -225,11 +225,11 @@ TEST_F(BuilderTest, FunctionType) {
TEST_F(BuilderTest, FunctionType_DeDuplicate) { TEST_F(BuilderTest, FunctionType_DeDuplicate) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func1(Source{}, "a_func", {}, &void_type, ast::Function func1(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ast::Function func2(Source{}, "b_func", {}, &void_type, ast::Function func2(Source{}, mod->RegisterSymbol("b_func"), "b_func", {},
create<ast::BlockStatement>(), &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func1)); ASSERT_TRUE(b.GenerateFunction(&func1));
@ -307,11 +307,11 @@ TEST_F(BuilderTest, Emit_Multiple_EntryPoint_With_Same_ModuleVar) {
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(
create<ast::Function>(Source{}, "a", params, &void_type, body, Source{}, mod->RegisterSymbol("a"), "a", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>( create<ast::StageDecoration>(ast::PipelineStage::kCompute,
ast::PipelineStage::kCompute, Source{}), Source{}),
}); });
mod->AddFunction(func); mod->AddFunction(func);
@ -334,11 +334,11 @@ TEST_F(BuilderTest, Emit_Multiple_EntryPoint_With_Same_ModuleVar) {
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(
create<ast::Function>(Source{}, "b", params, &void_type, body, Source{}, mod->RegisterSymbol("b"), "b", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>( create<ast::StageDecoration>(ast::PipelineStage::kCompute,
ast::PipelineStage::kCompute, Source{}), Source{}),
}); });
mod->AddFunction(func); mod->AddFunction(func);

View File

@ -471,8 +471,8 @@ TEST_F(IntrinsicBuilderTest, Call_GLSLMethod_WithLoad) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error(); ASSERT_TRUE(b.GenerateGlobalVariable(var)) << b.error();
@ -505,8 +505,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Float_Test, Call_Scalar) {
auto expr = Call(param.name, 1.0f); auto expr = Call(param.name, 1.0f);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -533,8 +533,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Float_Test, Call_Vector) {
auto expr = Call(param.name, vec2<f32>(1.0f, 1.0f)); auto expr = Call(param.name, vec2<f32>(1.0f, 1.0f));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -587,8 +587,8 @@ TEST_F(IntrinsicBuilderTest, Call_Length_Scalar) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -612,8 +612,8 @@ TEST_F(IntrinsicBuilderTest, Call_Length_Vector) {
auto expr = Call("length", vec2<f32>(1.0f, 1.0f)); auto expr = Call("length", vec2<f32>(1.0f, 1.0f));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -639,8 +639,8 @@ TEST_F(IntrinsicBuilderTest, Call_Normalize) {
auto expr = Call("normalize", vec2<f32>(1.0f, 1.0f)); auto expr = Call("normalize", vec2<f32>(1.0f, 1.0f));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -671,8 +671,8 @@ TEST_P(Intrinsic_Builtin_DualParam_Float_Test, Call_Scalar) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -700,8 +700,8 @@ TEST_P(Intrinsic_Builtin_DualParam_Float_Test, Call_Vector) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -737,8 +737,8 @@ TEST_F(IntrinsicBuilderTest, Call_Distance_Scalar) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -763,8 +763,8 @@ TEST_F(IntrinsicBuilderTest, Call_Distance_Vector) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -792,8 +792,8 @@ TEST_F(IntrinsicBuilderTest, Call_Cross) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -823,8 +823,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Float_Test, Call_Scalar) {
auto expr = Call(param.name, 1.0f, 1.0f, 1.0f); auto expr = Call(param.name, 1.0f, 1.0f, 1.0f);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -853,8 +853,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Float_Test, Call_Vector) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -894,8 +894,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Sint_Test, Call_Scalar) {
auto expr = Call(param.name, 1); auto expr = Call(param.name, 1);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -922,8 +922,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Sint_Test, Call_Vector) {
auto expr = Call(param.name, vec2<i32>(1, 1)); auto expr = Call(param.name, vec2<i32>(1, 1));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -957,8 +957,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Uint_Test, Call_Scalar) {
auto expr = Call(param.name, 1u); auto expr = Call(param.name, 1u);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -985,8 +985,8 @@ TEST_P(Intrinsic_Builtin_SingleParam_Uint_Test, Call_Vector) {
auto expr = Call(param.name, vec2<u32>(1u, 1u)); auto expr = Call(param.name, vec2<u32>(1u, 1u));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1020,8 +1020,8 @@ TEST_P(Intrinsic_Builtin_DualParam_SInt_Test, Call_Scalar) {
auto expr = Call(param.name, 1, 1); auto expr = Call(param.name, 1, 1);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1048,8 +1048,8 @@ TEST_P(Intrinsic_Builtin_DualParam_SInt_Test, Call_Vector) {
auto expr = Call(param.name, vec2<i32>(1, 1), vec2<i32>(1, 1)); auto expr = Call(param.name, vec2<i32>(1, 1), vec2<i32>(1, 1));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1084,8 +1084,8 @@ TEST_P(Intrinsic_Builtin_DualParam_UInt_Test, Call_Scalar) {
auto expr = Call(param.name, 1u, 1u); auto expr = Call(param.name, 1u, 1u);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1112,8 +1112,8 @@ TEST_P(Intrinsic_Builtin_DualParam_UInt_Test, Call_Vector) {
auto expr = Call(param.name, vec2<u32>(1u, 1u), vec2<u32>(1u, 1u)); auto expr = Call(param.name, vec2<u32>(1u, 1u), vec2<u32>(1u, 1u));
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1148,8 +1148,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Sint_Test, Call_Scalar) {
auto expr = Call(param.name, 1, 1, 1); auto expr = Call(param.name, 1, 1, 1);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1178,8 +1178,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Sint_Test, Call_Vector) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1213,8 +1213,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Uint_Test, Call_Scalar) {
auto expr = Call(param.name, 1u, 1u, 1u); auto expr = Call(param.name, 1u, 1u, 1u);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1243,8 +1243,8 @@ TEST_P(Intrinsic_Builtin_ThreeParam_Uint_Test, Call_Vector) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1276,8 +1276,8 @@ TEST_F(IntrinsicBuilderTest, Call_Determinant) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1320,8 +1320,8 @@ TEST_F(IntrinsicBuilderTest, Call_ArrayLength) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1360,8 +1360,8 @@ TEST_F(IntrinsicBuilderTest, Call_ArrayLength_OtherMembersInStruct) {
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();
@ -1405,8 +1405,8 @@ TEST_F(IntrinsicBuilderTest, DISABLED_Call_ArrayLength_Ptr) {
auto expr = Call("arrayLength", "ptr_var"); auto expr = Call("arrayLength", "ptr_var");
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, ty.void_, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), ty.void_, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateFunction(&func)) << b.error(); ASSERT_TRUE(b.GenerateFunction(&func)) << b.error();

View File

@ -121,8 +121,8 @@ TEST_F(BuilderTest, Switch_WithCase) {
td.RegisterVariableForTesting(a); td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@ -201,8 +201,8 @@ TEST_F(BuilderTest, Switch_WithDefault) {
td.RegisterVariableForTesting(a); td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@ -300,8 +300,8 @@ TEST_F(BuilderTest, Switch_WithCaseAndDefault) {
td.RegisterVariableForTesting(a); td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@ -408,8 +408,8 @@ TEST_F(BuilderTest, Switch_CaseWithFallthrough) {
td.RegisterVariableForTesting(a); td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@ -495,8 +495,8 @@ TEST_F(BuilderTest, Switch_CaseFallthroughLastStatement) {
td.RegisterVariableForTesting(a); td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();
@ -563,8 +563,8 @@ TEST_F(BuilderTest, Switch_WithNestedBreak) {
td.RegisterVariableForTesting(a); td.RegisterVariableForTesting(a);
ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error(); ASSERT_TRUE(td.DetermineResultType(&expr)) << td.error();
ast::Function func(Source{}, "a_func", {}, &i32, ast::Function func(Source{}, mod->RegisterSymbol("a_func"), "a_func", {},
create<ast::BlockStatement>(), &i32, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}); ast::FunctionDecorationList{});
ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error(); ASSERT_TRUE(b.GenerateGlobalVariable(v)) << b.error();

View File

@ -113,7 +113,8 @@ bool GeneratorImpl::Generate(const ast::Module& module) {
bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module, bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module,
ast::PipelineStage stage, ast::PipelineStage stage,
const std::string& name) { const std::string& name) {
auto* func = module.FindFunctionByNameAndStage(name, stage); auto* func =
module.FindFunctionBySymbolAndStage(module.GetSymbol(name), stage);
if (func == nullptr) { if (func == nullptr) {
error_ = "Unable to find requested entry point: " + name; error_ = "Unable to find requested entry point: " + name;
return false; return false;
@ -153,7 +154,7 @@ bool GeneratorImpl::GenerateEntryPoint(const ast::Module& module,
} }
for (auto* f : module.functions()) { for (auto* f : module.functions()) {
if (!f->HasAncestorEntryPoint(name)) { if (!f->HasAncestorEntryPoint(module.GetSymbol(name))) {
continue; continue;
} }

View File

@ -46,8 +46,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function) {
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func(Source{}, "my_func", {}, &void_type, body, ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", {},
ast::FunctionDecorationList{}); &void_type, body, ast::FunctionDecorationList{});
gen.increment_indent(); gen.increment_indent();
@ -85,8 +85,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithParams) {
ast::VariableDecorationList{})); // decorations ast::VariableDecorationList{})); // decorations
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func(Source{}, "my_func", params, &void_type, body, ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", params,
ast::FunctionDecorationList{}); &void_type, body, ast::FunctionDecorationList{});
gen.increment_indent(); gen.increment_indent();
@ -104,7 +104,8 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_WorkgroupSize) {
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func(Source{}, "my_func", {}, &void_type, body, ast::Function func(Source{}, mod.RegisterSymbol("my_func"), "my_func", {},
&void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}), create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
}); });
@ -127,7 +128,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Stage) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func( ast::Function func(
Source{}, "my_func", {}, &void_type, body, Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
}); });
@ -150,7 +151,7 @@ TEST_F(WgslGeneratorImplTest, Emit_Function_WithDecoration_Multiple) {
ast::type::Void void_type; ast::type::Void void_type;
ast::Function func( ast::Function func(
Source{}, "my_func", {}, &void_type, body, Source{}, mod.RegisterSymbol("my_func"), "my_func", {}, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}), create<ast::StageDecoration>(ast::PipelineStage::kFragment, Source{}),
create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}), create<ast::WorkgroupDecoration>(2u, 4u, 6u, Source{}),
@ -237,11 +238,11 @@ TEST_F(WgslGeneratorImplTest,
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(
create<ast::Function>(Source{}, "a", params, &void_type, body, Source{}, mod.RegisterSymbol("a"), "a", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>( create<ast::StageDecoration>(ast::PipelineStage::kCompute,
ast::PipelineStage::kCompute, Source{}), Source{}),
}); });
mod.AddFunction(func); mod.AddFunction(func);
@ -264,11 +265,11 @@ TEST_F(WgslGeneratorImplTest,
body->append(create<ast::VariableDeclStatement>(var)); body->append(create<ast::VariableDeclStatement>(var));
body->append(create<ast::ReturnStatement>(Source{})); body->append(create<ast::ReturnStatement>(Source{}));
auto* func = auto* func = create<ast::Function>(
create<ast::Function>(Source{}, "b", params, &void_type, body, Source{}, mod.RegisterSymbol("b"), "b", params, &void_type, body,
ast::FunctionDecorationList{ ast::FunctionDecorationList{
create<ast::StageDecoration>( create<ast::StageDecoration>(ast::PipelineStage::kCompute,
ast::PipelineStage::kCompute, Source{}), Source{}),
}); });
mod.AddFunction(func); mod.AddFunction(func);

View File

@ -33,8 +33,9 @@ TEST_F(WgslGeneratorImplTest, Generate) {
ast::type::Void void_type; ast::type::Void void_type;
mod.AddFunction(create<ast::Function>( mod.AddFunction(create<ast::Function>(
Source{}, "my_func", ast::VariableList{}, &void_type, Source{}, mod.RegisterSymbol("a_func"), "my_func", ast::VariableList{},
create<ast::BlockStatement>(), ast::FunctionDecorationList{})); &void_type, create<ast::BlockStatement>(),
ast::FunctionDecorationList{}));
ASSERT_TRUE(gen.Generate(mod)) << gen.error(); ASSERT_TRUE(gen.Generate(mod)) << gen.error();
EXPECT_EQ(gen.result(), R"(fn my_func() -> void { EXPECT_EQ(gen.result(), R"(fn my_func() -> void {