Implement WaitOnAddress, WakeByAddress*; macOS impl for atomic waits

This commit is contained in:
2025-11-11 10:44:14 -07:00
parent f18f2a929d
commit 09a7452c77
9 changed files with 589 additions and 82 deletions

View File

@@ -258,6 +258,8 @@ target_include_directories(wibo PRIVATE dll src ${WIBO_GENERATED_HEADER_DIR})
target_link_libraries(wibo PRIVATE mimalloc-obj)
if (CMAKE_SYSTEM_NAME STREQUAL "Linux")
target_link_libraries(wibo PRIVATE atomic)
elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin")
target_link_libraries(wibo PRIVATE c++ c++abi)
endif()
if (WIBO_ENABLE_LIBURING AND CMAKE_SYSTEM_NAME STREQUAL "Linux")
target_compile_definitions(wibo PRIVATE WIBO_ENABLE_LIBURING=1)
@@ -465,6 +467,7 @@ if (WIBO_ENABLE_FIXTURE_TESTS)
wibo_add_fixture_bin(NAME test_sysdir SOURCES test/test_sysdir.c)
wibo_add_fixture_bin(NAME test_srw_lock SOURCES test/test_srw_lock.c)
wibo_add_fixture_bin(NAME test_init_once SOURCES test/test_init_once.c)
wibo_add_fixture_bin(NAME test_wait_on_address SOURCES test/test_wait_on_address.c COMPILE_OPTIONS -lsynchronization)
# DLLs for fixture tests
wibo_add_fixture_dll(NAME external_exports SOURCES test/external_exports.c)

View File

@@ -10,7 +10,7 @@ set(CMAKE_ASM_COMPILER_TARGET ${TARGET})
# Force x86_64 architecture
set(CMAKE_OSX_ARCHITECTURES "x86_64" CACHE STRING "Build architecture for macOS" FORCE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.15" CACHE STRING "Minimum macOS deployment version" FORCE)
set(CMAKE_OSX_DEPLOYMENT_TARGET "11.0" CACHE STRING "Minimum macOS deployment version" FORCE)
# Search for programs in the build host directories
set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER)

View File

@@ -5,16 +5,20 @@
#include "errors.h"
#include "handles.h"
#include "heap.h"
#include "interlockedapi.h"
#include "internal.h"
#include "processthreadsapi.h"
#include "strutil.h"
#include "types.h"
#include <algorithm>
#include <array>
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <cstdint>
#include <cstring>
#include <dlfcn.h>
#include <memory>
#include <mutex>
#include <optional>
@@ -24,6 +28,7 @@
#include <thread>
#include <unistd.h>
#include <unordered_map>
#include <utility>
#include <vector>
namespace {
@@ -33,6 +38,292 @@ constexpr DWORD kSrwLockSharedIncrement = 0x2u;
constexpr GUEST_PTR kInitOnceStateMask = 0x3u;
constexpr GUEST_PTR kInitOnceCompletedFlag = 0x2u;
constexpr GUEST_PTR kInitOnceReservedMask = (1u << INIT_ONCE_CTX_RESERVED_BITS) - 1;
constexpr size_t kSupportedAddressSizes = 4; // 1, 2, 4, 8 bytes
struct AddressWaitQueue {
std::mutex mutex;
std::condition_variable cv;
size_t waiterCount = 0;
std::array<size_t, kSupportedAddressSizes> sizeCounts{};
};
std::mutex g_waitAddressMutex;
std::unordered_map<void *, std::weak_ptr<AddressWaitQueue>> g_waitAddressQueues;
constexpr size_t sizeToIndex(size_t size) {
size_t index = __builtin_ctz(size);
return index >= kSupportedAddressSizes ? -1 : index;
}
std::shared_ptr<AddressWaitQueue> getWaitQueue(void *address) {
std::lock_guard lk(g_waitAddressMutex);
auto &slot = g_waitAddressQueues[address];
auto queue = slot.lock();
if (!queue) {
queue = std::make_shared<AddressWaitQueue>();
slot = queue;
}
return queue;
}
std::shared_ptr<AddressWaitQueue> tryGetWaitQueue(void *address) {
std::lock_guard lk(g_waitAddressMutex);
auto it = g_waitAddressQueues.find(address);
if (it == g_waitAddressQueues.end()) {
return nullptr;
}
auto queue = it->second.lock();
if (!queue) {
g_waitAddressQueues.erase(it);
return nullptr;
}
return queue;
}
void cleanupWaitQueue(void *address, const std::shared_ptr<AddressWaitQueue> &queue) {
std::lock_guard lk(g_waitAddressMutex);
auto it = g_waitAddressQueues.find(address);
if (it == g_waitAddressQueues.end()) {
return;
}
auto locked = it->second.lock();
if (!locked) {
g_waitAddressQueues.erase(it);
return;
}
if (locked.get() != queue.get()) {
return;
}
std::lock_guard queueLock(queue->mutex);
if (queue->waiterCount == 0) {
g_waitAddressQueues.erase(it);
}
}
struct WaitRegistration {
void *address;
std::shared_ptr<AddressWaitQueue> queue;
size_t sizeIndex;
bool registered = false;
WaitRegistration(void *addr, std::shared_ptr<AddressWaitQueue> q, size_t idx)
: address(addr), queue(std::move(q)), sizeIndex(idx) {}
void registerWaiter() {
if (!queue) {
return;
}
std::lock_guard lk(queue->mutex);
queue->waiterCount++;
queue->sizeCounts[sizeIndex]++;
registered = true;
}
void unregister() {
if (!queue || !registered) {
return;
}
std::lock_guard lk(queue->mutex);
queue->waiterCount--;
queue->sizeCounts[sizeIndex]--;
registered = false;
}
~WaitRegistration() {
unregister();
if (queue) {
cleanupWaitQueue(address, queue);
}
}
};
#if defined(__APPLE__)
using LibcppMonitorFn = long long (*)(const void volatile *);
using LibcppWaitFn = void (*)(const void volatile *, long long);
using LibcppNotifyFn = void (*)(const void volatile *);
LibcppMonitorFn getLibcppAtomicMonitor() {
static LibcppMonitorFn fn =
reinterpret_cast<LibcppMonitorFn>(dlsym(RTLD_DEFAULT, "_ZNSt3__123__libcpp_atomic_monitorEPVKv"));
return fn;
}
LibcppWaitFn getLibcppAtomicWait() {
static LibcppWaitFn fn =
reinterpret_cast<LibcppWaitFn>(dlsym(RTLD_DEFAULT, "_ZNSt3__120__libcpp_atomic_waitEPVKvx"));
return fn;
}
LibcppNotifyFn getLibcppAtomicNotifyOne() {
static LibcppNotifyFn fn =
reinterpret_cast<LibcppNotifyFn>(dlsym(RTLD_DEFAULT, "_ZNSt3__123__cxx_atomic_notify_oneEPVKv"));
return fn;
}
LibcppNotifyFn getLibcppAtomicNotifyAll() {
static LibcppNotifyFn fn =
reinterpret_cast<LibcppNotifyFn>(dlsym(RTLD_DEFAULT, "_ZNSt3__123__cxx_atomic_notify_allEPVKv"));
return fn;
}
template <typename T> void platformWaitIndefinite(T volatile *address, T expected) {
auto monitorFn = getLibcppAtomicMonitor();
auto waitFn = getLibcppAtomicWait();
if (!monitorFn || !waitFn) {
while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == expected) {
std::this_thread::sleep_for(std::chrono::microseconds(50));
}
return;
}
while (true) {
T current = __atomic_load_n(address, __ATOMIC_ACQUIRE);
if (current != expected) {
return;
}
auto monitor = monitorFn(address);
current = __atomic_load_n(address, __ATOMIC_ACQUIRE);
if (current != expected) {
continue;
}
waitFn(address, monitor);
}
}
inline void platformNotifyAddress(void *address, size_t, bool wakeOne) {
auto notifyFn = wakeOne ? getLibcppAtomicNotifyOne() : getLibcppAtomicNotifyAll();
if (notifyFn) {
notifyFn(address);
}
}
#elif defined(__linux__)
template <typename T> void platformWaitIndefinite(T volatile *address, T expected) {
std::atomic_ref<T> ref(*address);
ref.wait(expected, std::memory_order_relaxed);
}
template <typename T> void linuxNotify(void *address, bool wakeOne) {
auto *typed = reinterpret_cast<T *>(address);
std::atomic_ref<T> ref(*typed);
if (wakeOne) {
ref.notify_one();
} else {
ref.notify_all();
}
}
inline void platformNotifyAddress(void *address, size_t size, bool wakeOne) {
switch (size) {
case 1:
linuxNotify<uint8_t>(address, wakeOne);
break;
case 2:
linuxNotify<uint16_t>(address, wakeOne);
break;
case 4:
linuxNotify<uint32_t>(address, wakeOne);
break;
case 8:
linuxNotify<uint64_t>(address, wakeOne);
break;
default:
break;
}
}
#else
template <typename T> void platformWaitIndefinite(T volatile *address, T expected) {
while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == expected) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
inline void platformNotifyAddress(void *, size_t, bool) {}
#endif
void notifyAtomicWaiters(void *address, const std::array<size_t, kSupportedAddressSizes> &sizeCounts, bool wakeOne) {
uintptr_t addrValue = reinterpret_cast<uintptr_t>(address);
for (size_t i = 0; i < sizeCounts.size(); ++i) {
if (sizeCounts[i] == 0) {
continue;
}
size_t size = 1 << i;
if (addrValue & (size - 1)) {
continue;
}
platformNotifyAddress(address, size, wakeOne);
}
}
template <typename T> bool waitOnAddressTyped(VOID volatile *addressVoid, PVOID comparePtr, DWORD dwMilliseconds) {
auto *address = reinterpret_cast<T volatile *>(addressVoid);
if (!address) {
kernel32::setLastError(ERROR_INVALID_PARAMETER);
return false;
}
if (reinterpret_cast<uintptr_t>(address) % alignof(T) != 0) {
kernel32::setLastError(ERROR_INVALID_PARAMETER);
return false;
}
const T compareValue = *reinterpret_cast<const T *>(comparePtr);
if (__atomic_load_n(address, __ATOMIC_ACQUIRE) != compareValue) {
return true;
}
if (dwMilliseconds == 0) {
kernel32::setLastError(ERROR_TIMEOUT);
return false;
}
void *queueKey = const_cast<void *>(addressVoid);
auto queue = getWaitQueue(queueKey);
if (!queue) {
kernel32::setLastError(ERROR_GEN_FAILURE);
return false;
}
int sizeIdx = sizeToIndex(sizeof(T));
DEBUG_LOG("size: %d, index %d\n", sizeof(T), sizeIdx);
if (sizeIdx < 0) {
kernel32::setLastError(ERROR_INVALID_PARAMETER);
return false;
}
WaitRegistration registration(queueKey, queue, static_cast<size_t>(sizeIdx));
registration.registerWaiter();
if (dwMilliseconds == INFINITE) {
while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) {
platformWaitIndefinite(address, compareValue);
}
return true;
}
const auto deadline =
std::chrono::steady_clock::now() + std::chrono::milliseconds(static_cast<uint64_t>(dwMilliseconds));
bool timedOut = false;
{
std::unique_lock lk(queue->mutex);
while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) {
if (queue->cv.wait_until(lk, deadline) == std::cv_status::timeout) {
if (__atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) {
timedOut = true;
break;
}
}
}
}
if (timedOut && __atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) {
kernel32::setLastError(ERROR_TIMEOUT);
return false;
}
return true;
}
std::u16string makeU16String(LPCWSTR name) {
if (!name) {
@@ -167,29 +458,25 @@ void eraseInitOnceState(LPINIT_ONCE once) {
g_initOnceStates.erase(once);
}
inline DWORD owningThreadId(LPCRITICAL_SECTION crit) {
std::atomic_ref<DWORD> owner(*reinterpret_cast<DWORD *>(&crit->OwningThread));
return owner.load(std::memory_order_acquire);
}
inline DWORD owningThreadId(LPCRITICAL_SECTION crit) { return __atomic_load_n(&crit->OwningThread, __ATOMIC_ACQUIRE); }
inline void setOwningThread(LPCRITICAL_SECTION crit, DWORD threadId) {
std::atomic_ref<DWORD> owner(*reinterpret_cast<DWORD *>(&crit->OwningThread));
owner.store(threadId, std::memory_order_release);
__atomic_store_n(&crit->OwningThread, threadId, __ATOMIC_RELEASE);
}
void waitForCriticalSection(LPCRITICAL_SECTION cs) {
std::atomic_ref<LONG> sequence(reinterpret_cast<LONG &>(cs->LockSemaphore));
LONG observed = sequence.load(std::memory_order_acquire);
auto *sequence = reinterpret_cast<LONG volatile *>(&cs->LockSemaphore);
LONG observed = __atomic_load_n(sequence, __ATOMIC_ACQUIRE);
while (owningThreadId(cs) != 0) {
sequence.wait(observed, std::memory_order_relaxed);
observed = sequence.load(std::memory_order_acquire);
kernel32::WaitOnAddress(sequence, &observed, sizeof(observed), INFINITE);
observed = __atomic_load_n(sequence, __ATOMIC_ACQUIRE);
}
}
void signalCriticalSection(LPCRITICAL_SECTION cs) {
std::atomic_ref<LONG> sequence(reinterpret_cast<LONG &>(cs->LockSemaphore));
sequence.fetch_add(1, std::memory_order_release);
sequence.notify_one();
auto *sequence = reinterpret_cast<LONG *>(&cs->LockSemaphore);
kernel32::InterlockedIncrement(const_cast<LONG volatile *>(sequence));
kernel32::WakeByAddressSingle(sequence);
}
inline bool trySpinAcquireCriticalSection(LPCRITICAL_SECTION cs, DWORD threadId) {
@@ -634,6 +921,71 @@ DWORD WINAPI WaitForMultipleObjects(DWORD nCount, const HANDLE *lpHandles, BOOL
return waitResult;
}
BOOL WINAPI WaitOnAddress(VOID volatile *Address, PVOID CompareAddress, SIZE_T AddressSize, DWORD dwMilliseconds) {
HOST_CONTEXT_GUARD();
VERBOSE_LOG("WaitOnAddress(%p, %p, %zu, %u)\n", Address, CompareAddress, AddressSize, dwMilliseconds);
if (!Address || !CompareAddress) {
setLastError(ERROR_INVALID_PARAMETER);
return FALSE;
}
BOOL result = FALSE;
switch (sizeToIndex(AddressSize)) {
case 0:
result = waitOnAddressTyped<uint8_t>(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE;
break;
case 1:
result = waitOnAddressTyped<uint16_t>(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE;
break;
case 2:
result = waitOnAddressTyped<uint32_t>(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE;
break;
case 3:
result = waitOnAddressTyped<uint64_t>(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE;
break;
default:
setLastError(ERROR_INVALID_PARAMETER);
return FALSE;
}
return result;
}
void WINAPI WakeByAddressSingle(PVOID Address) {
HOST_CONTEXT_GUARD();
VERBOSE_LOG("WakeByAddressSingle(%p)\n", Address);
if (!Address) {
return;
}
std::array<size_t, kSupportedAddressSizes> sizeCounts{};
auto queue = tryGetWaitQueue(Address);
if (queue) {
std::lock_guard lk(queue->mutex);
sizeCounts = queue->sizeCounts;
}
notifyAtomicWaiters(Address, sizeCounts, true);
if (queue) {
queue->cv.notify_one();
}
}
void WINAPI WakeByAddressAll(PVOID Address) {
HOST_CONTEXT_GUARD();
VERBOSE_LOG("WakeByAddressAll(%p)\n", Address);
if (!Address) {
return;
}
std::array<size_t, kSupportedAddressSizes> sizeCounts{};
auto queue = tryGetWaitQueue(Address);
if (queue) {
std::lock_guard lk(queue->mutex);
sizeCounts = queue->sizeCounts;
}
notifyAtomicWaiters(Address, sizeCounts, false);
if (queue) {
queue->cv.notify_all();
}
}
void WINAPI InitializeCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
HOST_CONTEXT_GUARD();
DEBUG_LOG("InitializeCriticalSection(%p)\n", lpCriticalSection);
@@ -689,13 +1041,13 @@ void WINAPI DeleteCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
lpCriticalSection->SpinCount = 0;
setOwningThread(lpCriticalSection, 0);
std::atomic_ref<LONG> sequence(reinterpret_cast<LONG &>(lpCriticalSection->LockSemaphore));
sequence.store(0, std::memory_order_release);
sequence.notify_all();
auto *sequence = reinterpret_cast<LONG *>(&lpCriticalSection->LockSemaphore);
kernel32::InterlockedExchange(const_cast<LONG volatile *>(sequence), 0);
kernel32::WakeByAddressAll(sequence);
std::atomic_ref<LONG> lockCount(lpCriticalSection->LockCount);
lockCount.store(-1, std::memory_order_release);
lockCount.notify_all();
auto *lockCount = reinterpret_cast<LONG *>(&lpCriticalSection->LockCount);
kernel32::InterlockedExchange(const_cast<LONG volatile *>(lockCount), -1);
kernel32::WakeByAddressAll(lockCount);
}
BOOL WINAPI TryEnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
@@ -706,18 +1058,18 @@ BOOL WINAPI TryEnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
return FALSE;
}
std::atomic_ref<LONG> lockCount(lpCriticalSection->LockCount);
auto *lockCount = const_cast<LONG volatile *>(&lpCriticalSection->LockCount);
const DWORD threadId = GetCurrentThreadId();
LONG expected = -1;
if (lockCount.compare_exchange_strong(expected, 0, std::memory_order_acq_rel, std::memory_order_acquire)) {
LONG previous = kernel32::InterlockedCompareExchange(lockCount, 0, -1);
if (previous == -1) {
setOwningThread(lpCriticalSection, threadId);
lpCriticalSection->RecursionCount = 1;
return TRUE;
}
if (owningThreadId(lpCriticalSection) == threadId) {
lockCount.fetch_add(1, std::memory_order_acq_rel);
kernel32::InterlockedIncrement(lockCount);
lpCriticalSection->RecursionCount++;
return TRUE;
}
@@ -737,8 +1089,8 @@ void WINAPI EnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
return;
}
std::atomic_ref<LONG> lockCount(lpCriticalSection->LockCount);
LONG result = lockCount.fetch_add(1, std::memory_order_acq_rel) + 1;
auto *lockCount = const_cast<LONG volatile *>(&lpCriticalSection->LockCount);
LONG result = kernel32::InterlockedIncrement(lockCount);
if (result) {
if (owningThreadId(lpCriticalSection) == threadId) {
lpCriticalSection->RecursionCount++;
@@ -766,15 +1118,14 @@ void WINAPI LeaveCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
return;
}
auto *lockCount = const_cast<LONG volatile *>(&lpCriticalSection->LockCount);
if (--lpCriticalSection->RecursionCount > 0) {
std::atomic_ref<LONG> lockCount(lpCriticalSection->LockCount);
lockCount.fetch_sub(1, std::memory_order_acq_rel);
kernel32::InterlockedDecrement(lockCount);
return;
}
setOwningThread(lpCriticalSection, 0);
std::atomic_ref<LONG> lockCount(lpCriticalSection->LockCount);
LONG newValue = lockCount.fetch_sub(1, std::memory_order_acq_rel) - 1;
LONG newValue = kernel32::InterlockedDecrement(lockCount);
if (newValue >= 0) {
signalCriticalSection(lpCriticalSection);
}
@@ -792,14 +1143,14 @@ BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL
return FALSE;
}
std::atomic_ref<GUEST_PTR> state(lpInitOnce->Ptr);
auto *state = &lpInitOnce->Ptr;
if (dwFlags & INIT_ONCE_CHECK_ONLY) {
if (dwFlags & INIT_ONCE_ASYNC) {
setLastError(ERROR_INVALID_PARAMETER);
return FALSE;
}
GUEST_PTR val = state.load(std::memory_order_acquire);
GUEST_PTR val = __atomic_load_n(state, __ATOMIC_ACQUIRE);
if ((val & kInitOnceStateMask) != kInitOnceCompletedFlag) {
if (fPending) {
*fPending = TRUE;
@@ -817,13 +1168,13 @@ BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL
}
while (true) {
GUEST_PTR val = state.load(std::memory_order_acquire);
GUEST_PTR val = __atomic_load_n(state, __ATOMIC_ACQUIRE);
switch (val & kInitOnceStateMask) {
case 0: { // first time
if (dwFlags & INIT_ONCE_ASYNC) {
GUEST_PTR expected = 0;
if (state.compare_exchange_strong(expected, static_cast<GUEST_PTR>(3), std::memory_order_acq_rel,
std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(state, &expected, static_cast<GUEST_PTR>(3), false, __ATOMIC_ACQ_REL,
__ATOMIC_ACQUIRE)) {
if (fPending) {
*fPending = TRUE;
}
@@ -832,8 +1183,8 @@ BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL
} else {
auto syncState = std::make_shared<InitOnceState>();
GUEST_PTR expected = 0;
if (state.compare_exchange_strong(expected, static_cast<GUEST_PTR>(1), std::memory_order_acq_rel,
std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(state, &expected, static_cast<GUEST_PTR>(1), false, __ATOMIC_ACQ_REL,
__ATOMIC_ACQUIRE)) {
insertInitOnceState(lpInitOnce, syncState);
if (fPending) {
*fPending = TRUE;
@@ -924,11 +1275,11 @@ BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpCon
return FALSE;
}
std::atomic_ref<GUEST_PTR> state(lpInitOnce->Ptr);
auto *state = &lpInitOnce->Ptr;
const GUEST_PTR finalValue = markFailed ? 0 : (contextValue | kInitOnceCompletedFlag);
while (true) {
GUEST_PTR val = state.load(std::memory_order_acquire);
GUEST_PTR val = __atomic_load_n(state, __ATOMIC_ACQUIRE);
switch (val & kInitOnceStateMask) {
case 1: {
auto syncState = getInitOnceState(lpInitOnce);
@@ -936,7 +1287,7 @@ BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpCon
setLastError(ERROR_GEN_FAILURE);
return FALSE;
}
if (!state.compare_exchange_strong(val, finalValue, std::memory_order_acq_rel, std::memory_order_acquire)) {
if (!__atomic_compare_exchange_n(state, &val, finalValue, false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) {
continue;
}
{
@@ -954,7 +1305,7 @@ BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpCon
setLastError(ERROR_INVALID_PARAMETER);
return FALSE;
}
if (!state.compare_exchange_strong(val, finalValue, std::memory_order_acq_rel, std::memory_order_acquire)) {
if (!__atomic_compare_exchange_n(state, &val, finalValue, false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) {
continue;
}
return TRUE;
@@ -971,15 +1322,16 @@ void WINAPI AcquireSRWLockShared(PSRWLOCK SRWLock) {
if (!SRWLock) {
return;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
auto *value = &SRWLock->Value;
while (true) {
ULONG current = value.load(std::memory_order_acquire);
ULONG current = __atomic_load_n(value, __ATOMIC_ACQUIRE);
if (current & kSrwLockExclusive) {
value.wait(current, std::memory_order_relaxed);
ULONG observed = current;
kernel32::WaitOnAddress(reinterpret_cast<VOID volatile *>(value), &observed, sizeof(observed), INFINITE);
continue;
}
ULONG desired = current + kSrwLockSharedIncrement;
if (value.compare_exchange_weak(current, desired, std::memory_order_acq_rel, std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(value, &current, desired, true, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) {
return;
}
}
@@ -991,11 +1343,11 @@ void WINAPI ReleaseSRWLockShared(PSRWLOCK SRWLock) {
if (!SRWLock) {
return;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
ULONG previous = value.fetch_sub(kSrwLockSharedIncrement, std::memory_order_acq_rel);
auto *value = &SRWLock->Value;
ULONG previous = __atomic_fetch_sub(value, kSrwLockSharedIncrement, __ATOMIC_ACQ_REL);
ULONG newValue = previous - kSrwLockSharedIncrement;
if (newValue == 0) {
value.notify_all();
kernel32::WakeByAddressAll(value);
}
}
@@ -1005,14 +1357,14 @@ void WINAPI AcquireSRWLockExclusive(PSRWLOCK SRWLock) {
if (!SRWLock) {
return;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
auto *value = &SRWLock->Value;
while (true) {
ULONG expected = 0;
if (value.compare_exchange_strong(expected, kSrwLockExclusive, std::memory_order_acq_rel,
std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(value, &expected, kSrwLockExclusive, false, __ATOMIC_ACQ_REL,
__ATOMIC_ACQUIRE)) {
return;
}
value.wait(expected, std::memory_order_relaxed);
kernel32::WaitOnAddress(reinterpret_cast<VOID volatile *>(value), &expected, sizeof(expected), INFINITE);
}
}
@@ -1022,9 +1374,8 @@ void WINAPI ReleaseSRWLockExclusive(PSRWLOCK SRWLock) {
if (!SRWLock) {
return;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
value.store(0, std::memory_order_release);
value.notify_all();
__atomic_store_n(&SRWLock->Value, 0u, __ATOMIC_RELEASE);
kernel32::WakeByAddressAll(&SRWLock->Value);
}
BOOLEAN WINAPI TryAcquireSRWLockExclusive(PSRWLOCK SRWLock) {
@@ -1033,10 +1384,9 @@ BOOLEAN WINAPI TryAcquireSRWLockExclusive(PSRWLOCK SRWLock) {
if (!SRWLock) {
return FALSE;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
ULONG expected = 0;
if (value.compare_exchange_strong(expected, kSrwLockExclusive, std::memory_order_acq_rel,
std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(&SRWLock->Value, &expected, kSrwLockExclusive, false, __ATOMIC_ACQ_REL,
__ATOMIC_ACQUIRE)) {
return TRUE;
}
return FALSE;
@@ -1048,11 +1398,10 @@ BOOLEAN WINAPI TryAcquireSRWLockShared(PSRWLOCK SRWLock) {
if (!SRWLock) {
return FALSE;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
ULONG current = value.load(std::memory_order_acquire);
ULONG current = __atomic_load_n(&SRWLock->Value, __ATOMIC_ACQUIRE);
while (!(current & kSrwLockExclusive)) {
ULONG desired = current + kSrwLockSharedIncrement;
if (value.compare_exchange_weak(current, desired, std::memory_order_acq_rel, std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(&SRWLock->Value, &current, desired, true, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) {
return TRUE;
}
}

View File

@@ -109,5 +109,8 @@ void WINAPI AcquireSRWLockExclusive(PSRWLOCK SRWLock);
void WINAPI ReleaseSRWLockExclusive(PSRWLOCK SRWLock);
BOOLEAN WINAPI TryAcquireSRWLockExclusive(PSRWLOCK SRWLock);
BOOLEAN WINAPI TryAcquireSRWLockShared(PSRWLOCK SRWLock);
BOOL WINAPI WaitOnAddress(volatile VOID *Address, PVOID CompareAddress, SIZE_T AddressSize, DWORD dwMilliseconds);
void WINAPI WakeByAddressSingle(PVOID Address);
void WINAPI WakeByAddressAll(PVOID Address);
} // namespace kernel32

View File

@@ -44,6 +44,7 @@
#define ERROR_INVALID_FLAGS 1004
#define ERROR_ALREADY_EXISTS 183
#define ERROR_NOT_OWNER 288
#define ERROR_TIMEOUT 1460
#define ERROR_TOO_MANY_POSTS 298
#define ERROR_SEM_TIMEOUT 121
#define ERROR_SXS_KEY_NOT_FOUND 14007

View File

@@ -46,7 +46,7 @@ template <size_t Index> void stubThunk();
namespace {
const std::array<std::pair<std::string_view, std::string_view>, 17> kApiSet = {
const std::array<std::pair<std::string_view, std::string_view>, 18> kApiSet = {
std::pair{"api-ms-win-core-crt-l1-1-0.dll", "msvcrt.dll"},
std::pair{"api-ms-win-core-crt-l2-1-0.dll", "msvcrt.dll"},
std::pair{"api-ms-win-crt-conio-l1-1-0.dll", "msvcrt.dll"},
@@ -64,6 +64,7 @@ const std::array<std::pair<std::string_view, std::string_view>, 17> kApiSet = {
std::pair{"api-ms-win-crt-string-l1-1-0.dll", "ucrtbase.dll"},
std::pair{"api-ms-win-crt-time-l1-1-0.dll", "ucrtbase.dll"},
std::pair{"api-ms-win-crt-utility-l1-1-0.dll", "ucrtbase.dll"},
std::pair{"api-ms-win-core-synch-l1-2-0.dll", "kernelbase.dll"},
};
constexpr DWORD DLL_PROCESS_DETACH = 0;

View File

@@ -28,7 +28,7 @@ ASM_END(installSelectors)
# int setThreadArea64(int entryNumber, TEB *teb)
# Runs syscall SYS_set_thread_area in 32-bit mode
ASM_GLOBAL(setThreadArea64, @function)
push rbx # save rbx
push rbx # save rbx
mov r8, rsp # save host stack
mov rdx, qword ptr [rsi+TEB_SP] # fetch guest stack
LJMP32 rsi # far jump into 32-bit code
@@ -47,7 +47,7 @@ ASM_GLOBAL(setThreadArea64, @function)
1:
add esp, 0x10 # cleanup stack
LJMP64 esi # far jump into 64-bit code
cdqe # sign-extend eax to rax
cdqe # sign-extend eax to rax
mov rsp, r8 # switch to host stack
pop rbx # restore rbx
ret
@@ -59,15 +59,17 @@ ASM_END(setThreadArea64)
# bool installSelectors(TEB *teb)
ASM_GLOBAL(installSelectors, @function)
mov ax, word ptr [rdi+TEB_DS_SEL] # fetch data segment selector
mov dx, word ptr [rdi+TEB_FS_SEL] # fetch fs selector
LJMP32 rdi # far jump into 32-bit code (sets cs)
mov ds, ax # setup data segment
mov es, ax # setup extra segment
mov fs, dx # setup fs segment
LJMP64 edi # far jump into 64-bit code
mov rax, 1 # return true
ret
mov ax, cs # fetch host code segment selector
mov word ptr [rdi+TEB_HOST_CS_SEL], ax # store host code segment selector
mov ax, word ptr [rdi+TEB_DS_SEL] # fetch data segment selector
mov dx, word ptr [rdi+TEB_FS_SEL] # fetch fs selector
LJMP32 rdi # far jump into 32-bit code (sets cs)
mov ds, ax # setup data segment
mov es, ax # setup extra segment
mov fs, dx # setup fs segment
LJMP64 edi # far jump into 64-bit code
mov rax, 1 # return true
ret
ASM_END(installSelectors)
#endif
@@ -80,10 +82,10 @@ ASM_GLOBAL(STUB_THUNK_SYMBOL, @function)
#define STUB_THUNK_SYMBOL _Z9stubThunkILj\()\number\()EEvv
ASM_GLOBAL(STUB_THUNK_SYMBOL, @function)
#endif
pop eax
push \number
push eax
jmp SYMBOL_NAME(thunk_entry_stubBase)
pop eax
push \number
push eax
jmp SYMBOL_NAME(thunk_entry_stubBase)
ASM_END(STUB_THUNK_SYMBOL)
.endm

View File

@@ -29,6 +29,13 @@ std::array<uint32_t, kMaxLdtEntries / kBitsPerWord> g_ldtBitmap{};
bool g_ldtBitmapInitialized = false;
int g_ldtHint = 1;
inline ldt_entry newLdtEntry() {
ldt_entry entry; // NOLINT(cppcoreguidelines-pro-type-member-init)
// Must memset to zero to avoid uninitialized padding bytes
std::memset(&entry, 0, sizeof(ldt_entry));
return entry;
}
inline ldt_entry createLdtEntry(uint32_t base, uint32_t size, bool code) {
uint32_t limit;
uint8_t granular;
@@ -39,9 +46,7 @@ inline ldt_entry createLdtEntry(uint32_t base, uint32_t size, bool code) {
limit = size - 1;
granular = DESC_GRAN_BYTE;
}
ldt_entry entry; // NOLINT(cppcoreguidelines-pro-type-member-init)
// Must memset to zero to avoid uninitialized padding bytes
std::memset(&entry, 0, sizeof(ldt_entry));
ldt_entry entry = newLdtEntry();
entry.code.limit00 = static_cast<uint16_t>(limit);
entry.code.base00 = static_cast<uint16_t>(base);
entry.code.base16 = static_cast<uint8_t>(base >> 16);
@@ -228,7 +233,7 @@ bool tebThreadSetup(TEB *teb) {
return false;
}
teb->CurrentFsSelector = createSelector(entryNumber);
DEBUG_LOG("Installing cs %d, ds %d, fs %d\n", teb->CodeSelector, teb->DataSelector, teb->CurrentFsSelector);
DEBUG_LOG("setup_darwin: Installing cs %d, ds %d, fs %d\n", teb->CodeSelector, teb->DataSelector, teb->CurrentFsSelector);
installSelectors(teb);
return true;
}
@@ -244,7 +249,8 @@ bool tebThreadTeardown(TEB *teb) {
return true;
}
int entryNumber = selector >> 3;
int ret = i386_set_ldt(entryNumber, nullptr, 1);
ldt_entry entry = newLdtEntry();
int ret = i386_set_ldt(entryNumber, &entry, 1);
if (ret < 0) {
return false;
}

142
test/test_wait_on_address.c Normal file
View File

@@ -0,0 +1,142 @@
#include "test_assert.h"
#include <synchapi.h>
#include <windows.h>
typedef struct {
volatile LONG *value;
LONG expected;
LONG triggerValue;
HANDLE readyEvent;
HANDLE doneEvent;
} WaitContext;
static DWORD WINAPI wait_thread(LPVOID param) {
WaitContext *ctx = (WaitContext *)param;
TEST_CHECK(SetEvent(ctx->readyEvent));
BOOL ok = WaitOnAddress((volatile VOID *)ctx->value, &ctx->expected, sizeof(LONG), INFINITE);
TEST_CHECK(ok);
TEST_CHECK_EQ(ctx->triggerValue, *ctx->value);
TEST_CHECK(SetEvent(ctx->doneEvent));
return 0;
}
static void close_handles(HANDLE *handles, size_t count) {
for (size_t i = 0; i < count; ++i) {
if (handles[i]) {
TEST_CHECK(CloseHandle(handles[i]));
}
}
}
static HANDLE create_event(void) {
HANDLE evt = CreateEventA(NULL, FALSE, FALSE, NULL);
TEST_CHECK(evt != NULL);
return evt;
}
static void test_wait_on_address_single(void) {
volatile LONG value = 0;
LONG expected = 0;
HANDLE ready = create_event();
HANDLE done = create_event();
WaitContext ctx = {
.value = &value,
.expected = expected,
.triggerValue = 1,
.readyEvent = ready,
.doneEvent = done,
};
HANDLE thread = CreateThread(NULL, 0, wait_thread, &ctx, 0, NULL);
TEST_CHECK(thread != NULL);
DWORD wait = WaitForSingleObject(ready, 1000);
TEST_CHECK_EQ(WAIT_OBJECT_0, wait);
Sleep(10);
TEST_CHECK_EQ(0, InterlockedExchange((LONG *)&value, ctx.triggerValue));
WakeByAddressSingle((PVOID)&value);
wait = WaitForSingleObject(done, 1000);
TEST_CHECK_EQ(WAIT_OBJECT_0, wait);
wait = WaitForSingleObject(thread, 1000);
TEST_CHECK_EQ(WAIT_OBJECT_0, wait);
HANDLE handles[] = {thread, ready, done};
close_handles(handles, sizeof(handles) / sizeof(handles[0]));
}
static void test_wait_on_address_all(void) {
volatile LONG value = 0;
const LONG expected = 0;
const LONG finalValue = 42;
HANDLE ready[2] = {create_event(), create_event()};
HANDLE done[2] = {create_event(), create_event()};
WaitContext ctx[2] = {
{&value, expected, finalValue, ready[0], done[0]},
{&value, expected, finalValue, ready[1], done[1]},
};
HANDLE threads[2];
for (int i = 0; i < 2; ++i) {
threads[i] = CreateThread(NULL, 0, wait_thread, &ctx[i], 0, NULL);
TEST_CHECK(threads[i] != NULL);
}
for (int i = 0; i < 2; ++i) {
DWORD wait = WaitForSingleObject(ready[i], 1000);
TEST_CHECK_EQ(WAIT_OBJECT_0, wait);
}
Sleep(10);
TEST_CHECK_EQ(0, InterlockedExchange((LONG *)&value, finalValue));
WakeByAddressAll((PVOID)&value);
for (int i = 0; i < 2; ++i) {
DWORD wait = WaitForSingleObject(done[i], 1000);
TEST_CHECK_EQ(WAIT_OBJECT_0, wait);
wait = WaitForSingleObject(threads[i], 1000);
TEST_CHECK_EQ(WAIT_OBJECT_0, wait);
}
HANDLE handles[] = {
threads[0], threads[1], ready[0], ready[1], done[0], done[1],
};
close_handles(handles, sizeof(handles) / sizeof(handles[0]));
}
static void test_wait_on_address_timeout(void) {
volatile LONG value = 7;
LONG expected = 7;
SetLastError(0);
BOOL ok = WaitOnAddress((volatile VOID *)&value, &expected, sizeof(LONG), 50);
TEST_CHECK(!ok);
TEST_CHECK_EQ(ERROR_TIMEOUT, GetLastError());
TEST_CHECK_EQ(7, value);
}
static void test_wait_on_address_immediate(void) {
volatile LONG value = 10;
LONG expected = 11;
BOOL ok = WaitOnAddress((volatile VOID *)&value, &expected, sizeof(LONG), 1000);
TEST_CHECK(ok);
TEST_CHECK_EQ(10, value);
}
int main(void) {
test_wait_on_address_single();
test_wait_on_address_all();
test_wait_on_address_timeout();
test_wait_on_address_immediate();
return EXIT_SUCCESS;
}