D3D12: Make HLSL generation per-entrypoint.

Also make the CompileShaderDXC/FXC standalone functions because
they don't use ShaderModule except to GetDevice().

Bug: dawn:216
Change-Id: Iaec9abe52ad4422891474086c3b973baf07046a5
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/28243
Commit-Queue: Kai Ninomiya <kainino@chromium.org>
Reviewed-by: Kai Ninomiya <kainino@chromium.org>
This commit is contained in:
Corentin Wallez 2020-09-09 22:55:17 +00:00 committed by Commit Bot service account
parent 28efed139f
commit b8712c01c1
4 changed files with 117 additions and 113 deletions

View File

@ -42,8 +42,12 @@ namespace dawn_native { namespace d3d12 {
compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR; compileFlags |= D3DCOMPILE_PACK_MATRIX_ROW_MAJOR;
ShaderModule* module = ToBackend(descriptor->computeStage.module); ShaderModule* module = ToBackend(descriptor->computeStage.module);
// Note that the HLSL will always use entryPoint "main".
std::string hlslSource; std::string hlslSource;
DAWN_TRY_ASSIGN(hlslSource, module->GetHLSLSource(ToBackend(GetLayout()))); DAWN_TRY_ASSIGN(hlslSource, module->TranslateToHLSL(descriptor->computeStage.entryPoint,
SingleShaderStage::Compute,
ToBackend(GetLayout())));
D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {}; D3D12_COMPUTE_PIPELINE_STATE_DESC d3dDesc = {};
d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature(); d3dDesc.pRootSignature = ToBackend(GetLayout())->GetRootSignature();
@ -52,18 +56,14 @@ namespace dawn_native { namespace d3d12 {
ComPtr<ID3DBlob> compiledFXCShader; ComPtr<ID3DBlob> compiledFXCShader;
if (device->IsToggleEnabled(Toggle::UseDXC)) { if (device->IsToggleEnabled(Toggle::UseDXC)) {
DAWN_TRY_ASSIGN( DAWN_TRY_ASSIGN(compiledDXCShader, CompileShaderDXC(device, SingleShaderStage::Compute,
compiledDXCShader, hlslSource, "main", compileFlags));
module->CompileShaderDXC(SingleShaderStage::Compute, hlslSource,
descriptor->computeStage.entryPoint, compileFlags));
d3dDesc.CS.pShaderBytecode = compiledDXCShader->GetBufferPointer(); d3dDesc.CS.pShaderBytecode = compiledDXCShader->GetBufferPointer();
d3dDesc.CS.BytecodeLength = compiledDXCShader->GetBufferSize(); d3dDesc.CS.BytecodeLength = compiledDXCShader->GetBufferSize();
} else { } else {
DAWN_TRY_ASSIGN( DAWN_TRY_ASSIGN(compiledFXCShader, CompileShaderFXC(device, SingleShaderStage::Compute,
compiledFXCShader, hlslSource, "main", compileFlags));
module->CompileShaderFXC(SingleShaderStage::Compute, hlslSource,
descriptor->computeStage.entryPoint, compileFlags));
d3dDesc.CS.pShaderBytecode = compiledFXCShader->GetBufferPointer(); d3dDesc.CS.pShaderBytecode = compiledFXCShader->GetBufferPointer();
d3dDesc.CS.BytecodeLength = compiledFXCShader->GetBufferSize(); d3dDesc.CS.BytecodeLength = compiledFXCShader->GetBufferSize();
} }

View File

@ -327,20 +327,21 @@ namespace dawn_native { namespace d3d12 {
wgpu::ShaderStage renderStages = wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment; wgpu::ShaderStage renderStages = wgpu::ShaderStage::Vertex | wgpu::ShaderStage::Fragment;
for (auto stage : IterateStages(renderStages)) { for (auto stage : IterateStages(renderStages)) {
// Note that the HLSL entryPoint will always be "main".
std::string hlslSource; std::string hlslSource;
DAWN_TRY_ASSIGN(hlslSource, modules[stage]->GetHLSLSource(ToBackend(GetLayout()))); DAWN_TRY_ASSIGN(hlslSource,
modules[stage]->TranslateToHLSL(GetStage(stage).entryPoint.c_str(),
stage, ToBackend(GetLayout())));
if (device->IsToggleEnabled(Toggle::UseDXC)) { if (device->IsToggleEnabled(Toggle::UseDXC)) {
DAWN_TRY_ASSIGN(compiledDXCShader[stage], DAWN_TRY_ASSIGN(compiledDXCShader[stage],
modules[stage]->CompileShaderDXC(stage, hlslSource, CompileShaderDXC(device, stage, hlslSource, "main", compileFlags));
entryPoints[stage], compileFlags));
shaders[stage]->pShaderBytecode = compiledDXCShader[stage]->GetBufferPointer(); shaders[stage]->pShaderBytecode = compiledDXCShader[stage]->GetBufferPointer();
shaders[stage]->BytecodeLength = compiledDXCShader[stage]->GetBufferSize(); shaders[stage]->BytecodeLength = compiledDXCShader[stage]->GetBufferSize();
} else { } else {
DAWN_TRY_ASSIGN(compiledFXCShader[stage], DAWN_TRY_ASSIGN(compiledFXCShader[stage],
modules[stage]->CompileShaderFXC(stage, hlslSource, CompileShaderFXC(device, stage, hlslSource, "main", compileFlags));
entryPoints[stage], compileFlags));
shaders[stage]->pShaderBytecode = compiledFXCShader[stage]->GetBufferPointer(); shaders[stage]->pShaderBytecode = compiledFXCShader[stage]->GetBufferPointer();
shaders[stage]->BytecodeLength = compiledFXCShader[stage]->GetBufferSize(); shaders[stage]->BytecodeLength = compiledFXCShader[stage]->GetBufferSize();

View File

@ -17,6 +17,7 @@
#include "common/Assert.h" #include "common/Assert.h"
#include "common/BitSetIterator.h" #include "common/BitSetIterator.h"
#include "common/Log.h" #include "common/Log.h"
#include "dawn_native/SpirvUtils.h"
#include "dawn_native/d3d12/BindGroupLayoutD3D12.h" #include "dawn_native/d3d12/BindGroupLayoutD3D12.h"
#include "dawn_native/d3d12/D3D12Error.h" #include "dawn_native/d3d12/D3D12Error.h"
#include "dawn_native/d3d12/DeviceD3D12.h" #include "dawn_native/d3d12/DeviceD3D12.h"
@ -84,11 +85,90 @@ namespace dawn_native { namespace d3d12 {
} // anonymous namespace } // anonymous namespace
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(Device* device,
SingleShaderStage stage,
const std::string& hlslSource,
const char* entryPoint,
uint32_t compileFlags) {
IDxcLibrary* dxcLibrary;
DAWN_TRY_ASSIGN(dxcLibrary, device->GetOrCreateDxcLibrary());
ComPtr<IDxcBlobEncoding> sourceBlob;
DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
"DXC create blob"));
IDxcCompiler* dxcCompiler;
DAWN_TRY_ASSIGN(dxcCompiler, device->GetOrCreateDxcCompiler());
std::wstring entryPointW;
DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPoint));
std::vector<const wchar_t*> arguments =
GetDXCArguments(compileFlags, device->IsExtensionEnabled(Extension::ShaderFloat16));
ComPtr<IDxcOperationResult> result;
DAWN_TRY(CheckHRESULT(
dxcCompiler->Compile(sourceBlob.Get(), nullptr, entryPointW.c_str(),
device->GetDeviceInfo().shaderProfiles[stage].c_str(),
arguments.data(), arguments.size(), nullptr, 0, nullptr, &result),
"DXC compile"));
HRESULT hr;
DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
if (FAILED(hr)) {
ComPtr<IDxcBlobEncoding> errors;
DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
std::string message = std::string("DXC compile failed with ") +
static_cast<char*>(errors->GetBufferPointer());
return DAWN_INTERNAL_ERROR(message);
}
ComPtr<IDxcBlob> compiledShader;
DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
return std::move(compiledShader);
}
ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(Device* device,
SingleShaderStage stage,
const std::string& hlslSource,
const char* entryPoint,
uint32_t compileFlags) {
const char* targetProfile = nullptr;
switch (stage) {
case SingleShaderStage::Vertex:
targetProfile = "vs_5_1";
break;
case SingleShaderStage::Fragment:
targetProfile = "ps_5_1";
break;
case SingleShaderStage::Compute:
targetProfile = "cs_5_1";
break;
}
ComPtr<ID3DBlob> compiledShader;
ComPtr<ID3DBlob> errors;
const PlatformFunctions* functions = device->GetFunctions();
if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
nullptr, entryPoint, targetProfile, compileFlags, 0,
&compiledShader, &errors))) {
std::string message = std::string("D3D compile failed with ") +
static_cast<char*>(errors->GetBufferPointer());
return DAWN_INTERNAL_ERROR(message);
}
return std::move(compiledShader);
}
// static // static
ResultOrError<ShaderModule*> ShaderModule::Create(Device* device, ResultOrError<ShaderModule*> ShaderModule::Create(Device* device,
const ShaderModuleDescriptor* descriptor) { const ShaderModuleDescriptor* descriptor) {
Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor)); Ref<ShaderModule> module = AcquireRef(new ShaderModule(device, descriptor));
DAWN_TRY(module->Initialize()); DAWN_TRY(module->InitializeBase());
return module.Detach(); return module.Detach();
} }
@ -96,14 +176,10 @@ namespace dawn_native { namespace d3d12 {
: ShaderModuleBase(device, descriptor) { : ShaderModuleBase(device, descriptor) {
} }
MaybeError ShaderModule::Initialize() { ResultOrError<std::string> ShaderModule::TranslateToHLSL(const char* entryPointName,
return InitializeBase(); SingleShaderStage stage,
} PipelineLayout* layout) const {
ResultOrError<std::string> ShaderModule::GetHLSLSource(PipelineLayout* layout) {
ASSERT(!IsError()); ASSERT(!IsError());
const std::vector<uint32_t>& spirv = GetSpirv();
// If these options are changed, the values in DawnSPIRVCrossHLSLFastFuzzer.cpp need to // If these options are changed, the values in DawnSPIRVCrossHLSLFastFuzzer.cpp need to
// be updated. // be updated.
spirv_cross::CompilerGLSL::Options options_glsl; spirv_cross::CompilerGLSL::Options options_glsl;
@ -127,12 +203,13 @@ namespace dawn_native { namespace d3d12 {
options_hlsl.point_size_compat = true; options_hlsl.point_size_compat = true;
options_hlsl.nonwritable_uav_texture_as_srv = true; options_hlsl.nonwritable_uav_texture_as_srv = true;
spirv_cross::CompilerHLSL compiler(spirv); spirv_cross::CompilerHLSL compiler(GetSpirv());
compiler.set_common_options(options_glsl); compiler.set_common_options(options_glsl);
compiler.set_hlsl_options(options_hlsl); compiler.set_hlsl_options(options_hlsl);
compiler.set_entry_point(entryPointName, ShaderStageToExecutionModel(stage));
const EntryPointMetadata::BindingInfo& moduleBindingInfo = const EntryPointMetadata::BindingInfo& moduleBindingInfo =
GetEntryPoint("main", GetMainEntryPointStageForTransition()).bindings; GetEntryPoint(entryPointName, stage).bindings;
for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) { for (BindGroupIndex group : IterateBitSet(layout->GetBindGroupLayoutsMask())) {
const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group)); const BindGroupLayout* bgl = ToBackend(layout->GetBindGroupLayout(group));
@ -158,85 +235,8 @@ namespace dawn_native { namespace d3d12 {
} }
} }
} }
return compiler.compile(); return compiler.compile();
} }
ResultOrError<ComPtr<IDxcBlob>> ShaderModule::CompileShaderDXC(SingleShaderStage stage,
const std::string& hlslSource,
const char* entryPoint,
uint32_t compileFlags) {
IDxcLibrary* dxcLibrary;
DAWN_TRY_ASSIGN(dxcLibrary, ToBackend(GetDevice())->GetOrCreateDxcLibrary());
ComPtr<IDxcBlobEncoding> sourceBlob;
DAWN_TRY(CheckHRESULT(dxcLibrary->CreateBlobWithEncodingOnHeapCopy(
hlslSource.c_str(), hlslSource.length(), CP_UTF8, &sourceBlob),
"DXC create blob"));
IDxcCompiler* dxcCompiler;
DAWN_TRY_ASSIGN(dxcCompiler, ToBackend(GetDevice())->GetOrCreateDxcCompiler());
std::wstring entryPointW;
DAWN_TRY_ASSIGN(entryPointW, ConvertStringToWstring(entryPoint));
std::vector<const wchar_t*> arguments = GetDXCArguments(
compileFlags, GetDevice()->IsExtensionEnabled(Extension::ShaderFloat16));
ComPtr<IDxcOperationResult> result;
DAWN_TRY(
CheckHRESULT(dxcCompiler->Compile(
sourceBlob.Get(), nullptr, entryPointW.c_str(),
ToBackend(GetDevice())->GetDeviceInfo().shaderProfiles[stage].c_str(),
arguments.data(), arguments.size(), nullptr, 0, nullptr, &result),
"DXC compile"));
HRESULT hr;
DAWN_TRY(CheckHRESULT(result->GetStatus(&hr), "DXC get status"));
if (FAILED(hr)) {
ComPtr<IDxcBlobEncoding> errors;
DAWN_TRY(CheckHRESULT(result->GetErrorBuffer(&errors), "DXC get error buffer"));
std::string message = std::string("DXC compile failed with ") +
static_cast<char*>(errors->GetBufferPointer());
return DAWN_INTERNAL_ERROR(message);
}
ComPtr<IDxcBlob> compiledShader;
DAWN_TRY(CheckHRESULT(result->GetResult(&compiledShader), "DXC get result"));
return std::move(compiledShader);
}
ResultOrError<ComPtr<ID3DBlob>> ShaderModule::CompileShaderFXC(SingleShaderStage stage,
const std::string& hlslSource,
const char* entryPoint,
uint32_t compileFlags) {
const char* targetProfile = nullptr;
switch (stage) {
case SingleShaderStage::Vertex:
targetProfile = "vs_5_1";
break;
case SingleShaderStage::Fragment:
targetProfile = "ps_5_1";
break;
case SingleShaderStage::Compute:
targetProfile = "cs_5_1";
break;
}
ComPtr<ID3DBlob> compiledShader;
ComPtr<ID3DBlob> errors;
const PlatformFunctions* functions = ToBackend(GetDevice())->GetFunctions();
if (FAILED(functions->d3dCompile(hlslSource.c_str(), hlslSource.length(), nullptr, nullptr,
nullptr, entryPoint, targetProfile, compileFlags, 0,
&compiledShader, &errors))) {
std::string message = std::string("D3D compile failed with ") +
static_cast<char*>(errors->GetBufferPointer());
return DAWN_INTERNAL_ERROR(message);
}
return std::move(compiledShader);
}
}} // namespace dawn_native::d3d12 }} // namespace dawn_native::d3d12

View File

@ -24,26 +24,29 @@ namespace dawn_native { namespace d3d12 {
class Device; class Device;
class PipelineLayout; class PipelineLayout;
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(Device* device,
SingleShaderStage stage,
const std::string& hlslSource,
const char* entryPoint,
uint32_t compileFlags);
ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(Device* device,
SingleShaderStage stage,
const std::string& hlslSource,
const char* entryPoint,
uint32_t compileFlags);
class ShaderModule final : public ShaderModuleBase { class ShaderModule final : public ShaderModuleBase {
public: public:
static ResultOrError<ShaderModule*> Create(Device* device, static ResultOrError<ShaderModule*> Create(Device* device,
const ShaderModuleDescriptor* descriptor); const ShaderModuleDescriptor* descriptor);
ResultOrError<std::string> GetHLSLSource(PipelineLayout* layout); ResultOrError<std::string> TranslateToHLSL(const char* entryPointName,
SingleShaderStage stage,
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(SingleShaderStage stage, PipelineLayout* layout) const;
const std::string& hlslSource,
const char* entryPoint,
uint32_t compileFlags);
ResultOrError<ComPtr<ID3DBlob>> CompileShaderFXC(SingleShaderStage stage,
const std::string& hlslSource,
const char* entryPoint,
uint32_t compileFlags);
private: private:
ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor); ShaderModule(Device* device, const ShaderModuleDescriptor* descriptor);
~ShaderModule() override = default; ~ShaderModule() override = default;
MaybeError Initialize();
}; };
}} // namespace dawn_native::d3d12 }} // namespace dawn_native::d3d12