diff --git a/src/ast/decorated_variable.cc b/src/ast/decorated_variable.cc index 89db9e4dd6..8588b99f53 100644 --- a/src/ast/decorated_variable.cc +++ b/src/ast/decorated_variable.cc @@ -35,6 +35,15 @@ bool DecoratedVariable::HasLocationDecoration() const { return false; } +bool DecoratedVariable::HasBuiltinDecoration() const { + for (const auto& deco : decorations_) { + if (deco->IsBuiltin()) { + return true; + } + } + return false; +} + bool DecoratedVariable::IsDecorated() const { return true; } diff --git a/src/ast/decorated_variable.h b/src/ast/decorated_variable.h index d2e381a2dc..489dcbc7e2 100644 --- a/src/ast/decorated_variable.h +++ b/src/ast/decorated_variable.h @@ -47,6 +47,8 @@ class DecoratedVariable : public Variable { /// @returns true if the decorations include a LocationDecoration bool HasLocationDecoration() const; + /// @returns true if the deocrations include a BuiltinDecoration + bool HasBuiltinDecoration() const; /// @returns true if this is a decorated variable bool IsDecorated() const override; diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index fbb9c845d7..9fea11f3fd 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -333,11 +333,26 @@ bool GeneratorImpl::EmitCall(ast::CallExpression* expr) { if (!first) { out_ << ", "; } - out_ << var_name; first = false; + out_ << var_name; } - // TODO(dsinclair): Emit builtins + auto* func = module_->FindFunctionByName(ident->name()); + if (func == nullptr) { + error_ = "Unable to find function: " + name; + return false; + } + for (const auto& data : func->referenced_builtin_variables()) { + auto* var = data.first; + if (var->storage_class() != ast::StorageClass::kInput) { + continue; + } + if (!first) { + out_ << ", "; + } + first = false; + out_ << var->name(); + } const auto& params = expr->params(); for (const auto& param : params) { @@ -517,15 +532,25 @@ bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) { } std::vector> in_locations; - std::vector> out_locations; + std::vector> + out_variables; for (auto data : func->referenced_location_variables()) { auto var = data.first; - auto locn_deco = data.second; + auto deco = data.second; if (var->storage_class() == ast::StorageClass::kInput) { - in_locations.push_back({var, locn_deco->value()}); + in_locations.push_back({var, deco->value()}); } else if (var->storage_class() == ast::StorageClass::kOutput) { - out_locations.push_back({var, locn_deco->value()}); + out_variables.push_back({var, deco}); + } + } + + for (auto data : func->referenced_builtin_variables()) { + auto var = data.first; + auto deco = data.second; + + if (var->storage_class() == ast::StorageClass::kOutput) { + out_variables.push_back({var, deco}); } } @@ -575,7 +600,7 @@ bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) { out_ << "};" << std::endl << std::endl; } - if (!out_locations.empty()) { + if (!out_variables.empty()) { auto out_struct_name = generate_name(ep_name + "_" + kOutStructNameSuffix); auto out_var_name = generate_name(kTintStructOutVarPrefix); ep_name_to_out_data_[ep_name] = {out_struct_name, out_var_name}; @@ -584,9 +609,9 @@ bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) { out_ << "struct " << out_struct_name << " {" << std::endl; increment_indent(); - for (auto& data : out_locations) { + for (auto& data : out_variables) { auto* var = data.first; - uint32_t loc = data.second; + auto* deco = data.second; make_indent(); if (!EmitType(var->type(), var->name())) { @@ -594,12 +619,26 @@ bool GeneratorImpl::EmitEntryPointData(ast::EntryPoint* ep) { } out_ << " " << var->name() << " [["; - if (ep->stage() == ast::PipelineStage::kVertex) { - out_ << "user(locn" << loc << ")"; - } else if (ep->stage() == ast::PipelineStage::kFragment) { - out_ << "color(" << loc << ")"; + + if (deco->IsLocation()) { + auto loc = deco->AsLocation()->value(); + if (ep->stage() == ast::PipelineStage::kVertex) { + out_ << "user(locn" << loc << ")"; + } else if (ep->stage() == ast::PipelineStage::kFragment) { + out_ << "color(" << loc << ")"; + } else { + error_ = "invalid location variable for pipeline stage"; + return false; + } + } else if (deco->IsBuiltin()) { + auto attr = builtin_to_attribute(deco->AsBuiltin()->value()); + if (attr.empty()) { + error_ = "unsupported builtin"; + return false; + } + out_ << attr; } else { - error_ = "invalid location variable for pipeline stage"; + error_ = "unsupported variable decoration for entry point output"; return false; } out_ << "]];" << std::endl; @@ -739,7 +778,22 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, } } - // TODO(dsinclair): Handle any entry point builtin params used here + for (const auto& data : func->referenced_builtin_variables()) { + auto* var = data.first; + if (var->storage_class() != ast::StorageClass::kInput) { + continue; + } + if (!first) { + out_ << ", "; + } + first = false; + + out_ << "thread "; + if (!EmitType(var->type(), "")) { + return false; + } + out_ << "& " << var->name(); + } // TODO(dsinclair): Binding/Set inputs @@ -771,6 +825,41 @@ bool GeneratorImpl::EmitFunctionInternal(ast::Function* func, return true; } +std::string GeneratorImpl::builtin_to_attribute(ast::Builtin builtin) const { + switch (builtin) { + case ast::Builtin::kPosition: + return "position"; + case ast::Builtin::kVertexIdx: + return "vertex_id"; + case ast::Builtin::kInstanceIdx: + return "instance_id"; + case ast::Builtin::kFrontFacing: + return "front_facing"; + case ast::Builtin::kFragCoord: + return "position"; + case ast::Builtin::kFragDepth: + return "depth(any)"; + // TODO(dsinclair): Ignore for now, I believe it will be removed from WGSL + // https://github.com/gpuweb/gpuweb/issues/920 + case ast::Builtin::kNumWorkgroups: + return ""; + // TODO(dsinclair): Ignore for now. This has been removed as a builtin + // in the spec. Need to update Tint to match. + // https://github.com/gpuweb/gpuweb/pull/824 + case ast::Builtin::kWorkgroupSize: + return ""; + case ast::Builtin::kLocalInvocationId: + return "thread_position_in_threadgroup"; + case ast::Builtin::kLocalInvocationIdx: + return "thread_index_in_threadgroup"; + case ast::Builtin::kGlobalInvocationId: + return "thread_position_in_grid"; + default: + break; + } + return ""; +} + bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) { make_indent(); @@ -799,13 +888,38 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::EntryPoint* ep) { } out_ << " " << namer_.NameFor(current_ep_name_) << "("; + bool first = true; auto in_data = ep_name_to_in_data_.find(current_ep_name_); if (in_data != ep_name_to_in_data_.end()) { out_ << in_data->second.struct_name << " " << in_data->second.var_name << " [[stage_in]]"; + first = false; } - // TODO(dsinclair): Output other builtin inputs + for (auto data : func->referenced_builtin_variables()) { + auto* var = data.first; + if (var->storage_class() != ast::StorageClass::kInput) { + continue; + } + + if (!first) { + out_ << ", "; + } + first = false; + + auto* builtin = data.second; + + if (!EmitType(var->type(), "")) { + return false; + } + + auto attr = builtin_to_attribute(builtin->value()); + if (attr.empty()) { + error_ = "unknown builtin"; + return false; + } + out_ << " " << var->name() << " [[" << attr << "]]"; + } // TODO(dsinclair): Binding/Set inputs @@ -845,9 +959,14 @@ bool GeneratorImpl::EmitIdentifier(ast::IdentifierExpression* expr) { ast::Variable* var = nullptr; if (global_variables_.get(ident->name(), &var)) { - if (var->IsDecorated() && var->AsDecorated()->HasLocationDecoration() && + bool in_or_out_struct_has_location = + var->IsDecorated() && var->AsDecorated()->HasLocationDecoration() && (var->storage_class() == ast::StorageClass::kInput || - var->storage_class() == ast::StorageClass::kOutput)) { + var->storage_class() == ast::StorageClass::kOutput); + bool in_struct_has_builtin = + var->IsDecorated() && var->AsDecorated()->HasBuiltinDecoration() && + var->storage_class() == ast::StorageClass::kOutput; + if (in_or_out_struct_has_location || in_struct_has_builtin) { auto var_type = var->storage_class() == ast::StorageClass::kInput ? VarType::kIn : VarType::kOut; diff --git a/src/writer/msl/generator_impl.h b/src/writer/msl/generator_impl.h index 6f7f6e5fad..4290d8071f 100644 --- a/src/writer/msl/generator_impl.h +++ b/src/writer/msl/generator_impl.h @@ -198,6 +198,11 @@ class GeneratorImpl : public TextGenerator { /// @returns the name std::string generate_name(const std::string& prefix); + /// Converts a builtin to an attribute name + /// @param builtin the builtin to convert + /// @returns the string name of the builtin or blank on error + std::string builtin_to_attribute(ast::Builtin builtin) const; + /// @returns the namer for testing Namer* namer_for_testing() { return &namer_; } diff --git a/src/writer/msl/generator_impl_call_test.cc b/src/writer/msl/generator_impl_call_test.cc index b9347cbabd..af4defea3b 100644 --- a/src/writer/msl/generator_impl_call_test.cc +++ b/src/writer/msl/generator_impl_call_test.cc @@ -16,7 +16,10 @@ #include "gtest/gtest.h" #include "src/ast/call_expression.h" +#include "src/ast/function.h" #include "src/ast/identifier_expression.h" +#include "src/ast/module.h" +#include "src/ast/type/void_type.h" #include "src/writer/msl/generator_impl.h" namespace tint { @@ -27,22 +30,42 @@ namespace { using MslGeneratorImplTest = testing::Test; TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithoutParams) { + ast::type::VoidType void_type; + auto id = std::make_unique("my_func"); ast::CallExpression call(std::move(id), {}); + auto func = std::make_unique("my_func", ast::VariableList{}, + &void_type); + + ast::Module m; + m.AddFunction(std::move(func)); + GeneratorImpl g; + g.set_module_for_testing(&m); + ASSERT_TRUE(g.EmitExpression(&call)) << g.error(); EXPECT_EQ(g.result(), "my_func()"); } TEST_F(MslGeneratorImplTest, EmitExpression_Call_WithParams) { + ast::type::VoidType void_type; + auto id = std::make_unique("my_func"); ast::ExpressionList params; params.push_back(std::make_unique("param1")); params.push_back(std::make_unique("param2")); ast::CallExpression call(std::move(id), std::move(params)); + auto func = std::make_unique("my_func", ast::VariableList{}, + &void_type); + + ast::Module m; + m.AddFunction(std::move(func)); + GeneratorImpl g; + g.set_module_for_testing(&m); + ASSERT_TRUE(g.EmitExpression(&call)) << g.error(); EXPECT_EQ(g.result(), "my_func(param1, param2)"); } diff --git a/src/writer/msl/generator_impl_entry_point_test.cc b/src/writer/msl/generator_impl_entry_point_test.cc index 824ac6ceab..a5bd1f3db6 100644 --- a/src/writer/msl/generator_impl_entry_point_test.cc +++ b/src/writer/msl/generator_impl_entry_point_test.cc @@ -18,9 +18,12 @@ #include "src/ast/entry_point.h" #include "src/ast/identifier_expression.h" #include "src/ast/location_decoration.h" +#include "src/ast/member_accessor_expression.h" #include "src/ast/module.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" +#include "src/ast/type/vector_type.h" +#include "src/ast/type/void_type.h" #include "src/ast/variable.h" #include "src/context.h" #include "src/type_determiner.h" @@ -419,6 +422,78 @@ TEST_F(MslGeneratorImplTest, EmitEntryPointData_Compute_Output) { EXPECT_EQ(g.error(), R"(invalid location variable for pipeline stage)"); } +TEST_F(MslGeneratorImplTest, EmitEntryPointData_Builtins) { + // Output builtins go in the output struct, input builtins will be passed + // as input parameters to the entry point function. + + // [[builtin frag_coord]] var coord : vec4; + // [[builtin frag_depth]] var depth : f32; + // + // struct main_out { + // float depth [[depth(any)]]; + // }; + + ast::type::F32Type f32; + ast::type::VoidType void_type; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kInput, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragCoord)); + coord_var->set_decorations(std::move(decos)); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + td.RegisterVariableForTesting(depth_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + mod.AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + auto ep = std::make_unique(ast::PipelineStage::kFragment, + "main", "frag_main"); + auto* ep_ptr = ep.get(); + + mod.AddEntryPoint(std::move(ep)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + g.set_module_for_testing(&mod); + ASSERT_TRUE(g.EmitEntryPointData(ep_ptr)) << g.error(); + EXPECT_EQ(g.result(), R"(struct main_out { + float depth [[depth(any)]]; +}; + +)"); +} + } // namespace } // namespace msl } // namespace writer diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 7d49b1dee0..5e25073914 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -22,6 +22,7 @@ #include "src/ast/identifier_expression.h" #include "src/ast/if_statement.h" #include "src/ast/location_decoration.h" +#include "src/ast/member_accessor_expression.h" #include "src/ast/module.h" #include "src/ast/return_statement.h" #include "src/ast/scalar_constructor_expression.h" @@ -29,6 +30,7 @@ #include "src/ast/type/array_type.h" #include "src/ast/type/f32_type.h" #include "src/ast/type/i32_type.h" +#include "src/ast/type/vector_type.h" #include "src/ast/type/void_type.h" #include "src/ast/variable.h" #include "src/ast/variable_decl_statement.h" @@ -216,8 +218,76 @@ fragment frag_main_out frag_main(frag_main_in tint_in [[stage_in]]) { )"); } +TEST_F(MslGeneratorImplTest, Emit_Function_EntryPoint_WithInOut_Builtins) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kInput, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragCoord)); + coord_var->set_decorations(std::move(decos)); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + td.RegisterVariableForTesting(depth_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + mod.AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + auto func = std::make_unique("frag_main", std::move(params), + &void_type); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + body.push_back(std::make_unique()); + func->set_body(std::move(body)); + + mod.AddFunction(std::move(func)); + + auto ep = std::make_unique(ast::PipelineStage::kFragment, "", + "frag_main"); + mod.AddEntryPoint(std::move(ep)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(mod)) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct frag_main_out { + float depth [[depth(any)]]; +}; + +fragment frag_main_out frag_main(float4 coord [[position]]) { + frag_main_out tint_out = {}; + tint_out.depth = coord.x; + return tint_out; +} + +)"); +} + TEST_F(MslGeneratorImplTest, - Emit_Function_Called_By_EntryPoints_WithGlobals_And_Params) { + Emit_Function_Called_By_EntryPoints_WithLocationGlobals_And_Params) { ast::type::VoidType void_type; ast::type::F32Type f32; @@ -318,6 +388,99 @@ fragment ep_1_out ep_1(ep_1_in tint_in [[stage_in]]) { )"); } +TEST_F(MslGeneratorImplTest, + Emit_Function_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { + ast::type::VoidType void_type; + ast::type::F32Type f32; + ast::type::VectorType vec4(&f32, 4); + + auto coord_var = + std::make_unique(std::make_unique( + "coord", ast::StorageClass::kInput, &vec4)); + + ast::VariableDecorationList decos; + decos.push_back( + std::make_unique(ast::Builtin::kFragCoord)); + coord_var->set_decorations(std::move(decos)); + + auto depth_var = + std::make_unique(std::make_unique( + "depth", ast::StorageClass::kOutput, &f32)); + decos.push_back( + std::make_unique(ast::Builtin::kFragDepth)); + depth_var->set_decorations(std::move(decos)); + + Context ctx; + ast::Module mod; + TypeDeterminer td(&ctx, &mod); + td.RegisterVariableForTesting(coord_var.get()); + td.RegisterVariableForTesting(depth_var.get()); + + mod.AddGlobalVariable(std::move(coord_var)); + mod.AddGlobalVariable(std::move(depth_var)); + + ast::VariableList params; + params.push_back(std::make_unique( + "param", ast::StorageClass::kFunction, &f32)); + auto sub_func = + std::make_unique("sub_func", std::move(params), &f32); + + ast::StatementList body; + body.push_back(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("coord"), + std::make_unique("x")))); + body.push_back(std::make_unique( + std::make_unique("param"))); + sub_func->set_body(std::move(body)); + + mod.AddFunction(std::move(sub_func)); + + auto func_1 = std::make_unique("frag_1_main", + std::move(params), &void_type); + + ast::ExpressionList expr; + expr.push_back(std::make_unique( + std::make_unique(&f32, 1.0f))); + body.push_back(std::make_unique( + std::make_unique("depth"), + std::make_unique( + std::make_unique("sub_func"), + std::move(expr)))); + body.push_back(std::make_unique()); + func_1->set_body(std::move(body)); + + mod.AddFunction(std::move(func_1)); + + auto ep1 = std::make_unique(ast::PipelineStage::kFragment, + "ep_1", "frag_1_main"); + mod.AddEntryPoint(std::move(ep1)); + + ASSERT_TRUE(td.Determine()) << td.error(); + + GeneratorImpl g; + ASSERT_TRUE(g.Generate(mod)) << g.error(); + EXPECT_EQ(g.result(), R"(#include + +struct ep_1_out { + float depth [[depth(any)]]; +}; + +float sub_func_ep_1(thread ep_1_out& tint_out, thread float4& coord, float param) { + tint_out.depth = coord.x; + return param; +} + +fragment ep_1_out ep_1(float4 coord [[position]]) { + ep_1_out tint_out = {}; + tint_out.depth = sub_func_ep_1(tint_out, coord, 1.00000000f); + return tint_out; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_Function_Called_Two_EntryPoints_WithGlobals) { ast::type::VoidType void_type; ast::type::F32Type f32; diff --git a/src/writer/msl/generator_impl_test.cc b/src/writer/msl/generator_impl_test.cc index 063fd1ff15..c7c6975b2a 100644 --- a/src/writer/msl/generator_impl_test.cc +++ b/src/writer/msl/generator_impl_test.cc @@ -74,6 +74,40 @@ TEST_F(MslGeneratorImplTest, NameConflictWith_InputStructName) { EXPECT_EQ(g.result(), "func_main_in_0"); } +struct MslBuiltinData { + ast::Builtin builtin; + const char* attribute_name; +}; +inline std::ostream& operator<<(std::ostream& out, MslBuiltinData data) { + out << data.builtin; + return out; +} +using MslBuiltinConversionTest = testing::TestWithParam; +TEST_P(MslBuiltinConversionTest, Emit) { + auto params = GetParam(); + + GeneratorImpl g; + EXPECT_EQ(g.builtin_to_attribute(params.builtin), + std::string(params.attribute_name)); +} +INSTANTIATE_TEST_SUITE_P( + MslGeneratorImplTest, + MslBuiltinConversionTest, + testing::Values(MslBuiltinData{ast::Builtin::kPosition, "position"}, + MslBuiltinData{ast::Builtin::kVertexIdx, "vertex_id"}, + MslBuiltinData{ast::Builtin::kInstanceIdx, "instance_id"}, + MslBuiltinData{ast::Builtin::kFrontFacing, "front_facing"}, + MslBuiltinData{ast::Builtin::kFragCoord, "position"}, + MslBuiltinData{ast::Builtin::kFragDepth, "depth(any)"}, + MslBuiltinData{ast::Builtin::kNumWorkgroups, ""}, + MslBuiltinData{ast::Builtin::kWorkgroupSize, ""}, + MslBuiltinData{ast::Builtin::kLocalInvocationId, + "thread_position_in_threadgroup"}, + MslBuiltinData{ast::Builtin::kLocalInvocationIdx, + "thread_index_in_threadgroup"}, + MslBuiltinData{ast::Builtin::kGlobalInvocationId, + "thread_position_in_grid"})); + } // namespace } // namespace msl } // namespace writer