Refactor D3D12 shader define strings code
Instead of using template functions, generate define string pairs in std::string. For DXC path, convert them to std::wstring. It is okay since shader generation is already expensive. By the way generalize it to all kinds of defines in ShaderCompilationRequest. Bug: dawn:1137 Change-Id: I5518e992b56497e28c8ac7e818bf19b4853dee4a Reviewed-on: https://dawn-review.googlesource.com/c/dawn/+/70120 Commit-Queue: Shrek Shao <shrekshao@google.com> Reviewed-by: Corentin Wallez <cwallez@chromium.org>
This commit is contained in:
parent
98e42aaa5a
commit
36200d1943
|
@ -34,94 +34,6 @@
|
|||
#include <sstream>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace dawn_native {
|
||||
template <typename StringType, typename T = int32_t>
|
||||
struct NumberToString {
|
||||
static StringType ToStringAsValue(T v);
|
||||
static StringType ToStringAsId(T v);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct NumberToString<std::string, T> {
|
||||
static constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
|
||||
static std::string ToStringAsValue(T v) {
|
||||
return std::to_string(v);
|
||||
}
|
||||
static std::string ToStringAsId(T v) {
|
||||
return std::to_string(v);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct NumberToString<std::wstring, T> {
|
||||
static constexpr WCHAR kSpecConstantPrefix[] = L"WGSL_SPEC_CONSTANT_";
|
||||
static std::wstring ToStringAsValue(T v) {
|
||||
return std::to_wstring(v);
|
||||
}
|
||||
static std::wstring ToStringAsId(T v) {
|
||||
return std::to_wstring(v);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumberToString<std::string, float> {
|
||||
static std::string ToStringAsValue(float v) {
|
||||
std::ostringstream out;
|
||||
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
|
||||
out.precision(8);
|
||||
out << std::fixed << v;
|
||||
return out.str();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumberToString<std::wstring, float> {
|
||||
static std::wstring ToStringAsValue(float v) {
|
||||
std::basic_ostringstream<WCHAR> out;
|
||||
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
|
||||
out.precision(8);
|
||||
out << std::fixed << v;
|
||||
return out.str();
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumberToString<std::string, uint32_t> {
|
||||
static std::string ToStringAsValue(uint32_t v) {
|
||||
return std::to_string(v) + "u";
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NumberToString<std::wstring, uint32_t> {
|
||||
static std::wstring ToStringAsValue(uint32_t v) {
|
||||
return std::to_wstring(v) + L"u";
|
||||
}
|
||||
};
|
||||
|
||||
template <typename StringType>
|
||||
StringType GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
|
||||
const OverridableConstantScalar* entry,
|
||||
double value = 0) {
|
||||
switch (dawnType) {
|
||||
case EntryPointMetadata::OverridableConstant::Type::Boolean:
|
||||
return NumberToString<StringType, int32_t>::ToStringAsValue(
|
||||
entry ? entry->b : static_cast<int32_t>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Float32:
|
||||
return NumberToString<StringType, float>::ToStringAsValue(
|
||||
entry ? entry->f32 : static_cast<float>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Int32:
|
||||
return NumberToString<StringType, int32_t>::ToStringAsValue(
|
||||
entry ? entry->i32 : static_cast<int32_t>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Uint32:
|
||||
return NumberToString<StringType, uint32_t>::ToStringAsValue(
|
||||
entry ? entry->u32 : static_cast<uint32_t>(value));
|
||||
default:
|
||||
UNREACHABLE();
|
||||
}
|
||||
}
|
||||
} // namespace dawn_native
|
||||
|
||||
namespace dawn_native { namespace d3d12 {
|
||||
|
||||
namespace {
|
||||
|
@ -181,6 +93,73 @@ namespace dawn_native { namespace d3d12 {
|
|||
output << ")";
|
||||
}
|
||||
|
||||
// 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough
|
||||
std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) {
|
||||
std::ostringstream out;
|
||||
out.precision(n);
|
||||
out << std::fixed << v;
|
||||
return out.str();
|
||||
}
|
||||
|
||||
std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType,
|
||||
const OverridableConstantScalar* entry,
|
||||
double value = 0) {
|
||||
switch (dawnType) {
|
||||
case EntryPointMetadata::OverridableConstant::Type::Boolean:
|
||||
return std::to_string(entry ? entry->b : static_cast<int32_t>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Float32:
|
||||
return FloatToStringWithPrecision(entry ? entry->f32
|
||||
: static_cast<float>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Int32:
|
||||
return std::to_string(entry ? entry->i32 : static_cast<int32_t>(value));
|
||||
case EntryPointMetadata::OverridableConstant::Type::Uint32:
|
||||
return std::to_string(entry ? entry->u32 : static_cast<uint32_t>(value));
|
||||
default:
|
||||
UNREACHABLE();
|
||||
}
|
||||
}
|
||||
|
||||
constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_";
|
||||
|
||||
void GetOverridableConstantsDefines(
|
||||
std::vector<std::pair<std::string, std::string>>* defineStrings,
|
||||
const PipelineConstantEntries* pipelineConstantEntries,
|
||||
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
|
||||
std::unordered_set<std::string> overriddenConstants;
|
||||
|
||||
// Set pipeline overridden values
|
||||
for (const auto& pipelineConstant : *pipelineConstantEntries) {
|
||||
const std::string& name = pipelineConstant.first;
|
||||
double value = pipelineConstant.second;
|
||||
|
||||
overriddenConstants.insert(name);
|
||||
|
||||
// This is already validated so `name` must exist
|
||||
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||
|
||||
defineStrings->emplace_back(
|
||||
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
|
||||
GetHLSLValueString(moduleConstant.type, nullptr, value));
|
||||
}
|
||||
|
||||
// Set shader initialized default values
|
||||
for (const auto& iter : *shaderEntryPointConstants) {
|
||||
const std::string& name = iter.first;
|
||||
if (overriddenConstants.count(name) != 0) {
|
||||
// This constant already has overridden value
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||
|
||||
// Uninitialized default values are okay since they ar only defined to pass
|
||||
// compilation but not used
|
||||
defineStrings->emplace_back(
|
||||
kSpecConstantPrefix + std::to_string(static_cast<int32_t>(moduleConstant.id)),
|
||||
GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue));
|
||||
}
|
||||
}
|
||||
|
||||
// The inputs to a shader compilation. These have been intentionally isolated from the
|
||||
// device to help ensure that the pipeline cache key contains all inputs for compilation.
|
||||
struct ShaderCompilationRequest {
|
||||
|
@ -199,8 +178,7 @@ namespace dawn_native { namespace d3d12 {
|
|||
bool usesNumWorkgroups;
|
||||
uint32_t numWorkgroupsRegisterSpace;
|
||||
uint32_t numWorkgroupsShaderRegister;
|
||||
const PipelineConstantEntries* pipelineConstantEntries;
|
||||
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants;
|
||||
std::vector<std::pair<std::string, std::string>> defineStrings;
|
||||
|
||||
// FXC/DXC common inputs
|
||||
bool disableWorkgroupInit;
|
||||
|
@ -293,10 +271,12 @@ namespace dawn_native { namespace d3d12 {
|
|||
request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0;
|
||||
request.deviceInfo = &device->GetDeviceInfo();
|
||||
request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16);
|
||||
request.pipelineConstantEntries = &programmableStage.constants;
|
||||
request.shaderEntryPointConstants =
|
||||
|
||||
GetOverridableConstantsDefines(
|
||||
&request.defineStrings, &programmableStage.constants,
|
||||
&programmableStage.module->GetEntryPoint(programmableStage.entryPoint)
|
||||
.overridableConstants;
|
||||
.overridableConstants);
|
||||
|
||||
return std::move(request);
|
||||
}
|
||||
|
||||
|
@ -349,17 +329,9 @@ namespace dawn_native { namespace d3d12 {
|
|||
stream << " dxcVersion=" << dxcVersion;
|
||||
stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature;
|
||||
|
||||
stream << " overridableConstants={";
|
||||
for (const auto& pipelineConstant : *pipelineConstantEntries) {
|
||||
const std::string& name = pipelineConstant.first;
|
||||
double value = pipelineConstant.second;
|
||||
|
||||
// This is already validated so `name` must exist
|
||||
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||
|
||||
stream << " <" << name << ","
|
||||
<< GetHLSLValueString<std::string>(moduleConstant.type, nullptr, value)
|
||||
<< ">";
|
||||
stream << " defines={";
|
||||
for (const auto& it : defineStrings) {
|
||||
stream << " <" << it.first << "," << it.second << ">";
|
||||
}
|
||||
stream << " }";
|
||||
|
||||
|
@ -422,53 +394,6 @@ namespace dawn_native { namespace d3d12 {
|
|||
return arguments;
|
||||
}
|
||||
|
||||
template <typename StringType>
|
||||
const std::vector<std::pair<StringType, StringType>> GetOverridableConstantsDefines(
|
||||
const PipelineConstantEntries* pipelineConstantEntries,
|
||||
const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) {
|
||||
std::vector<std::pair<StringType, StringType>> defineStrings;
|
||||
|
||||
std::unordered_set<std::string> overriddenConstants;
|
||||
|
||||
// Set pipeline overridden values
|
||||
for (const auto& pipelineConstant : *pipelineConstantEntries) {
|
||||
const std::string& name = pipelineConstant.first;
|
||||
double value = pipelineConstant.second;
|
||||
|
||||
overriddenConstants.insert(name);
|
||||
|
||||
// This is already validated so `name` must exist
|
||||
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||
|
||||
defineStrings.emplace_back(
|
||||
NumberToString<StringType>::kSpecConstantPrefix +
|
||||
NumberToString<StringType>::ToStringAsId(
|
||||
static_cast<int32_t>(moduleConstant.id)),
|
||||
GetHLSLValueString<StringType>(moduleConstant.type, nullptr, value));
|
||||
}
|
||||
|
||||
// Set shader initialized default values
|
||||
for (const auto& iter : *shaderEntryPointConstants) {
|
||||
const std::string& name = iter.first;
|
||||
if (overriddenConstants.count(name) != 0) {
|
||||
// This constant already has overridden value
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& moduleConstant = shaderEntryPointConstants->at(name);
|
||||
|
||||
// Uninitialized default values are okay since they are only defined to pass
|
||||
// compilation but not used
|
||||
defineStrings.emplace_back(NumberToString<StringType>::kSpecConstantPrefix +
|
||||
NumberToString<StringType>::ToStringAsId(
|
||||
static_cast<int32_t>(moduleConstant.id)),
|
||||
GetHLSLValueString<StringType>(
|
||||
moduleConstant.type, &moduleConstant.defaultValue));
|
||||
}
|
||||
|
||||
return defineStrings;
|
||||
}
|
||||
|
||||
ResultOrError<ComPtr<IDxcBlob>> CompileShaderDXC(IDxcLibrary* dxcLibrary,
|
||||
IDxcCompiler* dxcCompiler,
|
||||
const ShaderCompilationRequest& request,
|
||||
|
@ -486,8 +411,13 @@ namespace dawn_native { namespace d3d12 {
|
|||
GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature);
|
||||
|
||||
// Build defines for overridable constants
|
||||
const auto& defineStrings = GetOverridableConstantsDefines<std::wstring>(
|
||||
request.pipelineConstantEntries, request.shaderEntryPointConstants);
|
||||
std::vector<std::pair<std::wstring, std::wstring>> defineStrings;
|
||||
defineStrings.reserve(request.defineStrings.size());
|
||||
for (const auto& it : request.defineStrings) {
|
||||
defineStrings.emplace_back(UTF8ToWStr(it.first.c_str()),
|
||||
UTF8ToWStr(it.second.c_str()));
|
||||
}
|
||||
|
||||
std::vector<DxcDefine> dxcDefines;
|
||||
dxcDefines.reserve(defineStrings.size());
|
||||
for (const auto& d : defineStrings) {
|
||||
|
@ -538,14 +468,11 @@ namespace dawn_native { namespace d3d12 {
|
|||
ComPtr<ID3DBlob> errors;
|
||||
|
||||
// Build defines for overridable constants
|
||||
const auto& defineStrings = GetOverridableConstantsDefines<std::string>(
|
||||
request.pipelineConstantEntries, request.shaderEntryPointConstants);
|
||||
|
||||
const D3D_SHADER_MACRO* pDefines = nullptr;
|
||||
std::vector<D3D_SHADER_MACRO> fxcDefines;
|
||||
if (defineStrings.size() > 0) {
|
||||
fxcDefines.reserve(defineStrings.size() + 1);
|
||||
for (const auto& d : defineStrings) {
|
||||
if (request.defineStrings.size() > 0) {
|
||||
fxcDefines.reserve(request.defineStrings.size() + 1);
|
||||
for (const auto& d : request.defineStrings) {
|
||||
fxcDefines.push_back({d.first.c_str(), d.second.c_str()});
|
||||
}
|
||||
// d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array
|
||||
|
|
Loading…
Reference in New Issue