From bc33bae65943c339f7a11ed482436b3f32bf3c06 Mon Sep 17 00:00:00 2001 From: Luke Street Date: Sun, 28 Sep 2025 17:00:38 -0600 Subject: [PATCH] Formatting, fixes, deduplication --- dll/advapi32.cpp | 27 ++-- dll/crt.cpp | 13 +- dll/kernel32.cpp | 32 +--- dll/msvcrt.cpp | 26 +--- dll/psapi.cpp | 76 +++++----- dll/rpcrt4.cpp | 41 ++---- dll/version.cpp | 25 ++-- files.cpp | 6 +- module_registry.cpp | 349 +++++++++++++++++++------------------------- strutil.cpp | 37 +++++ strutil.h | 6 + 11 files changed, 293 insertions(+), 345 deletions(-) diff --git a/dll/advapi32.cpp b/dll/advapi32.cpp index d74bbee..0263be0 100644 --- a/dll/advapi32.cpp +++ b/dll/advapi32.cpp @@ -1,7 +1,8 @@ #include "common.h" #include "handles.h" +#include "strutil.h" +#include #include -#include #include namespace { @@ -294,7 +295,8 @@ namespace advapi32 { return TRUE; } - BOOL WIN_FUNC CryptAcquireContextW(void** phProv, const wchar_t* pszContainer, const wchar_t* pszProvider, unsigned int dwProvType, unsigned int dwFlags){ + BOOL WIN_FUNC CryptAcquireContextW(void **phProv, const uint16_t *pszContainer, const uint16_t *pszProvider, + unsigned int dwProvType, unsigned int dwFlags) { DEBUG_LOG("STUB: CryptAcquireContextW(%p)\n", phProv); // to quote the guy above me: screw them for now @@ -494,7 +496,7 @@ namespace advapi32 { if (TokenInformationClass == TokenUserClass) { constexpr size_t sidSize = sizeof(Sid); constexpr size_t tokenUserSize = sizeof(TokenUserData); - const unsigned int required = static_cast(tokenUserSize + sidSize); + const auto required = static_cast(tokenUserSize + sidSize); *ReturnLength = required; if (!TokenInformation || TokenInformationLength < required) { wibo::lastError = ERROR_INSUFFICIENT_BUFFER; @@ -544,8 +546,11 @@ namespace advapi32 { return FALSE; } - BOOL WIN_FUNC LookupAccountSidW(const wchar_t *lpSystemName, const void *sidPointer, wchar_t *Name, unsigned long *cchName, wchar_t *ReferencedDomainName, unsigned long *cchReferencedDomainName, SID_NAME_USE *peUse) { - DEBUG_LOG("LookupAccountSidW(system=%ls, sid=%p)\n", lpSystemName ? lpSystemName : L"(null)", sidPointer); + BOOL WIN_FUNC LookupAccountSidW(const uint16_t *lpSystemName, const void *sidPointer, uint16_t *Name, + unsigned long *cchName, uint16_t *ReferencedDomainName, + unsigned long *cchReferencedDomainName, SID_NAME_USE *peUse) { + std::string systemName = lpSystemName ? wideStringToString(lpSystemName) : std::string("(null)"); + DEBUG_LOG("LookupAccountSidW(system=%s, sid=%p)\n", systemName.c_str(), sidPointer); (void) lpSystemName; // Only local lookup supported if (!sidPointer || !cchName || !cchReferencedDomainName || !peUse) { wibo::lastError = ERROR_INVALID_PARAMETER; @@ -556,18 +561,18 @@ namespace advapi32 { wibo::lastError = ERROR_NONE_MAPPED; return FALSE; } - const wchar_t *accountName = L"SYSTEM"; - const wchar_t *domainName = L"NT AUTHORITY"; - unsigned long requiredAccount = static_cast(std::wcslen(accountName) + 1); - unsigned long requiredDomain = static_cast(std::wcslen(domainName) + 1); + static constexpr uint16_t accountName[] = {u'S', u'Y', u'S', u'T', u'E', u'M', u'\0'}; + static constexpr uint16_t domainName[] = {u'N', u'T', u' ', u'A', u'U', u'T', u'H', u'O', u'R', u'I', u'T', u'Y', u'\0'}; + unsigned long requiredAccount = wstrlen(accountName) + 1; + unsigned long requiredDomain = wstrlen(domainName) + 1; if (!Name || *cchName < requiredAccount || !ReferencedDomainName || *cchReferencedDomainName < requiredDomain) { *cchName = requiredAccount; *cchReferencedDomainName = requiredDomain; wibo::lastError = ERROR_INSUFFICIENT_BUFFER; return FALSE; } - std::wmemcpy(Name, accountName, requiredAccount); - std::wmemcpy(ReferencedDomainName, domainName, requiredDomain); + std::copy_n(accountName, requiredAccount, Name); + std::copy_n(domainName, requiredDomain, ReferencedDomainName); *peUse = SidTypeWellKnownGroup; *cchName = requiredAccount - 1; *cchReferencedDomainName = requiredDomain - 1; diff --git a/dll/crt.cpp b/dll/crt.cpp index f7208cd..99b8b3a 100644 --- a/dll/crt.cpp +++ b/dll/crt.cpp @@ -5,17 +5,18 @@ #include #include #include -#include #include +#include typedef void (*_PVFV)(); typedef int (*_PIFV)(); -typedef void (*_invalid_parameter_handler)(const wchar_t *, const wchar_t *, const wchar_t *, unsigned int, uintptr_t); +typedef void (*_invalid_parameter_handler)(const uint16_t *, const uint16_t *, const uint16_t *, unsigned int, + uintptr_t); extern char **environ; namespace msvcrt { - int WIN_ENTRY puts(const char *str); +int WIN_ENTRY puts(const char *str); } typedef enum _crt_app_type { @@ -195,11 +196,13 @@ void *WIN_ENTRY __acrt_iob_func(unsigned int index) { return nullptr; } -int WIN_ENTRY __stdio_common_vfprintf(unsigned long long /*options*/, FILE *stream, const char *format, void * /*locale*/, va_list args) { +int WIN_ENTRY __stdio_common_vfprintf(unsigned long long /*options*/, FILE *stream, const char *format, + void * /*locale*/, va_list args) { return vfprintf(stream, format, args); } -int WIN_ENTRY __stdio_common_vsprintf(unsigned long long /*options*/, char *buffer, size_t len, const char *format, void * /*locale*/, va_list args) { +int WIN_ENTRY __stdio_common_vsprintf(unsigned long long /*options*/, char *buffer, size_t len, const char *format, + void * /*locale*/, va_list args) { if (!buffer || !format) return -1; int result = vsnprintf(buffer, len, format, args); diff --git a/dll/kernel32.cpp b/dll/kernel32.cpp index 27e422a..574cafa 100644 --- a/dll/kernel32.cpp +++ b/dll/kernel32.cpp @@ -349,37 +349,11 @@ namespace kernel32 { std::memset(lpSystemInfo, 0, sizeof(*lpSystemInfo)); - WORD architecture = PROCESSOR_ARCHITECTURE_UNKNOWN; - DWORD processorType = 0; - WORD processorLevel = 0; - -#if defined(__x86_64__) || defined(_M_X64) - architecture = PROCESSOR_ARCHITECTURE_AMD64; - processorType = PROCESSOR_AMD_X8664; - processorLevel = 6; -#elif defined(__i386__) || defined(_M_IX86) - architecture = PROCESSOR_ARCHITECTURE_INTEL; - processorType = PROCESSOR_INTEL_PENTIUM; - processorLevel = 6; -#elif defined(__aarch64__) - architecture = PROCESSOR_ARCHITECTURE_ARM64; - processorType = 0; - processorLevel = 8; -#elif defined(__arm__) - architecture = PROCESSOR_ARCHITECTURE_ARM; - processorType = 0; - processorLevel = 7; -#else - architecture = PROCESSOR_ARCHITECTURE_UNKNOWN; - processorType = 0; - processorLevel = 0; -#endif - - lpSystemInfo->wProcessorArchitecture = architecture; + lpSystemInfo->wProcessorArchitecture = PROCESSOR_ARCHITECTURE_INTEL; lpSystemInfo->wReserved = 0; lpSystemInfo->dwOemId = lpSystemInfo->wProcessorArchitecture; - lpSystemInfo->dwProcessorType = processorType; - lpSystemInfo->wProcessorLevel = processorLevel; + lpSystemInfo->dwProcessorType = PROCESSOR_INTEL_PENTIUM; + lpSystemInfo->wProcessorLevel = 6; // Pentium lpSystemInfo->wProcessorRevision = 0; long pageSize = sysconf(_SC_PAGESIZE); diff --git a/dll/msvcrt.cpp b/dll/msvcrt.cpp index 8617ca2..b8f17d0 100644 --- a/dll/msvcrt.cpp +++ b/dll/msvcrt.cpp @@ -163,10 +163,8 @@ namespace msvcrt { } } if (!exeDir.empty()) { - std::string loweredResult = result; - std::string loweredExe = exeDir; - std::transform(loweredResult.begin(), loweredResult.end(), loweredResult.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); - std::transform(loweredExe.begin(), loweredExe.end(), loweredExe.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); + std::string loweredResult = stringToLower(result); + std::string loweredExe = stringToLower(exeDir); bool present = false; size_t start = 0; while (start <= loweredResult.size()) { @@ -648,18 +646,6 @@ char* WIN_ENTRY setlocale(int category, const char *locale){ return 0; } - static uint16_t toLower(uint16_t ch) { - if (ch >= 'A' && ch <= 'Z') { - return static_cast(ch + ('a' - 'A')); - } - wchar_t wide = static_cast(ch); - wchar_t lowered = std::towlower(wide); - if (lowered < 0 || lowered > 0xFFFF) { - return ch; - } - return static_cast(lowered); - } - int WIN_ENTRY _wcsicmp(const uint16_t *lhs, const uint16_t *rhs) { if (lhs == rhs) { return 0; @@ -672,15 +658,15 @@ char* WIN_ENTRY setlocale(int category, const char *locale){ } while (*lhs && *rhs) { - uint16_t a = toLower(*lhs++); - uint16_t b = toLower(*rhs++); + uint16_t a = wcharToLower(*lhs++); + uint16_t b = wcharToLower(*rhs++); if (a != b) { return static_cast(a) - static_cast(b); } } - uint16_t a = toLower(*lhs); - uint16_t b = toLower(*rhs); + uint16_t a = wcharToLower(*lhs); + uint16_t b = wcharToLower(*rhs); return static_cast(a) - static_cast(b); } diff --git a/dll/psapi.cpp b/dll/psapi.cpp index bd7e4ff..70657f3 100644 --- a/dll/psapi.cpp +++ b/dll/psapi.cpp @@ -2,55 +2,55 @@ #include "handles.h" namespace psapi { - BOOL WIN_FUNC EnumProcessModules(HANDLE hProcess, HMODULE *lphModule, DWORD cb, DWORD *lpcbNeeded) { - DEBUG_LOG("EnumProcessModules(hProcess=%p, cb=%u)\n", hProcess, cb); +BOOL WIN_FUNC EnumProcessModules(HANDLE hProcess, HMODULE *lphModule, DWORD cb, DWORD *lpcbNeeded) { + DEBUG_LOG("EnumProcessModules(hProcess=%p, cb=%u)\n", hProcess, cb); - bool recognizedHandle = false; - if (hProcess == (HANDLE)0xFFFFFFFF) { - recognizedHandle = true; - } else { - auto data = handles::dataFromHandle(hProcess, false); - recognizedHandle = (data.type == handles::TYPE_PROCESS); - } - if (!recognizedHandle) { - wibo::lastError = ERROR_ACCESS_DENIED; - return FALSE; - } - - HMODULE currentModule = wibo::mainModule ? reinterpret_cast(wibo::mainModule->imageBuffer) : nullptr; - DWORD required = currentModule ? sizeof(HMODULE) : 0; - if (lpcbNeeded) { - *lpcbNeeded = required; - } - - if (required == 0) { - wibo::lastError = ERROR_INVALID_HANDLE; - return FALSE; - } - - if (!lphModule || cb < required) { - wibo::lastError = ERROR_INSUFFICIENT_BUFFER; - return FALSE; - } - - lphModule[0] = currentModule; - wibo::lastError = ERROR_SUCCESS; - return TRUE; + bool recognizedHandle = false; + if (hProcess == (HANDLE)0xFFFFFFFF) { + recognizedHandle = true; + } else { + auto data = handles::dataFromHandle(hProcess, false); + recognizedHandle = (data.type == handles::TYPE_PROCESS); } + if (!recognizedHandle) { + wibo::lastError = ERROR_ACCESS_DENIED; + return FALSE; + } + + HMODULE currentModule = wibo::mainModule ? reinterpret_cast(wibo::mainModule->imageBuffer) : nullptr; + DWORD required = currentModule ? sizeof(HMODULE) : 0; + if (lpcbNeeded) { + *lpcbNeeded = required; + } + + if (required == 0) { + wibo::lastError = ERROR_INVALID_HANDLE; + return FALSE; + } + + if (!lphModule || cb < required) { + wibo::lastError = ERROR_INSUFFICIENT_BUFFER; + return FALSE; + } + + lphModule[0] = currentModule; + wibo::lastError = ERROR_SUCCESS; + return TRUE; } +} // namespace psapi static void *resolveByName(const char *name) { - DEBUG_LOG("psapi resolveByName(%s)\n", name); - if (strcmp(name, "EnumProcessModules") == 0) return (void *) psapi::EnumProcessModules; - if (strcmp(name, "K32EnumProcessModules") == 0) return (void *) psapi::EnumProcessModules; + if (strcmp(name, "EnumProcessModules") == 0) + return (void *)psapi::EnumProcessModules; + if (strcmp(name, "K32EnumProcessModules") == 0) + return (void *)psapi::EnumProcessModules; return nullptr; } static void *resolveByOrdinal(uint16_t ordinal) { - DEBUG_LOG("psapi resolveByOrdinal(%u)\n", ordinal); switch (ordinal) { case 4: // EnumProcessModules - return (void *) psapi::EnumProcessModules; + return (void *)psapi::EnumProcessModules; default: return nullptr; } diff --git a/dll/rpcrt4.cpp b/dll/rpcrt4.cpp index 4a77962..9ee967f 100644 --- a/dll/rpcrt4.cpp +++ b/dll/rpcrt4.cpp @@ -128,14 +128,8 @@ BindingHandleData *getBinding(RPC_BINDING_HANDLE handle) { extern "C" { -RPC_STATUS WIN_FUNC RpcStringBindingComposeW( - RPC_WSTR objUuid, - RPC_WSTR protSeq, - RPC_WSTR networkAddr, - RPC_WSTR endpoint, - RPC_WSTR options, - RPC_WSTR *stringBinding -) { +RPC_STATUS WIN_FUNC RpcStringBindingComposeW(RPC_WSTR objUuid, RPC_WSTR protSeq, RPC_WSTR networkAddr, + RPC_WSTR endpoint, RPC_WSTR options, RPC_WSTR *stringBinding) { BindingComponents components; components.objectUuid = toU16(objUuid); components.protocolSequence = toU16(protSeq); @@ -187,15 +181,10 @@ RPC_STATUS WIN_FUNC RpcBindingFromStringBindingW(RPC_WSTR stringBinding, RPC_BIN return RPC_S_OK; } -RPC_STATUS WIN_FUNC RpcBindingSetAuthInfoExW( - RPC_BINDING_HANDLE binding, - RPC_WSTR serverPrincName, - unsigned long authnLevel, - unsigned long authnSvc, - RPC_AUTH_IDENTITY_HANDLE authIdentity, - unsigned long authzSvc, - RPC_SECURITY_QOS *securityQos -) { +RPC_STATUS WIN_FUNC RpcBindingSetAuthInfoExW(RPC_BINDING_HANDLE binding, RPC_WSTR serverPrincName, + unsigned long authnLevel, unsigned long authnSvc, + RPC_AUTH_IDENTITY_HANDLE authIdentity, unsigned long authzSvc, + RPC_SECURITY_QOS *securityQos) { BindingHandleData *data = getBinding(binding); if (!data) { return RPC_S_INVALID_BINDING; @@ -260,9 +249,7 @@ NdrClientCall2(PMIDL_STUB_DESC stubDescriptor, PFORMAT_STRING format, ...) { return result; } -void WIN_FUNC NdrServerCall2(PRPC_MESSAGE message) { - DEBUG_LOG("STUB: NdrServerCall2 message=%p\n", message); -} +void WIN_FUNC NdrServerCall2(PRPC_MESSAGE message) { DEBUG_LOG("STUB: NdrServerCall2 message=%p\n", message); } } // extern "C" @@ -270,19 +257,19 @@ namespace { void *resolveByName(const char *name) { if (std::strcmp(name, "RpcStringBindingComposeW") == 0) - return (void *) RpcStringBindingComposeW; + return (void *)RpcStringBindingComposeW; if (std::strcmp(name, "RpcBindingFromStringBindingW") == 0) - return (void *) RpcBindingFromStringBindingW; + return (void *)RpcBindingFromStringBindingW; if (std::strcmp(name, "RpcStringFreeW") == 0) - return (void *) RpcStringFreeW; + return (void *)RpcStringFreeW; if (std::strcmp(name, "RpcBindingFree") == 0) - return (void *) RpcBindingFree; + return (void *)RpcBindingFree; if (std::strcmp(name, "RpcBindingSetAuthInfoExW") == 0) - return (void *) RpcBindingSetAuthInfoExW; + return (void *)RpcBindingSetAuthInfoExW; if (std::strcmp(name, "NdrClientCall2") == 0) - return (void *) NdrClientCall2; + return (void *)NdrClientCall2; if (std::strcmp(name, "NdrServerCall2") == 0) - return (void *) NdrServerCall2; + return (void *)NdrServerCall2; return nullptr; } diff --git a/dll/version.cpp b/dll/version.cpp index 06521c7..2bbefcf 100644 --- a/dll/version.cpp +++ b/dll/version.cpp @@ -3,8 +3,6 @@ #include "resources.h" #include "strutil.h" -#include -#include #include #include #include @@ -15,7 +13,7 @@ namespace { constexpr uint32_t RT_VERSION = 16; -static uint16_t read_u16(const uint8_t *ptr) { +static uint16_t readU16(const uint8_t *ptr) { return static_cast(ptr[0] | (ptr[1] << 8)); } @@ -49,9 +47,9 @@ static bool parseVersionBlock(const uint8_t *block, size_t available, VersionBlo return false; } - uint16_t totalLength = read_u16(block); - uint16_t valueLength = read_u16(block + sizeof(uint16_t)); - uint16_t type = read_u16(block + sizeof(uint16_t) * 2); + uint16_t totalLength = readU16(block); + uint16_t valueLength = readU16(block + sizeof(uint16_t)); + uint16_t type = readU16(block + sizeof(uint16_t) * 2); if (totalLength == 0 || totalLength > available) { DEBUG_LOG("invalid totalLength=%u available=%zu\n", totalLength, available); return false; @@ -61,7 +59,7 @@ static bool parseVersionBlock(const uint8_t *block, size_t available, VersionBlo const uint8_t *cursor = block + sizeof(uint16_t) * 3; out.key.clear(); while (cursor + sizeof(uint16_t) <= end) { - uint16_t ch = read_u16(cursor); + uint16_t ch = readU16(cursor); cursor += sizeof(uint16_t); if (!ch) break; @@ -101,11 +99,6 @@ static bool parseVersionBlock(const uint8_t *block, size_t available, VersionBlo return true; } -static std::string toLowerCopy(std::string str) { - std::transform(str.begin(), str.end(), str.begin(), [](unsigned char c) { return static_cast(std::tolower(c)); }); - return str; -} - static bool queryVersionBlock(const uint8_t *block, size_t available, const std::vector &segments, size_t depth, @@ -126,7 +119,7 @@ static bool queryVersionBlock(const uint8_t *block, size_t available, return true; } - const std::string targetLower = toLowerCopy(segments[depth]); + const std::string targetLower = stringToLower(segments[depth]); const uint8_t *cursor = view.childrenPtr; const uint8_t *end = view.childrenPtr + view.childrenBytes; @@ -137,12 +130,12 @@ static bool queryVersionBlock(const uint8_t *block, size_t available, break; if (child.totalLength == 0) break; - std::string childKeyLower = toLowerCopy(narrowKey(child.key)); + std::string childKeyLower = stringToLower(narrowKey(child.key)); if (childKeyLower == targetLower) { if (queryVersionBlock(childStart, child.totalLength, segments, depth + 1, outPtr, outLen, outType)) return true; } - size_t offset = static_cast(child.totalLength); + const auto offset = static_cast(child.totalLength); cursor = childStart + align4(offset); if (cursor <= childStart || cursor > end) break; @@ -252,7 +245,7 @@ static unsigned int VerQueryValueImpl(const void *pBlock, const std::string &sub return 0; const uint8_t *base = static_cast(pBlock); - uint16_t totalLength = read_u16(base); + uint16_t totalLength = readU16(base); if (totalLength < 6) return 0; diff --git a/files.cpp b/files.cpp index 579538b..034ca98 100644 --- a/files.cpp +++ b/files.cpp @@ -1,8 +1,8 @@ #include "common.h" #include "files.h" #include "handles.h" +#include "strutil.h" #include -#include #include #include #include @@ -182,13 +182,13 @@ namespace files { return std::nullopt; } std::string needle = filename; - std::transform(needle.begin(), needle.end(), needle.begin(), [](unsigned char ch) { return std::tolower(ch); }); + toLowerInPlace(needle); for (const auto &entry : std::filesystem::directory_iterator(directory, ec)) { if (ec) { break; } std::string candidate = entry.path().filename().string(); - std::transform(candidate.begin(), candidate.end(), candidate.begin(), [](unsigned char ch) { return std::tolower(ch); }); + toLowerInPlace(candidate); if (candidate == needle) { return canonicalPath(entry.path()); } diff --git a/module_registry.cpp b/module_registry.cpp index f0f6730..926f647 100644 --- a/module_registry.cpp +++ b/module_registry.cpp @@ -4,7 +4,6 @@ #include #include -#include #include #include #include @@ -50,45 +49,29 @@ struct PEExportDirectory { uint32_t addressOfNameOrdinals; }; -#define FOR_256_3(a, b, c, d) FOR_ITER((a << 6 | b << 4 | c << 2 | d)) -#define FOR_256_2(a, b) \ - FOR_256_3(a, b, 0, 0) \ - FOR_256_3(a, b, 0, 1) \ - FOR_256_3(a, b, 0, 2) \ - FOR_256_3(a, b, 0, 3) FOR_256_3(a, b, 1, 0) FOR_256_3(a, b, 1, 1) FOR_256_3(a, b, 1, 2) FOR_256_3(a, b, 1, 3) \ - FOR_256_3(a, b, 2, 0) FOR_256_3(a, b, 2, 1) FOR_256_3(a, b, 2, 2) FOR_256_3(a, b, 2, 3) FOR_256_3(a, b, 3, 0) \ - FOR_256_3(a, b, 3, 1) FOR_256_3(a, b, 3, 2) FOR_256_3(a, b, 3, 3) -#define FOR_256 \ - FOR_256_2(0, 0) \ - FOR_256_2(0, 1) \ - FOR_256_2(0, 2) \ - FOR_256_2(0, 3) FOR_256_2(1, 0) FOR_256_2(1, 1) FOR_256_2(1, 2) FOR_256_2(1, 3) FOR_256_2(2, 0) FOR_256_2(2, 1) \ - FOR_256_2(2, 2) FOR_256_2(2, 3) FOR_256_2(3, 0) FOR_256_2(3, 1) FOR_256_2(3, 2) FOR_256_2(3, 3) +using StubFuncType = void (*)(); +constexpr size_t MAX_STUBS = 0x100; +size_t stubIndex = 0; +std::array stubDlls; +std::array stubFuncNames; +std::unordered_map stubCache; -static constexpr size_t MAX_STUBS = 0x100; -static int stubIndex = 0; -static std::array stubDlls; -static std::array stubFuncNames; -static std::unordered_map stubCache; - -static std::string makeStubKey(const char *dllName, const char *funcName) { +std::string makeStubKey(const char *dllName, const char *funcName) { std::string key; if (dllName) { key.assign(dllName); - std::transform(key.begin(), key.end(), key.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); + toLowerInPlace(key); } key.push_back(':'); if (funcName) { std::string func(funcName); - std::transform(func.begin(), func.end(), func.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); + toLowerInPlace(func); key += func; } return key; } -static void stubBase(int index) { +void stubBase(size_t index) { const char *func = stubFuncNames[index].empty() ? "" : stubFuncNames[index].c_str(); const char *dll = stubDlls[index].empty() ? "" : stubDlls[index].c_str(); fprintf(stderr, "wibo: call reached missing import %s from %s\n", func, dll); @@ -96,47 +79,42 @@ static void stubBase(int index) { abort(); } -void (*stubFuncs[MAX_STUBS])(void) = { -#define FOR_ITER(i) []() { stubBase(i); }, - FOR_256 -#undef FOR_ITER -}; +template void stubThunk() { stubBase(Index); } -#undef FOR_256_3 -#undef FOR_256_2 -#undef FOR_256 +template +constexpr std::array makeStubTable(std::index_sequence) { + return {{stubThunk...}}; +} -void *resolveMissingFuncName(const char *dllName, const char *funcName) { +constexpr auto stubFuncs = makeStubTable(std::make_index_sequence{}); + +StubFuncType resolveMissingFuncName(const char *dllName, const char *funcName) { DEBUG_LOG("Missing function: %s (%s)\n", dllName, funcName); std::string key = makeStubKey(dllName, funcName); auto existing = stubCache.find(key); if (existing != stubCache.end()) { return existing->second; } - if (stubIndex >= static_cast(MAX_STUBS)) { - fprintf(stderr, - "Too many missing functions encountered (>%zu). Last failure: %s (%s)\n", - MAX_STUBS, funcName, dllName); - exit(1); + if (stubIndex >= MAX_STUBS) { + fprintf(stderr, "wibo: too many missing functions encountered (>%zu). Last failure: %s (%s)\n", MAX_STUBS, + funcName, dllName); + fflush(stderr); + abort(); } stubFuncNames[stubIndex] = funcName ? funcName : ""; stubDlls[stubIndex] = dllName ? dllName : ""; - void *stub = (void *)stubFuncs[stubIndex]; + StubFuncType stub = stubFuncs[stubIndex]; stubCache.emplace(std::move(key), stub); stubIndex++; return stub; } -void *resolveMissingFuncOrdinal(const char *dllName, uint16_t ordinal) { +StubFuncType resolveMissingFuncOrdinal(const char *dllName, uint16_t ordinal) { char buf[16]; sprintf(buf, "%d", ordinal); return resolveMissingFuncName(dllName, buf); } -} // namespace - -namespace { - using ModulePtr = std::unique_ptr; struct ModuleRegistry { @@ -152,23 +130,45 @@ struct ModuleRegistry { std::unordered_set pinnedModules; }; -ModuleRegistry ®istry() { - static ModuleRegistry reg; - return reg; -} +struct LockedRegistry { + ModuleRegistry *reg; + std::unique_lock lock; -std::string toLowerCopy(const std::string &value) { - std::string out = value; - std::transform(out.begin(), out.end(), out.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); - return out; + LockedRegistry(ModuleRegistry ®istryRef, std::unique_lock &&guard) + : reg(®istryRef), lock(std::move(guard)) {} + + LockedRegistry(const LockedRegistry &) = delete; + LockedRegistry &operator=(const LockedRegistry &) = delete; + LockedRegistry(LockedRegistry &&) = default; + LockedRegistry &operator=(LockedRegistry &&) = default; + + [[nodiscard]] ModuleRegistry &get() const { return *reg; } + ModuleRegistry *operator->() const { return reg; } + ModuleRegistry &operator*() const { return *reg; } +}; + +void registerBuiltinModule(ModuleRegistry ®, const wibo::Module *module); + +LockedRegistry registry() { + static ModuleRegistry reg; + std::unique_lock guard(reg.mutex); + if (!reg.initialized) { + reg.initialized = true; + const wibo::Module *builtins[] = { + &lib_advapi32, &lib_bcrypt, &lib_crt, &lib_kernel32, &lib_lmgr, &lib_mscoree, &lib_msvcrt, + &lib_ntdll, &lib_ole32, &lib_rpcrt4, &lib_user32, &lib_vcruntime, &lib_version, nullptr, + }; + for (const wibo::Module **module = builtins; *module; ++module) { + registerBuiltinModule(reg, *module); + } + } + return {reg, std::move(guard)}; } std::string normalizeAlias(const std::string &value) { std::string out = value; std::replace(out.begin(), out.end(), '/', '\\'); - std::transform(out.begin(), out.end(), out.begin(), - [](unsigned char c) { return static_cast(std::tolower(c)); }); + toLowerInPlace(out); return out; } @@ -211,7 +211,7 @@ std::vector candidateModuleNames(const ParsedModuleName &parsed) { std::string normalizedBaseKey(const ParsedModuleName &parsed) { if (parsed.base.empty()) { - return std::string(); + return {}; } std::string base = parsed.base; if (!parsed.hasExtension && !parsed.endsWithDot) { @@ -231,48 +231,36 @@ std::optional combineAndFind(const std::filesystem::path return files::findCaseInsensitiveFile(directory, filename); } -std::vector collectSearchDirectories(bool alteredSearchPath) { +std::vector collectSearchDirectories(ModuleRegistry ®, bool alteredSearchPath) { std::vector dirs; std::unordered_set seen; - auto addDirectory = [&](const std::filesystem::path &dir) { - if (dir.empty()) - return; - std::error_code ec; - auto canonical = std::filesystem::weakly_canonical(dir, ec); + auto addDirectory = [&](const std::filesystem::path &dir) { + if (dir.empty()) + return; + std::error_code ec; + auto canonical = std::filesystem::weakly_canonical(dir, ec); if (ec) { canonical = std::filesystem::absolute(dir, ec); } if (ec) return; - if (!std::filesystem::exists(canonical, ec) || ec) - return; - std::string key = toLowerCopy(canonical.string()); - if (seen.insert(key).second) { - dirs.push_back(canonical); - } - }; - - auto ® = registry(); - - if (wibo::argv && wibo::argc > 0 && wibo::argv[0]) { - std::filesystem::path mainBinary = std::filesystem::absolute(wibo::argv[0]); - if (mainBinary.has_parent_path()) { - addDirectory(mainBinary.parent_path()); - } - } + if (!std::filesystem::exists(canonical, ec) || ec) + return; + std::string key = stringToLower(canonical.string()); + if (seen.insert(key).second) { + dirs.push_back(canonical); + } + }; if (reg.dllDirectory.has_value()) { addDirectory(*reg.dllDirectory); } - addDirectory(files::pathFromWindows("Z:/Windows/System32")); - addDirectory(files::pathFromWindows("Z:/Windows")); - if (!alteredSearchPath) { addDirectory(std::filesystem::current_path()); } - if (const char *envPath = std::getenv("PATH")) { + if (const char *envPath = std::getenv("WIBO_PATH")) { std::string pathList = envPath; size_t start = 0; while (start <= pathList.size()) { @@ -302,7 +290,9 @@ std::vector collectSearchDirectories(bool alteredSearchPa return dirs; } -std::optional resolveModuleOnDisk(const std::string &requestedName, bool alteredSearchPath) { + +std::optional resolveModuleOnDisk(ModuleRegistry ®, const std::string &requestedName, + bool alteredSearchPath) { ParsedModuleName parsed = parseModuleName(requestedName); auto names = candidateModuleNames(parsed); @@ -310,9 +300,9 @@ std::optional resolveModuleOnDisk(const std::string &requ for (const auto &candidate : names) { auto combined = parsed.directory + "\\" + candidate; auto posixPath = files::pathFromWindows(combined.c_str()); - if (!posixPath.empty()) { - auto resolved = files::findCaseInsensitiveFile(std::filesystem::path(posixPath).parent_path(), - std::filesystem::path(posixPath).filename().string()); + if (!posixPath.empty()) { + auto resolved = files::findCaseInsensitiveFile(std::filesystem::path(posixPath).parent_path(), + std::filesystem::path(posixPath).filename().string()); if (resolved) { return files::canonicalPath(*resolved); } @@ -321,7 +311,7 @@ std::optional resolveModuleOnDisk(const std::string &requ return std::nullopt; } - auto dirs = collectSearchDirectories(alteredSearchPath); + auto dirs = collectSearchDirectories(reg, alteredSearchPath); for (const auto &dir : dirs) { for (const auto &candidate : names) { auto resolved = combineAndFind(dir, candidate); @@ -340,8 +330,7 @@ std::string storageKeyForPath(const std::filesystem::path &path) { std::string storageKeyForBuiltin(const std::string &normalizedName) { return normalizedName; } -wibo::ModuleInfo *findByAlias(const std::string &alias) { - auto ® = registry(); +wibo::ModuleInfo *findByAlias(ModuleRegistry ®, const std::string &alias) { auto it = reg.modulesByAlias.find(alias); if (it != reg.modulesByAlias.end()) { return it->second; @@ -349,11 +338,10 @@ wibo::ModuleInfo *findByAlias(const std::string &alias) { return nullptr; } -void registerAlias(const std::string &alias, wibo::ModuleInfo *info) { +void registerAlias(ModuleRegistry ®, const std::string &alias, wibo::ModuleInfo *info) { if (alias.empty() || !info) { return; } - auto ® = registry(); auto it = reg.modulesByAlias.find(alias); if (it == reg.modulesByAlias.end()) { reg.modulesByAlias[alias] = info; @@ -368,7 +356,7 @@ void registerAlias(const std::string &alias, wibo::ModuleInfo *info) { } } -void registerBuiltinModule(const wibo::Module *module) { +void registerBuiltinModule(ModuleRegistry ®, const wibo::Module *module) { if (!module) { return; } @@ -380,7 +368,6 @@ void registerBuiltinModule(const wibo::Module *module) { entry->exportsInitialized = true; auto storageKey = storageKeyForBuiltin(entry->normalizedName); auto raw = entry.get(); - auto ® = registry(); reg.modulesByKey[storageKey] = std::move(entry); reg.builtinAliasLists[module] = {}; @@ -395,7 +382,7 @@ void registerBuiltinModule(const wibo::Module *module) { if (pinModule) { reg.pinnedAliases.insert(alias); } - registerAlias(alias, raw); + registerAlias(reg, alias, raw); reg.builtinAliasMap[alias] = raw; ParsedModuleName parsed = parseModuleName(module->names[i]); std::string baseAlias = normalizedBaseKey(parsed); @@ -404,7 +391,7 @@ void registerBuiltinModule(const wibo::Module *module) { if (pinModule) { reg.pinnedAliases.insert(baseAlias); } - registerAlias(baseAlias, raw); + registerAlias(reg, baseAlias, raw); reg.builtinAliasMap[baseAlias] = raw; } } @@ -447,35 +434,17 @@ void callDllMain(wibo::ModuleInfo &info, DWORD reason) { } } -void ensureInitialized() { - auto ® = registry(); - if (reg.initialized) { - return; - } - reg.initialized = true; - - const wibo::Module *builtins[] = { - &lib_advapi32, &lib_bcrypt, &lib_crt, &lib_kernel32, &lib_lmgr, &lib_mscoree, &lib_msvcrt, - &lib_ntdll, &lib_ole32, &lib_rpcrt4, &lib_user32, &lib_vcruntime, &lib_version, nullptr, - }; - - for (const wibo::Module **module = builtins; *module; ++module) { - registerBuiltinModule(*module); - } -} - -void registerExternalModuleAliases(const std::string &requestedName, const std::filesystem::path &resolvedPath, - wibo::ModuleInfo *info) { +void registerExternalModuleAliases(ModuleRegistry ®, const std::string &requestedName, + const std::filesystem::path &resolvedPath, wibo::ModuleInfo *info) { ParsedModuleName parsed = parseModuleName(requestedName); - registerAlias(normalizedBaseKey(parsed), info); - registerAlias(normalizeAlias(requestedName), info); - registerAlias(storageKeyForPath(resolvedPath), info); + registerAlias(reg, normalizedBaseKey(parsed), info); + registerAlias(reg, normalizeAlias(requestedName), info); + registerAlias(reg, storageKeyForPath(resolvedPath), info); } -wibo::ModuleInfo *moduleFromAddress(void *addr) { +wibo::ModuleInfo *moduleFromAddress(ModuleRegistry ®, void *addr) { if (!addr) return nullptr; - auto ® = registry(); for (auto &pair : reg.modulesByKey) { wibo::ModuleInfo *info = pair.second.get(); if (!info) @@ -491,7 +460,7 @@ wibo::ModuleInfo *moduleFromAddress(void *addr) { } if (!base || size == 0) continue; - uint8_t *ptr = static_cast(addr); + auto *ptr = static_cast(addr); if (ptr >= base && ptr < base + size) { return info; } @@ -523,7 +492,8 @@ void ensureExportsInitialized(wibo::ModuleInfo &info) { } if (rva >= exe->exportDirectoryRVA && rva < exe->exportDirectoryRVA + exe->exportDirectorySize) { const char *forward = exe->fromRVA(rva); - info.exportsByOrdinal[i] = resolveMissingFuncName(info.originalName.c_str(), forward); + info.exportsByOrdinal[i] = + reinterpret_cast(resolveMissingFuncName(info.originalName.c_str(), forward)); } else { info.exportsByOrdinal[i] = exe->fromRVA(rva); } @@ -536,7 +506,7 @@ void ensureExportsInitialized(wibo::ModuleInfo &info) { auto *ordinals = exe->fromRVA(dir->addressOfNameOrdinals); for (uint32_t i = 0; i < nameCount; ++i) { uint16_t index = ordinals[i]; - uint16_t ordinal = static_cast(dir->base + index); + auto ordinal = static_cast(dir->base + index); if (index < info.exportsByOrdinal.size()) { const char *namePtr = exe->fromRVA(names[i]); info.exportNameToOrdinal[std::string(namePtr)] = ordinal; @@ -550,14 +520,11 @@ void ensureExportsInitialized(wibo::ModuleInfo &info) { namespace wibo { -void initializeModuleRegistry() { - std::lock_guard lock(registry().mutex); - ensureInitialized(); -} +void initializeModuleRegistry() { registry(); } void shutdownModuleRegistry() { - std::lock_guard lock(registry().mutex); - for (auto &pair : registry().modulesByKey) { + auto reg = registry(); + for (auto &pair : reg->modulesByKey) { ModuleInfo *info = pair.second.get(); if (!info || info->module) { continue; @@ -567,40 +534,38 @@ void shutdownModuleRegistry() { callDllMain(*info, DLL_PROCESS_DETACH); } } - registry().modulesByKey.clear(); - registry().modulesByAlias.clear(); - registry().dllDirectory.reset(); - registry().initialized = false; - registry().onExitTables.clear(); + reg->modulesByKey.clear(); + reg->modulesByAlias.clear(); + reg->dllDirectory.reset(); + reg->initialized = false; + reg->onExitTables.clear(); } ModuleInfo *moduleInfoFromHandle(HMODULE module) { return static_cast(module); } void setDllDirectoryOverride(const std::filesystem::path &path) { auto canonical = files::canonicalPath(path); - std::lock_guard lock(registry().mutex); - registry().dllDirectory = canonical; + auto reg = registry(); + reg->dllDirectory = canonical; } void clearDllDirectoryOverride() { - std::lock_guard lock(registry().mutex); - registry().dllDirectory.reset(); + auto reg = registry(); + reg->dllDirectory.reset(); } std::optional dllDirectoryOverride() { - std::lock_guard lock(registry().mutex); - return registry().dllDirectory; + auto reg = registry(); + return reg->dllDirectory; } void registerOnExitTable(void *table) { if (!table) return; - std::lock_guard lock(registry().mutex); - ensureInitialized(); - auto ® = registry(); - if (reg.onExitTables.find(table) == reg.onExitTables.end()) { - if (auto *info = moduleFromAddress(table)) { - reg.onExitTables[table] = info; + auto reg = registry(); + if (reg->onExitTables.find(table) == reg->onExitTables.end()) { + if (auto *info = moduleFromAddress(*reg, table)) { + reg->onExitTables[table] = info; } } } @@ -608,16 +573,15 @@ void registerOnExitTable(void *table) { void addOnExitFunction(void *table, void (*func)()) { if (!func) return; - std::lock_guard lock(registry().mutex); - auto ® = registry(); + auto reg = registry(); ModuleInfo *info = nullptr; - auto it = reg.onExitTables.find(table); - if (it != reg.onExitTables.end()) { + auto it = reg->onExitTables.find(table); + if (it != reg->onExitTables.end()) { info = it->second; } else if (table) { - info = moduleFromAddress(table); + info = moduleFromAddress(*reg, table); if (info) - reg.onExitTables[table] = info; + reg->onExitTables[table] = info; } if (info) { info->onExitFunctions.push_back(reinterpret_cast(func)); @@ -635,16 +599,15 @@ void runPendingOnExit(ModuleInfo &info) { } void executeOnExitTable(void *table) { - std::lock_guard lock(registry().mutex); - auto ® = registry(); + auto reg = registry(); ModuleInfo *info = nullptr; if (table) { - auto it = reg.onExitTables.find(table); - if (it != reg.onExitTables.end()) { + auto it = reg->onExitTables.find(table); + if (it != reg->onExitTables.end()) { info = it->second; - reg.onExitTables.erase(it); + reg->onExitTables.erase(it); } else { - info = moduleFromAddress(table); + info = moduleFromAddress(*reg, table); } } if (info) { @@ -656,13 +619,12 @@ HMODULE findLoadedModule(const char *name) { if (!name) { return nullptr; } - std::lock_guard lock(registry().mutex); - ensureInitialized(); + auto reg = registry(); ParsedModuleName parsed = parseModuleName(name); std::string alias = normalizedBaseKey(parsed); - ModuleInfo *info = findByAlias(alias); + ModuleInfo *info = findByAlias(*reg, alias); if (!info) { - info = findByAlias(normalizeAlias(name)); + info = findByAlias(*reg, normalizeAlias(name)); } return info; } @@ -675,23 +637,21 @@ HMODULE loadModule(const char *dllName) { std::string requested(dllName); DEBUG_LOG("loadModule(%s)\n", requested.c_str()); - std::lock_guard lock(registry().mutex); - ensureInitialized(); + auto reg = registry(); ParsedModuleName parsed = parseModuleName(requested); - auto ® = registry(); DWORD diskError = ERROR_SUCCESS; auto tryLoadExternal = [&](const std::filesystem::path &path) -> ModuleInfo * { std::string key = storageKeyForPath(path); - auto existingIt = reg.modulesByKey.find(key); - if (existingIt != reg.modulesByKey.end()) { + auto existingIt = reg->modulesByKey.find(key); + if (existingIt != reg->modulesByKey.end()) { ModuleInfo *info = existingIt->second.get(); if (info->refCount != UINT_MAX) { info->refCount++; } - registerExternalModuleAliases(requested, files::canonicalPath(path), info); + registerExternalModuleAliases(*reg, requested, files::canonicalPath(path), info); return info; } @@ -725,15 +685,15 @@ HMODULE loadModule(const char *dllName) { info->dontResolveReferences = false; ModuleInfo *raw = info.get(); - reg.modulesByKey[key] = std::move(info); - registerExternalModuleAliases(requested, raw->resolvedPath, raw); + reg->modulesByKey[key] = std::move(info); + registerExternalModuleAliases(*reg, requested, raw->resolvedPath, raw); ensureExportsInitialized(*raw); callDllMain(*raw, DLL_PROCESS_ATTACH); return raw; }; auto resolveAndLoadExternal = [&]() -> ModuleInfo * { - auto resolvedPath = resolveModuleOnDisk(requested, false); + auto resolvedPath = resolveModuleOnDisk(*reg, requested, false); if (!resolvedPath) { DEBUG_LOG(" module not found on disk\n"); return nullptr; @@ -742,9 +702,9 @@ HMODULE loadModule(const char *dllName) { }; std::string alias = normalizedBaseKey(parsed); - ModuleInfo *existing = findByAlias(alias); + ModuleInfo *existing = findByAlias(*reg, alias); if (!existing) { - existing = findByAlias(normalizeAlias(requested)); + existing = findByAlias(*reg, normalizeAlias(requested)); } if (existing) { DEBUG_LOG(" found existing module alias %s (builtin=%d)\n", alias.c_str(), existing->module != nullptr); @@ -756,7 +716,7 @@ HMODULE loadModule(const char *dllName) { lastError = ERROR_SUCCESS; return existing; } - bool pinned = reg.pinnedModules.count(existing) != 0; + bool pinned = reg->pinnedModules.count(existing) != 0; if (!pinned) { if (ModuleInfo *external = resolveAndLoadExternal()) { DEBUG_LOG(" replaced builtin module %s with external copy\n", requested.c_str()); @@ -777,13 +737,13 @@ HMODULE loadModule(const char *dllName) { auto fallbackAlias = normalizedBaseKey(parsed); ModuleInfo *builtin = nullptr; - auto builtinIt = reg.builtinAliasMap.find(fallbackAlias); - if (builtinIt != reg.builtinAliasMap.end()) { + auto builtinIt = reg->builtinAliasMap.find(fallbackAlias); + if (builtinIt != reg->builtinAliasMap.end()) { builtin = builtinIt->second; } if (!builtin) { - builtinIt = reg.builtinAliasMap.find(normalizeAlias(requested)); - if (builtinIt != reg.builtinAliasMap.end()) { + builtinIt = reg->builtinAliasMap.find(normalizeAlias(requested)); + if (builtinIt != reg->builtinAliasMap.end()) { builtin = builtinIt->second; } } @@ -801,7 +761,7 @@ void freeModule(HMODULE module) { if (!module) { return; } - std::lock_guard lock(registry().mutex); + auto reg = registry(); ModuleInfo *info = moduleInfoFromHandle(module); if (!info || info->refCount == UINT_MAX) { return; @@ -811,10 +771,9 @@ void freeModule(HMODULE module) { } info->refCount--; if (info->refCount == 0) { - auto ® = registry(); - for (auto it = reg.onExitTables.begin(); it != reg.onExitTables.end();) { + for (auto it = reg->onExitTables.begin(); it != reg->onExitTables.end();) { if (it->second == info) { - it = reg.onExitTables.erase(it); + it = reg->onExitTables.erase(it); } else { ++it; } @@ -823,10 +782,10 @@ void freeModule(HMODULE module) { callDllMain(*info, DLL_PROCESS_DETACH); std::string key = info->resolvedPath.empty() ? storageKeyForBuiltin(info->normalizedName) : storageKeyForPath(info->resolvedPath); - reg.modulesByKey.erase(key); - for (auto it = reg.modulesByAlias.begin(); it != reg.modulesByAlias.end();) { + reg->modulesByKey.erase(key); + for (auto it = reg->modulesByAlias.begin(); it != reg->modulesByAlias.end();) { if (it->second == info) { - it = reg.modulesByAlias.erase(it); + it = reg->modulesByAlias.erase(it); } else { ++it; } @@ -852,7 +811,7 @@ void *resolveFuncByName(HMODULE module, const char *funcName) { return resolveFuncByOrdinal(module, it->second); } } - return resolveMissingFuncName(info->originalName.c_str(), funcName); + return reinterpret_cast(resolveMissingFuncName(info->originalName.c_str(), funcName)); } void *resolveFuncByOrdinal(HMODULE module, uint16_t ordinal) { @@ -869,7 +828,7 @@ void *resolveFuncByOrdinal(HMODULE module, uint16_t ordinal) { if (!info->module) { ensureExportsInitialized(*info); if (!info->exportsByOrdinal.empty() && ordinal >= info->exportOrdinalBase) { - size_t index = static_cast(ordinal - info->exportOrdinalBase); + auto index = static_cast(ordinal - info->exportOrdinalBase); if (index < info->exportsByOrdinal.size()) { void *addr = info->exportsByOrdinal[index]; if (addr) { @@ -878,22 +837,20 @@ void *resolveFuncByOrdinal(HMODULE module, uint16_t ordinal) { } } } - return resolveMissingFuncOrdinal(info->originalName.c_str(), ordinal); + return reinterpret_cast(resolveMissingFuncOrdinal(info->originalName.c_str(), ordinal)); } void *resolveMissingImportByName(const char *dllName, const char *funcName) { const char *safeDll = dllName ? dllName : ""; const char *safeFunc = funcName ? funcName : ""; - std::lock_guard lock(registry().mutex); - ensureInitialized(); - return resolveMissingFuncName(safeDll, safeFunc); + [[maybe_unused]] auto reg = registry(); + return reinterpret_cast(resolveMissingFuncName(safeDll, safeFunc)); } void *resolveMissingImportByOrdinal(const char *dllName, uint16_t ordinal) { const char *safeDll = dllName ? dllName : ""; - std::lock_guard lock(registry().mutex); - ensureInitialized(); - return resolveMissingFuncOrdinal(safeDll, ordinal); + [[maybe_unused]] auto reg = registry(); + return reinterpret_cast(resolveMissingFuncOrdinal(safeDll, ordinal)); } Executable *executableFromModule(HMODULE module) { diff --git a/strutil.cpp b/strutil.cpp index 1c0708b..b6c8e9d 100644 --- a/strutil.cpp +++ b/strutil.cpp @@ -1,10 +1,47 @@ #include "strutil.h" #include "common.h" +#include +#include +#include #include #include #include #include +void toLowerInPlace(std::string &str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return static_cast(std::tolower(c)); }); +} + +void toUpperInPlace(std::string &str) { + std::transform(str.begin(), str.end(), str.begin(), + [](unsigned char c) { return static_cast(std::toupper(c)); }); +} + +std::string stringToLower(std::string_view str) { + std::string result(str); + toLowerInPlace(result); + return result; +} + +std::string stringToUpper(std::string_view str) { + std::string result(str); + toUpperInPlace(result); + return result; +} + +uint16_t wcharToLower(uint16_t ch) { + if (ch >= 'A' && ch <= 'Z') { + return static_cast(ch + ('a' - 'A')); + } + wchar_t wide = static_cast(ch); + wchar_t lowered = std::towlower(wide); + if (lowered < 0 || lowered > 0xFFFF) { + return ch; + } + return static_cast(lowered); +} + size_t wstrlen(const uint16_t *str) { if (!str) return 0; diff --git a/strutil.h b/strutil.h index 897ae4a..69289f4 100644 --- a/strutil.h +++ b/strutil.h @@ -2,6 +2,7 @@ #include #include +#include #include size_t wstrlen(const uint16_t *str); @@ -18,3 +19,8 @@ std::string wideStringToString(const uint16_t *src, int len = -1); std::vector stringToWideString(const char *src); long wstrtol(const uint16_t *string, uint16_t **end_ptr, int base); unsigned long wstrtoul(const uint16_t *string, uint16_t **end_ptr, int base); +void toLowerInPlace(std::string &str); +void toUpperInPlace(std::string &str); +std::string stringToLower(std::string_view str); +std::string stringToUpper(std::string_view str); +uint16_t wcharToLower(uint16_t ch);