From 363a995068ade0f0110eef0100973b6c55744a15 Mon Sep 17 00:00:00 2001 From: Corentin Wallez Date: Mon, 31 Oct 2022 11:17:50 +0000 Subject: [PATCH] ShaderModuleVk: Add a renamer to make other transforms safe. Without this other Tint transforms may end up calling user code instead of builtins (for example for the min() used in robustness). This commit does the following changes: - Changes ShaderModuleVk to return a CompiledSpirv object instead of just a Spirv Blob so that a remappedEntryPoint can be stored in the cache alongside the SPIR-V. - Inlines the logic and simplifies TransformedConcurrentShaderModuleCache slightly (by introducing a struct instead of std::pair, and adding a conversion method to ModuleAndSpirv). - Adds the Renamer transform to ShaderModuleVk and adapt the code to use the remappedEntryPoint where needed (pipeline creation and post-compilation reflection). - Adds a test where the min() used by the robustness transform is overriden to return a constant 0. - Moves the Renamer transform to be just after the SingleEntryPoint transform in D3D12 and Metal as well so as to make the test pass. Fixed: dawn:1583 Bug: dawn:1585 dawn:1587 Change-Id: Ia9de38d391a7901ed04b097f4a8d439759f7556e Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107020 Reviewed-by: Ben Clayton Kokoro: Kokoro Commit-Queue: Corentin Wallez --- src/dawn/native/d3d12/ShaderModuleD3D12.cpp | 23 +- src/dawn/native/metal/ShaderModuleMTL.mm | 17 +- src/dawn/native/vulkan/ComputePipelineVk.cpp | 2 +- src/dawn/native/vulkan/RenderPipelineVk.cpp | 2 +- src/dawn/native/vulkan/ShaderModuleVk.cpp | 210 ++++++++++--------- src/dawn/native/vulkan/ShaderModuleVk.h | 2 +- src/dawn/tests/end2end/ShaderTests.cpp | 53 +++++ webgpu-cts/expectations.txt | 6 + 8 files changed, 193 insertions(+), 122 deletions(-) diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp index a5392a1f7e..19a665e76d 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp @@ -297,6 +297,18 @@ ResultOrError TranslateToHLSL( tint::transform::Manager transformManager; tint::transform::DataMap transformInputs; + // Run before the renamer so that the entry point name matches `entryPointName` still. + transformManager.Add(); + transformInputs.Add(r.entryPointName.data()); + + // Needs to run before all other transforms so that they can use builtin names safely. + transformManager.Add(); + if (r.disableSymbolRenaming) { + // We still need to rename HLSL reserved keywords + transformInputs.Add( + tint::transform::Renamer::Target::kHlslKeywords); + } + if (!r.newBindingsMap.empty()) { transformManager.Add(); transformInputs.Add( @@ -315,17 +327,6 @@ ResultOrError TranslateToHLSL( transformManager.Add(); - transformManager.Add(); - transformInputs.Add(r.entryPointName.data()); - - transformManager.Add(); - - if (r.disableSymbolRenaming) { - // We still need to rename HLSL reserved keywords - transformInputs.Add( - tint::transform::Renamer::Target::kHlslKeywords); - } - if (r.substituteOverrideConfig) { // This needs to run after SingleEntryPoint transform which removes unused overrides for // current entry point. diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm index 8ab06286d8..80627da845 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.mm +++ b/src/dawn/native/metal/ShaderModuleMTL.mm @@ -190,9 +190,18 @@ ResultOrError> TranslateToMSL( // We only remap bindings for the target entry point, so we need to strip all other // entry points to avoid generating invalid bindings for them. + // Run before the renamer so that the entry point name matches `entryPointName` still. transformManager.Add(); transformInputs.Add(r.entryPointName); + // Needs to run before all other transforms so that they can use builtin names safely. + transformManager.Add(); + if (r.disableSymbolRenaming) { + // We still need to rename MSL reserved keywords + transformInputs.Add( + tint::transform::Renamer::Target::kMslKeywords); + } + if (!r.externalTextureBindings.empty()) { transformManager.Add(); transformInputs.Add( @@ -221,14 +230,6 @@ ResultOrError> TranslateToMSL( BindingRemapper::AccessControls{}, /* mayCollide */ true); - transformManager.Add(); - - if (r.disableSymbolRenaming) { - // We still need to rename MSL reserved keywords - transformInputs.Add( - tint::transform::Renamer::Target::kMslKeywords); - } - tint::Program program; tint::transform::DataMap transformOutputs; { diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp index ea8c59a6bf..49522e2aa1 100644 --- a/src/dawn/native/vulkan/ComputePipelineVk.cpp +++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp @@ -64,7 +64,7 @@ MaybeError ComputePipeline::Initialize() { module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout)); createInfo.stage.module = moduleAndSpirv.module; - createInfo.stage.pName = computeStage.entryPoint.c_str(); + createInfo.stage.pName = moduleAndSpirv.remappedEntryPoint; createInfo.stage.pSpecializationInfo = nullptr; diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp index df61ac77b3..0ca94216d0 100644 --- a/src/dawn/native/vulkan/RenderPipelineVk.cpp +++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp @@ -358,7 +358,7 @@ MaybeError RenderPipeline::Initialize() { shaderStage.pNext = nullptr; shaderStage.flags = 0; shaderStage.pSpecializationInfo = nullptr; - shaderStage.pName = programmableStage.entryPoint.c_str(); + shaderStage.pName = moduleAndSpirv.remappedEntryPoint; switch (stage) { case dawn::native::SingleShaderStage::Vertex: { diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp index b5796d6170..4703954946 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp @@ -21,6 +21,7 @@ #include #include "dawn/native/CacheRequest.h" +#include "dawn/native/Serializable.h" #include "dawn/native/SpirvValidation.h" #include "dawn/native/TintUtils.h" #include "dawn/native/vulkan/BindGroupLayoutVk.h" @@ -35,30 +36,13 @@ namespace dawn::native::vulkan { -// Spirv is a wrapper around Blob that exposes the data as uint32_t words. -class ShaderModule::Spirv : private Blob { - public: - static Spirv FromBlob(Blob&& blob) { - // Vulkan drivers expect the SPIRV to be aligned like an array of uint32_t values. - blob.AlignTo(alignof(uint32_t)); - return static_cast(blob); - } +#define COMPILED_SPIRV_MEMBERS(X) \ + X(std::vector, spirv) \ + X(std::string, remappedEntryPoint) - const Blob& ToBlob() const { return *this; } - - static Spirv Create(std::vector code) { - Blob blob = CreateBlob(std::move(code)); - ASSERT(IsPtrAligned(blob.Data(), alignof(uint32_t))); - return static_cast(std::move(blob)); - } - - const uint32_t* Code() const { return reinterpret_cast(Data()); } - size_t WordCount() const { return Size() / sizeof(uint32_t); } -}; - -} // namespace dawn::native::vulkan - -namespace dawn::native::vulkan { +// Represents the result and metadata for a SPIR-V compilation. +DAWN_SERIALIZABLE(struct, CompiledSpirv, COMPILED_SPIRV_MEMBERS){}; +#undef COMPILED_SPIRV_MEMBERS bool TransformedShaderModuleCacheKey::operator==( const TransformedShaderModuleCacheKey& other) const { @@ -84,16 +68,62 @@ size_t TransformedShaderModuleCacheKeyHashFunc::operator()( class ShaderModule::ConcurrentTransformedShaderModuleCache { public: - explicit ConcurrentTransformedShaderModuleCache(Device* device); - ~ConcurrentTransformedShaderModuleCache(); + explicit ConcurrentTransformedShaderModuleCache(Device* device) : mDevice(device) {} - std::optional Find(const TransformedShaderModuleCacheKey& key); + ~ConcurrentTransformedShaderModuleCache() { + std::lock_guard lock(mMutex); + + for (const auto& [_, moduleAndSpirv] : mTransformedShaderModuleCache) { + mDevice->GetFencedDeleter()->DeleteWhenUnused(moduleAndSpirv.vkModule); + } + } + + std::optional Find(const TransformedShaderModuleCacheKey& key) { + std::lock_guard lock(mMutex); + + auto iter = mTransformedShaderModuleCache.find(key); + if (iter != mTransformedShaderModuleCache.end()) { + return iter->second.AsRefs(); + } + return {}; + } ModuleAndSpirv AddOrGet(const TransformedShaderModuleCacheKey& key, VkShaderModule module, - Spirv&& spirv); + CompiledSpirv compilation) { + ASSERT(module != VK_NULL_HANDLE); + std::lock_guard lock(mMutex); + + auto iter = mTransformedShaderModuleCache.find(key); + if (iter == mTransformedShaderModuleCache.end()) { + bool added = false; + std::tie(iter, added) = mTransformedShaderModuleCache.emplace( + key, Entry{module, std::move(compilation.spirv), + std::move(compilation.remappedEntryPoint)}); + ASSERT(added); + } else { + // No need to use FencedDeleter since this shader module was just created and does + // not need to wait for queue operations to complete. + // Also, use of fenced deleter here is not thread safe. + mDevice->fn.DestroyShaderModule(mDevice->GetVkDevice(), module, nullptr); + } + return iter->second.AsRefs(); + } private: - using Entry = std::pair; + struct Entry { + VkShaderModule vkModule; + std::vector spirv; + std::string remappedEntryPoint; + + ModuleAndSpirv AsRefs() const { + return { + vkModule, + spirv.data(), + spirv.size(), + remappedEntryPoint.c_str(), + }; + } + }; Device* mDevice; std::mutex mMutex; @@ -103,57 +133,6 @@ class ShaderModule::ConcurrentTransformedShaderModuleCache { mTransformedShaderModuleCache; }; -ShaderModule::ConcurrentTransformedShaderModuleCache::ConcurrentTransformedShaderModuleCache( - Device* device) - : mDevice(device) {} - -ShaderModule::ConcurrentTransformedShaderModuleCache::~ConcurrentTransformedShaderModuleCache() { - std::lock_guard lock(mMutex); - for (const auto& [_, moduleAndSpirv] : mTransformedShaderModuleCache) { - mDevice->GetFencedDeleter()->DeleteWhenUnused(moduleAndSpirv.first); - } -} - -std::optional -ShaderModule::ConcurrentTransformedShaderModuleCache::Find( - const TransformedShaderModuleCacheKey& key) { - std::lock_guard lock(mMutex); - auto iter = mTransformedShaderModuleCache.find(key); - if (iter != mTransformedShaderModuleCache.end()) { - return ModuleAndSpirv{ - iter->second.first, - iter->second.second.Code(), - iter->second.second.WordCount(), - }; - } - return {}; -} - -ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet( - const TransformedShaderModuleCacheKey& key, - VkShaderModule module, - Spirv&& spirv) { - ASSERT(module != VK_NULL_HANDLE); - std::lock_guard lock(mMutex); - auto iter = mTransformedShaderModuleCache.find(key); - if (iter == mTransformedShaderModuleCache.end()) { - bool added = false; - std::tie(iter, added) = - mTransformedShaderModuleCache.emplace(key, std::make_pair(module, std::move(spirv))); - ASSERT(added); - } else { - // No need to use FencedDeleter since this shader module was just created and does - // not need to wait for queue operations to complete. - // Also, use of fenced deleter here is not thread safe. - mDevice->fn.DestroyShaderModule(mDevice->GetVkDevice(), module, nullptr); - } - return ModuleAndSpirv{ - iter->second.first, - iter->second.second.Code(), - iter->second.second.WordCount(), - }; -} - // static ResultOrError> ShaderModule::Create( Device* device, @@ -194,6 +173,7 @@ ShaderModule::~ShaderModule() = default; X(std::string_view, entryPointName) \ X(bool, isRobustnessEnabled) \ X(bool, disableWorkgroupInit) \ + X(bool, disableSymbolRenaming) \ X(bool, useZeroInitializeWorkgroupMemoryExtension) \ X(CacheKey::UnsafeUnkeyedValue, tracePlatform) @@ -275,6 +255,7 @@ ResultOrError ShaderModule::GetHandleAndSpirv( req.entryPointName = programmableStage.entryPoint; req.isRobustnessEnabled = GetDevice()->IsRobustnessEnabled(); req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit); + req.disableSymbolRenaming = GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming); req.useZeroInitializeWorkgroupMemoryExtension = GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension); req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform()); @@ -283,22 +264,31 @@ ResultOrError ShaderModule::GetHandleAndSpirv( const CombinedLimits& limits = GetDevice()->GetLimits(); req.limits = LimitsForCompilationRequest::Create(limits.v1); - CacheResult spirv; + CacheResult compilation; DAWN_TRY_LOAD_OR_RUN( - spirv, GetDevice(), std::move(req), Spirv::FromBlob, - [](SpirvCompilationRequest r) -> ResultOrError { + compilation, GetDevice(), std::move(req), CompiledSpirv::FromBlob, + [](SpirvCompilationRequest r) -> ResultOrError { tint::transform::Manager transformManager; + tint::transform::DataMap transformInputs; + + // Many Vulkan drivers can't handle multi-entrypoint shader modules. + // Run before the renamer so that the entry point name matches `entryPointName` still. + transformManager.append(std::make_unique()); + transformInputs.Add( + std::string(r.entryPointName)); + + // Needs to run before all other transforms so that they can use builtin names safely. + if (!r.disableSymbolRenaming) { + transformManager.Add(); + } + if (r.isRobustnessEnabled) { transformManager.append(std::make_unique()); } - // Many Vulkan drivers can't handle multi-entrypoint shader modules. - transformManager.append(std::make_unique()); + // Run the binding remapper after SingleEntryPoint to avoid collisions with // unused entryPoints. transformManager.append(std::make_unique()); - tint::transform::DataMap transformInputs; - transformInputs.Add( - std::string(r.entryPointName)); transformInputs.Add(std::move(r.bindingPoints), BindingRemapper::AccessControls{}, /* mayCollide */ false); @@ -314,18 +304,35 @@ ResultOrError ShaderModule::GetHandleAndSpirv( transformInputs.Add( std::move(r.substituteOverrideConfig).value()); } + tint::Program program; + tint::transform::DataMap transformOutputs; { TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms"); - DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram, - transformInputs, nullptr, nullptr)); + DAWN_TRY_ASSIGN(program, + RunTransforms(&transformManager, r.inputProgram, transformInputs, + &transformOutputs, nullptr)); } + // Get the entry point name after the renamer pass. + std::string remappedEntryPoint; + if (r.disableSymbolRenaming) { + remappedEntryPoint = r.entryPointName; + } else { + auto* data = transformOutputs.Get(); + ASSERT(data != nullptr); + + auto it = data->remappings.find(r.entryPointName.data()); + ASSERT(it != data->remappings.end()); + remappedEntryPoint = it->second; + } + ASSERT(remappedEntryPoint != ""); + + // Validate workgroup size after program runs transforms. if (r.stage == SingleShaderStage::Compute) { - // Validate workgroup size after program runs transforms. Extent3D _; DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize( - program, r.entryPointName.data(), r.limits)); + program, remappedEntryPoint.c_str(), r.limits)); } tint::writer::spirv::Options options; @@ -336,22 +343,25 @@ ResultOrError ShaderModule::GetHandleAndSpirv( TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "tint::writer::spirv::Generate()"); - auto result = tint::writer::spirv::Generate(&program, options); - DAWN_INVALID_IF(!result.success, "An error occured while generating SPIR-V: %s.", - result.error); + auto tintResult = tint::writer::spirv::Generate(&program, options); + DAWN_INVALID_IF(!tintResult.success, "An error occured while generating SPIR-V: %s.", + tintResult.error); - return Spirv::Create(std::move(result.spirv)); + CompiledSpirv result; + result.spirv = std::move(tintResult.spirv); + result.remappedEntryPoint = remappedEntryPoint; + return result; }); - DAWN_TRY(ValidateSpirv(GetDevice(), spirv->Code(), spirv->WordCount(), + DAWN_TRY(ValidateSpirv(GetDevice(), compilation->spirv.data(), compilation->spirv.size(), GetDevice()->IsToggleEnabled(Toggle::DumpShaders))); VkShaderModuleCreateInfo createInfo; createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; createInfo.pNext = nullptr; createInfo.flags = 0; - createInfo.codeSize = spirv->WordCount() * sizeof(uint32_t); - createInfo.pCode = spirv->Code(); + createInfo.codeSize = compilation->spirv.size() * sizeof(uint32_t); + createInfo.pCode = compilation->spirv.data(); Device* device = ToBackend(GetDevice()); @@ -366,13 +376,13 @@ ResultOrError ShaderModule::GetHandleAndSpirv( ModuleAndSpirv moduleAndSpirv; if (newHandle != VK_NULL_HANDLE) { if (BlobCache* cache = device->GetBlobCache()) { - cache->EnsureStored(spirv); + cache->EnsureStored(compilation); } // Set the label on `newHandle` now, and not on `moduleAndSpirv.module` later // since `moduleAndSpirv.module` may be in use by multiple threads. SetDebugName(ToBackend(GetDevice()), newHandle, "Dawn_ShaderModule", GetLabel()); moduleAndSpirv = - mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, spirv.Acquire()); + mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, compilation.Acquire()); } return std::move(moduleAndSpirv); diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h index 7c6a5dff46..b84090e993 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.h +++ b/src/dawn/native/vulkan/ShaderModuleVk.h @@ -50,11 +50,11 @@ class PipelineLayout; class ShaderModule final : public ShaderModuleBase { public: - class Spirv; struct ModuleAndSpirv { VkShaderModule module; const uint32_t* spirv; size_t wordCount; + const char* remappedEntryPoint; }; static ResultOrError> Create(Device* device, diff --git a/src/dawn/tests/end2end/ShaderTests.cpp b/src/dawn/tests/end2end/ShaderTests.cpp index b8731759a3..8909e6e359 100644 --- a/src/dawn/tests/end2end/ShaderTests.cpp +++ b/src/dawn/tests/end2end/ShaderTests.cpp @@ -896,6 +896,59 @@ TEST_P(ShaderTests, DISABLED_CheckUsageOf_chromium_disable_uniformity_analysis) )")); } +// Test that it is not possible to override the builtins in a way that breaks the robustness +// transform. +TEST_P(ShaderTests, ShaderOverridingRobustnessBuiltins) { + // TODO(dawn:1585): The OpenGL backend doesn't use the Renamer tint transform yet. + DAWN_SUPPRESS_TEST_IF(IsOpenGL() || IsOpenGLES()); + + // Make the test compute pipeline. + wgpu::ComputePipelineDescriptor cDesc; + cDesc.compute.module = utils::CreateShaderModule(device, R"( + // A fake min() function that always returns 0. + fn min(a : u32, b : u32) -> u32 { + return 0; + } + + @group(0) @binding(0) var result : u32; + @compute @workgroup_size(1) fn little_bobby_tables() { + // Prevent the SingleEntryPoint transform from removing our min(). + let forceUseOfMin = min(0, 1); + + let values = array(1, 2); + let index = 1u; + // Robustness adds transforms values[index] into values[min(index, 1u)]. + // - If our min() is called, the this will be values[0] which is 1. + // - If the correct min() is called, the this will be values[1] which is 2. + result = values[index]; + } + )"); + cDesc.compute.entryPoint = "little_bobby_tables"; + wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&cDesc); + + // Test 4-byte buffer that will receive the result. + wgpu::BufferDescriptor bufDesc; + bufDesc.size = 4; + bufDesc.usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc; + wgpu::Buffer buf = device.CreateBuffer(&bufDesc); + + wgpu::BindGroup bg = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), {{0, buf}}); + + // Run the compute pipeline. + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(pipeline); + pass.SetBindGroup(0, bg); + pass.DispatchWorkgroups(1); + pass.End(); + + wgpu::CommandBuffer commands = encoder.Finish(); + queue.Submit(1, &commands); + + // See the comment in the shader for why we expect a 2 here. + EXPECT_BUFFER_U32_EQ(2, buf, 0); +} + DAWN_INSTANTIATE_TEST(ShaderTests, D3D12Backend(), MetalBackend(), diff --git a/webgpu-cts/expectations.txt b/webgpu-cts/expectations.txt index 7148b95902..9043bb1543 100644 --- a/webgpu-cts/expectations.txt +++ b/webgpu-cts/expectations.txt @@ -174,6 +174,12 @@ crbug.com/dawn/1398 [ webgpu-adapter-default win ] webgpu:web_platform,copyToTex webgpu:api,validation,error_scope:current_scope:errorFilter="out-of-memory";stackDepth=100000 [ Slow ] webgpu:api,validation,error_scope:current_scope:errorFilter="validation";stackDepth=100000 [ Slow ] +################################################################################ +# Likely Intel Linux compiler slowness causing timeouts. +################################################################################ +crbug.com/dawn/1587 webgpu:shader,execution,zero_init:compute,zero_init:storageClass="function";workgroupSize=[1,1,1];batch__=9 [ Slow ] +crbug.com/dawn/1587 webgpu:shader,execution,zero_init:compute,zero_init:storageClass="private";workgroupSize=[1,1,1];batch__=9 [ Slow ] + ################################################################################ # entry_point_name_must_match failures ################################################################################