diff --git a/src/dawn_native/d3d12/BackendD3D12.cpp b/src/dawn_native/d3d12/BackendD3D12.cpp index 946ce7bd98..8d6120820d 100644 --- a/src/dawn_native/d3d12/BackendD3D12.cpp +++ b/src/dawn_native/d3d12/BackendD3D12.cpp @@ -128,6 +128,16 @@ namespace dawn_native { namespace d3d12 { return mDxcCompiler.Get(); } + ResultOrError Backend::GetOrCreateDxcValidator() { + if (mDxcValidator == nullptr) { + DAWN_TRY(CheckHRESULT( + mFunctions->dxcCreateInstance(CLSID_DxcValidator, IID_PPV_ARGS(&mDxcValidator)), + "DXC create validator")); + ASSERT(mDxcValidator != nullptr); + } + return mDxcValidator.Get(); + } + const PlatformFunctions* Backend::GetFunctions() const { return mFunctions.get(); } diff --git a/src/dawn_native/d3d12/BackendD3D12.h b/src/dawn_native/d3d12/BackendD3D12.h index 87c2d13f04..0490b60175 100644 --- a/src/dawn_native/d3d12/BackendD3D12.h +++ b/src/dawn_native/d3d12/BackendD3D12.h @@ -32,6 +32,7 @@ namespace dawn_native { namespace d3d12 { ComPtr GetFactory() const; ResultOrError GetOrCreateDxcLibrary(); ResultOrError GetOrCreateDxcCompiler(); + ResultOrError GetOrCreateDxcValidator(); const PlatformFunctions* GetFunctions() const; std::vector> DiscoverDefaultAdapters() override; @@ -45,6 +46,7 @@ namespace dawn_native { namespace d3d12 { ComPtr mFactory; ComPtr mDxcLibrary; ComPtr mDxcCompiler; + ComPtr mDxcValidator; }; }} // namespace dawn_native::d3d12 diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp index 1d15eb58c1..a03b04874c 100644 --- a/src/dawn_native/d3d12/DeviceD3D12.cpp +++ b/src/dawn_native/d3d12/DeviceD3D12.cpp @@ -218,6 +218,10 @@ namespace dawn_native { namespace d3d12 { return ToBackend(GetAdapter())->GetBackend()->GetOrCreateDxcCompiler(); } + ResultOrError Device::GetOrCreateDxcValidator() const { + return ToBackend(GetAdapter())->GetBackend()->GetOrCreateDxcValidator(); + } + const PlatformFunctions* Device::GetFunctions() const { return ToBackend(GetAdapter())->GetBackend()->GetFunctions(); } diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h index 732e187986..c35df07f7c 100644 --- a/src/dawn_native/d3d12/DeviceD3D12.h +++ b/src/dawn_native/d3d12/DeviceD3D12.h @@ -74,6 +74,7 @@ namespace dawn_native { namespace d3d12 { ComPtr GetFactory() const; ResultOrError GetOrCreateDxcLibrary() const; ResultOrError GetOrCreateDxcCompiler() const; + ResultOrError GetOrCreateDxcValidator() const; ResultOrError GetPendingCommandContext(); diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp index 96bb4ba87d..b117e3f05d 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp @@ -319,9 +319,12 @@ namespace dawn_native { namespace d3d12 { // layout. The pipeline layout is only required if we key from WGSL: two different pipeline // layouts could be used to produce different shader blobs and the wrong shader blob could // be loaded since the pipeline layout was missing from the key. + // The compiler flags or version used could also produce different HLSL source. HLSL key + // needs both to ensure the shader cache key is unique to the HLSL source. // TODO(dawn:549): Consider keying from WGSL and serialize the pipeline layout it used. - const PersistentCacheKey& shaderCacheKey = - CreateHLSLKey(entryPointName, stage, hlslSource, compileFlags); + PersistentCacheKey shaderCacheKey; + DAWN_TRY_ASSIGN(shaderCacheKey, + CreateHLSLKey(entryPointName, stage, hlslSource, compileFlags)); CompiledShader compiledShader = {}; DAWN_TRY_ASSIGN(compiledShader.cachedShader, @@ -357,10 +360,10 @@ namespace dawn_native { namespace d3d12 { return {}; } - PersistentCacheKey ShaderModule::CreateHLSLKey(const char* entryPointName, - SingleShaderStage stage, - const std::string& hlslSource, - uint32_t compileFlags) const { + ResultOrError ShaderModule::CreateHLSLKey(const char* entryPointName, + SingleShaderStage stage, + const std::string& hlslSource, + uint32_t compileFlags) const { std::stringstream stream; // Prefix the key with the type to avoid collisions from another type that could have the @@ -383,7 +386,15 @@ namespace dawn_native { namespace d3d12 { stream << compileFlags; - // TODO(dawn:549): add the HLSL compiler version for good measure. + // Add the HLSL compiler version for good measure. + // Prepend the compiler name to ensure the version is always unique. + if (GetDevice()->IsToggleEnabled(Toggle::UseDXC)) { + uint64_t dxCompilerVersion; + DAWN_TRY_ASSIGN(dxCompilerVersion, GetDXCompilerVersion()); + stream << "DXC" << dxCompilerVersion; + } else { + stream << "FXC" << GetD3DCompilerVersion(); + } // If the source contains multiple entry points, ensure they are cached seperately // per stage since DX shader code can only be compiled per stage using the same @@ -394,4 +405,24 @@ namespace dawn_native { namespace d3d12 { return PersistentCacheKey(std::istreambuf_iterator{stream}, std::istreambuf_iterator{}); } + + ResultOrError ShaderModule::GetDXCompilerVersion() const { + ComPtr dxcValidator; + DAWN_TRY_ASSIGN(dxcValidator, ToBackend(GetDevice())->GetOrCreateDxcValidator()); + + ComPtr versionInfo; + DAWN_TRY(CheckHRESULT(dxcValidator.As(&versionInfo), + "D3D12 QueryInterface IDxcValidator to IDxcVersionInfo")); + + uint32_t compilerMajor, compilerMinor; + DAWN_TRY(CheckHRESULT(versionInfo->GetVersion(&compilerMajor, &compilerMinor), + "IDxcVersionInfo::GetVersion")); + + // Pack both into a single version number. + return (uint64_t(compilerMajor) << uint64_t(32)) + compilerMinor; + } + + uint64_t ShaderModule::GetD3DCompilerVersion() const { + return D3D_COMPILER_VERSION; + } }} // namespace dawn_native::d3d12 diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.h b/src/dawn_native/d3d12/ShaderModuleD3D12.h index f5a94f0f21..1f04759890 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.h +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.h @@ -59,10 +59,13 @@ namespace dawn_native { namespace d3d12 { SingleShaderStage stage, PipelineLayout* layout) const; - PersistentCacheKey CreateHLSLKey(const char* entryPointName, - SingleShaderStage stage, - const std::string& hlslSource, - uint32_t compileFlags) const; + ResultOrError CreateHLSLKey(const char* entryPointName, + SingleShaderStage stage, + const std::string& hlslSource, + uint32_t compileFlags) const; + + ResultOrError GetDXCompilerVersion() const; + uint64_t GetD3DCompilerVersion() const; #ifdef DAWN_ENABLE_WGSL std::unique_ptr mTintModule;