diff --git a/src/dawn_native/d3d12/BackendD3D12.cpp b/src/dawn_native/d3d12/BackendD3D12.cpp index 57548c7ef8..5c0cca6518 100644 --- a/src/dawn_native/d3d12/BackendD3D12.cpp +++ b/src/dawn_native/d3d12/BackendD3D12.cpp @@ -102,34 +102,49 @@ namespace dawn_native { namespace d3d12 { return mFactory; } - ResultOrError Backend::GetOrCreateDxcLibrary() { + MaybeError Backend::EnsureDxcLibrary() { if (mDxcLibrary == nullptr) { DAWN_TRY(CheckHRESULT( mFunctions->dxcCreateInstance(CLSID_DxcLibrary, IID_PPV_ARGS(&mDxcLibrary)), "DXC create library")); ASSERT(mDxcLibrary != nullptr); } - return mDxcLibrary.Get(); + return {}; } - ResultOrError Backend::GetOrCreateDxcCompiler() { + MaybeError Backend::EnsureDxcCompiler() { if (mDxcCompiler == nullptr) { DAWN_TRY(CheckHRESULT( mFunctions->dxcCreateInstance(CLSID_DxcCompiler, IID_PPV_ARGS(&mDxcCompiler)), "DXC create compiler")); ASSERT(mDxcCompiler != nullptr); } - return mDxcCompiler.Get(); + return {}; } - ResultOrError Backend::GetOrCreateDxcValidator() { + MaybeError Backend::EnsureDxcValidator() { if (mDxcValidator == nullptr) { DAWN_TRY(CheckHRESULT( mFunctions->dxcCreateInstance(CLSID_DxcValidator, IID_PPV_ARGS(&mDxcValidator)), "DXC create validator")); ASSERT(mDxcValidator != nullptr); } - return mDxcValidator.Get(); + return {}; + } + + ComPtr Backend::GetDxcLibrary() const { + ASSERT(mDxcLibrary != nullptr); + return mDxcLibrary; + } + + ComPtr Backend::GetDxcCompiler() const { + ASSERT(mDxcCompiler != nullptr); + return mDxcCompiler; + } + + ComPtr Backend::GetDxcValidator() const { + ASSERT(mDxcValidator != nullptr); + return mDxcValidator; } const PlatformFunctions* Backend::GetFunctions() const { diff --git a/src/dawn_native/d3d12/BackendD3D12.h b/src/dawn_native/d3d12/BackendD3D12.h index 0490b60175..17f77ccec3 100644 --- a/src/dawn_native/d3d12/BackendD3D12.h +++ b/src/dawn_native/d3d12/BackendD3D12.h @@ -30,9 +30,14 @@ namespace dawn_native { namespace d3d12 { MaybeError Initialize(); ComPtr GetFactory() const; - ResultOrError GetOrCreateDxcLibrary(); - ResultOrError GetOrCreateDxcCompiler(); - ResultOrError GetOrCreateDxcValidator(); + + MaybeError EnsureDxcLibrary(); + MaybeError EnsureDxcCompiler(); + MaybeError EnsureDxcValidator(); + ComPtr GetDxcLibrary() const; + ComPtr GetDxcCompiler() const; + ComPtr GetDxcValidator() const; + const PlatformFunctions* GetFunctions() const; std::vector> DiscoverDefaultAdapters() override; diff --git a/src/dawn_native/d3d12/DeviceD3D12.cpp b/src/dawn_native/d3d12/DeviceD3D12.cpp index bb8320e04e..0ef0c124c9 100644 --- a/src/dawn_native/d3d12/DeviceD3D12.cpp +++ b/src/dawn_native/d3d12/DeviceD3D12.cpp @@ -160,7 +160,7 @@ namespace dawn_native { namespace d3d12 { // The environment can only use DXC when it's available. Override the decision if it is not // applicable. - ApplyUseDxcToggle(); + DAWN_TRY(ApplyUseDxcToggle()); return {}; } @@ -196,25 +196,33 @@ namespace dawn_native { namespace d3d12 { return ToBackend(GetAdapter())->GetBackend()->GetFactory(); } - void Device::ApplyUseDxcToggle() { + MaybeError Device::ApplyUseDxcToggle() { if (!ToBackend(GetAdapter())->GetBackend()->GetFunctions()->IsDXCAvailable()) { ForceSetToggle(Toggle::UseDXC, false); } else if (IsExtensionEnabled(Extension::ShaderFloat16)) { // Currently we can only use DXC to compile HLSL shaders using float16. ForceSetToggle(Toggle::UseDXC, true); } + + if (IsToggleEnabled(Toggle::UseDXC)) { + DAWN_TRY(ToBackend(GetAdapter())->GetBackend()->EnsureDxcCompiler()); + DAWN_TRY(ToBackend(GetAdapter())->GetBackend()->EnsureDxcLibrary()); + DAWN_TRY(ToBackend(GetAdapter())->GetBackend()->EnsureDxcValidator()); + } + + return {}; } - ResultOrError Device::GetOrCreateDxcLibrary() const { - return ToBackend(GetAdapter())->GetBackend()->GetOrCreateDxcLibrary(); + ComPtr Device::GetDxcLibrary() const { + return ToBackend(GetAdapter())->GetBackend()->GetDxcLibrary(); } - ResultOrError Device::GetOrCreateDxcCompiler() const { - return ToBackend(GetAdapter())->GetBackend()->GetOrCreateDxcCompiler(); + ComPtr Device::GetDxcCompiler() const { + return ToBackend(GetAdapter())->GetBackend()->GetDxcCompiler(); } - ResultOrError Device::GetOrCreateDxcValidator() const { - return ToBackend(GetAdapter())->GetBackend()->GetOrCreateDxcValidator(); + ComPtr Device::GetDxcValidator() const { + return ToBackend(GetAdapter())->GetBackend()->GetDxcValidator(); } const PlatformFunctions* Device::GetFunctions() const { diff --git a/src/dawn_native/d3d12/DeviceD3D12.h b/src/dawn_native/d3d12/DeviceD3D12.h index 4819dd4f92..915430c019 100644 --- a/src/dawn_native/d3d12/DeviceD3D12.h +++ b/src/dawn_native/d3d12/DeviceD3D12.h @@ -65,9 +65,9 @@ namespace dawn_native { namespace d3d12 { const PlatformFunctions* GetFunctions() const; ComPtr GetFactory() const; - ResultOrError GetOrCreateDxcLibrary() const; - ResultOrError GetOrCreateDxcCompiler() const; - ResultOrError GetOrCreateDxcValidator() const; + ComPtr GetDxcLibrary() const; + ComPtr GetDxcCompiler() const; + ComPtr GetDxcValidator() const; ResultOrError GetPendingCommandContext(); @@ -177,7 +177,7 @@ namespace dawn_native { namespace d3d12 { MaybeError CheckDebugLayerAndGenerateErrors(); - void ApplyUseDxcToggle(); + MaybeError ApplyUseDxcToggle(); ComPtr mFence; HANDLE mFenceEvent = nullptr; diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp index 327e3a1648..33d85a5bda 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp @@ -97,16 +97,14 @@ namespace dawn_native { namespace d3d12 { const std::string& hlslSource, const char* entryPoint, uint32_t compileFlags) { - IDxcLibrary* dxcLibrary; - DAWN_TRY_ASSIGN(dxcLibrary, device->GetOrCreateDxcLibrary()); + ComPtr dxcLibrary = device->GetDxcLibrary(); ComPtr sourceBlob; DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy( hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob), "DXC create blob")); - IDxcCompiler* dxcCompiler; - DAWN_TRY_ASSIGN(dxcCompiler, device->GetOrCreateDxcCompiler()); + ComPtr dxcCompiler = device->GetDxcCompiler(); std::wstring entryPointW; DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPoint)); @@ -478,8 +476,7 @@ namespace dawn_native { namespace d3d12 { } ResultOrError ShaderModule::GetDXCompilerVersion() const { - ComPtr dxcValidator; - DAWN_TRY_ASSIGN(dxcValidator, ToBackend(GetDevice())->GetOrCreateDxcValidator()); + ComPtr dxcValidator = ToBackend(GetDevice())->GetDxcValidator(); ComPtr versionInfo; DAWN_TRY(CheckHRESULT(dxcValidator.As(&versionInfo),