diff --git a/src/dawn/native/TintUtils.cpp b/src/dawn/native/TintUtils.cpp index 74f377e31b..a2cf5dc4e7 100644 --- a/src/dawn/native/TintUtils.cpp +++ b/src/dawn/native/TintUtils.cpp @@ -154,8 +154,10 @@ tint::transform::MultiplanarExternalTexture::BindingsMap BuildExternalTextureTra tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( const RenderPipelineBase& renderPipeline, + const std::string_view& entryPoint, BindGroupIndex pullingBufferBindingSet) { tint::transform::VertexPulling::Config cfg; + cfg.entry_point_name = entryPoint; cfg.pulling_group = static_cast(pullingBufferBindingSet); cfg.vertex_state.resize(renderPipeline.GetVertexBufferCount()); diff --git a/src/dawn/native/TintUtils.h b/src/dawn/native/TintUtils.h index fb73e4d4d7..d6bea1d88f 100644 --- a/src/dawn/native/TintUtils.h +++ b/src/dawn/native/TintUtils.h @@ -45,6 +45,7 @@ tint::transform::MultiplanarExternalTexture::BindingsMap BuildExternalTextureTra tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( const RenderPipelineBase& renderPipeline, + const std::string_view& entryPoint, BindGroupIndex pullingBufferBindingSet); tint::transform::SubstituteOverride::Config BuildSubstituteOverridesTransformConfig( diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm index 8ab06286d8..e07cd1daa5 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.mm +++ b/src/dawn/native/metal/ShaderModuleMTL.mm @@ -141,8 +141,8 @@ ResultOrError> TranslateToMSL( std::optional vertexPullingTransformConfig; if (stage == SingleShaderStage::Vertex && device->IsToggleEnabled(Toggle::MetalEnableVertexPulling)) { - vertexPullingTransformConfig = - BuildVertexPullingTransformConfig(*renderPipeline, kPullingBufferBindingSet); + vertexPullingTransformConfig = BuildVertexPullingTransformConfig( + *renderPipeline, programmableStage.entryPoint.c_str(), kPullingBufferBindingSet); for (VertexBufferSlot slot : IterateBitSet(renderPipeline->GetVertexBufferSlotsUsed())) { uint32_t metalIndex = renderPipeline->GetMtlVertexBufferIndex(slot); diff --git a/src/tint/transform/vertex_pulling.cc b/src/tint/transform/vertex_pulling.cc index 00b5a065ea..2ec12d5159 100644 --- a/src/tint/transform/vertex_pulling.cc +++ b/src/tint/transform/vertex_pulling.cc @@ -882,18 +882,8 @@ void VertexPulling::Run(CloneContext& ctx, const DataMap& inputs, DataMap&) cons } // Find entry point - const ast::Function* func = nullptr; - for (auto* fn : ctx.src->AST().Functions()) { - if (fn->PipelineStage() == ast::PipelineStage::kVertex) { - if (func != nullptr) { - ctx.dst->Diagnostics().add_error( - diag::System::Transform, - "VertexPulling found more than one vertex entry point"); - return; - } - func = fn; - } - } + auto* func = ctx.src->AST().Functions().Find(ctx.src->Symbols().Get(cfg.entry_point_name), + ast::PipelineStage::kVertex); if (func == nullptr) { ctx.dst->Diagnostics().add_error(diag::System::Transform, "Vertex stage entry point not found"); diff --git a/src/tint/transform/vertex_pulling.h b/src/tint/transform/vertex_pulling.h index 6dd35bc85a..255a49a639 100644 --- a/src/tint/transform/vertex_pulling.h +++ b/src/tint/transform/vertex_pulling.h @@ -135,8 +135,6 @@ using VertexStateDescriptor = std::vector; /// code, but these are types that the data may arrive as. We need to convert /// these smaller types into the base types such as `f32` and `u32` for the /// shader to use. -/// -/// The SingleEntryPoint transform must have run before VertexPulling. class VertexPulling final : public Castable { public: /// Configuration options for the transform @@ -154,6 +152,9 @@ class VertexPulling final : public Castable { /// @returns this Config Config& operator=(const Config&); + /// The entry point to add assignments into + std::string entry_point_name; + /// The vertex state descriptor, containing info about attributes VertexStateDescriptor vertex_state; @@ -162,7 +163,7 @@ class VertexPulling final : public Castable { uint32_t pulling_group = 4u; /// Reflect the fields of this class so that it can be used by tint::ForeachField() - TINT_REFLECT(vertex_state, pulling_group); + TINT_REFLECT(entry_point_name, vertex_state, pulling_group); }; /// Constructor diff --git a/src/tint/transform/vertex_pulling_test.cc b/src/tint/transform/vertex_pulling_test.cc index 7c774c579d..5fb8b1cd40 100644 --- a/src/tint/transform/vertex_pulling_test.cc +++ b/src/tint/transform/vertex_pulling_test.cc @@ -35,21 +35,18 @@ TEST_F(VertexPullingTest, Error_NoEntryPoint) { EXPECT_EQ(expect, str(got)); } -TEST_F(VertexPullingTest, Error_MultipleEntryPoint) { +TEST_F(VertexPullingTest, Error_InvalidEntryPoint) { auto* src = R"( @vertex fn main() -> @builtin(position) vec4 { return vec4(); } -@vertex -fn main2() -> @builtin(position) vec4 { - return vec4(); -} )"; - auto* expect = "error: VertexPulling found more than one vertex entry point"; + auto* expect = "error: Vertex stage entry point not found"; VertexPulling::Config cfg; + cfg.entry_point_name = "_"; DataMap data; data.Add(cfg); @@ -67,6 +64,7 @@ fn main() {} auto* expect = "error: Vertex stage entry point not found"; VertexPulling::Config cfg; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -89,6 +87,7 @@ fn main(@location(0) var_a : f32) -> @builtin(position) vec4 { VertexPulling::Config cfg; cfg.vertex_state = {{{15, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -117,6 +116,7 @@ fn main() -> @builtin(position) vec4 { )"; VertexPulling::Config cfg; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -153,6 +153,7 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi VertexPulling::Config cfg; cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -189,6 +190,7 @@ fn main(@builtin(instance_index) tint_pulling_instance_index : u32) -> @builtin( VertexPulling::Config cfg; cfg.vertex_state = {{{4, VertexStepMode::kInstance, {{VertexFormat::kFloat32, 0, 0}}}}}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -226,6 +228,7 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi VertexPulling::Config cfg; cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}}; cfg.pulling_group = 5; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -271,6 +274,7 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi VertexPulling::Config cfg; cfg.vertex_state = {{{4, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}}}}}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -328,6 +332,7 @@ fn main(@builtin(vertex_index) custom_vertex_index : u32, @builtin(instance_inde {{VertexFormat::kFloat32, 0, 1}}, }, }}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -406,6 +411,7 @@ fn main(tint_symbol_1 : tint_symbol) -> @builtin(position) vec4 { {{VertexFormat::kFloat32, 0, 1}}, }, }}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -484,6 +490,7 @@ struct Inputs { {{VertexFormat::kFloat32, 0, 1}}, }, }}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -559,6 +566,7 @@ fn main(indices : Indices) -> @builtin(position) vec4 { {{VertexFormat::kFloat32, 0, 1}}, }, }}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -634,6 +642,7 @@ struct Indices { {{VertexFormat::kFloat32, 0, 1}}, }, }}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -675,6 +684,7 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi cfg.vertex_state = {{{16, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -728,6 +738,7 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi {12, VertexStepMode::kVertex, {{VertexFormat::kFloat32x3, 0, 1}}}, {16, VertexStepMode::kVertex, {{VertexFormat::kFloat32x4, 0, 2}}}, }}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -777,6 +788,7 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index_1 : u32) -> @builtin(po cfg.vertex_state = {{{16, VertexStepMode::kVertex, {{VertexFormat::kFloat32, 0, 0}, {VertexFormat::kFloat32x4, 0, 1}}}}}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -921,6 +933,7 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi {VertexFormat::kSint32, 64, 26}, {VertexFormat::kSint32x2, 64, 27}, {VertexFormat::kSint32x3, 64, 28}, {VertexFormat::kSint32x4, 64, 29}, }}}}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -1066,6 +1079,7 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi {VertexFormat::kSint32, 63, 26}, {VertexFormat::kSint32x2, 63, 27}, {VertexFormat::kSint32x3, 63, 28}, {VertexFormat::kSint32x4, 63, 29}, }}}}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg); @@ -1210,6 +1224,7 @@ fn main(@builtin(vertex_index) tint_pulling_vertex_index : u32) -> @builtin(posi {VertexFormat::kSint32, 64, 26}, {VertexFormat::kSint32x2, 64, 27}, {VertexFormat::kSint32x3, 64, 28}, {VertexFormat::kSint32x4, 64, 29}, }}}}; + cfg.entry_point_name = "main"; DataMap data; data.Add(cfg);