diff --git a/src/ast/function.cc b/src/ast/function.cc index 3ca975e846..44d560360d 100644 --- a/src/ast/function.cc +++ b/src/ast/function.cc @@ -71,6 +71,34 @@ Function::referenced_location_variables() const { return ret; } +const std::vector> +Function::referenced_uniform_variables() const { + std::vector> ret; + + for (auto* var : referenced_module_variables()) { + if (!var->IsDecorated() || + var->storage_class() != ast::StorageClass::kUniform) { + continue; + } + + BindingDecoration* binding = nullptr; + SetDecoration* set = nullptr; + for (const auto& deco : var->AsDecorated()->decorations()) { + if (deco->IsBinding()) { + binding = deco->AsBinding(); + } else if (deco->IsSet()) { + set = deco->AsSet(); + } + } + if (binding == nullptr || set == nullptr) { + continue; + } + + ret.push_back({var, BindingInfo{binding, set}}); + } + return ret; +} + const std::vector> Function::referenced_builtin_variables() const { std::vector> ret; diff --git a/src/ast/function.h b/src/ast/function.h index a3722e1211..4a8c33fd34 100644 --- a/src/ast/function.h +++ b/src/ast/function.h @@ -21,10 +21,12 @@ #include #include +#include "src/ast/binding_decoration.h" #include "src/ast/builtin_decoration.h" #include "src/ast/expression.h" #include "src/ast/location_decoration.h" #include "src/ast/node.h" +#include "src/ast/set_decoration.h" #include "src/ast/statement.h" #include "src/ast/type/type.h" #include "src/ast/variable.h" @@ -35,6 +37,14 @@ namespace ast { /// A Function statement. class Function : public Node { public: + /// Information about a binding + struct BindingInfo { + /// The binding decoration + BindingDecoration* binding = nullptr; + /// The set decoration + SetDecoration* set = nullptr; + }; + /// Create a new empty function statement Function(); /// Create a function @@ -86,6 +96,11 @@ class Function : public Node { /// @returns the pair. const std::vector> referenced_builtin_variables() const; + /// Retrieves any referenced uniform variables. Note, the uniform must be + /// decorated with both binding and set decorations. + /// @returns the referenced uniforms + const std::vector> + referenced_uniform_variables() const; /// Adds an ancestor entry point /// @param ep the entry point ancestor diff --git a/src/reader/spirv/function.cc b/src/reader/spirv/function.cc index 6f8de075bc..38e4e606b5 100644 --- a/src/reader/spirv/function.cc +++ b/src/reader/spirv/function.cc @@ -2915,7 +2915,7 @@ TypedExpression FunctionEmitter::MakeAccessChain( type_mgr_->FindPointerToType(pointee_type_id, storage_class); auto* ast_pointer_type = parser_impl_.ConvertType(pointer_type_id); assert(ast_pointer_type); - assert(ast_pointer_type->IsPointer); + assert(ast_pointer_type->IsPointer()); current_expr.reset(TypedExpression(ast_pointer_type, std::move(next_expr))); } return current_expr; diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index 46ce013bcc..dbfff5c7e1 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -448,6 +448,7 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { error_ = "Unable to find function: " + name; return false; } + for (const auto& data : func->referenced_builtin_variables()) { auto* var = data.first; if (var->storage_class() != ast::StorageClass::kInput) { @@ -460,6 +461,15 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { out_ << var->name(); } + for (const auto& data : func->referenced_uniform_variables()) { + auto* var = data.first; + if (!first) { + out_ << ", "; + } + first = false; + out_ << var->name(); + } + const auto& params = expr->params(); for (const auto& param : params) { if (!first) { @@ -814,6 +824,25 @@ void GeneratorImpl::EmitStage(ast::PipelineStage stage) { return; } +bool GeneratorImpl::has_referenced_var_needing_struct(ast::Function* func) { + for (auto data : func->referenced_location_variables()) { + auto var = data.first; + if (var->storage_class() == ast::StorageClass::kInput || + var->storage_class() == ast::StorageClass::kOutput) { + return true; + } + } + + for (auto data : func->referenced_builtin_variables()) { + auto var = data.first; + if (var->storage_class() == ast::StorageClass::kOutput) { + return true; + } + } + + return false; +} + bool GeneratorImpl::EmitFunction(ast::Function* func) { make_indent(); @@ -825,9 +854,8 @@ bool GeneratorImpl::EmitFunction(ast::Function* func) { // TODO(dsinclair): This could be smarter. If the input/outputs for multiple // entry points are the same we could generate a single struct and then have // this determine it's the same struct and just emit once. - bool emit_duplicate_functions = - func->ancestor_entry_points().size() > 0 && - func->referenced_module_variables().size() > 0; + bool emit_duplicate_functions = func->ancestor_entry_points().size() > 0 && + has_referenced_var_needing_struct(func); if (emit_duplicate_functions) { for (const auto& ep_name : func->ancestor_entry_points()) { @@ -857,7 +885,6 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, } out_ << " "; - if (emit_duplicate_functions) { name = generate_name(name + "_" + ep_name); ep_func_name_remapped_[ep_name + "_" + func->name()] = name; @@ -908,6 +935,21 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, out_ << "& " << var->name(); } + for (const auto& data : func->referenced_uniform_variables()) { + auto* var = data.first; + if (!first) { + out_ << ", "; + } + first = false; + + out_ << "constant "; + // TODO(dsinclair): Can arrays be uniform? If so, fix this ... + if (!EmitType(var->type(), "")) { + return false; + } + out_ << "& " << var->name(); + } + // TODO(dsinclair): Binding/Set inputs for (const auto& v : func->params()) { @@ -1034,6 +1076,28 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) { out_ << " " << var->name() << " [[" << attr << "]]"; } + for (auto data : func->referenced_uniform_variables()) { + if (!first) { + out_ << ", "; + } + first = false; + + auto* var = data.first; + // TODO(dsinclair): We're using the binding to make up the buffer number but + // we should instead be using a provided mapping that uses both buffer and + // set. https://bugs.chromium.org/p/tint/issues/detail?id=104 + auto* binding = data.second.binding; + // auto* set = data.second.set; + + out_ << "constant "; + // TODO(dsinclair): Can you have a uniform array? If so, this needs to be + // updated to handle arrays property. + if (!EmitType(var->type(), "")) { + return false; + } + out_ << "& " << var->name() << " [[buffer(" << binding->value() << ")]]"; + } + // TODO(dsinclair): Binding/Set inputs out_ << ") {" << std::endl; diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index df90697eca..094ba8095b 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -208,6 +208,11 @@ class GeneratorImpl : public TextGenerator { /// @param mod the module to set. void set_module_for_testing(ast::Module* mod); + /// Determines if any used module variable requires an input or output struct. + /// @param func the function to check + /// @returns true if an input or output struct is required. + bool has_referenced_var_needing_struct(ast::Function* func); + /// Generates a name for the prefix /// @param prefix the prefix of the name to generate /// @returns the name diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 5e25073914..0558247c6f 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -15,6 +15,7 @@ #include "gtest/gtest.h" #include "src/ast/assignment_statement.h" #include "src/ast/binary_expression.h" +#include "src/ast/binding_decoration.h" #include "src/ast/call_expression.h" #include "src/ast/decorated_variable.h" #include "src/ast/float_literal.h" @@ -26,6 +27,7 @@ #include "src/ast/module.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" +#include "src/ast/set_decoration.h" #include "src/ast/sint_literal.h" #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" @@ -286,6 +288,62 @@ fragment frag_main_out frag_main(float4 coord [[position]]) { )"); } +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_With_Uniform) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kUniform, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("coord"), + std::make_unique("x"))); + + ast::StatementList body; + body.push_back(std::make_unique(std::move(var))); + body.push_back(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + auto ep = std::make_unique(ast::PipelineStage::kFragment, "", + "frag_main"); + mod.AddEntryPoint(std::move(ep)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(mod)) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +fragment void frag_main(constant float4& coord [[buffer(0)]]) { + float v = coord.x; + return; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) { ast::type::VoidType void_type; @@ -481,6 +539,83 @@ fragment ep_1_out ep_1(float4 coord [[position]]) { )"); } +TEST_F(MslGeneratorImplTest, Emit_Function_Called_By_EntryPoint_With_Uniform) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kUniform, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back(std::make_unique(0)); + decos.push_back(std::make_unique(1)); + coord_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + sub_func->set_body(std::move(body)); + + mod.AddFunction(std::move(sub_func)); + + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + + auto var = + std::make_unique("v", ast::StorageClass::kFunction, &f32); + var->set_constructor(std::make_unique( + std::make_unique("sub_func"), + std::move(expr))); + + body.push_back(std::make_unique(std::move(var))); + body.push_back(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + auto ep = std::make_unique(ast::PipelineStage::kFragment, "", + "frag_main"); + mod.AddEntryPoint(std::move(ep)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(mod)) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +float sub_func(constant float4& coord, float param) { + return coord.x; +} + +fragment void frag_main(constant float4& coord [[buffer(0)]]) { + float v = sub_func(coord, 1.00000000f); + return; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) { ast::type::VoidType void_type; ast::type::F32Type f32;