From 09a7452c777a9d4020d4d35e1b86d120fd6445dc Mon Sep 17 00:00:00 2001 From: Luke Street Date: Tue, 11 Nov 2025 10:44:14 -0700 Subject: [PATCH] Implement WaitOnAddress, WakeByAddress*; macOS impl for atomic waits --- CMakeLists.txt | 3 + cmake/toolchains/x86_64-darwin.cmake | 2 +- dll/kernel32/synchapi.cpp | 469 +++++++++++++++++++++++---- dll/kernel32/synchapi.h | 3 + src/errors.h | 1 + src/modules.cpp | 3 +- src/setup.S | 32 +- src/setup_darwin.cpp | 16 +- test/test_wait_on_address.c | 142 ++++++++ 9 files changed, 589 insertions(+), 82 deletions(-) create mode 100644 test/test_wait_on_address.c diff --git a/CMakeLists.txt b/CMakeLists.txt index 44cad45..43694ae 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -258,6 +258,8 @@ target_include_directories(wibo PRIVATE dll src ${WIBO_GENERATED_HEADER_DIR}) target_link_libraries(wibo PRIVATE mimalloc-obj) if (CMAKE_SYSTEM_NAME STREQUAL "Linux") target_link_libraries(wibo PRIVATE atomic) +elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin") + target_link_libraries(wibo PRIVATE c++ c++abi) endif() if (WIBO_ENABLE_LIBURING AND CMAKE_SYSTEM_NAME STREQUAL "Linux") target_compile_definitions(wibo PRIVATE WIBO_ENABLE_LIBURING=1) @@ -465,6 +467,7 @@ if (WIBO_ENABLE_FIXTURE_TESTS) wibo_add_fixture_bin(NAME test_sysdir SOURCES test/test_sysdir.c) wibo_add_fixture_bin(NAME test_srw_lock SOURCES test/test_srw_lock.c) wibo_add_fixture_bin(NAME test_init_once SOURCES test/test_init_once.c) + wibo_add_fixture_bin(NAME test_wait_on_address SOURCES test/test_wait_on_address.c COMPILE_OPTIONS -lsynchronization) # DLLs for fixture tests wibo_add_fixture_dll(NAME external_exports SOURCES test/external_exports.c) diff --git a/cmake/toolchains/x86_64-darwin.cmake b/cmake/toolchains/x86_64-darwin.cmake index 1cb242a..9829948 100644 --- a/cmake/toolchains/x86_64-darwin.cmake +++ b/cmake/toolchains/x86_64-darwin.cmake @@ -10,7 +10,7 @@ set(CMAKE_ASM_COMPILER_TARGET ${TARGET}) # Force x86_64 architecture set(CMAKE_OSX_ARCHITECTURES "x86_64" CACHE STRING "Build architecture for macOS" FORCE) -set(CMAKE_OSX_DEPLOYMENT_TARGET "10.15" CACHE STRING "Minimum macOS deployment version" FORCE) +set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS deployment version" FORCE) # Search for programs in the build host directories set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) diff --git a/dll/kernel32/synchapi.cpp b/dll/kernel32/synchapi.cpp index 16b941f..88e08d3 100644 --- a/dll/kernel32/synchapi.cpp +++ b/dll/kernel32/synchapi.cpp @@ -5,16 +5,20 @@ #include "errors.h" #include "handles.h" #include "heap.h" +#include "interlockedapi.h" #include "internal.h" #include "processthreadsapi.h" #include "strutil.h" #include "types.h" #include +#include #include #include #include +#include #include +#include #include #include #include @@ -24,6 +28,7 @@ #include #include #include +#include #include namespace { @@ -33,6 +38,292 @@ constexpr DWORD kSrwLockSharedIncrement = 0x2u; constexpr GUEST_PTR kInitOnceStateMask = 0x3u; constexpr GUEST_PTR kInitOnceCompletedFlag = 0x2u; constexpr GUEST_PTR kInitOnceReservedMask = (1u << INIT_ONCE_CTX_RESERVED_BITS) - 1; +constexpr size_t kSupportedAddressSizes = 4; // 1, 2, 4, 8 bytes + +struct AddressWaitQueue { + std::mutex mutex; + std::condition_variable cv; + size_t waiterCount = 0; + std::array sizeCounts{}; +}; + +std::mutex g_waitAddressMutex; +std::unordered_map> g_waitAddressQueues; + +constexpr size_t sizeToIndex(size_t size) { + size_t index = __builtin_ctz(size); + return index >= kSupportedAddressSizes ? -1 : index; +} + +std::shared_ptr getWaitQueue(void *address) { + std::lock_guard lk(g_waitAddressMutex); + auto &slot = g_waitAddressQueues[address]; + auto queue = slot.lock(); + if (!queue) { + queue = std::make_shared(); + slot = queue; + } + return queue; +} + +std::shared_ptr tryGetWaitQueue(void *address) { + std::lock_guard lk(g_waitAddressMutex); + auto it = g_waitAddressQueues.find(address); + if (it == g_waitAddressQueues.end()) { + return nullptr; + } + auto queue = it->second.lock(); + if (!queue) { + g_waitAddressQueues.erase(it); + return nullptr; + } + return queue; +} + +void cleanupWaitQueue(void *address, const std::shared_ptr &queue) { + std::lock_guard lk(g_waitAddressMutex); + auto it = g_waitAddressQueues.find(address); + if (it == g_waitAddressQueues.end()) { + return; + } + auto locked = it->second.lock(); + if (!locked) { + g_waitAddressQueues.erase(it); + return; + } + if (locked.get() != queue.get()) { + return; + } + std::lock_guard queueLock(queue->mutex); + if (queue->waiterCount == 0) { + g_waitAddressQueues.erase(it); + } +} + +struct WaitRegistration { + void *address; + std::shared_ptr queue; + size_t sizeIndex; + bool registered = false; + + WaitRegistration(void *addr, std::shared_ptr q, size_t idx) + : address(addr), queue(std::move(q)), sizeIndex(idx) {} + + void registerWaiter() { + if (!queue) { + return; + } + std::lock_guard lk(queue->mutex); + queue->waiterCount++; + queue->sizeCounts[sizeIndex]++; + registered = true; + } + + void unregister() { + if (!queue || !registered) { + return; + } + std::lock_guard lk(queue->mutex); + queue->waiterCount--; + queue->sizeCounts[sizeIndex]--; + registered = false; + } + + ~WaitRegistration() { + unregister(); + if (queue) { + cleanupWaitQueue(address, queue); + } + } +}; + +#if defined(__APPLE__) + +using LibcppMonitorFn = long long (*)(const void volatile *); +using LibcppWaitFn = void (*)(const void volatile *, long long); +using LibcppNotifyFn = void (*)(const void volatile *); + +LibcppMonitorFn getLibcppAtomicMonitor() { + static LibcppMonitorFn fn = + reinterpret_cast(dlsym(RTLD_DEFAULT, "_ZNSt3__123__libcpp_atomic_monitorEPVKv")); + return fn; +} + +LibcppWaitFn getLibcppAtomicWait() { + static LibcppWaitFn fn = + reinterpret_cast(dlsym(RTLD_DEFAULT, "_ZNSt3__120__libcpp_atomic_waitEPVKvx")); + return fn; +} + +LibcppNotifyFn getLibcppAtomicNotifyOne() { + static LibcppNotifyFn fn = + reinterpret_cast(dlsym(RTLD_DEFAULT, "_ZNSt3__123__cxx_atomic_notify_oneEPVKv")); + return fn; +} + +LibcppNotifyFn getLibcppAtomicNotifyAll() { + static LibcppNotifyFn fn = + reinterpret_cast(dlsym(RTLD_DEFAULT, "_ZNSt3__123__cxx_atomic_notify_allEPVKv")); + return fn; +} + +template void platformWaitIndefinite(T volatile *address, T expected) { + auto monitorFn = getLibcppAtomicMonitor(); + auto waitFn = getLibcppAtomicWait(); + if (!monitorFn || !waitFn) { + while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == expected) { + std::this_thread::sleep_for(std::chrono::microseconds(50)); + } + return; + } + while (true) { + T current = __atomic_load_n(address, __ATOMIC_ACQUIRE); + if (current != expected) { + return; + } + auto monitor = monitorFn(address); + current = __atomic_load_n(address, __ATOMIC_ACQUIRE); + if (current != expected) { + continue; + } + waitFn(address, monitor); + } +} + +inline void platformNotifyAddress(void *address, size_t, bool wakeOne) { + auto notifyFn = wakeOne ? getLibcppAtomicNotifyOne() : getLibcppAtomicNotifyAll(); + if (notifyFn) { + notifyFn(address); + } +} + +#elif defined(__linux__) + +template void platformWaitIndefinite(T volatile *address, T expected) { + std::atomic_ref ref(*address); + ref.wait(expected, std::memory_order_relaxed); +} + +template void linuxNotify(void *address, bool wakeOne) { + auto *typed = reinterpret_cast(address); + std::atomic_ref ref(*typed); + if (wakeOne) { + ref.notify_one(); + } else { + ref.notify_all(); + } +} + +inline void platformNotifyAddress(void *address, size_t size, bool wakeOne) { + switch (size) { + case 1: + linuxNotify(address, wakeOne); + break; + case 2: + linuxNotify(address, wakeOne); + break; + case 4: + linuxNotify(address, wakeOne); + break; + case 8: + linuxNotify(address, wakeOne); + break; + default: + break; + } +} + +#else + +template void platformWaitIndefinite(T volatile *address, T expected) { + while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == expected) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } +} + +inline void platformNotifyAddress(void *, size_t, bool) {} + +#endif + +void notifyAtomicWaiters(void *address, const std::array &sizeCounts, bool wakeOne) { + uintptr_t addrValue = reinterpret_cast(address); + for (size_t i = 0; i < sizeCounts.size(); ++i) { + if (sizeCounts[i] == 0) { + continue; + } + size_t size = 1 << i; + if (addrValue & (size - 1)) { + continue; + } + platformNotifyAddress(address, size, wakeOne); + } +} + +template bool waitOnAddressTyped(VOID volatile *addressVoid, PVOID comparePtr, DWORD dwMilliseconds) { + auto *address = reinterpret_cast(addressVoid); + if (!address) { + kernel32::setLastError(ERROR_INVALID_PARAMETER); + return false; + } + if (reinterpret_cast(address) % alignof(T) != 0) { + kernel32::setLastError(ERROR_INVALID_PARAMETER); + return false; + } + + const T compareValue = *reinterpret_cast(comparePtr); + if (__atomic_load_n(address, __ATOMIC_ACQUIRE) != compareValue) { + return true; + } + + if (dwMilliseconds == 0) { + kernel32::setLastError(ERROR_TIMEOUT); + return false; + } + + void *queueKey = const_cast(addressVoid); + auto queue = getWaitQueue(queueKey); + if (!queue) { + kernel32::setLastError(ERROR_GEN_FAILURE); + return false; + } + + int sizeIdx = sizeToIndex(sizeof(T)); + DEBUG_LOG("size: %d, index %d\n", sizeof(T), sizeIdx); + if (sizeIdx < 0) { + kernel32::setLastError(ERROR_INVALID_PARAMETER); + return false; + } + + WaitRegistration registration(queueKey, queue, static_cast(sizeIdx)); + registration.registerWaiter(); + + if (dwMilliseconds == INFINITE) { + while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) { + platformWaitIndefinite(address, compareValue); + } + return true; + } + + const auto deadline = + std::chrono::steady_clock::now() + std::chrono::milliseconds(static_cast(dwMilliseconds)); + bool timedOut = false; + { + std::unique_lock lk(queue->mutex); + while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) { + if (queue->cv.wait_until(lk, deadline) == std::cv_status::timeout) { + if (__atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) { + timedOut = true; + break; + } + } + } + } + if (timedOut && __atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) { + kernel32::setLastError(ERROR_TIMEOUT); + return false; + } + return true; +} std::u16string makeU16String(LPCWSTR name) { if (!name) { @@ -167,29 +458,25 @@ void eraseInitOnceState(LPINIT_ONCE once) { g_initOnceStates.erase(once); } -inline DWORD owningThreadId(LPCRITICAL_SECTION crit) { - std::atomic_ref owner(*reinterpret_cast(&crit->OwningThread)); - return owner.load(std::memory_order_acquire); -} +inline DWORD owningThreadId(LPCRITICAL_SECTION crit) { return __atomic_load_n(&crit->OwningThread, __ATOMIC_ACQUIRE); } inline void setOwningThread(LPCRITICAL_SECTION crit, DWORD threadId) { - std::atomic_ref owner(*reinterpret_cast(&crit->OwningThread)); - owner.store(threadId, std::memory_order_release); + __atomic_store_n(&crit->OwningThread, threadId, __ATOMIC_RELEASE); } void waitForCriticalSection(LPCRITICAL_SECTION cs) { - std::atomic_ref sequence(reinterpret_cast(cs->LockSemaphore)); - LONG observed = sequence.load(std::memory_order_acquire); + auto *sequence = reinterpret_cast(&cs->LockSemaphore); + LONG observed = __atomic_load_n(sequence, __ATOMIC_ACQUIRE); while (owningThreadId(cs) != 0) { - sequence.wait(observed, std::memory_order_relaxed); - observed = sequence.load(std::memory_order_acquire); + kernel32::WaitOnAddress(sequence, &observed, sizeof(observed), INFINITE); + observed = __atomic_load_n(sequence, __ATOMIC_ACQUIRE); } } void signalCriticalSection(LPCRITICAL_SECTION cs) { - std::atomic_ref sequence(reinterpret_cast(cs->LockSemaphore)); - sequence.fetch_add(1, std::memory_order_release); - sequence.notify_one(); + auto *sequence = reinterpret_cast(&cs->LockSemaphore); + kernel32::InterlockedIncrement(const_cast(sequence)); + kernel32::WakeByAddressSingle(sequence); } inline bool trySpinAcquireCriticalSection(LPCRITICAL_SECTION cs, DWORD threadId) { @@ -634,6 +921,71 @@ DWORD WINAPI WaitForMultipleObjects(DWORD nCount, const HANDLE *lpHandles, BOOL return waitResult; } +BOOL WINAPI WaitOnAddress(VOID volatile *Address, PVOID CompareAddress, SIZE_T AddressSize, DWORD dwMilliseconds) { + HOST_CONTEXT_GUARD(); + VERBOSE_LOG("WaitOnAddress(%p, %p, %zu, %u)\n", Address, CompareAddress, AddressSize, dwMilliseconds); + if (!Address || !CompareAddress) { + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + + BOOL result = FALSE; + switch (sizeToIndex(AddressSize)) { + case 0: + result = waitOnAddressTyped(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE; + break; + case 1: + result = waitOnAddressTyped(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE; + break; + case 2: + result = waitOnAddressTyped(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE; + break; + case 3: + result = waitOnAddressTyped(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE; + break; + default: + setLastError(ERROR_INVALID_PARAMETER); + return FALSE; + } + return result; +} + +void WINAPI WakeByAddressSingle(PVOID Address) { + HOST_CONTEXT_GUARD(); + VERBOSE_LOG("WakeByAddressSingle(%p)\n", Address); + if (!Address) { + return; + } + std::array sizeCounts{}; + auto queue = tryGetWaitQueue(Address); + if (queue) { + std::lock_guard lk(queue->mutex); + sizeCounts = queue->sizeCounts; + } + notifyAtomicWaiters(Address, sizeCounts, true); + if (queue) { + queue->cv.notify_one(); + } +} + +void WINAPI WakeByAddressAll(PVOID Address) { + HOST_CONTEXT_GUARD(); + VERBOSE_LOG("WakeByAddressAll(%p)\n", Address); + if (!Address) { + return; + } + std::array sizeCounts{}; + auto queue = tryGetWaitQueue(Address); + if (queue) { + std::lock_guard lk(queue->mutex); + sizeCounts = queue->sizeCounts; + } + notifyAtomicWaiters(Address, sizeCounts, false); + if (queue) { + queue->cv.notify_all(); + } +} + void WINAPI InitializeCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { HOST_CONTEXT_GUARD(); DEBUG_LOG("InitializeCriticalSection(%p)\n", lpCriticalSection); @@ -689,13 +1041,13 @@ void WINAPI DeleteCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { lpCriticalSection->SpinCount = 0; setOwningThread(lpCriticalSection, 0); - std::atomic_ref sequence(reinterpret_cast(lpCriticalSection->LockSemaphore)); - sequence.store(0, std::memory_order_release); - sequence.notify_all(); + auto *sequence = reinterpret_cast(&lpCriticalSection->LockSemaphore); + kernel32::InterlockedExchange(const_cast(sequence), 0); + kernel32::WakeByAddressAll(sequence); - std::atomic_ref lockCount(lpCriticalSection->LockCount); - lockCount.store(-1, std::memory_order_release); - lockCount.notify_all(); + auto *lockCount = reinterpret_cast(&lpCriticalSection->LockCount); + kernel32::InterlockedExchange(const_cast(lockCount), -1); + kernel32::WakeByAddressAll(lockCount); } BOOL WINAPI TryEnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { @@ -706,18 +1058,18 @@ BOOL WINAPI TryEnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { return FALSE; } - std::atomic_ref lockCount(lpCriticalSection->LockCount); + auto *lockCount = const_cast(&lpCriticalSection->LockCount); const DWORD threadId = GetCurrentThreadId(); - LONG expected = -1; - if (lockCount.compare_exchange_strong(expected, 0, std::memory_order_acq_rel, std::memory_order_acquire)) { + LONG previous = kernel32::InterlockedCompareExchange(lockCount, 0, -1); + if (previous == -1) { setOwningThread(lpCriticalSection, threadId); lpCriticalSection->RecursionCount = 1; return TRUE; } if (owningThreadId(lpCriticalSection) == threadId) { - lockCount.fetch_add(1, std::memory_order_acq_rel); + kernel32::InterlockedIncrement(lockCount); lpCriticalSection->RecursionCount++; return TRUE; } @@ -737,8 +1089,8 @@ void WINAPI EnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { return; } - std::atomic_ref lockCount(lpCriticalSection->LockCount); - LONG result = lockCount.fetch_add(1, std::memory_order_acq_rel) + 1; + auto *lockCount = const_cast(&lpCriticalSection->LockCount); + LONG result = kernel32::InterlockedIncrement(lockCount); if (result) { if (owningThreadId(lpCriticalSection) == threadId) { lpCriticalSection->RecursionCount++; @@ -766,15 +1118,14 @@ void WINAPI LeaveCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { return; } + auto *lockCount = const_cast(&lpCriticalSection->LockCount); if (--lpCriticalSection->RecursionCount > 0) { - std::atomic_ref lockCount(lpCriticalSection->LockCount); - lockCount.fetch_sub(1, std::memory_order_acq_rel); + kernel32::InterlockedDecrement(lockCount); return; } setOwningThread(lpCriticalSection, 0); - std::atomic_ref lockCount(lpCriticalSection->LockCount); - LONG newValue = lockCount.fetch_sub(1, std::memory_order_acq_rel) - 1; + LONG newValue = kernel32::InterlockedDecrement(lockCount); if (newValue >= 0) { signalCriticalSection(lpCriticalSection); } @@ -792,14 +1143,14 @@ BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL return FALSE; } - std::atomic_ref state(lpInitOnce->Ptr); + auto *state = &lpInitOnce->Ptr; if (dwFlags & INIT_ONCE_CHECK_ONLY) { if (dwFlags & INIT_ONCE_ASYNC) { setLastError(ERROR_INVALID_PARAMETER); return FALSE; } - GUEST_PTR val = state.load(std::memory_order_acquire); + GUEST_PTR val = __atomic_load_n(state, __ATOMIC_ACQUIRE); if ((val & kInitOnceStateMask) != kInitOnceCompletedFlag) { if (fPending) { *fPending = TRUE; @@ -817,13 +1168,13 @@ BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL } while (true) { - GUEST_PTR val = state.load(std::memory_order_acquire); + GUEST_PTR val = __atomic_load_n(state, __ATOMIC_ACQUIRE); switch (val & kInitOnceStateMask) { case 0: { // first time if (dwFlags & INIT_ONCE_ASYNC) { GUEST_PTR expected = 0; - if (state.compare_exchange_strong(expected, static_cast(3), std::memory_order_acq_rel, - std::memory_order_acquire)) { + if (__atomic_compare_exchange_n(state, &expected, static_cast(3), false, __ATOMIC_ACQ_REL, + __ATOMIC_ACQUIRE)) { if (fPending) { *fPending = TRUE; } @@ -832,8 +1183,8 @@ BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL } else { auto syncState = std::make_shared(); GUEST_PTR expected = 0; - if (state.compare_exchange_strong(expected, static_cast(1), std::memory_order_acq_rel, - std::memory_order_acquire)) { + if (__atomic_compare_exchange_n(state, &expected, static_cast(1), false, __ATOMIC_ACQ_REL, + __ATOMIC_ACQUIRE)) { insertInitOnceState(lpInitOnce, syncState); if (fPending) { *fPending = TRUE; @@ -924,11 +1275,11 @@ BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpCon return FALSE; } - std::atomic_ref state(lpInitOnce->Ptr); + auto *state = &lpInitOnce->Ptr; const GUEST_PTR finalValue = markFailed ? 0 : (contextValue | kInitOnceCompletedFlag); while (true) { - GUEST_PTR val = state.load(std::memory_order_acquire); + GUEST_PTR val = __atomic_load_n(state, __ATOMIC_ACQUIRE); switch (val & kInitOnceStateMask) { case 1: { auto syncState = getInitOnceState(lpInitOnce); @@ -936,7 +1287,7 @@ BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpCon setLastError(ERROR_GEN_FAILURE); return FALSE; } - if (!state.compare_exchange_strong(val, finalValue, std::memory_order_acq_rel, std::memory_order_acquire)) { + if (!__atomic_compare_exchange_n(state, &val, finalValue, false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) { continue; } { @@ -954,7 +1305,7 @@ BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpCon setLastError(ERROR_INVALID_PARAMETER); return FALSE; } - if (!state.compare_exchange_strong(val, finalValue, std::memory_order_acq_rel, std::memory_order_acquire)) { + if (!__atomic_compare_exchange_n(state, &val, finalValue, false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) { continue; } return TRUE; @@ -971,15 +1322,16 @@ void WINAPI AcquireSRWLockShared(PSRWLOCK SRWLock) { if (!SRWLock) { return; } - std::atomic_ref value(SRWLock->Value); + auto *value = &SRWLock->Value; while (true) { - ULONG current = value.load(std::memory_order_acquire); + ULONG current = __atomic_load_n(value, __ATOMIC_ACQUIRE); if (current & kSrwLockExclusive) { - value.wait(current, std::memory_order_relaxed); + ULONG observed = current; + kernel32::WaitOnAddress(reinterpret_cast(value), &observed, sizeof(observed), INFINITE); continue; } ULONG desired = current + kSrwLockSharedIncrement; - if (value.compare_exchange_weak(current, desired, std::memory_order_acq_rel, std::memory_order_acquire)) { + if (__atomic_compare_exchange_n(value, ¤t, desired, true, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) { return; } } @@ -991,11 +1343,11 @@ void WINAPI ReleaseSRWLockShared(PSRWLOCK SRWLock) { if (!SRWLock) { return; } - std::atomic_ref value(SRWLock->Value); - ULONG previous = value.fetch_sub(kSrwLockSharedIncrement, std::memory_order_acq_rel); + auto *value = &SRWLock->Value; + ULONG previous = __atomic_fetch_sub(value, kSrwLockSharedIncrement, __ATOMIC_ACQ_REL); ULONG newValue = previous - kSrwLockSharedIncrement; if (newValue == 0) { - value.notify_all(); + kernel32::WakeByAddressAll(value); } } @@ -1005,14 +1357,14 @@ void WINAPI AcquireSRWLockExclusive(PSRWLOCK SRWLock) { if (!SRWLock) { return; } - std::atomic_ref value(SRWLock->Value); + auto *value = &SRWLock->Value; while (true) { ULONG expected = 0; - if (value.compare_exchange_strong(expected, kSrwLockExclusive, std::memory_order_acq_rel, - std::memory_order_acquire)) { + if (__atomic_compare_exchange_n(value, &expected, kSrwLockExclusive, false, __ATOMIC_ACQ_REL, + __ATOMIC_ACQUIRE)) { return; } - value.wait(expected, std::memory_order_relaxed); + kernel32::WaitOnAddress(reinterpret_cast(value), &expected, sizeof(expected), INFINITE); } } @@ -1022,9 +1374,8 @@ void WINAPI ReleaseSRWLockExclusive(PSRWLOCK SRWLock) { if (!SRWLock) { return; } - std::atomic_ref value(SRWLock->Value); - value.store(0, std::memory_order_release); - value.notify_all(); + __atomic_store_n(&SRWLock->Value, 0u, __ATOMIC_RELEASE); + kernel32::WakeByAddressAll(&SRWLock->Value); } BOOLEAN WINAPI TryAcquireSRWLockExclusive(PSRWLOCK SRWLock) { @@ -1033,10 +1384,9 @@ BOOLEAN WINAPI TryAcquireSRWLockExclusive(PSRWLOCK SRWLock) { if (!SRWLock) { return FALSE; } - std::atomic_ref value(SRWLock->Value); ULONG expected = 0; - if (value.compare_exchange_strong(expected, kSrwLockExclusive, std::memory_order_acq_rel, - std::memory_order_acquire)) { + if (__atomic_compare_exchange_n(&SRWLock->Value, &expected, kSrwLockExclusive, false, __ATOMIC_ACQ_REL, + __ATOMIC_ACQUIRE)) { return TRUE; } return FALSE; @@ -1048,11 +1398,10 @@ BOOLEAN WINAPI TryAcquireSRWLockShared(PSRWLOCK SRWLock) { if (!SRWLock) { return FALSE; } - std::atomic_ref value(SRWLock->Value); - ULONG current = value.load(std::memory_order_acquire); + ULONG current = __atomic_load_n(&SRWLock->Value, __ATOMIC_ACQUIRE); while (!(current & kSrwLockExclusive)) { ULONG desired = current + kSrwLockSharedIncrement; - if (value.compare_exchange_weak(current, desired, std::memory_order_acq_rel, std::memory_order_acquire)) { + if (__atomic_compare_exchange_n(&SRWLock->Value, ¤t, desired, true, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) { return TRUE; } } diff --git a/dll/kernel32/synchapi.h b/dll/kernel32/synchapi.h index b64d470..22a75df 100644 --- a/dll/kernel32/synchapi.h +++ b/dll/kernel32/synchapi.h @@ -109,5 +109,8 @@ void WINAPI AcquireSRWLockExclusive(PSRWLOCK SRWLock); void WINAPI ReleaseSRWLockExclusive(PSRWLOCK SRWLock); BOOLEAN WINAPI TryAcquireSRWLockExclusive(PSRWLOCK SRWLock); BOOLEAN WINAPI TryAcquireSRWLockShared(PSRWLOCK SRWLock); +BOOL WINAPI WaitOnAddress(volatile VOID *Address, PVOID CompareAddress, SIZE_T AddressSize, DWORD dwMilliseconds); +void WINAPI WakeByAddressSingle(PVOID Address); +void WINAPI WakeByAddressAll(PVOID Address); } // namespace kernel32 diff --git a/src/errors.h b/src/errors.h index 443f20c..422db2d 100644 --- a/src/errors.h +++ b/src/errors.h @@ -44,6 +44,7 @@ #define ERROR_INVALID_FLAGS 1004 #define ERROR_ALREADY_EXISTS 183 #define ERROR_NOT_OWNER 288 +#define ERROR_TIMEOUT 1460 #define ERROR_TOO_MANY_POSTS 298 #define ERROR_SEM_TIMEOUT 121 #define ERROR_SXS_KEY_NOT_FOUND 14007 diff --git a/src/modules.cpp b/src/modules.cpp index 6127800..d7390d8 100644 --- a/src/modules.cpp +++ b/src/modules.cpp @@ -46,7 +46,7 @@ template void stubThunk(); namespace { -const std::array, 17> kApiSet = { +const std::array, 18> kApiSet = { std::pair{"api-ms-win-core-crt-l1-1-0.dll", "msvcrt.dll"}, std::pair{"api-ms-win-core-crt-l2-1-0.dll", "msvcrt.dll"}, std::pair{"api-ms-win-crt-conio-l1-1-0.dll", "msvcrt.dll"}, @@ -64,6 +64,7 @@ const std::array, 17> kApiSet = { std::pair{"api-ms-win-crt-string-l1-1-0.dll", "ucrtbase.dll"}, std::pair{"api-ms-win-crt-time-l1-1-0.dll", "ucrtbase.dll"}, std::pair{"api-ms-win-crt-utility-l1-1-0.dll", "ucrtbase.dll"}, + std::pair{"api-ms-win-core-synch-l1-2-0.dll", "kernelbase.dll"}, }; constexpr DWORD DLL_PROCESS_DETACH = 0; diff --git a/src/setup.S b/src/setup.S index a3d7e37..57d0b7c 100644 --- a/src/setup.S +++ b/src/setup.S @@ -28,7 +28,7 @@ ASM_END(installSelectors) # int setThreadArea64(int entryNumber, TEB *teb) # Runs syscall SYS_set_thread_area in 32-bit mode ASM_GLOBAL(setThreadArea64, @function) - push rbx # save rbx + push rbx # save rbx mov r8, rsp # save host stack mov rdx, qword ptr [rsi+TEB_SP] # fetch guest stack LJMP32 rsi # far jump into 32-bit code @@ -47,7 +47,7 @@ ASM_GLOBAL(setThreadArea64, @function) 1: add esp, 0x10 # cleanup stack LJMP64 esi # far jump into 64-bit code - cdqe # sign-extend eax to rax + cdqe # sign-extend eax to rax mov rsp, r8 # switch to host stack pop rbx # restore rbx ret @@ -59,15 +59,17 @@ ASM_END(setThreadArea64) # bool installSelectors(TEB *teb) ASM_GLOBAL(installSelectors, @function) - mov ax, word ptr [rdi+TEB_DS_SEL] # fetch data segment selector - mov dx, word ptr [rdi+TEB_FS_SEL] # fetch fs selector - LJMP32 rdi # far jump into 32-bit code (sets cs) - mov ds, ax # setup data segment - mov es, ax # setup extra segment - mov fs, dx # setup fs segment - LJMP64 edi # far jump into 64-bit code - mov rax, 1 # return true - ret + mov ax, cs # fetch host code segment selector + mov word ptr [rdi+TEB_HOST_CS_SEL], ax # store host code segment selector + mov ax, word ptr [rdi+TEB_DS_SEL] # fetch data segment selector + mov dx, word ptr [rdi+TEB_FS_SEL] # fetch fs selector + LJMP32 rdi # far jump into 32-bit code (sets cs) + mov ds, ax # setup data segment + mov es, ax # setup extra segment + mov fs, dx # setup fs segment + LJMP64 edi # far jump into 64-bit code + mov rax, 1 # return true + ret ASM_END(installSelectors) #endif @@ -80,10 +82,10 @@ ASM_GLOBAL(STUB_THUNK_SYMBOL, @function) #define STUB_THUNK_SYMBOL _Z9stubThunkILj\()\number\()EEvv ASM_GLOBAL(STUB_THUNK_SYMBOL, @function) #endif - pop eax - push \number - push eax - jmp SYMBOL_NAME(thunk_entry_stubBase) + pop eax + push \number + push eax + jmp SYMBOL_NAME(thunk_entry_stubBase) ASM_END(STUB_THUNK_SYMBOL) .endm diff --git a/src/setup_darwin.cpp b/src/setup_darwin.cpp index 0ad4b73..8aa559e 100644 --- a/src/setup_darwin.cpp +++ b/src/setup_darwin.cpp @@ -29,6 +29,13 @@ std::array g_ldtBitmap{}; bool g_ldtBitmapInitialized = false; int g_ldtHint = 1; +inline ldt_entry newLdtEntry() { + ldt_entry entry; // NOLINT(cppcoreguidelines-pro-type-member-init) + // Must memset to zero to avoid uninitialized padding bytes + std::memset(&entry, 0, sizeof(ldt_entry)); + return entry; +} + inline ldt_entry createLdtEntry(uint32_t base, uint32_t size, bool code) { uint32_t limit; uint8_t granular; @@ -39,9 +46,7 @@ inline ldt_entry createLdtEntry(uint32_t base, uint32_t size, bool code) { limit = size - 1; granular = DESC_GRAN_BYTE; } - ldt_entry entry; // NOLINT(cppcoreguidelines-pro-type-member-init) - // Must memset to zero to avoid uninitialized padding bytes - std::memset(&entry, 0, sizeof(ldt_entry)); + ldt_entry entry = newLdtEntry(); entry.code.limit00 = static_cast(limit); entry.code.base00 = static_cast(base); entry.code.base16 = static_cast(base >> 16); @@ -228,7 +233,7 @@ bool tebThreadSetup(TEB *teb) { return false; } teb->CurrentFsSelector = createSelector(entryNumber); - DEBUG_LOG("Installing cs %d, ds %d, fs %d\n", teb->CodeSelector, teb->DataSelector, teb->CurrentFsSelector); + DEBUG_LOG("setup_darwin: Installing cs %d, ds %d, fs %d\n", teb->CodeSelector, teb->DataSelector, teb->CurrentFsSelector); installSelectors(teb); return true; } @@ -244,7 +249,8 @@ bool tebThreadTeardown(TEB *teb) { return true; } int entryNumber = selector >> 3; - int ret = i386_set_ldt(entryNumber, nullptr, 1); + ldt_entry entry = newLdtEntry(); + int ret = i386_set_ldt(entryNumber, &entry, 1); if (ret < 0) { return false; } diff --git a/test/test_wait_on_address.c b/test/test_wait_on_address.c new file mode 100644 index 0000000..61ace46 --- /dev/null +++ b/test/test_wait_on_address.c @@ -0,0 +1,142 @@ +#include "test_assert.h" + +#include +#include + +typedef struct { + volatile LONG *value; + LONG expected; + LONG triggerValue; + HANDLE readyEvent; + HANDLE doneEvent; +} WaitContext; + +static DWORD WINAPI wait_thread(LPVOID param) { + WaitContext *ctx = (WaitContext *)param; + TEST_CHECK(SetEvent(ctx->readyEvent)); + BOOL ok = WaitOnAddress((volatile VOID *)ctx->value, &ctx->expected, sizeof(LONG), INFINITE); + TEST_CHECK(ok); + TEST_CHECK_EQ(ctx->triggerValue, *ctx->value); + TEST_CHECK(SetEvent(ctx->doneEvent)); + return 0; +} + +static void close_handles(HANDLE *handles, size_t count) { + for (size_t i = 0; i < count; ++i) { + if (handles[i]) { + TEST_CHECK(CloseHandle(handles[i])); + } + } +} + +static HANDLE create_event(void) { + HANDLE evt = CreateEventA(NULL, FALSE, FALSE, NULL); + TEST_CHECK(evt != NULL); + return evt; +} + +static void test_wait_on_address_single(void) { + volatile LONG value = 0; + LONG expected = 0; + + HANDLE ready = create_event(); + HANDLE done = create_event(); + + WaitContext ctx = { + .value = &value, + .expected = expected, + .triggerValue = 1, + .readyEvent = ready, + .doneEvent = done, + }; + + HANDLE thread = CreateThread(NULL, 0, wait_thread, &ctx, 0, NULL); + TEST_CHECK(thread != NULL); + + DWORD wait = WaitForSingleObject(ready, 1000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + Sleep(10); + + TEST_CHECK_EQ(0, InterlockedExchange((LONG *)&value, ctx.triggerValue)); + WakeByAddressSingle((PVOID)&value); + + wait = WaitForSingleObject(done, 1000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + wait = WaitForSingleObject(thread, 1000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + + HANDLE handles[] = {thread, ready, done}; + close_handles(handles, sizeof(handles) / sizeof(handles[0])); +} + +static void test_wait_on_address_all(void) { + volatile LONG value = 0; + const LONG expected = 0; + const LONG finalValue = 42; + + HANDLE ready[2] = {create_event(), create_event()}; + HANDLE done[2] = {create_event(), create_event()}; + + WaitContext ctx[2] = { + {&value, expected, finalValue, ready[0], done[0]}, + {&value, expected, finalValue, ready[1], done[1]}, + }; + + HANDLE threads[2]; + for (int i = 0; i < 2; ++i) { + threads[i] = CreateThread(NULL, 0, wait_thread, &ctx[i], 0, NULL); + TEST_CHECK(threads[i] != NULL); + } + + for (int i = 0; i < 2; ++i) { + DWORD wait = WaitForSingleObject(ready[i], 1000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + } + + Sleep(10); + + TEST_CHECK_EQ(0, InterlockedExchange((LONG *)&value, finalValue)); + WakeByAddressAll((PVOID)&value); + + for (int i = 0; i < 2; ++i) { + DWORD wait = WaitForSingleObject(done[i], 1000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + wait = WaitForSingleObject(threads[i], 1000); + TEST_CHECK_EQ(WAIT_OBJECT_0, wait); + } + + HANDLE handles[] = { + threads[0], threads[1], ready[0], ready[1], done[0], done[1], + }; + close_handles(handles, sizeof(handles) / sizeof(handles[0])); +} + +static void test_wait_on_address_timeout(void) { + volatile LONG value = 7; + LONG expected = 7; + + SetLastError(0); + BOOL ok = WaitOnAddress((volatile VOID *)&value, &expected, sizeof(LONG), 50); + TEST_CHECK(!ok); + TEST_CHECK_EQ(ERROR_TIMEOUT, GetLastError()); + TEST_CHECK_EQ(7, value); +} + +static void test_wait_on_address_immediate(void) { + volatile LONG value = 10; + LONG expected = 11; + + BOOL ok = WaitOnAddress((volatile VOID *)&value, &expected, sizeof(LONG), 1000); + TEST_CHECK(ok); + TEST_CHECK_EQ(10, value); +} + +int main(void) { + test_wait_on_address_single(); + test_wait_on_address_all(); + test_wait_on_address_timeout(); + test_wait_on_address_immediate(); + return EXIT_SUCCESS; +}