From b5bb2d91afc1d7802eb2c7addbfdc9f5357930da Mon Sep 17 00:00:00 2001 From: dan sinclair Date: Mon, 21 Sep 2020 18:58:01 +0000 Subject: [PATCH] [msl-writer] Update to emit based on pipeline stage. This CL updates the MSL writer to emit data base on the pipeline stage. Change-Id: I9fb2e146f0c898d9703d69a6a92f535757106bba Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/28703 Commit-Queue: dan sinclair Reviewed-by: David Neto --- BUILD.gn | 1 + src/CMakeLists.txt | 1 + src/writer/msl/generator_impl.cc | 261 +++++- src/writer/msl/generator_impl.h | 8 + ...tor_impl_function_entry_point_data_test.cc | 471 +++++++++++ .../msl/generator_impl_function_test.cc | 762 ++++++++++++++++++ 6 files changed, 1503 insertions(+), 1 deletion(-) create mode 100644 src/writer/msl/generator_impl_function_entry_point_data_test.cc diff --git a/BUILD.gn b/BUILD.gn index d1cec02c21..e9c4b5ee42 100644 --- a/BUILD.gn +++ b/BUILD.gn @@ -1044,6 +1044,7 @@ source_set("tint_unittests_msl_writer_src") { "src/writer/msl/generator_impl_continue_test.cc", "src/writer/msl/generator_impl_discard_test.cc", "src/writer/msl/generator_impl_entry_point_test.cc", + "src/writer/msl/generator_impl_function_entry_point_data_test.cc", "src/writer/msl/generator_impl_function_test.cc", "src/writer/msl/generator_impl_identifier_test.cc", "src/writer/msl/generator_impl_if_test.cc", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 3aa2b7ee99..aa142150f6 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -572,6 +572,7 @@ if(${TINT_BUILD_MSL_WRITER}) writer/msl/generator_impl_continue_test.cc writer/msl/generator_impl_discard_test.cc writer/msl/generator_impl_entry_point_test.cc + writer/msl/generator_impl_function_entry_point_data_test.cc writer/msl/generator_impl_function_test.cc writer/msl/generator_impl_identifier_test.cc writer/msl/generator_impl_if_test.cc diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index 2251e3b48b..34c9d6bdc8 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -130,6 +130,17 @@ bool GeneratorImpl::Generate() { } } + // Make sure all entry point data is emitted before the entry point functions + for (const auto& func : module_->functions()) { + if (!func->IsEntryPoint()) { + continue; + } + + if (!EmitEntryPointData(func.get())) { + return false; + } + } + for (const auto& func : module_->functions()) { if (!EmitFunction(func.get())) { return false; @@ -142,6 +153,15 @@ bool GeneratorImpl::Generate() { } out_ << std::endl; } + for (const auto& func : module_->functions()) { + if (!func->IsEntryPoint()) { + continue; + } + if (!EmitEntryPointFunction(func.get())) { + return false; + } + out_ << std::endl; + } return true; } @@ -1011,6 +1031,119 @@ bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) { return true; } +bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { + std::vector> in_locations; + std::vector> + out_variables; + for (auto data : func->referenced_location_variables()) { + auto* var = data.first; + auto* deco = data.second; + + if (var->storage_class() == ast::StorageClass::kInput) { + in_locations.push_back({var, deco->value()}); + } else if (var->storage_class() == ast::StorageClass::kOutput) { + out_variables.push_back({var, deco}); + } + } + + for (auto data : func->referenced_builtin_variables()) { + auto* var = data.first; + auto* deco = data.second; + + if (var->storage_class() == ast::StorageClass::kOutput) { + out_variables.push_back({var, deco}); + } + } + + if (!in_locations.empty()) { + auto in_struct_name = + generate_name(func->name() + "_" + kInStructNameSuffix); + auto in_var_name = generate_name(kTintStructInVarPrefix); + ep_name_to_in_data_[func->name()] = {in_struct_name, in_var_name}; + + make_indent(); + out_ << "struct " << in_struct_name << " {" << std::endl; + + increment_indent(); + + for (auto& data : in_locations) { + auto* var = data.first; + uint32_t loc = data.second; + + make_indent(); + if (!EmitType(var->type(), var->name())) { + return false; + } + + out_ << " " << var->name() << " [["; + if (func->pipeline_stage() == ast::PipelineStage::kVertex) { + out_ << "attribute(" << loc << ")"; + } else if (func->pipeline_stage() == ast::PipelineStage::kFragment) { + out_ << "user(locn" << loc << ")"; + } else { + error_ = "invalid location variable for pipeline stage"; + return false; + } + out_ << "]];" << std::endl; + } + decrement_indent(); + make_indent(); + + out_ << "};" << std::endl << std::endl; + } + + if (!out_variables.empty()) { + auto out_struct_name = + generate_name(func->name() + "_" + kOutStructNameSuffix); + auto out_var_name = generate_name(kTintStructOutVarPrefix); + ep_name_to_out_data_[func->name()] = {out_struct_name, out_var_name}; + + make_indent(); + out_ << "struct " << out_struct_name << " {" << std::endl; + + increment_indent(); + for (auto& data : out_variables) { + auto* var = data.first; + auto* deco = data.second; + + make_indent(); + if (!EmitType(var->type(), var->name())) { + return false; + } + + out_ << " " << var->name() << " [["; + + if (deco->IsLocation()) { + auto loc = deco->AsLocation()->value(); + if (func->pipeline_stage() == ast::PipelineStage::kVertex) { + out_ << "user(locn" << loc << ")"; + } else if (func->pipeline_stage() == ast::PipelineStage::kFragment) { + out_ << "color(" << loc << ")"; + } else { + error_ = "invalid location variable for pipeline stage"; + return false; + } + } else if (deco->IsBuiltin()) { + auto attr = builtin_to_attribute(deco->AsBuiltin()->value()); + if (attr.empty()) { + error_ = "unsupported builtin"; + return false; + } + out_ << attr; + } else { + error_ = "unsupported variable decoration for entry point output"; + return false; + } + out_ << "]];" << std::endl; + } + decrement_indent(); + make_indent(); + out_ << "};" << std::endl << std::endl; + } + + return true; +} + bool GeneratorImpl::EmitExpression(ast::Expression* expr) { if (expr->IsArrayAccessor()) { return EmitArrayAccessor(expr->AsArrayAccessor()); @@ -1097,7 +1230,7 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { make_indent(); // Entry points will be emitted later, skip for now. - if (module_->IsFunctionEntryPoint(func->name())) { + if (func->IsEntryPoint() || module_->IsFunctionEntryPoint(func->name())) { return true; } @@ -1404,6 +1537,132 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) { return true; } +bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { + make_indent(); + + current_ep_name_ = func->name(); + + EmitStage(func->pipeline_stage()); + out_ << " "; + + // This is an entry point, the return type is the entry point output structure + // if one exists, or void otherwise. + auto out_data = ep_name_to_out_data_.find(current_ep_name_); + bool has_out_data = out_data != ep_name_to_out_data_.end(); + if (has_out_data) { + out_ << out_data->second.struct_name; + } else { + out_ << "void"; + } + out_ << " " << namer_.NameFor(current_ep_name_) << "("; + + bool first = true; + auto in_data = ep_name_to_in_data_.find(current_ep_name_); + if (in_data != ep_name_to_in_data_.end()) { + out_ << in_data->second.struct_name << " " << in_data->second.var_name + << " [[stage_in]]"; + first = false; + } + + for (auto data : func->referenced_builtin_variables()) { + auto* var = data.first; + if (var->storage_class() != ast::StorageClass::kInput) { + continue; + } + + if (!first) { + out_ << ", "; + } + first = false; + + auto* builtin = data.second; + + if (!EmitType(var->type(), "")) { + return false; + } + + auto attr = builtin_to_attribute(builtin->value()); + if (attr.empty()) { + error_ = "unknown builtin"; + return false; + } + out_ << " " << var->name() << " [[" << attr << "]]"; + } + + for (auto data : func->referenced_uniform_variables()) { + if (!first) { + out_ << ", "; + } + first = false; + + auto* var = data.first; + // TODO(dsinclair): We're using the binding to make up the buffer number but + // we should instead be using a provided mapping that uses both buffer and + // set. https://bugs.chromium.org/p/tint/issues/detail?id=104 + auto* binding = data.second.binding; + if (binding == nullptr) { + error_ = "unable to find binding information for uniform: " + var->name(); + return false; + } + // auto* set = data.second.set; + + out_ << "constant "; + // TODO(dsinclair): Can you have a uniform array? If so, this needs to be + // updated to handle arrays property. + if (!EmitType(var->type(), "")) { + return false; + } + out_ << "& " << var->name() << " [[buffer(" << binding->value() << ")]]"; + } + + for (auto data : func->referenced_storagebuffer_variables()) { + if (!first) { + out_ << ", "; + } + first = false; + + auto* var = data.first; + // TODO(dsinclair): We're using the binding to make up the buffer number but + // we should instead be using a provided mapping that uses both buffer and + // set. https://bugs.chromium.org/p/tint/issues/detail?id=104 + auto* binding = data.second.binding; + // auto* set = data.second.set; + + out_ << "device "; + // TODO(dsinclair): Can you have a storagebuffer have an array? If so, this + // needs to be updated to handle arrays property. + if (!EmitType(var->type(), "")) { + return false; + } + out_ << "& " << var->name() << " [[buffer(" << binding->value() << ")]]"; + } + + out_ << ") {" << std::endl; + + increment_indent(); + + if (has_out_data) { + make_indent(); + out_ << out_data->second.struct_name << " " << out_data->second.var_name + << " = {};" << std::endl; + } + + generating_entry_point_ = true; + for (const auto& s : *(func->body())) { + if (!EmitStatement(s.get())) { + return false; + } + } + generating_entry_point_ = false; + + decrement_indent(); + make_indent(); + out_ << "}" << std::endl; + + current_ep_name_ = ""; + return true; +} + bool GeneratorImpl::global_is_in_struct(ast::Variable* var) const { bool in_or_out_struct_has_location = var->IsDecorated() && var->AsDecorated()->HasLocationDecoration() && diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index b574b43b82..5852f30491 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -121,10 +121,18 @@ class GeneratorImpl : public TextGenerator { /// @param ep the entry point /// @returns true if the entry point data was emitted bool EmitEntryPointData(ast::EntryPoint* ep); + /// Handles emitting information for an entry point + /// @param func the entry point function + /// @returns true if the entry point data was emitted + bool EmitEntryPointData(ast::Function* func); /// Handles emitting the entry point function /// @param ep the entry point /// @returns true if the entry point function was emitted bool EmitEntryPointFunction(ast::EntryPoint* ep); + /// Handles emitting the entry point function + /// @param func the entry point function + /// @returns true if the entry point function was emitted + bool EmitEntryPointFunction(ast::Function* func); /// Handles generate an Expression /// @param expr the expression /// @returns true if the expression was emitted diff --git a/src/writer/msl/generator_impl_function_entry_point_data_test.cc b/src/writer/msl/generator_impl_function_entry_point_data_test.cc new file mode 100644 index 0000000000..29cddd2891 --- /dev/null +++ b/src/writer/msl/generator_impl_function_entry_point_data_test.cc @@ -0,0 +1,471 @@ +// Copyright 2020 The Tint Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gtest/gtest.h" +#include "src/ast/assignment_statement.h" +#include "src/ast/decorated_variable.h" +#include "src/ast/identifier_expression.h" +#include "src/ast/location_decoration.h" +#include "src/ast/member_accessor_expression.h" +#include "src/ast/module.h" +#include "src/ast/pipeline_stage.h" +#include "src/ast/stage_decoration.h" +#include "src/ast/type/f32_type.h" +#include "src/ast/type/i32_type.h" +#include "src/ast/type/vector_type.h" +#include "src/ast/type/void_type.h" +#include "src/ast/variable.h" +#include "src/context.h" +#include "src/type_determiner.h" +#include "src/writer/msl/generator_impl.h" + +#include + +namespace tint { +namespace writer { +namespace msl { +namespace { + +using MslGeneratorImplTest = testing::Test; + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Input) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // struct vtx_main_in { + // float foo [[attribute(0)]]; + // int bar [[attribute(1)]]; + // }; + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kInput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = + std::make_unique("vtx_main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kVertex)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.EmitEntryPointData(func_ptr)) << g.error(); + EXPECT_EQ(g.result(), R"(struct vtx_main_in { + float foo [[attribute(0)]]; + int bar [[attribute(1)]]; +}; + +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Vertex_Output) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // struct vtx_main_out { + // float foo [[user(locn0)]]; + // int bar [[user(locn1)]]; + // }; + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kOutput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = + std::make_unique("vtx_main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kVertex)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.EmitEntryPointData(func_ptr)) << g.error(); + EXPECT_EQ(g.result(), R"(struct vtx_main_out { + float foo [[user(locn0)]]; + int bar [[user(locn1)]]; +}; + +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Input) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // struct frag_main_in { + // float foo [[user(locn0)]]; + // int bar [[user(locn1)]]; + // }; + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kInput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.EmitEntryPointData(func_ptr)) << g.error(); + EXPECT_EQ(g.result(), R"(struct main_in { + float foo [[user(locn0)]]; + int bar [[user(locn1)]]; +}; + +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Fragment_Output) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // struct frag_main_out { + // float foo [[color(0)]]; + // int bar [[color(1)]]; + // }; + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kOutput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.EmitEntryPointData(func_ptr)) << g.error(); + EXPECT_EQ(g.result(), R"(struct main_out { + float foo [[color(0)]]; + int bar [[color(1)]]; +}; + +)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Input) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // -> Error, not allowed + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kInput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kCompute)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_FALSE(g.EmitEntryPointData(func_ptr)) << g.error(); + EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Compute_Output) { + // [[location 0]] var foo : f32; + // [[location 1]] var bar : i32; + // + // -> Error not allowed + + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kOutput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &i32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("main", std::move(params), &f32); + func->add_decoration( + std::make_unique(ast::PipelineStage::kCompute)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("foo"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("bar"))); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_FALSE(g.EmitEntryPointData(func_ptr)) << g.error(); + EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)"); +} + +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPointData_Builtins) { + // Output builtins go in the output struct, input builtins will be passed + // as input parameters to the entry point function. + + // [[builtin frag_coord]] var coord : vec4; + // [[builtin frag_depth]] var depth : f32; + // + // struct main_out { + // float depth [[depth(any)]]; + // }; + + ast::type::F32Type f32; + ast::type::VoidType void_type; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kInput, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragCoord)); + coord_var->set_decorations(std::move(decos)); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + td.RegisterVariableForTesting(depth_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + mod.AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + auto func = + std::make_unique("main", std::move(params), &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + auto* func_ptr = func.get(); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.EmitEntryPointData(func_ptr)) << g.error(); + EXPECT_EQ(g.result(), R"(struct main_out { + float depth [[depth(any)]]; +}; + +)"); +} + +} // namespace +} // namespace msl +} // namespace writer +} // namespace tint diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index c8ea76481d..f162b6b28b 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -25,10 +25,12 @@ #include "src/ast/location_decoration.h" #include "src/ast/member_accessor_expression.h" #include "src/ast/module.h" +#include "src/ast/pipeline_stage.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" #include "src/ast/set_decoration.h" #include "src/ast/sint_literal.h" +#include "src/ast/stage_decoration.h" #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" @@ -220,6 +222,69 @@ fragment frag_main_out frag_main(frag_main_in tint_in [[stage_in]]) { )"); } +TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_WithInOutVars) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &f32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("foo"))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct frag_main_in { + float foo [[user(locn0)]]; +}; + +struct frag_main_out { + float bar [[color(1)]]; +}; + +fragment frag_main_out frag_main(frag_main_in tint_in [[stage_in]]) { + frag_main_out tint_out = {}; + tint_out.bar = tint_in.foo; + return tint_out; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithInOut_Builtins) { ast::type::VoidType void_type; ast::type::F32Type f32; @@ -288,6 +353,73 @@ fragment frag_main_out frag_main(float4 coord [[position]]) { )"); } +TEST_F(MslGeneratorImplTest, + Emit_FunctionDecoration_EntryPoint_WithInOut_Builtins) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kInput, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragCoord)); + coord_var->set_decorations(std::move(decos)); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + td.RegisterVariableForTesting(depth_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + mod.AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct frag_main_out { + float depth [[depth(any)]]; +}; + +fragment frag_main_out frag_main(float4 coord [[position]]) { + frag_main_out tint_out = {}; + tint_out.depth = coord.x; + return tint_out; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_With_Uniform) { ast::type::VoidType void_type; ast::type::F32Type f32; @@ -344,6 +476,60 @@ fragment void frag_main(constant float4& coord [[buffer(0)]]) { )"); } +TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_With_Uniform) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kUniform, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("coord"), + std::make_unique("x"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +fragment void frag_main(constant float4& coord [[buffer(0)]]) { + float v = coord.x; + return; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_With_StorageBuffer) { ast::type::VoidType void_type; ast::type::F32Type f32; @@ -400,6 +586,61 @@ fragment void frag_main(device float4& coord [[buffer(0)]]) { )"); } +TEST_F(MslGeneratorImplTest, + Emit_FunctionDecoration_EntryPoint_With_StorageBuffer) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kStorageBuffer, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("coord"), + std::make_unique("x"))); + + auto body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +fragment void frag_main(device float4& coord [[buffer(0)]]) { + float v = coord.x; + return; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) { ast::type::VoidType void_type; @@ -504,6 +745,109 @@ fragment ep_1_out ep_1(ep_1_in tint_in [[stage_in]]) { )"); } +TEST_F( + MslGeneratorImplTest, + Emit_FunctionDecoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + + auto foo_var = std::make_unique( + std::make_unique("foo", ast::StorageClass::kInput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + foo_var->set_decorations(std::move(decos)); + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &f32)); + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + auto val_var = std::make_unique( + std::make_unique("val", ast::StorageClass::kOutput, &f32)); + decos.push_back(std::make_unique(0)); + val_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(foo_var.get()); + td.RegisterVariableForTesting(bar_var.get()); + td.RegisterVariableForTesting(val_var.get()); + + mod.AddGlobalVariable(std::move(foo_var)); + mod.AddGlobalVariable(std::move(bar_var)); + mod.AddGlobalVariable(std::move(val_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique("foo"))); + body->append(std::make_unique( + std::make_unique("val"), + std::make_unique("param"))); + body->append(std::make_unique( + std::make_unique("foo"))); + sub_func->set_body(std::move(body)); + + mod.AddFunction(std::move(sub_func)); + + auto func_1 = + std::make_unique("ep_1", std::move(params), &void_type); + func_1->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique( + std::make_unique("sub_func"), + std::move(expr)))); + body->append(std::make_unique()); + func_1->set_body(std::move(body)); + + mod.AddFunction(std::move(func_1)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct ep_1_in { + float foo [[user(locn0)]]; +}; + +struct ep_1_out { + float bar [[color(1)]]; + float val [[color(0)]]; +}; + +float sub_func_ep_1(thread ep_1_in& tint_in, thread ep_1_out& tint_out, float param) { + tint_out.bar = tint_in.foo; + tint_out.val = param; + return tint_in.foo; +} + +fragment ep_1_out ep_1(ep_1_in tint_in [[stage_in]]) { + ep_1_out tint_out = {}; + tint_out.bar = sub_func_ep_1(tint_in, tint_out, 1.00000000f); + return tint_out; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_Called_By_EntryPoints_NoUsedGlobals) { ast::type::VoidType void_type; @@ -584,6 +928,84 @@ fragment ep_1_out ep_1() { )"); } +TEST_F(MslGeneratorImplTest, + Emit_FunctionDecoration_Called_By_EntryPoints_NoUsedGlobals) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(depth_var.get()); + + mod.AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("param"))); + sub_func->set_body(std::move(body)); + + mod.AddFunction(std::move(sub_func)); + + auto func_1 = + std::make_unique("ep_1", std::move(params), &void_type); + func_1->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("sub_func"), + std::move(expr)))); + body->append(std::make_unique()); + func_1->set_body(std::move(body)); + + mod.AddFunction(std::move(func_1)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct ep_1_out { + float depth [[depth(any)]]; +}; + +float sub_func(float param) { + return param; +} + +fragment ep_1_out ep_1() { + ep_1_out tint_out = {}; + tint_out.depth = sub_func(1.00000000f); + return tint_out; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { ast::type::VoidType void_type; @@ -679,6 +1101,100 @@ fragment ep_1_out ep_1(float4 coord [[position]]) { )"); } +TEST_F( + MslGeneratorImplTest, + Emit_FunctionDecoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kInput, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragCoord)); + coord_var->set_decorations(std::move(decos)); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + td.RegisterVariableForTesting(depth_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + mod.AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + body->append(std::make_unique( + std::make_unique("param"))); + sub_func->set_body(std::move(body)); + + mod.AddFunction(std::move(sub_func)); + + auto func_1 = + std::make_unique("ep_1", std::move(params), &void_type); + func_1->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("sub_func"), + std::move(expr)))); + body->append(std::make_unique()); + func_1->set_body(std::move(body)); + + mod.AddFunction(std::move(func_1)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct ep_1_out { + float depth [[depth(any)]]; +}; + +float sub_func_ep_1(thread ep_1_out& tint_out, thread float4& coord, float param) { + tint_out.depth = coord.x; + return param; +} + +fragment ep_1_out ep_1(float4 coord [[position]]) { + ep_1_out tint_out = {}; + tint_out.depth = sub_func_ep_1(tint_out, coord, 1.00000000f); + return tint_out; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_Called_By_EntryPoint_With_Uniform) { ast::type::VoidType void_type; ast::type::F32Type f32; @@ -757,6 +1273,83 @@ fragment void frag_main(constant float4& coord [[buffer(0)]]) { )"); } +TEST_F(MslGeneratorImplTest, + Emit_FunctionDecoration_Called_By_EntryPoint_With_Uniform) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kUniform, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + sub_func->set_body(std::move(body)); + + mod.AddFunction(std::move(sub_func)); + + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("sub_func"), + std::move(expr))); + + body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +float sub_func(constant float4& coord, float param) { + return coord.x; +} + +fragment void frag_main(constant float4& coord [[buffer(0)]]) { + float v = sub_func(coord, 1.00000000f); + return; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_Called_By_EntryPoint_With_StorageBuffer) { ast::type::VoidType void_type; @@ -836,6 +1429,83 @@ fragment void frag_main(device float4& coord [[buffer(0)]]) { )"); } +TEST_F(MslGeneratorImplTest, + Emit_FunctionDecoration_Called_By_EntryPoint_With_StorageBuffer) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kStorageBuffer, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + sub_func->set_body(std::move(body)); + + mod.AddFunction(std::move(sub_func)); + + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("sub_func"), + std::move(expr))); + + body = std::make_unique(); + body->append(std::make_unique(std::move(var))); + body->append(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +float sub_func(device float4& coord, float param) { + return coord.x; +} + +fragment void frag_main(device float4& coord [[buffer(0)]]) { + float v = sub_func(coord, 1.00000000f); + return; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) { ast::type::VoidType void_type; ast::type::F32Type f32; @@ -1014,6 +1684,75 @@ fragment ep_1_out ep_1() { )"); } +TEST_F(MslGeneratorImplTest, + Emit_FunctionDecoration_EntryPoints_WithGlobal_Nested_Return) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::I32Type i32; + + auto bar_var = std::make_unique( + std::make_unique("bar", ast::StorageClass::kOutput, &f32)); + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(1)); + bar_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(bar_var.get()); + mod.AddGlobalVariable(std::move(bar_var)); + + ast::VariableList params; + auto func_1 = + std::make_unique("ep_1", std::move(params), &void_type); + func_1->add_decoration( + std::make_unique(ast::PipelineStage::kFragment)); + + auto body = std::make_unique(); + body->append(std::make_unique( + std::make_unique("bar"), + std::make_unique( + std::make_unique(&f32, 1.0f)))); + + auto list = std::make_unique(); + list->append(std::make_unique()); + + body->append(std::make_unique( + std::make_unique( + ast::BinaryOp::kEqual, + std::make_unique( + std::make_unique(&i32, 1)), + std::make_unique( + std::make_unique(&i32, 1))), + std::move(list))); + + body->append(std::make_unique()); + func_1->set_body(std::move(body)); + + mod.AddFunction(std::move(func_1)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g(&mod); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct ep_1_out { + float bar [[color(1)]]; +}; + +fragment ep_1_out ep_1() { + ep_1_out tint_out = {}; + tint_out.bar = 1.00000000f; + if ((1 == 1)) { + return tint_out; + } + return tint_out; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithoutGlobals) { ast::type::VoidType void_type; @@ -1080,6 +1819,7 @@ fragment void ep_2() { )"); } + TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithName) { ast::type::VoidType void_type; @@ -1124,6 +1864,28 @@ kernel void main_tint_0() { )"); } +TEST_F(MslGeneratorImplTest, + Emit_FunctionDecoration_EntryPoint_WithNameCollision) { + ast::type::VoidType void_type; + + auto func = + std::make_unique("main", ast::VariableList{}, &void_type); + func->add_decoration( + std::make_unique(ast::PipelineStage::kCompute)); + + ast::Module m; + m.AddFunction(std::move(func)); + + GeneratorImpl g(&m); + ASSERT_TRUE(g.Generate()) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +kernel void main_tint_0() { +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_WithArrayParams) { ast::type::F32Type f32; ast::type::ArrayType ary(&f32, 5);