diff --git a/src/transform/msl.cc b/src/transform/msl.cc index a089fda847..cd0a5a9b23 100644 --- a/src/transform/msl.cc +++ b/src/transform/msl.cc @@ -15,8 +15,10 @@ #include "src/transform/msl.h" #include +#include #include "src/program_builder.h" +#include "src/semantic/variable.h" namespace tint { namespace transform { @@ -266,10 +268,125 @@ Transform::Output Msl::Run(const Program* in) { ProgramBuilder out; CloneContext ctx(&out, in); RenameReservedKeywords(&ctx, kReservedKeywords); + HandleEntryPointIOTypes(ctx); ctx.Clone(); return Output{Program(std::move(out))}; } +void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const { + // Collect location-decorated entry point parameters into a struct. + // Insert function-scope const declarations to replace those parameters. + // + // Before: + // ``` + // [[stage(fragment)]] + // fn frag_main([[builtin(frag_coord)]] coord : vec4, + // [[location(1)]] loc1 : f32, + // [[location(2)]] loc2 : vec4) -> void { + // var col : f32 = (coord.x * loc1); + // } + // ``` + // + // After: + // ``` + // struct frag_main_in { + // [[location(1)]] loc1 : f32; + // [[location(2)]] loc2 : vec4 + // }; + + // [[stage(fragment)]] + // fn frag_main([[builtin(frag_coord)]] coord : vec4, + // in : frag_main_in) -> void { + // const loc1 : f32 = in.loc1; + // const loc2 : vec4 = in.loc2; + // var col : f32 = (coord.x * loc1); + // } + // ``` + + for (auto* func : ctx.src->AST().Functions()) { + if (!func->IsEntryPoint()) { + continue; + } + + std::vector worklist; + ast::StructMemberList struct_members; + ast::VariableList new_parameters; + ast::StatementList new_body; + + // Find location-decorated parameters. + for (auto* param : func->params()) { + // TODO(jrprice): Handle structs (collate members into a single struct). + if (param->decorations().size() != 1) { + TINT_ICE(ctx.dst->Diagnostics()) << "Unsupported entry point parameter"; + } + + auto* deco = param->decorations()[0]; + if (auto* builtin = deco->As()) { + // Keep any builtin-decorated parameters unchanged. + new_parameters.push_back(ctx.Clone(param)); + } else if (auto* loc = deco->As()) { + // Create a struct member with the location decoration. + struct_members.push_back( + ctx.dst->Member(param->symbol().to_str(), ctx.Clone(param->type()), + ast::DecorationList{ctx.Clone(loc)})); + worklist.push_back(param); + } else { + TINT_ICE(ctx.dst->Diagnostics()) + << "Unsupported entry point parameter decoration"; + } + } + + if (worklist.empty()) { + // Nothing to do. + continue; + } + + // Create a struct type to hold all of the user-defined input parameters. + auto* in_struct = ctx.dst->create( + ctx.dst->Symbols().New(), + ctx.dst->create(struct_members, ast::DecorationList{})); + ctx.dst->AST().AddConstructedType(in_struct); + + // Create a new function parameter using this struct type. + auto struct_param_symbol = ctx.dst->Symbols().New(); + auto* struct_param = + ctx.dst->Var(struct_param_symbol, in_struct, ast::StorageClass::kNone); + new_parameters.push_back(struct_param); + + // Replace the original parameters with function-scope constants. + for (auto* param : worklist) { + // Create a function-scope const to replace the parameter. + // Initialize it with the value extracted from the struct parameter. + auto func_const_symbol = ctx.dst->Symbols().New(); + auto* func_const = + ctx.dst->Const(func_const_symbol, ctx.Clone(param->type()), + ctx.dst->MemberAccessor(struct_param_symbol, + param->symbol().to_str())); + + new_body.push_back(ctx.dst->WrapInStatement(func_const)); + + // Replace all uses of the function parameter with the function const. + for (auto* user : ctx.src->Sem().Get(param)->Users()) { + ctx.Replace(user->Declaration(), + ctx.dst->Expr(func_const_symbol)); + } + } + + // Copy over the rest of the function body unchanged. + for (auto* stmt : func->body()->list()) { + new_body.push_back(ctx.Clone(stmt)); + } + + // Rewrite the function header with the new parameters. + auto* new_func = ctx.dst->create( + func->source(), ctx.Clone(func->symbol()), new_parameters, + ctx.Clone(func->return_type()), + ctx.dst->create(new_body), + ctx.Clone(func->decorations())); + ctx.Replace(func, new_func); + } +} + } // namespace transform } // namespace tint diff --git a/src/transform/msl.h b/src/transform/msl.h index d98ae6ca29..7b5e16c0ab 100644 --- a/src/transform/msl.h +++ b/src/transform/msl.h @@ -33,6 +33,9 @@ class Msl : public Transform { /// @param program the source program to transform /// @returns the transformation result Output Run(const Program* program) override; + + /// Hoist location-decorated entry point parameters out to struct members. + void HandleEntryPointIOTypes(CloneContext& ctx) const; }; } // namespace transform diff --git a/src/transform/msl_test.cc b/src/transform/msl_test.cc index c38feb4e43..82d6d95ab5 100644 --- a/src/transform/msl_test.cc +++ b/src/transform/msl_test.cc @@ -326,6 +326,68 @@ INSTANTIATE_TEST_SUITE_P(MslReservedKeywordTest, "vec", "vertex")); +using MslEntryPointIOTest = TransformTest; + +TEST_F(MslEntryPointIOTest, HandleEntryPointIOTypes_Parameters) { + auto* src = R"( +[[stage(fragment)]] +fn frag_main([[builtin(frag_coord)]] coord : vec4, + [[location(1)]] loc1 : f32, + [[location(2)]] loc2 : vec4) -> void { + var col : f32 = (coord.x * loc1); +} +)"; + + auto* expect = R"( +struct tint_symbol_4 { + [[location(1)]] + tint_symbol_2 : f32; + [[location(2)]] + tint_symbol_3 : vec4; +}; + +[[stage(fragment)]] +fn frag_main([[builtin(frag_coord)]] coord : vec4, tint_symbol_5 : tint_symbol_4) -> void { + const tint_symbol_6 : f32 = tint_symbol_5.tint_symbol_2; + const tint_symbol_7 : vec4 = tint_symbol_5.tint_symbol_3; + var col : f32 = (coord.x * tint_symbol_6); +} +)"; + + auto got = Transform(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(MslEntryPointIOTest, HandleEntryPointIOTypes_Parameters_EmptyBody) { + auto* src = R"( +[[stage(fragment)]] +fn frag_main([[builtin(frag_coord)]] coord : vec4, + [[location(1)]] loc1 : f32, + [[location(2)]] loc2 : vec4) -> void { +} +)"; + + auto* expect = R"( +struct tint_symbol_4 { + [[location(1)]] + tint_symbol_2 : f32; + [[location(2)]] + tint_symbol_3 : vec4; +}; + +[[stage(fragment)]] +fn frag_main([[builtin(frag_coord)]] coord : vec4, tint_symbol_5 : tint_symbol_4) -> void { + const tint_symbol_6 : f32 = tint_symbol_5.tint_symbol_2; + const tint_symbol_7 : vec4 = tint_symbol_5.tint_symbol_3; +} +)"; + + auto got = Transform(src); + + EXPECT_EQ(expect, str(got)); +} + } // namespace } // namespace transform } // namespace tint diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index c144b8924b..00f9ffb4eb 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -1036,6 +1036,8 @@ bool GeneratorImpl::EmitLiteral(ast::Literal* lit) { return true; } +// TODO(jrprice): Remove this when we remove support for entry point params as +// module-scope globals. bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { auto* func_sem = program_->Sem().Get(func); @@ -1454,6 +1456,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { out_ << " " << program_->Symbols().NameFor(func->symbol()) << "("; bool first = true; + // TODO(jrprice): Remove this when we remove support for builtins as globals. auto in_data = ep_sym_to_in_data_.find(current_ep_sym_); if (in_data != ep_sym_to_in_data_.end()) { out_ << in_data->second.struct_name << " " << in_data->second.var_name @@ -1461,6 +1464,46 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { first = false; } + // Emit entry point parameters. + for (auto* var : func->params()) { + if (!first) { + out_ << ", "; + } + first = false; + + if (!EmitType(var->type(), "")) { + return false; + } + + out_ << " " << program_->Symbols().NameFor(var->symbol()); + + if (var->type()->Is()) { + out_ << " [[stage_in]]"; + } else { + auto& decos = var->decorations(); + bool builtin_found = false; + for (auto* deco : decos) { + auto* builtin = deco->As(); + if (!builtin) { + continue; + } + + builtin_found = true; + + auto attr = builtin_to_attribute(builtin->value()); + if (attr.empty()) { + diagnostics_.add_error("unknown builtin"); + return false; + } + out_ << " [[" << attr << "]]"; + } + if (!builtin_found) { + TINT_ICE(diagnostics_) << "Unsupported entry point parameter"; + } + } + } + + // TODO(jrprice): Remove this when we remove support for builtins as globals. for (auto data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->StorageClass() != ast::StorageClass::kInput) { @@ -2036,6 +2079,8 @@ bool GeneratorImpl::EmitStructType(const type::Struct* str) { uint32_t current_offset = 0; uint32_t pad_count = 0; for (auto* mem : str->impl()->members()) { + std::string attributes; + make_indent(); for (auto* deco : mem->decorations()) { if (auto* o = deco->As()) { @@ -2047,6 +2092,8 @@ bool GeneratorImpl::EmitStructType(const type::Struct* str) { make_indent(); } current_offset = offset; + } else if (auto* loc = deco->As()) { + attributes = " [[user(locn" + std::to_string(loc->value()) + ")]]"; } else { diagnostics_.add_error("unsupported member decoration: " + program_->str(deco)); @@ -2069,6 +2116,9 @@ bool GeneratorImpl::EmitStructType(const type::Struct* str) { if (!mem->type()->Is()) { out_ << " " << program_->Symbols().NameFor(mem->symbol()); } + + out_ << attributes; + out_ << ";" << std::endl; } decrement_indent(); diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 18486a76e7..543d0a8f65 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -93,36 +93,38 @@ fragment void main() { } TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_NoReturn_InOut) { - Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{create(0)}); + auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr, + ast::DecorationList{create(0)}); + // TODO(jrprice): Make this the return value when supported. Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{create(1)}); - Func("main", ast::VariableList{}, ty.void_(), + Func("main", ast::VariableList{foo_in}, ty.void_(), ast::StatementList{ create(Expr("bar"), Expr("foo")), /* no explicit return */}, ast::DecorationList{ create(ast::PipelineStage::kFragment)}); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate()) << gen.error(); EXPECT_EQ(gen.result(), R"(#include using namespace metal; -struct main_in { - float foo [[user(locn0)]]; +struct tint_symbol_2 { + float tint_symbol_1 [[user(locn0)]]; }; -struct main_out { +struct _tint_main_out { float bar [[color(1)]]; }; -fragment main_out main(main_in _tint_in [[stage_in]]) { - main_out _tint_out = {}; - _tint_out.bar = _tint_in.foo; +fragment _tint_main_out _tint_main(tint_symbol_2 tint_symbol_3 [[stage_in]]) { + _tint_main_out _tint_out = {}; + const float tint_symbol_4 = tint_symbol_3.tint_symbol_1; + _tint_out.bar = tint_symbol_4; return _tint_out; } @@ -130,9 +132,10 @@ fragment main_out main(main_in _tint_in [[stage_in]]) { } TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOutVars) { - Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{create(0)}); + auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr, + ast::DecorationList{create(0)}); + // TODO(jrprice): Make this the return value when supported. Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{create(1)}); @@ -140,27 +143,28 @@ TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOutVars) { create(Expr("bar"), Expr("foo")), create(), }; - Func("frag_main", ast::VariableList{}, ty.void_(), body, + Func("frag_main", ast::VariableList{foo_in}, ty.void_(), body, ast::DecorationList{ create(ast::PipelineStage::kFragment)}); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate()) << gen.error(); EXPECT_EQ(gen.result(), R"(#include using namespace metal; -struct frag_main_in { - float foo [[user(locn0)]]; +struct tint_symbol_2 { + float tint_symbol_1 [[user(locn0)]]; }; struct frag_main_out { float bar [[color(1)]]; }; -fragment frag_main_out frag_main(frag_main_in _tint_in [[stage_in]]) { +fragment frag_main_out frag_main(tint_symbol_2 tint_symbol_3 [[stage_in]]) { frag_main_out _tint_out = {}; - _tint_out.bar = _tint_in.foo; + const float tint_symbol_4 = tint_symbol_3.tint_symbol_1; + _tint_out.bar = tint_symbol_4; return _tint_out; } @@ -168,10 +172,12 @@ fragment frag_main_out frag_main(frag_main_in _tint_in [[stage_in]]) { } TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOut_Builtins) { - Global("coord", ty.vec4(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(ast::Builtin::kFragCoord)}); + auto* coord_in = + Var("coord", ty.vec4(), ast::StorageClass::kNone, nullptr, + ast::DecorationList{ + create(ast::Builtin::kFragCoord)}); + // TODO(jrprice): Make this the return value when supported. Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{ create(ast::Builtin::kFragDepth)}); @@ -182,12 +188,12 @@ TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOut_Builtins) { create(), }; - Func("frag_main", ast::VariableList{}, ty.void_(), body, + Func("frag_main", ast::VariableList{coord_in}, ty.void_(), body, ast::DecorationList{ create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate()) << gen.error(); EXPECT_EQ(gen.result(), R"(#include @@ -331,8 +337,8 @@ fragment void frag_main(const device Data& coord [[buffer(0)]]) { TEST_F( MslGeneratorImplTest, Emit_Decoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) { // NOLINT - Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{create(0)}); + auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr, + ast::DecorationList{create(0)}); Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{create(1)}); @@ -341,7 +347,8 @@ TEST_F( ast::DecorationList{create(0)}); ast::VariableList params; - params.push_back(Var("param", ty.f32(), ast::StorageClass::kFunction)); + params.push_back(Var("param", ty.f32(), ast::StorageClass::kNone)); + params.push_back(Var("foo", ty.f32(), ast::StorageClass::kNone)); auto body = ast::StatementList{ create(Expr("bar"), Expr("foo")), @@ -351,23 +358,24 @@ TEST_F( Func("sub_func", params, ty.f32(), body, ast::DecorationList{}); body = ast::StatementList{ - create(Expr("bar"), Call("sub_func", 1.0f)), + create(Expr("bar"), + Call("sub_func", 1.0f, Expr("foo"))), create(), }; - Func("ep_1", ast::VariableList{}, ty.void_(), body, + Func("ep_1", ast::VariableList{foo_in}, ty.void_(), body, ast::DecorationList{ create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate()) << gen.error(); EXPECT_EQ(gen.result(), R"(#include using namespace metal; -struct ep_1_in { - float foo [[user(locn0)]]; +struct tint_symbol_2 { + float tint_symbol_1 [[user(locn0)]]; }; struct ep_1_out { @@ -375,15 +383,16 @@ struct ep_1_out { float val [[color(0)]]; }; -float sub_func_ep_1(thread ep_1_in& _tint_in, thread ep_1_out& _tint_out, float param) { - _tint_out.bar = _tint_in.foo; +float sub_func_ep_1(thread ep_1_out& _tint_out, float param, float foo) { + _tint_out.bar = foo; _tint_out.val = param; - return _tint_in.foo; + return foo; } -fragment ep_1_out ep_1(ep_1_in _tint_in [[stage_in]]) { +fragment ep_1_out ep_1(tint_symbol_2 tint_symbol_3 [[stage_in]]) { ep_1_out _tint_out = {}; - _tint_out.bar = sub_func_ep_1(_tint_in, _tint_out, 1.0f); + const float tint_symbol_4 = tint_symbol_3.tint_symbol_1; + _tint_out.bar = sub_func_ep_1(_tint_out, 1.0f, tint_symbol_4); return _tint_out; } @@ -441,16 +450,18 @@ fragment ep_1_out ep_1() { TEST_F( MslGeneratorImplTest, Emit_Decoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { // NOLINT - Global("coord", ty.vec4(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(ast::Builtin::kFragCoord)}); + auto* coord_in = + Var("coord", ty.vec4(), ast::StorageClass::kNone, nullptr, + ast::DecorationList{ + create(ast::Builtin::kFragCoord)}); Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{ create(ast::Builtin::kFragDepth)}); ast::VariableList params; - params.push_back(Var("param", ty.f32(), ast::StorageClass::kFunction)); + params.push_back(Var("param", ty.f32(), ast::StorageClass::kNone)); + params.push_back(Var("coord", ty.vec4(), ast::StorageClass::kNone)); auto body = ast::StatementList{ create(Expr("depth"), @@ -461,16 +472,17 @@ TEST_F( Func("sub_func", params, ty.f32(), body, ast::DecorationList{}); body = ast::StatementList{ - create(Expr("depth"), Call("sub_func", 1.0f)), + create(Expr("depth"), + Call("sub_func", 1.0f, Expr("coord"))), create(), }; - Func("ep_1", ast::VariableList{}, ty.void_(), body, + Func("ep_1", ast::VariableList{coord_in}, ty.void_(), body, ast::DecorationList{ create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate()) << gen.error(); EXPECT_EQ(gen.result(), R"(#include @@ -480,14 +492,14 @@ struct ep_1_out { float depth [[depth(any)]]; }; -float sub_func_ep_1(thread ep_1_out& _tint_out, thread float4& coord, float param) { +float sub_func_ep_1(thread ep_1_out& _tint_out, float param, float4 coord) { _tint_out.depth = coord.x; return param; } fragment ep_1_out ep_1(float4 coord [[position]]) { ep_1_out _tint_out = {}; - _tint_out.depth = sub_func_ep_1(_tint_out, coord, 1.0f); + _tint_out.depth = sub_func_ep_1(_tint_out, 1.0f, coord); return _tint_out; } diff --git a/src/writer/msl/test_helper.h b/src/writer/msl/test_helper.h index 4241036265..7cbc608ca3 100644 --- a/src/writer/msl/test_helper.h +++ b/src/writer/msl/test_helper.h @@ -20,6 +20,7 @@ #include "gtest/gtest.h" #include "src/program_builder.h" +#include "src/transform/msl.h" #include "src/writer/msl/generator_impl.h" namespace tint { @@ -54,6 +55,34 @@ class TestHelperBase : public BASE, public ProgramBuilder { return *gen_; } + /// Builds the program, runs the program through the transform::Msl sanitizer + /// and returns a GeneratorImpl from the sanitized program. + /// @note The generator is only built once. Multiple calls to Build() will + /// return the same GeneratorImpl without rebuilding. + /// @return the built generator + GeneratorImpl& SanitizeAndBuild() { + if (gen_) { + return *gen_; + } + [&]() { + ASSERT_TRUE(IsValid()) << "Builder program is not valid\n" + << diag::Formatter().format(Diagnostics()); + }(); + program = std::make_unique(std::move(*this)); + [&]() { + ASSERT_TRUE(program->IsValid()) + << diag::Formatter().format(program->Diagnostics()); + }(); + auto result = transform::Msl().Run(program.get()); + [&]() { + ASSERT_TRUE(result.program.IsValid()) + << diag::Formatter().format(result.program.Diagnostics()); + }(); + *program = std::move(result.program); + gen_ = std::make_unique(program.get()); + return *gen_; + } + /// The program built with a call to Build() std::unique_ptr program;