From cf79a16fefa9169063d40c64b017e13c669708a4 Mon Sep 17 00:00:00 2001 From: James Price Date: Mon, 15 Mar 2021 16:39:21 +0000 Subject: [PATCH] [hlsl-writer] Handle non-struct entry point parameters Add a sanitizing transform to collect input parameters into a struct. HLSL does not allow non-struct entry-point parameters, so any location- or builtin-decorated inputs have to be provided via a struct instead. Bug: tint:511 Change-Id: I3784bcad3bfda757ebcf0efc98c499cfce639b5e Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/44420 Commit-Queue: James Price Auto-Submit: James Price Reviewed-by: Ben Clayton --- src/transform/hlsl.cc | 131 +++++++++++++++ src/transform/hlsl.h | 3 + src/transform/hlsl_test.cc | 99 +++++++++++ src/writer/hlsl/generator_impl.cc | 50 ++++++ .../hlsl/generator_impl_function_test.cc | 154 ++++++++++-------- 5 files changed, 367 insertions(+), 70 deletions(-) diff --git a/src/transform/hlsl.cc b/src/transform/hlsl.cc index f634b72f62..6da5a2a5da 100644 --- a/src/transform/hlsl.cc +++ b/src/transform/hlsl.cc @@ -15,11 +15,13 @@ #include "src/transform/hlsl.h" #include +#include #include "src/ast/variable_decl_statement.h" #include "src/program_builder.h" #include "src/semantic/expression.h" #include "src/semantic/statement.h" +#include "src/semantic/variable.h" namespace tint { namespace transform { @@ -31,6 +33,7 @@ Transform::Output Hlsl::Run(const Program* in) { ProgramBuilder out; CloneContext ctx(&out, in); PromoteArrayInitializerToConstVar(ctx); + HandleEntryPointIOTypes(ctx); ctx.Clone(); return Output{Program(std::move(out))}; } @@ -103,5 +106,133 @@ void Hlsl::PromoteArrayInitializerToConstVar(CloneContext& ctx) const { } } +void Hlsl::HandleEntryPointIOTypes(CloneContext& ctx) const { + // Collect entry point parameters into a struct. + // Insert function-scope const declarations to replace those parameters. + // + // Before: + // ``` + // [[stage(fragment)]] + // fn frag_main([[builtin(frag_coord)]] coord : vec4, + // [[location(1)]] loc1 : f32, + // [[location(2)]] loc2 : vec4) -> void { + // var col : f32 = (coord.x * loc1); + // } + // ``` + // + // After: + // ``` + // struct frag_main_in { + // [[builtin(frag_coord)]] coord : vec4; + // [[location(1)]] loc1 : f32; + // [[location(2)]] loc2 : vec4 + // }; + + // [[stage(fragment)]] + // fn frag_main(in : frag_main_in) -> void { + // const coord : vec4 = in.coord; + // const loc1 : f32 = in.loc1; + // const loc2 : vec4 = in.loc2; + // var col : f32 = (coord.x * loc1); + // } + // ``` + + for (auto* func : ctx.src->AST().Functions()) { + if (!func->IsEntryPoint()) { + continue; + } + + // Build a new structure to hold the non-struct input parameters. + ast::StructMemberList struct_members; + for (auto* param : func->params()) { + if (param->type()->Is()) { + // Already a struct, nothing to do. + continue; + } + + if (param->decorations().size() != 1) { + TINT_ICE(ctx.dst->Diagnostics()) << "Unsupported entry point parameter"; + } + + auto name = ctx.src->Symbols().NameFor(param->symbol()); + + auto* deco = param->decorations()[0]; + if (auto* builtin = deco->As()) { + // Create a struct member with the builtin decoration. + struct_members.push_back( + ctx.dst->Member(name, ctx.Clone(param->type()), + ast::DecorationList{ctx.Clone(builtin)})); + } else if (auto* loc = deco->As()) { + // Create a struct member with the location decoration. + struct_members.push_back( + ctx.dst->Member(name, ctx.Clone(param->type()), + ast::DecorationList{ctx.Clone(loc)})); + } else { + TINT_ICE(ctx.dst->Diagnostics()) + << "Unsupported entry point parameter decoration"; + } + } + + if (struct_members.empty()) { + // Nothing to do. + continue; + } + + ast::VariableList new_parameters; + ast::StatementList new_body; + + // Create a struct type to hold all of the non-struct input parameters. + auto* in_struct = ctx.dst->create( + ctx.dst->Symbols().New(), + ctx.dst->create(struct_members, ast::DecorationList{})); + ctx.dst->AST().AddConstructedType(in_struct); + + // Create a new function parameter using this struct type. + auto struct_param_symbol = ctx.dst->Symbols().New(); + auto* struct_param = + ctx.dst->Var(struct_param_symbol, in_struct, ast::StorageClass::kNone); + new_parameters.push_back(struct_param); + + // Replace the original parameters with function-scope constants. + for (auto* param : func->params()) { + if (param->type()->Is()) { + // Keep struct parameters unchanged. + new_parameters.push_back(ctx.Clone(param)); + continue; + } + + auto name = ctx.src->Symbols().NameFor(param->symbol()); + + // Create a function-scope const to replace the parameter. + // Initialize it with the value extracted from the struct parameter. + auto func_const_symbol = ctx.dst->Symbols().Register(name); + auto* func_const = + ctx.dst->Const(func_const_symbol, ctx.Clone(param->type()), + ctx.dst->MemberAccessor(struct_param_symbol, name)); + + new_body.push_back(ctx.dst->WrapInStatement(func_const)); + + // Replace all uses of the function parameter with the function const. + for (auto* user : ctx.src->Sem().Get(param)->Users()) { + ctx.Replace(user->Declaration(), + ctx.dst->Expr(func_const_symbol)); + } + } + + // Copy over the rest of the function body unchanged. + for (auto* stmt : func->body()->list()) { + new_body.push_back(ctx.Clone(stmt)); + } + + // Rewrite the function header with the new parameters. + auto* new_func = ctx.dst->create( + func->source(), ctx.Clone(func->symbol()), new_parameters, + ctx.Clone(func->return_type()), + ctx.dst->create(new_body), + ctx.Clone(func->decorations())); + ctx.Replace(func, new_func); + } +} + } // namespace transform } // namespace tint diff --git a/src/transform/hlsl.h b/src/transform/hlsl.h index a313c93f67..7213404ad9 100644 --- a/src/transform/hlsl.h +++ b/src/transform/hlsl.h @@ -43,6 +43,9 @@ class Hlsl : public Transform { /// the array usage statement. /// See crbug.com/tint/406 for more details void PromoteArrayInitializerToConstVar(CloneContext& ctx) const; + + /// Hoist entry point parameters out to struct members. + void HandleEntryPointIOTypes(CloneContext& ctx) const; }; } // namespace transform diff --git a/src/transform/hlsl_test.cc b/src/transform/hlsl_test.cc index a24f68b2a1..b05fdba0ed 100644 --- a/src/transform/hlsl_test.cc +++ b/src/transform/hlsl_test.cc @@ -143,6 +143,105 @@ fn main() -> void { EXPECT_EQ(expect, str(got)); } +TEST_F(HlslTest, HandleEntryPointIOTypes_Parameters) { + auto* src = R"( +struct FragIn { + [[location(2)]] + loc2 : f32; +}; + +[[stage(fragment)]] +fn frag_main([[builtin(frag_coord)]] coord : vec4, + [[location(1)]] loc1 : f32, + frag_in : FragIn) -> void { + var col : f32 = (coord.x * loc1 + frag_in.loc2); +} +)"; + + auto* expect = R"( +struct tint_symbol_3 { + [[builtin(frag_coord)]] + coord : vec4; + [[location(1)]] + loc1 : f32; +}; + +struct FragIn { + [[location(2)]] + loc2 : f32; +}; + +[[stage(fragment)]] +fn frag_main(tint_symbol_4 : tint_symbol_3, frag_in : FragIn) -> void { + const coord : vec4 = tint_symbol_4.coord; + const loc1 : f32 = tint_symbol_4.loc1; + var col : f32 = ((coord.x * loc1) + frag_in.loc2); +} +)"; + + auto got = Transform(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(HlslTest, HandleEntryPointIOTypes_OnlyStructParameters) { + // Expect no change. + auto* src = R"( +struct FragBuiltins { + [[builtin(frag_coord)]] + coord : vec4; +}; + +struct FragInputs { + [[location(1)]] + loc1 : f32; + [[location(2)]] + loc2 : vec4; +}; + +[[stage(fragment)]] +fn frag_main(builtins : FragBuiltins, inputs : FragInputs) -> void { + var col : f32 = (builtins.coord.x * inputs.loc1); +} +)"; + + auto got = Transform(src); + + EXPECT_EQ(src, str(got)); +} + +TEST_F(HlslTest, HandleEntryPointIOTypes_Parameters_EmptyBody) { + auto* src = R"( +[[stage(fragment)]] +fn frag_main([[builtin(frag_coord)]] coord : vec4, + [[location(1)]] loc1 : f32, + [[location(2)]] loc2 : vec4) -> void { +} +)"; + + auto* expect = R"( +struct tint_symbol_4 { + [[builtin(frag_coord)]] + coord : vec4; + [[location(1)]] + loc1 : f32; + [[location(2)]] + loc2 : vec4; +}; + +[[stage(fragment)]] +fn frag_main(tint_symbol_5 : tint_symbol_4) -> void { + const coord : vec4 = tint_symbol_5.coord; + const loc1 : f32 = tint_symbol_5.loc1; + const loc2 : vec4 = tint_symbol_5.loc2; +} +)"; + + auto got = Transform(src); + + EXPECT_EQ(expect, str(got)); +} + } // namespace } // namespace transform } // namespace tint diff --git a/src/writer/hlsl/generator_impl.cc b/src/writer/hlsl/generator_impl.cc index 47fda58ddf..675f4017c6 100644 --- a/src/writer/hlsl/generator_impl.cc +++ b/src/writer/hlsl/generator_impl.cc @@ -1527,6 +1527,8 @@ bool GeneratorImpl::EmitEntryPointData( auto* func_sem = builder_.Sem().Get(func); auto func_sym = func->symbol(); + // TODO(jrprice): Remove this when we remove support for entry point + // inputs/outputs as module-scope globals. for (auto data : func_sem->ReferencedLocationVariables()) { auto* var = data.first; auto* decl = var->Declaration(); @@ -1539,6 +1541,8 @@ bool GeneratorImpl::EmitEntryPointData( } } + // TODO(jrprice): Remove this when we remove support for entry point + // inputs/outputs as module-scope globals. for (auto data : func_sem->ReferencedBuiltinVariables()) { auto* var = data.first; auto* decl = var->Declaration(); @@ -1633,6 +1637,8 @@ bool GeneratorImpl::EmitEntryPointData( out << std::endl; } + // TODO(jrprice): Remove this when we remove support for entry point inputs as + // module-scope globals. if (!in_variables.empty()) { auto in_struct_name = generate_name(builder_.Symbols().NameFor(func_sym) + "_" + kInStructNameSuffix); @@ -1682,6 +1688,8 @@ bool GeneratorImpl::EmitEntryPointData( out << "};" << std::endl << std::endl; } + // TODO(jrprice): Remove this when we remove support for entry point outputs + // as module-scope globals. if (!outvariables.empty()) { auto outstruct_name = generate_name(builder_.Symbols().NameFor(func_sym) + "_" + kOutStructNameSuffix); @@ -1824,10 +1832,32 @@ bool GeneratorImpl::EmitEntryPointFunction(std::ostream& out, out << " " << namer_.NameFor(builder_.Symbols().NameFor(current_ep_sym_)) << "("; + bool first = true; + // TODO(jrprice): Remove this when we remove support for inputs as globals. auto in_data = ep_sym_to_in_data_.find(current_ep_sym_); if (in_data != ep_sym_to_in_data_.end()) { out << in_data->second.struct_name << " " << in_data->second.var_name; + first = false; } + + // Emit entry point parameters. + for (auto* var : func->params()) { + if (!var->type()->Is()) { + TINT_ICE(diagnostics_) << "Unsupported non-struct entry point parameter"; + } + + if (!first) { + out << ", "; + } + first = false; + + if (!EmitType(out, var->type(), "")) { + return false; + } + + out << " " << builder_.Symbols().NameFor(var->symbol()); + } + out << ") {" << std::endl; increment_indent(); @@ -2552,6 +2582,26 @@ bool GeneratorImpl::EmitStructType(std::ostream& out, if (!mem->type()->Is()) { out << " " << namer_.NameFor(builder_.Symbols().NameFor(mem->symbol())); } + + if (mem->decorations().size() > 0) { + auto* deco = mem->decorations()[0]; + if (auto* location = deco->As()) { + out << " : TEXCOORD" << location->value(); + } else if (auto* builtin = deco->As()) { + auto attr = builtin_to_attribute(builtin->value()); + if (attr.empty()) { + diagnostics_.add_error("unsupported builtin"); + return false; + } + out << " : " << attr; + } else if (auto* offset = deco->As()) { + // Nothing to do, offsets are handled at the point of access. + } else { + diagnostics_.add_error("unsupported struct member decoration"); + return false; + } + } + out << ";" << std::endl; } decrement_indent(); diff --git a/src/writer/hlsl/generator_impl_function_test.cc b/src/writer/hlsl/generator_impl_function_test.cc index e6ad1d5700..ba3e9c8435 100644 --- a/src/writer/hlsl/generator_impl_function_test.cc +++ b/src/writer/hlsl/generator_impl_function_test.cc @@ -109,17 +109,18 @@ TEST_F(HlslGeneratorImplTest_Function, TEST_F(HlslGeneratorImplTest_Function, Emit_Decoration_EntryPoint_NoReturn_InOut) { - Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(0), - }); + auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr, + ast::DecorationList{ + create(0), + }); + // TODO(jrprice): Make this the return value when supported. Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{ create(1), }); - Func("main", ast::VariableList{}, ty.void_(), + Func("main", ast::VariableList{foo_in}, ty.void_(), ast::StatementList{ create(Expr("bar"), Expr("foo")), /* no explicit return */}, @@ -127,10 +128,10 @@ TEST_F(HlslGeneratorImplTest_Function, create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_EQ(result(), R"(struct main_in { + EXPECT_EQ(result(), R"(struct tint_symbol_2 { float foo : TEXCOORD0; }; @@ -138,9 +139,10 @@ struct main_out { float bar : SV_Target1; }; -main_out main(main_in tint_in) { +main_out main(tint_symbol_2 tint_symbol_3) { main_out tint_out; - tint_out.bar = tint_in.foo; + const float foo = tint_symbol_3.foo; + tint_out.bar = foo; return tint_out; } @@ -149,17 +151,18 @@ main_out main(main_in tint_in) { TEST_F(HlslGeneratorImplTest_Function, Emit_Decoration_EntryPoint_WithInOutVars) { - Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(0), - }); + auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr, + ast::DecorationList{ + create(0), + }); + // TODO(jrprice): Make this the return value when supported. Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{ create(1), }); - Func("frag_main", ast::VariableList{}, ty.void_(), + Func("frag_main", ast::VariableList{foo_in}, ty.void_(), ast::StatementList{ create(Expr("bar"), Expr("foo")), create(), @@ -168,10 +171,10 @@ TEST_F(HlslGeneratorImplTest_Function, create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_EQ(result(), R"(struct frag_main_in { + EXPECT_EQ(result(), R"(struct tint_symbol_2 { float foo : TEXCOORD0; }; @@ -179,9 +182,10 @@ struct frag_main_out { float bar : SV_Target1; }; -frag_main_out frag_main(frag_main_in tint_in) { +frag_main_out frag_main(tint_symbol_2 tint_symbol_3) { frag_main_out tint_out; - tint_out.bar = tint_in.foo; + const float foo = tint_symbol_3.foo; + tint_out.bar = foo; return tint_out; } @@ -190,17 +194,19 @@ frag_main_out frag_main(frag_main_in tint_in) { TEST_F(HlslGeneratorImplTest_Function, Emit_Decoration_EntryPoint_WithInOut_Builtins) { - Global("coord", ty.vec4(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(ast::Builtin::kFragCoord), - }); + auto* coord_in = + Var("coord", ty.vec4(), ast::StorageClass::kNone, nullptr, + ast::DecorationList{ + create(ast::Builtin::kFragCoord), + }); + // TODO(jrprice): Make this the return value when supported. Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{ create(ast::Builtin::kFragDepth), }); - Func("frag_main", ast::VariableList{}, ty.void_(), + Func("frag_main", ast::VariableList{coord_in}, ty.void_(), ast::StatementList{ create(Expr("depth"), MemberAccessor("coord", "x")), @@ -210,10 +216,10 @@ TEST_F(HlslGeneratorImplTest_Function, create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_EQ(result(), R"(struct frag_main_in { + EXPECT_EQ(result(), R"(struct tint_symbol_2 { float4 coord : SV_Position; }; @@ -221,9 +227,10 @@ struct frag_main_out { float depth : SV_Depth; }; -frag_main_out frag_main(frag_main_in tint_in) { +frag_main_out frag_main(tint_symbol_2 tint_symbol_3) { frag_main_out tint_out; - tint_out.depth = tint_in.coord.x; + const float4 coord = tint_symbol_3.coord; + tint_out.depth = coord.x; return tint_out; } @@ -456,10 +463,10 @@ void frag_main() { TEST_F( HlslGeneratorImplTest_Function, Emit_Decoration_Called_By_EntryPoints_WithLocationGlobals_And_Params) { // NOLINT - Global("foo", ty.f32(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(0), - }); + auto* foo_in = Var("foo", ty.f32(), ast::StorageClass::kNone, nullptr, + ast::DecorationList{ + create(0), + }); Global("bar", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{ @@ -472,7 +479,8 @@ TEST_F( }); Func("sub_func", - ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kFunction)}, + ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kNone), + Var("foo", ty.f32(), ast::StorageClass::kNone)}, ty.f32(), ast::StatementList{ create(Expr("bar"), Expr("foo")), @@ -481,20 +489,20 @@ TEST_F( }, ast::DecorationList{}); - Func( - "ep_1", ast::VariableList{}, ty.void_(), - ast::StatementList{ - create(Expr("bar"), Call("sub_func", 1.0f)), - create(), - }, - ast::DecorationList{ - create(ast::PipelineStage::kFragment), - }); + Func("ep_1", ast::VariableList{foo_in}, ty.void_(), + ast::StatementList{ + create( + Expr("bar"), Call("sub_func", 1.0f, Expr("foo"))), + create(), + }, + ast::DecorationList{ + create(ast::PipelineStage::kFragment), + }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_EQ(result(), R"(struct ep_1_in { + EXPECT_EQ(result(), R"(struct tint_symbol_2 { float foo : TEXCOORD0; }; @@ -503,15 +511,16 @@ struct ep_1_out { float val : SV_Target0; }; -float sub_func_ep_1(in ep_1_in tint_in, out ep_1_out tint_out, float param) { - tint_out.bar = tint_in.foo; +float sub_func_ep_1(out ep_1_out tint_out, float param, float foo) { + tint_out.bar = foo; tint_out.val = param; - return tint_in.foo; + return foo; } -ep_1_out ep_1(ep_1_in tint_in) { +ep_1_out ep_1(tint_symbol_2 tint_symbol_3) { ep_1_out tint_out; - tint_out.bar = sub_func_ep_1(tint_in, tint_out, 1.0f); + const float foo = tint_symbol_3.foo; + tint_out.bar = sub_func_ep_1(tint_out, 1.0f, foo); return tint_out; } @@ -566,40 +575,44 @@ ep_1_out ep_1() { TEST_F( HlslGeneratorImplTest_Function, Emit_Decoration_Called_By_EntryPoints_WithBuiltinGlobals_And_Params) { // NOLINT - Global("coord", ty.vec4(), ast::StorageClass::kInput, nullptr, - ast::DecorationList{ - create(ast::Builtin::kFragCoord), - }); + auto* coord_in = + Var("coord", ty.vec4(), ast::StorageClass::kNone, nullptr, + ast::DecorationList{ + create(ast::Builtin::kFragCoord), + }); + // TODO(jrprice): Make this the return value when supported. Global("depth", ty.f32(), ast::StorageClass::kOutput, nullptr, ast::DecorationList{ create(ast::Builtin::kFragDepth), }); - Func("sub_func", - ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kFunction)}, - ty.f32(), - ast::StatementList{ - create(Expr("depth"), - MemberAccessor("coord", "x")), - create(Expr("param")), - }, - ast::DecorationList{}); + Func( + "sub_func", + ast::VariableList{Var("param", ty.f32(), ast::StorageClass::kNone), + Var("coord", ty.vec4(), ast::StorageClass::kNone)}, + ty.f32(), + ast::StatementList{ + create(Expr("depth"), + MemberAccessor("coord", "x")), + create(Expr("param")), + }, + ast::DecorationList{}); - Func("ep_1", ast::VariableList{}, ty.void_(), + Func("ep_1", ast::VariableList{coord_in}, ty.void_(), ast::StatementList{ - create(Expr("depth"), - Call("sub_func", 1.0f)), + create( + Expr("depth"), Call("sub_func", 1.0f, Expr("coord"))), create(), }, ast::DecorationList{ create(ast::PipelineStage::kFragment), }); - GeneratorImpl& gen = Build(); + GeneratorImpl& gen = SanitizeAndBuild(); ASSERT_TRUE(gen.Generate(out)) << gen.error(); - EXPECT_EQ(result(), R"(struct ep_1_in { + EXPECT_EQ(result(), R"(struct tint_symbol_2 { float4 coord : SV_Position; }; @@ -607,14 +620,15 @@ struct ep_1_out { float depth : SV_Depth; }; -float sub_func_ep_1(in ep_1_in tint_in, out ep_1_out tint_out, float param) { - tint_out.depth = tint_in.coord.x; +float sub_func_ep_1(out ep_1_out tint_out, float param, float4 coord) { + tint_out.depth = coord.x; return param; } -ep_1_out ep_1(ep_1_in tint_in) { +ep_1_out ep_1(tint_symbol_2 tint_symbol_3) { ep_1_out tint_out; - tint_out.depth = sub_func_ep_1(tint_in, tint_out, 1.0f); + const float4 coord = tint_symbol_3.coord; + tint_out.depth = sub_func_ep_1(tint_out, 1.0f, coord); return tint_out; }