diff --git a/src/transform/msl.cc b/src/transform/msl.cc index cd0a5a9b23..28cae988dc 100644 --- a/src/transform/msl.cc +++ b/src/transform/msl.cc @@ -14,6 +14,7 @@ #include "src/transform/msl.h" +#include #include #include @@ -309,12 +310,9 @@ void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const { continue; } - std::vector worklist; + // Find location-decorated parameters and build a struct to hold them. ast::StructMemberList struct_members; - ast::VariableList new_parameters; - ast::StatementList new_body; - - // Find location-decorated parameters. + std::unordered_set builtins; for (auto* param : func->params()) { // TODO(jrprice): Handle structs (collate members into a single struct). if (param->decorations().size() != 1) { @@ -324,24 +322,27 @@ void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const { auto* deco = param->decorations()[0]; if (auto* builtin = deco->As()) { // Keep any builtin-decorated parameters unchanged. - new_parameters.push_back(ctx.Clone(param)); + builtins.insert(param); + continue; } else if (auto* loc = deco->As()) { // Create a struct member with the location decoration. - struct_members.push_back( - ctx.dst->Member(param->symbol().to_str(), ctx.Clone(param->type()), - ast::DecorationList{ctx.Clone(loc)})); - worklist.push_back(param); + struct_members.push_back(ctx.dst->Member( + ctx.src->Symbols().NameFor(param->symbol()), + ctx.Clone(param->type()), ast::DecorationList{ctx.Clone(loc)})); } else { TINT_ICE(ctx.dst->Diagnostics()) << "Unsupported entry point parameter decoration"; } } - if (worklist.empty()) { + if (struct_members.empty()) { // Nothing to do. continue; } + ast::VariableList new_parameters; + ast::StatementList new_body; + // Create a struct type to hold all of the user-defined input parameters. auto* in_struct = ctx.dst->create( ctx.dst->Symbols().New(), @@ -355,14 +356,21 @@ void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const { new_parameters.push_back(struct_param); // Replace the original parameters with function-scope constants. - for (auto* param : worklist) { + for (auto* param : func->params()) { + if (builtins.count(param)) { + // Keep any builtin-decorated parameters unchanged. + new_parameters.push_back(ctx.Clone(param)); + continue; + } + + auto name = ctx.src->Symbols().NameFor(param->symbol()); + // Create a function-scope const to replace the parameter. // Initialize it with the value extracted from the struct parameter. - auto func_const_symbol = ctx.dst->Symbols().New(); + auto func_const_symbol = ctx.dst->Symbols().Register(name); auto* func_const = ctx.dst->Const(func_const_symbol, ctx.Clone(param->type()), - ctx.dst->MemberAccessor(struct_param_symbol, - param->symbol().to_str())); + ctx.dst->MemberAccessor(struct_param_symbol, name)); new_body.push_back(ctx.dst->WrapInStatement(func_const)); diff --git a/src/transform/msl_test.cc b/src/transform/msl_test.cc index 82d6d95ab5..915d591aac 100644 --- a/src/transform/msl_test.cc +++ b/src/transform/msl_test.cc @@ -339,18 +339,18 @@ fn frag_main([[builtin(frag_coord)]] coord : vec4, )"; auto* expect = R"( -struct tint_symbol_4 { +struct tint_symbol_3 { [[location(1)]] - tint_symbol_2 : f32; + loc1 : f32; [[location(2)]] - tint_symbol_3 : vec4; + loc2 : vec4; }; [[stage(fragment)]] -fn frag_main([[builtin(frag_coord)]] coord : vec4, tint_symbol_5 : tint_symbol_4) -> void { - const tint_symbol_6 : f32 = tint_symbol_5.tint_symbol_2; - const tint_symbol_7 : vec4 = tint_symbol_5.tint_symbol_3; - var col : f32 = (coord.x * tint_symbol_6); +fn frag_main(tint_symbol_4 : tint_symbol_3, [[builtin(frag_coord)]] coord : vec4) -> void { + const loc1 : f32 = tint_symbol_4.loc1; + const loc2 : vec4 = tint_symbol_4.loc2; + var col : f32 = (coord.x * loc1); } )"; @@ -359,6 +359,19 @@ fn frag_main([[builtin(frag_coord)]] coord : vec4, tint_symbol_5 : tint_sym EXPECT_EQ(expect, str(got)); } +TEST_F(MslEntryPointIOTest, HandleEntryPointIOTypes_OnlyBuiltinParameters) { + // Expect no change. + auto* src = R"( +[[stage(fragment)]] +fn frag_main([[builtin(frag_coord)]] coord : vec4) -> void { +} +)"; + + auto got = Transform(src); + + EXPECT_EQ(src, str(got)); +} + TEST_F(MslEntryPointIOTest, HandleEntryPointIOTypes_Parameters_EmptyBody) { auto* src = R"( [[stage(fragment)]] @@ -369,17 +382,17 @@ fn frag_main([[builtin(frag_coord)]] coord : vec4, )"; auto* expect = R"( -struct tint_symbol_4 { +struct tint_symbol_3 { [[location(1)]] - tint_symbol_2 : f32; + loc1 : f32; [[location(2)]] - tint_symbol_3 : vec4; + loc2 : vec4; }; [[stage(fragment)]] -fn frag_main([[builtin(frag_coord)]] coord : vec4, tint_symbol_5 : tint_symbol_4) -> void { - const tint_symbol_6 : f32 = tint_symbol_5.tint_symbol_2; - const tint_symbol_7 : vec4 = tint_symbol_5.tint_symbol_3; +fn frag_main(tint_symbol_4 : tint_symbol_3, [[builtin(frag_coord)]] coord : vec4) -> void { + const loc1 : f32 = tint_symbol_4.loc1; + const loc2 : vec4 = tint_symbol_4.loc2; } )"; diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 8f4de92bf5..fc8c55d650 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -114,7 +114,7 @@ TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_NoReturn_InOut) { using namespace metal; struct tint_symbol_2 { - float tint_symbol_1 [[user(locn0)]]; + float foo [[user(locn0)]]; }; struct _tint_main_out { @@ -123,8 +123,8 @@ struct _tint_main_out { fragment _tint_main_out _tint_main(tint_symbol_2 tint_symbol_3 [[stage_in]]) { _tint_main_out _tint_out = {}; - const float tint_symbol_4 = tint_symbol_3.tint_symbol_1; - _tint_out.bar = tint_symbol_4; + const float foo = tint_symbol_3.foo; + _tint_out.bar = foo; return _tint_out; } @@ -154,7 +154,7 @@ TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOutVars) { using namespace metal; struct tint_symbol_2 { - float tint_symbol_1 [[user(locn0)]]; + float foo [[user(locn0)]]; }; struct frag_main_out { @@ -163,8 +163,8 @@ struct frag_main_out { fragment frag_main_out frag_main(tint_symbol_2 tint_symbol_3 [[stage_in]]) { frag_main_out _tint_out = {}; - const float tint_symbol_4 = tint_symbol_3.tint_symbol_1; - _tint_out.bar = tint_symbol_4; + const float foo = tint_symbol_3.foo; + _tint_out.bar = foo; return _tint_out; } @@ -372,7 +372,7 @@ TEST_F( using namespace metal; struct tint_symbol_2 { - float tint_symbol_1 [[user(locn0)]]; + float foo [[user(locn0)]]; }; struct ep_1_out { @@ -388,8 +388,8 @@ float sub_func_ep_1(thread ep_1_out& _tint_out, float param, float foo) { fragment ep_1_out ep_1(tint_symbol_2 tint_symbol_3 [[stage_in]]) { ep_1_out _tint_out = {}; - const float tint_symbol_4 = tint_symbol_3.tint_symbol_1; - _tint_out.bar = sub_func_ep_1(_tint_out, 1.0f, tint_symbol_4); + const float foo = tint_symbol_3.foo; + _tint_out.bar = sub_func_ep_1(_tint_out, 1.0f, foo); return _tint_out; }