diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp index 928f7f2920..a006ddb8dc 100644 --- a/src/dawn_native/Pipeline.cpp +++ b/src/dawn_native/Pipeline.cpp @@ -29,15 +29,20 @@ namespace dawn_native { const ShaderModuleBase* module = descriptor->module; DAWN_TRY(device->ValidateObject(module)); - if (!module->HasEntryPoint(descriptor->entryPoint, stage)) { + if (!module->HasEntryPoint(descriptor->entryPoint)) { return DAWN_VALIDATION_ERROR("Entry point doesn't exist in the module"); } + const EntryPointMetadata& metadata = module->GetEntryPoint(descriptor->entryPoint); + + if (metadata.stage != stage) { + return DAWN_VALIDATION_ERROR("Entry point isn't for the correct stage"); + } + if (layout != nullptr) { - const EntryPointMetadata& metadata = - module->GetEntryPoint(descriptor->entryPoint, stage); DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout)); } + return {}; } @@ -54,7 +59,9 @@ namespace dawn_native { SingleShaderStage shaderStage = stage.first; ShaderModuleBase* module = stage.second->module; const char* entryPointName = stage.second->entryPoint; - const EntryPointMetadata& metadata = module->GetEntryPoint(entryPointName, shaderStage); + + const EntryPointMetadata& metadata = module->GetEntryPoint(entryPointName); + ASSERT(metadata.stage == shaderStage); // Record them internally. bool isFirstStage = mStageMask == wgpu::ShaderStage::None; diff --git a/src/dawn_native/PipelineLayout.cpp b/src/dawn_native/PipelineLayout.cpp index 883d9f9ead..dedef9094a 100644 --- a/src/dawn_native/PipelineLayout.cpp +++ b/src/dawn_native/PipelineLayout.cpp @@ -148,9 +148,8 @@ namespace dawn_native { // Loops over all the reflected BindGroupLayoutEntries from shaders. for (const StageAndDescriptor& stage : stages) { - SingleShaderStage shaderStage = stage.first; const EntryPointMetadata::BindingInfo& info = - stage.second->module->GetEntryPoint(stage.second->entryPoint, shaderStage).bindings; + stage.second->module->GetEntryPoint(stage.second->entryPoint).bindings; for (BindGroupIndex group(0); group < info.size(); ++group) { for (const auto& bindingIt : info[group]) { @@ -160,7 +159,7 @@ namespace dawn_native { // Create the BindGroupLayoutEntry BindGroupLayoutEntry entry = ConvertMetadataToEntry(shaderBinding); entry.binding = static_cast(bindingNumber); - entry.visibility = StageBit(shaderStage); + entry.visibility = StageBit(stage.first); // Add it to our map of all entries, if there is an existing entry, then we // need to merge, if we can. @@ -206,7 +205,7 @@ namespace dawn_native { // Sanity check in debug that the pipeline layout is compatible with the current pipeline. for (const StageAndDescriptor& stage : stages) { const EntryPointMetadata& metadata = - stage.second->module->GetEntryPoint(stage.second->entryPoint, stage.first); + stage.second->module->GetEntryPoint(stage.second->entryPoint); ASSERT(ValidateCompatibilityWithPipelineLayout(device, metadata, pipelineLayout) .IsSuccess()); } diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index c9ad7bea67..25e0a05b1f 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -330,8 +330,8 @@ namespace dawn_native { DAWN_TRY(ValidateRasterizationStateDescriptor(descriptor->rasterizationState)); } - const EntryPointMetadata& vertexMetadata = descriptor->vertexStage.module->GetEntryPoint( - descriptor->vertexStage.entryPoint, SingleShaderStage::Vertex); + const EntryPointMetadata& vertexMetadata = + descriptor->vertexStage.module->GetEntryPoint(descriptor->vertexStage.entryPoint); if ((vertexMetadata.usedVertexAttributes & ~attributesSetMask).any()) { return DAWN_VALIDATION_ERROR( "Pipeline vertex stage uses vertex buffers not in the vertex state"); @@ -352,8 +352,7 @@ namespace dawn_native { ASSERT(descriptor->fragmentStage != nullptr); const EntryPointMetadata& fragmentMetadata = - descriptor->fragmentStage->module->GetEntryPoint(descriptor->fragmentStage->entryPoint, - SingleShaderStage::Fragment); + descriptor->fragmentStage->module->GetEntryPoint(descriptor->fragmentStage->entryPoint); for (ColorAttachmentIndex i(uint8_t(0)); i < ColorAttachmentIndex(static_cast(descriptor->colorStateCount)); ++i) { DAWN_TRY(ValidateColorStateDescriptor( diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 8acee3b52d..a35d4c870f 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -844,19 +844,13 @@ namespace dawn_native { return new ShaderModuleBase(device, ObjectBase::kError); } - bool ShaderModuleBase::HasEntryPoint(const std::string& entryPoint, - SingleShaderStage stage) const { - auto entryPointsForNameIt = mEntryPoints.find(entryPoint); - if (entryPointsForNameIt == mEntryPoints.end()) { - return false; - } - return entryPointsForNameIt->second[stage] != nullptr; + bool ShaderModuleBase::HasEntryPoint(const std::string& entryPoint) const { + return mEntryPoints.count(entryPoint) > 0; } - const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint, - SingleShaderStage stage) const { - ASSERT(HasEntryPoint(entryPoint, stage)); - return *mEntryPoints.at(entryPoint)[stage]; + const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint) const { + ASSERT(HasEntryPoint(entryPoint)); + return *mEntryPoints.at(entryPoint); } size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const { @@ -921,13 +915,15 @@ namespace dawn_native { spirv_cross::Compiler compiler(mSpirv); for (const spirv_cross::EntryPoint& entryPoint : compiler.get_entry_points_and_stages()) { + ASSERT(mEntryPoints.count(entryPoint.name) == 0); + SingleShaderStage stage = ExecutionModelToShaderStage(entryPoint.execution_model); compiler.set_entry_point(entryPoint.name, entryPoint.execution_model); std::unique_ptr metadata; DAWN_TRY_ASSIGN(metadata, ExtractSpirvInfo(GetDevice(), compiler, entryPoint.name, stage)); - mEntryPoints[entryPoint.name][stage] = std::move(metadata); + mEntryPoints[entryPoint.name] = std::move(metadata); } return {}; diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index ef1f8adb07..7b5be55cf5 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -94,13 +94,12 @@ namespace dawn_native { static ShaderModuleBase* MakeError(DeviceBase* device); - // Return true iff the module has an entrypoint called `entryPoint` for stage `stage`. - bool HasEntryPoint(const std::string& entryPoint, SingleShaderStage stage) const; + // Return true iff the module has an entrypoint called `entryPoint`. + bool HasEntryPoint(const std::string& entryPoint) const; - // Returns the metadata for the given `entryPoint` and `stage`. HasEntryPoint with the same - // arguments must be true. - const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint, - SingleShaderStage stage) const; + // Returns the metadata for the given `entryPoint`. HasEntryPoint with the same argument + // must be true. + const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint) const; // Functors necessary for the unordered_set-based cache. struct HashFunc { @@ -132,7 +131,7 @@ namespace dawn_native { std::string mWgsl; // A map from [name, stage] to EntryPointMetadata. - std::unordered_map>> mEntryPoints; + std::unordered_map> mEntryPoints; }; } // namespace dawn_native diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp index 1c9fbb5cba..7dc7a29d83 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp @@ -209,7 +209,7 @@ namespace dawn_native { namespace d3d12 { compiler.set_entry_point(entryPointName, ShaderStageToExecutionModel(stage)); const EntryPointMetadata::BindingInfo& moduleBindingInfo = - GetEntryPoint(entryPointName, stage).bindings; + GetEntryPoint(entryPointName).bindings; for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm index ca19ae13d7..0dce6b10ab 100644 --- a/src/dawn_native/metal/ShaderModuleMTL.mm +++ b/src/dawn_native/metal/ShaderModuleMTL.mm @@ -182,7 +182,7 @@ namespace dawn_native { namespace metal { out->needsStorageBufferLength = compiler.needs_buffer_size_buffer(); if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && - GetEntryPoint(entryPointName, stage).usedVertexAttributes.any()) { + GetEntryPoint(entryPointName).usedVertexAttributes.any()) { out->needsStorageBufferLength = true; } diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp index 81c8c80054..98c542a411 100644 --- a/src/dawn_native/opengl/ShaderModuleGL.cpp +++ b/src/dawn_native/opengl/ShaderModuleGL.cpp @@ -108,8 +108,7 @@ namespace dawn_native { namespace opengl { compiler.set_name(combined.combined_id, info->GetName()); } - const EntryPointMetadata::BindingInfo& bindingInfo = - GetEntryPoint(entryPointName, stage).bindings; + const EntryPointMetadata::BindingInfo& bindingInfo = GetEntryPoint(entryPointName).bindings; // Change binding names to be "dawn_binding__". // Also unsets the SPIRV "Binding" decoration as it outputs "layout(binding=)" which diff --git a/src/tests/end2end/EntryPointTests.cpp b/src/tests/end2end/EntryPointTests.cpp index 6426d85c5d..0f27026dbe 100644 --- a/src/tests/end2end/EntryPointTests.cpp +++ b/src/tests/end2end/EntryPointTests.cpp @@ -67,54 +67,6 @@ TEST_P(EntryPointTests, FragAndVertexSameModule) { EXPECT_PIXEL_RGBA8_EQ(RGBA8::kRed, renderPass.color, 0, 0); } -// Test creating a render pipeline from two entryPoints in the same module with the same name. -TEST_P(EntryPointTests, FragAndVertexSameModuleSameName) { - // TODO: Reenable once Tint is able to produce Vulkan 1.0 / 1.1 SPIR-V. - DAWN_SKIP_TEST_IF(IsVulkan()); - - wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"( - [[builtin(position)]] var Position : vec4; - - [[stage(vertex)]] - fn main() -> void { - Position = vec4(0.0, 0.0, 0.0, 1.0); - return; - } - - [[location(0)]] var outColor : vec4; - - [[stage(fragment)]] - fn main() -> void { - outColor = vec4(1.0, 0.0, 0.0, 1.0); - return; - } - )"); - - // Create a point pipeline from the module. - utils::ComboRenderPipelineDescriptor desc(device); - desc.vertexStage.module = module; - desc.vertexStage.entryPoint = "main"; - desc.cFragmentStage.module = module; - desc.cFragmentStage.entryPoint = "main"; - desc.cColorStates[0].format = wgpu::TextureFormat::RGBA8Unorm; - desc.primitiveTopology = wgpu::PrimitiveTopology::PointList; - wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&desc); - - // Render the point and check that it was rendered. - utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, 1, 1); - wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - { - wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo); - pass.SetPipeline(pipeline); - pass.Draw(1); - pass.EndPass(); - } - wgpu::CommandBuffer commands = encoder.Finish(); - queue.Submit(1, &commands); - - EXPECT_PIXEL_RGBA8_EQ(RGBA8::kRed, renderPass.color, 0, 0); -} - // Test creating two compute pipelines from the same module. TEST_P(EntryPointTests, TwoComputeInModule) { // TODO: Reenable once Tint is able to produce Vulkan 1.0 / 1.1 SPIR-V.