From df36de18bfc7dedb3aeca38efb880631601738cc Mon Sep 17 00:00:00 2001 From: Luke Street Date: Mon, 6 Oct 2025 23:51:09 -0600 Subject: [PATCH] Implement async (overlapped) I/O with io_uring --- CMakeLists.txt | 40 ++++- Dockerfile | 12 +- Dockerfile.ubuntu | 1 + dll/kernel32.cpp | 2 + dll/kernel32/fileapi.cpp | 75 +++++--- dll/kernel32/internal.h | 10 +- dll/kernel32/ioapiset.cpp | 18 +- dll/kernel32/overlapped_util.h | 53 ++++++ dll/kernel32/processthreadsapi.cpp | 11 +- dll/kernel32/synchapi.cpp | 211 ++++++++++++++++++++-- dll/kernel32/synchapi.h | 2 + dll/msvcrt.cpp | 4 +- src/async_io.cpp | 277 +++++++++++++++++++++++++++++ src/async_io.h | 19 ++ src/errors.cpp | 19 ++ src/errors.h | 1 + src/handles.cpp | 39 +++- src/handles.h | 177 ++++++++++-------- src/main.cpp | 2 + src/processes.cpp | 3 +- test/test_overlapped_io.c | 76 ++++++++ 21 files changed, 909 insertions(+), 143 deletions(-) create mode 100644 dll/kernel32/overlapped_util.h create mode 100644 src/async_io.cpp create mode 100644 src/async_io.h diff --git a/CMakeLists.txt b/CMakeLists.txt index ad32dd0..681ac64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,11 +10,12 @@ set(CMAKE_CXX_FLAGS_INIT "-m32") set(CMAKE_EXE_LINKER_FLAGS_INIT "-m32") set(CMAKE_SHARED_LINKER_FLAGS_INIT "-m32") -project(wibo LANGUAGES CXX) +project(wibo LANGUAGES C CXX) list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") -set(CMAKE_CXX_STANDARD 17) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wall -fno-pie -no-pie -D_LARGEFILE64_SOURCE") find_package(Filesystem REQUIRED) @@ -30,7 +31,36 @@ FetchContent_Declare( ) FetchContent_MakeAvailable(mimalloc) -include_directories(.) +FetchContent_Declare( + liburing + GIT_REPOSITORY https://github.com/axboe/liburing.git + GIT_TAG liburing-2.4 +) +FetchContent_MakeAvailable(liburing) + +set(LIBURING_COMPAT ${liburing_SOURCE_DIR}/src/include/liburing/compat.h) +add_custom_command( + OUTPUT ${LIBURING_COMPAT} + COMMAND ${CMAKE_COMMAND} -E env + CC=${CMAKE_C_COMPILER} + AR=${CMAKE_AR} + RANLIB=${CMAKE_RANLIB} + ./configure --cc=${CMAKE_C_COMPILER} + WORKING_DIRECTORY ${liburing_SOURCE_DIR} + COMMENT "Running liburing configure" + VERBATIM +) +add_custom_target(liburing_configure DEPENDS ${LIBURING_COMPAT}) +add_library(liburing STATIC + ${liburing_SOURCE_DIR}/src/queue.c + ${liburing_SOURCE_DIR}/src/register.c + ${liburing_SOURCE_DIR}/src/setup.c + ${liburing_SOURCE_DIR}/src/syscall.c + ${liburing_SOURCE_DIR}/src/version.c) +add_dependencies(liburing liburing_configure) +target_include_directories(liburing PUBLIC ${liburing_SOURCE_DIR}/src/include) +target_compile_definitions(liburing PRIVATE _GNU_SOURCE) + add_executable(wibo dll/advapi32.cpp dll/advapi32/processthreadsapi.cpp @@ -74,6 +104,7 @@ add_executable(wibo dll/vcruntime.cpp dll/version.cpp src/access.cpp + src/async_io.cpp src/context.cpp src/errors.cpp src/files.cpp @@ -86,7 +117,8 @@ add_executable(wibo src/strutil.cpp ) target_include_directories(wibo PRIVATE dll src) -target_link_libraries(wibo PRIVATE std::filesystem mimalloc-static) +target_compile_features(wibo PRIVATE cxx_std_20) +target_link_libraries(wibo PRIVATE std::filesystem mimalloc-static liburing) install(TARGETS wibo DESTINATION bin) include(CTest) diff --git a/Dockerfile b/Dockerfile index a64e3f2..476f67c 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,14 +4,15 @@ FROM --platform=linux/i386 alpine:latest AS build # Install dependencies RUN apk add --no-cache \ bash \ - cmake \ - ninja \ - g++ \ - linux-headers \ binutils \ + cmake \ + g++ \ git \ + linux-headers \ + make \ mingw-w64-binutils \ - mingw-w64-gcc + mingw-w64-gcc \ + ninja # Copy source files WORKDIR /wibo @@ -23,6 +24,7 @@ ARG build_type=Release # Build static binary RUN cmake -S /wibo -B /wibo/build -G Ninja \ -DCMAKE_BUILD_TYPE="$build_type" \ + -DCMAKE_C_FLAGS="-static" \ -DCMAKE_CXX_FLAGS="-static" \ -DBUILD_TESTING=ON \ -DWIBO_ENABLE_FIXTURE_TESTS=ON \ diff --git a/Dockerfile.ubuntu b/Dockerfile.ubuntu index 8e1efc8..09d0a8c 100644 --- a/Dockerfile.ubuntu +++ b/Dockerfile.ubuntu @@ -14,6 +14,7 @@ RUN apt-get update \ gcc-mingw-w64-i686 \ gdb \ git \ + make \ ninja-build \ unzip \ wget \ diff --git a/dll/kernel32.cpp b/dll/kernel32.cpp index 63cd8c9..114bfbe 100644 --- a/dll/kernel32.cpp +++ b/dll/kernel32.cpp @@ -170,6 +170,8 @@ void *resolveByName(const char *name) { return (void *)kernel32::TryAcquireSRWLockExclusive; if (strcmp(name, "WaitForSingleObject") == 0) return (void *)kernel32::WaitForSingleObject; + if (strcmp(name, "WaitForMultipleObjects") == 0) + return (void *)kernel32::WaitForMultipleObjects; if (strcmp(name, "CreateMutexA") == 0) return (void *)kernel32::CreateMutexA; if (strcmp(name, "CreateMutexW") == 0) diff --git a/dll/kernel32/fileapi.cpp b/dll/kernel32/fileapi.cpp index 47cc5f4..92ec82f 100644 --- a/dll/kernel32/fileapi.cpp +++ b/dll/kernel32/fileapi.cpp @@ -1,12 +1,14 @@ #include "fileapi.h" #include "access.h" +#include "async_io.h" #include "common.h" #include "context.h" #include "errors.h" #include "files.h" #include "handles.h" #include "internal.h" +#include "overlapped_util.h" #include "strutil.h" #include "timeutil.h" @@ -291,7 +293,7 @@ template void populateFromStat(const FindSearchEntry &entry, template void populateFindData(const FindSearchEntry &entry, FindData &out) { resetFindDataStruct(out); - std::string nativePath = entry.fullPath.empty() ? std::string() : entry.fullPath.u8string(); + std::string nativePath = entry.fullPath.empty() ? std::string() : entry.fullPath.string(); struct stat st{}; if (!nativePath.empty() && stat(nativePath.c_str(), &st) == 0) { populateFromStat(entry, st, out); @@ -548,26 +550,6 @@ bool tryOpenConsoleDevice(DWORD dwDesiredAccess, DWORD dwShareMode, DWORD dwCrea namespace kernel32 { -namespace { - -void signalOverlappedEvent(OVERLAPPED *ov) { - if (ov && ov->hEvent) { - if (auto ev = wibo::handles().getAs(ov->hEvent)) { - ev->set(); - } - } -} - -void resetOverlappedEvent(OVERLAPPED *ov) { - if (ov && ov->hEvent) { - if (auto ev = wibo::handles().getAs(ov->hEvent)) { - ev->reset(); - } - } -} - -} // namespace - DWORD WIN_FUNC GetFileAttributesA(LPCSTR lpFileName) { HOST_CONTEXT_GUARD(); if (!lpFileName) { @@ -743,7 +725,27 @@ BOOL WIN_FUNC WriteFile(HANDLE hFile, LPCVOID lpBuffer, DWORD nNumberOfBytesToWr lpOverlapped->Internal = STATUS_PENDING; lpOverlapped->InternalHigh = 0; updateFilePointer = !file->overlapped; - resetOverlappedEvent(lpOverlapped); + detail::resetOverlappedEvent(lpOverlapped); + if (file->overlapped) { + if (nNumberOfBytesToWrite == 0) { + lpOverlapped->Internal = STATUS_SUCCESS; + lpOverlapped->InternalHigh = 0; + detail::signalOverlappedEvent(lpOverlapped); + if (lpNumberOfBytesWritten) { + *lpNumberOfBytesWritten = 0; + } + return TRUE; + } + auto asyncFile = file.clone(); + if (async_io::queueWrite(std::move(asyncFile), lpOverlapped, lpBuffer, nNumberOfBytesToWrite, offset, + file->isPipe)) { + if (lpNumberOfBytesWritten) { + *lpNumberOfBytesWritten = 0; + } + wibo::lastError = ERROR_IO_PENDING; + return FALSE; + } + } } auto io = files::write(file.get(), lpBuffer, nNumberOfBytesToWrite, offset, updateFilePointer); @@ -762,7 +764,7 @@ BOOL WIN_FUNC WriteFile(HANDLE hFile, LPCVOID lpBuffer, DWORD nNumberOfBytesToWr if (lpOverlapped != nullptr) { lpOverlapped->Internal = completionStatus; lpOverlapped->InternalHigh = io.bytesTransferred; - signalOverlappedEvent(lpOverlapped); + detail::signalOverlappedEvent(lpOverlapped); } return io.unixError == 0; @@ -820,7 +822,25 @@ BOOL WIN_FUNC ReadFile(HANDLE hFile, LPVOID lpBuffer, DWORD nNumberOfBytesToRead lpOverlapped->Internal = STATUS_PENDING; lpOverlapped->InternalHigh = 0; updateFilePointer = !file->overlapped; - resetOverlappedEvent(lpOverlapped); + detail::resetOverlappedEvent(lpOverlapped); + if (file->overlapped) { + if (nNumberOfBytesToRead == 0) { + lpOverlapped->Internal = STATUS_SUCCESS; + lpOverlapped->InternalHigh = 0; + detail::signalOverlappedEvent(lpOverlapped); + if (lpNumberOfBytesRead) { + *lpNumberOfBytesRead = 0; + } + return TRUE; + } + if (async_io::queueRead(file.clone(), lpOverlapped, lpBuffer, nNumberOfBytesToRead, offset, file->isPipe)) { + if (lpNumberOfBytesRead) { + *lpNumberOfBytesRead = 0; + } + wibo::lastError = ERROR_IO_PENDING; + return FALSE; + } + } } auto io = files::read(file.get(), lpBuffer, nNumberOfBytesToRead, offset, updateFilePointer); @@ -829,17 +849,18 @@ BOOL WIN_FUNC ReadFile(HANDLE hFile, LPVOID lpBuffer, DWORD nNumberOfBytesToRead completionStatus = wibo::statusFromErrno(io.unixError); wibo::lastError = wibo::winErrorFromErrno(io.unixError); } else if (io.reachedEnd && io.bytesTransferred == 0) { - completionStatus = STATUS_END_OF_FILE; if (file->isPipe) { + completionStatus = STATUS_PIPE_BROKEN; wibo::lastError = ERROR_BROKEN_PIPE; if (lpOverlapped != nullptr) { lpOverlapped->Internal = completionStatus; lpOverlapped->InternalHigh = 0; - signalOverlappedEvent(lpOverlapped); + detail::signalOverlappedEvent(lpOverlapped); } DEBUG_LOG("-> ERROR_BROKEN_PIPE\n"); return FALSE; } + completionStatus = STATUS_END_OF_FILE; } if (lpNumberOfBytesRead && (!file->overlapped || lpOverlapped == nullptr)) { @@ -849,7 +870,7 @@ BOOL WIN_FUNC ReadFile(HANDLE hFile, LPVOID lpBuffer, DWORD nNumberOfBytesToRead if (lpOverlapped != nullptr) { lpOverlapped->Internal = completionStatus; lpOverlapped->InternalHigh = io.bytesTransferred; - signalOverlappedEvent(lpOverlapped); + detail::signalOverlappedEvent(lpOverlapped); } DEBUG_LOG("-> %u bytes read, error %d\n", io.bytesTransferred, io.unixError == 0 ? 0 : wibo::lastError); diff --git a/dll/kernel32/internal.h b/dll/kernel32/internal.h index f2cbda4..919ba39 100644 --- a/dll/kernel32/internal.h +++ b/dll/kernel32/internal.h @@ -98,7 +98,7 @@ struct MutexObject final : WaitableObject { unsigned int recursionCount = 0; bool abandoned = false; // Owner exited without releasing - MutexObject() : WaitableObject(kType) { signaled.store(true, std::memory_order_relaxed); } + MutexObject() : WaitableObject(kType) { signaled = true; } void noteOwnerExit(/*pthread_t tid or emu tid*/) { std::lock_guard lk(m); @@ -106,8 +106,9 @@ struct MutexObject final : WaitableObject { ownerValid = false; recursionCount = 0; abandoned = true; - signaled.store(true, std::memory_order_release); + signaled = true; cv.notify_one(); + notifyWaiters(true); } } }; @@ -123,7 +124,7 @@ struct EventObject final : WaitableObject { bool resetAll = false; { std::lock_guard lk(m); - signaled.store(true, std::memory_order_release); + signaled = true; resetAll = manualReset; } if (resetAll) { @@ -131,11 +132,12 @@ struct EventObject final : WaitableObject { } else { cv.notify_one(); } + notifyWaiters(false); } void reset() { std::lock_guard lk(m); - signaled.store(false, std::memory_order_release); + signaled = false; } }; diff --git a/dll/kernel32/ioapiset.cpp b/dll/kernel32/ioapiset.cpp index 24efdd2..cf3d085 100644 --- a/dll/kernel32/ioapiset.cpp +++ b/dll/kernel32/ioapiset.cpp @@ -2,6 +2,7 @@ #include "context.h" #include "errors.h" +#include "overlapped_util.h" #include "synchapi.h" namespace kernel32 { @@ -15,8 +16,11 @@ BOOL WIN_FUNC GetOverlappedResult(HANDLE hFile, LPOVERLAPPED lpOverlapped, LPDWO wibo::lastError = ERROR_INVALID_PARAMETER; return FALSE; } - if (bWait && lpOverlapped->Internal == STATUS_PENDING && lpOverlapped->hEvent) { - WaitForSingleObject(lpOverlapped->hEvent, INFINITE); + if (bWait && lpOverlapped->Internal == STATUS_PENDING && + kernel32::detail::shouldSignalOverlappedEvent(lpOverlapped)) { + if (HANDLE waitHandle = kernel32::detail::normalizedOverlappedEventHandle(lpOverlapped)) { + WaitForSingleObject(waitHandle, INFINITE); + } } const auto status = static_cast(lpOverlapped->Internal); @@ -29,15 +33,11 @@ BOOL WIN_FUNC GetOverlappedResult(HANDLE hFile, LPOVERLAPPED lpOverlapped, LPDWO *lpNumberOfBytesTransferred = static_cast(lpOverlapped->InternalHigh); } - if (status == STATUS_SUCCESS) { + DWORD error = wibo::winErrorFromNtStatus(status); + if (error == ERROR_SUCCESS) { return TRUE; } - if (status == STATUS_END_OF_FILE) { - wibo::lastError = ERROR_HANDLE_EOF; - return FALSE; - } - - wibo::lastError = status; + wibo::lastError = error; return FALSE; } diff --git a/dll/kernel32/overlapped_util.h b/dll/kernel32/overlapped_util.h new file mode 100644 index 0000000..bf1cfb1 --- /dev/null +++ b/dll/kernel32/overlapped_util.h @@ -0,0 +1,53 @@ +#pragma once + +#include "context.h" +#include "handles.h" +#include "internal.h" +#include "minwinbase.h" + +#include + +namespace kernel32::detail { + +inline bool shouldSignalOverlappedEvent(const OVERLAPPED *ov) { + if (!ov) { + return false; + } + auto raw = reinterpret_cast(ov->hEvent); + return (raw & 1U) == 0 && raw != 0; +} + +inline HANDLE normalizedOverlappedEventHandle(const OVERLAPPED *ov) { + if (!ov) { + return nullptr; + } + auto raw = reinterpret_cast(ov->hEvent); + raw &= ~static_cast(1); + return reinterpret_cast(raw); +} + +inline void signalOverlappedEvent(OVERLAPPED *ov) { + if (!shouldSignalOverlappedEvent(ov)) { + return; + } + HANDLE handle = normalizedOverlappedEventHandle(ov); + if (handle) { + if (auto ev = wibo::handles().getAs(handle)) { + ev->set(); + } + } +} + +inline void resetOverlappedEvent(OVERLAPPED *ov) { + if (!ov) { + return; + } + HANDLE handle = normalizedOverlappedEventHandle(ov); + if (handle) { + if (auto ev = wibo::handles().getAs(handle)) { + ev->reset(); + } + } +} + +} // namespace kernel32::detail diff --git a/dll/kernel32/processthreadsapi.cpp b/dll/kernel32/processthreadsapi.cpp index b8bdd5f..87db7ac 100644 --- a/dll/kernel32/processthreadsapi.cpp +++ b/dll/kernel32/processthreadsapi.cpp @@ -106,7 +106,7 @@ void threadCleanup(void *param) { } { std::lock_guard lk(obj->m); - obj->signaled.store(true, std::memory_order_release); + obj->signaled = true; // Exit code set before pthread_exit } g_currentThreadObject = nullptr; @@ -114,6 +114,7 @@ void threadCleanup(void *param) { wibo::setThreadTibForHost(nullptr); // TODO: mark mutexes owned by this thread as abandoned obj->cv.notify_all(); + obj->notifyWaiters(false); detail::deref(obj); } @@ -328,10 +329,10 @@ BOOL WIN_FUNC TerminateProcess(HANDLE hProcess, UINT uExitCode) { wibo::lastError = ERROR_INVALID_HANDLE; return FALSE; } - if (process->signaled.load(std::memory_order_acquire)) { + std::lock_guard lk(process->m); + if (process->signaled) { return TRUE; } - std::lock_guard lk(process->m); if (syscall(SYS_pidfd_send_signal, process->pidfd, SIGKILL, nullptr, 0) != 0) { int err = errno; DEBUG_LOG("TerminateProcess: pidfd_send_signal(%d) failed: %s\n", process->pidfd, strerror(err)); @@ -368,8 +369,8 @@ BOOL WIN_FUNC GetExitCodeProcess(HANDLE hProcess, LPDWORD lpExitCode) { return FALSE; } DWORD exitCode = STILL_ACTIVE; - if (process->signaled.load(std::memory_order_acquire)) { - std::lock_guard lk(process->m); + std::lock_guard lk(process->m); + if (process->signaled) { exitCode = process->exitCode; } *lpExitCode = exitCode; diff --git a/dll/kernel32/synchapi.cpp b/dll/kernel32/synchapi.cpp index d831f76..9dd0bd4 100644 --- a/dll/kernel32/synchapi.cpp +++ b/dll/kernel32/synchapi.cpp @@ -7,9 +7,12 @@ #include "internal.h" #include "strutil.h" +#include #include #include +#include #include +#include #include #include #include @@ -34,6 +37,93 @@ void makeWideNameFromAnsi(LPCSTR ansiName, std::vector &outWide) { outWide = stringToWideString(ansiName); } +struct WaitBlock { + explicit WaitBlock(bool waitAllIn, DWORD count) : waitAll(waitAllIn != FALSE), satisfied(count, false) {} + + static void notify(void *context, WaitableObject *obj, DWORD index, bool abandoned) { + auto *self = static_cast(context); + if (self) { + self->handleSignal(obj, index, abandoned, true); + } + } + + void noteInitial(WaitableObject *obj, DWORD index, bool abandoned) { handleSignal(obj, index, abandoned, false); } + + bool isCompleted(DWORD &outResult) { + std::lock_guard lk(mutex); + if (!completed) { + return false; + } + outResult = result; + return true; + } + + bool waitUntil(const std::optional &deadline, DWORD &outResult) { + std::unique_lock lk(mutex); + if (!completed) { + if (deadline) { + if (!cv.wait_until(lk, *deadline, [&] { return completed; })) { + return false; + } + } else { + cv.wait(lk, [&] { return completed; }); + } + } + outResult = result; + return true; + } + + void handleSignal(WaitableObject *obj, DWORD index, bool abandoned, bool fromWaiter) { + if (!obj) { + return; + } + bool notify = false; + { + std::lock_guard lk(mutex); + if (index >= satisfied.size()) { + return; + } + if (satisfied[index]) { + // Already satisfied; nothing to do aside from cleanup below. + } else if (!completed) { + satisfied[index] = true; + if (waitAll) { + if (abandoned) { + result = WAIT_ABANDONED + index; + completed = true; + notify = true; + } else if (std::all_of(satisfied.begin(), satisfied.end(), [](bool v) { return v; })) { + result = WAIT_OBJECT_0; + completed = true; + notify = true; + } + } else { + result = abandoned ? (WAIT_ABANDONED + index) : (WAIT_OBJECT_0 + index); + completed = true; + notify = true; + } + } + } + // Always unregister once we've observed a signal for this waiter. + if (fromWaiter) { + obj->unregisterWaiter(this); + } else if (!waitAll || satisfied[index]) { + // Initial state satisfaction can drop registration immediately. + obj->unregisterWaiter(this); + } + if (notify) { + cv.notify_all(); + } + } + + const bool waitAll; + std::vector satisfied; + bool completed = false; + DWORD result = WAIT_TIMEOUT; + std::mutex mutex; + std::condition_variable cv; +}; + } // namespace namespace kernel32 { @@ -61,7 +151,7 @@ HANDLE WIN_FUNC CreateMutexW(LPSECURITY_ATTRIBUTES lpMutexAttributes, BOOL bInit mu->owner = pthread_self(); mu->ownerValid = true; mu->recursionCount = 1; - mu->signaled.store(false, std::memory_order_release); + mu->signaled = false; } return mu; }); @@ -102,12 +192,13 @@ BOOL WIN_FUNC ReleaseMutex(HANDLE hMutex) { } if (--mu->recursionCount == 0) { mu->ownerValid = false; - mu->signaled.store(true, std::memory_order_release); + mu->signaled = true; notify = true; } } if (notify) { mu->cv.notify_one(); + mu->notifyWaiters(false); } return TRUE; } @@ -125,7 +216,7 @@ HANDLE WIN_FUNC CreateEventW(LPSECURITY_ATTRIBUTES lpEventAttributes, BOOL bManu } auto [ev, created] = wibo::g_namespace.getOrCreate(name, [&]() { auto e = new EventObject(bManualReset); - e->signaled.store(bInitialState, std::memory_order_relaxed); + e->signaled = bInitialState; return e; }); if (!ev) { @@ -200,6 +291,7 @@ BOOL WIN_FUNC ReleaseSemaphore(HANDLE hSemaphore, LONG lReleaseCount, PLONG lpPr } LONG prev = 0; + bool shouldNotifyWaitBlocks = false; { std::lock_guard lk(sem->m); if (lpPreviousCount) { @@ -210,11 +302,15 @@ BOOL WIN_FUNC ReleaseSemaphore(HANDLE hSemaphore, LONG lReleaseCount, PLONG lpPr return FALSE; } sem->count += lReleaseCount; - sem->signaled.store(sem->count > 0, std::memory_order_release); + sem->signaled = sem->count > 0; + shouldNotifyWaitBlocks = sem->count > 0; } for (LONG i = 0; i < lReleaseCount; ++i) { sem->cv.notify_one(); } + if (shouldNotifyWaitBlocks) { + sem->notifyWaiters(false); + } if (lpPreviousCount) { *lpPreviousCount = prev; @@ -279,12 +375,12 @@ DWORD WIN_FUNC WaitForSingleObject(HANDLE hHandle, DWORD dwMilliseconds) { case ObjectType::Event: { auto ev = std::move(obj).downcast(); std::unique_lock lk(ev->m); - bool ok = doWait(lk, ev->cv, [&] { return ev->signaled.load(std::memory_order_acquire); }); + bool ok = doWait(lk, ev->cv, [&] { return ev->signaled; }); if (!ok) { return WAIT_TIMEOUT; } if (!ev->manualReset) { - ev->signaled.store(false, std::memory_order_release); + ev->signaled = false; } return WAIT_OBJECT_0; } @@ -297,7 +393,7 @@ DWORD WIN_FUNC WaitForSingleObject(HANDLE hHandle, DWORD dwMilliseconds) { } --sem->count; if (sem->count == 0) { - sem->signaled.store(false, std::memory_order_release); + sem->signaled = false; } return WAIT_OBJECT_0; } @@ -322,7 +418,7 @@ DWORD WIN_FUNC WaitForSingleObject(HANDLE hHandle, DWORD dwMilliseconds) { mu->owner = self; mu->ownerValid = true; mu->recursionCount = 1; - mu->signaled.store(false, std::memory_order_release); + mu->signaled = false; return ret; } case ObjectType::Thread: { @@ -333,7 +429,7 @@ DWORD WIN_FUNC WaitForSingleObject(HANDLE hHandle, DWORD dwMilliseconds) { // Windows actually allows you to wait on your own thread, but why bother? return WAIT_TIMEOUT; } - bool ok = doWait(lk, th->cv, [&] { return th->signaled.load(std::memory_order_acquire); }); + bool ok = doWait(lk, th->cv, [&] { return th->signaled; }); return ok ? WAIT_OBJECT_0 : WAIT_TIMEOUT; } case ObjectType::Process: { @@ -343,7 +439,7 @@ DWORD WIN_FUNC WaitForSingleObject(HANDLE hHandle, DWORD dwMilliseconds) { // Windows actually allows you to wait on your own process, but why bother? return WAIT_TIMEOUT; } - bool ok = doWait(lk, po->cv, [&] { return po->signaled.load(std::memory_order_acquire); }); + bool ok = doWait(lk, po->cv, [&] { return po->signaled; }); return ok ? WAIT_OBJECT_0 : WAIT_TIMEOUT; } default: @@ -352,6 +448,101 @@ DWORD WIN_FUNC WaitForSingleObject(HANDLE hHandle, DWORD dwMilliseconds) { } } +DWORD WIN_FUNC WaitForMultipleObjects(DWORD nCount, const HANDLE *lpHandles, BOOL bWaitAll, DWORD dwMilliseconds) { + HOST_CONTEXT_GUARD(); + DEBUG_LOG("WaitForMultipleObjects(%u, %p, %d, %u)\n", nCount, lpHandles, static_cast(bWaitAll), + dwMilliseconds); + + if (nCount == 0 || nCount > MAXIMUM_WAIT_OBJECTS || !lpHandles) { + wibo::lastError = ERROR_INVALID_PARAMETER; + return WAIT_FAILED; + } + + std::vector> objects(nCount); + for (DWORD i = 0; i < nCount; ++i) { + HandleMeta meta{}; + auto obj = wibo::handles().getAs(lpHandles[i], &meta); + if (!obj) { + wibo::lastError = ERROR_INVALID_HANDLE; + return WAIT_FAILED; + } + objects[i] = std::move(obj); + } + + WaitBlock block(bWaitAll, nCount); + for (DWORD i = 0; i < objects.size(); ++i) { + objects[i]->registerWaiter(&block, i, &WaitBlock::notify); + } + + for (DWORD i = 0; i < objects.size(); ++i) { + auto *obj = objects[i].get(); + bool isSignaled = obj->signaled; + bool isAbandoned = false; + if (auto *mu = detail::castTo(obj)) { + isAbandoned = mu->abandoned; + } + if (isSignaled) { + block.noteInitial(obj, i, isAbandoned); + } + } + + DWORD waitResult = WAIT_TIMEOUT; + if (!block.isCompleted(waitResult)) { + if (dwMilliseconds == 0) { + waitResult = WAIT_TIMEOUT; + } else { + std::optional deadline; + if (dwMilliseconds != INFINITE) { + deadline = + std::chrono::steady_clock::now() + std::chrono::milliseconds(static_cast(dwMilliseconds)); + } + DWORD signaledResult = WAIT_TIMEOUT; + bool completed = block.waitUntil(deadline, signaledResult); + if (completed) { + waitResult = signaledResult; + } else { + waitResult = WAIT_TIMEOUT; + } + } + } + + for (const auto &object : objects) { + object->unregisterWaiter(&block); + } + + if (waitResult == WAIT_TIMEOUT) { + return WAIT_TIMEOUT; + } + + if (waitResult == WAIT_FAILED) { + return WAIT_FAILED; + } + + auto consume = [&](DWORD index) { + if (index < nCount) { + WaitForSingleObject(lpHandles[index], 0); + } + }; + + if (bWaitAll) { + if (waitResult == WAIT_OBJECT_0) { + for (DWORD i = 0; i < nCount; ++i) { + consume(i); + } + } else if (waitResult >= WAIT_ABANDONED && waitResult < WAIT_ABANDONED + nCount) { + consume(waitResult - WAIT_ABANDONED); + } + } else { + if (waitResult >= WAIT_OBJECT_0 && waitResult < WAIT_OBJECT_0 + nCount) { + consume(waitResult - WAIT_OBJECT_0); + } else if (waitResult >= WAIT_ABANDONED && waitResult < WAIT_ABANDONED + nCount) { + consume(waitResult - WAIT_ABANDONED); + } + } + + return waitResult; +} + void WIN_FUNC InitializeCriticalSection(LPCRITICAL_SECTION lpCriticalSection) { HOST_CONTEXT_GUARD(); VERBOSE_LOG("STUB: InitializeCriticalSection(%p)\n", lpCriticalSection); diff --git a/dll/kernel32/synchapi.h b/dll/kernel32/synchapi.h index 8b1dd32..1514e90 100644 --- a/dll/kernel32/synchapi.h +++ b/dll/kernel32/synchapi.h @@ -8,6 +8,7 @@ constexpr DWORD WAIT_ABANDONED = 0x00000080; constexpr DWORD WAIT_TIMEOUT = 0x00000102; constexpr DWORD WAIT_FAILED = 0xFFFFFFFF; constexpr DWORD INFINITE = 0xFFFFFFFF; +constexpr DWORD MAXIMUM_WAIT_OBJECTS = 64; constexpr DWORD INIT_ONCE_CHECK_ONLY = 0x00000001UL; constexpr DWORD INIT_ONCE_ASYNC = 0x00000002UL; @@ -89,6 +90,7 @@ HANDLE WIN_FUNC CreateSemaphoreW(LPSECURITY_ATTRIBUTES lpSemaphoreAttributes, LO LPCWSTR lpName); BOOL WIN_FUNC ReleaseSemaphore(HANDLE hSemaphore, LONG lReleaseCount, PLONG lpPreviousCount); DWORD WIN_FUNC WaitForSingleObject(HANDLE hHandle, DWORD dwMilliseconds); +DWORD WIN_FUNC WaitForMultipleObjects(DWORD nCount, const HANDLE *lpHandles, BOOL bWaitAll, DWORD dwMilliseconds); void WIN_FUNC InitializeCriticalSection(LPCRITICAL_SECTION lpCriticalSection); BOOL WIN_FUNC InitializeCriticalSectionEx(LPCRITICAL_SECTION lpCriticalSection, DWORD dwSpinCount, DWORD Flags); BOOL WIN_FUNC InitializeCriticalSectionAndSpinCount(LPCRITICAL_SECTION lpCriticalSection, DWORD dwSpinCount); diff --git a/dll/msvcrt.cpp b/dll/msvcrt.cpp index 69ad602..0080676 100644 --- a/dll/msvcrt.cpp +++ b/dll/msvcrt.cpp @@ -2906,7 +2906,7 @@ namespace msvcrt { if (mode == P_WAIT) { std::unique_lock lk(po->m); - po->cv.wait(lk, [&] { return po->signaled.load(); }); + po->cv.wait(lk, [&] { return po->signaled; }); return static_cast(po->exitCode); } @@ -2955,7 +2955,7 @@ namespace msvcrt { if (mode == P_WAIT) { std::unique_lock lk(po->m); - po->cv.wait(lk, [&] { return po->signaled.load(); }); + po->cv.wait(lk, [&] { return po->signaled; }); return static_cast(po->exitCode); } diff --git a/src/async_io.cpp b/src/async_io.cpp new file mode 100644 index 0000000..bad2a6c --- /dev/null +++ b/src/async_io.cpp @@ -0,0 +1,277 @@ +#include "async_io.h" + +#include "common.h" +#include "errors.h" +#include "kernel32/overlapped_util.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace async_io { +namespace { + +constexpr unsigned kQueueDepth = 64; + +struct AsyncRequest { + enum class Kind { Read, Write, Shutdown }; + + Kind kind; + Pin file; + OVERLAPPED *overlapped = nullptr; + bool isPipe = false; + struct iovec vec{}; +}; + +class IoUringBackend { + public: + ~IoUringBackend() { shutdown(); } + bool init(); + void shutdown(); + [[nodiscard]] bool running() const noexcept { return mRunning.load(std::memory_order_acquire); } + + bool queueRead(Pin file, OVERLAPPED *ov, void *buffer, DWORD length, + const std::optional &offset, bool isPipe); + bool queueWrite(Pin file, OVERLAPPED *ov, const void *buffer, DWORD length, + const std::optional &offset, bool isPipe); + + private: + bool enqueueRequest(AsyncRequest *req, void *buffer, DWORD length, const std::optional &offset, + bool isWrite); + void requestStop(); + void workerLoop(); + void handleCompletion(struct io_uring_cqe *cqe); + void notifySpace(); + + struct io_uring mRing{}; + std::mutex mSubmitMutex; + std::condition_variable mQueueCv; + std::atomic mRunning{false}; + std::atomic mPending{0}; + std::thread mThread; +}; + +IoUringBackend gBackend; + +} // namespace + +bool initialize() { + if (gBackend.running()) { + return true; + } + return gBackend.init(); +} + +void shutdown() { gBackend.shutdown(); } + +bool running() { return gBackend.running(); } + +bool queueRead(Pin file, OVERLAPPED *ov, void *buffer, DWORD length, + const std::optional &offset, bool isPipe) { + if (!gBackend.running()) { + return false; + } + return gBackend.queueRead(std::move(file), ov, buffer, length, offset, isPipe); +} + +bool queueWrite(Pin file, OVERLAPPED *ov, const void *buffer, DWORD length, + const std::optional &offset, bool isPipe) { + if (!gBackend.running()) { + return false; + } + return gBackend.queueWrite(std::move(file), ov, buffer, length, offset, isPipe); +} + +bool IoUringBackend::init() { + if (mRunning.load(std::memory_order_acquire)) { + return true; + } + int rc = io_uring_queue_init(kQueueDepth, &mRing, 0); + if (rc < 0) { + DEBUG_LOG("io_uring_queue_init failed: %d\n", rc); + return false; + } + mRunning.store(true, std::memory_order_release); + mThread = std::thread(&IoUringBackend::workerLoop, this); + DEBUG_LOG("io_uring backend initialized (depth=%u)\n", kQueueDepth); + return true; +} + +void IoUringBackend::shutdown() { + if (!mRunning.exchange(false, std::memory_order_acq_rel)) { + return; + } + requestStop(); + if (mThread.joinable()) { + mThread.join(); + } + io_uring_queue_exit(&mRing); +} + +bool IoUringBackend::queueRead(Pin file, OVERLAPPED *ov, void *buffer, DWORD length, + const std::optional &offset, bool isPipe) { + auto *req = new AsyncRequest{AsyncRequest::Kind::Read, std::move(file), ov, isPipe}; + if (!enqueueRequest(req, buffer, length, offset, false)) { + delete req; + return false; + } + return true; +} + +bool IoUringBackend::queueWrite(Pin file, OVERLAPPED *ov, const void *buffer, DWORD length, + const std::optional &offset, bool isPipe) { + auto *req = new AsyncRequest{AsyncRequest::Kind::Write, std::move(file), ov, isPipe}; + if (!enqueueRequest(req, const_cast(buffer), length, offset, true)) { + delete req; + return false; + } + return true; +} + +bool IoUringBackend::enqueueRequest(AsyncRequest *req, void *buffer, DWORD length, const std::optional &offset, + bool isWrite) { + std::unique_lock lock(mSubmitMutex); + if (!mRunning.load(std::memory_order_acquire) && req->kind != AsyncRequest::Kind::Shutdown) { + return false; + } + + struct io_uring_sqe *sqe; + while (true) { + sqe = io_uring_get_sqe(&mRing); + if (!sqe) { + mQueueCv.wait(lock); + if (!mRunning.load(std::memory_order_acquire) && req->kind != AsyncRequest::Kind::Shutdown) { + return false; + } + continue; + } + io_uring_sqe_set_data(sqe, req); + if (req->kind == AsyncRequest::Kind::Shutdown) { + io_uring_prep_nop(sqe); + } else { + req->vec.iov_base = buffer; + req->vec.iov_len = length; + off64_t fileOffset = -1; + if (!req->isPipe && offset.has_value()) { + fileOffset = *offset; + } + int fd = req->file ? req->file->fd : -1; + if (isWrite) { + io_uring_prep_writev(sqe, fd, &req->vec, 1, fileOffset); + } else { + io_uring_prep_readv(sqe, fd, &req->vec, 1, fileOffset); + } + } + mPending.fetch_add(1, std::memory_order_relaxed); + break; + } + + while (true) { + int res = io_uring_submit(&mRing); + if (res >= 0) { + break; + } else if (res == -EINTR) { + continue; + } else if (res == -EBUSY || res == -EAGAIN) { + lock.unlock(); + std::this_thread::yield(); + lock.lock(); + continue; + } + DEBUG_LOG("io_uring_submit failed (will retry): %d\n", res); + } + + lock.unlock(); + mQueueCv.notify_one(); + return true; +} + +void IoUringBackend::requestStop() { + mRunning.store(false, std::memory_order_release); + auto *req = new AsyncRequest{AsyncRequest::Kind::Shutdown, Pin{}, nullptr, false}; + if (!enqueueRequest(req, nullptr, 0, std::nullopt, false)) { + delete req; + } +} + +void IoUringBackend::workerLoop() { + while (mRunning.load(std::memory_order_acquire) || mPending.load(std::memory_order_acquire) > 0) { + struct io_uring_cqe *cqe = nullptr; + int ret = io_uring_wait_cqe(&mRing, &cqe); + if (ret == -EINTR) { + continue; + } + if (ret < 0) { + DEBUG_LOG("io_uring_wait_cqe failed: %d\n", ret); + continue; + } + handleCompletion(cqe); + io_uring_cqe_seen(&mRing, cqe); + notifySpace(); + } + + while (mPending.load(std::memory_order_acquire) > 0) { + struct io_uring_cqe *cqe = nullptr; + int ret = io_uring_peek_cqe(&mRing, &cqe); + if (ret != 0 || !cqe) { + break; + } + handleCompletion(cqe); + io_uring_cqe_seen(&mRing, cqe); + notifySpace(); + } +} + +void IoUringBackend::handleCompletion(struct io_uring_cqe *cqe) { + auto *req = static_cast(io_uring_cqe_get_data(cqe)); + if (!req) { + return; + } + + if (req->kind == AsyncRequest::Kind::Shutdown) { + delete req; + mPending.fetch_sub(1, std::memory_order_acq_rel); + return; + } + + OVERLAPPED *ov = req->overlapped; + if (ov) { + if (cqe->res >= 0) { + ov->InternalHigh = static_cast(cqe->res); + if (req->kind == AsyncRequest::Kind::Read && cqe->res == 0) { + ov->Internal = req->isPipe ? STATUS_PIPE_BROKEN : STATUS_END_OF_FILE; + } else { + ov->Internal = STATUS_SUCCESS; + } + } else { + int err = -cqe->res; + ov->InternalHigh = 0; + if (err == EPIPE) { + ov->Internal = STATUS_PIPE_BROKEN; + } else { + NTSTATUS status = wibo::statusFromErrno(err); + if (status == STATUS_SUCCESS) { + status = STATUS_UNEXPECTED_IO_ERROR; + } + ov->Internal = status; + } + } + kernel32::detail::signalOverlappedEvent(ov); + } + + delete req; + mPending.fetch_sub(1, std::memory_order_acq_rel); +} + +void IoUringBackend::notifySpace() { + std::lock_guard lk(mSubmitMutex); + mQueueCv.notify_all(); +} + +} // namespace async_io diff --git a/src/async_io.h b/src/async_io.h new file mode 100644 index 0000000..c922ec8 --- /dev/null +++ b/src/async_io.h @@ -0,0 +1,19 @@ +#pragma once + +#include "kernel32/internal.h" +#include "kernel32/minwinbase.h" + +#include + +namespace async_io { + +bool initialize(); +void shutdown(); +bool running(); + +bool queueRead(Pin file, OVERLAPPED *ov, void *buffer, DWORD length, + const std::optional &offset, bool isPipe); +bool queueWrite(Pin file, OVERLAPPED *ov, const void *buffer, DWORD length, + const std::optional &offset, bool isPipe); + +} // namespace async_io diff --git a/src/errors.cpp b/src/errors.cpp index a916eeb..955bf98 100644 --- a/src/errors.cpp +++ b/src/errors.cpp @@ -54,4 +54,23 @@ NTSTATUS statusFromErrno(int err) { return statusFromWinError(winErrorFromErrno(err)); } +DWORD winErrorFromNtStatus(NTSTATUS status) { + switch (status) { + case STATUS_SUCCESS: + return ERROR_SUCCESS; + case STATUS_PENDING: + return ERROR_IO_PENDING; + case STATUS_END_OF_FILE: + return ERROR_HANDLE_EOF; + case STATUS_INVALID_HANDLE: + return ERROR_INVALID_HANDLE; + case STATUS_INVALID_PARAMETER: + return ERROR_INVALID_PARAMETER; + case STATUS_PIPE_BROKEN: + return ERROR_BROKEN_PIPE; + default: + return ERROR_NOT_SUPPORTED; + } +} + } // namespace wibo diff --git a/src/errors.h b/src/errors.h index 10d2c24..44e7da2 100644 --- a/src/errors.h +++ b/src/errors.h @@ -65,4 +65,5 @@ namespace wibo { DWORD winErrorFromErrno(int err); NTSTATUS statusFromWinError(DWORD error); NTSTATUS statusFromErrno(int err); +DWORD winErrorFromNtStatus(NTSTATUS status); } // namespace wibo diff --git a/src/handles.cpp b/src/handles.cpp index 3e880e7..8f06115 100644 --- a/src/handles.cpp +++ b/src/handles.cpp @@ -1,4 +1,5 @@ #include "handles.h" +#include #include #include @@ -10,7 +11,7 @@ constexpr uint32_t kCompatMaxIndex = (0xFFFFu >> kHandleAlignShift) - 1; // Delay reuse of small handles to avoid accidental stale aliasing constexpr uint32_t kQuarantineLen = 64; -inline uint32_t indexOf(HANDLE h) { +inline uint32_t indexOf(HANDLE h) noexcept { uint32_t v = static_cast(reinterpret_cast(h)); if (v == 0 || (v & ((1U << kHandleAlignShift) - 1)) != 0) { return UINT32_MAX; @@ -18,12 +19,12 @@ inline uint32_t indexOf(HANDLE h) { return (v >> kHandleAlignShift) - 1; } -inline HANDLE makeHandle(uint32_t index) { +inline HANDLE makeHandle(uint32_t index) noexcept { uint32_t v = (index + 1) << kHandleAlignShift; return reinterpret_cast(static_cast(v)); } -inline bool isPseudo(HANDLE h) { return reinterpret_cast(h) < 0; } +inline bool isPseudo(HANDLE h) noexcept { return reinterpret_cast(h) < 0; } } // namespace @@ -63,7 +64,7 @@ HANDLE Handles::alloc(Pin<> obj, uint32_t grantedAccess, uint32_t flags) { } HANDLE h = makeHandle(idx); - e.obj->handleCount.fetch_add(1, std::memory_order_acq_rel); + e.obj->handleCount.fetch_add(1, std::memory_order_relaxed); return h; } @@ -107,7 +108,7 @@ bool Handles::release(HANDLE h) { const auto generation = e.meta.generation + 1; e = {}; // Clear entry e.meta.generation = generation; - uint32_t handleCount = obj->handleCount.fetch_sub(1, std::memory_order_acq_rel) - 1; + uint32_t handleCount = obj->handleCount.fetch_sub(1, std::memory_order_relaxed) - 1; if (idx <= kCompatMaxIndex) { mQuarantine.push_back(idx); @@ -234,6 +235,34 @@ Pin<> Namespace::get(const std::u16string &name) { return Pin<>::acquire(it->second.obj); } +void WaitableObject::registerWaiter(void *context, DWORD index, WaiterCallback cb) { + if (!cb) { + return; + } + std::lock_guard lk(waitersMutex); + waiters.emplace_back(cb, context, index); +} + +void WaitableObject::unregisterWaiter(void *context) { + std::lock_guard lk(waitersMutex); + waiters.erase( + std::remove_if(waiters.begin(), waiters.end(), [context](const Waiter &w) { return w.context == context; }), + waiters.end()); +} + +void WaitableObject::notifyWaiters(bool abandoned) { + std::vector snapshot; + { + std::lock_guard lk(waitersMutex); + snapshot = waiters; + } + for (const auto &w : snapshot) { + if (w.callback) { + w.callback(w.context, this, w.index, abandoned); + } + } +} + namespace wibo { Namespace g_namespace; diff --git a/src/handles.h b/src/handles.h index 261939c..b8cde8d 100644 --- a/src/handles.h +++ b/src/handles.h @@ -2,11 +2,11 @@ #include "common.h" +#include #include #include #include #include -#include #include #include #include @@ -27,111 +27,145 @@ enum class ObjectType : uint16_t { RegistryKey, }; +enum ObjectFlags : uint16_t { + Of_None = 0x0, + Of_Waitable = 0x1, +}; + struct ObjectBase { const ObjectType type; + uint16_t flags = Of_None; std::atomic pointerCount{0}; std::atomic handleCount{0}; - explicit ObjectBase(ObjectType t) : type(t) {} - virtual ~ObjectBase() = default; + explicit ObjectBase(ObjectType t) noexcept : type(t) {} + virtual ~ObjectBase() noexcept = default; +}; - [[nodiscard]] virtual bool isWaitable() const { return false; } +template +concept ObjectBaseType = std::is_base_of_v; + +struct WaitableObject : ObjectBase { + bool signaled = false; // protected by m + std::mutex m; + std::condition_variable cv; + + using WaiterCallback = void (*)(void *, WaitableObject *, DWORD, bool); + struct Waiter { + WaiterCallback callback = nullptr; + void *context = nullptr; + DWORD index = 0; + }; + std::mutex waitersMutex; + std::vector waiters; + + explicit WaitableObject(ObjectType t) : ObjectBase(t) { flags |= Of_Waitable; } + + void registerWaiter(void *context, DWORD index, WaiterCallback cb); + void unregisterWaiter(void *context); + void notifyWaiters(bool abandoned); }; namespace detail { -inline void ref(ObjectBase *o) { o->pointerCount.fetch_add(1, std::memory_order_acq_rel); } -inline void deref(ObjectBase *o) { - if (o->pointerCount.fetch_sub(1, std::memory_order_acq_rel) == 1) { +inline void ref(ObjectBase *o) noexcept { o->pointerCount.fetch_add(1, std::memory_order_relaxed); } +inline void deref(ObjectBase *o) noexcept { + if (o->pointerCount.fetch_sub(1, std::memory_order_release) == 1) { + std::atomic_thread_fence(std::memory_order_acquire); delete o; } } +template constexpr bool typeMatches(const ObjectBase *o) noexcept { + if constexpr (requires { T::kType; }) { + return o && o->type == T::kType; + } else { + static_assert(false, "No kType on U and no typeMatches specialization provided"); + } +} +template <> constexpr bool typeMatches(const ObjectBase *o) noexcept { + return o && (o->flags & Of_Waitable); +} + +template T *castTo(ObjectBase *o) noexcept { + return typeMatches(o) ? static_cast(o) : nullptr; +} + } // namespace detail -struct WaitableObject : ObjectBase { - std::atomic signaled{false}; - std::mutex m; - std::condition_variable_any cv; - - using ObjectBase::ObjectBase; - [[nodiscard]] bool isWaitable() const override { return true; } -}; - -template struct Pin { - static_assert(std::is_base_of_v || std::is_same_v, - "Pin: T must be ObjectBase or derive from it"); - - T *obj = nullptr; +template class Pin { + public: + enum class Tag { Acquire, Adopt }; Pin() = default; - enum class Tag { Acquire, Adopt }; - template ::value>> - explicit Pin(U *p, Tag t) : obj(static_cast(p)) { + template + requires std::is_convertible_v + explicit constexpr Pin(U *p, Tag t) noexcept : obj(static_cast(p)) { if (obj && t == Tag::Acquire) { detail::ref(obj); } } Pin(const Pin &) = delete; - Pin(Pin &&other) noexcept : obj(std::exchange(other.obj, nullptr)) {} - template ::value>> Pin &operator=(Pin &&other) noexcept { + Pin(Pin &&other) noexcept : obj(other.release()) {} + template + requires std::is_base_of_v + Pin &operator=(Pin &&other) noexcept { reset(); - obj = std::exchange(other.obj, nullptr); + obj = other.release(); return *this; } - template ::value>> - Pin(Pin &&other) noexcept : obj(std::exchange(other.obj, nullptr)) {} // NOLINT(google-explicit-constructor) + template + requires std::is_convertible_v + Pin(Pin &&other) noexcept : obj(other.release()) {} // NOLINT(google-explicit-constructor) Pin &operator=(Pin &&other) noexcept { if (this != &other) { reset(); - obj = std::exchange(other.obj, nullptr); + obj = other.release(); } return *this; } - ~Pin() { reset(); } + ~Pin() noexcept { reset(); } - static Pin acquire(T *o) { return Pin{o, Tag::Acquire}; } - static Pin adopt(T *o) { return Pin{o, Tag::Adopt}; } + static Pin acquire(T *o) noexcept { return Pin{o, Tag::Acquire}; } + static constexpr Pin adopt(T *o) noexcept { return Pin{o, Tag::Adopt}; } - [[nodiscard]] T *release() { return std::exchange(obj, nullptr); } - void reset() { - if (obj) { + [[nodiscard]] constexpr T *release() noexcept { return std::exchange(obj, nullptr); } + void reset() noexcept { + if (auto *obj = release()) { detail::deref(obj); - obj = nullptr; } } - T *operator->() const { + constexpr T *operator->() const noexcept { assert(obj); return obj; } - T &operator*() const { + constexpr T &operator*() const noexcept { assert(obj); return *obj; } - [[nodiscard]] T *get() const { return obj; } - [[nodiscard]] Pin clone() const { return Pin::acquire(obj); } - explicit operator bool() const { return obj != nullptr; } + [[nodiscard]] constexpr T *get() const noexcept { return obj; } + [[nodiscard]] Pin clone() const noexcept { return Pin::acquire(obj); } + explicit constexpr operator bool() const noexcept { return obj != nullptr; } - template Pin downcast() && { - static_assert(std::is_base_of_v, "U must derive from ObjectBase"); - if constexpr (std::is_same_v) { + template Pin downcast() && noexcept { + if constexpr (std::is_convertible_v) { return std::move(*this); - } - if (obj && obj->type == U::kType) { - auto *u = static_cast(obj); - obj = nullptr; - return Pin::adopt(u); + } else if (detail::typeMatches(obj)) { + return Pin::adopt(static_cast(std::exchange(obj, nullptr))); } return Pin{}; } + + private: + T *obj = nullptr; }; -template -Pin make_pin(Args &&...args) noexcept(std::is_nothrow_constructible_v) { - T *p = new T(std::forward(args)...); - return Pin::acquire(p); +template + requires std::is_constructible_v +inline Pin make_pin(Args &&...args) noexcept(std::is_nothrow_constructible_v) { + return Pin::acquire(new T(std::forward(args)...)); } constexpr DWORD HANDLE_FLAG_INHERIT = 0x1; @@ -159,8 +193,7 @@ class Handles { HANDLE alloc(Pin<> obj, uint32_t grantedAccess, uint32_t flags); bool release(HANDLE h); Pin<> get(HANDLE h, HandleMeta *metaOut = nullptr); - template Pin getAs(HANDLE h, HandleMeta *metaOut = nullptr) { - static_assert(std::is_base_of_v, "T must derive from ObjectBase"); + template Pin getAs(HANDLE h, HandleMeta *metaOut = nullptr) { HandleMeta metaOutLocal{}; if (!metaOut) { metaOut = &metaOutLocal; @@ -169,13 +202,12 @@ class Handles { if (!obj) { return {}; } - if constexpr (std::is_same_v) { - return std::move(obj); - } else if (metaOut->typeCache != T::kType || obj->type != T::kType) { - return {}; - } else { - return Pin::adopt(static_cast(obj.release())); + if constexpr (requires { T::kType; }) { + if (metaOut->typeCache != T::kType) { + return {}; + } } + return std::move(obj).downcast(); } bool setInformation(HANDLE h, uint32_t mask, uint32_t value); bool getInformation(HANDLE h, uint32_t *outFlags) const; @@ -196,34 +228,37 @@ class Handles { uint32_t nextIndex = 0; }; +template using factory_ptr_t = std::remove_cvref_t>; +template using factory_obj_t = std::remove_pointer_t>; +template +concept ObjectFactoryFn = + std::invocable && std::is_pointer_v> && ObjectBaseType>; + class Namespace { public: bool insert(const std::u16string &name, ObjectBase *obj, bool permanent = false); void remove(ObjectBase *obj); Pin<> get(const std::u16string &name); - template Pin getAs(const std::u16string &name) { + template Pin getAs(const std::u16string &name) { if (auto pin = get(name)) { return std::move(pin).downcast(); } return {}; } - template , - typename T = std::remove_pointer_t>, - std::enable_if_t>::value, int> = 0> - std::pair, bool> getOrCreate(const std::u16string &name, F &&make) { + auto getOrCreate(const std::u16string &name, ObjectFactoryFn auto &&make) + -> std::pair>, bool> { + using T = factory_obj_t; if (name.empty()) { // No name: create unconditionally - T *raw = std::invoke(std::forward(make)); - return {Pin::acquire(raw), true}; + return {Pin::acquire(std::invoke(make)), true}; } if (auto existing = get(name)) { // Return even if downcast fails (don't use getAs) return {std::move(existing).downcast(), false}; } - T *raw = std::invoke(std::forward(make)); - Pin newObj = Pin::acquire(raw); + auto newObj = Pin::acquire(std::invoke(make)); if (!newObj) { return {Pin{}, false}; } diff --git a/src/main.cpp b/src/main.cpp index 93e236d..4099a3b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,4 +1,5 @@ #include "common.h" +#include "async_io.h" #include "context.h" #include "files.h" #include "modules.h" @@ -408,6 +409,7 @@ int main(int argc, char **argv) { blockUpper2GB(); files::init(); wibo::processes().init(); + async_io::initialize(); // Create TIB memset(&tib, 0, sizeof(tib)); diff --git a/src/processes.cpp b/src/processes.cpp index 28403a8..14eaeb5 100644 --- a/src/processes.cpp +++ b/src/processes.cpp @@ -192,13 +192,14 @@ void ProcessManager::checkPidfd(int pidfd) { } { std::lock_guard lk(po->m); - po->signaled.store(true, std::memory_order_release); + po->signaled = true; po->pidfd = -1; if (!po->forcedExitCode) { po->exitCode = decodeExitCode(si); } } po->cv.notify_all(); + po->notifyWaiters(false); { std::lock_guard lk(m); diff --git a/test/test_overlapped_io.c b/test/test_overlapped_io.c index 6426d35..95ada6f 100644 --- a/test/test_overlapped_io.c +++ b/test/test_overlapped_io.c @@ -112,6 +112,80 @@ static void test_getoverlappedresult_pending(void) { TEST_CHECK_EQ(0U, transferred); // No update if the operation is still pending } +static void test_overlapped_multiple_reads(void) { + HANDLE file = CreateFileA(kFilename, GENERIC_READ, FILE_SHARE_READ, NULL, + OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, NULL); + TEST_CHECK(file != INVALID_HANDLE_VALUE); + + OVERLAPPED ov1 = {0}; + OVERLAPPED ov2 = {0}; + ov1.Offset = 0; + ov2.Offset = 16; + ov1.hEvent = CreateEventA(NULL, TRUE, FALSE, NULL); + ov2.hEvent = CreateEventA(NULL, TRUE, FALSE, NULL); + TEST_CHECK(ov1.hEvent != NULL); + TEST_CHECK(ov2.hEvent != NULL); + + char head[8] = {0}; + char tail[8] = {0}; + + BOOL issued1 = ReadFile(file, head, 5, NULL, &ov1); + if (!issued1) { + TEST_CHECK_EQ(ERROR_IO_PENDING, GetLastError()); + } + + BOOL issued2 = ReadFile(file, tail, 5, NULL, &ov2); + if (!issued2) { + TEST_CHECK_EQ(ERROR_IO_PENDING, GetLastError()); + } + + HANDLE events[2] = {ov1.hEvent, ov2.hEvent}; + DWORD waitResult = WaitForMultipleObjects(2, events, TRUE, 1000); + TEST_CHECK_EQ(WAIT_OBJECT_0, waitResult); + + DWORD transferred = 0; + TEST_CHECK(GetOverlappedResult(file, &ov1, &transferred, FALSE)); + TEST_CHECK_EQ(5U, transferred); + head[5] = '\0'; + TEST_CHECK_STR_EQ("01234", head); + + transferred = 0; + TEST_CHECK(GetOverlappedResult(file, &ov2, &transferred, FALSE)); + TEST_CHECK_EQ(5U, transferred); + tail[5] = '\0'; + TEST_CHECK_STR_EQ("GHIJK", tail); + + TEST_CHECK(CloseHandle(ov2.hEvent)); + TEST_CHECK(CloseHandle(ov1.hEvent)); + TEST_CHECK(CloseHandle(file)); +} + +static void test_getoverlappedresult_wait(void) { + HANDLE file = CreateFileA(kFilename, GENERIC_READ, FILE_SHARE_READ, NULL, + OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, NULL); + TEST_CHECK(file != INVALID_HANDLE_VALUE); + + OVERLAPPED ov = {0}; + ov.Offset = 20; + ov.hEvent = CreateEventA(NULL, FALSE, FALSE, NULL); + TEST_CHECK(ov.hEvent != NULL); + + char buffer[8] = {0}; + BOOL issued = ReadFile(file, buffer, 6, NULL, &ov); + if (!issued) { + TEST_CHECK_EQ(ERROR_IO_PENDING, GetLastError()); + } + + DWORD transferred = 0; + TEST_CHECK(GetOverlappedResult(file, &ov, &transferred, TRUE)); + TEST_CHECK_EQ(6U, transferred); + buffer[6] = '\0'; + TEST_CHECK_STR_EQ("KLMNOP", buffer); + + TEST_CHECK(CloseHandle(ov.hEvent)); + TEST_CHECK(CloseHandle(file)); +} + static void test_overlapped_write(void) { HANDLE file = CreateFileA(kFilename, GENERIC_WRITE, 0, NULL, OPEN_EXISTING, FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, NULL); @@ -157,6 +231,8 @@ int main(void) { test_overlapped_read_with_event(); test_overlapped_eof(); test_getoverlappedresult_pending(); + test_overlapped_multiple_reads(); + test_getoverlappedresult_wait(); test_overlapped_write(); TEST_CHECK(DeleteFileA(kFilename)); return 0;