Refactor D3D12 shader define strings code

Instead of using template functions, generate define string pairs
in std::string. For DXC path, convert them to std::wstring. It is
okay since shader generation is already expensive.

By the way generalize it to all kinds of defines in
ShaderCompilationRequest.

Bug: dawn:1137
Change-Id: I5518e992b56497e28c8ac7e818bf19b4853dee4a
Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/70120
Commit-Queue: Shrek Shao <shrekshao@google.com>
Reviewed-by: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
shrekshao 2021-11-19 19:39:18 +00:00 committed by Dawn LUCI CQ
parent 98e42aaa5a
commit 36200d1943
1 changed files with 86 additions and 159 deletions

View File

@ -34,94 +34,6 @@
#include <sstream>
#include <unordered_map>
namespace dawn_native {
template <typename StringType, typename T = int32_t>
struct NumberToString {
static StringType ToStringAsValue(T v);
static StringType ToStringAsId(T v);
};
template <typename T>
struct NumberToString<std::string, T> {
static constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
static std::string ToStringAsValue(T v) {
return std::to_string(v);
}
static std::string ToStringAsId(T v) {
return std::to_string(v);
}
};
template <typename T>
struct NumberToString<std::wstring, T> {
static constexpr WCHAR kSpecConstantPrefix[] = L"WGSL_SPEC_CONSTANT_";
static std::wstring ToStringAsValue(T v) {
return std::to_wstring(v);
}
static std::wstring ToStringAsId(T v) {
return std::to_wstring(v);
}
};
template <>
struct NumberToString<std::string, float> {
static std::string ToStringAsValue(float v) {
std::ostringstream out;
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
out.precision(8);
out << std::fixed << v;
return out.str();
}
};
template <>
struct NumberToString<std::wstring, float> {
static std::wstring ToStringAsValue(float v) {
std::basic_ostringstream<WCHAR> out;
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
out.precision(8);
out << std::fixed << v;
return out.str();
}
};
template <>
struct NumberToString<std::string, uint32_t> {
static std::string ToStringAsValue(uint32_t v) {
return std::to_string(v) + "u";
}
};
template <>
struct NumberToString<std::wstring, uint32_t> {
static std::wstring ToStringAsValue(uint32_t v) {
return std::to_wstring(v) + L"u";
}
};
template <typename StringType>
StringType GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
const OverridableConstantScalar* entry,
double value = 0) {
switch (dawnType) {
case EntryPointMetadata::OverridableConstant::Type::Boolean:
return NumberToString<StringType, int32_t>::ToStringAsValue(
entry ? entry->b : static_cast<int32_t>(value));
case EntryPointMetadata::OverridableConstant::Type::Float32:
return NumberToString<StringType, float>::ToStringAsValue(
entry ? entry->f32 : static_cast<float>(value));
case EntryPointMetadata::OverridableConstant::Type::Int32:
return NumberToString<StringType, int32_t>::ToStringAsValue(
entry ? entry->i32 : static_cast<int32_t>(value));
case EntryPointMetadata::OverridableConstant::Type::Uint32:
return NumberToString<StringType, uint32_t>::ToStringAsValue(
entry ? entry->u32 : static_cast<uint32_t>(value));
default:
UNREACHABLE();
}
}
} // namespace dawn_native
namespace dawn_native { namespace d3d12 {
namespace {
@ -181,6 +93,73 @@ namespace dawn_native { namespace d3d12 {
output << ")";
}
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
std::ostringstream out;
out.precision(n);
out << std::fixed << v;
return out.str();
}
std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
const OverridableConstantScalar* entry,
double value = 0) {
switch (dawnType) {
case EntryPointMetadata::OverridableConstant::Type::Boolean:
return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
case EntryPointMetadata::OverridableConstant::Type::Float32:
return FloatToStringWithPrecision(entry ? entry->f32
: static_cast<float>(value));
case EntryPointMetadata::OverridableConstant::Type::Int32:
return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
case EntryPointMetadata::OverridableConstant::Type::Uint32:
return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
default:
UNREACHABLE();
}
}
constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
void GetOverridableConstantsDefines(
std::vector<std::pair<std::string, std::string>>* defineStrings,
const PipelineConstantEntries* pipelineConstantEntries,
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
std::unordered_set<std::string> overriddenConstants;
// Set pipeline overridden values
for (const auto& pipelineConstant : *pipelineConstantEntries) {
const std::string& name = pipelineConstant.first;
double value = pipelineConstant.second;
overriddenConstants.insert(name);
// This is already validated so `name` must exist
const auto& moduleConstant = shaderEntryPointConstants->at(name);
defineStrings->emplace_back(
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString(moduleConstant.type, nullptr, value));
}
// Set shader initialized default values
for (const auto& iter : *shaderEntryPointConstants) {
const std::string& name = iter.first;
if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value
continue;
}
const auto& moduleConstant = shaderEntryPointConstants->at(name);
// Uninitialized default values are okay since they ar only defined to pass
// compilation but not used
defineStrings->emplace_back(
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
}
}
// The inputs to a shader compilation. These have been intentionally isolated from the
// device to help ensure that the pipeline cache key contains all inputs for compilation.
struct ShaderCompilationRequest {
@ -199,8 +178,7 @@ namespace dawn_native { namespace d3d12 {
bool usesNumWorkgroups;
uint32_t numWorkgroupsRegisterSpace;
uint32_t numWorkgroupsShaderRegister;
const PipelineConstantEntries* pipelineConstantEntries;
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants;
std::vector<std::pair<std::string, std::string>> defineStrings;
// FXC/DXC common inputs
bool disableWorkgroupInit;
@ -293,10 +271,12 @@ namespace dawn_native { namespace d3d12 {
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
request.deviceInfo = &device->GetDeviceInfo();
request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
request.pipelineConstantEntries = &programmableStage.constants;
request.shaderEntryPointConstants =
GetOverridableConstantsDefines(
&request.defineStrings, &programmableStage.constants,
&programmableStage.module->GetEntryPoint(programmableStage.entryPoint)
.overridableConstants;
.overridableConstants);
return std::move(request);
}
@ -349,17 +329,9 @@ namespace dawn_native { namespace d3d12 {
stream << " dxcVersion=" << dxcVersion;
stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
stream << " overridableConstants={";
for (const auto& pipelineConstant : *pipelineConstantEntries) {
const std::string& name = pipelineConstant.first;
double value = pipelineConstant.second;
// This is already validated so `name` must exist
const auto& moduleConstant = shaderEntryPointConstants->at(name);
stream << " <" << name << ","
<< GetHLSLValueString<std::string>(moduleConstant.type, nullptr, value)
<< ">";
stream << " defines={";
for (const auto& it : defineStrings) {
stream << " <" << it.first << "," << it.second << ">";
}
stream << " }";
@ -422,53 +394,6 @@ namespace dawn_native { namespace d3d12 {
return arguments;
}
template <typename StringType>
const std::vector<std::pair<StringType, StringType>> GetOverridableConstantsDefines(
const PipelineConstantEntries* pipelineConstantEntries,
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
std::vector<std::pair<StringType, StringType>> defineStrings;
std::unordered_set<std::string> overriddenConstants;
// Set pipeline overridden values
for (const auto& pipelineConstant : *pipelineConstantEntries) {
const std::string& name = pipelineConstant.first;
double value = pipelineConstant.second;
overriddenConstants.insert(name);
// This is already validated so `name` must exist
const auto& moduleConstant = shaderEntryPointConstants->at(name);
defineStrings.emplace_back(
NumberToString<StringType>::kSpecConstantPrefix +
NumberToString<StringType>::ToStringAsId(
static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString<StringType>(moduleConstant.type, nullptr, value));
}
// Set shader initialized default values
for (const auto& iter : *shaderEntryPointConstants) {
const std::string& name = iter.first;
if (overriddenConstants.count(name) != 0) {
// This constant already has overridden value
continue;
}
const auto& moduleConstant = shaderEntryPointConstants->at(name);
// Uninitialized default values are okay since they are only defined to pass
// compilation but not used
defineStrings.emplace_back(NumberToString<StringType>::kSpecConstantPrefix +
NumberToString<StringType>::ToStringAsId(
static_cast<int32_t>(moduleConstant.id)),
GetHLSLValueString<StringType>(
moduleConstant.type, &moduleConstant.defaultValue));
}
return defineStrings;
}
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
IDxcCompiler* dxcCompiler,
const ShaderCompilationRequest& request,
@ -486,8 +411,13 @@ namespace dawn_native { namespace d3d12 {
GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);
// Build defines for overridable constants
const auto& defineStrings = GetOverridableConstantsDefines<std::wstring>(
request.pipelineConstantEntries, request.shaderEntryPointConstants);
std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
defineStrings.reserve(request.defineStrings.size());
for (const auto& it : request.defineStrings) {
defineStrings.emplace_back(UTF8ToWStr(it.first.c_str()),
UTF8ToWStr(it.second.c_str()));
}
std::vector<DxcDefine> dxcDefines;
dxcDefines.reserve(defineStrings.size());
for (const auto& d : defineStrings) {
@ -538,14 +468,11 @@ namespace dawn_native { namespace d3d12 {
ComPtr<ID3DBlob> errors;
// Build defines for overridable constants
const auto& defineStrings = GetOverridableConstantsDefines<std::string>(
request.pipelineConstantEntries, request.shaderEntryPointConstants);
const D3D_SHADER_MACRO* pDefines = nullptr;
std::vector<D3D_SHADER_MACRO> fxcDefines;
if (defineStrings.size() > 0) {
fxcDefines.reserve(defineStrings.size() + 1);
for (const auto& d : defineStrings) {
if (request.defineStrings.size() > 0) {
fxcDefines.reserve(request.defineStrings.size() + 1);
for (const auto& d : request.defineStrings) {
fxcDefines.push_back({d.first.c_str(), d.second.c_str()});
}
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array