diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index ba4c935e1c..61e560d3fb 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -44,12 +44,16 @@ GeneratorImpl::GeneratorImpl() = default; GeneratorImpl::~GeneratorImpl() = default; bool GeneratorImpl::Generate(const ast::Module& module) { + module_ = &module; + for (const auto& func : module.functions()) { if (!EmitFunction(func.get())) { return false; } out_ << std::endl; } + + module_ = nullptr; return true; } @@ -227,14 +231,51 @@ bool GeneratorImpl::EmitExpression(ast::Expression* expr) { return false; } +void GeneratorImpl::EmitStage(ast::PipelineStage stage) { + switch (stage) { + case ast::PipelineStage::kFragment: + out_ << "fragment"; + break; + case ast::PipelineStage::kVertex: + out_ << "vertex"; + break; + case ast::PipelineStage::kCompute: + out_ << "kernel"; + break; + case ast::PipelineStage::kNone: + break; + } + return; +} + bool GeneratorImpl::EmitFunction(ast::Function* func) { make_indent(); + // TODO(dsinclair): Technically this is wrong as you could, in theory, have + // multiple entry points pointing at the same function. I'm ignoring that for + // now. It will either go away with the entry_point changes in the spec + // or we'll have to figure out how to deal with it. + + auto name = func->name(); + + for (const auto& ep : module_->entry_points()) { + if (ep->function_name() == name) { + EmitStage(ep->stage()); + out_ << " "; + + if (!ep->name().empty()) { + name = ep->name(); + } + + break; + } + } + if (!EmitType(func->return_type(), "")) { return false; } - out_ << " " << func->name() << "("; + out_ << " " << name << "("; bool first = true; for (const auto& v : func->params()) { diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index 491f120e9c..10d907324b 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -76,6 +76,9 @@ class GeneratorImpl : public TextGenerator { /// @param expr the scalar constructor expression /// @returns true if the scalar constructor is emitted bool EmitScalarConstructor(ast::ScalarConstructorExpression* expr); + /// Handles emitting a pipeline stage name + /// @param stage the stage to emit + void EmitStage(ast::PipelineStage stage); /// Handles a brace-enclosed list of statements. /// @param statements the statements to output /// @returns true if the statements were emitted @@ -97,6 +100,9 @@ class GeneratorImpl : public TextGenerator { /// @param expr the type constructor expression /// @returns true if the constructor is emitted bool EmitTypeConstructor(ast::TypeConstructorExpression* expr); + + private: + const ast::Module* module_ = nullptr; }; } // namespace msl diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index d47864ebde..965eabacbf 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -14,6 +14,7 @@ #include "gtest/gtest.h" #include "src/ast/function.h" +#include "src/ast/module.h" #include "src/ast/return_statement.h" #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" @@ -32,19 +33,24 @@ using MslGeneratorImplTest = testing::Test; TEST_F(MslGeneratorImplTest, Emit_Function) { ast::type::VoidType void_type; - ast::Function func("my_func", {}, &void_type); + auto func = std::make_unique("my_func", ast::VariableList{}, + &void_type); ast::StatementList body; body.push_back(std::make_unique()); - func.set_body(std::move(body)); + func->set_body(std::move(body)); + + ast::Module m; + m.AddFunction(std::move(func)); GeneratorImpl g; g.increment_indent(); - ASSERT_TRUE(g.EmitFunction(&func)); + ASSERT_TRUE(g.Generate(m)) << g.error(); EXPECT_EQ(g.result(), R"( void my_func() { return; } + )"); } @@ -59,19 +65,64 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithParams) { std::make_unique("b", ast::StorageClass::kNone, &i32)); ast::type::VoidType void_type; - ast::Function func("my_func", std::move(params), &void_type); + auto func = + std::make_unique("my_func", std::move(params), &void_type); ast::StatementList body; body.push_back(std::make_unique()); - func.set_body(std::move(body)); + func->set_body(std::move(body)); + + ast::Module m; + m.AddFunction(std::move(func)); GeneratorImpl g; g.increment_indent(); - ASSERT_TRUE(g.EmitFunction(&func)); + ASSERT_TRUE(g.Generate(m)) << g.error(); EXPECT_EQ(g.result(), R"( void my_func(float a, int b) { return; } + +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_NoName) { + ast::type::VoidType void_type; + + auto func = std::make_unique("frag_main", ast::VariableList{}, + &void_type); + auto ep = std::make_unique(ast::PipelineStage::kFragment, "", + "frag_main"); + + ast::Module m; + m.AddFunction(std::move(func)); + m.AddEntryPoint(std::move(ep)); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(m)) << g.error(); + EXPECT_EQ(g.result(), R"(fragment void frag_main() { +} + +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) { + ast::type::VoidType void_type; + + auto func = std::make_unique("comp_main", ast::VariableList{}, + &void_type); + auto ep = std::make_unique(ast::PipelineStage::kCompute, + "main", "comp_main"); + + ast::Module m; + m.AddFunction(std::move(func)); + m.AddEntryPoint(std::move(ep)); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(m)) << g.error(); + EXPECT_EQ(g.result(), R"(kernel void main() { +} + )"); }