diff --git a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp index ec5b3b3863..9fb8b19eb1 100644 --- a/src/dawn_native/d3d12/ShaderModuleD3D12.cpp +++ b/src/dawn_native/d3d12/ShaderModuleD3D12.cpp @@ -34,94 +34,6 @@ #include #include -namespace dawn_native { - template - struct NumberToString { - static StringType ToStringAsValue(T v); - static StringType ToStringAsId(T v); - }; - - template - struct NumberToString { - static constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; - static std::string ToStringAsValue(T v) { - return std::to_string(v); - } - static std::string ToStringAsId(T v) { - return std::to_string(v); - } - }; - - template - struct NumberToString { - static constexpr WCHAR kSpecConstantPrefix[] = L"WGSL_SPEC_CONSTANT_"; - static std::wstring ToStringAsValue(T v) { - return std::to_wstring(v); - } - static std::wstring ToStringAsId(T v) { - return std::to_wstring(v); - } - }; - - template <> - struct NumberToString { - static std::string ToStringAsValue(float v) { - std::ostringstream out; - // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough - out.precision(8); - out << std::fixed << v; - return out.str(); - } - }; - - template <> - struct NumberToString { - static std::wstring ToStringAsValue(float v) { - std::basic_ostringstream out; - // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough - out.precision(8); - out << std::fixed << v; - return out.str(); - } - }; - - template <> - struct NumberToString { - static std::string ToStringAsValue(uint32_t v) { - return std::to_string(v) + "u"; - } - }; - - template <> - struct NumberToString { - static std::wstring ToStringAsValue(uint32_t v) { - return std::to_wstring(v) + L"u"; - } - }; - - template - StringType GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType, - const OverridableConstantScalar* entry, - double value = 0) { - switch (dawnType) { - case EntryPointMetadata::OverridableConstant::Type::Boolean: - return NumberToString::ToStringAsValue( - entry ? entry->b : static_cast(value)); - case EntryPointMetadata::OverridableConstant::Type::Float32: - return NumberToString::ToStringAsValue( - entry ? entry->f32 : static_cast(value)); - case EntryPointMetadata::OverridableConstant::Type::Int32: - return NumberToString::ToStringAsValue( - entry ? entry->i32 : static_cast(value)); - case EntryPointMetadata::OverridableConstant::Type::Uint32: - return NumberToString::ToStringAsValue( - entry ? entry->u32 : static_cast(value)); - default: - UNREACHABLE(); - } - } -} // namespace dawn_native - namespace dawn_native { namespace d3d12 { namespace { @@ -181,6 +93,73 @@ namespace dawn_native { namespace d3d12 { output << ")"; } + // 32 bit float has 7 decimal digits of precision so setting n to 8 should be enough + std::string FloatToStringWithPrecision(float v, std::streamsize n = 8) { + std::ostringstream out; + out.precision(n); + out << std::fixed << v; + return out.str(); + } + + std::string GetHLSLValueString(EntryPointMetadata::OverridableConstant::Type dawnType, + const OverridableConstantScalar* entry, + double value = 0) { + switch (dawnType) { + case EntryPointMetadata::OverridableConstant::Type::Boolean: + return std::to_string(entry ? entry->b : static_cast(value)); + case EntryPointMetadata::OverridableConstant::Type::Float32: + return FloatToStringWithPrecision(entry ? entry->f32 + : static_cast(value)); + case EntryPointMetadata::OverridableConstant::Type::Int32: + return std::to_string(entry ? entry->i32 : static_cast(value)); + case EntryPointMetadata::OverridableConstant::Type::Uint32: + return std::to_string(entry ? entry->u32 : static_cast(value)); + default: + UNREACHABLE(); + } + } + + constexpr char kSpecConstantPrefix[] = "WGSL_SPEC_CONSTANT_"; + + void GetOverridableConstantsDefines( + std::vector>* defineStrings, + const PipelineConstantEntries* pipelineConstantEntries, + const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) { + std::unordered_set overriddenConstants; + + // Set pipeline overridden values + for (const auto& pipelineConstant : *pipelineConstantEntries) { + const std::string& name = pipelineConstant.first; + double value = pipelineConstant.second; + + overriddenConstants.insert(name); + + // This is already validated so `name` must exist + const auto& moduleConstant = shaderEntryPointConstants->at(name); + + defineStrings->emplace_back( + kSpecConstantPrefix + std::to_string(static_cast(moduleConstant.id)), + GetHLSLValueString(moduleConstant.type, nullptr, value)); + } + + // Set shader initialized default values + for (const auto& iter : *shaderEntryPointConstants) { + const std::string& name = iter.first; + if (overriddenConstants.count(name) != 0) { + // This constant already has overridden value + continue; + } + + const auto& moduleConstant = shaderEntryPointConstants->at(name); + + // Uninitialized default values are okay since they ar only defined to pass + // compilation but not used + defineStrings->emplace_back( + kSpecConstantPrefix + std::to_string(static_cast(moduleConstant.id)), + GetHLSLValueString(moduleConstant.type, &moduleConstant.defaultValue)); + } + } + // The inputs to a shader compilation. These have been intentionally isolated from the // device to help ensure that the pipeline cache key contains all inputs for compilation. struct ShaderCompilationRequest { @@ -199,8 +178,7 @@ namespace dawn_native { namespace d3d12 { bool usesNumWorkgroups; uint32_t numWorkgroupsRegisterSpace; uint32_t numWorkgroupsShaderRegister; - const PipelineConstantEntries* pipelineConstantEntries; - const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants; + std::vector> defineStrings; // FXC/DXC common inputs bool disableWorkgroupInit; @@ -293,10 +271,12 @@ namespace dawn_native { namespace d3d12 { request.dxcVersion = compiler == Compiler::DXC ? dxcVersion : 0; request.deviceInfo = &device->GetDeviceInfo(); request.hasShaderFloat16Feature = device->IsFeatureEnabled(Feature::ShaderFloat16); - request.pipelineConstantEntries = &programmableStage.constants; - request.shaderEntryPointConstants = + + GetOverridableConstantsDefines( + &request.defineStrings, &programmableStage.constants, &programmableStage.module->GetEntryPoint(programmableStage.entryPoint) - .overridableConstants; + .overridableConstants); + return std::move(request); } @@ -349,17 +329,9 @@ namespace dawn_native { namespace d3d12 { stream << " dxcVersion=" << dxcVersion; stream << " hasShaderFloat16Feature=" << hasShaderFloat16Feature; - stream << " overridableConstants={"; - for (const auto& pipelineConstant : *pipelineConstantEntries) { - const std::string& name = pipelineConstant.first; - double value = pipelineConstant.second; - - // This is already validated so `name` must exist - const auto& moduleConstant = shaderEntryPointConstants->at(name); - - stream << " <" << name << "," - << GetHLSLValueString(moduleConstant.type, nullptr, value) - << ">"; + stream << " defines={"; + for (const auto& it : defineStrings) { + stream << " <" << it.first << "," << it.second << ">"; } stream << " }"; @@ -422,53 +394,6 @@ namespace dawn_native { namespace d3d12 { return arguments; } - template - const std::vector> GetOverridableConstantsDefines( - const PipelineConstantEntries* pipelineConstantEntries, - const EntryPointMetadata::OverridableConstantsMap* shaderEntryPointConstants) { - std::vector> defineStrings; - - std::unordered_set overriddenConstants; - - // Set pipeline overridden values - for (const auto& pipelineConstant : *pipelineConstantEntries) { - const std::string& name = pipelineConstant.first; - double value = pipelineConstant.second; - - overriddenConstants.insert(name); - - // This is already validated so `name` must exist - const auto& moduleConstant = shaderEntryPointConstants->at(name); - - defineStrings.emplace_back( - NumberToString::kSpecConstantPrefix + - NumberToString::ToStringAsId( - static_cast(moduleConstant.id)), - GetHLSLValueString(moduleConstant.type, nullptr, value)); - } - - // Set shader initialized default values - for (const auto& iter : *shaderEntryPointConstants) { - const std::string& name = iter.first; - if (overriddenConstants.count(name) != 0) { - // This constant already has overridden value - continue; - } - - const auto& moduleConstant = shaderEntryPointConstants->at(name); - - // Uninitialized default values are okay since they are only defined to pass - // compilation but not used - defineStrings.emplace_back(NumberToString::kSpecConstantPrefix + - NumberToString::ToStringAsId( - static_cast(moduleConstant.id)), - GetHLSLValueString( - moduleConstant.type, &moduleConstant.defaultValue)); - } - - return defineStrings; - } - ResultOrError> CompileShaderDXC(IDxcLibrary* dxcLibrary, IDxcCompiler* dxcCompiler, const ShaderCompilationRequest& request, @@ -486,8 +411,13 @@ namespace dawn_native { namespace d3d12 { GetDXCArguments(request.compileFlags, request.hasShaderFloat16Feature); // Build defines for overridable constants - const auto& defineStrings = GetOverridableConstantsDefines( - request.pipelineConstantEntries, request.shaderEntryPointConstants); + std::vector> defineStrings; + defineStrings.reserve(request.defineStrings.size()); + for (const auto& it : request.defineStrings) { + defineStrings.emplace_back(UTF8ToWStr(it.first.c_str()), + UTF8ToWStr(it.second.c_str())); + } + std::vector dxcDefines; dxcDefines.reserve(defineStrings.size()); for (const auto& d : defineStrings) { @@ -538,14 +468,11 @@ namespace dawn_native { namespace d3d12 { ComPtr errors; // Build defines for overridable constants - const auto& defineStrings = GetOverridableConstantsDefines( - request.pipelineConstantEntries, request.shaderEntryPointConstants); - const D3D_SHADER_MACRO* pDefines = nullptr; std::vector fxcDefines; - if (defineStrings.size() > 0) { - fxcDefines.reserve(defineStrings.size() + 1); - for (const auto& d : defineStrings) { + if (request.defineStrings.size() > 0) { + fxcDefines.reserve(request.defineStrings.size() + 1); + for (const auto& d : request.defineStrings) { fxcDefines.push_back({d.first.c_str(), d.second.c_str()}); } // d3dCompile D3D_SHADER_MACRO* pDefines is a nullptr terminated array