Implement WaitOnAddress, WakeByAddress*; macOS impl for atomic waits

This commit is contained in:
2025-11-11 10:44:14 -07:00
parent f18f2a929d
commit 09a7452c77
9 changed files with 589 additions and 82 deletions

View File

@@ -5,16 +5,20 @@
#include "errors.h"
#include "handles.h"
#include "heap.h"
#include "interlockedapi.h"
#include "internal.h"
#include "processthreadsapi.h"
#include "strutil.h"
#include "types.h"
#include <algorithm>
#include <array>
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <cstdint>
#include <cstring>
#include <dlfcn.h>
#include <memory>
#include <mutex>
#include <optional>
@@ -24,6 +28,7 @@
#include <thread>
#include <unistd.h>
#include <unordered_map>
#include <utility>
#include <vector>
namespace {
@@ -33,6 +38,292 @@ constexpr DWORD kSrwLockSharedIncrement = 0x2u;
constexpr GUEST_PTR kInitOnceStateMask = 0x3u;
constexpr GUEST_PTR kInitOnceCompletedFlag = 0x2u;
constexpr GUEST_PTR kInitOnceReservedMask = (1u << INIT_ONCE_CTX_RESERVED_BITS) - 1;
constexpr size_t kSupportedAddressSizes = 4; // 1, 2, 4, 8 bytes
struct AddressWaitQueue {
std::mutex mutex;
std::condition_variable cv;
size_t waiterCount = 0;
std::array<size_t, kSupportedAddressSizes> sizeCounts{};
};
std::mutex g_waitAddressMutex;
std::unordered_map<void *, std::weak_ptr<AddressWaitQueue>> g_waitAddressQueues;
constexpr size_t sizeToIndex(size_t size) {
size_t index = __builtin_ctz(size);
return index >= kSupportedAddressSizes ? -1 : index;
}
std::shared_ptr<AddressWaitQueue> getWaitQueue(void *address) {
std::lock_guard lk(g_waitAddressMutex);
auto &slot = g_waitAddressQueues[address];
auto queue = slot.lock();
if (!queue) {
queue = std::make_shared<AddressWaitQueue>();
slot = queue;
}
return queue;
}
std::shared_ptr<AddressWaitQueue> tryGetWaitQueue(void *address) {
std::lock_guard lk(g_waitAddressMutex);
auto it = g_waitAddressQueues.find(address);
if (it == g_waitAddressQueues.end()) {
return nullptr;
}
auto queue = it->second.lock();
if (!queue) {
g_waitAddressQueues.erase(it);
return nullptr;
}
return queue;
}
void cleanupWaitQueue(void *address, const std::shared_ptr<AddressWaitQueue> &queue) {
std::lock_guard lk(g_waitAddressMutex);
auto it = g_waitAddressQueues.find(address);
if (it == g_waitAddressQueues.end()) {
return;
}
auto locked = it->second.lock();
if (!locked) {
g_waitAddressQueues.erase(it);
return;
}
if (locked.get() != queue.get()) {
return;
}
std::lock_guard queueLock(queue->mutex);
if (queue->waiterCount == 0) {
g_waitAddressQueues.erase(it);
}
}
struct WaitRegistration {
void *address;
std::shared_ptr<AddressWaitQueue> queue;
size_t sizeIndex;
bool registered = false;
WaitRegistration(void *addr, std::shared_ptr<AddressWaitQueue> q, size_t idx)
: address(addr), queue(std::move(q)), sizeIndex(idx) {}
void registerWaiter() {
if (!queue) {
return;
}
std::lock_guard lk(queue->mutex);
queue->waiterCount++;
queue->sizeCounts[sizeIndex]++;
registered = true;
}
void unregister() {
if (!queue || !registered) {
return;
}
std::lock_guard lk(queue->mutex);
queue->waiterCount--;
queue->sizeCounts[sizeIndex]--;
registered = false;
}
~WaitRegistration() {
unregister();
if (queue) {
cleanupWaitQueue(address, queue);
}
}
};
#if defined(__APPLE__)
using LibcppMonitorFn = long long (*)(const void volatile *);
using LibcppWaitFn = void (*)(const void volatile *, long long);
using LibcppNotifyFn = void (*)(const void volatile *);
LibcppMonitorFn getLibcppAtomicMonitor() {
static LibcppMonitorFn fn =
reinterpret_cast<LibcppMonitorFn>(dlsym(RTLD_DEFAULT, "_ZNSt3__123__libcpp_atomic_monitorEPVKv"));
return fn;
}
LibcppWaitFn getLibcppAtomicWait() {
static LibcppWaitFn fn =
reinterpret_cast<LibcppWaitFn>(dlsym(RTLD_DEFAULT, "_ZNSt3__120__libcpp_atomic_waitEPVKvx"));
return fn;
}
LibcppNotifyFn getLibcppAtomicNotifyOne() {
static LibcppNotifyFn fn =
reinterpret_cast<LibcppNotifyFn>(dlsym(RTLD_DEFAULT, "_ZNSt3__123__cxx_atomic_notify_oneEPVKv"));
return fn;
}
LibcppNotifyFn getLibcppAtomicNotifyAll() {
static LibcppNotifyFn fn =
reinterpret_cast<LibcppNotifyFn>(dlsym(RTLD_DEFAULT, "_ZNSt3__123__cxx_atomic_notify_allEPVKv"));
return fn;
}
template <typename T> void platformWaitIndefinite(T volatile *address, T expected) {
auto monitorFn = getLibcppAtomicMonitor();
auto waitFn = getLibcppAtomicWait();
if (!monitorFn || !waitFn) {
while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == expected) {
std::this_thread::sleep_for(std::chrono::microseconds(50));
}
return;
}
while (true) {
T current = __atomic_load_n(address, __ATOMIC_ACQUIRE);
if (current != expected) {
return;
}
auto monitor = monitorFn(address);
current = __atomic_load_n(address, __ATOMIC_ACQUIRE);
if (current != expected) {
continue;
}
waitFn(address, monitor);
}
}
inline void platformNotifyAddress(void *address, size_t, bool wakeOne) {
auto notifyFn = wakeOne ? getLibcppAtomicNotifyOne() : getLibcppAtomicNotifyAll();
if (notifyFn) {
notifyFn(address);
}
}
#elif defined(__linux__)
template <typename T> void platformWaitIndefinite(T volatile *address, T expected) {
std::atomic_ref<T> ref(*address);
ref.wait(expected, std::memory_order_relaxed);
}
template <typename T> void linuxNotify(void *address, bool wakeOne) {
auto *typed = reinterpret_cast<T *>(address);
std::atomic_ref<T> ref(*typed);
if (wakeOne) {
ref.notify_one();
} else {
ref.notify_all();
}
}
inline void platformNotifyAddress(void *address, size_t size, bool wakeOne) {
switch (size) {
case 1:
linuxNotify<uint8_t>(address, wakeOne);
break;
case 2:
linuxNotify<uint16_t>(address, wakeOne);
break;
case 4:
linuxNotify<uint32_t>(address, wakeOne);
break;
case 8:
linuxNotify<uint64_t>(address, wakeOne);
break;
default:
break;
}
}
#else
template <typename T> void platformWaitIndefinite(T volatile *address, T expected) {
while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == expected) {
std::this_thread::sleep_for(std::chrono::milliseconds(1));
}
}
inline void platformNotifyAddress(void *, size_t, bool) {}
#endif
void notifyAtomicWaiters(void *address, const std::array<size_t, kSupportedAddressSizes> &sizeCounts, bool wakeOne) {
uintptr_t addrValue = reinterpret_cast<uintptr_t>(address);
for (size_t i = 0; i < sizeCounts.size(); ++i) {
if (sizeCounts[i] == 0) {
continue;
}
size_t size = 1 << i;
if (addrValue & (size - 1)) {
continue;
}
platformNotifyAddress(address, size, wakeOne);
}
}
template <typename T> bool waitOnAddressTyped(VOID volatile *addressVoid, PVOID comparePtr, DWORD dwMilliseconds) {
auto *address = reinterpret_cast<T volatile *>(addressVoid);
if (!address) {
kernel32::setLastError(ERROR_INVALID_PARAMETER);
return false;
}
if (reinterpret_cast<uintptr_t>(address) % alignof(T) != 0) {
kernel32::setLastError(ERROR_INVALID_PARAMETER);
return false;
}
const T compareValue = *reinterpret_cast<const T *>(comparePtr);
if (__atomic_load_n(address, __ATOMIC_ACQUIRE) != compareValue) {
return true;
}
if (dwMilliseconds == 0) {
kernel32::setLastError(ERROR_TIMEOUT);
return false;
}
void *queueKey = const_cast<void *>(addressVoid);
auto queue = getWaitQueue(queueKey);
if (!queue) {
kernel32::setLastError(ERROR_GEN_FAILURE);
return false;
}
int sizeIdx = sizeToIndex(sizeof(T));
DEBUG_LOG("size: %d, index %d\n", sizeof(T), sizeIdx);
if (sizeIdx < 0) {
kernel32::setLastError(ERROR_INVALID_PARAMETER);
return false;
}
WaitRegistration registration(queueKey, queue, static_cast<size_t>(sizeIdx));
registration.registerWaiter();
if (dwMilliseconds == INFINITE) {
while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) {
platformWaitIndefinite(address, compareValue);
}
return true;
}
const auto deadline =
std::chrono::steady_clock::now() + std::chrono::milliseconds(static_cast<uint64_t>(dwMilliseconds));
bool timedOut = false;
{
std::unique_lock lk(queue->mutex);
while (__atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) {
if (queue->cv.wait_until(lk, deadline) == std::cv_status::timeout) {
if (__atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) {
timedOut = true;
break;
}
}
}
}
if (timedOut && __atomic_load_n(address, __ATOMIC_ACQUIRE) == compareValue) {
kernel32::setLastError(ERROR_TIMEOUT);
return false;
}
return true;
}
std::u16string makeU16String(LPCWSTR name) {
if (!name) {
@@ -167,29 +458,25 @@ void eraseInitOnceState(LPINIT_ONCE once) {
g_initOnceStates.erase(once);
}
inline DWORD owningThreadId(LPCRITICAL_SECTION crit) {
std::atomic_ref<DWORD> owner(*reinterpret_cast<DWORD *>(&crit->OwningThread));
return owner.load(std::memory_order_acquire);
}
inline DWORD owningThreadId(LPCRITICAL_SECTION crit) { return __atomic_load_n(&crit->OwningThread, __ATOMIC_ACQUIRE); }
inline void setOwningThread(LPCRITICAL_SECTION crit, DWORD threadId) {
std::atomic_ref<DWORD> owner(*reinterpret_cast<DWORD *>(&crit->OwningThread));
owner.store(threadId, std::memory_order_release);
__atomic_store_n(&crit->OwningThread, threadId, __ATOMIC_RELEASE);
}
void waitForCriticalSection(LPCRITICAL_SECTION cs) {
std::atomic_ref<LONG> sequence(reinterpret_cast<LONG &>(cs->LockSemaphore));
LONG observed = sequence.load(std::memory_order_acquire);
auto *sequence = reinterpret_cast<LONG volatile *>(&cs->LockSemaphore);
LONG observed = __atomic_load_n(sequence, __ATOMIC_ACQUIRE);
while (owningThreadId(cs) != 0) {
sequence.wait(observed, std::memory_order_relaxed);
observed = sequence.load(std::memory_order_acquire);
kernel32::WaitOnAddress(sequence, &observed, sizeof(observed), INFINITE);
observed = __atomic_load_n(sequence, __ATOMIC_ACQUIRE);
}
}
void signalCriticalSection(LPCRITICAL_SECTION cs) {
std::atomic_ref<LONG> sequence(reinterpret_cast<LONG &>(cs->LockSemaphore));
sequence.fetch_add(1, std::memory_order_release);
sequence.notify_one();
auto *sequence = reinterpret_cast<LONG *>(&cs->LockSemaphore);
kernel32::InterlockedIncrement(const_cast<LONG volatile *>(sequence));
kernel32::WakeByAddressSingle(sequence);
}
inline bool trySpinAcquireCriticalSection(LPCRITICAL_SECTION cs, DWORD threadId) {
@@ -634,6 +921,71 @@ DWORD WINAPI WaitForMultipleObjects(DWORD nCount, const HANDLE *lpHandles, BOOL
return waitResult;
}
BOOL WINAPI WaitOnAddress(VOID volatile *Address, PVOID CompareAddress, SIZE_T AddressSize, DWORD dwMilliseconds) {
HOST_CONTEXT_GUARD();
VERBOSE_LOG("WaitOnAddress(%p, %p, %zu, %u)\n", Address, CompareAddress, AddressSize, dwMilliseconds);
if (!Address || !CompareAddress) {
setLastError(ERROR_INVALID_PARAMETER);
return FALSE;
}
BOOL result = FALSE;
switch (sizeToIndex(AddressSize)) {
case 0:
result = waitOnAddressTyped<uint8_t>(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE;
break;
case 1:
result = waitOnAddressTyped<uint16_t>(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE;
break;
case 2:
result = waitOnAddressTyped<uint32_t>(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE;
break;
case 3:
result = waitOnAddressTyped<uint64_t>(Address, CompareAddress, dwMilliseconds) ? TRUE : FALSE;
break;
default:
setLastError(ERROR_INVALID_PARAMETER);
return FALSE;
}
return result;
}
void WINAPI WakeByAddressSingle(PVOID Address) {
HOST_CONTEXT_GUARD();
VERBOSE_LOG("WakeByAddressSingle(%p)\n", Address);
if (!Address) {
return;
}
std::array<size_t, kSupportedAddressSizes> sizeCounts{};
auto queue = tryGetWaitQueue(Address);
if (queue) {
std::lock_guard lk(queue->mutex);
sizeCounts = queue->sizeCounts;
}
notifyAtomicWaiters(Address, sizeCounts, true);
if (queue) {
queue->cv.notify_one();
}
}
void WINAPI WakeByAddressAll(PVOID Address) {
HOST_CONTEXT_GUARD();
VERBOSE_LOG("WakeByAddressAll(%p)\n", Address);
if (!Address) {
return;
}
std::array<size_t, kSupportedAddressSizes> sizeCounts{};
auto queue = tryGetWaitQueue(Address);
if (queue) {
std::lock_guard lk(queue->mutex);
sizeCounts = queue->sizeCounts;
}
notifyAtomicWaiters(Address, sizeCounts, false);
if (queue) {
queue->cv.notify_all();
}
}
void WINAPI InitializeCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
HOST_CONTEXT_GUARD();
DEBUG_LOG("InitializeCriticalSection(%p)\n", lpCriticalSection);
@@ -689,13 +1041,13 @@ void WINAPI DeleteCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
lpCriticalSection->SpinCount = 0;
setOwningThread(lpCriticalSection, 0);
std::atomic_ref<LONG> sequence(reinterpret_cast<LONG &>(lpCriticalSection->LockSemaphore));
sequence.store(0, std::memory_order_release);
sequence.notify_all();
auto *sequence = reinterpret_cast<LONG *>(&lpCriticalSection->LockSemaphore);
kernel32::InterlockedExchange(const_cast<LONG volatile *>(sequence), 0);
kernel32::WakeByAddressAll(sequence);
std::atomic_ref<LONG> lockCount(lpCriticalSection->LockCount);
lockCount.store(-1, std::memory_order_release);
lockCount.notify_all();
auto *lockCount = reinterpret_cast<LONG *>(&lpCriticalSection->LockCount);
kernel32::InterlockedExchange(const_cast<LONG volatile *>(lockCount), -1);
kernel32::WakeByAddressAll(lockCount);
}
BOOL WINAPI TryEnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
@@ -706,18 +1058,18 @@ BOOL WINAPI TryEnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
return FALSE;
}
std::atomic_ref<LONG> lockCount(lpCriticalSection->LockCount);
auto *lockCount = const_cast<LONG volatile *>(&lpCriticalSection->LockCount);
const DWORD threadId = GetCurrentThreadId();
LONG expected = -1;
if (lockCount.compare_exchange_strong(expected, 0, std::memory_order_acq_rel, std::memory_order_acquire)) {
LONG previous = kernel32::InterlockedCompareExchange(lockCount, 0, -1);
if (previous == -1) {
setOwningThread(lpCriticalSection, threadId);
lpCriticalSection->RecursionCount = 1;
return TRUE;
}
if (owningThreadId(lpCriticalSection) == threadId) {
lockCount.fetch_add(1, std::memory_order_acq_rel);
kernel32::InterlockedIncrement(lockCount);
lpCriticalSection->RecursionCount++;
return TRUE;
}
@@ -737,8 +1089,8 @@ void WINAPI EnterCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
return;
}
std::atomic_ref<LONG> lockCount(lpCriticalSection->LockCount);
LONG result = lockCount.fetch_add(1, std::memory_order_acq_rel) + 1;
auto *lockCount = const_cast<LONG volatile *>(&lpCriticalSection->LockCount);
LONG result = kernel32::InterlockedIncrement(lockCount);
if (result) {
if (owningThreadId(lpCriticalSection) == threadId) {
lpCriticalSection->RecursionCount++;
@@ -766,15 +1118,14 @@ void WINAPI LeaveCriticalSection(LPCRITICAL_SECTION lpCriticalSection) {
return;
}
auto *lockCount = const_cast<LONG volatile *>(&lpCriticalSection->LockCount);
if (--lpCriticalSection->RecursionCount > 0) {
std::atomic_ref<LONG> lockCount(lpCriticalSection->LockCount);
lockCount.fetch_sub(1, std::memory_order_acq_rel);
kernel32::InterlockedDecrement(lockCount);
return;
}
setOwningThread(lpCriticalSection, 0);
std::atomic_ref<LONG> lockCount(lpCriticalSection->LockCount);
LONG newValue = lockCount.fetch_sub(1, std::memory_order_acq_rel) - 1;
LONG newValue = kernel32::InterlockedDecrement(lockCount);
if (newValue >= 0) {
signalCriticalSection(lpCriticalSection);
}
@@ -792,14 +1143,14 @@ BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL
return FALSE;
}
std::atomic_ref<GUEST_PTR> state(lpInitOnce->Ptr);
auto *state = &lpInitOnce->Ptr;
if (dwFlags & INIT_ONCE_CHECK_ONLY) {
if (dwFlags & INIT_ONCE_ASYNC) {
setLastError(ERROR_INVALID_PARAMETER);
return FALSE;
}
GUEST_PTR val = state.load(std::memory_order_acquire);
GUEST_PTR val = __atomic_load_n(state, __ATOMIC_ACQUIRE);
if ((val & kInitOnceStateMask) != kInitOnceCompletedFlag) {
if (fPending) {
*fPending = TRUE;
@@ -817,13 +1168,13 @@ BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL
}
while (true) {
GUEST_PTR val = state.load(std::memory_order_acquire);
GUEST_PTR val = __atomic_load_n(state, __ATOMIC_ACQUIRE);
switch (val & kInitOnceStateMask) {
case 0: { // first time
if (dwFlags & INIT_ONCE_ASYNC) {
GUEST_PTR expected = 0;
if (state.compare_exchange_strong(expected, static_cast<GUEST_PTR>(3), std::memory_order_acq_rel,
std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(state, &expected, static_cast<GUEST_PTR>(3), false, __ATOMIC_ACQ_REL,
__ATOMIC_ACQUIRE)) {
if (fPending) {
*fPending = TRUE;
}
@@ -832,8 +1183,8 @@ BOOL WINAPI InitOnceBeginInitialize(LPINIT_ONCE lpInitOnce, DWORD dwFlags, PBOOL
} else {
auto syncState = std::make_shared<InitOnceState>();
GUEST_PTR expected = 0;
if (state.compare_exchange_strong(expected, static_cast<GUEST_PTR>(1), std::memory_order_acq_rel,
std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(state, &expected, static_cast<GUEST_PTR>(1), false, __ATOMIC_ACQ_REL,
__ATOMIC_ACQUIRE)) {
insertInitOnceState(lpInitOnce, syncState);
if (fPending) {
*fPending = TRUE;
@@ -924,11 +1275,11 @@ BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpCon
return FALSE;
}
std::atomic_ref<GUEST_PTR> state(lpInitOnce->Ptr);
auto *state = &lpInitOnce->Ptr;
const GUEST_PTR finalValue = markFailed ? 0 : (contextValue | kInitOnceCompletedFlag);
while (true) {
GUEST_PTR val = state.load(std::memory_order_acquire);
GUEST_PTR val = __atomic_load_n(state, __ATOMIC_ACQUIRE);
switch (val & kInitOnceStateMask) {
case 1: {
auto syncState = getInitOnceState(lpInitOnce);
@@ -936,7 +1287,7 @@ BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpCon
setLastError(ERROR_GEN_FAILURE);
return FALSE;
}
if (!state.compare_exchange_strong(val, finalValue, std::memory_order_acq_rel, std::memory_order_acquire)) {
if (!__atomic_compare_exchange_n(state, &val, finalValue, false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) {
continue;
}
{
@@ -954,7 +1305,7 @@ BOOL WINAPI InitOnceComplete(LPINIT_ONCE lpInitOnce, DWORD dwFlags, LPVOID lpCon
setLastError(ERROR_INVALID_PARAMETER);
return FALSE;
}
if (!state.compare_exchange_strong(val, finalValue, std::memory_order_acq_rel, std::memory_order_acquire)) {
if (!__atomic_compare_exchange_n(state, &val, finalValue, false, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) {
continue;
}
return TRUE;
@@ -971,15 +1322,16 @@ void WINAPI AcquireSRWLockShared(PSRWLOCK SRWLock) {
if (!SRWLock) {
return;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
auto *value = &SRWLock->Value;
while (true) {
ULONG current = value.load(std::memory_order_acquire);
ULONG current = __atomic_load_n(value, __ATOMIC_ACQUIRE);
if (current & kSrwLockExclusive) {
value.wait(current, std::memory_order_relaxed);
ULONG observed = current;
kernel32::WaitOnAddress(reinterpret_cast<VOID volatile *>(value), &observed, sizeof(observed), INFINITE);
continue;
}
ULONG desired = current + kSrwLockSharedIncrement;
if (value.compare_exchange_weak(current, desired, std::memory_order_acq_rel, std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(value, &current, desired, true, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) {
return;
}
}
@@ -991,11 +1343,11 @@ void WINAPI ReleaseSRWLockShared(PSRWLOCK SRWLock) {
if (!SRWLock) {
return;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
ULONG previous = value.fetch_sub(kSrwLockSharedIncrement, std::memory_order_acq_rel);
auto *value = &SRWLock->Value;
ULONG previous = __atomic_fetch_sub(value, kSrwLockSharedIncrement, __ATOMIC_ACQ_REL);
ULONG newValue = previous - kSrwLockSharedIncrement;
if (newValue == 0) {
value.notify_all();
kernel32::WakeByAddressAll(value);
}
}
@@ -1005,14 +1357,14 @@ void WINAPI AcquireSRWLockExclusive(PSRWLOCK SRWLock) {
if (!SRWLock) {
return;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
auto *value = &SRWLock->Value;
while (true) {
ULONG expected = 0;
if (value.compare_exchange_strong(expected, kSrwLockExclusive, std::memory_order_acq_rel,
std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(value, &expected, kSrwLockExclusive, false, __ATOMIC_ACQ_REL,
__ATOMIC_ACQUIRE)) {
return;
}
value.wait(expected, std::memory_order_relaxed);
kernel32::WaitOnAddress(reinterpret_cast<VOID volatile *>(value), &expected, sizeof(expected), INFINITE);
}
}
@@ -1022,9 +1374,8 @@ void WINAPI ReleaseSRWLockExclusive(PSRWLOCK SRWLock) {
if (!SRWLock) {
return;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
value.store(0, std::memory_order_release);
value.notify_all();
__atomic_store_n(&SRWLock->Value, 0u, __ATOMIC_RELEASE);
kernel32::WakeByAddressAll(&SRWLock->Value);
}
BOOLEAN WINAPI TryAcquireSRWLockExclusive(PSRWLOCK SRWLock) {
@@ -1033,10 +1384,9 @@ BOOLEAN WINAPI TryAcquireSRWLockExclusive(PSRWLOCK SRWLock) {
if (!SRWLock) {
return FALSE;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
ULONG expected = 0;
if (value.compare_exchange_strong(expected, kSrwLockExclusive, std::memory_order_acq_rel,
std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(&SRWLock->Value, &expected, kSrwLockExclusive, false, __ATOMIC_ACQ_REL,
__ATOMIC_ACQUIRE)) {
return TRUE;
}
return FALSE;
@@ -1048,11 +1398,10 @@ BOOLEAN WINAPI TryAcquireSRWLockShared(PSRWLOCK SRWLock) {
if (!SRWLock) {
return FALSE;
}
std::atomic_ref<ULONG> value(SRWLock->Value);
ULONG current = value.load(std::memory_order_acquire);
ULONG current = __atomic_load_n(&SRWLock->Value, __ATOMIC_ACQUIRE);
while (!(current & kSrwLockExclusive)) {
ULONG desired = current + kSrwLockSharedIncrement;
if (value.compare_exchange_weak(current, desired, std::memory_order_acq_rel, std::memory_order_acquire)) {
if (__atomic_compare_exchange_n(&SRWLock->Value, &current, desired, true, __ATOMIC_ACQ_REL, __ATOMIC_ACQUIRE)) {
return TRUE;
}
}