From a240e3dc4b60be1c65e24e65f46664aac013ff61 Mon Sep 17 00:00:00 2001 From: Luke Street Date: Mon, 27 Oct 2025 12:27:26 -0600 Subject: [PATCH] Support TlsExpansionSlots (>64 TLS slots) --- dll/kernel32/processthreadsapi.cpp | 6 +- src/main.cpp | 2 + src/modules.cpp | 10 +- src/tls.cpp | 237 ++++++++++++++++++++++++++--- src/tls.h | 3 +- test/test_tls.c | 85 ++++++++++- 6 files changed, 313 insertions(+), 30 deletions(-) diff --git a/dll/kernel32/processthreadsapi.cpp b/dll/kernel32/processthreadsapi.cpp index 8073f1e..f16b309 100644 --- a/dll/kernel32/processthreadsapi.cpp +++ b/dll/kernel32/processthreadsapi.cpp @@ -387,7 +387,6 @@ DWORD WIN_FUNC TlsAlloc() { wibo::lastError = ERROR_NOT_ENOUGH_MEMORY; return TLS_OUT_OF_INDEXES; } - wibo::tls::setValue(index, nullptr); wibo::lastError = ERROR_SUCCESS; return index; } @@ -399,7 +398,6 @@ BOOL WIN_FUNC TlsFree(DWORD dwTlsIndex) { wibo::lastError = ERROR_INVALID_PARAMETER; return FALSE; } - wibo::tls::setValue(dwTlsIndex, nullptr); wibo::lastError = ERROR_SUCCESS; return TRUE; } @@ -411,7 +409,9 @@ LPVOID WIN_FUNC TlsGetValue(DWORD dwTlsIndex) { wibo::lastError = ERROR_INVALID_PARAMETER; return nullptr; } - return wibo::tls::getValue(dwTlsIndex); + void *result = wibo::tls::getValue(dwTlsIndex); + wibo::lastError = ERROR_SUCCESS; + return result; } BOOL WIN_FUNC TlsSetValue(DWORD dwTlsIndex, LPVOID lpTlsValue) { diff --git a/src/main.cpp b/src/main.cpp index 537255b..53be2fd 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -67,6 +67,7 @@ void wibo::destroyTib(TIB *tibPtr) { if (!tibPtr) { return; } + tls::cleanupTib(tibPtr); std::free(tibPtr); } @@ -613,6 +614,7 @@ int main(int argc, char **argv) { } DEBUG_LOG("We came back\n"); wibo::shutdownModuleRegistry(); + wibo::tls::cleanupTib(&tib); return 1; } diff --git a/src/modules.cpp b/src/modules.cpp index 8d19d8c..4694189 100644 --- a/src/modules.cpp +++ b/src/modules.cpp @@ -282,7 +282,10 @@ void allocateModuleTlsForThread(wibo::ModuleInfo &module, TIB *tib) { } } info.threadAllocations.emplace(tib, block); - wibo::tls::setValue(tib, info.index, block); + if (!wibo::tls::setValue(tib, info.index, block)) { + DEBUG_LOG(" allocateModuleTlsForThread: failed to publish TLS pointer for %s (index %u)\n", + module.originalName.c_str(), info.index); + } } void freeModuleTlsForThread(wibo::ModuleInfo &module, TIB *tib) { @@ -300,7 +303,10 @@ void freeModuleTlsForThread(wibo::ModuleInfo &module, TIB *tib) { void *block = it->second; info.threadAllocations.erase(it); if (info.index < kTlsSlotCount && wibo::tls::getValue(tib, info.index) == block) { - wibo::tls::setValue(tib, info.index, nullptr); + if (!wibo::tls::setValue(tib, info.index, nullptr)) { + DEBUG_LOG(" freeModuleTlsForThread: failed to clear TLS pointer for %s (index %u)\n", + module.originalName.c_str(), info.index); + } } if (block) { std::free(block); diff --git a/src/tls.cpp b/src/tls.cpp index 9df2de9..b530740 100644 --- a/src/tls.cpp +++ b/src/tls.cpp @@ -1,13 +1,155 @@ #include "tls.h" +#include "common.h" #include #include +#include +#include #include +#include namespace { -std::array g_slotUsed{}; -std::mutex g_slotMutex; +constexpr size_t kMaxExpansionSlots = wibo::tls::kTlsMaxSlotCount - kTlsSlotCount; + +std::mutex g_tlsMutex; +std::array g_slotUsed{}; +std::vector g_activeTibs; +size_t g_expansionCapacity = 0; + +struct TlsVector { + size_t capacity; + void *slots[]; +}; + +TlsVector *allocateVector(size_t capacity) { + if (capacity == 0 || capacity > kMaxExpansionSlots) { + return nullptr; + } + const size_t bytes = sizeof(TlsVector) + capacity * sizeof(void *); + auto *vector = static_cast(std::calloc(1, bytes)); + if (!vector) { + return nullptr; + } + vector->capacity = capacity; + return vector; +} + +TlsVector *vectorFromSlots(void **slots) { + if (!slots) { + return nullptr; + } + auto *base = reinterpret_cast(slots) - offsetof(TlsVector, slots); + return reinterpret_cast(base); +} + +TlsVector *getExpansionVector(TIB *tib) { + if (!tib) { + return nullptr; + } + return vectorFromSlots(tib->tlsExpansionSlots); +} + +void setExpansionVector(TIB *tib, TlsVector *vector) { + if (!tib) { + return; + } + tib->tlsExpansionSlots = vector ? vector->slots : nullptr; +} + +size_t chooseCapacity(size_t current, size_t required) { + if (required == 0) { + return current; + } + if (required > kMaxExpansionSlots) { + return 0; + } + size_t capacity = current; + if (capacity == 0) { + capacity = 1; + } + while (capacity < required) { + size_t next = capacity * 2; + if (next <= capacity || next > kMaxExpansionSlots) { + capacity = kMaxExpansionSlots; + } else { + capacity = next; + } + } + if (capacity > kMaxExpansionSlots) { + capacity = kMaxExpansionSlots; + } + if (capacity < required) { + return 0; + } + return capacity; +} + +struct PendingResize { + TIB *tib; + TlsVector *oldVector; + TlsVector *newVector; +}; + +bool ensureGlobalExpansionCapacityLocked(size_t required) { + if (required == 0) { + return true; + } + if (required <= g_expansionCapacity) { + return true; + } + size_t target = chooseCapacity(g_expansionCapacity, required); + if (target == 0) { + return false; + } + std::vector pending; + pending.reserve(g_activeTibs.size()); + for (TIB *tib : g_activeTibs) { + TlsVector *currentVector = getExpansionVector(tib); + size_t currentCapacity = currentVector ? currentVector->capacity : 0; + if (currentCapacity >= target) { + continue; + } + TlsVector *newVector = allocateVector(target); + if (!newVector) { + for (auto &entry : pending) { + std::free(entry.newVector); + } + return false; + } + if (currentVector) { + std::copy_n(currentVector->slots, std::min(currentVector->capacity, newVector->capacity), newVector->slots); + } + pending.emplace_back(tib, currentVector, newVector); + } + for (auto &entry : pending) { + setExpansionVector(entry.tib, entry.newVector); + } + for (auto &entry : pending) { + if (entry.oldVector) { + std::free(entry.oldVector); + } + } + g_expansionCapacity = target; + return true; +} + +void zeroSlotForAllTibs(size_t index) { + if (index < kTlsSlotCount) { + for (TIB *tib : g_activeTibs) { + tib->tlsSlots[index] = nullptr; + } + return; + } + size_t expansionIndex = index - kTlsSlotCount; + for (TIB *tib : g_activeTibs) { + TlsVector *vector = getExpansionVector(tib); + if (!vector || expansionIndex >= vector->capacity) { + continue; + } + vector->slots[expansionIndex] = nullptr; + } +} } // namespace @@ -17,53 +159,108 @@ void initializeTib(TIB *tib) { if (!tib) { return; } - std::fill(std::begin(tib->tlsSlots), std::end(tib->tlsSlots), nullptr); - tib->tlsLinks.flink = nullptr; - tib->tlsLinks.blink = nullptr; - tib->tlsExpansionSlots = nullptr; - tib->flsSlots = nullptr; + std::lock_guard lock(g_tlsMutex); + if (std::find(g_activeTibs.begin(), g_activeTibs.end(), tib) != g_activeTibs.end()) { + return; + } + g_activeTibs.push_back(tib); + if (g_expansionCapacity > 0 && !getExpansionVector(tib)) { + if (TlsVector *vector = allocateVector(g_expansionCapacity)) { + setExpansionVector(tib, vector); + } + } +} + +void cleanupTib(TIB *tib) { + if (!tib) { + return; + } + std::lock_guard lock(g_tlsMutex); + if (TlsVector *vector = getExpansionVector(tib)) { + std::free(vector); + setExpansionVector(tib, nullptr); + } + auto it = std::find(g_activeTibs.begin(), g_activeTibs.end(), tib); + if (it != g_activeTibs.end()) { + g_activeTibs.erase(it); + } } DWORD reserveSlot() { - std::lock_guard lock(g_slotMutex); - for (DWORD index = 0; index < static_cast(kTlsSlotCount); ++index) { - if (!g_slotUsed[index]) { - g_slotUsed[index] = true; - return index; + std::lock_guard lock(g_tlsMutex); + for (DWORD index = 0; index < static_cast(wibo::tls::kTlsMaxSlotCount); ++index) { + if (g_slotUsed[index]) { + continue; } + if (index >= static_cast(kTlsSlotCount)) { + size_t required = static_cast(index) - kTlsSlotCount + 1; + if (!ensureGlobalExpansionCapacityLocked(required)) { + return kInvalidTlsIndex; + } + } + g_slotUsed[index] = true; + zeroSlotForAllTibs(index); + return index; } return kInvalidTlsIndex; } bool releaseSlot(DWORD index) { - if (index >= static_cast(kTlsSlotCount)) { + if (index >= static_cast(wibo::tls::kTlsMaxSlotCount)) { return false; } - std::lock_guard lock(g_slotMutex); + std::lock_guard lock(g_tlsMutex); if (!g_slotUsed[index]) { return false; } g_slotUsed[index] = false; + zeroSlotForAllTibs(index); return true; } bool isSlotAllocated(DWORD index) { - std::lock_guard lock(g_slotMutex); - return index < kTlsSlotCount && g_slotUsed[index]; + std::lock_guard lock(g_tlsMutex); + return index < wibo::tls::kTlsMaxSlotCount && g_slotUsed[index]; } void *getValue(TIB *tib, DWORD index) { - if (!tib || index >= static_cast(kTlsSlotCount)) { + if (!tib || index >= static_cast(wibo::tls::kTlsMaxSlotCount)) { return nullptr; } - return tib->tlsSlots[index]; + if (index < static_cast(kTlsSlotCount)) { + return tib->tlsSlots[index]; + } + std::lock_guard lock(g_tlsMutex); + TlsVector *vector = getExpansionVector(tib); + if (!vector) { + return nullptr; + } + size_t expansionIndex = static_cast(index) - kTlsSlotCount; + if (expansionIndex >= vector->capacity) { + return nullptr; + } + return vector->slots[expansionIndex]; } bool setValue(TIB *tib, DWORD index, void *value) { - if (!tib || index >= static_cast(kTlsSlotCount)) { + if (!tib || index >= static_cast(wibo::tls::kTlsMaxSlotCount)) { return false; } - tib->tlsSlots[index] = value; + if (index < static_cast(kTlsSlotCount)) { + tib->tlsSlots[index] = value; + return true; + } + std::lock_guard lock(g_tlsMutex); + size_t expansionIndex = static_cast(index) - kTlsSlotCount; + TlsVector *vector = getExpansionVector(tib); + if ((!vector || expansionIndex >= vector->capacity) && !ensureGlobalExpansionCapacityLocked(expansionIndex + 1)) { + return false; + } + vector = getExpansionVector(tib); + if (!vector || expansionIndex >= vector->capacity) { + return false; + } + vector->slots[expansionIndex] = value; return true; } diff --git a/src/tls.h b/src/tls.h index 0203f2f..dbf28b2 100644 --- a/src/tls.h +++ b/src/tls.h @@ -5,8 +5,10 @@ namespace wibo::tls { constexpr DWORD kInvalidTlsIndex = 0xFFFFFFFFu; +constexpr size_t kTlsMaxSlotCount = 1088; void initializeTib(TIB *tib); +void cleanupTib(TIB *tib); DWORD reserveSlot(); bool releaseSlot(DWORD index); @@ -19,4 +21,3 @@ void *getValue(DWORD index); bool setValue(DWORD index, void *value); } // namespace wibo::tls - diff --git a/test/test_tls.c b/test/test_tls.c index 2c2974d..d83b8d1 100644 --- a/test/test_tls.c +++ b/test/test_tls.c @@ -16,7 +16,9 @@ static void **tls_slots(void) { typedef struct { DWORD tlsIndex; + DWORD tlsExpansionIndex; int threadValue; + int expansionValue; HANDLE readyEvent; HANDLE continueEvent; } ThreadCtx; @@ -34,12 +36,23 @@ static DWORD WINAPI tls_thread_proc(LPVOID param) { TEST_CHECK_EQ(threadPtr, TlsGetValue(ctx->tlsIndex)); TEST_CHECK_EQ(threadPtr, tls_slots()[ctx->tlsIndex]); + if (ctx->tlsExpansionIndex != TLS_OUT_OF_INDEXES) { + DWORD expansionIndex = ctx->tlsExpansionIndex; + TEST_CHECK(expansionIndex >= TLS_MINIMUM_AVAILABLE); + void *expansionPtr = &ctx->expansionValue; + TEST_CHECK(TlsSetValue(expansionIndex, expansionPtr)); + TEST_CHECK_EQ(expansionPtr, TlsGetValue(expansionIndex)); + } + TEST_CHECK(SetEvent(ctx->readyEvent)); TEST_CHECK_EQ(WAIT_OBJECT_0, WaitForSingleObject(ctx->continueEvent, 1000)); /* Clear before exit */ TEST_CHECK(TlsSetValue(ctx->tlsIndex, NULL)); + if (ctx->tlsExpansionIndex != TLS_OUT_OF_INDEXES) { + TEST_CHECK(TlsSetValue(ctx->tlsExpansionIndex, NULL)); + } return 0; } @@ -58,9 +71,11 @@ int main(void) { TEST_CHECK_EQ(mainPtr, TlsGetValue(tlsIndex)); TEST_CHECK_EQ(mainPtr, tlsArray[tlsIndex]); - ThreadCtx ctx; + ThreadCtx ctx = {0}; ctx.tlsIndex = tlsIndex; + ctx.tlsExpansionIndex = TLS_OUT_OF_INDEXES; ctx.threadValue = 0x4242; + ctx.expansionValue = 0; ctx.readyEvent = CreateEventA(NULL, FALSE, FALSE, NULL); ctx.continueEvent = CreateEventA(NULL, FALSE, FALSE, NULL); TEST_CHECK(ctx.readyEvent != NULL); @@ -76,20 +91,82 @@ int main(void) { TEST_CHECK_EQ(mainPtr, tlsArray[tlsIndex]); TEST_CHECK(SetEvent(ctx.continueEvent)); - TEST_CHECK_EQ(WAIT_OBJECT_0, WaitForSingleObject(thread, 1000)); TEST_CHECK(CloseHandle(thread)); - TEST_CHECK(CloseHandle(ctx.readyEvent)); TEST_CHECK(CloseHandle(ctx.continueEvent)); /* Ensure worker cleanup didn't disturb main thread */ + tlsArray = tls_slots(); TEST_CHECK_EQ(mainPtr, TlsGetValue(tlsIndex)); TEST_CHECK_EQ(mainPtr, tlsArray[tlsIndex]); + /* Allocate additional slots to cross the TLS_MINIMUM_AVAILABLE boundary */ + const size_t extraCount = 80; + DWORD extraSlots[extraCount]; + size_t extraUsed = 0; + DWORD expansionIndex = TLS_OUT_OF_INDEXES; + + for (; extraUsed < extraCount; ++extraUsed) { + DWORD index = TlsAlloc(); + TEST_CHECK(index != TLS_OUT_OF_INDEXES); + extraSlots[extraUsed] = index; + if (index >= TLS_MINIMUM_AVAILABLE) { + expansionIndex = index; + ++extraUsed; + break; + } + } + TEST_CHECK(expansionIndex != TLS_OUT_OF_INDEXES); + + tlsArray = tls_slots(); + + int mainExpansionValue = 0x5678; + void *mainExpansionPtr = &mainExpansionValue; + TEST_CHECK(TlsSetValue(expansionIndex, mainExpansionPtr)); + TEST_CHECK_EQ(mainExpansionPtr, TlsGetValue(expansionIndex)); + + ThreadCtx expansionCtx = {0}; + expansionCtx.tlsIndex = tlsIndex; + expansionCtx.tlsExpansionIndex = expansionIndex; + expansionCtx.threadValue = 0x3535; + expansionCtx.expansionValue = 0x2626; + expansionCtx.readyEvent = CreateEventA(NULL, FALSE, FALSE, NULL); + expansionCtx.continueEvent = CreateEventA(NULL, FALSE, FALSE, NULL); + TEST_CHECK(expansionCtx.readyEvent != NULL); + TEST_CHECK(expansionCtx.continueEvent != NULL); + + thread = CreateThread(NULL, 0, tls_thread_proc, &expansionCtx, 0, NULL); + TEST_CHECK(thread != NULL); + TEST_CHECK_EQ(WAIT_OBJECT_0, WaitForSingleObject(expansionCtx.readyEvent, 1000)); + + tlsArray = tls_slots(); + TEST_CHECK_EQ(mainPtr, TlsGetValue(tlsIndex)); + TEST_CHECK_EQ(mainPtr, tlsArray[tlsIndex]); + TEST_CHECK_EQ(mainExpansionPtr, TlsGetValue(expansionIndex)); + + TEST_CHECK(SetEvent(expansionCtx.continueEvent)); + TEST_CHECK_EQ(WAIT_OBJECT_0, WaitForSingleObject(thread, 1000)); + TEST_CHECK(CloseHandle(thread)); + TEST_CHECK(CloseHandle(expansionCtx.readyEvent)); + TEST_CHECK(CloseHandle(expansionCtx.continueEvent)); + + /* Ensure worker cleanup didn't disturb main thread values */ + tlsArray = tls_slots(); + TEST_CHECK_EQ(mainPtr, TlsGetValue(tlsIndex)); + TEST_CHECK_EQ(mainPtr, tlsArray[tlsIndex]); + TEST_CHECK_EQ(mainExpansionPtr, TlsGetValue(expansionIndex)); + + /* Clear and free all slots */ TEST_CHECK(TlsSetValue(tlsIndex, NULL)); - TEST_CHECK_EQ(NULL, TlsGetValue(tlsIndex)); + TEST_CHECK(TlsSetValue(expansionIndex, NULL)); + + for (size_t i = 0; i < extraUsed; ++i) { + TEST_CHECK(TlsFree(extraSlots[i])); + } TEST_CHECK(TlsFree(tlsIndex)); + + tlsArray = tls_slots(); TEST_CHECK_EQ(NULL, tlsArray[tlsIndex]); return EXIT_SUCCESS;