diff --git a/src/transform/vertex_pulling.cc b/src/transform/vertex_pulling.cc index 98e4000753..f84bb9bc26 100644 --- a/src/transform/vertex_pulling.cc +++ b/src/transform/vertex_pulling.cc @@ -47,12 +47,14 @@ struct State { CloneContext& ctx; VertexPulling::Config const cfg; - std::unordered_map location_to_var; - Symbol vertex_index_name; - Symbol instance_index_name; + std::unordered_map> + location_to_expr; + std::function vertex_index_expr = nullptr; + std::function instance_index_expr = nullptr; Symbol pulling_position_name; Symbol struct_buffer_name; std::unordered_map vertex_buffer_names; + ast::VariableList new_function_parameters; /// Generate the vertex buffer binding name /// @param index index to append to buffer name @@ -106,7 +108,9 @@ struct State { for (auto* d : v->decorations()) { if (auto* builtin = d->As()) { if (builtin->value() == ast::Builtin::kVertexIndex) { - vertex_index_name = ctx.Clone(v->symbol()); + vertex_index_expr = [this, v]() { + return ctx.dst->Expr(ctx.Clone(v->symbol())); + }; return; } } @@ -114,11 +118,10 @@ struct State { } // We didn't find a vertex index builtin, so create one - static const char kDefaultVertexIndexName[] = "tint_pulling_vertex_index"; - vertex_index_name = ctx.dst->Symbols().New(kDefaultVertexIndexName); + auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index"); + vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); }; - ctx.dst->Global(vertex_index_name, ctx.dst->ty.u32(), - ast::StorageClass::kInput, nullptr, + ctx.dst->Global(name, ctx.dst->ty.u32(), ast::StorageClass::kInput, nullptr, ast::DecorationList{ ctx.dst->Builtin(ast::Builtin::kVertexIndex), }); @@ -147,7 +150,9 @@ struct State { for (auto* d : v->decorations()) { if (auto* builtin = d->As()) { if (builtin->value() == ast::Builtin::kInstanceIndex) { - instance_index_name = ctx.Clone(v->symbol()); + instance_index_expr = [this, v]() { + return ctx.dst->Expr(ctx.Clone(v->symbol())); + }; return; } } @@ -155,12 +160,10 @@ struct State { } // We didn't find an instance index builtin, so create one - static const char kDefaultInstanceIndexName[] = - "tint_pulling_instance_index"; - instance_index_name = ctx.dst->Symbols().New(kDefaultInstanceIndexName); + auto name = ctx.dst->Symbols().New("tint_pulling_instance_index"); + instance_index_expr = [this, name]() { return ctx.dst->Expr(name); }; - ctx.dst->Global(instance_index_name, ctx.dst->ty.u32(), - ast::StorageClass::kInput, nullptr, + ctx.dst->Global(name, ctx.dst->ty.u32(), ast::StorageClass::kInput, nullptr, ast::DecorationList{ ctx.dst->Builtin(ast::Builtin::kInstanceIndex), }); @@ -180,10 +183,12 @@ struct State { // This is where the replacement is created. Expressions use // identifier strings instead of pointers, so we don't need to update // any other place in the AST. - auto* replacement = ctx.dst->Var(ctx.Clone(v->symbol()), - ctx.Clone(v->declared_type()), + auto name = ctx.Clone(v->symbol()); + auto* replacement = ctx.dst->Var(name, ctx.Clone(v->declared_type()), ast::StorageClass::kPrivate); - location_to_var[location] = replacement; + location_to_expr[location] = [this, name]() { + return ctx.dst->Expr(name); + }; ctx.Replace(v, replacement); break; } @@ -237,30 +242,29 @@ struct State { for (const VertexAttributeDescriptor& attribute_desc : buffer_layout.attributes) { - auto it = location_to_var.find(attribute_desc.shader_location); - if (it == location_to_var.end()) { + auto it = location_to_expr.find(attribute_desc.shader_location); + if (it == location_to_expr.end()) { continue; } - auto* v = it->second; + auto* ident = it->second(); - auto name = buffer_layout.step_mode == InputStepMode::kVertex - ? vertex_index_name - : instance_index_name; + auto* index_expr = buffer_layout.step_mode == InputStepMode::kVertex + ? vertex_index_expr() + : instance_index_expr(); // An expression for the start of the read in the buffer in bytes auto* pos_value = ctx.dst->Add( - ctx.dst->Mul(name, + ctx.dst->Mul(index_expr, static_cast(buffer_layout.array_stride)), static_cast(attribute_desc.offset)); // Update position of the read - auto* set_pos_expr = ctx.dst->create( - ctx.dst->Expr(GetPullingPositionName()), pos_value); + auto* set_pos_expr = + ctx.dst->Assign(ctx.dst->Expr(GetPullingPositionName()), pos_value); stmts.emplace_back(set_pos_expr); - stmts.emplace_back(ctx.dst->create( - ctx.dst->create(v->symbol()), - AccessByFormat(i, attribute_desc.format))); + stmts.emplace_back( + ctx.dst->Assign(ident, AccessByFormat(i, attribute_desc.format))); } } @@ -379,6 +383,186 @@ struct State { return ctx.dst->create( ctx.dst->create(base_type, count), std::move(expr_list)); } + + /// Process a non-struct entry point parameter. + /// Generate function-scope variables for location parameters, and record + /// vertex_index and instance_index builtins if present. + /// @param func the entry point function + /// @param param the parameter to process + void ProcessNonStructParameter(ast::Function* func, ast::Variable* param) { + if (auto* location = + ast::GetDecoration(param->decorations())) { + // Create a function-scope variable to replace the parameter. + auto func_var_sym = ctx.Clone(param->symbol()); + auto* func_var_type = ctx.Clone(param->declared_type()); + auto* func_var = ctx.dst->Var(func_var_sym, func_var_type, + ast::StorageClass::kFunction); + ctx.InsertBefore(func->body()->statements(), *func->body()->begin(), + ctx.dst->Decl(func_var)); + // Capture mapping from location to the new variable. + location_to_expr[location->value()] = [this, func_var]() { + return ctx.dst->Expr(func_var); + }; + } else if (auto* builtin = ast::GetDecoration( + param->decorations())) { + // Check for existing vertex_index and instance_index builtins. + if (builtin->value() == ast::Builtin::kVertexIndex) { + vertex_index_expr = [this, param]() { + return ctx.dst->Expr(ctx.Clone(param->symbol())); + }; + } else if (builtin->value() == ast::Builtin::kInstanceIndex) { + instance_index_expr = [this, param]() { + return ctx.dst->Expr(ctx.Clone(param->symbol())); + }; + } + new_function_parameters.push_back(ctx.Clone(param)); + } else { + TINT_ICE(ctx.dst->Diagnostics()) << "Invalid entry point parameter"; + } + } + + /// Process a struct entry point parameter. + /// If the struct has members with location attributes, push the parameter to + /// a function-scope variable and create a new struct parameter without those + /// attributes. Record expressions for members that are vertex_index and + /// instance_index builtins. + /// @param func the entry point function + /// @param param the parameter to process + void ProcessStructParameter(ast::Function* func, ast::Variable* param) { + auto* struct_ty = param->declared_type()->As(); + if (!struct_ty) { + TINT_ICE(ctx.dst->Diagnostics()) << "Invalid struct parameter"; + } + + auto param_sym = ctx.Clone(param->symbol()); + + // Process the struct members. + bool has_locations = false; + ast::StructMemberList members_to_clone; + for (auto* member : struct_ty->impl()->members()) { + auto member_sym = ctx.Clone(member->symbol()); + std::function member_expr = [this, param_sym, + member_sym]() { + return ctx.dst->MemberAccessor(param_sym, member_sym); + }; + + if (auto* location = ast::GetDecoration( + member->decorations())) { + // Capture mapping from location to struct member. + location_to_expr[location->value()] = member_expr; + has_locations = true; + } else if (auto* builtin = ast::GetDecoration( + member->decorations())) { + // Check for existing vertex_index and instance_index builtins. + if (builtin->value() == ast::Builtin::kVertexIndex) { + vertex_index_expr = member_expr; + } else if (builtin->value() == ast::Builtin::kInstanceIndex) { + instance_index_expr = member_expr; + } + members_to_clone.push_back(member); + } else { + TINT_ICE(ctx.dst->Diagnostics()) << "Invalid entry point parameter"; + } + } + + if (!has_locations) { + // Nothing to do. + new_function_parameters.push_back(ctx.Clone(param)); + return; + } + + // Create a function-scope variable to replace the parameter. + auto* func_var = ctx.dst->Var(param_sym, ctx.Clone(param->declared_type()), + ast::StorageClass::kFunction); + ctx.InsertBefore(func->body()->statements(), *func->body()->begin(), + ctx.dst->Decl(func_var)); + + if (!members_to_clone.empty()) { + // Create a new struct without the location attributes. + ast::StructMemberList new_members; + for (auto* member : members_to_clone) { + auto member_sym = ctx.Clone(member->symbol()); + auto member_type = ctx.Clone(member->type()); + auto member_decos = ctx.Clone(member->decorations()); + new_members.push_back( + ctx.dst->Member(member_sym, member_type, std::move(member_decos))); + } + auto new_struct = + ctx.dst->Structure(ctx.dst->Symbols().New(), new_members); + + // Create a new function parameter with this struct. + auto* new_param = ctx.dst->Param(ctx.dst->Symbols().New(), new_struct); + new_function_parameters.push_back(new_param); + + // Copy values from the new parameter to the function-scope variable. + for (auto* member : members_to_clone) { + auto member_name = ctx.Clone(member->symbol()); + ctx.InsertBefore( + func->body()->statements(), *func->body()->begin(), + ctx.dst->Assign(ctx.dst->MemberAccessor(func_var, member_name), + ctx.dst->MemberAccessor(new_param, member_name))); + } + } + } + + /// Process an entry point function. + /// @param func the entry point function + void Process(ast::Function* func) { + if (func->body()->empty()) { + return; + } + + // Process entry point parameters. + for (auto* param : func->params()) { + auto* sem = ctx.src->Sem().Get(param); + if (sem->Type()->Is()) { + ProcessStructParameter(func, param); + } else { + ProcessNonStructParameter(func, param); + } + } + + // Insert new parameters for vertex_index and instance_index if needed. + if (!vertex_index_expr) { + for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { + if (layout.step_mode == InputStepMode::kVertex) { + auto name = ctx.dst->Symbols().New("tint_pulling_vertex_index"); + new_function_parameters.push_back( + ctx.dst->Param(name, ctx.dst->ty.u32(), + {ctx.dst->Builtin(ast::Builtin::kVertexIndex)})); + vertex_index_expr = [this, name]() { return ctx.dst->Expr(name); }; + break; + } + } + } + if (!instance_index_expr) { + for (const VertexBufferLayoutDescriptor& layout : cfg.vertex_state) { + if (layout.step_mode == InputStepMode::kInstance) { + auto name = ctx.dst->Symbols().New("tint_pulling_instance_index"); + new_function_parameters.push_back( + ctx.dst->Param(name, ctx.dst->ty.u32(), + {ctx.dst->Builtin(ast::Builtin::kInstanceIndex)})); + instance_index_expr = [this, name]() { return ctx.dst->Expr(name); }; + break; + } + } + } + + // Generate vertex pulling preamble. + ctx.InsertBefore(func->body()->statements(), *func->body()->begin(), + CreateVertexPullingPreamble()); + + // Rewrite the function header with the new parameters. + auto func_sym = ctx.Clone(func->symbol()); + auto ret_type = ctx.Clone(func->return_type()); + auto* body = ctx.Clone(func->body()); + auto decos = ctx.Clone(func->decorations()); + auto ret_decos = ctx.Clone(func->return_type_decorations()); + auto* new_func = ctx.dst->create( + func->source(), func_sym, new_function_parameters, ret_type, body, + std::move(decos), std::move(ret_decos)); + ctx.Replace(func, new_func); + } }; } // namespace @@ -413,18 +597,26 @@ Output VertexPulling::Run(const Program* in, const DataMap& data) { CloneContext ctx(&out, in); State state{ctx, cfg}; - state.FindOrInsertVertexIndexIfUsed(); - state.FindOrInsertInstanceIndexIfUsed(); - state.ConvertVertexInputVariablesToPrivate(); - state.AddVertexStorageBuffers(); - ctx.ReplaceAll([&](ast::Function* f) -> ast::Function* { - if (f == func) { - return CloneWithStatementsAtStart(&ctx, f, - {state.CreateVertexPullingPreamble()}); - } - return nullptr; // Just clone func - }); + if (func->params().empty()) { + // TODO(crbug.com/tint/697): Remove this path for the old shader IO syntax. + state.FindOrInsertVertexIndexIfUsed(); + state.FindOrInsertInstanceIndexIfUsed(); + state.ConvertVertexInputVariablesToPrivate(); + state.AddVertexStorageBuffers(); + + ctx.ReplaceAll([&](ast::Function* f) -> ast::Function* { + if (f == func) { + return CloneWithStatementsAtStart( + &ctx, f, {state.CreateVertexPullingPreamble()}); + } + return nullptr; // Just clone func + }); + } else { + state.AddVertexStorageBuffers(); + state.Process(func); + } + ctx.Clone(); return Output(Program(std::move(out))); diff --git a/src/transform/vertex_pulling_test.cc b/src/transform/vertex_pulling_test.cc index 8effc23c30..1e6ff9ed00 100644 --- a/src/transform/vertex_pulling_test.cc +++ b/src/transform/vertex_pulling_test.cc @@ -109,17 +109,13 @@ fn main() -> [[builtin(position)]] vec4 { TEST_F(VertexPullingTest, OneAttribute) { auto* src = R"( -[[location(0)]] var var_a : f32; - [[stage(vertex)]] -fn main() -> [[builtin(position)]] vec4 { - return vec4(); +fn main([[location(0)]] var_a : f32) -> [[builtin(position)]] vec4 { + return vec4(var_a, 0.0, 0.0, 1.0); } )"; auto* expect = R"( -[[builtin(vertex_index)]] var tint_pulling_vertex_index : u32; - [[block]] struct TintVertexData { tint_vertex_data : [[stride(4)]] array; @@ -127,16 +123,15 @@ struct TintVertexData { [[binding(0), group(4)]] var tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData; -var var_a : f32; - [[stage(vertex)]] -fn main() -> [[builtin(position)]] vec4 { +fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4 { + var var_a : f32; { var tint_pulling_pos : u32; tint_pulling_pos = ((tint_pulling_vertex_index * 4u) + 0u); var_a = bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]); } - return vec4(); + return vec4(var_a, 0.0, 0.0, 1.0); } )"; @@ -154,17 +149,13 @@ fn main() -> [[builtin(position)]] vec4 { TEST_F(VertexPullingTest, OneInstancedAttribute) { auto* src = R"( -[[location(0)]] var var_a : f32; - [[stage(vertex)]] -fn main() -> [[builtin(position)]] vec4 { - return vec4(); +fn main([[location(0)]] var_a : f32) -> [[builtin(position)]] vec4 { + return vec4(var_a, 0.0, 0.0, 1.0); } )"; auto* expect = R"( -[[builtin(instance_index)]] var tint_pulling_instance_index : u32; - [[block]] struct TintVertexData { tint_vertex_data : [[stride(4)]] array; @@ -172,16 +163,15 @@ struct TintVertexData { [[binding(0), group(4)]] var tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData; -var var_a : f32; - [[stage(vertex)]] -fn main() -> [[builtin(position)]] vec4 { +fn main([[builtin(instance_index)]] tint_pulling_instance_index : u32) -> [[builtin(position)]] vec4 { + var var_a : f32; { var tint_pulling_pos : u32; tint_pulling_pos = ((tint_pulling_instance_index * 4u) + 0u); var_a = bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]); } - return vec4(); + return vec4(var_a, 0.0, 0.0, 1.0); } )"; @@ -199,6 +189,472 @@ fn main() -> [[builtin(position)]] vec4 { TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet) { auto* src = R"( +[[stage(vertex)]] +fn main([[location(0)]] var_a : f32) -> [[builtin(position)]] vec4 { + return vec4(var_a, 0.0, 0.0, 1.0); +} +)"; + + auto* expect = R"( +[[block]] +struct TintVertexData { + tint_vertex_data : [[stride(4)]] array; +}; + +[[binding(0), group(5)]] var tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData; + +[[stage(vertex)]] +fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4 { + var var_a : f32; + { + var tint_pulling_pos : u32; + tint_pulling_pos = ((tint_pulling_vertex_index * 4u) + 0u); + var_a = bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]); + } + return vec4(var_a, 0.0, 0.0, 1.0); +} +)"; + + VertexPulling::Config cfg; + cfg.vertex_state = { + {{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}; + cfg.pulling_group = 5; + cfg.entry_point_name = "main"; + + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(VertexPullingTest, OneAttribute_Struct) { + auto* src = R"( +struct Inputs { + [[location(0)]] var_a : f32; +}; + +[[stage(vertex)]] +fn main(inputs : Inputs) -> [[builtin(position)]] vec4 { + return vec4(inputs.var_a, 0.0, 0.0, 1.0); +} +)"; + + auto* expect = R"( +[[block]] +struct TintVertexData { + tint_vertex_data : [[stride(4)]] array; +}; + +[[binding(0), group(4)]] var tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData; + +struct Inputs { + [[location(0)]] + var_a : f32; +}; + +[[stage(vertex)]] +fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4 { + var inputs : Inputs; + { + var tint_pulling_pos : u32; + tint_pulling_pos = ((tint_pulling_vertex_index * 4u) + 0u); + inputs.var_a = bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]); + } + return vec4(inputs.var_a, 0.0, 0.0, 1.0); +} +)"; + + VertexPulling::Config cfg; + cfg.vertex_state = { + {{4, InputStepMode::kVertex, {{VertexFormat::kF32, 0, 0}}}}}; + cfg.entry_point_name = "main"; + + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + + EXPECT_EQ(expect, str(got)); +} + +// We expect the transform to use an existing builtin variables if it finds them +TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) { + auto* src = R"( +[[stage(vertex)]] +fn main([[location(0)]] var_a : f32, + [[location(1)]] var_b : f32, + [[builtin(vertex_index)]] custom_vertex_index : u32, + [[builtin(instance_index)]] custom_instance_index : u32 + ) -> [[builtin(position)]] vec4 { + return vec4(var_a, var_b, 0.0, 1.0); +} +)"; + + auto* expect = R"( +[[block]] +struct TintVertexData { + tint_vertex_data : [[stride(4)]] array; +}; + +[[binding(0), group(4)]] var tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData; + +[[binding(1), group(4)]] var tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData; + +[[stage(vertex)]] +fn main([[builtin(vertex_index)]] custom_vertex_index : u32, [[builtin(instance_index)]] custom_instance_index : u32) -> [[builtin(position)]] vec4 { + var var_a : f32; + var var_b : f32; + { + var tint_pulling_pos : u32; + tint_pulling_pos = ((custom_vertex_index * 4u) + 0u); + var_a = bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]); + tint_pulling_pos = ((custom_instance_index * 4u) + 0u); + var_b = bitcast(tint_pulling_vertex_buffer_1.tint_vertex_data[(tint_pulling_pos / 4u)]); + } + return vec4(var_a, var_b, 0.0, 1.0); +} +)"; + + VertexPulling::Config cfg; + cfg.vertex_state = {{ + { + 4, + InputStepMode::kVertex, + {{VertexFormat::kF32, 0, 0}}, + }, + { + 4, + InputStepMode::kInstance, + {{VertexFormat::kF32, 0, 1}}, + }, + }}; + cfg.entry_point_name = "main"; + + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Struct) { + auto* src = R"( +struct Inputs { + [[location(0)]] var_a : f32; + [[location(1)]] var_b : f32; + [[builtin(vertex_index)]] custom_vertex_index : u32; + [[builtin(instance_index)]] custom_instance_index : u32; +}; + +[[stage(vertex)]] +fn main(inputs : Inputs) -> [[builtin(position)]] vec4 { + return vec4(inputs.var_a, inputs.var_b, 0.0, 1.0); +} +)"; + + auto* expect = R"( +[[block]] +struct TintVertexData { + tint_vertex_data : [[stride(4)]] array; +}; + +[[binding(0), group(4)]] var tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData; + +[[binding(1), group(4)]] var tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData; + +struct tint_symbol { + [[builtin(vertex_index)]] + custom_vertex_index : u32; + [[builtin(instance_index)]] + custom_instance_index : u32; +}; + +struct Inputs { + [[location(0)]] + var_a : f32; + [[location(1)]] + var_b : f32; + [[builtin(vertex_index)]] + custom_vertex_index : u32; + [[builtin(instance_index)]] + custom_instance_index : u32; +}; + +[[stage(vertex)]] +fn main(tint_symbol_1 : tint_symbol) -> [[builtin(position)]] vec4 { + var inputs : Inputs; + inputs.custom_vertex_index = tint_symbol_1.custom_vertex_index; + inputs.custom_instance_index = tint_symbol_1.custom_instance_index; + { + var tint_pulling_pos : u32; + tint_pulling_pos = ((inputs.custom_vertex_index * 4u) + 0u); + inputs.var_a = bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]); + tint_pulling_pos = ((inputs.custom_instance_index * 4u) + 0u); + inputs.var_b = bitcast(tint_pulling_vertex_buffer_1.tint_vertex_data[(tint_pulling_pos / 4u)]); + } + return vec4(inputs.var_a, inputs.var_b, 0.0, 1.0); +} +)"; + + VertexPulling::Config cfg; + cfg.vertex_state = {{ + { + 4, + InputStepMode::kVertex, + {{VertexFormat::kF32, 0, 0}}, + }, + { + 4, + InputStepMode::kInstance, + {{VertexFormat::kF32, 0, 1}}, + }, + }}; + cfg.entry_point_name = "main"; + + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_SeparateStruct) { + auto* src = R"( +struct Inputs { + [[location(0)]] var_a : f32; + [[location(1)]] var_b : f32; +}; + +struct Indices { + [[builtin(vertex_index)]] custom_vertex_index : u32; + [[builtin(instance_index)]] custom_instance_index : u32; +}; + +[[stage(vertex)]] +fn main(inputs : Inputs, indices : Indices) -> [[builtin(position)]] vec4 { + return vec4(inputs.var_a, inputs.var_b, 0.0, 1.0); +} +)"; + + auto* expect = R"( +[[block]] +struct TintVertexData { + tint_vertex_data : [[stride(4)]] array; +}; + +[[binding(0), group(4)]] var tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData; + +[[binding(1), group(4)]] var tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData; + +struct Inputs { + [[location(0)]] + var_a : f32; + [[location(1)]] + var_b : f32; +}; + +struct Indices { + [[builtin(vertex_index)]] + custom_vertex_index : u32; + [[builtin(instance_index)]] + custom_instance_index : u32; +}; + +[[stage(vertex)]] +fn main(indices : Indices) -> [[builtin(position)]] vec4 { + var inputs : Inputs; + { + var tint_pulling_pos : u32; + tint_pulling_pos = ((indices.custom_vertex_index * 4u) + 0u); + inputs.var_a = bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]); + tint_pulling_pos = ((indices.custom_instance_index * 4u) + 0u); + inputs.var_b = bitcast(tint_pulling_vertex_buffer_1.tint_vertex_data[(tint_pulling_pos / 4u)]); + } + return vec4(inputs.var_a, inputs.var_b, 0.0, 1.0); +} +)"; + + VertexPulling::Config cfg; + cfg.vertex_state = {{ + { + 4, + InputStepMode::kVertex, + {{VertexFormat::kF32, 0, 0}}, + }, + { + 4, + InputStepMode::kInstance, + {{VertexFormat::kF32, 0, 1}}, + }, + }}; + cfg.entry_point_name = "main"; + + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(VertexPullingTest, TwoAttributesSameBuffer) { + auto* src = R"( +[[stage(vertex)]] +fn main([[location(0)]] var_a : f32, + [[location(1)]] var_b : vec4) -> [[builtin(position)]] vec4 { + return vec4(); +} +)"; + + auto* expect = R"( +[[block]] +struct TintVertexData { + tint_vertex_data : [[stride(4)]] array; +}; + +[[binding(0), group(4)]] var tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData; + +[[stage(vertex)]] +fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4 { + var var_a : f32; + var var_b : vec4; + { + var tint_pulling_pos : u32; + tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u); + var_a = bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[(tint_pulling_pos / 4u)]); + tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u); + var_b = vec4(bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 12u) / 4u)])); + } + return vec4(); +} +)"; + + VertexPulling::Config cfg; + cfg.vertex_state = { + {{16, + InputStepMode::kVertex, + {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}}; + cfg.entry_point_name = "main"; + + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(VertexPullingTest, FloatVectorAttributes) { + auto* src = R"( +[[stage(vertex)]] +fn main([[location(0)]] var_a : vec2, + [[location(1)]] var_b : vec3, + [[location(2)]] var_c : vec4 + ) -> [[builtin(position)]] vec4 { + return vec4(); +} +)"; + + auto* expect = R"( +[[block]] +struct TintVertexData { + tint_vertex_data : [[stride(4)]] array; +}; + +[[binding(0), group(4)]] var tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData; + +[[binding(1), group(4)]] var tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData; + +[[binding(2), group(4)]] var tint_pulling_vertex_buffer_2 : [[access(read)]] TintVertexData; + +[[stage(vertex)]] +fn main([[builtin(vertex_index)]] tint_pulling_vertex_index : u32) -> [[builtin(position)]] vec4 { + var var_a : vec2; + var var_b : vec3; + var var_c : vec4; + { + var tint_pulling_pos : u32; + tint_pulling_pos = ((tint_pulling_vertex_index * 8u) + 0u); + var_a = vec2(bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)])); + tint_pulling_pos = ((tint_pulling_vertex_index * 12u) + 0u); + var_b = vec3(bitcast(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)])); + tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u); + var_c = vec4(bitcast(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]), bitcast(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 12u) / 4u)])); + } + return vec4(); +} +)"; + + VertexPulling::Config cfg; + cfg.vertex_state = {{ + {8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}}, + {12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}}, + {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}, + }}; + cfg.entry_point_name = "main"; + + DataMap data; + data.Add(cfg); + auto got = Run(src, data); + + EXPECT_EQ(expect, str(got)); +} + +TEST_F(VertexPullingTest, AttemptSymbolCollision) { + auto* src = R"( +[[stage(vertex)]] +fn main([[location(0)]] var_a : f32, + [[location(1)]] var_b : vec4) -> [[builtin(position)]] vec4 { + var tint_pulling_vertex_index : i32; + var tint_pulling_vertex_buffer_0 : i32; + var tint_vertex_data : i32; + var tint_pulling_pos : i32; + return vec4(); +} +)"; + + auto* expect = R"( +[[block]] +struct TintVertexData { + tint_vertex_data_1 : [[stride(4)]] array; +}; + +[[binding(0), group(4)]] var tint_pulling_vertex_buffer_0_1 : [[access(read)]] TintVertexData; + +[[stage(vertex)]] +fn main([[builtin(vertex_index)]] tint_pulling_vertex_index_1 : u32) -> [[builtin(position)]] vec4 { + var var_a : f32; + var var_b : vec4; + { + var tint_pulling_pos_1 : u32; + tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u); + var_a = bitcast(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[(tint_pulling_pos_1 / 4u)]); + tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u); + var_b = vec4(bitcast(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 0u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 4u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 8u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 12u) / 4u)])); + } + var tint_pulling_vertex_index : i32; + var tint_pulling_vertex_buffer_0 : i32; + var tint_vertex_data : i32; + var tint_pulling_pos : i32; + return vec4(); +} +)"; + + VertexPulling::Config cfg; + cfg.vertex_state = { + {{16, + InputStepMode::kVertex, + {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}}; + cfg.entry_point_name = "main"; + + DataMap data; + data.Add(cfg); + auto got = Run(src, std::move(data)); + + EXPECT_EQ(expect, str(got)); +} + +// TODO(crbug.com/tint/697): Remove this. +TEST_F(VertexPullingTest, OneAttributeDifferentOutputSet_Legacy) { + auto* src = R"( [[location(0)]] var var_a : f32; [[stage(vertex)]] @@ -243,8 +699,9 @@ fn main() -> [[builtin(position)]] vec4 { EXPECT_EQ(expect, str(got)); } +// TODO(crbug.com/tint/697): Remove this. // We expect the transform to use an existing builtin variables if it finds them -TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex) { +TEST_F(VertexPullingTest, ExistingVertexIndexAndInstanceIndex_Legacy) { auto* src = R"( [[location(0)]] var var_a : f32; [[location(1)]] var var_b : f32; @@ -310,7 +767,8 @@ fn main() -> [[builtin(position)]] vec4 { EXPECT_EQ(expect, str(got)); } -TEST_F(VertexPullingTest, TwoAttributesSameBuffer) { +// TODO(crbug.com/tint/697): Remove this. +TEST_F(VertexPullingTest, TwoAttributesSameBuffer_Legacy) { auto* src = R"( [[location(0)]] var var_a : f32; [[location(1)]] var var_b : vec4; @@ -362,127 +820,6 @@ fn main() -> [[builtin(position)]] vec4 { EXPECT_EQ(expect, str(got)); } -TEST_F(VertexPullingTest, FloatVectorAttributes) { - auto* src = R"( -[[location(0)]] var var_a : vec2; -[[location(1)]] var var_b : vec3; -[[location(2)]] var var_c : vec4; - -[[stage(vertex)]] -fn main() -> [[builtin(position)]] vec4 { - return vec4(); -} -)"; - - auto* expect = R"( -[[builtin(vertex_index)]] var tint_pulling_vertex_index : u32; - -[[block]] -struct TintVertexData { - tint_vertex_data : [[stride(4)]] array; -}; - -[[binding(0), group(4)]] var tint_pulling_vertex_buffer_0 : [[access(read)]] TintVertexData; - -[[binding(1), group(4)]] var tint_pulling_vertex_buffer_1 : [[access(read)]] TintVertexData; - -[[binding(2), group(4)]] var tint_pulling_vertex_buffer_2 : [[access(read)]] TintVertexData; - -var var_a : vec2; - -var var_b : vec3; - -var var_c : vec4; - -[[stage(vertex)]] -fn main() -> [[builtin(position)]] vec4 { - { - var tint_pulling_pos : u32; - tint_pulling_pos = ((tint_pulling_vertex_index * 8u) + 0u); - var_a = vec2(bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)])); - tint_pulling_pos = ((tint_pulling_vertex_index * 12u) + 0u); - var_b = vec3(bitcast(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast(tint_pulling_vertex_buffer_1.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)])); - tint_pulling_pos = ((tint_pulling_vertex_index * 16u) + 0u); - var_c = vec4(bitcast(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 0u) / 4u)]), bitcast(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 4u) / 4u)]), bitcast(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 8u) / 4u)]), bitcast(tint_pulling_vertex_buffer_2.tint_vertex_data[((tint_pulling_pos + 12u) / 4u)])); - } - return vec4(); -} -)"; - - VertexPulling::Config cfg; - cfg.vertex_state = {{ - {8, InputStepMode::kVertex, {{VertexFormat::kVec2F32, 0, 0}}}, - {12, InputStepMode::kVertex, {{VertexFormat::kVec3F32, 0, 1}}}, - {16, InputStepMode::kVertex, {{VertexFormat::kVec4F32, 0, 2}}}, - }}; - cfg.entry_point_name = "main"; - - DataMap data; - data.Add(cfg); - auto got = Run(src, data); - - EXPECT_EQ(expect, str(got)); -} - -TEST_F(VertexPullingTest, AttemptSymbolCollision) { - auto* src = R"( -[[location(0)]] var var_a : f32; -[[location(1)]] var var_b : vec4; - -[[stage(vertex)]] -fn main() -> [[builtin(position)]] vec4 { - var tint_pulling_vertex_index : i32; - var tint_pulling_vertex_buffer_0 : i32; - var tint_vertex_data : i32; - var tint_pulling_pos : i32; - return vec4(); -} -)"; - - auto* expect = R"( -[[builtin(vertex_index)]] var tint_pulling_vertex_index_1 : u32; - -[[block]] -struct TintVertexData { - tint_vertex_data_1 : [[stride(4)]] array; -}; - -[[binding(0), group(4)]] var tint_pulling_vertex_buffer_0_1 : [[access(read)]] TintVertexData; - -var var_a : f32; - -var var_b : vec4; - -[[stage(vertex)]] -fn main() -> [[builtin(position)]] vec4 { - { - var tint_pulling_pos_1 : u32; - tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u); - var_a = bitcast(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[(tint_pulling_pos_1 / 4u)]); - tint_pulling_pos_1 = ((tint_pulling_vertex_index_1 * 16u) + 0u); - var_b = vec4(bitcast(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 0u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 4u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 8u) / 4u)]), bitcast(tint_pulling_vertex_buffer_0_1.tint_vertex_data_1[((tint_pulling_pos_1 + 12u) / 4u)])); - } - var tint_pulling_vertex_index : i32; - var tint_pulling_vertex_buffer_0 : i32; - var tint_vertex_data : i32; - var tint_pulling_pos : i32; - return vec4(); -} -)"; - - VertexPulling::Config cfg; - cfg.vertex_state = { - {{16, - InputStepMode::kVertex, - {{VertexFormat::kF32, 0, 0}, {VertexFormat::kVec4F32, 0, 1}}}}}; - cfg.entry_point_name = "main"; - - DataMap data; - data.Add(cfg); - auto got = Run(src, std::move(data)); - - EXPECT_EQ(expect, str(got)); -} } // namespace } // namespace transform } // namespace tint