diff --git a/CMakeLists.txt b/CMakeLists.txt index b1d6d10..765473e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -175,6 +175,7 @@ add_executable(wibo src/processes.cpp src/resources.cpp src/strutil.cpp + src/tls.cpp ) target_compile_definitions(wibo PRIVATE _GNU_SOURCE _FILE_OFFSET_BITS=64 _TIME_BITS=64) target_compile_features(wibo PRIVATE cxx_std_20) @@ -261,6 +262,8 @@ if (WIBO_ENABLE_FIXTURE_TESTS) wibo_add_fixture_bin(NAME test_bcrypt SOURCES test/test_bcrypt.c COMPILE_OPTIONS -lbcrypt) wibo_add_fixture_bin(NAME test_resources SOURCES test/test_resources.c ${WIBO_TEST_BIN_DIR}/test_resources_res.o COMPILE_OPTIONS -lversion) wibo_add_fixture_bin(NAME test_threading SOURCES test/test_threading.c) + wibo_add_fixture_bin(NAME test_tls SOURCES test/test_tls.c) + wibo_add_fixture_bin(NAME test_tls_reloc SOURCES test/test_tls_reloc.c) wibo_add_fixture_bin(NAME test_handleapi SOURCES test/test_handleapi.c) wibo_add_fixture_bin(NAME test_findfile SOURCES test/test_findfile.c) wibo_add_fixture_bin(NAME test_synchapi SOURCES test/test_synchapi.c) @@ -283,6 +286,7 @@ if (WIBO_ENABLE_FIXTURE_TESTS) wibo_add_fixture_dll(NAME external_exports SOURCES test/external_exports.c) wibo_add_fixture_dll(NAME dll_attach_failure SOURCES test/dll_attach_failure.c) wibo_add_fixture_dll(NAME thread_notifications SOURCES test/thread_notifications.c) + wibo_add_fixture_dll(NAME tls_reloc SOURCES test/tls_reloc_dll.c) # Resources for test_resources add_custom_command( diff --git a/dll/kernel32/processthreadsapi.cpp b/dll/kernel32/processthreadsapi.cpp index f97d13c..8073f1e 100644 --- a/dll/kernel32/processthreadsapi.cpp +++ b/dll/kernel32/processthreadsapi.cpp @@ -10,8 +10,8 @@ #include "processes.h" #include "strutil.h" #include "timeutil.h" +#include "tls.h" -#include #include #include #include @@ -33,10 +33,6 @@ namespace { using kernel32::ThreadObject; -constexpr DWORD kMaxTlsValues = 100; -bool g_tlsSlotsUsed[kMaxTlsValues] = {false}; -LPVOID g_tlsSlots[kMaxTlsValues] = {nullptr}; - DWORD_PTR g_processAffinityMask = 0; bool g_processAffinityMaskInitialized = false; @@ -386,47 +382,50 @@ BOOL WIN_FUNC GetExitCodeProcess(HANDLE hProcess, LPDWORD lpExitCode) { DWORD WIN_FUNC TlsAlloc() { HOST_CONTEXT_GUARD(); VERBOSE_LOG("TlsAlloc()\n"); - for (DWORD i = 0; i < kMaxTlsValues; ++i) { - if (!g_tlsSlotsUsed[i]) { - g_tlsSlotsUsed[i] = true; - g_tlsSlots[i] = nullptr; - return i; - } + DWORD index = wibo::tls::reserveSlot(); + if (index == wibo::tls::kInvalidTlsIndex) { + wibo::lastError = ERROR_NOT_ENOUGH_MEMORY; + return TLS_OUT_OF_INDEXES; } - wibo::lastError = ERROR_NOT_ENOUGH_MEMORY; - return TLS_OUT_OF_INDEXES; + wibo::tls::setValue(index, nullptr); + wibo::lastError = ERROR_SUCCESS; + return index; } BOOL WIN_FUNC TlsFree(DWORD dwTlsIndex) { HOST_CONTEXT_GUARD(); VERBOSE_LOG("TlsFree(%u)\n", dwTlsIndex); - if (dwTlsIndex >= kMaxTlsValues || !g_tlsSlotsUsed[dwTlsIndex]) { + if (!wibo::tls::releaseSlot(dwTlsIndex)) { wibo::lastError = ERROR_INVALID_PARAMETER; return FALSE; } - g_tlsSlotsUsed[dwTlsIndex] = false; - g_tlsSlots[dwTlsIndex] = nullptr; + wibo::tls::setValue(dwTlsIndex, nullptr); + wibo::lastError = ERROR_SUCCESS; return TRUE; } LPVOID WIN_FUNC TlsGetValue(DWORD dwTlsIndex) { HOST_CONTEXT_GUARD(); VERBOSE_LOG("TlsGetValue(%u)\n", dwTlsIndex); - if (dwTlsIndex >= kMaxTlsValues || !g_tlsSlotsUsed[dwTlsIndex]) { + if (!wibo::tls::isSlotAllocated(dwTlsIndex)) { wibo::lastError = ERROR_INVALID_PARAMETER; return nullptr; } - return g_tlsSlots[dwTlsIndex]; + return wibo::tls::getValue(dwTlsIndex); } BOOL WIN_FUNC TlsSetValue(DWORD dwTlsIndex, LPVOID lpTlsValue) { HOST_CONTEXT_GUARD(); VERBOSE_LOG("TlsSetValue(%u, %p)\n", dwTlsIndex, lpTlsValue); - if (dwTlsIndex >= kMaxTlsValues || !g_tlsSlotsUsed[dwTlsIndex]) { + if (!wibo::tls::isSlotAllocated(dwTlsIndex)) { wibo::lastError = ERROR_INVALID_PARAMETER; return FALSE; } - g_tlsSlots[dwTlsIndex] = lpTlsValue; + if (!wibo::tls::setValue(dwTlsIndex, lpTlsValue)) { + wibo::lastError = ERROR_INVALID_PARAMETER; + return FALSE; + } + wibo::lastError = ERROR_SUCCESS; return TRUE; } diff --git a/src/common.h b/src/common.h index e3075fd..479d95f 100644 --- a/src/common.h +++ b/src/common.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -220,6 +221,8 @@ constexpr DWORD PIPE_ACCEPT_REMOTE_CLIENTS = 0x00000000; constexpr DWORD PIPE_REJECT_REMOTE_CLIENTS = 0x00000008; constexpr DWORD PIPE_UNLIMITED_INSTANCES = 255; +constexpr size_t kTlsSlotCount = 64; + struct UNICODE_STRING { unsigned short Length; unsigned short MaximumLength; @@ -248,23 +251,150 @@ struct PEB { unsigned int SessionId; }; +struct ClientId { + void *uniqueProcess; + void *uniqueThread; +}; + +struct ListEntry { + void *flink; + void *blink; +}; + +struct ActivationContextStack { + void *activeFrame; + ListEntry frameListCache; + uint32_t flags; + uint32_t nextCookieSequenceNumber; + uint32_t stackId; +}; + +struct GdiTebBatch { + uint32_t offset; + uint32_t hdc; + uint32_t buffer[310]; +}; + struct TIB { - void *sehFrame; + void *exceptionList; void *stackBase; void *stackLimit; void *subSystemTib; - void *fiberData; - void *arbitraryDataSlot; - TIB *tib; - char reserved1[0x14]; + union { + void *fiberData; + uint32_t version; + } fiber; + void *arbitraryUserPointer; + TIB *self; + void *environmentPointer; + ClientId clientId; + void *activeRpcHandle; + void *threadLocalStoragePointer; PEB *peb; - char reserved2[0x1000]; + uint32_t lastErrorValue; + uint32_t countOfOwnedCriticalSections; + void *csrClientThread; + void *win32ThreadInfo; + uint32_t user32Reserved[26]; + uint32_t userReserved[5]; + void *wow32Reserved; + uint32_t currentLocale; + uint32_t fpSoftwareStatusRegister; + void *reservedForDebuggerInstrumentation[16]; + void *systemReserved1[26]; + uint8_t placeholderCompatibilityMode; + uint8_t placeholderHydrationAlwaysExplicit; + uint8_t placeholderReserved[10]; + uint32_t proxiedProcessId; + ActivationContextStack activationContextStack; + uint8_t workingOnBehalfOfTicket[8]; + int32_t exceptionCode; + ActivationContextStack *activationContextStackPointer; + uintptr_t instrumentationCallbackSp; + uintptr_t instrumentationCallbackPreviousPc; + uintptr_t instrumentationCallbackPreviousSp; + uint8_t instrumentationCallbackDisabled; + uint8_t spareBytes1[23]; + uint32_t txFsContext; + GdiTebBatch gdiTebBatch; + ClientId realClientId; + void *gdiCachedProcessHandle; + uint32_t gdiClientPID; + uint32_t gdiClientTID; + void *gdiThreadLocaleInfo; + uintptr_t win32ClientInfo[62]; + void *glDispatchTable[233]; + void *glReserved1[29]; + void *glReserved2; + void *glSectionInfo; + void *glSection; + void *glTable; + void *glCurrentRC; + void *glContext; + uint32_t lastStatusValue; + UNICODE_STRING staticUnicodeString; + WCHAR staticUnicodeBuffer[261]; + void *deallocationStack; + void *tlsSlots[kTlsSlotCount]; + ListEntry tlsLinks; + void *vdm; + void *reservedForNtRpc; + void *dbgSsReserved[2]; + uint32_t hardErrorMode; + void *instrumentation[9]; + GUID activityId; + void *subProcessTag; + void *perflibData; + void *etwTraceData; + void *winSockData; + uint32_t gdiBatchCount; + uint32_t idealProcessorValue; + uint32_t guaranteedStackBytes; + void *reservedForPerf; + void *reservedForOle; + uint32_t waitingOnLoaderLock; + void *savedPriorityState; + uintptr_t reservedForCodeCoverage; + void *threadPoolData; + void **tlsExpansionSlots; + uint32_t muiGeneration; + uint32_t isImpersonating; + void *nlsCache; + void *shimData; + uint32_t heapVirtualAffinity; + void *currentTransactionHandle; + void *activeFrame; + void *flsSlots; + void *preferredLanguages; + void *userPrefLanguages; + void *mergedPrefLanguages; + uint32_t muiImpersonation; + uint16_t crossTebFlags; + uint16_t sameTebFlags; + void *txnScopeEnterCallback; + void *txnScopeExitCallback; + void *txnScopeContext; + uint32_t lockCount; + long wowTebOffset; + void *resourceRetValue; + void *reservedForWdf; + uint64_t reservedForCrt; + GUID effectiveContainerId; uint16_t hostFsSelector; uint16_t hostGsSelector; uint8_t hostSegmentsValid; uint8_t padding[3]; }; +static_assert(offsetof(TIB, self) == 0x18, "Self pointer offset mismatch"); +static_assert(offsetof(TIB, threadLocalStoragePointer) == 0x2C, "TLS pointer offset mismatch"); +static_assert(offsetof(TIB, peb) == 0x30, "PEB pointer offset mismatch"); +static_assert(offsetof(TIB, lastErrorValue) == 0x34, "LastErrorValue offset mismatch"); +static_assert(offsetof(TIB, gdiTebBatch) == 0x1D4, "GdiTebBatch offset mismatch"); +static_assert(offsetof(TIB, deallocationStack) == 0xE0C, "DeallocationStack offset mismatch"); +static_assert(offsetof(TIB, tlsSlots) == 0xE10, "TLS slots offset mismatch"); +static_assert(sizeof(TIB) >= 0x1000, "TIB too small"); + namespace wibo { extern thread_local uint32_t lastError; diff --git a/src/loader.cpp b/src/loader.cpp index c1207fe..4697be4 100644 --- a/src/loader.cpp +++ b/src/loader.cpp @@ -230,6 +230,8 @@ bool wibo::Executable::loadPE(FILE *file, bool exec) { importDirectorySize = header32.importTable.size; delayImportDirectoryRVA = header32.delayImportDescriptor.virtualAddress; delayImportDirectorySize = header32.delayImportDescriptor.size; + tlsDirectoryRVA = header32.tlsTable.virtualAddress; + tlsDirectorySize = header32.tlsTable.size; execMapped = exec; importsResolved = false; importsResolving = false; diff --git a/src/main.cpp b/src/main.cpp index 10c91e2..8f6d07a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,4 +1,3 @@ -#include "async_io.h" #include "common.h" #include "context.h" #include "files.h" @@ -6,6 +5,7 @@ #include "processes.h" #include "strutil.h" #include "version_info.h" +#include "tls.h" #include #include @@ -57,7 +57,8 @@ TIB *wibo::allocateTib() { if (!newTib) { return nullptr; } - newTib->tib = newTib; + tls::initializeTib(newTib); + newTib->self = newTib; newTib->peb = processPeb; return newTib; } @@ -451,7 +452,8 @@ int main(int argc, char **argv) { // Create TIB memset(&tib, 0, sizeof(tib)); - tib.tib = &tib; + wibo::tls::initializeTib(&tib); + tib.self = &tib; tib.peb = static_cast(calloc(1, sizeof(PEB))); tib.peb->ProcessParameters = static_cast(calloc(1, sizeof(RTL_USER_PROCESS_PARAMETERS))); @@ -596,6 +598,10 @@ int main(int argc, char **argv) { fprintf(stderr, "Failed to resolve imports for main module (DLL initialization failure?)\n"); abort(); } + if (!wibo::initializeModuleTls(*wibo::mainModule)) { + fprintf(stderr, "Failed to initialize TLS for main module\n"); + return 1; + } // Reset last error wibo::lastError = 0; diff --git a/src/modules.cpp b/src/modules.cpp index 1f1d2f6..8d19d8c 100644 --- a/src/modules.cpp +++ b/src/modules.cpp @@ -5,6 +5,7 @@ #include "errors.h" #include "files.h" #include "strutil.h" +#include "tls.h" #include #include @@ -40,6 +41,10 @@ constexpr DWORD DLL_PROCESS_DETACH = 0; constexpr DWORD DLL_PROCESS_ATTACH = 1; constexpr DWORD DLL_THREAD_ATTACH = 2; constexpr DWORD DLL_THREAD_DETACH = 3; +constexpr DWORD TLS_PROCESS_ATTACH = DLL_PROCESS_ATTACH; +constexpr DWORD TLS_PROCESS_DETACH = DLL_PROCESS_DETACH; +constexpr DWORD TLS_THREAD_ATTACH = DLL_THREAD_ATTACH; +constexpr DWORD TLS_THREAD_DETACH = DLL_THREAD_DETACH; struct PEExportDirectory { uint32_t characteristics; @@ -224,6 +229,103 @@ std::string normalizedBaseKey(const ParsedModuleName &parsed) { return normalizeAlias(base); } +struct ImageTlsDirectory32 { + uint32_t StartAddressOfRawData; + uint32_t EndAddressOfRawData; + uint32_t AddressOfIndex; + uint32_t AddressOfCallBacks; + uint32_t SizeOfZeroFill; + uint32_t Characteristics; +}; + +uintptr_t resolveModuleAddress(const wibo::Executable &exec, uintptr_t address) { + if (address == 0) { + return 0; + } + const uintptr_t actualBase = reinterpret_cast(exec.imageBase); + if (address >= actualBase) { + uintptr_t offset = address - actualBase; + if (offset < exec.imageSize) { + return address; + } + } + const uintptr_t preferredBase = static_cast(exec.preferredImageBase); + if (address >= preferredBase) { + return actualBase + (address - preferredBase); + } + return static_cast(static_cast(address) + exec.relocationDelta); +} + +void allocateModuleTlsForThread(wibo::ModuleInfo &module, TIB *tib) { + if (!tib) { + return; + } + auto &info = module.tlsInfo; + if (!info.hasTls || info.index == wibo::tls::kInvalidTlsIndex || info.index >= kTlsSlotCount) { + return; + } + if (info.threadAllocations.find(tib) != info.threadAllocations.end()) { + return; + } + void *block = nullptr; + const size_t allocationSize = info.allocationSize; + if (allocationSize > 0) { + block = std::malloc(allocationSize); + if (!block) { + DEBUG_LOG(" allocateModuleTlsForThread: failed to allocate %zu bytes for %s\n", allocationSize, + module.originalName.c_str()); + return; + } + std::memset(block, 0, allocationSize); + if (info.templateData && info.templateSize > 0) { + std::memcpy(block, info.templateData, info.templateSize); + } + } + info.threadAllocations.emplace(tib, block); + wibo::tls::setValue(tib, info.index, block); +} + +void freeModuleTlsForThread(wibo::ModuleInfo &module, TIB *tib) { + if (!tib) { + return; + } + auto &info = module.tlsInfo; + if (!info.hasTls) { + return; + } + auto it = info.threadAllocations.find(tib); + if (it == info.threadAllocations.end()) { + return; + } + void *block = it->second; + info.threadAllocations.erase(it); + if (info.index < kTlsSlotCount && wibo::tls::getValue(tib, info.index) == block) { + wibo::tls::setValue(tib, info.index, nullptr); + } + if (block) { + std::free(block); + } +} + +void runModuleTlsCallbacks(wibo::ModuleInfo &module, DWORD reason) { + if (!module.tlsInfo.hasTls || module.tlsInfo.callbacks.empty()) { + return; + } + TIB *tib = wibo::getThreadTibForHost(); + if (!tib) { + return; + } + GUEST_CONTEXT_GUARD(tib); + using TlsCallback = void(WIN_FUNC *)(void *, DWORD, void *); + for (void *callbackAddr : module.tlsInfo.callbacks) { + if (!callbackAddr) { + continue; + } + auto callback = reinterpret_cast(callbackAddr); + callback(module.handle, reason, nullptr); + } +} + std::optional combineAndFind(const std::filesystem::path &directory, const std::string &filename) { if (filename.empty()) { @@ -663,9 +765,13 @@ void shutdownModuleRegistry() { continue; } runPendingOnExit(*info); + if (info->tlsInfo.hasTls) { + runModuleTlsCallbacks(*info, TLS_PROCESS_DETACH); + } if (info->processAttachCalled && info->processAttachSucceeded) { callDllMain(*info, DLL_PROCESS_DETACH, reinterpret_cast(1)); } + releaseModuleTls(*info); } reg->modulesByKey.clear(); reg->modulesByAlias.clear(); @@ -762,6 +868,79 @@ void runPendingOnExit(ModuleInfo &info) { info.onExitFunctions.clear(); } +bool initializeModuleTls(ModuleInfo &module) { + if (module.tlsInfo.hasTls) { + return true; + } + if (!module.executable) { + return true; + } + Executable &exec = *module.executable; + if (exec.tlsDirectoryRVA == 0 || exec.tlsDirectorySize < sizeof(ImageTlsDirectory32)) { + return true; + } + auto tlsDirectory = exec.fromRVA(exec.tlsDirectoryRVA); + if (!tlsDirectory) { + return false; + } + + auto &info = module.tlsInfo; + info.templateSize = (tlsDirectory->EndAddressOfRawData > tlsDirectory->StartAddressOfRawData) + ? tlsDirectory->EndAddressOfRawData - tlsDirectory->StartAddressOfRawData + : 0; + info.zeroFillSize = tlsDirectory->SizeOfZeroFill; + info.characteristics = tlsDirectory->Characteristics; + info.templateData = reinterpret_cast(resolveModuleAddress(exec, tlsDirectory->StartAddressOfRawData)); + info.indexLocation = reinterpret_cast(resolveModuleAddress(exec, tlsDirectory->AddressOfIndex)); + info.callbacks.clear(); + uintptr_t callbacksArray = resolveModuleAddress(exec, tlsDirectory->AddressOfCallBacks); + if (callbacksArray) { + auto callbackPtr = reinterpret_cast(callbacksArray); + while (callbackPtr && *callbackPtr) { + info.callbacks.push_back(reinterpret_cast(resolveModuleAddress(exec, *callbackPtr))); + ++callbackPtr; + } + } + info.allocationSize = info.templateSize + info.zeroFillSize; + DWORD index = tls::reserveSlot(); + if (index == tls::kInvalidTlsIndex) { + wibo::lastError = ERROR_NOT_ENOUGH_MEMORY; + return false; + } + info.index = index; + if (info.indexLocation) { + *info.indexLocation = index; + } + info.hasTls = true; + info.threadAllocations.clear(); + + if (TIB *tib = wibo::getThreadTibForHost()) { + allocateModuleTlsForThread(module, tib); + } + runModuleTlsCallbacks(module, TLS_PROCESS_ATTACH); + wibo::lastError = ERROR_SUCCESS; + return true; +} + +void releaseModuleTls(ModuleInfo &module) { + if (!module.tlsInfo.hasTls) { + return; + } + for (auto &[tib, block] : module.tlsInfo.threadAllocations) { + if (tib && module.tlsInfo.index < kTlsSlotCount && wibo::tls::getValue(tib, module.tlsInfo.index) == block) { + wibo::tls::setValue(tib, module.tlsInfo.index, nullptr); + } + if (block) { + std::free(block); + } + } + module.tlsInfo.threadAllocations.clear(); + if (module.tlsInfo.index != wibo::tls::kInvalidTlsIndex) { + wibo::tls::releaseSlot(module.tlsInfo.index); + } + module.tlsInfo = wibo::ModuleTlsInfo{}; +} + void executeOnExitTable(void *table) { auto reg = registry(); ModuleInfo *info = nullptr; @@ -789,6 +968,13 @@ void notifyDllThreadAttach() { targets.push_back(info); } } + TIB *tib = wibo::getThreadTibForHost(); + for (wibo::ModuleInfo *info : targets) { + if (info && info->tlsInfo.hasTls && tib) { + allocateModuleTlsForThread(*info, tib); + runModuleTlsCallbacks(*info, TLS_THREAD_ATTACH); + } + } for (wibo::ModuleInfo *info : targets) { callDllMain(*info, DLL_THREAD_ATTACH, nullptr); } @@ -805,9 +991,20 @@ void notifyDllThreadDetach() { targets.push_back(info); } } + TIB *tib = wibo::getThreadTibForHost(); + for (auto it = targets.rbegin(); it != targets.rend(); ++it) { + if (*it && (*it)->tlsInfo.hasTls && tib) { + runModuleTlsCallbacks(**it, TLS_THREAD_DETACH); + } + } for (auto it = targets.rbegin(); it != targets.rend(); ++it) { callDllMain(**it, DLL_THREAD_DETACH, nullptr); } + for (auto it = targets.rbegin(); it != targets.rend(); ++it) { + if (*it && (*it)->tlsInfo.hasTls && tib) { + freeModuleTlsForThread(**it, tib); + } + } wibo::lastError = ERROR_SUCCESS; } @@ -903,9 +1100,17 @@ ModuleInfo *loadModule(const char *dllName) { diskError = wibo::lastError; return nullptr; } + if (!initializeModuleTls(*raw)) { + DEBUG_LOG(" initializeModuleTls failed for %s\n", raw->originalName.c_str()); + reg.lock.lock(); + reg->modulesByKey.erase(key); + diskError = wibo::lastError; + return nullptr; + } reg.lock.lock(); if (!callDllMain(*raw, DLL_PROCESS_ATTACH, nullptr)) { DEBUG_LOG(" DllMain failed for %s\n", raw->originalName.c_str()); + releaseModuleTls(*raw); runPendingOnExit(*raw); for (auto it = reg->onExitTables.begin(); it != reg->onExitTables.end();) { if (it->second == raw) { @@ -1015,7 +1220,11 @@ void freeModule(ModuleInfo *info) { } } runPendingOnExit(*info); + if (info->tlsInfo.hasTls) { + runModuleTlsCallbacks(*info, TLS_PROCESS_DETACH); + } callDllMain(*info, DLL_PROCESS_DETACH, nullptr); + releaseModuleTls(*info); std::string key = info->resolvedPath.empty() ? storageKeyForBuiltin(info->normalizedName) : storageKeyForPath(info->resolvedPath); reg->modulesByKey.erase(key); diff --git a/src/modules.h b/src/modules.h index 8e24834..a06700f 100644 --- a/src/modules.h +++ b/src/modules.h @@ -1,6 +1,7 @@ #pragma once #include "common.h" +#include "tls.h" #include #include @@ -54,12 +55,27 @@ class Executable { uint32_t importDirectorySize = 0; uint32_t delayImportDirectoryRVA = 0; uint32_t delayImportDirectorySize = 0; + uint32_t tlsDirectoryRVA = 0; + uint32_t tlsDirectorySize = 0; bool execMapped = false; bool importsResolved = false; bool importsResolving = false; std::vector sections; }; +struct ModuleTlsInfo { + bool hasTls = false; + DWORD index = tls::kInvalidTlsIndex; + DWORD *indexLocation = nullptr; + uint8_t *templateData = nullptr; + size_t templateSize = 0; + size_t zeroFillSize = 0; + uint32_t characteristics = 0; + size_t allocationSize = 0; + std::vector callbacks; + std::unordered_map threadAllocations; +}; + struct ModuleInfo { // Windows-style handle to the module. For the main module, this is the image base. // For other modules, this is a pointer to the ModuleInfo structure. @@ -84,6 +100,7 @@ struct ModuleInfo { std::unordered_map exportNameToOrdinal; bool exportsInitialized = false; std::vector onExitFunctions; + ModuleTlsInfo tlsInfo; }; extern ModuleInfo *mainModule; @@ -104,6 +121,8 @@ void notifyDllThreadAttach(); void notifyDllThreadDetach(); BOOL disableThreadNotifications(ModuleInfo *info); std::unordered_map allLoadedModules(); +bool initializeModuleTls(ModuleInfo &module); +void releaseModuleTls(ModuleInfo &module); ModuleInfo *loadModule(const char *name); void freeModule(ModuleInfo *info); diff --git a/src/tls.cpp b/src/tls.cpp new file mode 100644 index 0000000..9df2de9 --- /dev/null +++ b/src/tls.cpp @@ -0,0 +1,74 @@ +#include "tls.h" + +#include +#include +#include + +namespace { + +std::array g_slotUsed{}; +std::mutex g_slotMutex; + +} // namespace + +namespace wibo::tls { + +void initializeTib(TIB *tib) { + if (!tib) { + return; + } + std::fill(std::begin(tib->tlsSlots), std::end(tib->tlsSlots), nullptr); + tib->tlsLinks.flink = nullptr; + tib->tlsLinks.blink = nullptr; + tib->tlsExpansionSlots = nullptr; + tib->flsSlots = nullptr; +} + +DWORD reserveSlot() { + std::lock_guard lock(g_slotMutex); + for (DWORD index = 0; index < static_cast(kTlsSlotCount); ++index) { + if (!g_slotUsed[index]) { + g_slotUsed[index] = true; + return index; + } + } + return kInvalidTlsIndex; +} + +bool releaseSlot(DWORD index) { + if (index >= static_cast(kTlsSlotCount)) { + return false; + } + std::lock_guard lock(g_slotMutex); + if (!g_slotUsed[index]) { + return false; + } + g_slotUsed[index] = false; + return true; +} + +bool isSlotAllocated(DWORD index) { + std::lock_guard lock(g_slotMutex); + return index < kTlsSlotCount && g_slotUsed[index]; +} + +void *getValue(TIB *tib, DWORD index) { + if (!tib || index >= static_cast(kTlsSlotCount)) { + return nullptr; + } + return tib->tlsSlots[index]; +} + +bool setValue(TIB *tib, DWORD index, void *value) { + if (!tib || index >= static_cast(kTlsSlotCount)) { + return false; + } + tib->tlsSlots[index] = value; + return true; +} + +void *getValue(DWORD index) { return getValue(getThreadTibForHost(), index); } + +bool setValue(DWORD index, void *value) { return setValue(getThreadTibForHost(), index, value); } + +} // namespace wibo::tls diff --git a/src/tls.h b/src/tls.h new file mode 100644 index 0000000..0203f2f --- /dev/null +++ b/src/tls.h @@ -0,0 +1,22 @@ +#pragma once + +#include "common.h" + +namespace wibo::tls { + +constexpr DWORD kInvalidTlsIndex = 0xFFFFFFFFu; + +void initializeTib(TIB *tib); + +DWORD reserveSlot(); +bool releaseSlot(DWORD index); +bool isSlotAllocated(DWORD index); + +void *getValue(TIB *tib, DWORD index); +bool setValue(TIB *tib, DWORD index, void *value); + +void *getValue(DWORD index); +bool setValue(DWORD index, void *value); + +} // namespace wibo::tls + diff --git a/test/test_tls.c b/test/test_tls.c new file mode 100644 index 0000000..2c2974d --- /dev/null +++ b/test/test_tls.c @@ -0,0 +1,96 @@ +#include "test_assert.h" +#include +#include +#include + +static void *current_teb(void) { + void *teb = NULL; + __asm__ __volatile__("movl %%fs:0x18, %0" : "=r"(teb)); + return teb; +} + +static void **tls_slots(void) { + uint8_t *teb = (uint8_t *)current_teb(); + return (void **)(teb + 0xE10); +} + +typedef struct { + DWORD tlsIndex; + int threadValue; + HANDLE readyEvent; + HANDLE continueEvent; +} ThreadCtx; + +static DWORD WINAPI tls_thread_proc(LPVOID param) { + ThreadCtx *ctx = (ThreadCtx *)param; + TEST_CHECK(ctx != NULL); + + /* TLS initially zero for a new thread */ + TEST_CHECK_EQ(NULL, TlsGetValue(ctx->tlsIndex)); + TEST_CHECK_EQ(NULL, tls_slots()[ctx->tlsIndex]); + + void *threadPtr = &ctx->threadValue; + TEST_CHECK(TlsSetValue(ctx->tlsIndex, threadPtr)); + TEST_CHECK_EQ(threadPtr, TlsGetValue(ctx->tlsIndex)); + TEST_CHECK_EQ(threadPtr, tls_slots()[ctx->tlsIndex]); + + TEST_CHECK(SetEvent(ctx->readyEvent)); + + TEST_CHECK_EQ(WAIT_OBJECT_0, WaitForSingleObject(ctx->continueEvent, 1000)); + + /* Clear before exit */ + TEST_CHECK(TlsSetValue(ctx->tlsIndex, NULL)); + return 0; +} + +int main(void) { + DWORD tlsIndex = TlsAlloc(); + TEST_CHECK(tlsIndex != TLS_OUT_OF_INDEXES); + + TEST_CHECK_EQ(NULL, TlsGetValue(tlsIndex)); + + void **tlsArray = tls_slots(); + TEST_CHECK(tlsArray != NULL); + + int mainValue = 12345; + void *mainPtr = &mainValue; + TEST_CHECK(TlsSetValue(tlsIndex, mainPtr)); + TEST_CHECK_EQ(mainPtr, TlsGetValue(tlsIndex)); + TEST_CHECK_EQ(mainPtr, tlsArray[tlsIndex]); + + ThreadCtx ctx; + ctx.tlsIndex = tlsIndex; + ctx.threadValue = 0x4242; + ctx.readyEvent = CreateEventA(NULL, FALSE, FALSE, NULL); + ctx.continueEvent = CreateEventA(NULL, FALSE, FALSE, NULL); + TEST_CHECK(ctx.readyEvent != NULL); + TEST_CHECK(ctx.continueEvent != NULL); + + HANDLE thread = CreateThread(NULL, 0, tls_thread_proc, &ctx, 0, NULL); + TEST_CHECK(thread != NULL); + + TEST_CHECK_EQ(WAIT_OBJECT_0, WaitForSingleObject(ctx.readyEvent, 1000)); + + /* Main thread value should be unchanged by worker */ + TEST_CHECK_EQ(mainPtr, TlsGetValue(tlsIndex)); + TEST_CHECK_EQ(mainPtr, tlsArray[tlsIndex]); + + TEST_CHECK(SetEvent(ctx.continueEvent)); + + TEST_CHECK_EQ(WAIT_OBJECT_0, WaitForSingleObject(thread, 1000)); + TEST_CHECK(CloseHandle(thread)); + + TEST_CHECK(CloseHandle(ctx.readyEvent)); + TEST_CHECK(CloseHandle(ctx.continueEvent)); + + /* Ensure worker cleanup didn't disturb main thread */ + TEST_CHECK_EQ(mainPtr, TlsGetValue(tlsIndex)); + TEST_CHECK_EQ(mainPtr, tlsArray[tlsIndex]); + + TEST_CHECK(TlsSetValue(tlsIndex, NULL)); + TEST_CHECK_EQ(NULL, TlsGetValue(tlsIndex)); + TEST_CHECK(TlsFree(tlsIndex)); + TEST_CHECK_EQ(NULL, tlsArray[tlsIndex]); + + return EXIT_SUCCESS; +} diff --git a/test/test_tls_reloc.c b/test/test_tls_reloc.c new file mode 100644 index 0000000..1c1a760 --- /dev/null +++ b/test/test_tls_reloc.c @@ -0,0 +1,61 @@ +#include "test_assert.h" +#include +#include + +#ifndef TLS_RELOC_PREFERRED_BASE +#define TLS_RELOC_PREFERRED_BASE 0x30000000u +#endif + +#ifndef TLS_RELOC_INITIAL_VALUE +#define TLS_RELOC_INITIAL_VALUE 0x2468ACEDu +#endif + +typedef int(__stdcall *tls_get_template_value_fn)(void); +typedef void *(__stdcall *tls_template_address_fn)(void); +typedef int(__stdcall *tls_callback_hits_fn)(void); + +static void *reserve_preferred_region(size_t size) { + void *preferred = (void *)(uintptr_t)TLS_RELOC_PREFERRED_BASE; + void *reservation = VirtualAlloc(preferred, size, MEM_RESERVE, PAGE_NOACCESS); + return reservation; +} + +int main(void) { + const size_t reservationSize = 0x200000; // 2 MB + void *preferred = (void *)(uintptr_t)TLS_RELOC_PREFERRED_BASE; + void *reservation = reserve_preferred_region(reservationSize); + TEST_CHECK_MSG(reservation == preferred, "VirtualAlloc(%p) failed: %lu", preferred, + (unsigned long)GetLastError()); + + HMODULE mod = LoadLibraryA("tls_reloc.dll"); + TEST_CHECK_MSG(mod != NULL, "LoadLibraryA failed: %lu", (unsigned long)GetLastError()); + + TEST_CHECK_MSG(VirtualFree(reservation, 0, MEM_RELEASE) != 0, "VirtualFree failed: %lu", + (unsigned long)GetLastError()); + + TEST_CHECK((uintptr_t)mod != (uintptr_t)preferred); + + FARPROC rawGet = GetProcAddress(mod, "tls_get_template_value@0"); + FARPROC rawAddr = GetProcAddress(mod, "tls_template_address@0"); + FARPROC rawHits = GetProcAddress(mod, "tls_callback_hits@0"); + TEST_CHECK(rawGet != NULL); + TEST_CHECK(rawAddr != NULL); + TEST_CHECK(rawHits != NULL); + + tls_get_template_value_fn tls_get_template_value = (tls_get_template_value_fn)(uintptr_t)rawGet; + tls_template_address_fn tls_template_address = (tls_template_address_fn)(uintptr_t)rawAddr; + tls_callback_hits_fn tls_callback_hits = (tls_callback_hits_fn)(uintptr_t)rawHits; + + void *templateAddr = tls_template_address(); + TEST_CHECK(templateAddr != NULL); + + int initial = tls_get_template_value(); + TEST_CHECK_EQ(TLS_RELOC_INITIAL_VALUE, (unsigned int)initial); + + int hits = tls_callback_hits(); + TEST_CHECK_EQ(1, hits); + + TEST_CHECK(FreeLibrary(mod)); + + return EXIT_SUCCESS; +} diff --git a/test/tls_reloc_dll.c b/test/tls_reloc_dll.c new file mode 100644 index 0000000..4c13f48 --- /dev/null +++ b/test/tls_reloc_dll.c @@ -0,0 +1,40 @@ +#include + +#ifndef TLS_RELOC_INITIAL_VALUE +#define TLS_RELOC_INITIAL_VALUE 0x2468ACEDu +#endif + +__attribute__((section(".tls$AAA"), used)) static int g_tlsInitialValue = (int)TLS_RELOC_INITIAL_VALUE; +__attribute__((section(".tls$ZZZ"), used)) static const int g_tlsTerminator = 0; + +static int g_tlsCallbackCount = 0; + +static void NTAPI tls_callback(PVOID module, DWORD reason, PVOID reserved) { + (void)module; + (void)reserved; + if (reason == DLL_PROCESS_ATTACH) { + ++g_tlsCallbackCount; + } +} + +__attribute__((section(".CRT$XLB"), used)) static const PIMAGE_TLS_CALLBACK g_tlsCallback = tls_callback; + +BOOL WINAPI DllMain(HINSTANCE hinstDLL, DWORD fdwReason, LPVOID lpReserved) { + (void)lpReserved; + if (fdwReason == DLL_PROCESS_ATTACH) { + DisableThreadLibraryCalls(hinstDLL); + } + return TRUE; +} + +__declspec(dllexport) int __stdcall tls_get_template_value(void) { + return g_tlsInitialValue; +} + +__declspec(dllexport) void *__stdcall tls_template_address(void) { + return &g_tlsInitialValue; +} + +__declspec(dllexport) int __stdcall tls_callback_hits(void) { + return g_tlsCallbackCount; +}