ShaderModuleVk: Add a renamer to make other transforms safe.

Without this other Tint transforms may end up calling user code instead
of builtins (for example for the min() used in robustness).

This commit does the following changes:
 - Changes ShaderModuleVk to return a CompiledSpirv object instead of
   just a Spirv Blob so that a remappedEntryPoint can be stored in the
   cache alongside the SPIR-V.
 - Inlines the logic and simplifies TransformedConcurrentShaderModuleCache
   slightly (by introducing a struct instead of std::pair, and adding a
   conversion method to ModuleAndSpirv).
 - Adds the Renamer transform to ShaderModuleVk and adapt the code to
   use the remappedEntryPoint where needed (pipeline creation and
   post-compilation reflection).
 - Adds a test where the min() used by the robustness transform is
   overriden to return a constant 0.
 - Moves the Renamer transform to be just after the SingleEntryPoint
   transform in D3D12 and Metal as well so as to make the test pass.

Fixed: dawn:1583
Bug: dawn:1585 dawn:1587

Change-Id: Ia9de38d391a7901ed04b097f4a8d439759f7556e
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/107020
Reviewed-by: Ben Clayton <bclayton@google.com>
Kokoro: Kokoro <noreply+kokoro@google.com>
Commit-Queue: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
Corentin Wallez 2022-10-31 11:17:50 +00:00 committed by Dawn LUCI CQ
parent c14c8822b5
commit 363a995068
8 changed files with 193 additions and 122 deletions

View File

@ -297,6 +297,18 @@ ResultOrError<std::string> TranslateToHLSL(
tint::transform::Manager transformManager;
tint::transform::DataMap transformInputs;
// Run before the renamer so that the entry point name matches `entryPointName` still.
transformManager.Add<tint::transform::SingleEntryPoint>();
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(r.entryPointName.data());
// Needs to run before all other transforms so that they can use builtin names safely.
transformManager.Add<tint::transform::Renamer>();
if (r.disableSymbolRenaming) {
// We still need to rename HLSL reserved keywords
transformInputs.Add<tint::transform::Renamer::Config>(
tint::transform::Renamer::Target::kHlslKeywords);
}
if (!r.newBindingsMap.empty()) {
transformManager.Add<tint::transform::MultiplanarExternalTexture>();
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
@ -315,17 +327,6 @@ ResultOrError<std::string> TranslateToHLSL(
transformManager.Add<tint::transform::BindingRemapper>();
transformManager.Add<tint::transform::SingleEntryPoint>();
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(r.entryPointName.data());
transformManager.Add<tint::transform::Renamer>();
if (r.disableSymbolRenaming) {
// We still need to rename HLSL reserved keywords
transformInputs.Add<tint::transform::Renamer::Config>(
tint::transform::Renamer::Target::kHlslKeywords);
}
if (r.substituteOverrideConfig) {
// This needs to run after SingleEntryPoint transform which removes unused overrides for
// current entry point.

View File

@ -190,9 +190,18 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(
// We only remap bindings for the target entry point, so we need to strip all other
// entry points to avoid generating invalid bindings for them.
// Run before the renamer so that the entry point name matches `entryPointName` still.
transformManager.Add<tint::transform::SingleEntryPoint>();
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(r.entryPointName);
// Needs to run before all other transforms so that they can use builtin names safely.
transformManager.Add<tint::transform::Renamer>();
if (r.disableSymbolRenaming) {
// We still need to rename MSL reserved keywords
transformInputs.Add<tint::transform::Renamer::Config>(
tint::transform::Renamer::Target::kMslKeywords);
}
if (!r.externalTextureBindings.empty()) {
transformManager.Add<tint::transform::MultiplanarExternalTexture>();
transformInputs.Add<tint::transform::MultiplanarExternalTexture::NewBindingPoints>(
@ -221,14 +230,6 @@ ResultOrError<CacheResult<MslCompilation>> TranslateToMSL(
BindingRemapper::AccessControls{},
/* mayCollide */ true);
transformManager.Add<tint::transform::Renamer>();
if (r.disableSymbolRenaming) {
// We still need to rename MSL reserved keywords
transformInputs.Add<tint::transform::Renamer::Config>(
tint::transform::Renamer::Target::kMslKeywords);
}
tint::Program program;
tint::transform::DataMap transformOutputs;
{

View File

@ -64,7 +64,7 @@ MaybeError ComputePipeline::Initialize() {
module->GetHandleAndSpirv(SingleShaderStage::Compute, computeStage, layout));
createInfo.stage.module = moduleAndSpirv.module;
createInfo.stage.pName = computeStage.entryPoint.c_str();
createInfo.stage.pName = moduleAndSpirv.remappedEntryPoint;
createInfo.stage.pSpecializationInfo = nullptr;

View File

@ -358,7 +358,7 @@ MaybeError RenderPipeline::Initialize() {
shaderStage.pNext = nullptr;
shaderStage.flags = 0;
shaderStage.pSpecializationInfo = nullptr;
shaderStage.pName = programmableStage.entryPoint.c_str();
shaderStage.pName = moduleAndSpirv.remappedEntryPoint;
switch (stage) {
case dawn::native::SingleShaderStage::Vertex: {

View File

@ -21,6 +21,7 @@
#include <vector>
#include "dawn/native/CacheRequest.h"
#include "dawn/native/Serializable.h"
#include "dawn/native/SpirvValidation.h"
#include "dawn/native/TintUtils.h"
#include "dawn/native/vulkan/BindGroupLayoutVk.h"
@ -35,30 +36,13 @@
namespace dawn::native::vulkan {
// Spirv is a wrapper around Blob that exposes the data as uint32_t words.
class ShaderModule::Spirv : private Blob {
public:
static Spirv FromBlob(Blob&& blob) {
// Vulkan drivers expect the SPIRV to be aligned like an array of uint32_t values.
blob.AlignTo(alignof(uint32_t));
return static_cast<Spirv&&>(blob);
}
#define COMPILED_SPIRV_MEMBERS(X) \
X(std::vector<uint32_t>, spirv) \
X(std::string, remappedEntryPoint)
const Blob& ToBlob() const { return *this; }
static Spirv Create(std::vector<uint32_t> code) {
Blob blob = CreateBlob(std::move(code));
ASSERT(IsPtrAligned(blob.Data(), alignof(uint32_t)));
return static_cast<Spirv&&>(std::move(blob));
}
const uint32_t* Code() const { return reinterpret_cast<const uint32_t*>(Data()); }
size_t WordCount() const { return Size() / sizeof(uint32_t); }
};
} // namespace dawn::native::vulkan
namespace dawn::native::vulkan {
// Represents the result and metadata for a SPIR-V compilation.
DAWN_SERIALIZABLE(struct, CompiledSpirv, COMPILED_SPIRV_MEMBERS){};
#undef COMPILED_SPIRV_MEMBERS
bool TransformedShaderModuleCacheKey::operator==(
const TransformedShaderModuleCacheKey& other) const {
@ -84,16 +68,62 @@ size_t TransformedShaderModuleCacheKeyHashFunc::operator()(
class ShaderModule::ConcurrentTransformedShaderModuleCache {
public:
explicit ConcurrentTransformedShaderModuleCache(Device* device);
~ConcurrentTransformedShaderModuleCache();
explicit ConcurrentTransformedShaderModuleCache(Device* device) : mDevice(device) {}
std::optional<ModuleAndSpirv> Find(const TransformedShaderModuleCacheKey& key);
~ConcurrentTransformedShaderModuleCache() {
std::lock_guard<std::mutex> lock(mMutex);
for (const auto& [_, moduleAndSpirv] : mTransformedShaderModuleCache) {
mDevice->GetFencedDeleter()->DeleteWhenUnused(moduleAndSpirv.vkModule);
}
}
std::optional<ModuleAndSpirv> Find(const TransformedShaderModuleCacheKey& key) {
std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key);
if (iter != mTransformedShaderModuleCache.end()) {
return iter->second.AsRefs();
}
return {};
}
ModuleAndSpirv AddOrGet(const TransformedShaderModuleCacheKey& key,
VkShaderModule module,
Spirv&& spirv);
CompiledSpirv compilation) {
ASSERT(module != VK_NULL_HANDLE);
std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key);
if (iter == mTransformedShaderModuleCache.end()) {
bool added = false;
std::tie(iter, added) = mTransformedShaderModuleCache.emplace(
key, Entry{module, std::move(compilation.spirv),
std::move(compilation.remappedEntryPoint)});
ASSERT(added);
} else {
// No need to use FencedDeleter since this shader module was just created and does
// not need to wait for queue operations to complete.
// Also, use of fenced deleter here is not thread safe.
mDevice->fn.DestroyShaderModule(mDevice->GetVkDevice(), module, nullptr);
}
return iter->second.AsRefs();
}
private:
using Entry = std::pair<VkShaderModule, Spirv>;
struct Entry {
VkShaderModule vkModule;
std::vector<uint32_t> spirv;
std::string remappedEntryPoint;
ModuleAndSpirv AsRefs() const {
return {
vkModule,
spirv.data(),
spirv.size(),
remappedEntryPoint.c_str(),
};
}
};
Device* mDevice;
std::mutex mMutex;
@ -103,57 +133,6 @@ class ShaderModule::ConcurrentTransformedShaderModuleCache {
mTransformedShaderModuleCache;
};
ShaderModule::ConcurrentTransformedShaderModuleCache::ConcurrentTransformedShaderModuleCache(
Device* device)
: mDevice(device) {}
ShaderModule::ConcurrentTransformedShaderModuleCache::~ConcurrentTransformedShaderModuleCache() {
std::lock_guard<std::mutex> lock(mMutex);
for (const auto& [_, moduleAndSpirv] : mTransformedShaderModuleCache) {
mDevice->GetFencedDeleter()->DeleteWhenUnused(moduleAndSpirv.first);
}
}
std::optional<ShaderModule::ModuleAndSpirv>
ShaderModule::ConcurrentTransformedShaderModuleCache::Find(
const TransformedShaderModuleCacheKey& key) {
std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key);
if (iter != mTransformedShaderModuleCache.end()) {
return ModuleAndSpirv{
iter->second.first,
iter->second.second.Code(),
iter->second.second.WordCount(),
};
}
return {};
}
ShaderModule::ModuleAndSpirv ShaderModule::ConcurrentTransformedShaderModuleCache::AddOrGet(
const TransformedShaderModuleCacheKey& key,
VkShaderModule module,
Spirv&& spirv) {
ASSERT(module != VK_NULL_HANDLE);
std::lock_guard<std::mutex> lock(mMutex);
auto iter = mTransformedShaderModuleCache.find(key);
if (iter == mTransformedShaderModuleCache.end()) {
bool added = false;
std::tie(iter, added) =
mTransformedShaderModuleCache.emplace(key, std::make_pair(module, std::move(spirv)));
ASSERT(added);
} else {
// No need to use FencedDeleter since this shader module was just created and does
// not need to wait for queue operations to complete.
// Also, use of fenced deleter here is not thread safe.
mDevice->fn.DestroyShaderModule(mDevice->GetVkDevice(), module, nullptr);
}
return ModuleAndSpirv{
iter->second.first,
iter->second.second.Code(),
iter->second.second.WordCount(),
};
}
// static
ResultOrError<Ref<ShaderModule>> ShaderModule::Create(
Device* device,
@ -194,6 +173,7 @@ ShaderModule::~ShaderModule() = default;
X(std::string_view, entryPointName) \
X(bool, isRobustnessEnabled) \
X(bool, disableWorkgroupInit) \
X(bool, disableSymbolRenaming) \
X(bool, useZeroInitializeWorkgroupMemoryExtension) \
X(CacheKey::UnsafeUnkeyedValue<dawn::platform::Platform*>, tracePlatform)
@ -275,6 +255,7 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
req.entryPointName = programmableStage.entryPoint;
req.isRobustnessEnabled = GetDevice()->IsRobustnessEnabled();
req.disableWorkgroupInit = GetDevice()->IsToggleEnabled(Toggle::DisableWorkgroupInit);
req.disableSymbolRenaming = GetDevice()->IsToggleEnabled(Toggle::DisableSymbolRenaming);
req.useZeroInitializeWorkgroupMemoryExtension =
GetDevice()->IsToggleEnabled(Toggle::VulkanUseZeroInitializeWorkgroupMemoryExtension);
req.tracePlatform = UnsafeUnkeyedValue(GetDevice()->GetPlatform());
@ -283,22 +264,31 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
const CombinedLimits& limits = GetDevice()->GetLimits();
req.limits = LimitsForCompilationRequest::Create(limits.v1);
CacheResult<Spirv> spirv;
CacheResult<CompiledSpirv> compilation;
DAWN_TRY_LOAD_OR_RUN(
spirv, GetDevice(), std::move(req), Spirv::FromBlob,
[](SpirvCompilationRequest r) -> ResultOrError<Spirv> {
compilation, GetDevice(), std::move(req), CompiledSpirv::FromBlob,
[](SpirvCompilationRequest r) -> ResultOrError<CompiledSpirv> {
tint::transform::Manager transformManager;
tint::transform::DataMap transformInputs;
// Many Vulkan drivers can't handle multi-entrypoint shader modules.
// Run before the renamer so that the entry point name matches `entryPointName` still.
transformManager.append(std::make_unique<tint::transform::SingleEntryPoint>());
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(
std::string(r.entryPointName));
// Needs to run before all other transforms so that they can use builtin names safely.
if (!r.disableSymbolRenaming) {
transformManager.Add<tint::transform::Renamer>();
}
if (r.isRobustnessEnabled) {
transformManager.append(std::make_unique<tint::transform::Robustness>());
}
// Many Vulkan drivers can't handle multi-entrypoint shader modules.
transformManager.append(std::make_unique<tint::transform::SingleEntryPoint>());
// Run the binding remapper after SingleEntryPoint to avoid collisions with
// unused entryPoints.
transformManager.append(std::make_unique<tint::transform::BindingRemapper>());
tint::transform::DataMap transformInputs;
transformInputs.Add<tint::transform::SingleEntryPoint::Config>(
std::string(r.entryPointName));
transformInputs.Add<BindingRemapper::Remappings>(std::move(r.bindingPoints),
BindingRemapper::AccessControls{},
/* mayCollide */ false);
@ -314,18 +304,35 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
transformInputs.Add<tint::transform::SubstituteOverride::Config>(
std::move(r.substituteOverrideConfig).value());
}
tint::Program program;
tint::transform::DataMap transformOutputs;
{
TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "RunTransforms");
DAWN_TRY_ASSIGN(program, RunTransforms(&transformManager, r.inputProgram,
transformInputs, nullptr, nullptr));
DAWN_TRY_ASSIGN(program,
RunTransforms(&transformManager, r.inputProgram, transformInputs,
&transformOutputs, nullptr));
}
// Get the entry point name after the renamer pass.
std::string remappedEntryPoint;
if (r.disableSymbolRenaming) {
remappedEntryPoint = r.entryPointName;
} else {
auto* data = transformOutputs.Get<tint::transform::Renamer::Data>();
ASSERT(data != nullptr);
auto it = data->remappings.find(r.entryPointName.data());
ASSERT(it != data->remappings.end());
remappedEntryPoint = it->second;
}
ASSERT(remappedEntryPoint != "");
// Validate workgroup size after program runs transforms.
if (r.stage == SingleShaderStage::Compute) {
// Validate workgroup size after program runs transforms.
Extent3D _;
DAWN_TRY_ASSIGN(_, ValidateComputeStageWorkgroupSize(
program, r.entryPointName.data(), r.limits));
program, remappedEntryPoint.c_str(), r.limits));
}
tint::writer::spirv::Options options;
@ -336,22 +343,25 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General,
"tint::writer::spirv::Generate()");
auto result = tint::writer::spirv::Generate(&program, options);
DAWN_INVALID_IF(!result.success, "An error occured while generating SPIR-V: %s.",
result.error);
auto tintResult = tint::writer::spirv::Generate(&program, options);
DAWN_INVALID_IF(!tintResult.success, "An error occured while generating SPIR-V: %s.",
tintResult.error);
return Spirv::Create(std::move(result.spirv));
CompiledSpirv result;
result.spirv = std::move(tintResult.spirv);
result.remappedEntryPoint = remappedEntryPoint;
return result;
});
DAWN_TRY(ValidateSpirv(GetDevice(), spirv->Code(), spirv->WordCount(),
DAWN_TRY(ValidateSpirv(GetDevice(), compilation->spirv.data(), compilation->spirv.size(),
GetDevice()->IsToggleEnabled(Toggle::DumpShaders)));
VkShaderModuleCreateInfo createInfo;
createInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
createInfo.pNext = nullptr;
createInfo.flags = 0;
createInfo.codeSize = spirv->WordCount() * sizeof(uint32_t);
createInfo.pCode = spirv->Code();
createInfo.codeSize = compilation->spirv.size() * sizeof(uint32_t);
createInfo.pCode = compilation->spirv.data();
Device* device = ToBackend(GetDevice());
@ -366,13 +376,13 @@ ResultOrError<ShaderModule::ModuleAndSpirv> ShaderModule::GetHandleAndSpirv(
ModuleAndSpirv moduleAndSpirv;
if (newHandle != VK_NULL_HANDLE) {
if (BlobCache* cache = device->GetBlobCache()) {
cache->EnsureStored(spirv);
cache->EnsureStored(compilation);
}
// Set the label on `newHandle` now, and not on `moduleAndSpirv.module` later
// since `moduleAndSpirv.module` may be in use by multiple threads.
SetDebugName(ToBackend(GetDevice()), newHandle, "Dawn_ShaderModule", GetLabel());
moduleAndSpirv =
mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, spirv.Acquire());
mTransformedShaderModuleCache->AddOrGet(cacheKey, newHandle, compilation.Acquire());
}
return std::move(moduleAndSpirv);

View File

@ -50,11 +50,11 @@ class PipelineLayout;
class ShaderModule final : public ShaderModuleBase {
public:
class Spirv;
struct ModuleAndSpirv {
VkShaderModule module;
const uint32_t* spirv;
size_t wordCount;
const char* remappedEntryPoint;
};
static ResultOrError<Ref<ShaderModule>> Create(Device* device,

View File

@ -896,6 +896,59 @@ TEST_P(ShaderTests, DISABLED_CheckUsageOf_chromium_disable_uniformity_analysis)
)"));
}
// Test that it is not possible to override the builtins in a way that breaks the robustness
// transform.
TEST_P(ShaderTests, ShaderOverridingRobustnessBuiltins) {
// TODO(dawn:1585): The OpenGL backend doesn't use the Renamer tint transform yet.
DAWN_SUPPRESS_TEST_IF(IsOpenGL() || IsOpenGLES());
// Make the test compute pipeline.
wgpu::ComputePipelineDescriptor cDesc;
cDesc.compute.module = utils::CreateShaderModule(device, R"(
// A fake min() function that always returns 0.
fn min(a : u32, b : u32) -> u32 {
return 0;
}
@group(0) @binding(0) var<storage, read_write> result : u32;
@compute @workgroup_size(1) fn little_bobby_tables() {
// Prevent the SingleEntryPoint transform from removing our min().
let forceUseOfMin = min(0, 1);
let values = array<u32, 2>(1, 2);
let index = 1u;
// Robustness adds transforms values[index] into values[min(index, 1u)].
// - If our min() is called, the this will be values[0] which is 1.
// - If the correct min() is called, the this will be values[1] which is 2.
result = values[index];
}
)");
cDesc.compute.entryPoint = "little_bobby_tables";
wgpu::ComputePipeline pipeline = device.CreateComputePipeline(&cDesc);
// Test 4-byte buffer that will receive the result.
wgpu::BufferDescriptor bufDesc;
bufDesc.size = 4;
bufDesc.usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc;
wgpu::Buffer buf = device.CreateBuffer(&bufDesc);
wgpu::BindGroup bg = utils::MakeBindGroup(device, pipeline.GetBindGroupLayout(0), {{0, buf}});
// Run the compute pipeline.
wgpu::CommandEncoder encoder = device.CreateCommandEncoder();
wgpu::ComputePassEncoder pass = encoder.BeginComputePass();
pass.SetPipeline(pipeline);
pass.SetBindGroup(0, bg);
pass.DispatchWorkgroups(1);
pass.End();
wgpu::CommandBuffer commands = encoder.Finish();
queue.Submit(1, &commands);
// See the comment in the shader for why we expect a 2 here.
EXPECT_BUFFER_U32_EQ(2, buf, 0);
}
DAWN_INSTANTIATE_TEST(ShaderTests,
D3D12Backend(),
MetalBackend(),

View File

@ -174,6 +174,12 @@ crbug.com/dawn/1398 [ webgpu-adapter-default win ] webgpu:web_platform,copyToTex
webgpu:api,validation,error_scope:current_scope:errorFilter="out-of-memory";stackDepth=100000 [ Slow ]
webgpu:api,validation,error_scope:current_scope:errorFilter="validation";stackDepth=100000 [ Slow ]
################################################################################
# Likely Intel Linux compiler slowness causing timeouts.
################################################################################
crbug.com/dawn/1587 webgpu:shader,execution,zero_init:compute,zero_init:storageClass="function";workgroupSize=[1,1,1];batch__=9 [ Slow ]
crbug.com/dawn/1587 webgpu:shader,execution,zero_init:compute,zero_init:storageClass="private";workgroupSize=[1,1,1];batch__=9 [ Slow ]
################################################################################
# entry_point_name_must_match failures
################################################################################