From a0df1384f207c4b3b36351bf75dd6dd6ec857901 Mon Sep 17 00:00:00 2001 From: Austin Eng Date: Fri, 5 Aug 2022 20:09:07 +0000 Subject: [PATCH] Cache WGSL -> DXBC/DXIL compilation Bug: dawn:1480 Change-Id: I858111f62be457c2e7cd5017bbf4c10e76395e83 Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/95340 Reviewed-by: Loko Kung Commit-Queue: Austin Eng Kokoro: Kokoro --- src/dawn/native/Blob.cpp | 26 + src/dawn/native/StreamImplTint.cpp | 20 +- src/dawn/native/TintUtils.cpp | 7 +- src/dawn/native/TintUtils.h | 11 +- src/dawn/native/d3d12/BlobD3D12.cpp | 12 + src/dawn/native/d3d12/BlobD3D12.h | 1 + src/dawn/native/d3d12/ShaderModuleD3D12.cpp | 790 ++++++++---------- src/dawn/native/d3d12/ShaderModuleD3D12.h | 23 +- src/dawn/native/d3d12/UtilsD3D12.cpp | 8 +- src/dawn/native/d3d12/UtilsD3D12.h | 2 +- src/dawn/native/stream/Stream.cpp | 10 + src/dawn/native/stream/Stream.h | 7 +- .../tests/end2end/PipelineCachingTests.cpp | 5 +- .../tests/unittests/native/StreamTests.cpp | 59 ++ 14 files changed, 529 insertions(+), 452 deletions(-) diff --git a/src/dawn/native/Blob.cpp b/src/dawn/native/Blob.cpp index cecac30e4f..78b18e7d60 100644 --- a/src/dawn/native/Blob.cpp +++ b/src/dawn/native/Blob.cpp @@ -17,6 +17,7 @@ #include "dawn/common/Assert.h" #include "dawn/common/Math.h" #include "dawn/native/Blob.h" +#include "dawn/native/stream/Stream.h" namespace dawn::native { @@ -99,4 +100,29 @@ void Blob::AlignTo(size_t alignment) { *this = std::move(blob); } +template <> +void stream::Stream::Write(stream::Sink* s, const Blob& b) { + size_t size = b.Size(); + StreamIn(s, size); + if (size > 0) { + void* ptr = s->GetSpace(size); + memcpy(ptr, b.Data(), size); + } +} + +template <> +MaybeError stream::Stream::Read(stream::Source* s, Blob* b) { + size_t size; + DAWN_TRY(StreamOut(s, &size)); + if (size > 0) { + const void* ptr; + DAWN_TRY(s->Read(&ptr, size)); + *b = CreateBlob(size); + memcpy(b->Data(), ptr, size); + } else { + *b = Blob(); + } + return {}; +} + } // namespace dawn::native diff --git a/src/dawn/native/StreamImplTint.cpp b/src/dawn/native/StreamImplTint.cpp index 1c0c04de43..13a70cabb5 100644 --- a/src/dawn/native/StreamImplTint.cpp +++ b/src/dawn/native/StreamImplTint.cpp @@ -14,7 +14,8 @@ #include "dawn/native/stream/Stream.h" -#include "tint/tint.h" +#include "dawn/native/TintUtils.h" +#include "tint/writer/array_length_from_uniform_options.h" namespace dawn::native { @@ -96,4 +97,21 @@ void stream::Stream::Write( StreamIn(sink, attrib.format, attrib.offset, attrib.shader_location); } +// static +template <> +void stream::Stream::Write( + stream::Sink* sink, + const tint::writer::ArrayLengthFromUniformOptions& o) { + static_assert(offsetof(tint::writer::ArrayLengthFromUniformOptions, ubo_binding) == 0, + "Please update serialization for tint::writer::ArrayLengthFromUniformOptions"); + static_assert( + offsetof(tint::writer::ArrayLengthFromUniformOptions, bindpoint_to_size_index) == 8, + "Please update serialization for tint::writer::ArrayLengthFromUniformOptions"); + static_assert( + sizeof(tint::writer::ArrayLengthFromUniformOptions) == + 8 + sizeof(tint::writer::ArrayLengthFromUniformOptions::bindpoint_to_size_index), + "Please update serialization for tint::writer::ArrayLengthFromUniformOptions"); + StreamIn(sink, o.ubo_binding, o.bindpoint_to_size_index); +} + } // namespace dawn::native diff --git a/src/dawn/native/TintUtils.cpp b/src/dawn/native/TintUtils.cpp index c1585a57b6..f59e6e5cd1 100644 --- a/src/dawn/native/TintUtils.cpp +++ b/src/dawn/native/TintUtils.cpp @@ -185,7 +185,10 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( } // namespace dawn::native -bool std::less::operator()(const tint::sem::BindingPoint& a, - const tint::sem::BindingPoint& b) const { +namespace tint::sem { + +bool operator<(const BindingPoint& a, const BindingPoint& b) { return std::tie(a.group, a.binding) < std::tie(b.group, b.binding); } + +} // namespace tint::sem diff --git a/src/dawn/native/TintUtils.h b/src/dawn/native/TintUtils.h index e17fc691d2..7c03881502 100644 --- a/src/dawn/native/TintUtils.h +++ b/src/dawn/native/TintUtils.h @@ -49,10 +49,11 @@ tint::transform::VertexPulling::Config BuildVertexPullingTransformConfig( } // namespace dawn::native -// std::less operator for std::map containing BindingPoint -template <> -struct std::less { - bool operator()(const tint::sem::BindingPoint& a, const tint::sem::BindingPoint& b) const; -}; +namespace tint::sem { + +// Defin operator< for std::map containing BindingPoint +bool operator<(const BindingPoint& a, const BindingPoint& b); + +} // namespace tint::sem #endif // SRC_DAWN_NATIVE_TINTUTILS_H_ diff --git a/src/dawn/native/d3d12/BlobD3D12.cpp b/src/dawn/native/d3d12/BlobD3D12.cpp index ef9bbb9905..3b5965758d 100644 --- a/src/dawn/native/d3d12/BlobD3D12.cpp +++ b/src/dawn/native/d3d12/BlobD3D12.cpp @@ -28,4 +28,16 @@ Blob CreateBlob(ComPtr blob) { }); } +Blob CreateBlob(ComPtr blob) { + // Detach so the deleter callback can "own" the reference + IDxcBlob* ptr = blob.Detach(); + return Blob::UnsafeCreateWithDeleter(reinterpret_cast(ptr->GetBufferPointer()), + ptr->GetBufferSize(), [=]() { + // Reattach and drop to delete it. + ComPtr b; + b.Attach(ptr); + b = nullptr; + }); +} + } // namespace dawn::native diff --git a/src/dawn/native/d3d12/BlobD3D12.h b/src/dawn/native/d3d12/BlobD3D12.h index 563ac7341c..cc8c99c902 100644 --- a/src/dawn/native/d3d12/BlobD3D12.h +++ b/src/dawn/native/d3d12/BlobD3D12.h @@ -18,5 +18,6 @@ namespace dawn::native { Blob CreateBlob(ComPtr blob); +Blob CreateBlob(ComPtr blob); } // namespace dawn::native diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp index 3e09cc00b6..2fd04bdfc4 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.cpp @@ -29,79 +29,41 @@ #include "dawn/common/Log.h" #include "dawn/common/WindowsUtils.h" #include "dawn/native/CacheKey.h" +#include "dawn/native/CacheRequest.h" #include "dawn/native/Pipeline.h" #include "dawn/native/TintUtils.h" #include "dawn/native/d3d12/AdapterD3D12.h" #include "dawn/native/d3d12/BackendD3D12.h" #include "dawn/native/d3d12/BindGroupLayoutD3D12.h" +#include "dawn/native/d3d12/BlobD3D12.h" #include "dawn/native/d3d12/D3D12Error.h" #include "dawn/native/d3d12/DeviceD3D12.h" #include "dawn/native/d3d12/PipelineLayoutD3D12.h" #include "dawn/native/d3d12/PlatformFunctions.h" #include "dawn/native/d3d12/UtilsD3D12.h" +#include "dawn/native/stream/BlobSource.h" +#include "dawn/native/stream/ByteVectorSink.h" #include "dawn/platform/DawnPlatform.h" #include "dawn/platform/tracing/TraceEvent.h" #include "tint/tint.h" +namespace dawn::native::stream { + +// Define no-op serializations for pD3DCompile, IDxcLibrary, and IDxcCompiler. +// These are output-only interfaces used to generate bytecode. +template <> +void Stream::Write(Sink*, IDxcLibrary* const&) {} +template <> +void Stream::Write(Sink*, IDxcCompiler* const&) {} +template <> +void Stream::Write(Sink*, pD3DCompile const&) {} + +} // namespace dawn::native::stream + namespace dawn::native::d3d12 { namespace { -uint64_t GetD3DCompilerVersion() { - return D3D_COMPILER_VERSION; -} - -struct CompareBindingPoint { - constexpr bool operator()(const tint::transform::BindingPoint& lhs, - const tint::transform::BindingPoint& rhs) const { - if (lhs.group != rhs.group) { - return lhs.group < rhs.group; - } else { - return lhs.binding < rhs.binding; - } - } -}; - -void StreamIn(std::stringstream& output, const tint::ast::Access& access) { - output << access; -} - -void StreamIn(std::stringstream& output, const tint::transform::BindingPoint& binding_point) { - output << "(BindingPoint"; - output << " group=" << binding_point.group; - output << " binding=" << binding_point.binding; - output << ")"; -} - -template ::value>::type> -void StreamIn(std::stringstream& output, const T& val) { - output << val; -} - -template -void StreamIn(std::stringstream& output, - const std::unordered_map& map) { - output << "(map"; - - std::map sorted(map.begin(), map.end()); - for (auto& [bindingPoint, value] : sorted) { - output << " "; - StreamIn(output, bindingPoint); - output << "="; - StreamIn(output, value); - } - output << ")"; -} - -void StreamIn(std::stringstream& output, - const tint::writer::ArrayLengthFromUniformOptions& arrayLengthFromUniform) { - output << "(ArrayLengthFromUniformOptions"; - output << " ubo_binding="; - StreamIn(output, arrayLengthFromUniform.ubo_binding); - output << " bindpoint_to_size_index="; - StreamIn(output, arrayLengthFromUniform.bindpoint_to_size_index); - output << ")"; -} // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) { @@ -130,253 +92,113 @@ std::string GetHLSLValueString(EntryPointMetadata::Override::Type dawnType, constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; -void GetOverridableConstantsDefines( - std::vector>* defineStrings, - const PipelineConstantEntries* pipelineConstantEntries, - const EntryPointMetadata::OverridesMap* shaderEntryPointConstants) { +using DefineStrings = std::vector>; + +DefineStrings GetOverridableConstantsDefines( + const PipelineConstantEntries& pipelineConstantEntries, + const EntryPointMetadata::OverridesMap& shaderEntryPointConstants) { + DefineStrings defineStrings; std::unordered_set overriddenConstants; // Set pipeline overridden values - for (const auto& [name, value] : *pipelineConstantEntries) { + for (const auto& [name, value] : pipelineConstantEntries) { overriddenConstants.insert(name); // This is already validated so `name` must exist - const auto& moduleConstant = shaderEntryPointConstants->at(name); + const auto& moduleConstant = shaderEntryPointConstants.at(name); - defineStrings->emplace_back( + defineStrings.emplace_back( kSpecConstantPrefix + std::to_string(static_cast(moduleConstant.id)), GetHLSLValueString(moduleConstant.type, nullptr, value)); } // Set shader initialized default values - for (const auto& iter : *shaderEntryPointConstants) { + for (const auto& iter : shaderEntryPointConstants) { const std::string& name = iter.first; if (overriddenConstants.count(name) != 0) { // This constant already has overridden value continue; } - const auto& moduleConstant = shaderEntryPointConstants->at(name); + const auto& moduleConstant = shaderEntryPointConstants.at(name); // Uninitialized default values are okay since they ar only defined to pass // compilation but not used - defineStrings->emplace_back( + defineStrings.emplace_back( kSpecConstantPrefix + std::to_string(static_cast(moduleConstant.id)), GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue)); } + return defineStrings; } -// The inputs to a shader compilation. These have been intentionally isolated from the -// device to help ensure that the pipeline cache key contains all inputs for compilation. -struct ShaderCompilationRequest { - enum Compiler { FXC, DXC }; +enum class Compiler { FXC, DXC }; - // Common inputs - Compiler compiler; - const tint::Program* program; - const char* entryPointName; - SingleShaderStage stage; - uint32_t compileFlags; - bool disableSymbolRenaming; - tint::transform::BindingRemapper::BindingPoints remappedBindingPoints; - tint::transform::BindingRemapper::AccessControls remappedAccessControls; - bool isRobustnessEnabled; - bool usesNumWorkgroups; - uint32_t numWorkgroupsRegisterSpace; - uint32_t numWorkgroupsShaderRegister; - tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform; - std::vector> defineStrings; +#define HLSL_COMPILATION_REQUEST_MEMBERS(X) \ + X(const tint::Program*, inputProgram) \ + X(std::string_view, entryPointName) \ + X(SingleShaderStage, stage) \ + X(uint32_t, shaderModel) \ + X(uint32_t, compileFlags) \ + X(Compiler, compiler) \ + X(uint64_t, compilerVersion) \ + X(std::wstring_view, dxcShaderProfile) \ + X(std::string_view, fxcShaderProfile) \ + X(pD3DCompile, d3dCompile) \ + X(IDxcLibrary*, dxcLibrary) \ + X(IDxcCompiler*, dxcCompiler) \ + X(uint32_t, firstIndexOffsetShaderRegister) \ + X(uint32_t, firstIndexOffsetRegisterSpace) \ + X(bool, usesNumWorkgroups) \ + X(uint32_t, numWorkgroupsShaderRegister) \ + X(uint32_t, numWorkgroupsRegisterSpace) \ + X(DefineStrings, defineStrings) \ + X(tint::transform::MultiplanarExternalTexture::BindingsMap, newBindingsMap) \ + X(tint::writer::ArrayLengthFromUniformOptions, arrayLengthFromUniform) \ + X(tint::transform::BindingRemapper::BindingPoints, remappedBindingPoints) \ + X(tint::transform::BindingRemapper::AccessControls, remappedAccessControls) \ + X(bool, disableSymbolRenaming) \ + X(bool, isRobustnessEnabled) \ + X(bool, disableWorkgroupInit) \ + X(bool, dumpShaders) - // FXC/DXC common inputs - bool disableWorkgroupInit; +#define D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS(X) \ + X(bool, hasShaderFloat16Feature) \ + X(uint32_t, compileFlags) \ + X(Compiler, compiler) \ + X(uint64_t, compilerVersion) \ + X(std::wstring_view, dxcShaderProfile) \ + X(std::string_view, fxcShaderProfile) \ + X(pD3DCompile, d3dCompile) \ + X(IDxcLibrary*, dxcLibrary) \ + X(IDxcCompiler*, dxcCompiler) \ + X(DefineStrings, defineStrings) - // FXC inputs - uint64_t fxcVersion; +struct HlslCompilationRequest { + DAWN_VISITABLE_MEMBERS(HLSL_COMPILATION_REQUEST_MEMBERS) - // DXC inputs - uint64_t dxcVersion; - const D3D12DeviceInfo* deviceInfo; - bool hasShaderFloat16Feature; - - static ResultOrError Create( - const char* entryPointName, - SingleShaderStage stage, - const PipelineLayout* layout, - uint32_t compileFlags, - const Device* device, - const tint::Program* program, - const EntryPointMetadata& entryPoint, - const ProgrammableStage& programmableStage) { - Compiler compiler; - uint64_t dxcVersion = 0; - if (device->IsToggleEnabled(Toggle::UseDXC)) { - compiler = Compiler::DXC; - DAWN_TRY_ASSIGN(dxcVersion, - ToBackend(device->GetAdapter())->GetBackend()->GetDXCompilerVersion()); - } else { - compiler = Compiler::FXC; - } - - using tint::transform::BindingPoint; - using tint::transform::BindingRemapper; - - BindingRemapper::BindingPoints remappedBindingPoints; - BindingRemapper::AccessControls remappedAccessControls; - - tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform; - arrayLengthFromUniform.ubo_binding = { - layout->GetDynamicStorageBufferLengthsRegisterSpace(), - layout->GetDynamicStorageBufferLengthsShaderRegister()}; - - const BindingInfoArray& moduleBindingInfo = entryPoint.bindings; - for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { - const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); - const auto& groupBindingInfo = moduleBindingInfo[group]; - - // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify - // the Tint AST to make the "bindings" decoration match the offset chosen by - // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers - // assigned to each interface variable. - for (const auto& [binding, bindingInfo] : groupBindingInfo) { - BindingIndex bindingIndex = bgl->GetBindingIndex(binding); - BindingPoint srcBindingPoint{static_cast(group), - static_cast(binding)}; - BindingPoint dstBindingPoint{static_cast(group), - bgl->GetShaderRegister(bindingIndex)}; - if (srcBindingPoint != dstBindingPoint) { - remappedBindingPoints.emplace(srcBindingPoint, dstBindingPoint); - } - - // Declaring a read-only storage buffer in HLSL but specifying a storage - // buffer in the BGL produces the wrong output. Force read-only storage - // buffer bindings to be treated as UAV instead of SRV. Internal storage - // buffer is a storage buffer used in the internal pipeline. - const bool forceStorageBufferAsUAV = - (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage && - (bgl->GetBindingInfo(bindingIndex).buffer.type == - wgpu::BufferBindingType::Storage || - bgl->GetBindingInfo(bindingIndex).buffer.type == - kInternalStorageBufferBinding)); - if (forceStorageBufferAsUAV) { - remappedAccessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite); - } - } - - // Add arrayLengthFromUniform options - { - for (const auto& bindingAndRegisterOffset : - layout->GetDynamicStorageBufferLengthInfo()[group].bindingAndRegisterOffsets) { - BindingNumber binding = bindingAndRegisterOffset.binding; - uint32_t registerOffset = bindingAndRegisterOffset.registerOffset; - - BindingPoint bindingPoint{static_cast(group), - static_cast(binding)}; - // Get the renamed binding point if it was remapped. - auto it = remappedBindingPoints.find(bindingPoint); - if (it != remappedBindingPoints.end()) { - bindingPoint = it->second; - } - - arrayLengthFromUniform.bindpoint_to_size_index.emplace(bindingPoint, - registerOffset); - } - } - } - - ShaderCompilationRequest request; - request.compiler = compiler; - request.program = program; - request.entryPointName = entryPointName; - request.stage = stage; - request.compileFlags = compileFlags; - request.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming); - request.remappedBindingPoints = std::move(remappedBindingPoints); - request.remappedAccessControls = std::move(remappedAccessControls); - request.isRobustnessEnabled = device->IsRobustnessEnabled(); - request.disableWorkgroupInit = device->IsToggleEnabled(Toggle::DisableWorkgroupInit); - request.usesNumWorkgroups = entryPoint.usesNumWorkgroups; - request.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister(); - request.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace(); - request.arrayLengthFromUniform = std::move(arrayLengthFromUniform); - request.fxcVersion = compiler == Compiler::FXC ? GetD3DCompilerVersion() : 0; - request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0; - request.deviceInfo = &device->GetDeviceInfo(); - request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16); - - GetOverridableConstantsDefines( - &request.defineStrings, &programmableStage.constants, - &programmableStage.module->GetEntryPoint(programmableStage.entryPoint).overrides); - - return std::move(request); - } - - // TODO(dawn:1341): Move to use CacheKey instead of the vector. - ResultOrError> CreateCacheKey() const { - // Generate the WGSL from the Tint program so it's normalized. - // TODO(tint:1180): Consider using a binary serialization of the tint AST for a more - // compact representation. - auto result = tint::writer::wgsl::Generate(program, tint::writer::wgsl::Options{}); - if (!result.success) { - std::ostringstream errorStream; - errorStream << "Tint WGSL failure:" << std::endl; - errorStream << "Generator: " << result.error << std::endl; - return DAWN_INTERNAL_ERROR(errorStream.str().c_str()); - } - - std::stringstream stream; - - // Prefix the key with the type to avoid collisions from another type that could - // have the same key. - stream << static_cast(CacheKey::Type::Shader); - stream << "\n"; - - stream << result.wgsl.length(); - stream << "\n"; - - stream << result.wgsl; - stream << "\n"; - - stream << "(ShaderCompilationRequest"; - stream << " compiler=" << compiler; - stream << " entryPointName=" << entryPointName; - stream << " stage=" << uint32_t(stage); - stream << " compileFlags=" << compileFlags; - stream << " disableSymbolRenaming=" << disableSymbolRenaming; - - stream << " remappedBindingPoints="; - StreamIn(stream, remappedBindingPoints); - - stream << " remappedAccessControls="; - StreamIn(stream, remappedAccessControls); - - stream << " useNumWorkgroups=" << usesNumWorkgroups; - stream << " numWorkgroupsRegisterSpace=" << numWorkgroupsRegisterSpace; - stream << " numWorkgroupsShaderRegister=" << numWorkgroupsShaderRegister; - - stream << " arrayLengthFromUniform="; - StreamIn(stream, arrayLengthFromUniform); - - stream << " shaderModel=" << deviceInfo->shaderModel; - stream << " disableWorkgroupInit=" << disableWorkgroupInit; - stream << " isRobustnessEnabled=" << isRobustnessEnabled; - stream << " fxcVersion=" << fxcVersion; - stream << " dxcVersion=" << dxcVersion; - stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature; - - stream << " defines={"; - for (const auto& [name, value] : defineStrings) { - stream << " <" << name << "," << value << ">"; - } - stream << " }"; - - stream << ")"; - stream << "\n"; - - return std::vector(std::istreambuf_iterator{stream}, - std::istreambuf_iterator{}); + friend void StreamIn(stream::Sink* sink, const HlslCompilationRequest& r) { + r.VisitAll([&](const auto&... members) { StreamIn(sink, members...); }); } }; +struct D3DBytecodeCompilationRequest { + DAWN_VISITABLE_MEMBERS(D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS) + + friend void StreamIn(stream::Sink* sink, const D3DBytecodeCompilationRequest& r) { + r.VisitAll([&](const auto&... members) { StreamIn(sink, members...); }); + } +}; + +#define D3D_COMPILATION_REQUEST_MEMBERS(X) \ + X(HlslCompilationRequest, hlsl) \ + X(D3DBytecodeCompilationRequest, bytecode) \ + X(CacheKey::UnsafeUnkeyedValue, tracePlatform) + +DAWN_MAKE_CACHE_REQUEST(D3DCompilationRequest, D3D_COMPILATION_REQUEST_MEMBERS); +#undef HLSL_COMPILATION_REQUEST_MEMBERS +#undef D3D_BYTECODE_COMPILATION_REQUEST_MEMBERS +#undef D3D_COMPILATION_REQUEST_MEMBERS + std::vector GetDXCArguments(uint32_t compileFlags, bool enable16BitTypes) { std::vector arguments; if (compileFlags & D3DCOMPILE_ENABLE_BACKWARDS_COMPATIBILITY) { @@ -429,25 +251,24 @@ std::vector GetDXCArguments(uint32_t compileFlags, bool enable16 return arguments; } -ResultOrError> CompileShaderDXC(IDxcLibrary* dxcLibrary, - IDxcCompiler* dxcCompiler, - const ShaderCompilationRequest& request, +ResultOrError> CompileShaderDXC(const D3DBytecodeCompilationRequest& r, + const std::string& entryPointName, const std::string& hlslSource) { ComPtr sourceBlob; - DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy( + DAWN_TRY(CheckHRESULT(r.dxcLibrary->CreateBlobWithEncodingFromPinned( hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob), "DXC create blob")); std::wstring entryPointW; - DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(request.entryPointName)); + DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPointName)); std::vector arguments = - GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature); + GetDXCArguments(r.compileFlags, r.hasShaderFloat16Feature); // Build defines for overridable constants std::vector> defineStrings; - defineStrings.reserve(request.defineStrings.size()); - for (const auto& [name, value] : request.defineStrings) { + defineStrings.reserve(r.defineStrings.size()); + for (const auto& [name, value] : r.defineStrings) { defineStrings.emplace_back(UTF8ToWStr(name.c_str()), UTF8ToWStr(value.c_str())); } @@ -458,12 +279,11 @@ ResultOrError> CompileShaderDXC(IDxcLibrary* dxcLibrary, } ComPtr result; - DAWN_TRY( - CheckHRESULT(dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), - request.deviceInfo->shaderProfiles[request.stage].c_str(), - arguments.data(), arguments.size(), dxcDefines.data(), - dxcDefines.size(), nullptr, &result), - "DXC compile")); + DAWN_TRY(CheckHRESULT( + r.dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(), + r.dxcShaderProfile.data(), arguments.data(), arguments.size(), + dxcDefines.data(), dxcDefines.size(), nullptr, &result), + "DXC compile")); HRESULT hr; DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status")); @@ -543,31 +363,18 @@ std::string CompileFlagsToStringFXC(uint32_t compileFlags) { return result; } -ResultOrError> CompileShaderFXC(const PlatformFunctions* functions, - const ShaderCompilationRequest& request, +ResultOrError> CompileShaderFXC(const D3DBytecodeCompilationRequest& r, + const std::string& entryPointName, const std::string& hlslSource) { - const char* targetProfile = nullptr; - switch (request.stage) { - case SingleShaderStage::Vertex: - targetProfile = "vs_5_1"; - break; - case SingleShaderStage::Fragment: - targetProfile = "ps_5_1"; - break; - case SingleShaderStage::Compute: - targetProfile = "cs_5_1"; - break; - } - ComPtr compiledShader; ComPtr errors; // Build defines for overridable constants const D3D_SHADER_MACRO* pDefines = nullptr; std::vector fxcDefines; - if (request.defineStrings.size() > 0) { - fxcDefines.reserve(request.defineStrings.size() + 1); - for (const auto& [name, value] : request.defineStrings) { + if (r.defineStrings.size() > 0) { + fxcDefines.reserve(r.defineStrings.size() + 1); + for (const auto& [name, value] : r.defineStrings) { fxcDefines.push_back({name.c_str(), value.c_str()}); } // d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array @@ -575,36 +382,49 @@ ResultOrError> CompileShaderFXC(const PlatformFunctions* functi pDefines = fxcDefines.data(); } - DAWN_INVALID_IF( - FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, - nullptr, request.entryPointName, targetProfile, - request.compileFlags, 0, &compiledShader, &errors)), - "D3D compile failed with: %s", static_cast(errors->GetBufferPointer())); + DAWN_INVALID_IF(FAILED(r.d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, pDefines, + nullptr, entryPointName.c_str(), r.fxcShaderProfile.data(), + r.compileFlags, 0, &compiledShader, &errors)), + "D3D compile failed with: %s", static_cast(errors->GetBufferPointer())); return std::move(compiledShader); } -ResultOrError TranslateToHLSL(dawn::platform::Platform* platform, - const ShaderCompilationRequest& request, - std::string* remappedEntryPointName) { +ResultOrError TranslateToHLSL( + HlslCompilationRequest r, + CacheKey::UnsafeUnkeyedValue tracePlatform, + std::string* remappedEntryPointName, + bool* usesVertexOrInstanceIndex) { std::ostringstream errorStream; errorStream << "Tint HLSL failure:" << std::endl; tint::transform::Manager transformManager; tint::transform::DataMap transformInputs; - if (request.isRobustnessEnabled) { + if (!r.newBindingsMap.empty()) { + transformManager.Add(); + transformInputs.Add( + std::move(r.newBindingsMap)); + } + + if (r.stage == SingleShaderStage::Vertex) { + transformManager.Add(); + transformInputs.Add( + r.firstIndexOffsetShaderRegister, r.firstIndexOffsetRegisterSpace); + } + + if (r.isRobustnessEnabled) { transformManager.Add(); } transformManager.Add(); transformManager.Add(); - transformInputs.Add(request.entryPointName); + transformInputs.Add(r.entryPointName.data()); transformManager.Add(); - if (request.disableSymbolRenaming) { + if (r.disableSymbolRenaming) { // We still need to rename HLSL reserved keywords transformInputs.Add( tint::transform::Renamer::Target::kHlslKeywords); @@ -615,104 +435,92 @@ ResultOrError TranslateToHLSL(dawn::platform::Platform* platform, // different types. const bool mayCollide = true; transformInputs.Add( - std::move(request.remappedBindingPoints), std::move(request.remappedAccessControls), - mayCollide); + std::move(r.remappedBindingPoints), std::move(r.remappedAccessControls), mayCollide); tint::Program transformedProgram; tint::transform::DataMap transformOutputs; { - TRACE_EVENT0(platform, General, "RunTransforms"); + TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "RunTransforms"); DAWN_TRY_ASSIGN(transformedProgram, - RunTransforms(&transformManager, request.program, transformInputs, + RunTransforms(&transformManager, r.inputProgram, transformInputs, &transformOutputs, nullptr)); } if (auto* data = transformOutputs.Get()) { - auto it = data->remappings.find(request.entryPointName); + auto it = data->remappings.find(r.entryPointName.data()); if (it != data->remappings.end()) { *remappedEntryPointName = it->second; } else { - DAWN_INVALID_IF(!request.disableSymbolRenaming, + DAWN_INVALID_IF(!r.disableSymbolRenaming, "Could not find remapped name for entry point."); - *remappedEntryPointName = request.entryPointName; + *remappedEntryPointName = r.entryPointName; } } else { return DAWN_FORMAT_VALIDATION_ERROR("Transform output missing renamer data."); } + if (r.stage == SingleShaderStage::Vertex) { + if (auto* data = transformOutputs.Get()) { + *usesVertexOrInstanceIndex = data->has_vertex_or_instance_index; + } else { + return DAWN_FORMAT_VALIDATION_ERROR( + "Transform output missing first index offset data."); + } + } + tint::writer::hlsl::Options options; - options.disable_workgroup_init = request.disableWorkgroupInit; - if (request.usesNumWorkgroups) { - options.root_constant_binding_point = tint::sem::BindingPoint{ - request.numWorkgroupsRegisterSpace, request.numWorkgroupsShaderRegister}; + options.disable_workgroup_init = r.disableWorkgroupInit; + if (r.usesNumWorkgroups) { + options.root_constant_binding_point = + tint::sem::BindingPoint{r.numWorkgroupsRegisterSpace, r.numWorkgroupsShaderRegister}; } // TODO(dawn:549): HLSL generation outputs the indices into the // array_length_from_uniform buffer that were actually used. When the blob cache can // store more than compiled shaders, we should reflect these used indices and store // them as well. This would allow us to only upload root constants that are actually // read by the shader. - options.array_length_from_uniform = request.arrayLengthFromUniform; - TRACE_EVENT0(platform, General, "tint::writer::hlsl::Generate"); + options.array_length_from_uniform = r.arrayLengthFromUniform; + TRACE_EVENT0(tracePlatform.UnsafeGetValue(), General, "tint::writer::hlsl::Generate"); auto result = tint::writer::hlsl::Generate(&transformedProgram, options); DAWN_INVALID_IF(!result.success, "An error occured while generating HLSL: %s", result.error); return std::move(result.hlsl); } -template -MaybeError CompileShader(dawn::platform::Platform* platform, - const PlatformFunctions* functions, - IDxcLibrary* dxcLibrary, - IDxcCompiler* dxcCompiler, - ShaderCompilationRequest&& request, - bool dumpShaders, - F&& DumpShadersEmitLog, - CompiledShader* compiledShader) { +ResultOrError CompileShader(D3DCompilationRequest r) { + CompiledShader compiledShader; // Compile the source shader to HLSL. - std::string hlslSource; std::string remappedEntryPoint; - DAWN_TRY_ASSIGN(hlslSource, TranslateToHLSL(platform, request, &remappedEntryPoint)); - if (dumpShaders) { - std::ostringstream dumpedMsg; - dumpedMsg << "/* Dumped generated HLSL */" << std::endl << hlslSource; - DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); - } - request.entryPointName = remappedEntryPoint.c_str(); - switch (request.compiler) { - case ShaderCompilationRequest::Compiler::DXC: { - TRACE_EVENT0(platform, General, "CompileShaderDXC"); - DAWN_TRY_ASSIGN(compiledShader->compiledDXCShader, - CompileShaderDXC(dxcLibrary, dxcCompiler, request, hlslSource)); + DAWN_TRY_ASSIGN(compiledShader.hlslSource, + TranslateToHLSL(std::move(r.hlsl), r.tracePlatform, &remappedEntryPoint, + &compiledShader.usesVertexOrInstanceIndex)); + + switch (r.bytecode.compiler) { + case Compiler::DXC: { + TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "CompileShaderDXC"); + ComPtr compiledDXCShader; + DAWN_TRY_ASSIGN(compiledDXCShader, CompileShaderDXC(r.bytecode, remappedEntryPoint, + compiledShader.hlslSource)); + compiledShader.shaderBlob = CreateBlob(std::move(compiledDXCShader)); break; } - case ShaderCompilationRequest::Compiler::FXC: { - TRACE_EVENT0(platform, General, "CompileShaderFXC"); - DAWN_TRY_ASSIGN(compiledShader->compiledFXCShader, - CompileShaderFXC(functions, request, hlslSource)); + case Compiler::FXC: { + TRACE_EVENT0(r.tracePlatform.UnsafeGetValue(), General, "CompileShaderFXC"); + ComPtr compiledFXCShader; + DAWN_TRY_ASSIGN(compiledFXCShader, CompileShaderFXC(r.bytecode, remappedEntryPoint, + compiledShader.hlslSource)); + compiledShader.shaderBlob = CreateBlob(std::move(compiledFXCShader)); break; } } - if (dumpShaders && request.compiler == ShaderCompilationRequest::Compiler::FXC) { - std::ostringstream dumpedMsg; - dumpedMsg << "/* FXC compile flags */ " << std::endl - << CompileFlagsToStringFXC(request.compileFlags) << std::endl; - - dumpedMsg << "/* Dumped disassembled DXBC */" << std::endl; - - ComPtr disassembly; - if (FAILED(functions->d3dDisassemble(compiledShader->compiledFXCShader->GetBufferPointer(), - compiledShader->compiledFXCShader->GetBufferSize(), 0, - nullptr, &disassembly))) { - dumpedMsg << "D3D disassemble failed" << std::endl; - } else { - dumpedMsg << reinterpret_cast(disassembly->GetBufferPointer()); - } - DumpShadersEmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); + // If dumpShaders is false, we don't need the HLSL for logging. Clear the contents so it + // isn't stored into the cache. + if (!r.hlsl.dumpShaders) { + compiledShader.hlslSource = ""; } - - return {}; + return compiledShader; } } // anonymous namespace @@ -741,74 +549,202 @@ ResultOrError ShaderModule::Compile(const ProgrammableStage& pro SingleShaderStage stage, const PipelineLayout* layout, uint32_t compileFlags) { - TRACE_EVENT0(GetDevice()->GetPlatform(), General, "ShaderModuleD3D12::Compile"); + Device* device = ToBackend(GetDevice()); + TRACE_EVENT0(device->GetPlatform(), General, "ShaderModuleD3D12::Compile"); ASSERT(!IsError()); - ScopedTintICEHandler scopedICEHandler(GetDevice()); + ScopedTintICEHandler scopedICEHandler(device); + const EntryPointMetadata& entryPoint = GetEntryPoint(programmableStage.entryPoint); - Device* device = ToBackend(GetDevice()); + D3DCompilationRequest req = {}; + req.tracePlatform = UnsafeUnkeyedValue(device->GetPlatform()); + req.hlsl.shaderModel = device->GetDeviceInfo().shaderModel; + req.hlsl.disableSymbolRenaming = device->IsToggleEnabled(Toggle::DisableSymbolRenaming); + req.hlsl.isRobustnessEnabled = device->IsRobustnessEnabled(); + req.hlsl.disableWorkgroupInit = device->IsToggleEnabled(Toggle::DisableWorkgroupInit); + req.hlsl.dumpShaders = device->IsToggleEnabled(Toggle::DumpShaders); - CompiledShader compiledShader = {}; - - tint::transform::Manager transformManager; - tint::transform::DataMap transformInputs; - - const tint::Program* program = GetTintProgram(); - tint::Program programAsValue; - - auto externalTextureBindings = BuildExternalTextureTransformBindings(layout); - if (!externalTextureBindings.empty()) { - transformManager.Add(); - transformInputs.Add( - std::move(externalTextureBindings)); - } - - if (stage == SingleShaderStage::Vertex) { - transformManager.Add(); - transformInputs.Add( - layout->GetFirstIndexOffsetShaderRegister(), - layout->GetFirstIndexOffsetRegisterSpace()); - } - - tint::transform::DataMap transformOutputs; - DAWN_TRY_ASSIGN(programAsValue, RunTransforms(&transformManager, program, transformInputs, - &transformOutputs, nullptr)); - program = &programAsValue; - - if (stage == SingleShaderStage::Vertex) { - if (auto* data = transformOutputs.Get()) { - // TODO(dawn:549): Consider adding this information to the pipeline cache once we - // can store more than the shader blob in it. - compiledShader.usesVertexOrInstanceIndex = data->has_vertex_or_instance_index; + req.bytecode.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16); + req.bytecode.compileFlags = compileFlags; + req.bytecode.defineStrings = + GetOverridableConstantsDefines(programmableStage.constants, entryPoint.overrides); + if (device->IsToggleEnabled(Toggle::UseDXC)) { + req.bytecode.compiler = Compiler::DXC; + req.bytecode.dxcLibrary = device->GetDxcLibrary().Get(); + req.bytecode.dxcCompiler = device->GetDxcCompiler().Get(); + DAWN_TRY_ASSIGN(req.bytecode.compilerVersion, + ToBackend(device->GetAdapter())->GetBackend()->GetDXCompilerVersion()); + req.bytecode.dxcShaderProfile = device->GetDeviceInfo().shaderProfiles[stage]; + } else { + req.bytecode.compiler = Compiler::FXC; + req.bytecode.d3dCompile = device->GetFunctions()->d3dCompile; + req.bytecode.compilerVersion = D3D_COMPILER_VERSION; + switch (stage) { + case SingleShaderStage::Vertex: + req.bytecode.fxcShaderProfile = "vs_5_1"; + break; + case SingleShaderStage::Fragment: + req.bytecode.fxcShaderProfile = "ps_5_1"; + break; + case SingleShaderStage::Compute: + req.bytecode.fxcShaderProfile = "cs_5_1"; + break; } } - ShaderCompilationRequest request; - DAWN_TRY_ASSIGN(request, - ShaderCompilationRequest::Create( - programmableStage.entryPoint.c_str(), stage, layout, compileFlags, device, - program, GetEntryPoint(programmableStage.entryPoint), programmableStage)); + using tint::transform::BindingPoint; + using tint::transform::BindingRemapper; - // TODO(dawn:1341): Add shader cache key generation and caching for the compiled shader. - DAWN_TRY(CompileShader( - device->GetPlatform(), device->GetFunctions(), - device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcLibrary().Get() : nullptr, - device->IsToggleEnabled(Toggle::UseDXC) ? device->GetDxcCompiler().Get() : nullptr, - std::move(request), device->IsToggleEnabled(Toggle::DumpShaders), - [&](WGPULoggingType loggingType, const char* message) { - GetDevice()->EmitLog(loggingType, message); - }, - &compiledShader)); - return std::move(compiledShader); + BindingRemapper::BindingPoints remappedBindingPoints; + BindingRemapper::AccessControls remappedAccessControls; + + tint::writer::ArrayLengthFromUniformOptions arrayLengthFromUniform; + arrayLengthFromUniform.ubo_binding = {layout->GetDynamicStorageBufferLengthsRegisterSpace(), + layout->GetDynamicStorageBufferLengthsShaderRegister()}; + + const BindingInfoArray& moduleBindingInfo = entryPoint.bindings; + for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { + const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); + const auto& groupBindingInfo = moduleBindingInfo[group]; + + // d3d12::BindGroupLayout packs the bindings per HLSL register-space. We modify + // the Tint AST to make the "bindings" decoration match the offset chosen by + // d3d12::BindGroupLayout so that Tint produces HLSL with the correct registers + // assigned to each interface variable. + for (const auto& [binding, bindingInfo] : groupBindingInfo) { + BindingIndex bindingIndex = bgl->GetBindingIndex(binding); + BindingPoint srcBindingPoint{static_cast(group), + static_cast(binding)}; + BindingPoint dstBindingPoint{static_cast(group), + bgl->GetShaderRegister(bindingIndex)}; + if (srcBindingPoint != dstBindingPoint) { + remappedBindingPoints.emplace(srcBindingPoint, dstBindingPoint); + } + + // Declaring a read-only storage buffer in HLSL but specifying a storage + // buffer in the BGL produces the wrong output. Force read-only storage + // buffer bindings to be treated as UAV instead of SRV. Internal storage + // buffer is a storage buffer used in the internal pipeline. + const bool forceStorageBufferAsUAV = + (bindingInfo.buffer.type == wgpu::BufferBindingType::ReadOnlyStorage && + (bgl->GetBindingInfo(bindingIndex).buffer.type == + wgpu::BufferBindingType::Storage || + bgl->GetBindingInfo(bindingIndex).buffer.type == kInternalStorageBufferBinding)); + if (forceStorageBufferAsUAV) { + remappedAccessControls.emplace(srcBindingPoint, tint::ast::Access::kReadWrite); + } + } + + // Add arrayLengthFromUniform options + { + for (const auto& bindingAndRegisterOffset : + layout->GetDynamicStorageBufferLengthInfo()[group].bindingAndRegisterOffsets) { + BindingNumber binding = bindingAndRegisterOffset.binding; + uint32_t registerOffset = bindingAndRegisterOffset.registerOffset; + + BindingPoint bindingPoint{static_cast(group), + static_cast(binding)}; + // Get the renamed binding point if it was remapped. + auto it = remappedBindingPoints.find(bindingPoint); + if (it != remappedBindingPoints.end()) { + bindingPoint = it->second; + } + + arrayLengthFromUniform.bindpoint_to_size_index.emplace(bindingPoint, + registerOffset); + } + } + } + + req.hlsl.inputProgram = GetTintProgram(); + req.hlsl.entryPointName = programmableStage.entryPoint.c_str(); + req.hlsl.stage = stage; + req.hlsl.firstIndexOffsetShaderRegister = layout->GetFirstIndexOffsetShaderRegister(); + req.hlsl.firstIndexOffsetRegisterSpace = layout->GetFirstIndexOffsetRegisterSpace(); + req.hlsl.usesNumWorkgroups = entryPoint.usesNumWorkgroups; + req.hlsl.numWorkgroupsShaderRegister = layout->GetNumWorkgroupsShaderRegister(); + req.hlsl.numWorkgroupsRegisterSpace = layout->GetNumWorkgroupsRegisterSpace(); + req.hlsl.remappedBindingPoints = std::move(remappedBindingPoints); + req.hlsl.remappedAccessControls = std::move(remappedAccessControls); + req.hlsl.newBindingsMap = BuildExternalTextureTransformBindings(layout); + req.hlsl.arrayLengthFromUniform = std::move(arrayLengthFromUniform); + + CacheResult compiledShader; + DAWN_TRY_LOAD_OR_RUN(compiledShader, device, std::move(req), CompiledShader::FromBlob, + CompileShader); + + if (device->IsToggleEnabled(Toggle::DumpShaders)) { + std::ostringstream dumpedMsg; + dumpedMsg << "/* Dumped generated HLSL */" << std::endl + << compiledShader->hlslSource << std::endl; + device->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); + + if (device->IsToggleEnabled(Toggle::UseDXC)) { + dumpedMsg << "/* Dumped disassembled DXIL */" << std::endl; + D3D12_SHADER_BYTECODE code = compiledShader->GetD3D12ShaderBytecode(); + ComPtr dxcBlob; + ComPtr disassembly; + if (FAILED(device->GetDxcLibrary()->CreateBlobWithEncodingFromPinned( + code.pShaderBytecode, code.BytecodeLength, 0, &dxcBlob)) || + FAILED(device->GetDxcCompiler()->Disassemble(dxcBlob.Get(), &disassembly))) { + dumpedMsg << "DXC disassemble failed" << std::endl; + } else { + dumpedMsg << std::string_view( + static_cast(disassembly->GetBufferPointer()), + disassembly->GetBufferSize()); + } + } else { + dumpedMsg << "/* FXC compile flags */ " << std::endl + << CompileFlagsToStringFXC(compileFlags) << std::endl; + dumpedMsg << "/* Dumped disassembled DXBC */" << std::endl; + ComPtr disassembly; + D3D12_SHADER_BYTECODE code = compiledShader->GetD3D12ShaderBytecode(); + if (FAILED(device->GetFunctions()->d3dDisassemble( + code.pShaderBytecode, code.BytecodeLength, 0, nullptr, &disassembly))) { + dumpedMsg << "D3D disassemble failed" << std::endl; + } else { + dumpedMsg << std::string_view( + static_cast(disassembly->GetBufferPointer()), + disassembly->GetBufferSize()); + } + } + device->EmitLog(WGPULoggingType_Info, dumpedMsg.str().c_str()); + } + + if (BlobCache* cache = device->GetBlobCache()) { + cache->EnsureStored(compiledShader); + } + + // Clear the hlslSource. It is only used for logging and should not be used + // outside of the compilation. + CompiledShader result = compiledShader.Acquire(); + result.hlslSource = ""; + return result; } D3D12_SHADER_BYTECODE CompiledShader::GetD3D12ShaderBytecode() const { - if (compiledFXCShader != nullptr) { - return {compiledFXCShader->GetBufferPointer(), compiledFXCShader->GetBufferSize()}; - } else if (compiledDXCShader != nullptr) { - return {compiledDXCShader->GetBufferPointer(), compiledDXCShader->GetBufferSize()}; - } - UNREACHABLE(); - return {}; + return {shaderBlob.Data(), shaderBlob.Size()}; } + } // namespace dawn::native::d3d12 + +namespace dawn::native { + +// Define the implementation to store d3d12::CompiledShader into the BlobCache. +template <> +void BlobCache::Store(const CacheKey& key, const d3d12::CompiledShader& c) { + stream::ByteVectorSink sink; + c.VisitAll([&](const auto&... members) { StreamIn(&sink, members...); }); + Store(key, CreateBlob(std::move(sink))); +} + +// Define the implementation to load d3d12::CompiledShader from a Blob. +// static +ResultOrError d3d12::CompiledShader::FromBlob(Blob blob) { + stream::BlobSource source(std::move(blob)); + d3d12::CompiledShader c; + DAWN_TRY(c.VisitAll([&](auto&... members) { return StreamOut(&source, &members...); })); + return c; +} + +} // namespace dawn::native diff --git a/src/dawn/native/d3d12/ShaderModuleD3D12.h b/src/dawn/native/d3d12/ShaderModuleD3D12.h index 7f68b10221..528e1e47ba 100644 --- a/src/dawn/native/d3d12/ShaderModuleD3D12.h +++ b/src/dawn/native/d3d12/ShaderModuleD3D12.h @@ -15,8 +15,11 @@ #ifndef SRC_DAWN_NATIVE_D3D12_SHADERMODULED3D12_H_ #define SRC_DAWN_NATIVE_D3D12_SHADERMODULED3D12_H_ -#include "dawn/native/ShaderModule.h" +#include +#include "dawn/native/Blob.h" +#include "dawn/native/ShaderModule.h" +#include "dawn/native/VisitableMembers.h" #include "dawn/native/d3d12/d3d12_platform.h" namespace dawn::native { @@ -28,14 +31,22 @@ namespace dawn::native::d3d12 { class Device; class PipelineLayout; -// Manages a ref to one of the various representations of shader blobs and information used to -// emulate vertex/instance index starts +#define COMPILED_SHADER_MEMBERS(X) \ + X(Blob, shaderBlob) \ + X(std::string, hlslSource) \ + X(bool, usesVertexOrInstanceIndex) + +// `CompiledShader` holds a ref to one of the various representations of shader blobs and +// information used to emulate vertex/instance index starts. It also holds the `hlslSource` for the +// shader compilation, which is only transiently available during Compile, and cleared before it +// returns. It is not written to or loaded from the cache unless Toggle dump_shaders is true. struct CompiledShader { - ComPtr compiledFXCShader; - ComPtr compiledDXCShader; + static ResultOrError FromBlob(Blob blob); + D3D12_SHADER_BYTECODE GetD3D12ShaderBytecode() const; - bool usesVertexOrInstanceIndex; + DAWN_VISITABLE_MEMBERS(COMPILED_SHADER_MEMBERS) +#undef COMPILED_SHADER_MEMBERS }; class ShaderModule final : public ShaderModuleBase { diff --git a/src/dawn/native/d3d12/UtilsD3D12.cpp b/src/dawn/native/d3d12/UtilsD3D12.cpp index 0e761f8f1c..e706d8f658 100644 --- a/src/dawn/native/d3d12/UtilsD3D12.cpp +++ b/src/dawn/native/d3d12/UtilsD3D12.cpp @@ -81,19 +81,19 @@ bool NeedBufferSizeWorkaroundForBufferTextureCopyOnD3D12(const BufferCopy& buffe } // anonymous namespace -ResultOrError ConvertStringToWstring(const char* str) { - size_t len = strlen(str); +ResultOrError ConvertStringToWstring(std::string_view s) { + size_t len = s.length(); if (len == 0) { return std::wstring(); } - int numChars = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, len, nullptr, 0); + int numChars = MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, s.data(), len, nullptr, 0); if (numChars == 0) { return DAWN_INTERNAL_ERROR("Failed to convert string to wide string"); } std::wstring result; result.resize(numChars); int numConvertedChars = - MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, str, len, &result[0], numChars); + MultiByteToWideChar(CP_UTF8, MB_ERR_INVALID_CHARS, s.data(), len, &result[0], numChars); if (numConvertedChars != numChars) { return DAWN_INTERNAL_ERROR("Failed to convert string to wide string"); } diff --git a/src/dawn/native/d3d12/UtilsD3D12.h b/src/dawn/native/d3d12/UtilsD3D12.h index 1418f54525..dcbe782cbe 100644 --- a/src/dawn/native/d3d12/UtilsD3D12.h +++ b/src/dawn/native/d3d12/UtilsD3D12.h @@ -26,7 +26,7 @@ namespace dawn::native::d3d12 { -ResultOrError ConvertStringToWstring(const char* str); +ResultOrError ConvertStringToWstring(std::string_view s); D3D12_COMPARISON_FUNC ToD3D12ComparisonFunc(wgpu::CompareFunction func); diff --git a/src/dawn/native/stream/Stream.cpp b/src/dawn/native/stream/Stream.cpp index 1ca241c29d..beb3823da0 100644 --- a/src/dawn/native/stream/Stream.cpp +++ b/src/dawn/native/stream/Stream.cpp @@ -48,4 +48,14 @@ void Stream::Write(Sink* s, const std::string_view& t) { } } +template <> +void Stream::Write(Sink* s, const std::wstring_view& t) { + StreamIn(s, t.length()); + size_t size = t.length() * sizeof(wchar_t); + if (size > 0) { + void* ptr = s->GetSpace(size); + memcpy(ptr, t.data(), size); + } +} + } // namespace dawn::native::stream diff --git a/src/dawn/native/stream/Stream.h b/src/dawn/native/stream/Stream.h index 34333179f8..d077cccc60 100644 --- a/src/dawn/native/stream/Stream.h +++ b/src/dawn/native/stream/Stream.h @@ -297,10 +297,9 @@ class Stream> { public: static void Write(stream::Sink* sink, const std::unordered_map& m) { std::vector> ordered(m.begin(), m.end()); - std::sort(ordered.begin(), ordered.end(), - [](const std::pair& a, const std::pair& b) { - return std::less{}(a.first, b.first); - }); + std::sort( + ordered.begin(), ordered.end(), + [](const std::pair& a, const std::pair& b) { return a.first < b.first; }); StreamIn(sink, ordered); } }; diff --git a/src/dawn/tests/end2end/PipelineCachingTests.cpp b/src/dawn/tests/end2end/PipelineCachingTests.cpp index 5dcf9180e4..cbf5a92f29 100644 --- a/src/dawn/tests/end2end/PipelineCachingTests.cpp +++ b/src/dawn/tests/end2end/PipelineCachingTests.cpp @@ -108,8 +108,8 @@ class PipelineCachingTests : public DawnTest { const EntryCounts counts = { // pipeline caching is only implemented on D3D12/Vulkan IsD3D12() || IsVulkan() ? 1u : 0u, - // shader module caching is only implemented on Vulkan/Metal - IsVulkan() || IsMetal() ? 1u : 0u, + // shader module caching is only implemented on Vulkan/D3D12/Metal + IsVulkan() || IsMetal() || IsD3D12() ? 1u : 0u, }; NiceMock mMockCache; }; @@ -646,6 +646,7 @@ TEST_P(SinglePipelineCachingTests, RenderPipelineBlobCacheIsolationKey) { DAWN_INSTANTIATE_TEST(SinglePipelineCachingTests, D3D12Backend({"enable_blob_cache"}), + D3D12Backend({"enable_blob_cache", "use_dxc"}), MetalBackend({"enable_blob_cache"}), OpenGLBackend({"enable_blob_cache"}), OpenGLESBackend({"enable_blob_cache"}), diff --git a/src/dawn/tests/unittests/native/StreamTests.cpp b/src/dawn/tests/unittests/native/StreamTests.cpp index 6f257000d4..a1196f3465 100644 --- a/src/dawn/tests/unittests/native/StreamTests.cpp +++ b/src/dawn/tests/unittests/native/StreamTests.cpp @@ -174,6 +174,30 @@ TEST(SerializeTests, StdStringViews) { EXPECT_CACHE_KEY_EQ(str, expected); } +// Test that ByteVectorSink serializes std::wstring_views as expected. +TEST(SerializeTests, StdWStringViews) { + static constexpr std::wstring_view str(L"Hello world!"); + + ByteVectorSink expected; + StreamIn(&expected, size_t(str.length())); + size_t bytes = str.length() * sizeof(wchar_t); + memcpy(expected.GetSpace(bytes), str.data(), bytes); + + EXPECT_CACHE_KEY_EQ(str, expected); +} + +// Test that ByteVectorSink serializes Blobs as expected. +TEST(SerializeTests, Blob) { + uint8_t data[] = "dawn native Blob"; + Blob blob = Blob::UnsafeCreateWithDeleter(data, sizeof(data), []() {}); + + ByteVectorSink expected; + StreamIn(&expected, sizeof(data)); + expected.insert(expected.end(), data, data + sizeof(data)); + + EXPECT_CACHE_KEY_EQ(blob, expected); +} + // Test that ByteVectorSink serializes other ByteVectorSinks as expected. TEST(SerializeTests, ByteVectorSinks) { ByteVectorSink data = {'d', 'a', 't', 'a'}; @@ -309,6 +333,41 @@ TEST(StreamTests, SerializeDeserializeVisitableMembers) { } } +// Test that serializing then deserializing a Blob yields the same data. +// Tested here instead of in the type-parameterized tests since Blobs are not copyable. +TEST(StreamTests, SerializeDeserializeBlobs) { + // Test an empty blob + { + Blob blob; + EXPECT_EQ(blob.Size(), 0u); + + ByteVectorSink sink; + StreamIn(&sink, blob); + + BlobSource src(CreateBlob(sink)); + Blob out; + auto err = StreamOut(&src, &out); + EXPECT_FALSE(err.IsError()); + EXPECT_EQ(blob.Size(), out.Size()); + EXPECT_EQ(memcmp(blob.Data(), out.Data(), blob.Size()), 0); + } + + // Test a blob with some data + { + Blob blob = CreateBlob(std::vector{6.24, 3.12222}); + + ByteVectorSink sink; + StreamIn(&sink, blob); + + BlobSource src(CreateBlob(sink)); + Blob out; + auto err = StreamOut(&src, &out); + EXPECT_FALSE(err.IsError()); + EXPECT_EQ(blob.Size(), out.Size()); + EXPECT_EQ(memcmp(blob.Data(), out.Data(), blob.Size()), 0); + } +} + template std::bitset BitsetFromBitString(const char (&str)[N]) { // N - 1 because the last character is the null terminator.