diff --git a/src/dawn/native/TintUtils.cpp b/src/dawn/native/TintUtils.cpp index a2cf5dc4e7..74f377e31b 100644 --- a/src/dawn/native/TintUtils.cpp +++ b/src/dawn/native/TintUtils.cpp @@ -154,10 +154,8 @@ tint::transform::MultiplanarExternalTexture::BindingsMap BuildExternalTextureTra tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( const RenderPipelineBase& renderPipeline, - const std::string_view& entryPoint, BindGroupIndex pullingBufferBindingSet) { tint::transform::VertexPulling::Config cfg; - cfg.entry_point_name = entryPoint; cfg.pulling_group = static_cast(pullingBufferBindingSet); cfg.vertex_state.resize(renderPipeline.GetVertexBufferCount()); diff --git a/src/dawn/native/TintUtils.h b/src/dawn/native/TintUtils.h index d6bea1d88f..fb73e4d4d7 100644 --- a/src/dawn/native/TintUtils.h +++ b/src/dawn/native/TintUtils.h @@ -45,7 +45,6 @@ tint::transform::MultiplanarExternalTexture::BindingsMap BuildExternalTextureTra tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( const RenderPipelineBase& renderPipeline, - const std::string_view& entryPoint, BindGroupIndex pullingBufferBindingSet); tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig( diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm index e07cd1daa5..8ab06286d8 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.mm +++ b/src/dawn/native/metal/ShaderModuleMTL.mm @@ -141,8 +141,8 @@ ResultOrError> TranslateToMSL( std::optional vertexPullingTransformConfig; if (stage == SingleShaderStage::Vertex && device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) { - vertexPullingTransformConfig = BuildVertexPullingTransformConfig( - *renderPipeline, programmableStage.entryPoint.c_str(), kPullingBufferBindingSet); + vertexPullingTransformConfig = + BuildVertexPullingTransformConfig(*renderPipeline, kPullingBufferBindingSet); for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot); diff --git a/src/tint/fuzzers/transform_builder.h b/src/tint/fuzzers/transform_builder.h index d2c0f61877..58675260af 100644 --- a/src/tint/fuzzers/transform_builder.h +++ b/src/tint/fuzzers/transform_builder.h @@ -178,7 +178,6 @@ class TransformBuilder { /// @param tb - TransformBuilder to add transform to static void impl(TransformBuilder* tb) { transform::VertexPulling::Config cfg; - cfg.entry_point_name = tb->builder()->build(); cfg.vertex_state = tb->builder()->vector( GenerateVertexBufferLayoutDescriptor); cfg.pulling_group = tb->builder()->build(); diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc index 2ec12d5159..00b5a065ea 100644 --- a/src/tint/transform/vertex_pulling.cc +++ b/src/tint/transform/vertex_pulling.cc @@ -882,8 +882,18 @@ void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) cons } // Find entry point - auto* func = ctx.src->AST().Functions().Find(ctx.src->Symbols().Get(cfg.entry_point_name), - ast::PipelineStage::kVertex); + const ast::Function* func = nullptr; + for (auto* fn : ctx.src->AST().Functions()) { + if (fn->PipelineStage() == ast::PipelineStage::kVertex) { + if (func != nullptr) { + ctx.dst->Diagnostics().add_error( + diag::System::Transform, + "VertexPulling found more than one vertex entry point"); + return; + } + func = fn; + } + } if (func == nullptr) { ctx.dst->Diagnostics().add_error(diag::System::Transform, "Vertex stage entry point not found"); diff --git a/src/tint/transform/vertex_pulling.h b/src/tint/transform/vertex_pulling.h index 255a49a639..6dd35bc85a 100644 --- a/src/tint/transform/vertex_pulling.h +++ b/src/tint/transform/vertex_pulling.h @@ -135,6 +135,8 @@ using VertexStateDescriptor = std::vector; /// code, but these are types that the data may arrive as. We need to convert /// these smaller types into the base types such as `f32` and `u32` for the /// shader to use. +/// +/// The SingleEntryPoint transform must have run before VertexPulling. class VertexPulling final : public Castable { public: /// Configuration options for the transform @@ -152,9 +154,6 @@ class VertexPulling final : public Castable { /// @returns this Config Config& operator=(const Config&); - /// The entry point to add assignments into - std::string entry_point_name; - /// The vertex state descriptor, containing info about attributes VertexStateDescriptor vertex_state; @@ -163,7 +162,7 @@ class VertexPulling final : public Castable { uint32_t pulling_group = 4u; /// Reflect the fields of this class so that it can be used by tint::ForeachField() - TINT_REFLECT(entry_point_name, vertex_state, pulling_group); + TINT_REFLECT(vertex_state, pulling_group); }; /// Constructor diff --git a/src/tint/transform/vertex_pulling_test.cc b/src/tint/transform/vertex_pulling_test.cc index 5fb8b1cd40..7c774c579d 100644 --- a/src/tint/transform/vertex_pulling_test.cc +++ b/src/tint/transform/vertex_pulling_test.cc @@ -35,18 +35,21 @@ TEST_F(VertexPullingTest, Error_NoEntryPoint) { EXPECT_EQ(expect, str(got)); } -TEST_F(VertexPullingTest, Error_InvalidEntryPoint) { +TEST_F(VertexPullingTest, Error_MultipleEntryPoint) { auto* src = R"( @vertex fn main() -> @builtin(position) vec4 { return vec4(); } +@vertex +fn main2() -> @builtin(position) vec4 { + return vec4(); +} )"; - auto* expect = "error: Vertex stage entry point not found"; + auto* expect = "error: VertexPulling found more than one vertex entry point"; VertexPulling::Config cfg; - cfg.entry_point_name = "_"; DataMap data; data.Add(cfg); @@ -64,7 +67,6 @@ fn main() {} auto* expect = "error: Vertex stage entry point not found"; VertexPulling::Config cfg; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -87,7 +89,6 @@ fn main(@location(0) var_a : f32) -> @builtin(position) vec4 { VertexPulling::Config cfg; cfg.vertex_state = {{{15, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -116,7 +117,6 @@ fn main() -> @builtin(position) vec4 { )"; VertexPulling::Config cfg; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -153,7 +153,6 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi VertexPulling::Config cfg; cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -190,7 +189,6 @@ fn main(@builtin(instance_index) tint_pulling_instance_index : u32) -> @builtin( VertexPulling::Config cfg; cfg.vertex_state = {{{4, VertexStepMode::kInstance, {{VertexFormat::kFloat32, 0, 0}}}}}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -228,7 +226,6 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi VertexPulling::Config cfg; cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}}; cfg.pulling_group = 5; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -274,7 +271,6 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi VertexPulling::Config cfg; cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -332,7 +328,6 @@ fn main(@builtin(vertex_index) custom_vertex_index : u32, @builtin(instance_inde {{VertexFormat::kFloat32, 0, 1}}, }, }}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -411,7 +406,6 @@ fn main(tint_symbol_1 : tint_symbol) -> @builtin(position) vec4 { {{VertexFormat::kFloat32, 0, 1}}, }, }}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -490,7 +484,6 @@ struct Inputs { {{VertexFormat::kFloat32, 0, 1}}, }, }}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -566,7 +559,6 @@ fn main(indices : Indices) -> @builtin(position) vec4 { {{VertexFormat::kFloat32, 0, 1}}, }, }}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -642,7 +634,6 @@ struct Indices { {{VertexFormat::kFloat32, 0, 1}}, }, }}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -684,7 +675,6 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi cfg.vertex_state = {{{16, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -738,7 +728,6 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi {12, VertexStepMode::kVertex, {{VertexFormat::kFloat32x3, 0, 1}}}, {16, VertexStepMode::kVertex, {{VertexFormat::kFloat32x4, 0, 2}}}, }}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -788,7 +777,6 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index_1 : u32) -> @builtin(po cfg.vertex_state = {{{16, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -933,7 +921,6 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi {VertexFormat::kSint32, 64, 26}, {VertexFormat::kSint32x2, 64, 27}, {VertexFormat::kSint32x3, 64, 28}, {VertexFormat::kSint32x4, 64, 29}, }}}}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -1079,7 +1066,6 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi {VertexFormat::kSint32, 63, 26}, {VertexFormat::kSint32x2, 63, 27}, {VertexFormat::kSint32x3, 63, 28}, {VertexFormat::kSint32x4, 63, 29}, }}}}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -1224,7 +1210,6 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi {VertexFormat::kSint32, 64, 26}, {VertexFormat::kSint32x2, 64, 27}, {VertexFormat::kSint32x3, 64, 28}, {VertexFormat::kSint32x4, 64, 29}, }}}}; - cfg.entry_point_name = "main"; DataMap data; data.Add(cfg);