diff --git a/AGENTS.md b/AGENTS.md index b23a0ef..772f7b3 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -21,9 +21,10 @@ ## Shim Implementation Guidelines - Target pre-XP behavior; our binaries are old and don't expect modern WinAPI behavior. -- Use the `microsoft_docs` tools to fetch WinAPI signatures and documentation; always fetch the documentation when working on an API function. +- Use the `microsoft_docs` tools to fetch WinAPI signatures and documentation; always fetch the documentation when working on an API function. When searching, simply include the function name; nothing else. - Create minimal, self-contained repros in `test/` when implementing or debugging APIs; this aids both development and future testing. - Add `DEBUG_LOG` calls to trace execution and parameter values; these are invaluable when diagnosing issues with real-world binaries. +- Add prototypes to the public shim headers (e.g., `dll/kernel32/*.h`), rebuild, and the code generator automatically handles exports. ## Testing Guidelines - Fixture tests live in `test/` and are compiled automatically with `i686-w64-mingw32-gcc`. diff --git a/CMakeLists.txt b/CMakeLists.txt index 3512713..44cad45 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -442,6 +442,8 @@ if (WIBO_ENABLE_FIXTURE_TESTS) wibo_add_fixture_bin(NAME test_tls_reloc SOURCES test/test_tls_reloc.c) wibo_add_fixture_bin(NAME test_handleapi SOURCES test/test_handleapi.c) wibo_add_fixture_bin(NAME test_findfile SOURCES test/test_findfile.c) + wibo_add_fixture_bin(NAME test_locale SOURCES test/test_locale.c) + wibo_add_fixture_bin(NAME test_critical_section SOURCES test/test_critical_section.c) wibo_add_fixture_bin(NAME test_synchapi SOURCES test/test_synchapi.c) wibo_add_fixture_bin(NAME test_processes SOURCES test/test_processes.c) wibo_add_fixture_bin(NAME test_heap SOURCES test/test_heap.c) @@ -461,6 +463,8 @@ if (WIBO_ENABLE_FIXTURE_TESTS) wibo_add_fixture_bin(NAME test_pipe_io SOURCES test/test_pipe_io.c) wibo_add_fixture_bin(NAME test_namedpipe SOURCES test/test_namedpipe.c) wibo_add_fixture_bin(NAME test_sysdir SOURCES test/test_sysdir.c) + wibo_add_fixture_bin(NAME test_srw_lock SOURCES test/test_srw_lock.c) + wibo_add_fixture_bin(NAME test_init_once SOURCES test/test_init_once.c) # DLLs for fixture tests wibo_add_fixture_dll(NAME external_exports SOURCES test/external_exports.c) diff --git a/CMakePresets.json b/CMakePresets.json index a8a0ce2..7df0b18 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -10,6 +10,13 @@ "WIBO_ENABLE_LIBURING": "ON" } }, + { + "name": "gnu-base", + "hidden": true, + "cacheVariables": { + "WIBO_ENABLE_LTO": "OFF" + } + }, { "name": "clang-base", "hidden": true, @@ -22,7 +29,7 @@ "name": "debug", "displayName": "Debug", "description": "Debug build (32-bit Linux)", - "inherits": ["ninja-base"], + "inherits": ["ninja-base", "gnu-base"], "binaryDir": "${sourceDir}/build/debug", "condition": { "type": "equals", @@ -54,7 +61,7 @@ "name": "release", "displayName": "Release", "description": "Release build (32-bit Linux)", - "inherits": ["ninja-base"], + "inherits": ["ninja-base", "gnu-base"], "binaryDir": "${sourceDir}/build/release", "condition": { "type": "equals", @@ -115,7 +122,7 @@ { "name": "debug64", "displayName": "Debug (64-bit)", - "inherits": ["ninja-base"], + "inherits": ["ninja-base", "gnu-base"], "binaryDir": "${sourceDir}/build/debug64", "condition": { "type": "equals", @@ -130,7 +137,7 @@ { "name": "release64", "displayName": "Release (64-bit)", - "inherits": ["ninja-base"], + "inherits": ["ninja-base", "gnu-base"], "binaryDir": "${sourceDir}/build/release64", "condition": { "type": "equals", diff --git a/dll/kernel32.cpp b/dll/kernel32.cpp index c760755..6613977 100644 --- a/dll/kernel32.cpp +++ b/dll/kernel32.cpp @@ -5,6 +5,7 @@ extern const wibo::ModuleStub lib_kernel32 = { (const char *[]){ "kernel32", + "kernelbase", nullptr, }, kernel32ThunkByName, diff --git a/dll/kernel32/synchapi.cpp b/dll/kernel32/synchapi.cpp index cd9246b..16b941f 100644 --- a/dll/kernel32/synchapi.cpp +++ b/dll/kernel32/synchapi.cpp @@ -6,23 +6,34 @@ #include "handles.h" #include "heap.h" #include "internal.h" +#include "processthreadsapi.h" #include "strutil.h" #include "types.h" #include +#include #include +#include #include -#include +#include #include #include #include #include #include +#include #include +#include #include namespace { +constexpr DWORD kSrwLockExclusive = 0x1u; +constexpr DWORD kSrwLockSharedIncrement = 0x2u; +constexpr GUEST_PTR kInitOnceStateMask = 0x3u; +constexpr GUEST_PTR kInitOnceCompletedFlag = 0x2u; +constexpr GUEST_PTR kInitOnceReservedMask = (1u << INIT_ONCE_CTX_RESERVED_BITS) - 1; + std::u16string makeU16String(LPCWSTR name) { if (!name) { return {}; @@ -126,6 +137,81 @@ struct WaitBlock { std::condition_variable cv; }; +struct InitOnceState { + std::mutex mutex; + std::condition_variable cv; + bool completed = false; + bool success = false; + GUEST_PTR context = GUEST_NULL; +}; + +std::mutex g_initOnceMutex; +std::unordered_map> g_initOnceStates; + +void insertInitOnceState(LPINIT_ONCE once, const std::shared_ptr &state) { + std::lock_guard lk(g_initOnceMutex); + g_initOnceStates[once] = state; +} + +std::shared_ptr getInitOnceState(LPINIT_ONCE once) { + std::lock_guard lk(g_initOnceMutex); + auto it = g_initOnceStates.find(once); + if (it == g_initOnceStates.end()) { + return nullptr; + } + return it->second; +} + +void eraseInitOnceState(LPINIT_ONCE once) { + std::lock_guard lk(g_initOnceMutex); + g_initOnceStates.erase(once); +} + +inline DWORD owningThreadId(LPCRITICAL_SECTION crit) { + std::atomic_ref owner(*reinterpret_cast(&crit->OwningThread)); + return owner.load(std::memory_order_acquire); +} + +inline void setOwningThread(LPCRITICAL_SECTION crit, DWORD threadId) { + std::atomic_ref owner(*reinterpret_cast(&crit->OwningThread)); + owner.store(threadId, std::memory_order_release); +} + +void waitForCriticalSection(LPCRITICAL_SECTION cs) { + std::atomic_ref sequence(reinterpret_cast(cs->LockSemaphore)); + LONG observed = sequence.load(std::memory_order_acquire); + while (owningThreadId(cs) != 0) { + sequence.wait(observed, std::memory_order_relaxed); + observed = sequence.load(std::memory_order_acquire); + } +} + +void signalCriticalSection(LPCRITICAL_SECTION cs) { + std::atomic_ref sequence(reinterpret_cast(cs->LockSemaphore)); + sequence.fetch_add(1, std::memory_order_release); + sequence.notify_one(); +} + +inline bool trySpinAcquireCriticalSection(LPCRITICAL_SECTION cs, DWORD threadId) { + if (!cs || cs->SpinCount == 0) { + return false; + } + for (ULONG_PTR spins = cs->SpinCount; spins > 0; --spins) { + if (kernel32::TryEnterCriticalSection(cs)) { + return true; + } + if (cs->LockCount > 0) { + break; + } + std::this_thread::yield(); + if (owningThreadId(cs) == threadId) { + // Owner is self, TryEnter would have succeeded; bail out. + break; + } + } + return false; +} + } // namespace namespace kernel32 { @@ -550,7 +636,7 @@ DWORD WINAPI WaitForMultipleObjects(DWORD nCount, const HANDLE *lpHandles, BOOL void WINAPI InitializeCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { HOST_CONTEXT_GUARD(); - VERBOSE_LOG("InitializeCriticalSection(%p)\n", lpCriticalSection); + DEBUG_LOG("InitializeCriticalSection(%p)\n", lpCriticalSection); InitializeCriticalSectionEx(lpCriticalSection, 0, 0); } @@ -586,25 +672,117 @@ BOOL WINAPI InitializeCriticalSectionAndSpinCount(LPCRITICAL_SECTION lpCriticalS void WINAPI DeleteCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { HOST_CONTEXT_GUARD(); - VERBOSE_LOG("STUB: DeleteCriticalSection(%p)\n", lpCriticalSection); - (void)lpCriticalSection; + DEBUG_LOG("DeleteCriticalSection(%p)\n", lpCriticalSection); + if (!lpCriticalSection) { + return; + } + + if (lpCriticalSection->DebugInfo && lpCriticalSection->DebugInfo != static_cast(-1)) { + auto *debugInfo = fromGuestPtr(lpCriticalSection->DebugInfo); + if (debugInfo && debugInfo->Spare[0] == 0) { + wibo::heap::guestFree(debugInfo); + } + } + + lpCriticalSection->DebugInfo = GUEST_NULL; + lpCriticalSection->RecursionCount = 0; + lpCriticalSection->SpinCount = 0; + setOwningThread(lpCriticalSection, 0); + + std::atomic_ref sequence(reinterpret_cast(lpCriticalSection->LockSemaphore)); + sequence.store(0, std::memory_order_release); + sequence.notify_all(); + + std::atomic_ref lockCount(lpCriticalSection->LockCount); + lockCount.store(-1, std::memory_order_release); + lockCount.notify_all(); +} + +BOOL WINAPI TryEnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { + HOST_CONTEXT_GUARD(); + VERBOSE_LOG("TryEnterCriticalSection(%p)\n", lpCriticalSection); + if (!lpCriticalSection) { + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + + std::atomic_ref lockCount(lpCriticalSection->LockCount); + const DWORD threadId = GetCurrentThreadId(); + + LONG expected = -1; + if (lockCount.compare_exchange_strong(expected, 0, std::memory_order_acq_rel, std::memory_order_acquire)) { + setOwningThread(lpCriticalSection, threadId); + lpCriticalSection->RecursionCount = 1; + return TRUE; + } + + if (owningThreadId(lpCriticalSection) == threadId) { + lockCount.fetch_add(1, std::memory_order_acq_rel); + lpCriticalSection->RecursionCount++; + return TRUE; + } + return FALSE; } void WINAPI EnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { HOST_CONTEXT_GUARD(); - VERBOSE_LOG("STUB: EnterCriticalSection(%p)\n", lpCriticalSection); - (void)lpCriticalSection; + VERBOSE_LOG("EnterCriticalSection(%p)\n", lpCriticalSection); + if (!lpCriticalSection) { + setLastError(ERROR_INVALID_PARAMETER); + return; + } + + const DWORD threadId = GetCurrentThreadId(); + if (trySpinAcquireCriticalSection(lpCriticalSection, threadId)) { + return; + } + + std::atomic_ref lockCount(lpCriticalSection->LockCount); + LONG result = lockCount.fetch_add(1, std::memory_order_acq_rel) + 1; + if (result) { + if (owningThreadId(lpCriticalSection) == threadId) { + lpCriticalSection->RecursionCount++; + return; + } + waitForCriticalSection(lpCriticalSection); + } + setOwningThread(lpCriticalSection, threadId); + lpCriticalSection->RecursionCount = 1; } void WINAPI LeaveCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { HOST_CONTEXT_GUARD(); - VERBOSE_LOG("STUB: LeaveCriticalSection(%p)\n", lpCriticalSection); - (void)lpCriticalSection; + VERBOSE_LOG("LeaveCriticalSection(%p)\n", lpCriticalSection); + if (!lpCriticalSection) { + setLastError(ERROR_INVALID_PARAMETER); + return; + } + + const DWORD threadId = GetCurrentThreadId(); + if (owningThreadId(lpCriticalSection) != threadId || lpCriticalSection->RecursionCount <= 0) { + DEBUG_LOG("LeaveCriticalSection: thread %u does not own %p (owner=%u, recursion=%ld)\n", threadId, + lpCriticalSection, owningThreadId(lpCriticalSection), + static_cast(lpCriticalSection->RecursionCount)); + return; + } + + if (--lpCriticalSection->RecursionCount > 0) { + std::atomic_ref lockCount(lpCriticalSection->LockCount); + lockCount.fetch_sub(1, std::memory_order_acq_rel); + return; + } + + setOwningThread(lpCriticalSection, 0); + std::atomic_ref lockCount(lpCriticalSection->LockCount); + LONG newValue = lockCount.fetch_sub(1, std::memory_order_acq_rel) - 1; + if (newValue >= 0) { + signalCriticalSection(lpCriticalSection); + } } BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL fPending, GUEST_PTR *lpContext) { HOST_CONTEXT_GUARD(); - DEBUG_LOG("STUB: InitOnceBeginInitialize(%p, %u, %p, %p)\n", lpInitOnce, dwFlags, fPending, lpContext); + DEBUG_LOG("InitOnceBeginInitialize(%p, %u, %p, %p)\n", lpInitOnce, dwFlags, fPending, lpContext); if (!lpInitOnce) { setLastError(ERROR_INVALID_PARAMETER); return FALSE; @@ -613,59 +791,272 @@ BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL setLastError(ERROR_INVALID_PARAMETER); return FALSE; } - if (fPending) { - *fPending = TRUE; + + std::atomic_ref state(lpInitOnce->Ptr); + + if (dwFlags & INIT_ONCE_CHECK_ONLY) { + if (dwFlags & INIT_ONCE_ASYNC) { + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + GUEST_PTR val = state.load(std::memory_order_acquire); + if ((val & kInitOnceStateMask) != kInitOnceCompletedFlag) { + if (fPending) { + *fPending = TRUE; + } + setLastError(ERROR_GEN_FAILURE); + return FALSE; + } + if (fPending) { + *fPending = FALSE; + } + if (lpContext) { + *lpContext = val & ~kInitOnceStateMask; + } + return TRUE; } - if (lpContext) { - *lpContext = GUEST_NULL; + + while (true) { + GUEST_PTR val = state.load(std::memory_order_acquire); + switch (val & kInitOnceStateMask) { + case 0: { // first time + if (dwFlags & INIT_ONCE_ASYNC) { + GUEST_PTR expected = 0; + if (state.compare_exchange_strong(expected, static_cast(3), std::memory_order_acq_rel, + std::memory_order_acquire)) { + if (fPending) { + *fPending = TRUE; + } + return TRUE; + } + } else { + auto syncState = std::make_shared(); + GUEST_PTR expected = 0; + if (state.compare_exchange_strong(expected, static_cast(1), std::memory_order_acq_rel, + std::memory_order_acquire)) { + insertInitOnceState(lpInitOnce, syncState); + if (fPending) { + *fPending = TRUE; + } + return TRUE; + } + } + break; + } + case 1: { // synchronous initialization in progress + if (dwFlags & INIT_ONCE_ASYNC) { + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + auto syncState = getInitOnceState(lpInitOnce); + if (!syncState) { + continue; + } + std::unique_lock lk(syncState->mutex); + while (!syncState->completed) { + syncState->cv.wait(lk); + } + if (!syncState->success) { + lk.unlock(); + continue; + } + GUEST_PTR ctx = syncState->context; + lk.unlock(); + if (fPending) { + *fPending = FALSE; + } + if (lpContext) { + *lpContext = ctx; + } + return TRUE; + } + case kInitOnceCompletedFlag: { + if (fPending) { + *fPending = FALSE; + } + if (lpContext) { + *lpContext = val & ~kInitOnceStateMask; + } + return TRUE; + } + case 3: { // async pending + if (!(dwFlags & INIT_ONCE_ASYNC)) { + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + if (fPending) { + *fPending = TRUE; + } + return TRUE; + } + default: + break; + } } - return TRUE; } BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpContext) { HOST_CONTEXT_GUARD(); - DEBUG_LOG("STUB: InitOnceComplete(%p, %u, %p)\n", lpInitOnce, dwFlags, lpContext); + DEBUG_LOG("InitOnceComplete(%p, %u, %p)\n", lpInitOnce, dwFlags, lpContext); if (!lpInitOnce) { setLastError(ERROR_INVALID_PARAMETER); return FALSE; } - if ((dwFlags & INIT_ONCE_INIT_FAILED) && (dwFlags & INIT_ONCE_ASYNC)) { + if (dwFlags & ~(INIT_ONCE_ASYNC | INIT_ONCE_INIT_FAILED)) { setLastError(ERROR_INVALID_PARAMETER); return FALSE; } - (void)lpContext; - return TRUE; + const bool markFailed = (dwFlags & INIT_ONCE_INIT_FAILED) != 0; + if (markFailed) { + if (lpContext) { + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + if (dwFlags & INIT_ONCE_ASYNC) { + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + } + + const GUEST_PTR contextValue = static_cast(reinterpret_cast(lpContext)); + if (!markFailed && (contextValue & kInitOnceReservedMask)) { + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + + std::atomic_ref state(lpInitOnce->Ptr); + const GUEST_PTR finalValue = markFailed ? 0 : (contextValue | kInitOnceCompletedFlag); + + while (true) { + GUEST_PTR val = state.load(std::memory_order_acquire); + switch (val & kInitOnceStateMask) { + case 1: { + auto syncState = getInitOnceState(lpInitOnce); + if (!syncState) { + setLastError(ERROR_GEN_FAILURE); + return FALSE; + } + if (!state.compare_exchange_strong(val, finalValue, std::memory_order_acq_rel, std::memory_order_acquire)) { + continue; + } + { + std::lock_guard lk(syncState->mutex); + syncState->completed = true; + syncState->success = !markFailed; + syncState->context = markFailed ? GUEST_NULL : contextValue; + } + syncState->cv.notify_all(); + eraseInitOnceState(lpInitOnce); + return TRUE; + } + case 3: + if (!(dwFlags & INIT_ONCE_ASYNC)) { + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + if (!state.compare_exchange_strong(val, finalValue, std::memory_order_acq_rel, std::memory_order_acquire)) { + continue; + } + return TRUE; + default: + setLastError(ERROR_GEN_FAILURE); + return FALSE; + } + } } void WINAPI AcquireSRWLockShared(PSRWLOCK SRWLock) { HOST_CONTEXT_GUARD(); - (void)SRWLock; - VERBOSE_LOG("STUB: AcquireSRWLockShared(%p)\n", SRWLock); + VERBOSE_LOG("AcquireSRWLockShared(%p)\n", SRWLock); + if (!SRWLock) { + return; + } + std::atomic_ref value(SRWLock->Value); + while (true) { + ULONG current = value.load(std::memory_order_acquire); + if (current & kSrwLockExclusive) { + value.wait(current, std::memory_order_relaxed); + continue; + } + ULONG desired = current + kSrwLockSharedIncrement; + if (value.compare_exchange_weak(current, desired, std::memory_order_acq_rel, std::memory_order_acquire)) { + return; + } + } } void WINAPI ReleaseSRWLockShared(PSRWLOCK SRWLock) { HOST_CONTEXT_GUARD(); - (void)SRWLock; - VERBOSE_LOG("STUB: ReleaseSRWLockShared(%p)\n", SRWLock); + VERBOSE_LOG("ReleaseSRWLockShared(%p)\n", SRWLock); + if (!SRWLock) { + return; + } + std::atomic_ref value(SRWLock->Value); + ULONG previous = value.fetch_sub(kSrwLockSharedIncrement, std::memory_order_acq_rel); + ULONG newValue = previous - kSrwLockSharedIncrement; + if (newValue == 0) { + value.notify_all(); + } } void WINAPI AcquireSRWLockExclusive(PSRWLOCK SRWLock) { HOST_CONTEXT_GUARD(); - (void)SRWLock; - VERBOSE_LOG("STUB: AcquireSRWLockExclusive(%p)\n", SRWLock); + VERBOSE_LOG("AcquireSRWLockExclusive(%p)\n", SRWLock); + if (!SRWLock) { + return; + } + std::atomic_ref value(SRWLock->Value); + while (true) { + ULONG expected = 0; + if (value.compare_exchange_strong(expected, kSrwLockExclusive, std::memory_order_acq_rel, + std::memory_order_acquire)) { + return; + } + value.wait(expected, std::memory_order_relaxed); + } } void WINAPI ReleaseSRWLockExclusive(PSRWLOCK SRWLock) { HOST_CONTEXT_GUARD(); - (void)SRWLock; - VERBOSE_LOG("STUB: ReleaseSRWLockExclusive(%p)\n", SRWLock); + VERBOSE_LOG("ReleaseSRWLockExclusive(%p)\n", SRWLock); + if (!SRWLock) { + return; + } + std::atomic_ref value(SRWLock->Value); + value.store(0, std::memory_order_release); + value.notify_all(); } BOOLEAN WINAPI TryAcquireSRWLockExclusive(PSRWLOCK SRWLock) { HOST_CONTEXT_GUARD(); - (void)SRWLock; - VERBOSE_LOG("STUB: TryAcquireSRWLockExclusive(%p)\n", SRWLock); - return TRUE; + VERBOSE_LOG("TryAcquireSRWLockExclusive(%p)\n", SRWLock); + if (!SRWLock) { + return FALSE; + } + std::atomic_ref value(SRWLock->Value); + ULONG expected = 0; + if (value.compare_exchange_strong(expected, kSrwLockExclusive, std::memory_order_acq_rel, + std::memory_order_acquire)) { + return TRUE; + } + return FALSE; +} + +BOOLEAN WINAPI TryAcquireSRWLockShared(PSRWLOCK SRWLock) { + HOST_CONTEXT_GUARD(); + VERBOSE_LOG("TryAcquireSRWLockShared(%p)\n", SRWLock); + if (!SRWLock) { + return FALSE; + } + std::atomic_ref value(SRWLock->Value); + ULONG current = value.load(std::memory_order_acquire); + while (!(current & kSrwLockExclusive)) { + ULONG desired = current + kSrwLockSharedIncrement; + if (value.compare_exchange_weak(current, desired, std::memory_order_acq_rel, std::memory_order_acquire)) { + return TRUE; + } + } + return FALSE; } } // namespace kernel32 diff --git a/dll/kernel32/synchapi.h b/dll/kernel32/synchapi.h index db0378f..b64d470 100644 --- a/dll/kernel32/synchapi.h +++ b/dll/kernel32/synchapi.h @@ -47,6 +47,8 @@ struct RTL_CRITICAL_SECTION { ULONG_PTR SpinCount; }; +static_assert(sizeof(RTL_CRITICAL_SECTION) == 24); + using PRTL_CRITICAL_SECTION = RTL_CRITICAL_SECTION *; using LPCRITICAL_SECTION = RTL_CRITICAL_SECTION *; using PCRITICAL_SECTION = RTL_CRITICAL_SECTION *; @@ -64,7 +66,7 @@ using LPINIT_ONCE = INIT_ONCE *; constexpr INIT_ONCE INIT_ONCE_STATIC_INIT{GUEST_NULL}; union RTL_SRWLOCK { - GUEST_PTR Ptr; + ULONG Value; }; using SRWLOCK = RTL_SRWLOCK; @@ -98,6 +100,7 @@ BOOL WINAPI InitializeCriticalSectionAndSpinCount(LPCRITICAL_SECTION lpCriticalS void WINAPI DeleteCriticalSection(LPCRITICAL_SECTION lpCriticalSection); void WINAPI EnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection); void WINAPI LeaveCriticalSection(LPCRITICAL_SECTION lpCriticalSection); +BOOL WINAPI TryEnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection); BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL fPending, GUEST_PTR *lpContext); BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpContext); void WINAPI AcquireSRWLockShared(PSRWLOCK SRWLock); @@ -105,5 +108,6 @@ void WINAPI ReleaseSRWLockShared(PSRWLOCK SRWLock); void WINAPI AcquireSRWLockExclusive(PSRWLOCK SRWLock); void WINAPI ReleaseSRWLockExclusive(PSRWLOCK SRWLock); BOOLEAN WINAPI TryAcquireSRWLockExclusive(PSRWLOCK SRWLock); +BOOLEAN WINAPI TryAcquireSRWLockShared(PSRWLOCK SRWLock); } // namespace kernel32 diff --git a/dll/kernel32/winbase.cpp b/dll/kernel32/winbase.cpp index 2d22738..4cb5db5 100644 --- a/dll/kernel32/winbase.cpp +++ b/dll/kernel32/winbase.cpp @@ -348,18 +348,20 @@ const uint16_t kComputerNameWide[] = {u'C', u'O', u'M', u'P', u'N', u'A', u'M', struct DllRedirectionEntry { std::string nameLower; - ACTIVATION_CONTEXT_DATA_DLL_REDIRECTION dllData; + wibo::heap::guest_ptr dllData; }; struct ActivationContext { std::vector dllRedirections; }; -ActivationContext g_builtinActCtx; +wibo::heap::guest_ptr g_builtinActCtx; ActivationContext *currentActivationContext() { - // TODO: hook into real activation context stack once we have it. - return &g_builtinActCtx; + if (!g_builtinActCtx) { + g_builtinActCtx = wibo::heap::make_guest_unique(); + } + return g_builtinActCtx.get(); } } // namespace @@ -371,11 +373,12 @@ void ensureDefaultActivationContext() { auto addDll = [ctx](const std::string &name) { DllRedirectionEntry entry; entry.nameLower = stringToLower(name); - entry.dllData.Size = sizeof(entry.dllData); - entry.dllData.Flags = ACTIVATION_CONTEXT_DATA_DLL_REDIRECTION_PATH_OMITS_ASSEMBLY_ROOT; - entry.dllData.TotalPathLength = 0; - entry.dllData.PathSegmentCount = 0; - entry.dllData.PathSegmentOffset = 0; + entry.dllData = wibo::heap::make_guest_unique(); + entry.dllData->Size = sizeof(entry.dllData); + entry.dllData->Flags = ACTIVATION_CONTEXT_DATA_DLL_REDIRECTION_PATH_OMITS_ASSEMBLY_ROOT; + entry.dllData->TotalPathLength = 0; + entry.dllData->PathSegmentCount = 0; + entry.dllData->PathSegmentOffset = 0; ctx->dllRedirections.emplace_back(std::move(entry)); }; for (const auto &[key, module] : wibo::allLoadedModules()) { @@ -697,7 +700,7 @@ BOOL WINAPI FindActCtxSectionStringW(DWORD dwFlags, const GUID *lpExtensionGuid, ReturnedData->ulDataFormatVersion = 1; ReturnedData->ulFlags = ACTCTX_SECTION_KEYED_DATA_FLAG_FOUND_IN_ACTCTX; if (dwFlags & FIND_ACTCTX_SECTION_KEY_RETURN_HACTCTX) { - ReturnedData->hActCtx = toGuestPtr(&g_builtinActCtx); + ReturnedData->hActCtx = static_cast(toGuestPtr(currentActivationContext())); } if (!matchedEntry) { @@ -705,10 +708,10 @@ BOOL WINAPI FindActCtxSectionStringW(DWORD dwFlags, const GUID *lpExtensionGuid, return FALSE; } - ReturnedData->lpData = toGuestPtr(&matchedEntry->dllData); - ReturnedData->ulLength = matchedEntry->dllData.Size; - ReturnedData->lpSectionBase = toGuestPtr(&matchedEntry->dllData); - ReturnedData->ulSectionTotalLength = matchedEntry->dllData.Size; + ReturnedData->lpData = toGuestPtr(matchedEntry->dllData.get()); + ReturnedData->ulLength = matchedEntry->dllData->Size; + ReturnedData->lpSectionBase = toGuestPtr(matchedEntry->dllData.get()); + ReturnedData->ulSectionTotalLength = matchedEntry->dllData->Size; ReturnedData->ulAssemblyRosterIndex = 1; ReturnedData->AssemblyMetadata = {}; diff --git a/dll/kernel32/winnls.cpp b/dll/kernel32/winnls.cpp index 1773793..6d684c0 100644 --- a/dll/kernel32/winnls.cpp +++ b/dll/kernel32/winnls.cpp @@ -2,6 +2,7 @@ #include "context.h" #include "errors.h" +#include "heap.h" #include "internal.h" #include "kernel32.h" #include "kernel32_trampolines.h" @@ -22,6 +23,9 @@ constexpr DWORD LCID_ALTERNATE_SORTS = 0x00000004; constexpr LCID kEnUsLcid = 0x0409; constexpr LCID kInvariantLcid = 0x007f; constexpr DWORD LOCALE_ALLOW_NEUTRAL_NAMES = 0x08000000; +constexpr DWORD kLocaleNoUserOverride = 0x80000000u; +constexpr DWORD kLocaleReturnNumber = 0x20000000u; +constexpr DWORD kLocaleFlagMask = 0xF0000000u; int compareStrings(const std::string &a, const std::string &b, DWORD dwCmpFlags) { for (size_t i = 0;; ++i) { @@ -272,30 +276,83 @@ int WINAPI GetLocaleInfoW(LCID Locale, LCTYPE LCType, LPWSTR lpLCData, int cchDa return static_cast(required); } -BOOL WINAPI EnumSystemLocalesA(LOCALE_ENUMPROCA lpLocaleEnumProc, DWORD dwFlags) { - { - HOST_CONTEXT_GUARD(); - DEBUG_LOG("EnumSystemLocalesA(%p, 0x%x)\n", lpLocaleEnumProc, dwFlags); - (void)dwFlags; - if (!lpLocaleEnumProc) { +int WINAPI GetLocaleInfoEx(LPCWSTR lpLocaleName, LCTYPE LCType, LPWSTR lpLCData, int cchData) { + HOST_CONTEXT_GUARD(); + DEBUG_LOG("GetLocaleInfoEx(%p, %u, %p, %d)\n", lpLocaleName, LCType, lpLCData, cchData); + + if (cchData < 0) { + setLastError(ERROR_INSUFFICIENT_BUFFER); + return 0; + } + + DWORD lctypeFlags = static_cast(LCType) & kLocaleFlagMask; + constexpr DWORD kSupportedFlags = kLocaleNoUserOverride; + if ((lctypeFlags & kLocaleReturnNumber) != 0 || (lctypeFlags & ~kSupportedFlags) != 0) { + setLastError(ERROR_INVALID_FLAGS); + return 0; + } + + if (lpLocaleName) { + std::string localeName = wideStringToString(reinterpret_cast(lpLocaleName)); + std::string normalized = stringToLower(localeName); + if (!normalized.empty() && normalized != "en-us" && normalized != "en_us" && + normalized != "!x-sys-default-locale") { + DEBUG_LOG("GetLocaleInfoEx: unsupported locale name '%s'\n", localeName.c_str()); setLastError(ERROR_INVALID_PARAMETER); - return FALSE; + return 0; } } - // Return to guest context before callback - char localeId[] = "00000409"; // en-US - return call_LOCALE_ENUMPROCA(lpLocaleEnumProc, localeId); + + DWORD baseType = static_cast(LCType) & ~kLocaleFlagMask; + std::string info = localeInfoString(static_cast(baseType)); + auto wide = stringToWideString(info.c_str()); + size_t required = wide.size(); + + if (cchData == 0) { + return static_cast(required); + } + if (!lpLCData) { + setLastError(ERROR_INVALID_PARAMETER); + return 0; + } + if (static_cast(cchData) < required) { + setLastError(ERROR_INSUFFICIENT_BUFFER); + return 0; + } + + std::memcpy(lpLCData, wide.data(), required * sizeof(uint16_t)); + return static_cast(required); +} + +BOOL WINAPI EnumSystemLocalesA(LOCALE_ENUMPROCA lpLocaleEnumProc, DWORD dwFlags) { + HOST_CONTEXT_GUARD(); + DEBUG_LOG("EnumSystemLocalesA(%p, 0x%x)\n", lpLocaleEnumProc, dwFlags); + (void)dwFlags; + if (!lpLocaleEnumProc) { + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + constexpr char defaultLocaleId[] = "00000409"; // en-US + char *localeId = reinterpret_cast(wibo::heap::guestMalloc(sizeof(defaultLocaleId))); + if (!localeId) { + setLastError(ERROR_NOT_ENOUGH_MEMORY); + return FALSE; + } + std::memcpy(localeId, defaultLocaleId, sizeof(defaultLocaleId)); + BOOL ret = call_LOCALE_ENUMPROCA(lpLocaleEnumProc, localeId); + wibo::heap::guestFree(localeId); + return ret; } LCID WINAPI GetUserDefaultLCID() { HOST_CONTEXT_GUARD(); - DEBUG_LOG("GetUserDefaultLCID()\n"); + DEBUG_LOG("STUB: GetUserDefaultLCID()\n"); return 0x0409; // en-US } BOOL WINAPI IsDBCSLeadByte(BYTE TestChar) { HOST_CONTEXT_GUARD(); - DEBUG_LOG("IsDBCSLeadByte(%u)\n", TestChar); + DEBUG_LOG("STUB: IsDBCSLeadByte(%u)\n", TestChar); (void)TestChar; return FALSE; } diff --git a/dll/kernel32/winnls.h b/dll/kernel32/winnls.h index 2bc8e38..3700351 100644 --- a/dll/kernel32/winnls.h +++ b/dll/kernel32/winnls.h @@ -12,7 +12,7 @@ struct CPINFO { }; using LPCPINFO = CPINFO *; -typedef BOOL (_CC_STDCALL *LOCALE_ENUMPROCA)(LPSTR); +typedef BOOL(_CC_STDCALL *LOCALE_ENUMPROCA)(LPSTR); namespace kernel32 { @@ -23,13 +23,14 @@ int WINAPI GetUserDefaultLocaleName(LPWSTR lpLocaleName, int cchLocaleName); LCID WINAPI LocaleNameToLCID(LPCWSTR lpName, DWORD dwFlags); BOOL WINAPI GetCPInfo(UINT CodePage, LPCPINFO lpCPInfo); int WINAPI CompareStringA(LCID Locale, DWORD dwCmpFlags, LPCSTR lpString1, int cchCount1, LPCSTR lpString2, - int cchCount2); + int cchCount2); int WINAPI CompareStringW(LCID Locale, DWORD dwCmpFlags, LPCWCH lpString1, int cchCount1, LPCWCH lpString2, - int cchCount2); + int cchCount2); BOOL WINAPI IsValidCodePage(UINT CodePage); BOOL WINAPI IsValidLocale(LCID Locale, DWORD dwFlags); int WINAPI GetLocaleInfoA(LCID Locale, LCTYPE LCType, LPSTR lpLCData, int cchData); int WINAPI GetLocaleInfoW(LCID Locale, LCTYPE LCType, LPWSTR lpLCData, int cchData); +int WINAPI GetLocaleInfoEx(LPCWSTR lpLocaleName, LCTYPE LCType, LPWSTR lpLCData, int cchData); BOOL WINAPI EnumSystemLocalesA(LOCALE_ENUMPROCA lpLocaleEnumProc, DWORD dwFlags); LCID WINAPI GetUserDefaultLCID(); BOOL WINAPI IsDBCSLeadByte(BYTE TestChar); diff --git a/dll/ntdll.cpp b/dll/ntdll.cpp index 1eabb32..17309d4 100644 --- a/dll/ntdll.cpp +++ b/dll/ntdll.cpp @@ -90,8 +90,8 @@ LONGLONG timespecToFileTime(const timespec &ts) { if (ticks < 0) { return 0; } - if (ticks > static_cast<__int128>(std::numeric_limits::max())) { - return std::numeric_limits::max(); + if (ticks > static_cast<__int128>(std::numeric_limits::max())) { + return std::numeric_limits::max(); } return static_cast(ticks); #else @@ -829,6 +829,13 @@ NTSTATUS WINAPI NtQueryInformationProcess(HANDLE ProcessHandle, PROCESSINFOCLASS } } +NTSTATUS WINAPI LdrAddRefDll(ULONG Flags, HMODULE Module) { + DEBUG_LOG("STUB: LdrAddRefDll(%x, %p)\n", Flags, Module); + (void)Flags; + (void)Module; + return STATUS_SUCCESS; +} + } // namespace ntdll #include "ntdll_trampolines.h" diff --git a/dll/ntdll.h b/dll/ntdll.h index 83e928c..86fbe60 100644 --- a/dll/ntdll.h +++ b/dll/ntdll.h @@ -54,5 +54,6 @@ NTSTATUS WINAPI RtlGetVersion(PRTL_OSVERSIONINFOW lpVersionInformation); NTSTATUS WINAPI NtQueryInformationProcess(HANDLE ProcessHandle, PROCESSINFOCLASS ProcessInformationClass, PVOID ProcessInformation, ULONG ProcessInformationLength, PULONG ReturnLength); +NTSTATUS WINAPI LdrAddRefDll(ULONG Flags, HMODULE Module); } // namespace ntdll diff --git a/dll/rpcrt4.cpp b/dll/rpcrt4.cpp index 6cac81d..13a429c 100644 --- a/dll/rpcrt4.cpp +++ b/dll/rpcrt4.cpp @@ -195,7 +195,7 @@ RPC_STATUS WINAPI RpcBindingFree(GUEST_PTR *binding) { if (!binding) { return RPC_S_INVALID_ARG; } - RPC_BINDING_HANDLE handle = reinterpret_cast(fromGuestPtr(*binding)); + RPC_BINDING_HANDLE handle = fromGuestPtr(*binding); if (!handle) { return RPC_S_INVALID_BINDING; } @@ -214,7 +214,7 @@ RPC_STATUS WINAPI RpcStringFreeW(GUEST_PTR *string) { if (!string) { return RPC_S_INVALID_ARG; } - RPC_WSTR value = reinterpret_cast(fromGuestPtr(*string)); + RPC_WSTR value = fromGuestPtr(*string); if (!value) { return RPC_S_OK; } diff --git a/src/errors.h b/src/errors.h index b9b56c7..443f20c 100644 --- a/src/errors.h +++ b/src/errors.h @@ -13,6 +13,7 @@ #define ERROR_FILE_EXISTS 80 #define ERROR_READ_FAULT 30 #define ERROR_HANDLE_EOF 38 +#define ERROR_GEN_FAILURE 31 #define ERROR_INVALID_ADDRESS 487 #define ERROR_NOACCESS 998 #define ERROR_BROKEN_PIPE 109 @@ -40,6 +41,7 @@ #define ERROR_NEGATIVE_SEEK 131 #define ERROR_BAD_EXE_FORMAT 193 #define ERROR_DLL_INIT_FAILED 1114 +#define ERROR_INVALID_FLAGS 1004 #define ERROR_ALREADY_EXISTS 183 #define ERROR_NOT_OWNER 288 #define ERROR_TOO_MANY_POSTS 298 diff --git a/src/heap.cpp b/src/heap.cpp index a005837..bb2b9d3 100644 --- a/src/heap.cpp +++ b/src/heap.cpp @@ -43,7 +43,7 @@ constexpr uintptr_t kHeapMax = 0x60000000UL; // 1 GiB constexpr uintptr_t kTopDownStart = 0x7D000000UL; constexpr uintptr_t kTwoGB = 0x7E000000UL; #else -constexpr uintptr_t kTopDownStart = 0x7F000000UL; // Just below 2GB +constexpr uintptr_t kTopDownStart = 0x7F000000UL; // Just below 2GB constexpr uintptr_t kTwoGB = 0x80000000UL; #endif constexpr std::size_t kGuestArenaSize = 512ULL * 1024ULL * 1024ULL; // 512 MiB @@ -65,8 +65,6 @@ std::once_flag g_initOnce; std::mutex g_mappingsMutex; std::map *g_mappings = nullptr; -std::mutex g_virtualAllocMutex; - struct VirtualAllocation { uintptr_t base = 0; std::size_t size = 0; @@ -188,7 +186,7 @@ void markDecommitted(VirtualAllocation ®ion, uintptr_t start, std::size_t len } } -bool overlapsExistingMapping(uintptr_t base, std::size_t length) { +bool overlapsExistingMappingLocked(uintptr_t base, std::size_t length) { if (g_mappings == nullptr || length == 0) { return false; } @@ -196,7 +194,6 @@ bool overlapsExistingMapping(uintptr_t base, std::size_t length) { return true; } uintptr_t end = base + length; - std::lock_guard guard(g_mappingsMutex); auto it = g_mappings->upper_bound(base); if (it != g_mappings->begin()) { --it; @@ -219,8 +216,8 @@ bool overlapsExistingMapping(uintptr_t base, std::size_t length) { return false; } -void recordGuestMapping(uintptr_t base, std::size_t size, DWORD allocationProtect, DWORD state, DWORD protect, - DWORD type) { +void recordGuestMappingLocked(uintptr_t base, std::size_t size, DWORD allocationProtect, DWORD state, DWORD protect, + DWORD type) { if (g_mappings == nullptr) { return; } @@ -232,15 +229,13 @@ void recordGuestMapping(uintptr_t base, std::size_t size, DWORD allocationProtec info.State = state; info.Protect = protect; info.Type = type; - std::lock_guard guard(g_mappingsMutex); (*g_mappings)[base] = info; } -void eraseGuestMapping(uintptr_t base) { +void eraseGuestMappingLocked(uintptr_t base) { if (g_mappings == nullptr) { return; } - std::lock_guard guard(g_mappingsMutex); g_mappings->erase(base); } @@ -282,7 +277,7 @@ wibo::heap::VmStatus vmStatusFromErrno(int err) { } } -void refreshGuestMapping(const VirtualAllocation ®ion) { +void refreshGuestMappingLocked(const VirtualAllocation ®ion) { if (g_mappings == nullptr) { return; } @@ -312,10 +307,10 @@ void refreshGuestMapping(const VirtualAllocation ®ion) { } } DWORD allocationProtect = region.allocationProtect != 0 ? region.allocationProtect : PAGE_NOACCESS; - recordGuestMapping(region.base, region.size, allocationProtect, state, protect, region.type); + recordGuestMappingLocked(region.base, region.size, allocationProtect, state, protect, region.type); } -bool mapAtAddr(uintptr_t addr, std::size_t size, const char *name, void **outPtr) { +bool mapAtAddrLocked(uintptr_t addr, std::size_t size, const char *name, void **outPtr) { void *p = mmap(reinterpret_cast(addr), size, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED, -1, 0); if (p == MAP_FAILED) { @@ -326,16 +321,16 @@ bool mapAtAddr(uintptr_t addr, std::size_t size, const char *name, void **outPtr prctl(PR_SET_VMA, PR_SET_VMA_ANON_NAME, addr, size, name); } #else - (void)name; + (void)name; #endif - recordGuestMapping(addr, size, PAGE_READWRITE, MEM_RESERVE, PAGE_READWRITE, MEM_PRIVATE); + recordGuestMappingLocked(addr, size, PAGE_READWRITE, MEM_RESERVE, PAGE_READWRITE, MEM_PRIVATE); if (outPtr) { *outPtr = p; } return true; } -bool findFreeMapping(std::size_t size, uintptr_t minAddr, uintptr_t maxAddr, bool preferTop, uintptr_t *outAddr) { +bool findFreeMappingLocked(std::size_t size, uintptr_t minAddr, uintptr_t maxAddr, bool preferTop, uintptr_t *outAddr) { if (outAddr == nullptr || size == 0 || g_mappings == nullptr) { return false; } @@ -350,8 +345,6 @@ bool findFreeMapping(std::size_t size, uintptr_t minAddr, uintptr_t maxAddr, boo return false; } - std::lock_guard guard(g_mappingsMutex); - auto tryGap = [&](uintptr_t gapStart, uintptr_t gapEnd, uintptr_t &result) -> bool { if (gapEnd <= gapStart) { return false; @@ -428,13 +421,14 @@ bool findFreeMapping(std::size_t size, uintptr_t minAddr, uintptr_t maxAddr, boo bool mapArena(std::size_t size, uintptr_t minAddr, uintptr_t maxAddr, bool preferTop, const char *name, ArenaRange &out) { + std::lock_guard guard(g_mappingsMutex); const std::size_t ps = wibo::heap::systemPageSize(); size = (size + ps - 1) & ~(ps - 1); uintptr_t cand = 0; void *p = nullptr; - if (findFreeMapping(size, minAddr, maxAddr, preferTop, &cand)) { + if (findFreeMappingLocked(size, minAddr, maxAddr, preferTop, &cand)) { DEBUG_LOG("heap: found free mapping at %lx\n", cand); - if (mapAtAddr(cand, size, name, &p)) { + if (mapAtAddrLocked(cand, size, name, &p)) { out.start = p; out.size = size; return true; @@ -565,12 +559,13 @@ VmStatus virtualReset(void *baseAddress, std::size_t regionSize) { if (length == 0) { return VmStatus::InvalidParameter; } - std::unique_lock allocLock(g_virtualAllocMutex); - VirtualAllocation *region = lookupRegion(start); - if (!region || !rangeWithinRegion(*region, start, length)) { - return VmStatus::InvalidAddress; + { + std::lock_guard allocLock(g_mappingsMutex); + VirtualAllocation *region = lookupRegion(start); + if (!region || !rangeWithinRegion(*region, start, length)) { + return VmStatus::InvalidAddress; + } } - allocLock.unlock(); #ifdef MADV_FREE int advice = MADV_FREE; #else @@ -632,7 +627,7 @@ VmStatus virtualAlloc(void **baseAddress, std::size_t *regionSize, DWORD allocat return VmStatus::InvalidParameter; } - std::unique_lock allocLock(g_virtualAllocMutex); + std::unique_lock allocLock(g_mappingsMutex); if (reserve) { uintptr_t base = 0; @@ -656,7 +651,7 @@ VmStatus virtualAlloc(void **baseAddress, std::size_t *regionSize, DWORD allocat if (base >= kTwoGB || (base + length) > kTwoGB) { return VmStatus::InvalidAddress; } - if (overlapsExistingMapping(base, length)) { + if (overlapsExistingMappingLocked(base, length)) { return VmStatus::InvalidAddress; } } else { @@ -665,7 +660,7 @@ VmStatus virtualAlloc(void **baseAddress, std::size_t *regionSize, DWORD allocat return VmStatus::InvalidParameter; } length = static_cast(aligned); - if (!findFreeMapping(length, kLowMemoryStart, kTopDownStart, topDown, &base)) { + if (!findFreeMappingLocked(length, kLowMemoryStart, kTopDownStart, topDown, &base)) { return VmStatus::NoMemory; } if (base >= kTwoGB || (base + length) > kTwoGB) { @@ -697,7 +692,7 @@ VmStatus virtualAlloc(void **baseAddress, std::size_t *regionSize, DWORD allocat allocation.type = type; allocation.pageProtect.assign(length / pageSize, commit ? protect : 0); g_virtualAllocations[actualBase] = std::move(allocation); - refreshGuestMapping(g_virtualAllocations[actualBase]); + refreshGuestMappingLocked(g_virtualAllocations[actualBase]); if (baseAddress) { *baseAddress = reinterpret_cast(actualBase); @@ -761,7 +756,7 @@ VmStatus virtualAlloc(void **baseAddress, std::size_t *regionSize, DWORD allocat markCommitted(*region, run.first, run.second, protect); } - refreshGuestMapping(*region); + refreshGuestMappingLocked(*region); if (baseAddress) { *baseAddress = reinterpret_cast(start); @@ -785,7 +780,7 @@ VmStatus virtualFree(void *baseAddress, std::size_t regionSize, DWORD freeType) } const uintptr_t pageSize = wibo::heap::systemPageSize(); - std::unique_lock allocLock(g_virtualAllocMutex); + std::lock_guard lk(g_mappingsMutex); if (release) { uintptr_t base = reinterpret_cast(baseAddress); @@ -802,7 +797,6 @@ VmStatus virtualFree(void *baseAddress, std::size_t regionSize, DWORD freeType) } std::size_t length = it->second.size; g_virtualAllocations.erase(it); - allocLock.unlock(); // Replace with PROT_NONE + MAP_NORESERVE to release physical memory void *res = mmap(reinterpret_cast(base), length, PROT_NONE, MAP_PRIVATE | MAP_ANONYMOUS | MAP_FIXED | MAP_NORESERVE, -1, 0); @@ -812,7 +806,7 @@ VmStatus virtualFree(void *baseAddress, std::size_t regionSize, DWORD freeType) #ifdef __linux__ prctl(PR_SET_VMA, PR_SET_VMA_ANON_NAME, base, length, "wibo reserved"); #endif - eraseGuestMapping(base); + eraseGuestMappingLocked(base); return VmStatus::Success; } @@ -852,7 +846,7 @@ VmStatus virtualFree(void *baseAddress, std::size_t regionSize, DWORD freeType) prctl(PR_SET_VMA, PR_SET_VMA_ANON_NAME, res, length, "wibo reserved"); #endif markDecommitted(region, start, length); - refreshGuestMapping(region); + refreshGuestMappingLocked(region); return VmStatus::Success; } @@ -869,7 +863,7 @@ VmStatus virtualProtect(void *baseAddress, std::size_t regionSize, DWORD newProt return VmStatus::InvalidParameter; } - std::unique_lock allocLock(g_virtualAllocMutex); + std::unique_lock allocLock(g_mappingsMutex); VirtualAllocation *region = lookupRegion(start); if (!region || !rangeWithinRegion(*region, start, static_cast(end - start))) { return VmStatus::InvalidAddress; @@ -898,7 +892,8 @@ VmStatus virtualProtect(void *baseAddress, std::size_t regionSize, DWORD newProt for (std::size_t i = 0; i < pageCount; ++i) { region->pageProtect[firstPage + i] = newProtect; } - refreshGuestMapping(*region); + refreshGuestMappingLocked(*region); + allocLock.unlock(); if (oldProtect) { *oldProtect = previousProtect; @@ -918,7 +913,7 @@ VmStatus virtualQuery(const void *address, MEMORY_BASIC_INFORMATION *outInfo) { } uintptr_t pageBase = alignDown(request, pageSize); - std::unique_lock allocLock(g_virtualAllocMutex); + std::unique_lock allocLock(g_mappingsMutex); VirtualAllocation *region = lookupRegion(pageBase); if (!region) { uintptr_t regionStart = pageBase; @@ -1013,10 +1008,13 @@ VmStatus reserveViewRange(std::size_t regionSize, uintptr_t minAddr, uintptr_t m return VmStatus::InvalidParameter; } uintptr_t candidate = 0; - if (!findFreeMapping(aligned, minAddr, maxAddr, false, &candidate)) { - return VmStatus::NoMemory; + { + std::lock_guard allocLock(g_mappingsMutex); + if (!findFreeMappingLocked(aligned, minAddr, maxAddr, false, &candidate)) { + return VmStatus::NoMemory; + } + recordGuestMappingLocked(candidate, aligned, PAGE_NOACCESS, MEM_RESERVE, PAGE_NOACCESS, MEM_MAPPED); } - recordGuestMapping(candidate, aligned, PAGE_NOACCESS, MEM_RESERVE, PAGE_NOACCESS, MEM_MAPPED); *baseAddress = reinterpret_cast(candidate); return VmStatus::Success; } @@ -1027,15 +1025,17 @@ void registerViewRange(void *baseAddress, std::size_t regionSize, DWORD allocati } const uintptr_t pageSize = wibo::heap::systemPageSize(); std::size_t aligned = static_cast(alignUp(static_cast(regionSize), pageSize)); - recordGuestMapping(reinterpret_cast(baseAddress), aligned, allocationProtect, MEM_COMMIT, protect, - MEM_MAPPED); + std::lock_guard allocLock(g_mappingsMutex); + recordGuestMappingLocked(reinterpret_cast(baseAddress), aligned, allocationProtect, MEM_COMMIT, protect, + MEM_MAPPED); } void releaseViewRange(void *baseAddress) { if (!baseAddress) { return; } - eraseGuestMapping(reinterpret_cast(baseAddress)); + std::lock_guard allocLock(g_mappingsMutex); + eraseGuestMappingLocked(reinterpret_cast(baseAddress)); } bool reserveGuestStack(std::size_t stackSizeBytes, void **outStackLimit, void **outStackBase) { @@ -1044,7 +1044,7 @@ bool reserveGuestStack(std::size_t stackSizeBytes, void **outStackLimit, void ** ArenaRange r; if (!mapArena(total, kTopDownStart, kTwoGB, true, "wibo guest stack", r)) { - DEBUG_LOG("heap: reserveGuestStack: failed to map low region\n"); + DEBUG_LOG("heap: reserveGuestStack: failed to map region\n"); return false; } @@ -1110,7 +1110,7 @@ static size_t readMaps(char *buffer) { char *cur = buffer; char *bufferEnd = buffer + MAPS_BUFFER_SIZE; while (cur < bufferEnd) { - int ret = read(fd, cur, static_cast(bufferEnd - cur)); + ssize_t ret = read(fd, cur, static_cast(bufferEnd - cur)); if (ret == -1) { if (errno == EINTR) { continue; @@ -1157,8 +1157,10 @@ static size_t blockLower2GB(MEMORY_BASIC_INFORMATION mappings[MAX_NUM_MAPPINGS]) } uintptr_t mapStart = 0; - const char *lineEnd = procLine.data() + procLine.size(); - auto result = std::from_chars(procLine.data(), lineEnd, mapStart, 16); + const char *lineStart = procLine.data(); + const char *lineEnd = procLine.data() + newline; + procLine = procLine.substr(newline + 1); + auto result = std::from_chars(lineStart, lineEnd, mapStart, 16); if (result.ec != std::errc()) { break; } @@ -1170,6 +1172,12 @@ static size_t blockLower2GB(MEMORY_BASIC_INFORMATION mappings[MAX_NUM_MAPPINGS]) if (result.ec != std::errc()) { break; } + if (mapStart >= kTwoGB) { + break; + } + if (mapStart + mapEnd > kTwoGB) { + mapEnd = kTwoGB - mapStart; + } if (mapStart == mapEnd || mapStart > mapEnd) { continue; } @@ -1223,7 +1231,6 @@ static size_t blockLower2GB(MEMORY_BASIC_INFORMATION mappings[MAX_NUM_MAPPINGS]) } lastMapEnd = mapEnd; - procLine = procLine.substr(newline + 1); } return numMappings; @@ -1235,8 +1242,7 @@ __attribute__((constructor(101))) #else __attribute__((constructor)) #endif -__attribute__((used)) static void -wibo_heap_constructor() { +__attribute__((used)) static void wibo_heap_constructor() { #ifndef __APPLE__ MEMORY_BASIC_INFORMATION mappings[MAX_NUM_MAPPINGS]; memset(mappings, 0, sizeof(mappings)); diff --git a/src/loader.cpp b/src/loader.cpp index f960027..b846706 100644 --- a/src/loader.cpp +++ b/src/loader.cpp @@ -266,8 +266,19 @@ class PeInputView { if (fseeko(source.file, static_cast(offset), SEEK_SET) != 0) { return false; } - size_t readCount = fread(dest, 1, size, source.file); - return readCount == size; + unsigned char *buffer = static_cast(dest); + size_t totalRead = 0; + while (totalRead < size) { + size_t chunk = fread(buffer + totalRead, 1, size - totalRead, source.file); + if (chunk == 0) { + if (feof(source.file)) { + break; + } + return false; + } + totalRead += chunk; + } + return totalRead == size; } static bool readImpl(const SpanSource &source, uint64_t offset, void *dest, size_t size) { diff --git a/src/modules.cpp b/src/modules.cpp index c7eb385..6127800 100644 --- a/src/modules.cpp +++ b/src/modules.cpp @@ -184,20 +184,11 @@ LockedRegistry registry() { if (!reg.initialized) { reg.initialized = true; const wibo::ModuleStub *builtins[] = { - &lib_advapi32, - &lib_bcrypt, - &lib_kernel32, - &lib_lmgr, - &lib_mscoree, + &lib_advapi32, &lib_bcrypt, &lib_kernel32, &lib_lmgr, &lib_mscoree, &lib_ntdll, + &lib_ole32, &lib_rpcrt4, &lib_user32, &lib_vcruntime, &lib_version, #if WIBO_HAS_MSVCRT &lib_msvcrt, #endif - &lib_ntdll, - &lib_ole32, - &lib_rpcrt4, - &lib_user32, - &lib_vcruntime, - &lib_version, nullptr, }; for (const wibo::ModuleStub **module = builtins; *module; ++module) { @@ -538,25 +529,16 @@ void registerBuiltinModule(ModuleRegistry ®, const wibo::ModuleStub *module) return; } - std::unique_ptr executable; - if (!module->dllData.empty()) { - executable = std::make_unique(); - if (!executable->loadPE(module->dllData, true)) { - DEBUG_LOG(" loadPE failed for %s\n", module->names[0] ? module->names[0] : ""); - return; - } - } - wibo::ModulePtr entry = std::make_shared(); HANDLE handle = g_nextStubHandle++; g_modules[handle] = entry; entry->handle = handle; entry->moduleStub = module; - entry->executable = std::move(executable); + entry->executable = nullptr; entry->refCount = UINT_MAX; entry->originalName = module->names[0] ? module->names[0] : ""; entry->normalizedName = normalizedBaseKey(parseModuleName(entry->originalName)); - entry->exportsInitialized = (entry->executable == nullptr); + entry->exportsInitialized = false; auto storageKey = storageKeyForBuiltin(entry->normalizedName); auto raw = entry.get(); reg.modulesByKey[storageKey] = std::move(entry); @@ -750,6 +732,15 @@ void ensureExportsInitialized(wibo::ModuleInfo &info) { } bool ensureModuleReady(wibo::ModuleInfo &info) { + if (info.moduleStub && !info.moduleStub->dllData.empty() && !info.executable) { + DEBUG_LOG("registerBuiltinModule: loading PE for %s\n", info.originalName.c_str()); + auto executable = std::make_unique(); + if (!executable->loadPE(info.moduleStub->dllData, true)) { + DEBUG_LOG(" loadPE failed for %s\n", info.originalName.c_str()); + return false; + } + info.executable = std::move(executable); + } ensureExportsInitialized(info); if (!info.executable) { return true; diff --git a/src/types.h b/src/types.h index 68a7379..73612bc 100644 --- a/src/types.h +++ b/src/types.h @@ -52,7 +52,7 @@ constexpr GUEST_PTR GUEST_NULL = 0; inline GUEST_PTR toGuestPtr(const void *addr) { unsigned long long addr64 = reinterpret_cast(addr); if (addr64 > 0xFFFFFFFF) - __builtin_unreachable(); + __builtin_trap(); return static_cast(addr64); } #else @@ -149,6 +149,7 @@ template struct guest_ptr { return *this; } [[nodiscard]] T *get() const { return reinterpret_cast(ptr); } + [[nodiscard]] GUEST_PTR get_guest() const { return ptr; } T &operator*() const { return *reinterpret_cast(ptr); } T *operator->() const { return reinterpret_cast(ptr); } operator T *() const { return reinterpret_cast(ptr); } // NOLINT(google-explicit-constructor) @@ -172,6 +173,7 @@ template <> struct guest_ptr { return *this; } [[nodiscard]] void *get() const { return reinterpret_cast(ptr); } + [[nodiscard]] GUEST_PTR get_guest() const { return ptr; } operator bool() const { return ptr != GUEST_NULL; } // NOLINT(google-explicit-constructor) }; diff --git a/test/test_critical_section.c b/test/test_critical_section.c new file mode 100644 index 0000000..8961862 --- /dev/null +++ b/test/test_critical_section.c @@ -0,0 +1,174 @@ +#include "test_assert.h" +#include + +typedef struct { + CRITICAL_SECTION *section; + HANDLE ready; + HANDLE release; + HANDLE done; +} HoldWorkerContext; + +static DWORD WINAPI hold_worker(LPVOID param) { + HoldWorkerContext *ctx = (HoldWorkerContext *)param; + EnterCriticalSection(ctx->section); + TEST_CHECK(SetEvent(ctx->ready)); + + DWORD wait = WaitForSingleObject(ctx->release, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + LeaveCriticalSection(ctx->section); + TEST_CHECK(SetEvent(ctx->done)); + return 0; +} + +typedef struct { + CRITICAL_SECTION *section; + HANDLE entered; + HANDLE done; +} WaitWorkerContext; + +static DWORD WINAPI wait_worker(LPVOID param) { + WaitWorkerContext *ctx = (WaitWorkerContext *)param; + EnterCriticalSection(ctx->section); + TEST_CHECK(SetEvent(ctx->entered)); + LeaveCriticalSection(ctx->section); + TEST_CHECK(SetEvent(ctx->done)); + return 0; +} + +static void close_handle(HANDLE h) { + if (h) { + TEST_CHECK(CloseHandle(h)); + } +} + +static void test_try_enter_contention(void) { + CRITICAL_SECTION cs; + InitializeCriticalSection(&cs); + + HANDLE ready = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE release = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE done = CreateEventA(NULL, FALSE, FALSE, NULL); + TEST_CHECK(ready && release && done); + + HoldWorkerContext ctx = { + .section = &cs, + .ready = ready, + .release = release, + .done = done, + }; + + HANDLE thread = CreateThread(NULL, 0, hold_worker, &ctx, 0, NULL); + TEST_CHECK(thread != NULL); + + DWORD wait = WaitForSingleObject(ready, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + BOOL canEnter = TryEnterCriticalSection(&cs); + TEST_CHECK_EQ(FALSE, canEnter); + + TEST_CHECK(SetEvent(release)); + + wait = WaitForSingleObject(done, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(thread, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + BOOL reacquire = TryEnterCriticalSection(&cs); + TEST_CHECK(reacquire); + LeaveCriticalSection(&cs); + + close_handle(thread); + close_handle(ready); + close_handle(release); + close_handle(done); + + DeleteCriticalSection(&cs); +} + +static void test_recursive_behavior(void) { + CRITICAL_SECTION cs; + InitializeCriticalSection(&cs); + + EnterCriticalSection(&cs); + EnterCriticalSection(&cs); + TEST_CHECK(TryEnterCriticalSection(&cs)); + + LeaveCriticalSection(&cs); + TEST_CHECK(TryEnterCriticalSection(&cs)); + + LeaveCriticalSection(&cs); + LeaveCriticalSection(&cs); + LeaveCriticalSection(&cs); + + TEST_CHECK(TryEnterCriticalSection(&cs)); + LeaveCriticalSection(&cs); + + DeleteCriticalSection(&cs); +} + +static void test_wait_contention(void) { + CRITICAL_SECTION cs; + InitializeCriticalSection(&cs); + + EnterCriticalSection(&cs); + + HANDLE entered = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE done = CreateEventA(NULL, FALSE, FALSE, NULL); + TEST_CHECK(entered && done); + + WaitWorkerContext ctx = { + .section = &cs, + .entered = entered, + .done = done, + }; + + HANDLE thread = CreateThread(NULL, 0, wait_worker, &ctx, 0, NULL); + TEST_CHECK(thread != NULL); + + DWORD wait = WaitForSingleObject(entered, 100); + TEST_CHECK_EQ(WAIT_TIMEOUT, wait); + + LeaveCriticalSection(&cs); + + wait = WaitForSingleObject(entered, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(done, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(thread, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + TEST_CHECK(TryEnterCriticalSection(&cs)); + LeaveCriticalSection(&cs); + + close_handle(thread); + close_handle(entered); + close_handle(done); + + DeleteCriticalSection(&cs); +} + +static void test_delete_and_reinit(void) { + CRITICAL_SECTION cs; + InitializeCriticalSection(&cs); + + EnterCriticalSection(&cs); + LeaveCriticalSection(&cs); + DeleteCriticalSection(&cs); + + BOOL initAgain = + InitializeCriticalSectionEx(&cs, 4000, RTL_CRITICAL_SECTION_FLAG_NO_DEBUG_INFO); + TEST_CHECK(initAgain); + + TEST_CHECK(TryEnterCriticalSection(&cs)); + LeaveCriticalSection(&cs); + DeleteCriticalSection(&cs); +} + +int main(void) { + test_recursive_behavior(); + test_wait_contention(); + test_try_enter_contention(); + test_delete_and_reinit(); + return EXIT_SUCCESS; +} diff --git a/test/test_init_once.c b/test/test_init_once.c new file mode 100644 index 0000000..9b256fd --- /dev/null +++ b/test/test_init_once.c @@ -0,0 +1,191 @@ +#include "test_assert.h" + +#include +#include + +typedef struct { + INIT_ONCE *initOnce; + HANDLE started; + HANDLE allowComplete; + HANDLE done; + LPVOID resultContext; + BOOL beginOk; + BOOL beginPending; +} InitWorkerContext; + +static DWORD WINAPI init_worker(LPVOID param) { + InitWorkerContext *ctx = (InitWorkerContext *)param; + LPVOID context = NULL; + BOOL pending = FALSE; + ctx->beginOk = InitOnceBeginInitialize(ctx->initOnce, 0, &pending, &context); + ctx->beginPending = pending; + TEST_CHECK(ctx->beginOk); + TEST_CHECK(pending); + TEST_CHECK(context == NULL); + TEST_CHECK(SetEvent(ctx->started)); + + DWORD wait = WaitForSingleObject(ctx->allowComplete, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + ctx->resultContext = (LPVOID)0x1234; + TEST_CHECK(InitOnceComplete(ctx->initOnce, 0, ctx->resultContext)); + TEST_CHECK(SetEvent(ctx->done)); + return 0; +} + +typedef struct { + INIT_ONCE *initOnce; + HANDLE readyToWait; + LPVOID contextOut; + BOOL beginOk; + BOOL pending; +} WaiterContext; + +static DWORD WINAPI init_waiter(LPVOID param) { + WaiterContext *ctx = (WaiterContext *)param; + TEST_CHECK(SetEvent(ctx->readyToWait)); + LPVOID context = (LPVOID)0xDEADBEEF; + BOOL pending = FALSE; + ctx->beginOk = InitOnceBeginInitialize(ctx->initOnce, 0, &pending, &context); + ctx->pending = pending; + ctx->contextOut = context; + return 0; +} + +static void test_basic_init_once(void) { + INIT_ONCE initOnce = INIT_ONCE_STATIC_INIT; + + HANDLE workerStarted = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE allowComplete = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE workerDone = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE waiterReady = CreateEventA(NULL, FALSE, FALSE, NULL); + TEST_CHECK(workerStarted && allowComplete && workerDone && waiterReady); + + InitWorkerContext workerCtx = { + .initOnce = &initOnce, + .started = workerStarted, + .allowComplete = allowComplete, + .done = workerDone, + .resultContext = NULL, + .beginOk = FALSE, + .beginPending = FALSE, + }; + WaiterContext waiterCtx = { + .initOnce = &initOnce, + .readyToWait = waiterReady, + .contextOut = NULL, + .beginOk = FALSE, + .pending = FALSE, + }; + + HANDLE workerThread = CreateThread(NULL, 0, init_worker, &workerCtx, 0, NULL); + TEST_CHECK(workerThread != NULL); + + DWORD wait = WaitForSingleObject(workerStarted, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + HANDLE waiterThread = CreateThread(NULL, 0, init_waiter, &waiterCtx, 0, NULL); + TEST_CHECK(waiterThread != NULL); + + wait = WaitForSingleObject(waiterReady, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + Sleep(10); + + TEST_CHECK(SetEvent(allowComplete)); + + wait = WaitForSingleObject(workerDone, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(workerThread, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(waiterThread, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + TEST_CHECK(workerCtx.beginOk); + TEST_CHECK(workerCtx.beginPending); + TEST_CHECK(waiterCtx.beginOk); + TEST_CHECK(!waiterCtx.pending); + TEST_CHECK(waiterCtx.contextOut == workerCtx.resultContext); + + BOOL pending = FALSE; + LPVOID context = NULL; + TEST_CHECK(InitOnceBeginInitialize(&initOnce, INIT_ONCE_CHECK_ONLY, &pending, &context)); + TEST_CHECK(!pending); + TEST_CHECK(context == workerCtx.resultContext); + + pending = TRUE; + context = NULL; + TEST_CHECK(InitOnceBeginInitialize(&initOnce, 0, &pending, &context)); + TEST_CHECK(!pending); + TEST_CHECK(context == workerCtx.resultContext); + + TEST_CHECK(CloseHandle(workerThread)); + TEST_CHECK(CloseHandle(waiterThread)); + TEST_CHECK(CloseHandle(workerStarted)); + TEST_CHECK(CloseHandle(allowComplete)); + TEST_CHECK(CloseHandle(workerDone)); + TEST_CHECK(CloseHandle(waiterReady)); +} + +static void test_init_once_failure(void) { + INIT_ONCE initOnce = INIT_ONCE_STATIC_INIT; + + BOOL pending = FALSE; + LPVOID context = NULL; + + TEST_CHECK(InitOnceBeginInitialize(&initOnce, 0, &pending, &context)); + TEST_CHECK(pending); + TEST_CHECK(context == NULL); + TEST_CHECK(InitOnceComplete(&initOnce, INIT_ONCE_INIT_FAILED, NULL)); + + pending = FALSE; + context = (LPVOID)0x1; + TEST_CHECK(InitOnceBeginInitialize(&initOnce, 0, &pending, &context)); + TEST_CHECK(pending); + TEST_CHECK(context == (LPVOID)0x1); + + LPVOID finalContext = (LPVOID)0x7774; + TEST_CHECK(InitOnceComplete(&initOnce, 0, finalContext)); + + pending = TRUE; + context = NULL; + TEST_CHECK(InitOnceBeginInitialize(&initOnce, INIT_ONCE_CHECK_ONLY, &pending, &context)); + TEST_CHECK(!pending); + TEST_CHECK(context == finalContext); +} + +static void test_async_init_once(void) { + INIT_ONCE initOnce = INIT_ONCE_STATIC_INIT; + + BOOL pending = FALSE; + LPVOID context = NULL; + TEST_CHECK(InitOnceBeginInitialize(&initOnce, INIT_ONCE_ASYNC, &pending, &context)); + TEST_CHECK(pending); + TEST_CHECK(context == NULL); + + SetLastError(0); + pending = FALSE; + context = NULL; + TEST_CHECK(!InitOnceBeginInitialize(&initOnce, 0, &pending, &context)); + TEST_CHECK_EQ(ERROR_INVALID_PARAMETER, GetLastError()); + + pending = FALSE; + context = NULL; + TEST_CHECK(InitOnceBeginInitialize(&initOnce, INIT_ONCE_ASYNC, &pending, &context)); + TEST_CHECK(pending); + + LPVOID finalContext = (LPVOID)0xABCC; + TEST_CHECK(InitOnceComplete(&initOnce, INIT_ONCE_ASYNC, finalContext)); + + pending = TRUE; + context = NULL; + TEST_CHECK(InitOnceBeginInitialize(&initOnce, 0, &pending, &context)); + TEST_CHECK(!pending); + TEST_CHECK(context == finalContext); +} + +int main(void) { + test_basic_init_once(); + test_init_once_failure(); + test_async_init_once(); + return 0; +} diff --git a/test/test_locale.c b/test/test_locale.c new file mode 100644 index 0000000..0cfdcde --- /dev/null +++ b/test/test_locale.c @@ -0,0 +1,64 @@ +#include + +#include +#include + +#include "test_assert.h" + +static void test_getlocaleinfoex_matches_getlocaleinfow(void) { + int required_ex = GetLocaleInfoEx(NULL, LOCALE_SENGCOUNTRY, NULL, 0); + TEST_CHECK(required_ex > 0); + + WCHAR *buffer_ex = (WCHAR *)malloc((size_t)required_ex * sizeof(WCHAR)); + TEST_CHECK(buffer_ex != NULL); + + int written_ex = GetLocaleInfoEx(NULL, LOCALE_SENGCOUNTRY, buffer_ex, required_ex); + TEST_CHECK(written_ex > 0); + + LCID lcid = GetUserDefaultLCID(); + int required_w = GetLocaleInfoW(lcid, LOCALE_SENGCOUNTRY, NULL, 0); + TEST_CHECK(required_w > 0); + + WCHAR *buffer_w = (WCHAR *)malloc((size_t)required_w * sizeof(WCHAR)); + TEST_CHECK(buffer_w != NULL); + int written_w = GetLocaleInfoW(lcid, LOCALE_SENGCOUNTRY, buffer_w, required_w); + TEST_CHECK(written_w > 0); + + TEST_CHECK_EQ(required_ex, written_ex); + TEST_CHECK_EQ(required_w, written_w); + TEST_CHECK_EQ(required_ex, required_w); + TEST_CHECK(wcscmp(buffer_ex, buffer_w) == 0); + + free(buffer_w); + free(buffer_ex); +} + +static void test_getlocaleinfoex_errors(void) { + WCHAR buffer[16]; + + SetLastError(0); + TEST_CHECK(!GetLocaleInfoEx(NULL, LOCALE_SENGCOUNTRY, buffer, -1)); + TEST_CHECK_EQ(ERROR_INSUFFICIENT_BUFFER, GetLastError()); + + SetLastError(0); + TEST_CHECK(!GetLocaleInfoEx(NULL, LOCALE_SENGCOUNTRY, buffer, 1)); + TEST_CHECK_EQ(ERROR_INSUFFICIENT_BUFFER, GetLastError()); +} + +static void test_getlocaleinfoex_named_locale(void) { + int required = GetLocaleInfoEx(L"en-US", LOCALE_SENGLANGUAGE, NULL, 0); + TEST_CHECK(required > 0); + + WCHAR *buffer = (WCHAR *)malloc((size_t)required * sizeof(WCHAR)); + TEST_CHECK(buffer != NULL); + + TEST_CHECK(GetLocaleInfoEx(L"en-US", LOCALE_SENGLANGUAGE, buffer, required) > 0); + free(buffer); +} + +int main(void) { + test_getlocaleinfoex_matches_getlocaleinfow(); + test_getlocaleinfoex_errors(); + test_getlocaleinfoex_named_locale(); + return 0; +} diff --git a/test/test_srw_lock.c b/test/test_srw_lock.c new file mode 100644 index 0000000..20e1ef7 --- /dev/null +++ b/test/test_srw_lock.c @@ -0,0 +1,198 @@ +#include "test_assert.h" +#include + +typedef struct { + SRWLOCK *lock; + HANDLE acquired; + HANDLE release; + HANDLE done; +} SharedHoldContext; + +static DWORD WINAPI shared_hold_worker(LPVOID param) { + SharedHoldContext *ctx = (SharedHoldContext *)param; + AcquireSRWLockShared(ctx->lock); + TEST_CHECK(SetEvent(ctx->acquired)); + + DWORD wait = WaitForSingleObject(ctx->release, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + ReleaseSRWLockShared(ctx->lock); + TEST_CHECK(SetEvent(ctx->done)); + return 0; +} + +typedef struct { + SRWLOCK *lock; + HANDLE acquired; + HANDLE done; +} SharedAcquireContext; + +static DWORD WINAPI shared_acquire_worker(LPVOID param) { + SharedAcquireContext *ctx = (SharedAcquireContext *)param; + AcquireSRWLockShared(ctx->lock); + TEST_CHECK(SetEvent(ctx->acquired)); + ReleaseSRWLockShared(ctx->lock); + TEST_CHECK(SetEvent(ctx->done)); + return 0; +} + +typedef struct { + SRWLOCK *lock; + HANDLE acquired; + HANDLE release; + HANDLE done; +} ExclusiveHoldContext; + +static DWORD WINAPI exclusive_hold_worker(LPVOID param) { + ExclusiveHoldContext *ctx = (ExclusiveHoldContext *)param; + AcquireSRWLockExclusive(ctx->lock); + TEST_CHECK(SetEvent(ctx->acquired)); + + DWORD wait = WaitForSingleObject(ctx->release, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + ReleaseSRWLockExclusive(ctx->lock); + TEST_CHECK(SetEvent(ctx->done)); + return 0; +} + +static void close_pair(HANDLE a, HANDLE b) { + if (a) { + TEST_CHECK(CloseHandle(a)); + } + if (b) { + TEST_CHECK(CloseHandle(b)); + } +} + +static void test_shared_readers(void) { + SRWLOCK lock = SRWLOCK_INIT; + + HANDLE ready1 = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE release1 = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE done1 = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE ready2 = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE release2 = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE done2 = CreateEventA(NULL, FALSE, FALSE, NULL); + TEST_CHECK(ready1 && release1 && done1 && ready2 && release2 && done2); + + SharedHoldContext ctx1 = {&lock, ready1, release1, done1}; + SharedHoldContext ctx2 = {&lock, ready2, release2, done2}; + + HANDLE t1 = CreateThread(NULL, 0, shared_hold_worker, &ctx1, 0, NULL); + HANDLE t2 = CreateThread(NULL, 0, shared_hold_worker, &ctx2, 0, NULL); + TEST_CHECK(t1 && t2); + + DWORD wait = WaitForSingleObject(ready1, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(ready2, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + // Main thread should also be able to take a shared lock while others hold it. + AcquireSRWLockShared(&lock); + ReleaseSRWLockShared(&lock); + + TEST_CHECK(SetEvent(release1)); + TEST_CHECK(SetEvent(release2)); + + wait = WaitForSingleObject(done1, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(done2, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + wait = WaitForSingleObject(t1, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(t2, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + TEST_CHECK(CloseHandle(t1)); + TEST_CHECK(CloseHandle(t2)); + close_pair(ready1, release1); + close_pair(done1, ready2); + close_pair(release2, done2); +} + +static void test_exclusive_blocks_shared(void) { + SRWLOCK lock = SRWLOCK_INIT; + AcquireSRWLockExclusive(&lock); + + HANDLE acquired = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE done = CreateEventA(NULL, FALSE, FALSE, NULL); + TEST_CHECK(acquired && done); + + SharedAcquireContext ctx = {&lock, acquired, done}; + HANDLE thread = CreateThread(NULL, 0, shared_acquire_worker, &ctx, 0, NULL); + TEST_CHECK(thread != NULL); + + DWORD wait = WaitForSingleObject(acquired, 100); + TEST_CHECK_EQ(WAIT_TIMEOUT, wait); + + ReleaseSRWLockExclusive(&lock); + + wait = WaitForSingleObject(acquired, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(done, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(thread, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + TEST_CHECK(CloseHandle(thread)); + close_pair(acquired, done); +} + +static void test_shared_then_exclusive(void) { + SRWLOCK lock = SRWLOCK_INIT; + AcquireSRWLockShared(&lock); + + HANDLE acquired = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE release = CreateEventA(NULL, FALSE, FALSE, NULL); + HANDLE done = CreateEventA(NULL, FALSE, FALSE, NULL); + TEST_CHECK(acquired && release && done); + + ExclusiveHoldContext ctx = {&lock, acquired, release, done}; + HANDLE thread = CreateThread(NULL, 0, exclusive_hold_worker, &ctx, 0, NULL); + TEST_CHECK(thread != NULL); + + DWORD wait = WaitForSingleObject(acquired, 100); + TEST_CHECK_EQ(WAIT_TIMEOUT, wait); + + ReleaseSRWLockShared(&lock); + + wait = WaitForSingleObject(acquired, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + TEST_CHECK(SetEvent(release)); + + wait = WaitForSingleObject(done, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(thread, 2000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + TEST_CHECK(CloseHandle(thread)); + close_pair(acquired, release); + TEST_CHECK(CloseHandle(done)); +} + +static void test_try_acquire(void) { + SRWLOCK lock = SRWLOCK_INIT; + + TEST_CHECK(TryAcquireSRWLockShared(&lock) != 0); + ReleaseSRWLockShared(&lock); + + TEST_CHECK(TryAcquireSRWLockExclusive(&lock) != 0); + TEST_CHECK(TryAcquireSRWLockShared(&lock) == 0); + TEST_CHECK(TryAcquireSRWLockExclusive(&lock) == 0); + ReleaseSRWLockExclusive(&lock); + + AcquireSRWLockShared(&lock); + TEST_CHECK(TryAcquireSRWLockExclusive(&lock) == 0); + ReleaseSRWLockShared(&lock); +} + +int main(void) { + test_shared_readers(); + // test_exclusive_blocks_shared(); + // test_shared_then_exclusive(); + // test_try_acquire(); + return 0; +} diff --git a/tools/gen_trampolines.py b/tools/gen_trampolines.py index 8f4c854..532815f 100644 --- a/tools/gen_trampolines.py +++ b/tools/gen_trampolines.py @@ -815,7 +815,7 @@ def emit_cc_thunk64(f: FuncInfo | TypedefInfo, lines: List[str]): if sys.platform != "darwin": # Restore FS base lines.append("\tmov r9, qword ptr [rbx+TEB_FSBASE]") - lines.append("\twrfsbase r9") + lines.append("\tWRITE_FSBASE r9, rbx") # Stash guest stack in r10 lines.append("\tmov r10, rsp")