From 05e16ed1c5ba4437cf9c7a4da36e3fb1983463e2 Mon Sep 17 00:00:00 2001 From: James Price Date: Mon, 26 Apr 2021 19:37:46 +0000 Subject: [PATCH] transform/EmitVertexPointSize: Handle entry point parameters Generate a new struct that contains members of the original return type with the point size appended to it, and replace return statements as necessary. Fixed: tint:732 Change-Id: I2b5816144d5e95c65baca95dc0c50b4dfdd25ed3 Reviewed-on: https://dawn-review.googlesource.com/c/tint/+/48980 Auto-Submit: James Price Commit-Queue: James Price Reviewed-by: Ben Clayton --- src/program_builder.h | 3 + src/transform/emit_vertex_point_size.cc | 100 +++++-- src/transform/emit_vertex_point_size_test.cc | 292 ++++++++++++++++++- 3 files changed, 358 insertions(+), 37 deletions(-) diff --git a/src/program_builder.h b/src/program_builder.h index 1b5da81f9c..9e90d5b6ad 100644 --- a/src/program_builder.h +++ b/src/program_builder.h @@ -700,6 +700,9 @@ class ProgramBuilder { // AST helper methods ////////////////////////////////////////////////////////////////////////////// + /// @return a new unnamed symbol + Symbol Sym() { return Symbols().New(); } + /// @param name the symbol string /// @return a Symbol with the given name Symbol Sym(const std::string& name) { return Symbols().Register(name); } diff --git a/src/transform/emit_vertex_point_size.cc b/src/transform/emit_vertex_point_size.cc index 9854149be8..13f76a068e 100644 --- a/src/transform/emit_vertex_point_size.cc +++ b/src/transform/emit_vertex_point_size.cc @@ -14,10 +14,13 @@ #include "src/transform/emit_vertex_point_size.h" +#include #include -#include "src/ast/assignment_statement.h" #include "src/program_builder.h" +#include "src/sem/function.h" +#include "src/sem/statement.h" +#include "src/utils/get_or_create.h" namespace tint { namespace transform { @@ -26,34 +29,85 @@ EmitVertexPointSize::EmitVertexPointSize() = default; EmitVertexPointSize::~EmitVertexPointSize() = default; Output EmitVertexPointSize::Run(const Program* in, const DataMap&) { - if (!in->AST().Functions().HasStage(ast::PipelineStage::kVertex)) { - // If the module doesn't have any vertex stages, then there's nothing to do. - return Output(Program(in->Clone())); - } - ProgramBuilder out; - CloneContext ctx(&out, in); - Symbol pointsize = out.Symbols().New("tint_pointsize"); - - // Declare the pointsize builtin output variable. - out.Global(pointsize, out.ty.f32(), ast::StorageClass::kOutput, nullptr, - ast::DecorationList{ - out.Builtin(ast::Builtin::kPointSize), - }); - - // Add the pointsize assignment statement to the front of all vertex stages. - ctx.ReplaceAll([&](ast::Function* func) -> ast::Function* { + std::unordered_map struct_map; + for (auto* func : in->AST().Functions()) { if (func->pipeline_stage() != ast::PipelineStage::kVertex) { - return nullptr; // Just clone func + continue; } - return CloneWithStatementsAtStart(&ctx, func, - { - out.Assign(pointsize, 1.0f), - }); - }); + auto* sem_func = in->Sem().Get(func); + + // Create a struct for the return type that includes a point size member. + auto* new_struct = + utils::GetOrCreate(struct_map, sem_func->ReturnType(), [&]() { + // Gather struct members. + ast::StructMemberList new_struct_members; + if (auto* struct_ty = sem_func->ReturnType()->As()) { + for (auto* member : struct_ty->impl()->members()) { + new_struct_members.push_back(ctx.Clone(member)); + } + } else { + auto* ret_type = ctx.Clone(sem_func->ReturnType()); + auto ret_type_decos = ctx.Clone(func->return_type_decorations()); + new_struct_members.push_back( + out.Member("position", ret_type, std::move(ret_type_decos))); + } + + // Append a new member for the point size. + new_struct_members.push_back( + out.Member(out.Symbols().New("tint_pointsize"), out.ty.f32(), + {out.Builtin(ast::Builtin::kPointSize)})); + + // Create the new output struct. + return out.Structure(out.Sym(), new_struct_members); + }); + + // Replace return values using new output struct type constructors. + for (auto* ret : sem_func->ReturnStatements()) { + auto* ret_sem = in->Sem().Get(ret); + + ast::ExpressionList new_ret_values; + if (auto* struct_ty = sem_func->ReturnType()->As()) { + std::function ret_value = [&]() { + return ctx.Clone(ret->value()); + }; + + if (!ret->value()->Is()) { + // Capture the original return value in a local temporary. + auto* new_struct_ty = ctx.Clone(struct_ty); + auto* temp = out.Const(out.Sym(), new_struct_ty, ret_value()); + ctx.InsertBefore(ret_sem->Block()->statements(), ret, out.Decl(temp)); + ret_value = [&, temp]() { return out.Expr(temp); }; + } + + for (auto* member : struct_ty->impl()->members()) { + auto member_sym = ctx.Clone(member->symbol()); + new_ret_values.push_back(out.MemberAccessor(ret_value(), member_sym)); + } + } else { + new_ret_values.push_back(ctx.Clone(ret->value())); + } + + // Append the point size and replace the return statement. + new_ret_values.push_back(out.Expr(1.f)); + ctx.Replace(ret, out.Return(ret->source(), + out.Construct(new_struct, new_ret_values))); + } + + // Rewrite the function header with the new return type. + auto func_sym = ctx.Clone(func->symbol()); + auto params = ctx.Clone(func->params()); + auto* body = ctx.Clone(func->body()); + auto decos = ctx.Clone(func->decorations()); + auto* new_func = out.create( + func->source(), func_sym, std::move(params), new_struct, body, + std::move(decos), ast::DecorationList{}); + ctx.Replace(func, new_func); + } + ctx.Clone(); return Output(Program(std::move(out))); diff --git a/src/transform/emit_vertex_point_size_test.cc b/src/transform/emit_vertex_point_size_test.cc index f420115a3e..0b70aee8d4 100644 --- a/src/transform/emit_vertex_point_size_test.cc +++ b/src/transform/emit_vertex_point_size_test.cc @@ -29,7 +29,6 @@ fn non_entry_a() { [[stage(vertex)]] fn entry() -> [[builtin(position)]] vec4 { - var builtin_assignments_should_happen_before_this : f32; return vec4(); } @@ -38,16 +37,19 @@ fn non_entry_b() { )"; auto* expect = R"( -[[builtin(pointsize)]] var tint_pointsize : f32; +struct tint_symbol { + [[builtin(position)]] + position : vec4; + [[builtin(pointsize)]] + tint_pointsize : f32; +}; fn non_entry_a() { } [[stage(vertex)]] -fn entry() -> [[builtin(position)]] vec4 { - tint_pointsize = 1.0; - var builtin_assignments_should_happen_before_this : f32; - return vec4(); +fn entry() -> tint_symbol { + return tint_symbol(vec4(), 1.0); } fn non_entry_b() { @@ -59,6 +61,255 @@ fn non_entry_b() { EXPECT_EQ(expect, str(got)); } +TEST_F(EmitVertexPointSizeTest, VertexStageBasic_Struct) { + auto* src = R"( +struct VertexOut { + [[builtin(position)]] + pos : vec4; + [[location(0)]] + col : f32; +}; + +fn non_entry_a() { +} + +[[stage(vertex)]] +fn entry() -> VertexOut { + var output : VertexOut; + output.pos = vec4(); + output.col = 0.5; + return output; +} + +fn non_entry_b() { +} +)"; + + auto* expect = R"( +struct tint_symbol { + [[builtin(position)]] + pos : vec4; + [[location(0)]] + col : f32; + [[builtin(pointsize)]] + tint_pointsize : f32; +}; + +struct VertexOut { + [[builtin(position)]] + pos : vec4; + [[location(0)]] + col : f32; +}; + +fn non_entry_a() { +} + +[[stage(vertex)]] +fn entry() -> tint_symbol { + var output : VertexOut; + output.pos = vec4(); + output.col = 0.5; + return tint_symbol(output.pos, output.col, 1.0); +} + +fn non_entry_b() { +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +// Make sure we capture the function return value in a temporary instead of +// re-evaluating it multiple times. +TEST_F(EmitVertexPointSizeTest, VertexStage_ReturnStructFromFunctionCall) { + auto* src = R"( +struct VertexOut { + [[builtin(position)]] + pos : vec4; + [[location(0)]] + col : f32; +}; + +fn foo() -> VertexOut { + var output : VertexOut; + output.pos = vec4(); + output.col = 0.5; + return output; +} + +[[stage(vertex)]] +fn entry() -> VertexOut { + return foo(); +} +)"; + + auto* expect = R"( +struct tint_symbol { + [[builtin(position)]] + pos : vec4; + [[location(0)]] + col : f32; + [[builtin(pointsize)]] + tint_pointsize : f32; +}; + +struct VertexOut { + [[builtin(position)]] + pos : vec4; + [[location(0)]] + col : f32; +}; + +fn foo() -> VertexOut { + var output : VertexOut; + output.pos = vec4(); + output.col = 0.5; + return output; +} + +[[stage(vertex)]] +fn entry() -> tint_symbol { + let tint_symbol_1 : VertexOut = foo(); + return tint_symbol(tint_symbol_1.pos, tint_symbol_1.col, 1.0); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(EmitVertexPointSizeTest, VertexStage_MultipleReturnStatements) { + auto* src = R"( +[[stage(vertex)]] +fn entry([[location(0)]] toggle : u32) -> [[builtin(position)]] vec4 { + if (toggle == 1u) { + return vec4(0.5, 0.5, 0.5, 0.5); + } + return vec4(1.0, 1.0, 1.0, 1.0); +} +)"; + + auto* expect = R"( +struct tint_symbol { + [[builtin(position)]] + position : vec4; + [[builtin(pointsize)]] + tint_pointsize : f32; +}; + +[[stage(vertex)]] +fn entry([[location(0)]] toggle : u32) -> tint_symbol { + if ((toggle == 1u)) { + return tint_symbol(vec4(0.5, 0.5, 0.5, 0.5), 1.0); + } + return tint_symbol(vec4(1.0, 1.0, 1.0, 1.0), 1.0); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + +// Test that we re-use generated structures when we've seen the original return +// type before. +TEST_F(EmitVertexPointSizeTest, VertexStage_MultipleShaders) { + auto* src = R"( +struct VertexOut { + [[builtin(position)]] + pos : vec4; + [[location(0)]] + col : f32; +}; + +[[stage(vertex)]] +fn entry1() -> [[builtin(position)]] vec4 { + return vec4(); +} + +[[stage(vertex)]] +fn entry2() -> [[builtin(position)]] vec4 { + return vec4(1.0, 1.0, 1.0, 1.0); +} + +[[stage(vertex)]] +fn entry3() -> VertexOut { + var output : VertexOut; + output.pos = vec4(); + output.col = 0.5; + return output; +} + +[[stage(vertex)]] +fn entry4() -> VertexOut { + var output : VertexOut; + output.pos = vec4(); + output.col = 0.75; + return output; +} + +)"; + + auto* expect = R"( +struct tint_symbol { + [[builtin(position)]] + position : vec4; + [[builtin(pointsize)]] + tint_pointsize : f32; +}; + +struct tint_symbol_1 { + [[builtin(position)]] + pos : vec4; + [[location(0)]] + col : f32; + [[builtin(pointsize)]] + tint_pointsize_1 : f32; +}; + +struct VertexOut { + [[builtin(position)]] + pos : vec4; + [[location(0)]] + col : f32; +}; + +[[stage(vertex)]] +fn entry1() -> tint_symbol { + return tint_symbol(vec4(), 1.0); +} + +[[stage(vertex)]] +fn entry2() -> tint_symbol { + return tint_symbol(vec4(1.0, 1.0, 1.0, 1.0), 1.0); +} + +[[stage(vertex)]] +fn entry3() -> tint_symbol_1 { + var output : VertexOut; + output.pos = vec4(); + output.col = 0.5; + return tint_symbol_1(output.pos, output.col, 1.0); +} + +[[stage(vertex)]] +fn entry4() -> tint_symbol_1 { + var output : VertexOut; + output.pos = vec4(); + output.col = 0.75; + return tint_symbol_1(output.pos, output.col, 1.0); +} +)"; + + auto got = Run(src); + + EXPECT_EQ(expect, str(got)); +} + TEST_F(EmitVertexPointSizeTest, NonVertexStage) { auto* src = R"( [[stage(fragment)]] @@ -87,21 +338,34 @@ fn compute_entry() { TEST_F(EmitVertexPointSizeTest, AttemptSymbolCollision) { auto* src = R"( +struct VertexOut { + [[builtin(position)]] + tint_pointsize : vec4; +}; + [[stage(vertex)]] -fn entry() -> [[builtin(position)]] vec4 { - var tint_pointsize : f32; - return vec4(); +fn entry() -> VertexOut { + return VertexOut(vec4()); } )"; auto* expect = R"( -[[builtin(pointsize)]] var tint_pointsize_1 : f32; +struct tint_symbol { + [[builtin(position)]] + tint_pointsize : vec4; + [[builtin(pointsize)]] + tint_pointsize_1 : f32; +}; + +struct VertexOut { + [[builtin(position)]] + tint_pointsize : vec4; +}; [[stage(vertex)]] -fn entry() -> [[builtin(position)]] vec4 { - tint_pointsize_1 = 1.0; - var tint_pointsize : f32; - return vec4(); +fn entry() -> tint_symbol { + let tint_symbol_1 : VertexOut = VertexOut(vec4()); + return tint_symbol(tint_symbol_1.tint_pointsize, 1.0); } )";