From d42713de7aa3227f1feb405beddac4e37051b679 Mon Sep 17 00:00:00 2001 From: Corentin Wallez Date: Thu, 5 Nov 2020 13:25:16 +0000 Subject: [PATCH] Remove support for multiple entrypoints with the same name Previsouly having a ShaderModule with multiple entrypoints with the same name and different stages was valid in Dawn. However it is disallowed by the WGSL specification so change Dawn to index the ShaderModule's entrypoints only by their name (instead of name and stage). Bug: dawn:216 Change-Id: Id6fc80a03436b008c2f057bd30c70fdf240919e8 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/31665 Reviewed-by: dan sinclair Reviewed-by: Austin Eng Commit-Queue: Corentin Wallez --- src/dawn_native/Pipeline.cpp | 15 +++++-- src/dawn_native/PipelineLayout.cpp | 7 ++- src/dawn_native/RenderPipeline.cpp | 7 ++- src/dawn_native/ShaderModule.cpp | 20 ++++----- src/dawn_native/ShaderModule.h | 13 +++--- src/dawn_native/d3d12/ShaderModuleD3D12.cpp | 2 +- src/dawn_native/metal/ShaderModuleMTL.mm | 2 +- src/dawn_native/opengl/ShaderModuleGL.cpp | 3 +- src/tests/end2end/EntryPointTests.cpp | 48 --------------------- 9 files changed, 34 insertions(+), 83 deletions(-) diff --git a/src/dawn_native/Pipeline.cpp b/src/dawn_native/Pipeline.cpp index 928f7f2920..a006ddb8dc 100644 --- a/src/dawn_native/Pipeline.cpp +++ b/src/dawn_native/Pipeline.cpp @@ -29,15 +29,20 @@ namespace dawn_native { const ShaderModuleBase* module = descriptor->module; DAWN_TRY(device->ValidateObject(module)); - if (!module->HasEntryPoint(descriptor->entryPoint, stage)) { + if (!module->HasEntryPoint(descriptor->entryPoint)) { return DAWN_VALIDATION_ERROR("Entry point doesn't exist in the module"); } + const EntryPointMetadata& metadata = module->GetEntryPoint(descriptor->entryPoint); + + if (metadata.stage != stage) { + return DAWN_VALIDATION_ERROR("Entry point isn't for the correct stage"); + } + if (layout != nullptr) { - const EntryPointMetadata& metadata = - module->GetEntryPoint(descriptor->entryPoint, stage); DAWN_TRY(ValidateCompatibilityWithPipelineLayout(device, metadata, layout)); } + return {}; } @@ -54,7 +59,9 @@ namespace dawn_native { SingleShaderStage shaderStage = stage.first; ShaderModuleBase* module = stage.second->module; const char* entryPointName = stage.second->entryPoint; - const EntryPointMetadata& metadata = module->GetEntryPoint(entryPointName, shaderStage); + + const EntryPointMetadata& metadata = module->GetEntryPoint(entryPointName); + ASSERT(metadata.stage == shaderStage); // Record them internally. bool isFirstStage = mStageMask == wgpu::ShaderStage::None; diff --git a/src/dawn_native/PipelineLayout.cpp b/src/dawn_native/PipelineLayout.cpp index 883d9f9ead..dedef9094a 100644 --- a/src/dawn_native/PipelineLayout.cpp +++ b/src/dawn_native/PipelineLayout.cpp @@ -148,9 +148,8 @@ namespace dawn_native { // Loops over all the reflected BindGroupLayoutEntries from shaders. for (const StageAndDescriptor& stage : stages) { - SingleShaderStage shaderStage = stage.first; const EntryPointMetadata::BindingInfo& info = - stage.second->module->GetEntryPoint(stage.second->entryPoint, shaderStage).bindings; + stage.second->module->GetEntryPoint(stage.second->entryPoint).bindings; for (BindGroupIndex group(0); group < info.size(); ++group) { for (const auto& bindingIt : info[group]) { @@ -160,7 +159,7 @@ namespace dawn_native { // Create the BindGroupLayoutEntry BindGroupLayoutEntry entry = ConvertMetadataToEntry(shaderBinding); entry.binding = static_cast(bindingNumber); - entry.visibility = StageBit(shaderStage); + entry.visibility = StageBit(stage.first); // Add it to our map of all entries, if there is an existing entry, then we // need to merge, if we can. @@ -206,7 +205,7 @@ namespace dawn_native { // Sanity check in debug that the pipeline layout is compatible with the current pipeline. for (const StageAndDescriptor& stage : stages) { const EntryPointMetadata& metadata = - stage.second->module->GetEntryPoint(stage.second->entryPoint, stage.first); + stage.second->module->GetEntryPoint(stage.second->entryPoint); ASSERT(ValidateCompatibilityWithPipelineLayout(device, metadata, pipelineLayout) .IsSuccess()); } diff --git a/src/dawn_native/RenderPipeline.cpp b/src/dawn_native/RenderPipeline.cpp index c9ad7bea67..25e0a05b1f 100644 --- a/src/dawn_native/RenderPipeline.cpp +++ b/src/dawn_native/RenderPipeline.cpp @@ -330,8 +330,8 @@ namespace dawn_native { DAWN_TRY(ValidateRasterizationStateDescriptor(descriptor->rasterizationState)); } - const EntryPointMetadata& vertexMetadata = descriptor->vertexStage.module->GetEntryPoint( - descriptor->vertexStage.entryPoint, SingleShaderStage::Vertex); + const EntryPointMetadata& vertexMetadata = + descriptor->vertexStage.module->GetEntryPoint(descriptor->vertexStage.entryPoint); if ((vertexMetadata.usedVertexAttributes & ~attributesSetMask).any()) { return DAWN_VALIDATION_ERROR( "Pipeline vertex stage uses vertex buffers not in the vertex state"); @@ -352,8 +352,7 @@ namespace dawn_native { ASSERT(descriptor->fragmentStage != nullptr); const EntryPointMetadata& fragmentMetadata = - descriptor->fragmentStage->module->GetEntryPoint(descriptor->fragmentStage->entryPoint, - SingleShaderStage::Fragment); + descriptor->fragmentStage->module->GetEntryPoint(descriptor->fragmentStage->entryPoint); for (ColorAttachmentIndex i(uint8_t(0)); i < ColorAttachmentIndex(static_cast(descriptor->colorStateCount)); ++i) { DAWN_TRY(ValidateColorStateDescriptor( diff --git a/src/dawn_native/ShaderModule.cpp b/src/dawn_native/ShaderModule.cpp index 8acee3b52d..a35d4c870f 100644 --- a/src/dawn_native/ShaderModule.cpp +++ b/src/dawn_native/ShaderModule.cpp @@ -844,19 +844,13 @@ namespace dawn_native { return new ShaderModuleBase(device, ObjectBase::kError); } - bool ShaderModuleBase::HasEntryPoint(const std::string& entryPoint, - SingleShaderStage stage) const { - auto entryPointsForNameIt = mEntryPoints.find(entryPoint); - if (entryPointsForNameIt == mEntryPoints.end()) { - return false; - } - return entryPointsForNameIt->second[stage] != nullptr; + bool ShaderModuleBase::HasEntryPoint(const std::string& entryPoint) const { + return mEntryPoints.count(entryPoint) > 0; } - const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint, - SingleShaderStage stage) const { - ASSERT(HasEntryPoint(entryPoint, stage)); - return *mEntryPoints.at(entryPoint)[stage]; + const EntryPointMetadata& ShaderModuleBase::GetEntryPoint(const std::string& entryPoint) const { + ASSERT(HasEntryPoint(entryPoint)); + return *mEntryPoints.at(entryPoint); } size_t ShaderModuleBase::HashFunc::operator()(const ShaderModuleBase* module) const { @@ -921,13 +915,15 @@ namespace dawn_native { spirv_cross::Compiler compiler(mSpirv); for (const spirv_cross::EntryPoint& entryPoint : compiler.get_entry_points_and_stages()) { + ASSERT(mEntryPoints.count(entryPoint.name) == 0); + SingleShaderStage stage = ExecutionModelToShaderStage(entryPoint.execution_model); compiler.set_entry_point(entryPoint.name, entryPoint.execution_model); std::unique_ptr metadata; DAWN_TRY_ASSIGN(metadata, ExtractSpirvInfo(GetDevice(), compiler, entryPoint.name, stage)); - mEntryPoints[entryPoint.name][stage] = std::move(metadata); + mEntryPoints[entryPoint.name] = std::move(metadata); } return {}; diff --git a/src/dawn_native/ShaderModule.h b/src/dawn_native/ShaderModule.h index ef1f8adb07..7b5be55cf5 100644 --- a/src/dawn_native/ShaderModule.h +++ b/src/dawn_native/ShaderModule.h @@ -94,13 +94,12 @@ namespace dawn_native { static ShaderModuleBase* MakeError(DeviceBase* device); - // Return true iff the module has an entrypoint called `entryPoint` for stage `stage`. - bool HasEntryPoint(const std::string& entryPoint, SingleShaderStage stage) const; + // Return true iff the module has an entrypoint called `entryPoint`. + bool HasEntryPoint(const std::string& entryPoint) const; - // Returns the metadata for the given `entryPoint` and `stage`. HasEntryPoint with the same - // arguments must be true. - const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint, - SingleShaderStage stage) const; + // Returns the metadata for the given `entryPoint`. HasEntryPoint with the same argument + // must be true. + const EntryPointMetadata& GetEntryPoint(const std::string& entryPoint) const; // Functors necessary for the unordered_set-based cache. struct HashFunc { @@ -132,7 +131,7 @@ namespace dawn_native { std::string mWgsl; // A map from [name, stage] to EntryPointMetadata. - std::unordered_map>> mEntryPoints; + std::unordered_map> mEntryPoints; }; } // namespace dawn_native diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp index 1c9fbb5cba..7dc7a29d83 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp @@ -209,7 +209,7 @@ namespace dawn_native { namespace d3d12 { compiler.set_entry_point(entryPointName, ShaderStageToExecutionModel(stage)); const EntryPointMetadata::BindingInfo& moduleBindingInfo = - GetEntryPoint(entryPointName, stage).bindings; + GetEntryPoint(entryPointName).bindings; for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); diff --git a/src/dawn_native/metal/ShaderModuleMTL.mm b/src/dawn_native/metal/ShaderModuleMTL.mm index ca19ae13d7..0dce6b10ab 100644 --- a/src/dawn_native/metal/ShaderModuleMTL.mm +++ b/src/dawn_native/metal/ShaderModuleMTL.mm @@ -182,7 +182,7 @@ namespace dawn_native { namespace metal { out->needsStorageBufferLength = compiler.needs_buffer_size_buffer(); if (GetDevice()->IsToggleEnabled(Toggle::MetalEnableVertexPulling) && - GetEntryPoint(entryPointName, stage).usedVertexAttributes.any()) { + GetEntryPoint(entryPointName).usedVertexAttributes.any()) { out->needsStorageBufferLength = true; } diff --git a/src/dawn_native/opengl/ShaderModuleGL.cpp b/src/dawn_native/opengl/ShaderModuleGL.cpp index 81c8c80054..98c542a411 100644 --- a/src/dawn_native/opengl/ShaderModuleGL.cpp +++ b/src/dawn_native/opengl/ShaderModuleGL.cpp @@ -108,8 +108,7 @@ namespace dawn_native { namespace opengl { compiler.set_name(combined.combined_id, info->GetName()); } - const EntryPointMetadata::BindingInfo& bindingInfo = - GetEntryPoint(entryPointName, stage).bindings; + const EntryPointMetadata::BindingInfo& bindingInfo = GetEntryPoint(entryPointName).bindings; // Change binding names to be "dawn_binding__". // Also unsets the SPIRV "Binding" decoration as it outputs "layout(binding=)" which diff --git a/src/tests/end2end/EntryPointTests.cpp b/src/tests/end2end/EntryPointTests.cpp index 6426d85c5d..0f27026dbe 100644 --- a/src/tests/end2end/EntryPointTests.cpp +++ b/src/tests/end2end/EntryPointTests.cpp @@ -67,54 +67,6 @@ TEST_P(EntryPointTests, FragAndVertexSameModule) { EXPECT_PIXEL_RGBA8_EQ(RGBA8::kRed, renderPass.color, 0, 0); } -// Test creating a render pipeline from two entryPoints in the same module with the same name. -TEST_P(EntryPointTests, FragAndVertexSameModuleSameName) { - // TODO: Reenable once Tint is able to produce Vulkan 1.0 / 1.1 SPIR-V. - DAWN_SKIP_TEST_IF(IsVulkan()); - - wgpu::ShaderModule module = utils::CreateShaderModuleFromWGSL(device, R"( - [[builtin(position)]] var Position : vec4; - - [[stage(vertex)]] - fn main() -> void { - Position = vec4(0.0, 0.0, 0.0, 1.0); - return; - } - - [[location(0)]] var outColor : vec4; - - [[stage(fragment)]] - fn main() -> void { - outColor = vec4(1.0, 0.0, 0.0, 1.0); - return; - } - )"); - - // Create a point pipeline from the module. - utils::ComboRenderPipelineDescriptor desc(device); - desc.vertexStage.module = module; - desc.vertexStage.entryPoint = "main"; - desc.cFragmentStage.module = module; - desc.cFragmentStage.entryPoint = "main"; - desc.cColorStates[0].format = wgpu::TextureFormat::RGBA8Unorm; - desc.primitiveTopology = wgpu::PrimitiveTopology::PointList; - wgpu::RenderPipeline pipeline = device.CreateRenderPipeline(&desc); - - // Render the point and check that it was rendered. - utils::BasicRenderPass renderPass = utils::CreateBasicRenderPass(device, 1, 1); - wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); - { - wgpu::RenderPassEncoder pass = encoder.BeginRenderPass(&renderPass.renderPassInfo); - pass.SetPipeline(pipeline); - pass.Draw(1); - pass.EndPass(); - } - wgpu::CommandBuffer commands = encoder.Finish(); - queue.Submit(1, &commands); - - EXPECT_PIXEL_RGBA8_EQ(RGBA8::kRed, renderPass.color, 0, 0); -} - // Test creating two compute pipelines from the same module. TEST_P(EntryPointTests, TwoComputeInModule) { // TODO: Reenable once Tint is able to produce Vulkan 1.0 / 1.1 SPIR-V.