diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index affbbf2d31..0a2c72f5d8 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -69,6 +69,8 @@ GeneratorImpl::~GeneratorImpl() = default; bool GeneratorImpl::Generate(const ast::Module& module) { module_ = &module; + out_ << "#include " << std::endl << std::endl; + for (auto* const alias : module.alias_types()) { if (!EmitAliasType(alias)) { return false; diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 083daac986..e555ed5a82 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -47,7 +47,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function) { g.increment_indent(); ASSERT_TRUE(g.Generate(m)) << g.error(); - EXPECT_EQ(g.result(), R"( void my_func() { + EXPECT_EQ(g.result(), R"(#include + + void my_func() { return; } @@ -71,7 +73,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_Name_Collision) { g.increment_indent(); ASSERT_TRUE(g.Generate(m)) << g.error(); - EXPECT_EQ(g.result(), R"( void main_tint_0() { + EXPECT_EQ(g.result(), R"(#include + + void main_tint_0() { return; } @@ -103,7 +107,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithParams) { g.increment_indent(); ASSERT_TRUE(g.Generate(m)) << g.error(); - EXPECT_EQ(g.result(), R"( void my_func(float a, int b) { + EXPECT_EQ(g.result(), R"(#include + + void my_func(float a, int b) { return; } @@ -124,7 +130,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_NoName) { GeneratorImpl g; ASSERT_TRUE(g.Generate(m)) << g.error(); - EXPECT_EQ(g.result(), R"(fragment void frag_main() { + EXPECT_EQ(g.result(), R"(#include + +fragment void frag_main() { } )"); @@ -144,7 +152,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) { GeneratorImpl g; ASSERT_TRUE(g.Generate(m)) << g.error(); - EXPECT_EQ(g.result(), R"(kernel void my_main() { + EXPECT_EQ(g.result(), R"(#include + +kernel void my_main() { } )"); @@ -164,7 +174,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithNameCollision) { GeneratorImpl g; ASSERT_TRUE(g.Generate(m)) << g.error(); - EXPECT_EQ(g.result(), R"(kernel void main_tint_0() { + EXPECT_EQ(g.result(), R"(#include + +kernel void main_tint_0() { } )"); @@ -193,7 +205,9 @@ TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) { g.increment_indent(); ASSERT_TRUE(g.Generate(m)) << g.error(); - EXPECT_EQ(g.result(), R"( void my_func(float a[5]) { + EXPECT_EQ(g.result(), R"(#include + + void my_func(float a[5]) { return; }