From df415a8919d1c6d85d119e8aefc8396b37af0548 Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Wed, 15 Jul 2020 18:04:11 +0000 Subject: [PATCH] [msl-writer] Generate entry point functions. This CL generates entry point functions and duplicate functions as needed to call from the entry points. Bug: tint:8 Change-Id: I8092ce463248e7a887c26ae05b0774e8fa21ab94 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/24764 Reviewed-by: David Neto --- src/ast/decorated_variable.cc | 9 + src/ast/decorated_variable.h | 3 + src/ast/module.cc | 8 + src/ast/module.h | 5 + src/ast/module_test.cc | 13 + src/writer/msl/generator_impl.cc | 260 +++++++++-- src/writer/msl/generator_impl.h | 43 +- .../msl/generator_impl_entry_point_test.cc | 36 +- .../msl/generator_impl_function_test.cc | 422 ++++++++++++++++++ src/writer/msl/generator_impl_test.cc | 12 +- test/triangle.wgsl | 4 +- 11 files changed, 742 insertions(+), 73 deletions(-) diff --git a/src/ast/decorated_variable.cc b/src/ast/decorated_variable.cc index 340075e436..89db9e4dd6 100644 --- a/src/ast/decorated_variable.cc +++ b/src/ast/decorated_variable.cc @@ -26,6 +26,15 @@ DecoratedVariable::DecoratedVariable(DecoratedVariable&&) = default; DecoratedVariable::~DecoratedVariable() = default; +bool DecoratedVariable::HasLocationDecoration() const { + for (const auto& deco : decorations_) { + if (deco->IsLocation()) { + return true; + } + } + return false; +} + bool DecoratedVariable::IsDecorated() const { return true; } diff --git a/src/ast/decorated_variable.h b/src/ast/decorated_variable.h index 992e641248..d2e381a2dc 100644 --- a/src/ast/decorated_variable.h +++ b/src/ast/decorated_variable.h @@ -45,6 +45,9 @@ class DecoratedVariable : public Variable { /// @returns the decorations attached to this variable const VariableDecorationList& decorations() const { return decorations_; } + /// @returns true if the decorations include a LocationDecoration + bool HasLocationDecoration() const; + /// @returns true if this is a decorated variable bool IsDecorated() const override; diff --git a/src/ast/module.cc b/src/ast/module.cc index 7ddd7e7771..74a081cce7 100644 --- a/src/ast/module.cc +++ b/src/ast/module.cc @@ -43,6 +43,14 @@ Function* Module::FindFunctionByName(const std::string& name) const { return nullptr; } +bool Module::IsFunctionEntryPoint(const std::string& name) const { + for (const auto& ep : entry_points_) { + if (ep->function_name() == name) + return true; + } + return false; +} + bool Module::IsValid() const { for (const auto& import : imports_) { if (import == nullptr || !import->IsValid()) { diff --git a/src/ast/module.h b/src/ast/module.h index 77e1924d97..d17a0cb44c 100644 --- a/src/ast/module.h +++ b/src/ast/module.h @@ -65,6 +65,11 @@ class Module { /// @returns the entry points in the module const EntryPointList& entry_points() const { return entry_points_; } + /// Checks if the given function name is an entry point function + /// @param name the function name + /// @returns true if name is an entry point function + bool IsFunctionEntryPoint(const std::string& name) const; + /// Adds a type alias to the module /// @param type the alias to add void AddAliasType(type::AliasType* type) { alias_types_.push_back(type); } diff --git a/src/ast/module_test.cc b/src/ast/module_test.cc index 869f3040e6..4cd0ebbaf3 100644 --- a/src/ast/module_test.cc +++ b/src/ast/module_test.cc @@ -91,6 +91,19 @@ TEST_F(ModuleTest, LookupFunction) { EXPECT_EQ(func_ptr, m.FindFunctionByName("main")); } +TEST_F(ModuleTest, IsEntryPoint) { + type::F32Type f32; + Module m; + + auto func = std::make_unique("other_func", VariableList{}, &f32); + m.AddFunction(std::move(func)); + + m.AddEntryPoint( + std::make_unique(PipelineStage::kVertex, "main", "vtx_main")); + EXPECT_TRUE(m.IsFunctionEntryPoint("vtx_main")); + EXPECT_FALSE(m.IsFunctionEntryPoint("other_func")); +} + TEST_F(ModuleTest, LookupFunctionMissing) { Module m; EXPECT_EQ(nullptr, m.FindFunctionByName("Missing")); diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index 5568d7ea5f..defabf5627 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -59,6 +59,8 @@ namespace { const char kInStructNameSuffix[] = "in"; const char kOutStructNameSuffix[] = "out"; +const char kTintStructInVarPrefix[] = "tint_in"; +const char kTintStructOutVarPrefix[] = "tint_out"; bool last_is_break_or_fallthrough(const ast::StatementList& stmts) { if (stmts.empty()) { @@ -78,13 +80,11 @@ void GeneratorImpl::set_module_for_testing(ast::Module* mod) { module_ = mod; } -std::string GeneratorImpl::generate_struct_name(ast::EntryPoint* ep, - const std::string& type) { - std::string base_name = ep->function_name() + "_" + type; - std::string name = base_name; +std::string GeneratorImpl::generate_name(const std::string& prefix) { + std::string name = prefix; uint32_t i = 0; while (namer_.IsMapped(name)) { - name = base_name + "_" + std::to_string(i); + name = prefix + "_" + std::to_string(i); ++i; } namer_.RegisterRemappedName(name); @@ -96,6 +96,10 @@ bool GeneratorImpl::Generate(const ast::Module& module) { out_ << "#include " << std::endl << std::endl; + for (const auto& global : module.global_variables()) { + global_variables_.set(global->name(), global.get()); + } + for (auto* const alias : module.alias_types()) { if (!EmitAliasType(alias)) { return false; @@ -106,7 +110,7 @@ bool GeneratorImpl::Generate(const ast::Module& module) { } for (const auto& ep : module.entry_points()) { - if (!EmitEntryPoint(ep.get())) { + if (!EmitEntryPointData(ep.get())) { return false; } } @@ -115,6 +119,12 @@ bool GeneratorImpl::Generate(const ast::Module& module) { if (!EmitFunction(func.get())) { return false; } + } + + for (const auto& ep : module.entry_points()) { + if (!EmitEntryPointFunction(ep.get())) { + return false; + } out_ << std::endl; } @@ -283,12 +293,32 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { } if (!ident->has_path()) { - if (!EmitExpression(expr->func())) { - return false; + auto name = ident->name(); + auto it = ep_func_name_remapped_.find(current_ep_name_ + "_" + name); + if (it != ep_func_name_remapped_.end()) { + name = it->second; } - out_ << "("; + out_ << name << "("; bool first = true; + + auto in_it = ep_name_to_in_data_.find(current_ep_name_); + if (in_it != ep_name_to_in_data_.end()) { + out_ << in_it->second.var_name; + first = false; + } + + auto out_it = ep_name_to_out_data_.find(current_ep_name_); + if (out_it != ep_name_to_out_data_.end()) { + if (!first) { + out_ << ", "; + } + out_ << out_it->second.var_name; + first = false; + } + + // TODO(dsinclair): Emit builtins + const auto& params = expr->params(); for (const auto& param : params) { if (!first) { @@ -459,7 +489,7 @@ bool GeneratorImpl::EmitLiteral(ast::Literal* lit) { return true; } -bool GeneratorImpl::EmitEntryPoint(ast::EntryPoint* ep) { +bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) { auto* func = module_->FindFunctionByName(ep->function_name()); if (func == nullptr) { error_ = "Unable to find entry point function: " + ep->function_name(); @@ -491,9 +521,20 @@ bool GeneratorImpl::EmitEntryPoint(ast::EntryPoint* ep) { } } + auto ep_name = ep->name(); + if (ep_name.empty()) { + ep_name = ep->function_name(); + } + + // TODO(dsinclair): There is a potential bug here. Entry points can have the + // same name in WGSL if they have different pipeline stages. This does not + // take that into account and will emit duplicate struct names. I'm ignoring + // this until https://github.com/gpuweb/gpuweb/issues/662 is resolved as it + // may remove this issue and entry point names will need to be unique. if (!in_locations.empty()) { - auto in_struct_name = generate_struct_name(ep, kInStructNameSuffix); - ep_name_to_in_struct_[ep->name()] = in_struct_name; + auto in_struct_name = generate_name(ep_name + "_" + kInStructNameSuffix); + auto in_var_name = generate_name(kTintStructInVarPrefix); + ep_name_to_in_data_[ep_name] = {in_struct_name, in_var_name}; make_indent(); out_ << "struct " << in_struct_name << " {" << std::endl; @@ -527,8 +568,9 @@ bool GeneratorImpl::EmitEntryPoint(ast::EntryPoint* ep) { } if (!out_locations.empty()) { - auto out_struct_name = generate_struct_name(ep, kOutStructNameSuffix); - ep_name_to_out_struct_[ep->name()] = out_struct_name; + auto out_struct_name = generate_name(ep_name + "_" + kOutStructNameSuffix); + auto out_var_name = generate_name(kTintStructOutVarPrefix); + ep_name_to_out_data_[ep_name] = {out_struct_name, out_var_name}; make_indent(); out_ << "struct " << out_struct_name << " {" << std::endl; @@ -615,33 +657,82 @@ void GeneratorImpl::EmitStage(ast::PipelineStage stage) { bool GeneratorImpl::EmitFunction(ast::Function* func) { make_indent(); - // TODO(dsinclair): Technically this is wrong as you could, in theory, have - // multiple entry points pointing at the same function. I'm ignoring that for - // now. It will either go away with the entry_point changes in the spec - // or we'll have to figure out how to deal with it. - - auto name = func->name(); - - for (const auto& ep : module_->entry_points()) { - if (ep->function_name() == name) { - EmitStage(ep->stage()); - out_ << " "; - - if (!ep->name().empty()) { - name = ep->name(); - } - - break; - } + // Entry points will be emitted later, skip for now. + if (module_->IsFunctionEntryPoint(func->name())) { + return true; } + // TODO(dsinclair): This could be smarter. If the input/outputs for multiple + // entry points are the same we could generate a single struct and then have + // this determine it's the same struct and just emit once. + bool emit_duplicate_functions = + func->ancestor_entry_points().size() > 0 && + func->referenced_module_variables().size() > 0; + + if (emit_duplicate_functions) { + for (const auto& ep_name : func->ancestor_entry_points()) { + if (!EmitFunctionInternal(func, emit_duplicate_functions, ep_name)) { + return false; + } + out_ << std::endl; + } + } else { + // Emit as non-duplicated + if (!EmitFunctionInternal(func, false, "")) { + return false; + } + out_ << std::endl; + } + + return true; +} + +bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, + bool emit_duplicate_functions, + const std::string& ep_name) { + auto name = func->name(); + if (!EmitType(func->return_type(), "")) { return false; } - out_ << " " << namer_.NameFor(name) << "("; + out_ << " "; + + if (emit_duplicate_functions) { + name = generate_name(name + "_" + ep_name); + ep_func_name_remapped_[ep_name + "_" + func->name()] = name; + } else { + name = namer_.NameFor(name); + } + out_ << name << "("; bool first = true; + + // If we're emitting duplicate functions that means the function takes + // the stage_in or stage_out value from the entry point, emit them. + // + // We emit both of them if they're there regardless of if they're both used. + if (emit_duplicate_functions) { + auto in_it = ep_name_to_in_data_.find(ep_name); + if (in_it != ep_name_to_in_data_.end()) { + out_ << "thread " << in_it->second.struct_name << "& " + << in_it->second.var_name; + first = false; + } + + auto out_it = ep_name_to_out_data_.find(ep_name); + if (out_it != ep_name_to_out_data_.end()) { + if (!first) { + out_ << ", "; + } + out_ << "thread " << out_it->second.struct_name << "& " + << out_it->second.var_name; + first = false; + } + } + + // TODO(dsinclair): Handle any entry point builtin params used here + for (const auto& v : func->params()) { if (!first) { out_ << ", "; @@ -656,9 +747,79 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { out_ << " " << v->name(); } } + out_ << ")"; - return EmitStatementBlockAndNewline(func->body()); + current_ep_name_ = ep_name; + + if (!EmitStatementBlockAndNewline(func->body())) { + return false; + } + + current_ep_name_ = ""; + + return true; +} + +bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) { + make_indent(); + + current_ep_name_ = ep->name(); + if (current_ep_name_.empty()) { + current_ep_name_ = ep->function_name(); + } + + auto* func = module_->FindFunctionByName(ep->function_name()); + if (func == nullptr) { + error_ = "unable to find function for entry point: " + ep->function_name(); + return false; + } + + EmitStage(ep->stage()); + out_ << " "; + + // This is an entry point, the return type is the entry point output structure + // if one exists, or void otherwise. + auto out_data = ep_name_to_out_data_.find(current_ep_name_); + bool has_out_data = out_data != ep_name_to_out_data_.end(); + if (has_out_data) { + out_ << out_data->second.struct_name; + } else { + out_ << "void"; + } + out_ << " " << namer_.NameFor(current_ep_name_) << "("; + + auto in_data = ep_name_to_in_data_.find(current_ep_name_); + if (in_data != ep_name_to_in_data_.end()) { + out_ << in_data->second.struct_name << " " << in_data->second.var_name + << " [[stage_in]]"; + } + + // TODO(dsinclair): Output other builtin inputs + out_ << ") {" << std::endl; + + increment_indent(); + + if (has_out_data) { + make_indent(); + out_ << out_data->second.struct_name << " " << out_data->second.var_name + << " = {};" << std::endl; + } + + generating_entry_point_ = true; + for (const auto& s : func->body()) { + if (!EmitStatement(s.get())) { + return false; + } + } + generating_entry_point_ = false; + + decrement_indent(); + make_indent(); + out_ << "}" << std::endl; + + current_ep_name_ = ""; + return true; } bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) { @@ -668,7 +829,30 @@ bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) { error_ = "Identifier paths not handled yet."; return false; } + + ast::Variable* var = nullptr; + if (global_variables_.get(ident->name(), &var)) { + if (var->storage_class() == ast::StorageClass::kInput && + var->IsDecorated() && var->AsDecorated()->HasLocationDecoration()) { + auto it = ep_name_to_in_data_.find(current_ep_name_); + if (it == ep_name_to_in_data_.end()) { + error_ = "unable to find entry point data for input"; + return false; + } + out_ << it->second.var_name << "."; + } else if (var->storage_class() == ast::StorageClass::kOutput && + var->IsDecorated() && + var->AsDecorated()->HasLocationDecoration()) { + auto it = ep_name_to_out_data_.find(current_ep_name_); + if (it == ep_name_to_out_data_.end()) { + error_ = "unable to find entry point data for output"; + return false; + } + out_ << it->second.var_name << "."; + } + } out_ << namer_.NameFor(ident->name()); + return true; } @@ -785,7 +969,13 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) { make_indent(); out_ << "return"; - if (stmt->has_value()) { + + if (generating_entry_point_) { + auto out_data = ep_name_to_out_data_.find(current_ep_name_); + if (out_data != ep_name_to_out_data_.end()) { + out_ << " " << out_data->second.var_name; + } + } else if (stmt->has_value()) { out_ << " "; if (!EmitExpression(stmt->value())) { return false; diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index 4c94f3c446..c851ee76c3 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -23,6 +23,7 @@ #include "src/ast/module.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/type_constructor_expression.h" +#include "src/scope_stack.h" #include "src/writer/msl/namer.h" #include "src/writer/text_generator.h" @@ -93,7 +94,11 @@ class GeneratorImpl : public TextGenerator { /// Handles emitting information for an entry point /// @param ep the entry point /// @returns true if the entry point data was emitted - bool EmitEntryPoint(ast::EntryPoint* ep); + bool EmitEntryPointData(ast::EntryPoint* ep); + /// Handles emitting the entry point function + /// @param ep the entry point + /// @returns true if the entry point function was emitted + bool EmitEntryPointFunction(ast::EntryPoint* ep); /// Handles generate an Expression /// @param expr the expression /// @returns true if the expression was emitted @@ -102,6 +107,15 @@ class GeneratorImpl : public TextGenerator { /// @param func the function to generate /// @returns true if the function was emitted bool EmitFunction(ast::Function* func); + /// Internal helper for emitting functions + /// @param func the function to emit + /// @param emit_duplicate_functions set true if we need to duplicate per entry + /// point + /// @param ep_name the current entry point or blank if none set + /// @returns true if the function was emitted. + bool EmitFunctionInternal(ast::Function* func, + bool emit_duplicate_functions, + const std::string& ep_name); /// Handles generating an identifier expression /// @param expr the identifier expression /// @returns true if the identifeir was emitted @@ -179,22 +193,33 @@ class GeneratorImpl : public TextGenerator { /// @param mod the module to set. void set_module_for_testing(ast::Module* mod); - /// Generates a name for the input struct - /// @param ep the entry point to generate for - /// @param type the type of struct to generate - /// @returns the input struct name - std::string generate_struct_name(ast::EntryPoint* ep, - const std::string& type); + /// Generates a name for the prefix + /// @param prefix the prefix of the name to generate + /// @returns the name + std::string generate_name(const std::string& prefix); /// @returns the namer for testing Namer* namer_for_testing() { return &namer_; } private: Namer namer_; + ScopeStack global_variables_; + std::string current_ep_name_; + bool generating_entry_point_ = false; const ast::Module* module_ = nullptr; uint32_t loop_emission_counter_ = 0; - std::unordered_map ep_name_to_in_struct_; - std::unordered_map ep_name_to_out_struct_; + + struct EntryPointData { + std::string struct_name; + std::string var_name; + }; + std::unordered_map ep_name_to_in_data_; + std::unordered_map ep_name_to_out_data_; + + // This maps an input of "_" to a remapped + // function name. If there is no entry for a given key then function did + // not need to be remapped for the entry point and can be emitted directly. + std::unordered_map ep_func_name_remapped_; }; } // namespace msl diff --git a/src/writer/msl/generator_impl_entry_point_test.cc b/src/writer/msl/generator_impl_entry_point_test.cc index a102300d41..824ac6ceab 100644 --- a/src/writer/msl/generator_impl_entry_point_test.cc +++ b/src/writer/msl/generator_impl_entry_point_test.cc @@ -33,7 +33,7 @@ namespace { using MslGeneratorImplTest = testing::Test; -TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Input) { +TEST_F(MslGeneratorImplTest, EmitEntryPointData_Vertex_Input) { // [[location 0]] var foo : f32; // [[location 1]] var bar : i32; // @@ -81,8 +81,8 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Input) { mod.AddFunction(std::move(func)); - auto ep = std::make_unique(ast::PipelineStage::kVertex, - "main", "vtx_main"); + auto ep = std::make_unique(ast::PipelineStage::kVertex, "", + "vtx_main"); auto* ep_ptr = ep.get(); mod.AddEntryPoint(std::move(ep)); @@ -91,7 +91,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Input) { GeneratorImpl g; g.set_module_for_testing(&mod); - ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error(); + ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error(); EXPECT_EQ(g.result(), R"(struct vtx_main_in { float foo [[attribute(0)]]; int bar [[attribute(1)]]; @@ -100,7 +100,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Input) { )"); } -TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Output) { +TEST_F(MslGeneratorImplTest, EmitEntryPointData_Vertex_Output) { // [[location 0]] var foo : f32; // [[location 1]] var bar : i32; // @@ -148,8 +148,8 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Output) { mod.AddFunction(std::move(func)); - auto ep = std::make_unique(ast::PipelineStage::kVertex, - "main", "vtx_main"); + auto ep = std::make_unique(ast::PipelineStage::kVertex, "", + "vtx_main"); auto* ep_ptr = ep.get(); mod.AddEntryPoint(std::move(ep)); @@ -158,7 +158,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Output) { GeneratorImpl g; g.set_module_for_testing(&mod); - ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error(); + ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error(); EXPECT_EQ(g.result(), R"(struct vtx_main_out { float foo [[user(locn0)]]; int bar [[user(locn1)]]; @@ -167,7 +167,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Vertex_Output) { )"); } -TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Input) { +TEST_F(MslGeneratorImplTest, EmitEntryPointData_Fragment_Input) { // [[location 0]] var foo : f32; // [[location 1]] var bar : i32; // @@ -225,8 +225,8 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Input) { GeneratorImpl g; g.set_module_for_testing(&mod); - ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error(); - EXPECT_EQ(g.result(), R"(struct frag_main_in { + ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error(); + EXPECT_EQ(g.result(), R"(struct main_in { float foo [[user(locn0)]]; int bar [[user(locn1)]]; }; @@ -234,7 +234,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Input) { )"); } -TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Output) { +TEST_F(MslGeneratorImplTest, EmitEntryPointData_Fragment_Output) { // [[location 0]] var foo : f32; // [[location 1]] var bar : i32; // @@ -292,8 +292,8 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Output) { GeneratorImpl g; g.set_module_for_testing(&mod); - ASSERT_TRUE(g.EmitEntryPoint(ep_ptr)) << g.error(); - EXPECT_EQ(g.result(), R"(struct frag_main_out { + ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error(); + EXPECT_EQ(g.result(), R"(struct main_out { float foo [[color(0)]]; int bar [[color(1)]]; }; @@ -301,7 +301,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Fragment_Output) { )"); } -TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Input) { +TEST_F(MslGeneratorImplTest, EmitEntryPointData_Compute_Input) { // [[location 0]] var foo : f32; // [[location 1]] var bar : i32; // @@ -356,11 +356,11 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Input) { GeneratorImpl g; g.set_module_for_testing(&mod); - ASSERT_FALSE(g.EmitEntryPoint(ep_ptr)) << g.error(); + ASSERT_FALSE(g.EmitEntryPointData(ep_ptr)) << g.error(); EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)"); } -TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Output) { +TEST_F(MslGeneratorImplTest, EmitEntryPointData_Compute_Output) { // [[location 0]] var foo : f32; // [[location 1]] var bar : i32; // @@ -415,7 +415,7 @@ TEST_F(MslGeneratorImplTest, EmitEntryPoint_Compute_Output) { GeneratorImpl g; g.set_module_for_testing(&mod); - ASSERT_FALSE(g.EmitEntryPoint(ep_ptr)) << g.error(); + ASSERT_FALSE(g.EmitEntryPointData(ep_ptr)) << g.error(); EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)"); } diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index e555ed5a82..7d49b1dee0 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -13,14 +13,27 @@ // limitations under the License. #include "gtest/gtest.h" +#include "src/ast/assignment_statement.h" +#include "src/ast/binary_expression.h" +#include "src/ast/call_expression.h" +#include "src/ast/decorated_variable.h" +#include "src/ast/float_literal.h" #include "src/ast/function.h" +#include "src/ast/identifier_expression.h" +#include "src/ast/if_statement.h" +#include "src/ast/location_decoration.h" #include "src/ast/module.h" #include "src/ast/return_statement.h" +#include "src/ast/scalar_constructor_expression.h" +#include "src/ast/sint_literal.h" #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" #include "src/ast/type/void_type.h" #include "src/ast/variable.h" +#include "src/ast/variable_decl_statement.h" +#include "src/context.h" +#include "src/type_determiner.h" #include "src/writer/msl/generator_impl.h" namespace tint { @@ -138,6 +151,415 @@ fragment void frag_main() { )"); } +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithInOutVars) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &f32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("bar"), + std::make_unique("foo"))); + body.push_back(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + auto ep = std::make_unique(ast::PipelineStage::kFragment, "", + "frag_main"); + mod.AddEntryPoint(std::move(ep)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(mod)) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct frag_main_in { + float foo [[user(locn0)]]; +}; + +struct frag_main_out { + float bar [[color(1)]]; +}; + +fragment frag_main_out frag_main(frag_main_in tint_in [[stage_in]]) { + frag_main_out tint_out = {}; + tint_out.bar = tint_in.foo; + return tint_out; +} + +)"); +} + +TEST_F(MslGeneratorImplTest, + Emit_Function_Called_By_EntryPoints_WithGlobals_And_Params) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &f32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + auto val_var = std::make_unique( + std::make_unique("val", ast::StorageClass::kOutput, &f32)); + decos.push_back(std::make_unique(0)); + val_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + td.RegisterVariableForTesting(val_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + mod.AddGlobalVariable(std::move(val_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("bar"), + std::make_unique("foo"))); + body.push_back(std::make_unique( + std::make_unique("val"), + std::make_unique("param"))); + body.push_back(std::make_unique( + std::make_unique("foo"))); + sub_func->set_body(std::move(body)); + + mod.AddFunction(std::move(sub_func)); + + auto func_1 = std::make_unique("frag_1_main", + std::move(params), &void_type); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + body.push_back(std::make_unique( + std::make_unique("bar"), + std::make_unique( + std::make_unique("sub_func"), + std::move(expr)))); + body.push_back(std::make_unique()); + func_1->set_body(std::move(body)); + + mod.AddFunction(std::move(func_1)); + + auto ep1 = std::make_unique(ast::PipelineStage::kFragment, + "ep_1", "frag_1_main"); + mod.AddEntryPoint(std::move(ep1)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(mod)) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct ep_1_in { + float foo [[user(locn0)]]; +}; + +struct ep_1_out { + float bar [[color(1)]]; + float val [[color(0)]]; +}; + +float sub_func_ep_1(thread ep_1_in& tint_in, thread ep_1_out& tint_out, float param) { + tint_out.bar = tint_in.foo; + tint_out.val = param; + return tint_in.foo; +} + +fragment ep_1_out ep_1(ep_1_in tint_in [[stage_in]]) { + ep_1_out tint_out = {}; + tint_out.bar = sub_func_ep_1(tint_in, tint_out, 1.00000000f); + return tint_out; +} + +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &f32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("bar"), + std::make_unique("foo"))); + body.push_back(std::make_unique( + std::make_unique("foo"))); + sub_func->set_body(std::move(body)); + + mod.AddFunction(std::move(sub_func)); + + auto func_1 = std::make_unique("frag_1_main", + std::move(params), &void_type); + + body.push_back(std::make_unique( + std::make_unique("bar"), + std::make_unique( + std::make_unique("sub_func"), + ast::ExpressionList{}))); + body.push_back(std::make_unique()); + func_1->set_body(std::move(body)); + + mod.AddFunction(std::move(func_1)); + + auto ep1 = std::make_unique(ast::PipelineStage::kFragment, + "ep_1", "frag_1_main"); + auto ep2 = std::make_unique(ast::PipelineStage::kFragment, + "ep_2", "frag_1_main"); + mod.AddEntryPoint(std::move(ep1)); + mod.AddEntryPoint(std::move(ep2)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(mod)) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct ep_1_in { + float foo [[user(locn0)]]; +}; + +struct ep_1_out { + float bar [[color(1)]]; +}; + +struct ep_2_in { + float foo [[user(locn0)]]; +}; + +struct ep_2_out { + float bar [[color(1)]]; +}; + +float sub_func_ep_1(thread ep_1_in& tint_in, thread ep_1_out& tint_out) { + tint_out.bar = tint_in.foo; + return tint_in.foo; +} + +float sub_func_ep_2(thread ep_2_in& tint_in, thread ep_2_out& tint_out) { + tint_out.bar = tint_in.foo; + return tint_in.foo; +} + +fragment ep_1_out ep_1(ep_1_in tint_in [[stage_in]]) { + ep_1_out tint_out = {}; + tint_out.bar = sub_func_ep_1(tint_in, tint_out); + return tint_out; +} + +fragment ep_2_out ep_2(ep_2_in tint_in [[stage_in]]) { + ep_2_out tint_out = {}; + tint_out.bar = sub_func_ep_2(tint_in, tint_out); + return tint_out; +} + +)"); +} + +TEST_F(MslGeneratorImplTest, + Emit_Function_EntryPoints_WithGlobal_Nested_Return) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &f32)); + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(bar_var.get()); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func_1 = std::make_unique("frag_1_main", + std::move(params), &void_type); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("bar"), + std::make_unique( + std::make_unique(&f32, 1.0f)))); + + ast::StatementList list; + list.push_back(std::make_unique()); + body.push_back(std::make_unique( + std::make_unique( + ast::BinaryOp::kEqual, + std::make_unique( + std::make_unique(&i32, 1)), + std::make_unique( + std::make_unique(&i32, 1))), + std::move(list))); + + body.push_back(std::make_unique()); + func_1->set_body(std::move(body)); + + mod.AddFunction(std::move(func_1)); + + auto ep1 = std::make_unique(ast::PipelineStage::kFragment, + "ep_1", "frag_1_main"); + mod.AddEntryPoint(std::move(ep1)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(mod)) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct ep_1_out { + float bar [[color(1)]]; +}; + +fragment ep_1_out ep_1() { + ep_1_out tint_out = {}; + tint_out.bar = 1.00000000f; + if ((1 == 1)) { + return tint_out; + } + return tint_out; +} + +)"); +} + +TEST_F(MslGeneratorImplTest, + Emit_Function_Called_Two_EntryPoints_WithoutGlobals) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + + ast::VariableList params; + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique( + std::make_unique(&f32, 1.0)))); + sub_func->set_body(std::move(body)); + + mod.AddFunction(std::move(sub_func)); + + auto func_1 = std::make_unique("frag_1_main", + std::move(params), &void_type); + + body.push_back(std::make_unique( + std::make_unique("foo", ast::StorageClass::kFunction, + &f32))); + body.back()->AsVariableDecl()->variable()->set_constructor( + std::make_unique( + std::make_unique("sub_func"), + ast::ExpressionList{})); + + body.push_back(std::make_unique()); + func_1->set_body(std::move(body)); + + mod.AddFunction(std::move(func_1)); + + auto ep1 = std::make_unique(ast::PipelineStage::kFragment, + "ep_1", "frag_1_main"); + auto ep2 = std::make_unique(ast::PipelineStage::kFragment, + "ep_2", "frag_1_main"); + mod.AddEntryPoint(std::move(ep1)); + mod.AddEntryPoint(std::move(ep2)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(mod)) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +float sub_func() { + return 1.00000000f; +} + +fragment void ep_1() { + float foo = sub_func(); + return; +} + +fragment void ep_2() { + float foo = sub_func(); + return; +} + +)"); +} TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) { ast::type::VoidType void_type; diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc index 41b0cbb3d2..063fd1ff15 100644 --- a/src/writer/msl/generator_impl_test.cc +++ b/src/writer/msl/generator_impl_test.cc @@ -51,29 +51,23 @@ compute void my_func() { } TEST_F(MslGeneratorImplTest, InputStructName) { - ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main"); - GeneratorImpl g; - ASSERT_EQ(g.generate_struct_name(&ep, "in"), "func_main_in"); + ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in"); } TEST_F(MslGeneratorImplTest, InputStructName_ConflictWithExisting) { - ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main"); - GeneratorImpl g; // Register the struct name as existing. auto* namer = g.namer_for_testing(); namer->NameFor("func_main_out"); - ASSERT_EQ(g.generate_struct_name(&ep, "out"), "func_main_out_0"); + ASSERT_EQ(g.generate_name("func_main_out"), "func_main_out_0"); } TEST_F(MslGeneratorImplTest, NameConflictWith_InputStructName) { - ast::EntryPoint ep(ast::PipelineStage::kVertex, "main", "func_main"); - GeneratorImpl g; - ASSERT_EQ(g.generate_struct_name(&ep, "in"), "func_main_in"); + ASSERT_EQ(g.generate_name("func_main_in"), "func_main_in"); ast::IdentifierExpression ident("func_main_in"); ASSERT_TRUE(g.EmitIdentifier(&ident)); diff --git a/test/triangle.wgsl b/test/triangle.wgsl index 864417f613..6eb9ed7b00 100644 --- a/test/triangle.wgsl +++ b/test/triangle.wgsl @@ -28,9 +28,9 @@ fn vtx_main() -> void { entry_point vertex as "main" = vtx_main; # Fragment shader -[[location 0]] var outColor : ptr>; +[[location 0]] var outColor : vec4; fn frag_main() -> void { outColor = vec4(1, 0, 0, 1); return; } -entry_point fragment as "main" = frag_main; +entry_point fragment = frag_main;