diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp index a5392a1f7e..19a665e76d 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp @@ -297,6 +297,18 @@ ResultOrError TranslateToHLSL( tint::transform::Manager transformManager; tint::transform::DataMap transformInputs; + // Run before the renamer so that the entry point name matches `entryPointName` still. + transformManager.Add(); + transformInputs.Add(r.entryPointName.data()); + + // Needs to run before all other transforms so that they can use builtin names safely. + transformManager.Add(); + if (r.disableSymbolRenaming) { + // We still need to rename HLSL reserved keywords + transformInputs.Add( + tint::transform::Renamer::Target::kHlslKeywords); + } + if (!r.newBindingsMap.empty()) { transformManager.Add(); transformInputs.Add( @@ -315,17 +327,6 @@ ResultOrError TranslateToHLSL( transformManager.Add(); - transformManager.Add(); - transformInputs.Add(r.entryPointName.data()); - - transformManager.Add(); - - if (r.disableSymbolRenaming) { - // We still need to rename HLSL reserved keywords - transformInputs.Add( - tint::transform::Renamer::Target::kHlslKeywords); - } - if (r.substituteOverrideConfig) { // This needs to run after SingleEntryPoint transform which removes unused overrides for // current entry point. diff --git a/src/dawn/native/metal/ShaderModuleMTL.mm b/src/dawn/native/metal/ShaderModuleMTL.mm index 8ab06286d8..80627da845 100644 --- a/src/dawn/native/metal/ShaderModuleMTL.mm +++ b/src/dawn/native/metal/ShaderModuleMTL.mm @@ -190,9 +190,18 @@ ResultOrError> TranslateToMSL( // We only remap bindings for the target entry point, so we need to strip all other // entry points to avoid generating invalid bindings for them. + // Run before the renamer so that the entry point name matches `entryPointName` still. transformManager.Add(); transformInputs.Add(r.entryPointName); + // Needs to run before all other transforms so that they can use builtin names safely. + transformManager.Add(); + if (r.disableSymbolRenaming) { + // We still need to rename MSL reserved keywords + transformInputs.Add( + tint::transform::Renamer::Target::kMslKeywords); + } + if (!r.externalTextureBindings.empty()) { transformManager.Add(); transformInputs.Add( @@ -221,14 +230,6 @@ ResultOrError> TranslateToMSL( BindingRemapper::AccessControls{}, /* mayCollide */ true); - transformManager.Add(); - - if (r.disableSymbolRenaming) { - // We still need to rename MSL reserved keywords - transformInputs.Add( - tint::transform::Renamer::Target::kMslKeywords); - } - tint::Program program; tint::transform::DataMap transformOutputs; { diff --git a/src/dawn/native/vulkan/ComputePipelineVk.cpp b/src/dawn/native/vulkan/ComputePipelineVk.cpp index ea8c59a6bf..49522e2aa1 100644 --- a/src/dawn/native/vulkan/ComputePipelineVk.cpp +++ b/src/dawn/native/vulkan/ComputePipelineVk.cpp @@ -64,7 +64,7 @@ MaybeError ComputePipeline::Initialize() { module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout)); createInfo.stage.module = moduleAndSpirv.module; - createInfo.stage.pName = computeStage.entryPoint.c_str(); + createInfo.stage.pName = moduleAndSpirv.remappedEntryPoint; createInfo.stage.pSpecializationInfo = nullptr; diff --git a/src/dawn/native/vulkan/RenderPipelineVk.cpp b/src/dawn/native/vulkan/RenderPipelineVk.cpp index df61ac77b3..0ca94216d0 100644 --- a/src/dawn/native/vulkan/RenderPipelineVk.cpp +++ b/src/dawn/native/vulkan/RenderPipelineVk.cpp @@ -358,7 +358,7 @@ MaybeError RenderPipeline::Initialize() { shaderStage.pNext = nullptr; shaderStage.flags = 0; shaderStage.pSpecializationInfo = nullptr; - shaderStage.pName = programmableStage.entryPoint.c_str(); + shaderStage.pName = moduleAndSpirv.remappedEntryPoint; switch (stage) { case dawn::native::SingleShaderStage::Vertex: { diff --git a/src/dawn/native/vulkan/ShaderModuleVk.cpp b/src/dawn/native/vulkan/ShaderModuleVk.cpp index b5796d6170..4703954946 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.cpp +++ b/src/dawn/native/vulkan/ShaderModuleVk.cpp @@ -21,6 +21,7 @@ #include #include "dawn/native/CacheRequest.h" +#include "dawn/native/Serializable.h" #include "dawn/native/SpirvValidation.h" #include "dawn/native/TintUtils.h" #include "dawn/native/vulkan/BindGroupLayoutVk.h" @@ -35,30 +36,13 @@ namespace dawn::native::vulkan { -// Spirv is a wrapper around Blob that exposes the data as uint32_t words. -class ShaderModule::Spirv : private Blob { - public: - static Spirv FromBlob(Blob&& blob) { - // Vulkan drivers expect the SPIRV to be aligned like an array of uint32_t values. - blob.AlignTo(alignof(uint32_t)); - return static_cast(blob); - } +#define COMPILED_SPIRV_MEMBERS(X) \ + X(std::vector, spirv) \ + X(std::string, remappedEntryPoint) - const Blob& ToBlob() const { return *this; } - - static Spirv Create(std::vector code) { - Blob blob = CreateBlob(std::move(code)); - ASSERT(IsPtrAligned(blob.Data(), alignof(uint32_t))); - return static_cast(std::move(blob)); - } - - const uint32_t* Code() const { return reinterpret_cast(Data()); } - size_t WordCount() const { return Size() / sizeof(uint32_t); } -}; - -} // namespace dawn::native::vulkan - -namespace dawn::native::vulkan { +// Represents the result and metadata for a SPIR-V compilation. +DAWN_SERIALIZABLE(struct, CompiledSpirv, COMPILED_SPIRV_MEMBERS){}; +#undef COMPILED_SPIRV_MEMBERS bool TransformedShaderModuleCacheKey::operator==( const TransformedShaderModuleCacheKey& other) const { @@ -84,16 +68,62 @@ size_t TransformedShaderModuleCacheKeyHashFunc::operator()( class ShaderModule::ConcurrentTransformedShaderModuleCache { public: - explicit ConcurrentTransformedShaderModuleCache(Device* device); - ~ConcurrentTransformedShaderModuleCache(); + explicit ConcurrentTransformedShaderModuleCache(Device* device) : mDevice(device) {} - std::optional Find(const TransformedShaderModuleCacheKey& key); + ~ConcurrentTransformedShaderModuleCache() { + std::lock_guard lock(mMutex); + + for (const auto& [_, moduleAndSpirv] : mTransformedShaderModuleCache) { + mDevice->GetFencedDeleter()->DeleteWhenUnused(moduleAndSpirv.vkModule); + } + } + + std::optional Find(const TransformedShaderModuleCacheKey& key) { + std::lock_guard lock(mMutex); + + auto iter = mTransformedShaderModuleCache.find(key); + if (iter != mTransformedShaderModuleCache.end()) { + return iter->second.AsRefs(); + } + return {}; + } ModuleAndSpirv AddOrGet(const TransformedShaderModuleCacheKey& key, VkShaderModule module, - Spirv&& spirv); + CompiledSpirv compilation) { + ASSERT(module != VK_NULL_HANDLE); + std::lock_guard lock(mMutex); + + auto iter = mTransformedShaderModuleCache.find(key); + if (iter == mTransformedShaderModuleCache.end()) { + bool added = false; + std::tie(iter, added) = mTransformedShaderModuleCache.emplace( + key, Entry{module, std::move(compilation.spirv), + std::move(compilation.remappedEntryPoint)}); + ASSERT(added); + } else { + // No need to use FencedDeleter since this shader module was just created and does + // not need to wait for queue operations to complete. + // Also, use of fenced deleter here is not thread safe. + mDevice->fn.DestroyShaderModule(mDevice->GetVkDevice(), module, nullptr); + } + return iter->second.AsRefs(); + } private: - using Entry = std::pair; + struct Entry { + VkShaderModule vkModule; + std::vector spirv; + std::string remappedEntryPoint; + + ModuleAndSpirv AsRefs() const { + return { + vkModule, + spirv.data(), + spirv.size(), + remappedEntryPoint.c_str(), + }; + } + }; Device* mDevice; std::mutex mMutex; @@ -103,57 +133,6 @@ class ShaderModule::ConcurrentTransformedShaderModuleCache { mTransformedShaderModuleCache; }; -ShaderModule::ConcurrentTransformedShaderModuleCache::ConcurrentTransformedShaderModuleCache( - Device* device) - : mDevice(device) {} - -ShaderModule::ConcurrentTransformedShaderModuleCache::~ConcurrentTransformedShaderModuleCache() { - std::lock_guard lock(mMutex); - for (const auto& [_, moduleAndSpirv] : mTransformedShaderModuleCache) { - mDevice->GetFencedDeleter()->DeleteWhenUnused(moduleAndSpirv.first); - } -} - -std::optional -ShaderModule::ConcurrentTransformedShaderModuleCache::Find( - const TransformedShaderModuleCacheKey& key) { - std::lock_guard lock(mMutex); - auto iter = mTransformedShaderModuleCache.find(key); - if (iter != mTransformedShaderModuleCache.end()) { - return ModuleAndSpirv{ - iter->second.first, - iter->second.second.Code(), - iter->second.second.WordCount(), - }; - } - return {}; -} - -ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet( - const TransformedShaderModuleCacheKey& key, - VkShaderModule module, - Spirv&& spirv) { - ASSERT(module != VK_NULL_HANDLE); - std::lock_guard lock(mMutex); - auto iter = mTransformedShaderModuleCache.find(key); - if (iter == mTransformedShaderModuleCache.end()) { - bool added = false; - std::tie(iter, added) = - mTransformedShaderModuleCache.emplace(key, std::make_pair(module, std::move(spirv))); - ASSERT(added); - } else { - // No need to use FencedDeleter since this shader module was just created and does - // not need to wait for queue operations to complete. - // Also, use of fenced deleter here is not thread safe. - mDevice->fn.DestroyShaderModule(mDevice->GetVkDevice(), module, nullptr); - } - return ModuleAndSpirv{ - iter->second.first, - iter->second.second.Code(), - iter->second.second.WordCount(), - }; -} - // static ResultOrError> ShaderModule::Create( Device* device, @@ -194,6 +173,7 @@ ShaderModule::~ShaderModule() = default; X(std::string_view, entryPointName) \ X(bool, isRobustnessEnabled) \ X(bool, disableWorkgroupInit) \ + X(bool, disableSymbolRenaming) \ X(bool, useZeroInitializeWorkgroupMemoryExtension) \ X(CacheKey::UnsafeUnkeyedValue, tracePlatform) @@ -275,6 +255,7 @@ ResultOrError ShaderModule::GetHandleAndSpirv( req.entryPointName = programmableStage.entryPoint; req.isRobustnessEnabled = GetDevice()->IsRobustnessEnabled(); req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit); + req.disableSymbolRenaming = GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming); req.useZeroInitializeWorkgroupMemoryExtension = GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension); req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform()); @@ -283,22 +264,31 @@ ResultOrError ShaderModule::GetHandleAndSpirv( const CombinedLimits& limits = GetDevice()->GetLimits(); req.limits = LimitsForCompilationRequest::Create(limits.v1); - CacheResult spirv; + CacheResult compilation; DAWN_TRY_LOAD_OR_RUN( - spirv, GetDevice(), std::move(req), Spirv::FromBlob, - [](SpirvCompilationRequest r) -> ResultOrError { + compilation, GetDevice(), std::move(req), CompiledSpirv::FromBlob, + [](SpirvCompilationRequest r) -> ResultOrError { tint::transform::Manager transformManager; + tint::transform::DataMap transformInputs; + + // Many Vulkan drivers can't handle multi-entrypoint shader modules. + // Run before the renamer so that the entry point name matches `entryPointName` still. + transformManager.append(std::make_unique()); + transformInputs.Add( + std::string(r.entryPointName)); + + // Needs to run before all other transforms so that they can use builtin names safely. + if (!r.disableSymbolRenaming) { + transformManager.Add(); + } + if (r.isRobustnessEnabled) { transformManager.append(std::make_unique()); } - // Many Vulkan drivers can't handle multi-entrypoint shader modules. - transformManager.append(std::make_unique()); + // Run the binding remapper after SingleEntryPoint to avoid collisions with // unused entryPoints. transformManager.append(std::make_unique()); - tint::transform::DataMap transformInputs; - transformInputs.Add( - std::string(r.entryPointName)); transformInputs.Add(std::move(r.bindingPoints), BindingRemapper::AccessControls{}, /* mayCollide */ false); @@ -314,18 +304,35 @@ ResultOrError ShaderModule::GetHandleAndSpirv( transformInputs.Add( std::move(r.substituteOverrideConfig).value()); } + tint::Program program; + tint::transform::DataMap transformOutputs; { TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms"); - DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram, - transformInputs, nullptr, nullptr)); + DAWN_TRY_ASSIGN(program, + RunTransforms(&transformManager, r.inputProgram, transformInputs, + &transformOutputs, nullptr)); } + // Get the entry point name after the renamer pass. + std::string remappedEntryPoint; + if (r.disableSymbolRenaming) { + remappedEntryPoint = r.entryPointName; + } else { + auto* data = transformOutputs.Get(); + ASSERT(data != nullptr); + + auto it = data->remappings.find(r.entryPointName.data()); + ASSERT(it != data->remappings.end()); + remappedEntryPoint = it->second; + } + ASSERT(remappedEntryPoint != ""); + + // Validate workgroup size after program runs transforms. if (r.stage == SingleShaderStage::Compute) { - // Validate workgroup size after program runs transforms. Extent3D _; DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize( - program, r.entryPointName.data(), r.limits)); + program, remappedEntryPoint.c_str(), r.limits)); } tint::writer::spirv::Options options; @@ -336,22 +343,25 @@ ResultOrError ShaderModule::GetHandleAndSpirv( TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "tint::writer::spirv::Generate()"); - auto result = tint::writer::spirv::Generate(&program, options); - DAWN_INVALID_IF(!result.success, "An error occured while generating SPIR-V: %s.", - result.error); + auto tintResult = tint::writer::spirv::Generate(&program, options); + DAWN_INVALID_IF(!tintResult.success, "An error occured while generating SPIR-V: %s.", + tintResult.error); - return Spirv::Create(std::move(result.spirv)); + CompiledSpirv result; + result.spirv = std::move(tintResult.spirv); + result.remappedEntryPoint = remappedEntryPoint; + return result; }); - DAWN_TRY(ValidateSpirv(GetDevice(), spirv->Code(), spirv->WordCount(), + DAWN_TRY(ValidateSpirv(GetDevice(), compilation->spirv.data(), compilation->spirv.size(), GetDevice()->IsToggleEnabled(Toggle::DumpShaders))); VkShaderModuleCreateInfo createInfo; createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO; createInfo.pNext = nullptr; createInfo.flags = 0; - createInfo.codeSize = spirv->WordCount() * sizeof(uint32_t); - createInfo.pCode = spirv->Code(); + createInfo.codeSize = compilation->spirv.size() * sizeof(uint32_t); + createInfo.pCode = compilation->spirv.data(); Device* device = ToBackend(GetDevice()); @@ -366,13 +376,13 @@ ResultOrError ShaderModule::GetHandleAndSpirv( ModuleAndSpirv moduleAndSpirv; if (newHandle != VK_NULL_HANDLE) { if (BlobCache* cache = device->GetBlobCache()) { - cache->EnsureStored(spirv); + cache->EnsureStored(compilation); } // Set the label on `newHandle` now, and not on `moduleAndSpirv.module` later // since `moduleAndSpirv.module` may be in use by multiple threads. SetDebugName(ToBackend(GetDevice()), newHandle, "Dawn_ShaderModule", GetLabel()); moduleAndSpirv = - mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, spirv.Acquire()); + mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, compilation.Acquire()); } return std::move(moduleAndSpirv); diff --git a/src/dawn/native/vulkan/ShaderModuleVk.h b/src/dawn/native/vulkan/ShaderModuleVk.h index 7c6a5dff46..b84090e993 100644 --- a/src/dawn/native/vulkan/ShaderModuleVk.h +++ b/src/dawn/native/vulkan/ShaderModuleVk.h @@ -50,11 +50,11 @@ class PipelineLayout; class ShaderModule final : public ShaderModuleBase { public: - class Spirv; struct ModuleAndSpirv { VkShaderModule module; const uint32_t* spirv; size_t wordCount; + const char* remappedEntryPoint; }; static ResultOrError> Create(Device* device, diff --git a/src/dawn/tests/end2end/ShaderTests.cpp b/src/dawn/tests/end2end/ShaderTests.cpp index b8731759a3..8909e6e359 100644 --- a/src/dawn/tests/end2end/ShaderTests.cpp +++ b/src/dawn/tests/end2end/ShaderTests.cpp @@ -896,6 +896,59 @@ TEST_P(ShaderTests, DISABLED_CheckUsageOf_chromium_disable_uniformity_analysis) )")); } +// Test that it is not possible to override the builtins in a way that breaks the robustness +// transform. +TEST_P(ShaderTests, ShaderOverridingRobustnessBuiltins) { + // TODO(dawn:1585): The OpenGL backend doesn't use the Renamer tint transform yet. + DAWN_SUPPRESS_TEST_IF(IsOpenGL() || IsOpenGLES()); + + // Make the test compute pipeline. + wgpu::ComputePipelineDescriptor cDesc; + cDesc.compute.module = utils::CreateShaderModule(device, R"( + // A fake min() function that always returns 0. + fn min(a : u32, b : u32) -> u32 { + return 0; + } + + @group(0) @binding(0) var result : u32; + @compute @workgroup_size(1) fn little_bobby_tables() { + // Prevent the SingleEntryPoint transform from removing our min(). + let forceUseOfMin = min(0, 1); + + let values = array(1, 2); + let index = 1u; + // Robustness adds transforms values[index] into values[min(index, 1u)]. + // - If our min() is called, the this will be values[0] which is 1. + // - If the correct min() is called, the this will be values[1] which is 2. + result = values[index]; + } + )"); + cDesc.compute.entryPoint = "little_bobby_tables"; + wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&cDesc); + + // Test 4-byte buffer that will receive the result. + wgpu::BufferDescriptor bufDesc; + bufDesc.size = 4; + bufDesc.usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc; + wgpu::Buffer buf = device.CreateBuffer(&bufDesc); + + wgpu::BindGroup bg = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), {{0, buf}}); + + // Run the compute pipeline. + wgpu::CommandEncoder encoder = device.CreateCommandEncoder(); + wgpu::ComputePassEncoder pass = encoder.BeginComputePass(); + pass.SetPipeline(pipeline); + pass.SetBindGroup(0, bg); + pass.DispatchWorkgroups(1); + pass.End(); + + wgpu::CommandBuffer commands = encoder.Finish(); + queue.Submit(1, &commands); + + // See the comment in the shader for why we expect a 2 here. + EXPECT_BUFFER_U32_EQ(2, buf, 0); +} + DAWN_INSTANTIATE_TEST(ShaderTests, D3D12Backend(), MetalBackend(), diff --git a/webgpu-cts/expectations.txt b/webgpu-cts/expectations.txt index 7148b95902..9043bb1543 100644 --- a/webgpu-cts/expectations.txt +++ b/webgpu-cts/expectations.txt @@ -174,6 +174,12 @@ crbug.com/dawn/1398 [ webgpu-adapter-default win ] webgpu:web_platform,copyToTex webgpu:api,validation,error_scope:current_scope:errorFilter="out-of-memory";stackDepth=100000 [ Slow ] webgpu:api,validation,error_scope:current_scope:errorFilter="validation";stackDepth=100000 [ Slow ] +################################################################################ +# Likely Intel Linux compiler slowness causing timeouts. +################################################################################ +crbug.com/dawn/1587 webgpu:shader,execution,zero_init:compute,zero_init:storageClass="function";workgroupSize=[1,1,1];batch__=9 [ Slow ] +crbug.com/dawn/1587 webgpu:shader,execution,zero_init:compute,zero_init:storageClass="private";workgroupSize=[1,1,1];batch__=9 [ Slow ] + ################################################################################ # entry_point_name_must_match failures ################################################################################