diff --git a/include/tint/tint.h b/include/tint/tint.h index b6d0031785..6e72c470e1 100644 --- a/include/tint/tint.h +++ b/include/tint/tint.h @@ -25,6 +25,7 @@ #include "src/reader/reader.h" #include "src/transform/binding_remapper.h" #include "src/transform/bound_array_accessors.h" +#include "src/transform/canonicalize_entry_point_io.h" #include "src/transform/emit_vertex_point_size.h" #include "src/transform/first_index_offset.h" #include "src/transform/manager.h" diff --git a/samples/main.cc b/samples/main.cc index b1db3a7b79..34ba6eef81 100644 --- a/samples/main.cc +++ b/samples/main.cc @@ -707,6 +707,8 @@ int main(int argc, const char** argv) { #endif // TINT_BUILD_SPV_WRITER #if TINT_BUILD_MSL_WRITER case Format::kMsl: + transform_manager.append( + std::make_unique()); transform_manager.append(std::make_unique()); break; #endif // TINT_BUILD_MSL_WRITER diff --git a/src/transform/msl.cc b/src/transform/msl.cc index ffcb6b102e..af4bc448e6 100644 --- a/src/transform/msl.cc +++ b/src/transform/msl.cc @@ -14,13 +14,9 @@ #include "src/transform/msl.h" -#include -#include #include -#include #include "src/program_builder.h" -#include "src/semantic/variable.h" namespace tint { namespace transform { @@ -270,134 +266,9 @@ Transform::Output Msl::Run(const Program* in, const DataMap&) { ProgramBuilder out; CloneContext ctx(&out, in); RenameReservedKeywords(&ctx, kReservedKeywords); - HandleEntryPointIOTypes(ctx); ctx.Clone(); - return Output{Program(std::move(out))}; } -void Msl::HandleEntryPointIOTypes(CloneContext& ctx) const { - // Collect location-decorated entry point parameters into a struct. - // Insert function-scope const declarations to replace those parameters. - // - // Before: - // ``` - // [[stage(fragment)]] - // fn frag_main([[builtin(frag_coord)]] coord : vec4, - // [[location(1)]] loc1 : f32, - // [[location(2)]] loc2 : vec4) -> void { - // var col : f32 = (coord.x * loc1); - // } - // ``` - // - // After: - // ``` - // struct frag_main_in { - // [[location(1)]] loc1 : f32; - // [[location(2)]] loc2 : vec4 - // }; - - // [[stage(fragment)]] - // fn frag_main([[builtin(frag_coord)]] coord : vec4, - // in : frag_main_in) -> void { - // const loc1 : f32 = in.loc1; - // const loc2 : vec4 = in.loc2; - // var col : f32 = (coord.x * loc1); - // } - // ``` - - for (auto* func : ctx.src->AST().Functions()) { - if (!func->IsEntryPoint()) { - continue; - } - - // Find location-decorated parameters and build a struct to hold them. - ast::StructMemberList struct_members; - std::unordered_set builtins; - for (auto* param : func->params()) { - // TODO(jrprice): Handle structs (collate members into a single struct). - if (param->decorations().size() != 1) { - TINT_ICE(ctx.dst->Diagnostics()) << "Unsupported entry point parameter"; - } - - auto* deco = param->decorations()[0]; - if (deco->Is()) { - // Keep any builtin-decorated parameters unchanged. - builtins.insert(param); - continue; - } else if (auto* loc = deco->As()) { - // Create a struct member with the location decoration. - std::string name = ctx.src->Symbols().NameFor(param->symbol()); - auto* type = ctx.Clone(ctx.src->Sem().Get(param)->Type()); - struct_members.push_back( - ctx.dst->Member(name, type, ast::DecorationList{ctx.Clone(loc)})); - } else { - TINT_ICE(ctx.dst->Diagnostics()) - << "Unsupported entry point parameter decoration"; - } - } - - if (struct_members.empty()) { - // Nothing to do. - continue; - } - - ast::VariableList new_parameters; - ast::StatementList new_body; - - // Create a struct type to hold all of the user-defined input parameters. - auto* in_struct = ctx.dst->create( - ctx.dst->Symbols().New(), - ctx.dst->create(struct_members, ast::DecorationList{})); - ctx.InsertBefore(func, in_struct); - - // Create a new function parameter using this struct type. - auto struct_param_symbol = ctx.dst->Symbols().New(); - auto* struct_param = - ctx.dst->Var(struct_param_symbol, in_struct, ast::StorageClass::kNone); - new_parameters.push_back(struct_param); - - // Replace the original parameters with function-scope constants. - for (auto* param : func->params()) { - if (builtins.count(param)) { - // Keep any builtin-decorated parameters unchanged. - new_parameters.push_back(ctx.Clone(param)); - continue; - } - - auto name = ctx.src->Symbols().NameFor(param->symbol()); - - // Create a function-scope const to replace the parameter. - // Initialize it with the value extracted from the struct parameter. - auto func_const_symbol = ctx.dst->Symbols().Register(name); - auto* type = ctx.Clone(ctx.src->Sem().Get(param)->Type()); - auto* constructor = ctx.dst->MemberAccessor(struct_param_symbol, name); - auto* func_const = ctx.dst->Const(func_const_symbol, type, constructor); - - new_body.push_back(ctx.dst->WrapInStatement(func_const)); - - // Replace all uses of the function parameter with the function const. - for (auto* user : ctx.src->Sem().Get(param)->Users()) { - ctx.Replace(user->Declaration(), - ctx.dst->Expr(func_const_symbol)); - } - } - - // Copy over the rest of the function body unchanged. - for (auto* stmt : func->body()->list()) { - new_body.push_back(ctx.Clone(stmt)); - } - - // Rewrite the function header with the new parameters. - auto* new_func = ctx.dst->create( - func->source(), ctx.Clone(func->symbol()), new_parameters, - ctx.Clone(func->return_type()), - ctx.dst->create(new_body), - ctx.Clone(func->decorations()), - ctx.Clone(func->return_type_decorations())); - ctx.Replace(func, new_func); - } -} - } // namespace transform } // namespace tint diff --git a/src/transform/msl.h b/src/transform/msl.h index 13abdca57e..3121b67bca 100644 --- a/src/transform/msl.h +++ b/src/transform/msl.h @@ -34,10 +34,6 @@ class Msl : public Transform { /// @param data optional extra transform-specific input data /// @returns the transformation result Output Run(const Program* program, const DataMap& data = {}) override; - - private: - /// Hoist location-decorated entry point parameters out to struct members. - void HandleEntryPointIOTypes(CloneContext& ctx) const; }; } // namespace transform diff --git a/src/transform/msl_test.cc b/src/transform/msl_test.cc index d1ddac8e94..7bf1931478 100644 --- a/src/transform/msl_test.cc +++ b/src/transform/msl_test.cc @@ -326,109 +326,6 @@ INSTANTIATE_TEST_SUITE_P(MslReservedKeywordTest, "vec", "vertex")); -using MslEntryPointIOTest = TransformTest; - -TEST_F(MslEntryPointIOTest, HandleEntryPointIOTypes_Parameters) { - auto* src = R"( -[[stage(fragment)]] -fn frag_main([[builtin(frag_coord)]] coord : vec4, - [[location(1)]] loc1 : f32, - [[location(2)]] loc2 : vec4) -> void { - var col : f32 = (coord.x * loc1); -} -)"; - - auto* expect = R"( -struct tint_symbol_3 { - [[location(1)]] - loc1 : f32; - [[location(2)]] - loc2 : vec4; -}; - -[[stage(fragment)]] -fn frag_main(tint_symbol_4 : tint_symbol_3, [[builtin(frag_coord)]] coord : vec4) -> void { - const loc1 : f32 = tint_symbol_4.loc1; - const loc2 : vec4 = tint_symbol_4.loc2; - var col : f32 = (coord.x * loc1); -} -)"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - -TEST_F(MslEntryPointIOTest, HandleEntryPointIOTypes_OnlyBuiltinParameters) { - // Expect no change. - auto* src = R"( -[[stage(fragment)]] -fn frag_main([[builtin(frag_coord)]] coord : vec4) -> void { -} -)"; - - auto got = Run(src); - - EXPECT_EQ(src, str(got)); -} - -TEST_F(MslEntryPointIOTest, HandleEntryPointIOTypes_Parameter_TypeAlias) { - auto* src = R"( -type myf32 = f32; - -[[stage(fragment)]] -fn frag_main([[location(1)]] loc1 : myf32) -> void { -} -)"; - - auto* expect = R"( -type myf32 = f32; - -struct tint_symbol_3 { - [[location(1)]] - loc1 : myf32; -}; - -[[stage(fragment)]] -fn frag_main(tint_symbol_4 : tint_symbol_3) -> void { - const loc1 : myf32 = tint_symbol_4.loc1; -} -)"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - -TEST_F(MslEntryPointIOTest, HandleEntryPointIOTypes_Parameters_EmptyBody) { - auto* src = R"( -[[stage(fragment)]] -fn frag_main([[builtin(frag_coord)]] coord : vec4, - [[location(1)]] loc1 : f32, - [[location(2)]] loc2 : vec4) -> void { -} -)"; - - auto* expect = R"( -struct tint_symbol_3 { - [[location(1)]] - loc1 : f32; - [[location(2)]] - loc2 : vec4; -}; - -[[stage(fragment)]] -fn frag_main(tint_symbol_4 : tint_symbol_3, [[builtin(frag_coord)]] coord : vec4) -> void { - const loc1 : f32 = tint_symbol_4.loc1; - const loc2 : vec4 = tint_symbol_4.loc2; -} -)"; - - auto got = Run(src); - - EXPECT_EQ(expect, str(got)); -} - } // namespace } // namespace transform } // namespace tint diff --git a/src/writer/msl/generator_impl.cc b/src/writer/msl/generator_impl.cc index cf7e396b2e..52152df3c9 100644 --- a/src/writer/msl/generator_impl.cc +++ b/src/writer/msl/generator_impl.cc @@ -959,8 +959,8 @@ bool GeneratorImpl::EmitLiteral(ast::Literal* lit) { return true; } -// TODO(jrprice): Remove this when we remove support for entry point params as -// module-scope globals. +// TODO(crbug.com/tint/697): Remove this when we remove support for entry point +// params as module-scope globals. bool GeneratorImpl::EmitEntryPointData(ast::Function* func) { auto* func_sem = program_->Sem().Get(func); @@ -1376,14 +1376,19 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { auto out_data = ep_sym_to_out_data_.find(current_ep_sym_); bool has_out_data = out_data != ep_sym_to_out_data_.end(); if (has_out_data) { + // TODO(crbug.com/tint/697): Remove this. + if (!func->return_type()->Is()) { + TINT_ICE(diagnostics_) << "Mixing module-scope variables and return " + "types for shader outputs"; + } out_ << out_data->second.struct_name; } else { - out_ << "void"; + out_ << func->return_type()->FriendlyName(program_->Symbols()); } out_ << " " << program_->Symbols().NameFor(func->symbol()) << "("; bool first = true; - // TODO(jrprice): Remove this when we remove support for builtins as globals. + // TODO(crbug.com/tint/697): Remove this. auto in_data = ep_sym_to_in_data_.find(current_ep_sym_); if (in_data != ep_sym_to_in_data_.end()) { out_ << in_data->second.struct_name << " " << in_data->second.var_name @@ -1432,7 +1437,7 @@ bool GeneratorImpl::EmitEntryPointFunction(ast::Function* func) { } } - // TODO(jrprice): Remove this when we remove support for builtins as globals. + // TODO(crbug.com/tint/697): Remove this. for (auto data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; if (var->StorageClass() != ast::StorageClass::kInput) { @@ -1740,12 +1745,14 @@ bool GeneratorImpl::EmitReturn(ast::ReturnStatement* stmt) { out_ << "return"; + // TODO(crbug.com/tint/697): Remove this conditional. if (generating_entry_point_) { auto out_data = ep_sym_to_out_data_.find(current_ep_sym_); if (out_data != ep_sym_to_out_data_.end()) { out_ << " " << out_data->second.var_name; } - } else if (stmt->has_value()) { + } + if (stmt->has_value()) { out_ << " "; if (!EmitExpression(stmt->value())) { return false; @@ -2095,8 +2102,35 @@ bool GeneratorImpl::EmitStructType(const type::Struct* str) { // Emit decorations for (auto* deco : mem->decorations()) { - if (auto* loc = deco->As()) { - out_ << " [[user(locn" + std::to_string(loc->value()) + ")]]"; + if (auto* builtin = deco->As()) { + auto attr = builtin_to_attribute(builtin->value()); + if (attr.empty()) { + diagnostics_.add_error("unknown builtin"); + return false; + } + out_ << " [[" << attr << "]]"; + } else if (auto* loc = deco->As()) { + auto& pipeline_stage_uses = + program_->Sem().Get(str)->PipelineStageUses(); + if (pipeline_stage_uses.size() != 1) { + TINT_ICE(diagnostics_) << "invalid entry point IO struct uses"; + } + + if (pipeline_stage_uses.count( + semantic::PipelineStageUsage::kVertexInput)) { + out_ << " [[attribute(" + std::to_string(loc->value()) + ")]]"; + } else if (pipeline_stage_uses.count( + semantic::PipelineStageUsage::kVertexOutput)) { + out_ << " [[user(locn" + std::to_string(loc->value()) + ")]]"; + } else if (pipeline_stage_uses.count( + semantic::PipelineStageUsage::kFragmentInput)) { + out_ << " [[user(locn" + std::to_string(loc->value()) + ")]]"; + } else if (pipeline_stage_uses.count( + semantic::PipelineStageUsage::kFragmentOutput)) { + out_ << " [[color(" + std::to_string(loc->value()) + ")]]"; + } else { + TINT_ICE(diagnostics_) << "invalid use of location decoration"; + } } } diff --git a/src/writer/msl/generator_impl_function_test.cc b/src/writer/msl/generator_impl_function_test.cc index 8c67520efc..265ce0a069 100644 --- a/src/writer/msl/generator_impl_function_test.cc +++ b/src/writer/msl/generator_impl_function_test.cc @@ -92,60 +92,16 @@ fragment void main() { )"); } -TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_NoReturn_InOut) { - auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr, - ast::DecorationList{create(0)}); - - // TODO(jrprice): Make this the return value when supported. - Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr, - ast::DecorationList{create(1)}); - - Func("main", ast::VariableList{foo_in}, ty.void_(), - ast::StatementList{ - create(Expr("bar"), Expr("foo")), - /* no explicit return */}, - ast::DecorationList{ - create(ast::PipelineStage::kFragment)}); - - GeneratorImpl& gen = SanitizeAndBuild(); - - ASSERT_TRUE(gen.Generate()) << gen.error(); - EXPECT_EQ(gen.result(), R"(#include - -using namespace metal; -struct tint_symbol_2 { - float foo [[user(locn0)]]; -}; - -struct _tint_main_out { - float bar [[color(1)]]; -}; - -fragment _tint_main_out _tint_main(tint_symbol_2 tint_symbol_3 [[stage_in]]) { - _tint_main_out _tint_out = {}; - const float foo = tint_symbol_3.foo; - _tint_out.bar = foo; - return _tint_out; -} - -)"); -} - TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOutVars) { - auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr, - ast::DecorationList{create(0)}); - - // TODO(jrprice): Make this the return value when supported. - Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr, - ast::DecorationList{create(1)}); - - auto body = ast::StatementList{ - create(Expr("bar"), Expr("foo")), - create(), - }; - Func("frag_main", ast::VariableList{foo_in}, ty.void_(), body, - ast::DecorationList{ - create(ast::PipelineStage::kFragment)}); + // fn frag_main([[location(0)]] foo : f32) -> [[location(1)]] f32 { + // return foo; + // } + auto* foo_in = + Const("foo", ty.f32(), nullptr, {create(0)}); + Func("frag_main", ast::VariableList{foo_in}, ty.f32(), + {create(Expr("foo"))}, + {create(ast::PipelineStage::kFragment)}, + {create(1)}); GeneratorImpl& gen = SanitizeAndBuild(); @@ -153,45 +109,32 @@ TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOutVars) { EXPECT_EQ(gen.result(), R"(#include using namespace metal; -struct tint_symbol_2 { +struct tint_symbol_3 { float foo [[user(locn0)]]; }; - -struct frag_main_out { - float bar [[color(1)]]; +struct tint_symbol_5 { + float value [[color(1)]]; }; -fragment frag_main_out frag_main(tint_symbol_2 tint_symbol_3 [[stage_in]]) { - frag_main_out _tint_out = {}; - const float foo = tint_symbol_3.foo; - _tint_out.bar = foo; - return _tint_out; +fragment tint_symbol_5 frag_main(tint_symbol_3 tint_symbol_1 [[stage_in]]) { + const float foo = tint_symbol_1.foo; + return {foo}; } )"); } TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOut_Builtins) { + // fn frag_main([[position(0)]] coord : vec4) -> [[frag_depth]] f32 { + // return coord.x; + // } auto* coord_in = - Var("coord", ty.vec4(), ast::StorageClass::kNone, nullptr, - ast::DecorationList{ - create(ast::Builtin::kFragCoord)}); - - // TODO(jrprice): Make this the return value when supported. - Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr, - ast::DecorationList{ - create(ast::Builtin::kFragDepth)}); - - auto body = ast::StatementList{ - create(Expr("depth"), - MemberAccessor("coord", "x")), - create(), - }; - - Func("frag_main", ast::VariableList{coord_in}, ty.void_(), body, - ast::DecorationList{ - create(ast::PipelineStage::kFragment), - }); + Const("coord", ty.vec4(), nullptr, + {create(ast::Builtin::kFragCoord)}); + Func("frag_main", ast::VariableList{coord_in}, ty.f32(), + {create(MemberAccessor("coord", "x"))}, + {create(ast::PipelineStage::kFragment)}, + {create(ast::Builtin::kFragDepth)}); GeneratorImpl& gen = SanitizeAndBuild(); @@ -199,50 +142,155 @@ TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_WithInOut_Builtins) { EXPECT_EQ(gen.result(), R"(#include using namespace metal; -struct frag_main_out { - float depth [[depth(any)]]; +struct tint_symbol_3 { + float4 coord [[position]]; +}; +struct tint_symbol_5 { + float value [[depth(any)]]; }; -fragment frag_main_out frag_main(float4 coord [[position]]) { - frag_main_out _tint_out = {}; - _tint_out.depth = coord.x; - return _tint_out; +fragment tint_symbol_5 frag_main(tint_symbol_3 tint_symbol_1 [[stage_in]]) { + const float4 coord = tint_symbol_1.coord; + return {coord.x}; } )"); } -TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoint_With_Uniform) { - Global("coord", ty.vec4(), ast::StorageClass::kUniform, nullptr, - ast::DecorationList{create(0), - create(1)}); +TEST_F(MslGeneratorImplTest, + Emit_Decoration_EntryPoint_SharedStruct_DifferentStages) { + // struct Interface { + // [[location(1)]] col1 : f32; + // [[location(2)]] col2 : f32; + // }; + // fn vert_main() -> Interface { + // return Interface(0.4, 0.6); + // } + // fn frag_main(colors : Interface) -> void { + // const r = colors.col1; + // const g = colors.col2; + // } + auto* interface_struct = Structure( + "Interface", + {Member("col1", ty.f32(), {create(1)}), + Member("col2", ty.f32(), {create(2)})}); - auto* var = Var("v", ty.f32(), ast::StorageClass::kFunction, - MemberAccessor("coord", "x")); + Func("vert_main", {}, interface_struct, + {create( + Construct(interface_struct, Expr(0.5f), Expr(0.25f)))}, + {create(ast::PipelineStage::kVertex)}); - Func("frag_main", ast::VariableList{}, ty.void_(), - ast::StatementList{ - create(var), - create(), + Func("frag_main", {Const("colors", interface_struct)}, ty.void_(), + { + WrapInStatement( + Const("r", ty.f32(), MemberAccessor(Expr("colors"), "col1"))), + WrapInStatement( + Const("g", ty.f32(), MemberAccessor(Expr("colors"), "col2"))), }, - ast::DecorationList{ - create(ast::PipelineStage::kFragment), - }); + {create(ast::PipelineStage::kFragment)}); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate()) << gen.error(); EXPECT_EQ(gen.result(), R"(#include using namespace metal; -fragment void frag_main(constant float4& coord [[buffer(0)]]) { - float v = coord.x; +struct Interface { + float col1; + float col2; +}; +struct tint_symbol_4 { + float col1 [[user(locn1)]]; + float col2 [[user(locn2)]]; +}; +struct tint_symbol_9 { + float col1 [[user(locn1)]]; + float col2 [[user(locn2)]]; +}; + +vertex tint_symbol_4 vert_main() { + const Interface tint_symbol_5 = {0.5f, 0.25f}; + return {tint_symbol_5.col1, tint_symbol_5.col2}; +} + +fragment void frag_main(tint_symbol_9 tint_symbol_7 [[stage_in]]) { + const Interface colors = {tint_symbol_7.col1, tint_symbol_7.col2}; + const float r = colors.col1; + const float g = colors.col2; return; } )"); } +TEST_F(MslGeneratorImplTest, + Emit_Decoration_EntryPoint_SharedStruct_HelperFunction) { + // struct VertexOutput { + // [[builtin(position)]] pos : vec4; + // }; + // fn foo(x : f32) -> VertexOutput { + // return VertexOutput(vec4(x, x, x, 1.0)); + // } + // fn vert_main1() -> VertexOutput { + // return foo(0.5); + // } + // fn vert_main2() -> VertexOutput { + // return foo(0.25); + // } + auto* vertex_output_struct = Structure( + "VertexOutput", + {Member("pos", ty.vec4(), + {create(ast::Builtin::kPosition)})}); + + Func("foo", {Const("x", ty.f32())}, vertex_output_struct, + {create(Construct( + vertex_output_struct, Construct(ty.vec4(), Expr("x"), Expr("x"), + Expr("x"), Expr(1.f))))}, + {}); + + Func("vert_main1", {}, vertex_output_struct, + {create( + Construct(vertex_output_struct, Expr(Call("foo", Expr(0.5f)))))}, + {create(ast::PipelineStage::kVertex)}); + + Func("vert_main2", {}, vertex_output_struct, + {create( + Construct(vertex_output_struct, Expr(Call("foo", Expr(0.25f)))))}, + {create(ast::PipelineStage::kVertex)}); + + GeneratorImpl& gen = SanitizeAndBuild(); + + ASSERT_TRUE(gen.Generate()) << gen.error(); + EXPECT_EQ(gen.result(), R"(#include + +using namespace metal; +struct VertexOutput { + float4 pos; +}; +struct tint_symbol_3 { + float4 pos [[position]]; +}; +struct tint_symbol_7 { + float4 pos [[position]]; +}; + +VertexOutput foo(float x) { + return {float4(x, x, x, 1.0f)}; +} + +vertex tint_symbol_3 vert_main1() { + const VertexOutput tint_symbol_5 = {foo(0.5f)}; + return {tint_symbol_5.pos}; +} + +vertex tint_symbol_7 vert_main2() { + const VertexOutput tint_symbol_8 = {foo(0.25f)}; + return {tint_symbol_8.pos}; +} + +)"); +} + TEST_F(MslGeneratorImplTest, Emit_FunctionDecoration_EntryPoint_With_RW_StorageBuffer) { auto* s = Structure("Data", { @@ -331,11 +379,12 @@ fragment void frag_main(const device Data& coord [[buffer(0)]]) { )"); } +// TODO(crbug.com/tint/697): Remove this test TEST_F( MslGeneratorImplTest, Emit_Decoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) { // NOLINT - auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr, - ast::DecorationList{create(0)}); + Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr, + ast::DecorationList{create(0)}); Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{create(1)}); @@ -345,7 +394,6 @@ TEST_F( ast::VariableList params; params.push_back(Var("param", ty.f32(), ast::StorageClass::kNone)); - params.push_back(Var("foo", ty.f32(), ast::StorageClass::kNone)); auto body = ast::StatementList{ create(Expr("bar"), Expr("foo")), @@ -355,23 +403,22 @@ TEST_F( Func("sub_func", params, ty.f32(), body, ast::DecorationList{}); body = ast::StatementList{ - create(Expr("bar"), - Call("sub_func", 1.0f, Expr("foo"))), + create(Expr("bar"), Call("sub_func", 1.0f)), create(), }; - Func("ep_1", ast::VariableList{foo_in}, ty.void_(), body, + Func("ep_1", ast::VariableList{}, ty.void_(), body, ast::DecorationList{ create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = SanitizeAndBuild(); + GeneratorImpl& gen = Build(); ASSERT_TRUE(gen.Generate()) << gen.error(); EXPECT_EQ(gen.result(), R"(#include using namespace metal; -struct tint_symbol_2 { +struct ep_1_in { float foo [[user(locn0)]]; }; @@ -380,22 +427,22 @@ struct ep_1_out { float val [[color(0)]]; }; -float sub_func_ep_1(thread ep_1_out& _tint_out, float param, float foo) { - _tint_out.bar = foo; +float sub_func_ep_1(thread ep_1_in& _tint_in, thread ep_1_out& _tint_out, float param) { + _tint_out.bar = _tint_in.foo; _tint_out.val = param; - return foo; + return _tint_in.foo; } -fragment ep_1_out ep_1(tint_symbol_2 tint_symbol_3 [[stage_in]]) { +fragment ep_1_out ep_1(ep_1_in _tint_in [[stage_in]]) { ep_1_out _tint_out = {}; - const float foo = tint_symbol_3.foo; - _tint_out.bar = sub_func_ep_1(_tint_out, 1.0f, foo); + _tint_out.bar = sub_func_ep_1(_tint_in, _tint_out, 1.0f); return _tint_out; } )"); } +// TODO(crbug.com/tint/697): Remove this test TEST_F(MslGeneratorImplTest, Emit_Decoration_Called_By_EntryPoints_NoUsedGlobals) { Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr, @@ -444,13 +491,13 @@ fragment ep_1_out ep_1() { )"); } +// TODO(crbug.com/tint/697): Remove this test TEST_F( MslGeneratorImplTest, Emit_Decoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { // NOLINT - auto* coord_in = - Var("coord", ty.vec4(), ast::StorageClass::kNone, nullptr, - ast::DecorationList{ - create(ast::Builtin::kFragCoord)}); + Global("coord", ty.vec4(), ast::StorageClass::kInput, nullptr, + ast::DecorationList{ + create(ast::Builtin::kFragCoord)}); Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{ @@ -458,7 +505,6 @@ TEST_F( ast::VariableList params; params.push_back(Var("param", ty.f32(), ast::StorageClass::kNone)); - params.push_back(Var("coord", ty.vec4(), ast::StorageClass::kNone)); auto body = ast::StatementList{ create(Expr("depth"), @@ -469,17 +515,16 @@ TEST_F( Func("sub_func", params, ty.f32(), body, ast::DecorationList{}); body = ast::StatementList{ - create(Expr("depth"), - Call("sub_func", 1.0f, Expr("coord"))), + create(Expr("depth"), Call("sub_func", 1.0f)), create(), }; - Func("ep_1", ast::VariableList{coord_in}, ty.void_(), body, + Func("ep_1", ast::VariableList{}, ty.void_(), body, ast::DecorationList{ create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = SanitizeAndBuild(); + GeneratorImpl& gen = Build(); ASSERT_TRUE(gen.Generate()) << gen.error(); EXPECT_EQ(gen.result(), R"(#include @@ -489,14 +534,14 @@ struct ep_1_out { float depth [[depth(any)]]; }; -float sub_func_ep_1(thread ep_1_out& _tint_out, float param, float4 coord) { +float sub_func_ep_1(thread ep_1_out& _tint_out, thread float4& coord, float param) { _tint_out.depth = coord.x; return param; } fragment ep_1_out ep_1(float4 coord [[position]]) { ep_1_out _tint_out = {}; - _tint_out.depth = sub_func_ep_1(_tint_out, 1.0f, coord); + _tint_out.depth = sub_func_ep_1(_tint_out, coord, 1.0f); return _tint_out; } @@ -666,6 +711,7 @@ fragment void frag_main(const device Data& coord [[buffer(0)]]) { )"); } +// TODO(crbug.com/tint/697): Remove this test TEST_F(MslGeneratorImplTest, Emit_Decoration_EntryPoints_WithGlobal_Nested_Return) { Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr, diff --git a/src/writer/msl/test_helper.h b/src/writer/msl/test_helper.h index 7cbc608ca3..a40d2c3bcf 100644 --- a/src/writer/msl/test_helper.h +++ b/src/writer/msl/test_helper.h @@ -20,6 +20,8 @@ #include "gtest/gtest.h" #include "src/program_builder.h" +#include "src/transform/canonicalize_entry_point_io.h" +#include "src/transform/manager.h" #include "src/transform/msl.h" #include "src/writer/msl/generator_impl.h" @@ -73,7 +75,12 @@ class TestHelperBase : public BASE, public ProgramBuilder { ASSERT_TRUE(program->IsValid()) << diag::Formatter().format(program->Diagnostics()); }(); - auto result = transform::Msl().Run(program.get()); + + tint::transform::Manager transform_manager; + transform_manager.append( + std::make_unique()); + transform_manager.append(std::make_unique()); + auto result = transform_manager.Run(program.get()); [&]() { ASSERT_TRUE(result.program.IsValid()) << diag::Formatter().format(result.program.Diagnostics());