diff --git a/CMakeLists.txt b/CMakeLists.txt index e46fe24..3512713 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -223,9 +223,11 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Linux") target_sources(wibo PRIVATE src/async_io_epoll.cpp src/processes_linux.cpp + src/setup_linux.cpp ) elseif (CMAKE_SYSTEM_NAME STREQUAL "Darwin") target_sources(wibo PRIVATE + src/async_io_kqueue.cpp src/processes_darwin.cpp src/setup_darwin.cpp ) diff --git a/src/async_io.cpp b/src/async_io.cpp index 6b99333..a238590 100644 --- a/src/async_io.cpp +++ b/src/async_io.cpp @@ -52,6 +52,9 @@ static constexpr BackendEntry kBackends[] = { #ifdef __linux__ {"epoll", detail::createEpollBackend}, #endif +#ifdef __APPLE__ + {"kqueue", detail::createKqueueBackend}, +#endif }; AsyncIOBackend &asyncIO() { diff --git a/src/async_io.h b/src/async_io.h index d407186..7268041 100644 --- a/src/async_io.h +++ b/src/async_io.h @@ -27,6 +27,9 @@ std::unique_ptr createIoUringBackend(); #ifdef __linux__ std::unique_ptr createEpollBackend(); #endif +#ifdef __APPLE__ +std::unique_ptr createKqueueBackend(); +#endif } // namespace detail diff --git a/src/async_io_kqueue.cpp b/src/async_io_kqueue.cpp new file mode 100644 index 0000000..95bc90e --- /dev/null +++ b/src/async_io_kqueue.cpp @@ -0,0 +1,781 @@ +#include "async_io.h" + +#include "errors.h" +#include "files.h" +#include "kernel32/overlapped_util.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +constexpr NTSTATUS kStatusCancelled = static_cast(0xC0000120); + +struct AsyncRequest { + enum class Kind { Read, Write }; + + Kind kind; + Pin file; + OVERLAPPED *overlapped = nullptr; + std::optional offset; + bool isPipe = false; + bool updateFilePointer = false; + void *readBuffer = nullptr; + const uint8_t *writeBuffer = nullptr; + size_t length = 0; + size_t progress = 0; +}; + +struct FileState { + explicit FileState(int fdIn) : fd(fdIn) {} + + int fd; + bool readRegistered = false; + bool writeRegistered = false; + int originalFlags = -1; + std::deque> readQueue; + std::deque> writeQueue; +}; + +struct ProcessResult { + bool completed = false; + bool requeue = false; + NTSTATUS status = STATUS_SUCCESS; + size_t bytesTransferred = 0; +}; + +struct Completion { + std::unique_ptr req; + NTSTATUS status = STATUS_SUCCESS; + size_t bytesTransferred = 0; +}; + +class KqueueBackend : public wibo::AsyncIOBackend { + public: + ~KqueueBackend() override { shutdown(); } + + bool init() override; + void shutdown() override; + [[nodiscard]] bool running() const noexcept override { return mRunning.load(std::memory_order_acquire); } + + bool queueRead(Pin file, OVERLAPPED *ov, void *buffer, DWORD length, + const std::optional &offset, bool isPipe) override; + bool queueWrite(Pin file, OVERLAPPED *ov, const void *buffer, DWORD length, + const std::optional &offset, bool isPipe) override; + + private: + bool enqueueRequest(std::unique_ptr req); + bool enqueueFileRequest(std::unique_ptr req); + void workerLoop(); + void fileWorkerLoop(); + void handleFileEvent(FileState &state, const struct kevent &kev); + void notifyWorker() const; + void updateRegistrationLocked(FileState &state) const; + static void ensureNonBlocking(FileState &state); + static void restoreOriginalFlags(FileState &state); + void processCompletions(); + void failAllPending(); + void completeRequest(const AsyncRequest &req, NTSTATUS status, size_t bytesTransferred); + static Completion processBlockingRequest(AsyncRequest &req); + + static ProcessResult tryProcessPipeRead(AsyncRequest &req); + static ProcessResult tryProcessPipeWrite(AsyncRequest &req); + + std::atomic mRunning{false}; + std::atomic mPending{0}; + int mKqueueFd = -1; + const uintptr_t mWakeIdent = 1; + std::thread mThread; + + std::mutex mMutex; + std::unordered_map> mFileStates; + + std::mutex mFileQueueMutex; + std::condition_variable mFileQueueCv; + std::deque> mFileQueue; + bool mFileStopping = false; + std::vector mFileWorkers; + + std::mutex mCompletionMutex; + std::deque mCompletions; +}; + +bool KqueueBackend::init() { + if (mRunning.load(std::memory_order_acquire)) { + return true; + } + + mKqueueFd = kqueue(); + if (mKqueueFd < 0) { + DEBUG_LOG("AsyncIO(kqueue): kqueue() failed: %d (%s)\n", errno, strerror(errno)); + return false; + } + + struct kevent kev; + EV_SET(&kev, mWakeIdent, EVFILT_USER, EV_ADD | EV_CLEAR, 0, 0, nullptr); + if (kevent(mKqueueFd, &kev, 1, nullptr, 0, nullptr) < 0) { + DEBUG_LOG("AsyncIO(kqueue): kevent(EV_ADD user) failed: %d (%s)\n", errno, strerror(errno)); + close(mKqueueFd); + mKqueueFd = -1; + return false; + } + + unsigned int workerCount = std::thread::hardware_concurrency(); + if (workerCount == 0) { + workerCount = 1; + } + workerCount = std::min(workerCount, 4u); + + { + std::lock_guard lk(mFileQueueMutex); + mFileStopping = false; + } + mFileWorkers.reserve(workerCount); + for (unsigned int i = 0; i < workerCount; ++i) { + mFileWorkers.emplace_back(&KqueueBackend::fileWorkerLoop, this); + } + + mRunning.store(true, std::memory_order_release); + mThread = std::thread(&KqueueBackend::workerLoop, this); + DEBUG_LOG("AsyncIO: kqueue backend initialized\n"); + return true; +} + +void KqueueBackend::shutdown() { + if (!mRunning.exchange(false, std::memory_order_acq_rel)) { + return; + } + + { + std::lock_guard lk(mFileQueueMutex); + mFileStopping = true; + } + mFileQueueCv.notify_all(); + notifyWorker(); + + if (mThread.joinable()) { + mThread.join(); + } + + for (auto &worker : mFileWorkers) { + if (worker.joinable()) { + worker.join(); + } + } + mFileWorkers.clear(); + + if (mKqueueFd >= 0) { + close(mKqueueFd); + mKqueueFd = -1; + } + + { + std::lock_guard lk(mMutex); + for (auto &entry : mFileStates) { + restoreOriginalFlags(*entry.second); + } + mFileStates.clear(); + } + { + std::lock_guard lk(mFileQueueMutex); + mFileQueue.clear(); + } + { + std::lock_guard lk(mCompletionMutex); + mCompletions.clear(); + } + mPending.store(0, std::memory_order_release); +} + +bool KqueueBackend::queueRead(Pin file, OVERLAPPED *ov, void *buffer, DWORD length, + const std::optional &offset, bool isPipe) { + auto req = std::make_unique(AsyncRequest::Kind::Read); + req->file = std::move(file); + req->overlapped = ov; + req->offset = offset; + req->isPipe = isPipe; + req->updateFilePointer = req->file ? !req->file->overlapped : true; + req->readBuffer = buffer; + req->length = length; + return enqueueRequest(std::move(req)); +} + +bool KqueueBackend::queueWrite(Pin file, OVERLAPPED *ov, const void *buffer, DWORD length, + const std::optional &offset, bool isPipe) { + auto req = std::make_unique(AsyncRequest::Kind::Write); + req->file = std::move(file); + req->overlapped = ov; + req->offset = offset; + req->isPipe = isPipe; + req->updateFilePointer = req->file ? !req->file->overlapped : true; + req->writeBuffer = static_cast(buffer); + req->length = length; + return enqueueRequest(std::move(req)); +} + +bool KqueueBackend::enqueueRequest(std::unique_ptr req) { + if (!req || !req->file || !req->file->valid()) { + return false; + } + if (!mRunning.load(std::memory_order_acquire)) { + return false; + } + + if (req->isPipe) { + std::lock_guard lk(mMutex); + if (!mRunning.load(std::memory_order_acquire)) { + return false; + } + mPending.fetch_add(1, std::memory_order_acq_rel); + const int fd = req->file->fd; + auto &statePtr = mFileStates[fd]; + if (!statePtr) { + statePtr = std::make_unique(fd); + } + FileState &state = *statePtr; + ensureNonBlocking(state); + if (req->kind == AsyncRequest::Kind::Read) { + state.readQueue.emplace_back(std::move(req)); + } else { + state.writeQueue.emplace_back(std::move(req)); + } + updateRegistrationLocked(state); + notifyWorker(); + return true; + } + + mPending.fetch_add(1, std::memory_order_acq_rel); + if (enqueueFileRequest(std::move(req))) { + return true; + } + mPending.fetch_sub(1, std::memory_order_acq_rel); + return false; +} + +bool KqueueBackend::enqueueFileRequest(std::unique_ptr req) { + std::lock_guard lk(mFileQueueMutex); + if (mFileStopping) { + return false; + } + mFileQueue.emplace_back(std::move(req)); + mFileQueueCv.notify_one(); + return true; +} + +void KqueueBackend::workerLoop() { + std::array events{}; + + while (true) { + processCompletions(); + + if (!mRunning.load(std::memory_order_acquire) && mPending.load(std::memory_order_acquire) == 0) { + break; + } + + const bool waitIndefinitely = mRunning.load(std::memory_order_acquire); + struct timespec ts; + struct timespec *timeout = nullptr; + if (!waitIndefinitely) { + ts.tv_sec = 0; + ts.tv_nsec = 10 * 1000 * 1000; // 10ms + timeout = &ts; + } + + int count = kevent(mKqueueFd, nullptr, 0, events.data(), static_cast(events.size()), timeout); + if (count < 0) { + if (errno == EINTR) { + continue; + } + DEBUG_LOG("AsyncIO(kqueue): kevent wait failed: %d (%s)\n", errno, strerror(errno)); + continue; + } + if (count == 0) { + continue; + } + + for (int i = 0; i < count; ++i) { + auto &ev = events[static_cast(i)]; + if (ev.filter == EVFILT_USER) { + processCompletions(); + continue; + } + auto *state = static_cast(ev.udata); + if (!state) { + continue; + } + handleFileEvent(*state, ev); + } + } + + processCompletions(); + failAllPending(); +} + +void KqueueBackend::fileWorkerLoop() { + while (true) { + std::unique_ptr req; + { + std::unique_lock lk(mFileQueueMutex); + mFileQueueCv.wait(lk, [&] { return mFileStopping || !mFileQueue.empty(); }); + if (mFileStopping && mFileQueue.empty()) { + break; + } + req = std::move(mFileQueue.front()); + mFileQueue.pop_front(); + } + + if (!req) { + continue; + } + + Completion completion = processBlockingRequest(*req); + completion.req = std::move(req); + { + std::lock_guard lk(mCompletionMutex); + mCompletions.emplace_back(std::move(completion)); + } + notifyWorker(); + } +} + +void KqueueBackend::handleFileEvent(FileState &state, const struct kevent &kev) { + if ((kev.flags & EV_ERROR) != 0 && kev.data != 0) { + DEBUG_LOG("AsyncIO(kqueue): event error on fd %d filter %d: %s\n", state.fd, kev.filter, + strerror(static_cast(kev.data))); + } + + const bool canRead = kev.filter == EVFILT_READ || (kev.flags & EV_EOF) != 0; + const bool canWrite = kev.filter == EVFILT_WRITE || (kev.flags & EV_EOF) != 0; + + if (canRead) { + while (true) { + std::unique_ptr req; + { + std::lock_guard lk(mMutex); + if (state.readQueue.empty()) { + break; + } + req = std::move(state.readQueue.front()); + state.readQueue.pop_front(); + } + + auto result = tryProcessPipeRead(*req); + if (result.requeue) { + std::lock_guard lk(mMutex); + state.readQueue.emplace_front(std::move(req)); + updateRegistrationLocked(state); + break; + } + if (result.completed) { + completeRequest(*req, result.status, result.bytesTransferred); + } + { + std::lock_guard lk(mMutex); + updateRegistrationLocked(state); + } + } + } + + if (canWrite) { + while (true) { + std::unique_ptr req; + { + std::lock_guard lk(mMutex); + if (state.writeQueue.empty()) { + break; + } + req = std::move(state.writeQueue.front()); + state.writeQueue.pop_front(); + } + + auto result = tryProcessPipeWrite(*req); + if (result.requeue) { + std::lock_guard lk(mMutex); + state.writeQueue.emplace_front(std::move(req)); + updateRegistrationLocked(state); + break; + } + if (result.completed) { + completeRequest(*req, result.status, result.bytesTransferred); + } + { + std::lock_guard lk(mMutex); + updateRegistrationLocked(state); + } + } + } + + const int fd = state.fd; + { + std::lock_guard lk(mMutex); + auto it = mFileStates.find(fd); + if (it != mFileStates.end() && it->second.get() == &state) { + FileState *ptr = it->second.get(); + if (!ptr->readRegistered && !ptr->writeRegistered && ptr->readQueue.empty() && ptr->writeQueue.empty()) { + restoreOriginalFlags(*ptr); + mFileStates.erase(it); + } + } + } +} + +void KqueueBackend::notifyWorker() const { + if (mKqueueFd < 0) { + return; + } + struct kevent kev; + EV_SET(&kev, mWakeIdent, EVFILT_USER, 0, NOTE_TRIGGER, 0, nullptr); + while (kevent(mKqueueFd, &kev, 1, nullptr, 0, nullptr) < 0) { + if (errno == EINTR) { + continue; + } + break; + } +} + +void KqueueBackend::updateRegistrationLocked(FileState &state) const { + const bool needRead = !state.readQueue.empty(); + const bool needWrite = !state.writeQueue.empty(); + + if (!needRead && !needWrite) { + if (state.readRegistered) { + struct kevent kev; + EV_SET(&kev, state.fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr); + if (kevent(mKqueueFd, &kev, 1, nullptr, 0, nullptr) != 0 && errno != ENOENT) { + DEBUG_LOG("AsyncIO(kqueue): kevent delete read fd %d failed: %d (%s)\n", state.fd, errno, + strerror(errno)); + } + state.readRegistered = false; + } + if (state.writeRegistered) { + struct kevent kev; + EV_SET(&kev, state.fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr); + if (kevent(mKqueueFd, &kev, 1, nullptr, 0, nullptr) != 0 && errno != ENOENT) { + DEBUG_LOG("AsyncIO(kqueue): kevent delete write fd %d failed: %d (%s)\n", state.fd, errno, + strerror(errno)); + } + state.writeRegistered = false; + } + restoreOriginalFlags(state); + return; + } + + if (needRead && !state.readRegistered) { + struct kevent kev; + EV_SET(&kev, state.fd, EVFILT_READ, EV_ADD | EV_ENABLE | EV_CLEAR, 0, 0, &state); + if (kevent(mKqueueFd, &kev, 1, nullptr, 0, nullptr) != 0) { + DEBUG_LOG("AsyncIO(kqueue): kevent add read fd %d failed: %d (%s)\n", state.fd, errno, strerror(errno)); + } else { + state.readRegistered = true; + } + } + + if (!needRead && state.readRegistered) { + struct kevent kev; + EV_SET(&kev, state.fd, EVFILT_READ, EV_DELETE, 0, 0, nullptr); + if (kevent(mKqueueFd, &kev, 1, nullptr, 0, nullptr) != 0 && errno != ENOENT) { + DEBUG_LOG("AsyncIO(kqueue): kevent delete read fd %d failed: %d (%s)\n", state.fd, errno, strerror(errno)); + } + state.readRegistered = false; + } + + if (needWrite && !state.writeRegistered) { + struct kevent kev; + EV_SET(&kev, state.fd, EVFILT_WRITE, EV_ADD | EV_ENABLE | EV_CLEAR, 0, 0, &state); + if (kevent(mKqueueFd, &kev, 1, nullptr, 0, nullptr) != 0) { + DEBUG_LOG("AsyncIO(kqueue): kevent add write fd %d failed: %d (%s)\n", state.fd, errno, strerror(errno)); + } else { + state.writeRegistered = true; + } + } + + if (!needWrite && state.writeRegistered) { + struct kevent kev; + EV_SET(&kev, state.fd, EVFILT_WRITE, EV_DELETE, 0, 0, nullptr); + if (kevent(mKqueueFd, &kev, 1, nullptr, 0, nullptr) != 0 && errno != ENOENT) { + DEBUG_LOG("AsyncIO(kqueue): kevent delete write fd %d failed: %d (%s)\n", state.fd, errno, strerror(errno)); + } + state.writeRegistered = false; + } + + if (state.readRegistered || state.writeRegistered) { + ensureNonBlocking(state); + } else { + restoreOriginalFlags(state); + } +} + +void KqueueBackend::ensureNonBlocking(FileState &state) { + if (state.originalFlags >= 0) { + return; + } + + int flags = fcntl(state.fd, F_GETFL, 0); + if (flags < 0) { + DEBUG_LOG("AsyncIO(kqueue): fcntl(F_GETFL) failed for fd %d: %d (%s)\n", state.fd, errno, strerror(errno)); + return; + } + + if ((flags & O_NONBLOCK) != 0) { + return; + } + + state.originalFlags = flags; + if (fcntl(state.fd, F_SETFL, flags | O_NONBLOCK) != 0) { + DEBUG_LOG("AsyncIO(kqueue): fcntl(F_SETFL) failed for fd %d: %d (%s)\n", state.fd, errno, strerror(errno)); + state.originalFlags = -1; + } +} + +void KqueueBackend::restoreOriginalFlags(FileState &state) { + if (state.originalFlags < 0) { + return; + } + + if (fcntl(state.fd, F_SETFL, state.originalFlags) != 0) { + DEBUG_LOG("AsyncIO(kqueue): restoring flags for fd %d failed: %d (%s)\n", state.fd, errno, strerror(errno)); + } + + state.originalFlags = -1; +} + +void KqueueBackend::processCompletions() { + std::deque pending; + { + std::lock_guard lk(mCompletionMutex); + if (mCompletions.empty()) { + return; + } + pending.swap(mCompletions); + } + + for (auto &entry : pending) { + if (entry.req) { + completeRequest(*entry.req, entry.status, entry.bytesTransferred); + } + } +} + +void KqueueBackend::failAllPending() { + processCompletions(); + + std::vector> pending; + { + std::lock_guard lk(mMutex); + for (auto &entry : mFileStates) { + auto &state = *entry.second; + while (!state.readQueue.empty()) { + pending.emplace_back(std::move(state.readQueue.front())); + state.readQueue.pop_front(); + } + while (!state.writeQueue.empty()) { + pending.emplace_back(std::move(state.writeQueue.front())); + state.writeQueue.pop_front(); + } + state.readRegistered = false; + state.writeRegistered = false; + restoreOriginalFlags(state); + } + } + + { + std::lock_guard lk(mFileQueueMutex); + while (!mFileQueue.empty()) { + pending.emplace_back(std::move(mFileQueue.front())); + mFileQueue.pop_front(); + } + } + + std::deque completions; + { + std::lock_guard lk(mCompletionMutex); + completions.swap(mCompletions); + } + for (auto &entry : completions) { + if (entry.req) { + completeRequest(*entry.req, entry.status, entry.bytesTransferred); + } + } + + for (auto &req : pending) { + if (req) { + completeRequest(*req, kStatusCancelled, 0); + } + } +} + +void KqueueBackend::completeRequest(const AsyncRequest &req, NTSTATUS status, size_t bytesTransferred) { + kernel32::detail::signalOverlappedEvent(req.file.get(), req.overlapped, status, bytesTransferred); + mPending.fetch_sub(1, std::memory_order_acq_rel); +} + +Completion KqueueBackend::processBlockingRequest(AsyncRequest &req) { + Completion result{}; + if (!req.file || !req.file->valid()) { + result.status = STATUS_INVALID_HANDLE; + return result; + } + + files::IOResult io{}; + if (req.kind == AsyncRequest::Kind::Read) { + io = files::read(req.file.get(), req.readBuffer, req.length, req.offset, req.updateFilePointer); + } else { + io = files::write(req.file.get(), req.writeBuffer, req.length, req.offset, req.updateFilePointer); + } + + result.bytesTransferred = io.bytesTransferred; + + if (io.unixError != 0) { + result.status = wibo::statusFromErrno(io.unixError); + if (result.status == STATUS_SUCCESS) { + result.status = STATUS_UNEXPECTED_IO_ERROR; + } + } else if (req.kind == AsyncRequest::Kind::Read && io.bytesTransferred == 0 && io.reachedEnd) { + result.status = req.isPipe ? STATUS_PIPE_BROKEN : STATUS_END_OF_FILE; + } else if (req.kind == AsyncRequest::Kind::Write && io.bytesTransferred == 0 && io.reachedEnd) { + result.status = STATUS_END_OF_FILE; + } else { + result.status = STATUS_SUCCESS; + } + + return result; +} + +ProcessResult KqueueBackend::tryProcessPipeRead(AsyncRequest &req) { + ProcessResult result{}; + if (!req.file || !req.file->valid()) { + result.completed = true; + result.status = STATUS_INVALID_HANDLE; + return result; + } + const int fd = req.file->fd; + if (req.length == 0) { + result.completed = true; + result.status = STATUS_SUCCESS; + return result; + } + + uint8_t *buffer = static_cast(req.readBuffer); + size_t toRead = req.length; + while (true) { + size_t chunk = std::min(toRead, static_cast(SSIZE_MAX)); + ssize_t rc = ::read(fd, buffer, chunk); + if (rc > 0) { + result.completed = true; + result.status = STATUS_SUCCESS; + result.bytesTransferred = static_cast(rc); + return result; + } + if (rc == 0) { + result.completed = true; + result.status = req.isPipe ? STATUS_PIPE_BROKEN : STATUS_END_OF_FILE; + result.bytesTransferred = 0; + return result; + } + if (rc == -1) { + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + result.requeue = true; + return result; + } + int err = errno ? errno : EIO; + result.completed = true; + if (err == EPIPE || err == ECONNRESET) { + result.status = STATUS_PIPE_BROKEN; + } else { + result.status = wibo::statusFromErrno(err); + if (result.status == STATUS_SUCCESS) { + result.status = STATUS_UNEXPECTED_IO_ERROR; + } + } + result.bytesTransferred = 0; + return result; + } + } +} + +ProcessResult KqueueBackend::tryProcessPipeWrite(AsyncRequest &req) { + ProcessResult result{}; + if (!req.file || !req.file->valid()) { + result.completed = true; + result.status = STATUS_INVALID_HANDLE; + return result; + } + + const int fd = req.file->fd; + size_t remaining = req.length - req.progress; + const uint8_t *buffer = req.writeBuffer ? req.writeBuffer + req.progress : nullptr; + + while (remaining > 0) { + size_t chunk = std::min(remaining, static_cast(SSIZE_MAX)); + ssize_t rc = ::write(fd, buffer, chunk); + if (rc > 0) { + size_t written = static_cast(rc); + req.progress += written; + remaining -= written; + buffer += written; + if (req.offset.has_value()) { + *req.offset += static_cast(written); + } + continue; + } + if (rc == 0) { + break; + } + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + result.requeue = true; + return result; + } + int err = errno ? errno : EIO; + result.completed = true; + if (err == EPIPE || err == ECONNRESET) { + result.status = STATUS_PIPE_BROKEN; + } else { + result.status = wibo::statusFromErrno(err); + if (result.status == STATUS_SUCCESS) { + result.status = STATUS_UNEXPECTED_IO_ERROR; + } + } + result.bytesTransferred = req.progress; + return result; + } + + if (remaining == 0) { + result.completed = true; + result.status = STATUS_SUCCESS; + result.bytesTransferred = req.progress; + } else { + result.requeue = true; + } + return result; +} + +} // namespace + +namespace wibo::detail { + +std::unique_ptr createKqueueBackend() { return std::make_unique(); } + +} // namespace wibo::detail diff --git a/src/common.h b/src/common.h index 0362605..a0f1952 100644 --- a/src/common.h +++ b/src/common.h @@ -41,6 +41,7 @@ TEB *allocateTib(); void destroyTib(TEB *tib); void initializeTibStackInfo(TEB *tib); bool installTibForCurrentThread(TEB *tib); +void uninstallTebForCurrentThread(); void debug_log(const char *fmt, ...); diff --git a/src/macros.S b/src/macros.S index 461c586..9e98481 100644 --- a/src/macros.S +++ b/src/macros.S @@ -41,8 +41,7 @@ m1632: .long 1f # 32-bit code offset .long 0 # 32-bit code segment (filled in at runtime) .text - mov r10w, word ptr [\teb_reg+TEB_FS_SEL] - sub r10w, 16 + mov r10w, word ptr [\teb_reg+TEB_CS_SEL] mov word ptr [rip+m1632+4], r10w jmp fword ptr [rip+m1632] #else diff --git a/src/macros.h b/src/macros.h index a2e51d8..1a76a5c 100644 --- a/src/macros.h +++ b/src/macros.h @@ -12,6 +12,8 @@ #ifdef __x86_64__ +#define TEB_CS_SEL 0xf9c // CodeSelector +#define TEB_DS_SEL 0xf9e // DataSelector #define TEB_SP 0xfa0 // CurrentStackPointer #define TEB_FSBASE 0xfa8 // HostFsBase #define TEB_GSBASE 0xfb0 // HostGsBase @@ -78,4 +80,4 @@ static_cast(reinterpret_cast(&GLUE(symbol, End)) - \ reinterpret_cast(&symbol)) #define INCLUDE_BIN_SPAN(symbol) std::span(symbol, INCLUDE_BIN_SIZE(symbol)) -#endif \ No newline at end of file +#endif diff --git a/src/main.cpp b/src/main.cpp index 7de5e9c..fa95bf7 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -5,6 +5,7 @@ #include "heap.h" #include "modules.h" #include "processes.h" +#include "setup.h" #include "strutil.h" #include "tls.h" #include "types.h" @@ -14,23 +15,8 @@ #include #include #include -#include #include #include -#include -#include -#include -#include - -#ifdef __x86_64__ -#include "setup.h" -#endif - -#ifdef __linux__ -#include -#include -#include -#endif char **wibo::argv; int wibo::argc; @@ -104,41 +90,13 @@ bool wibo::installTibForCurrentThread(TEB *tibPtr) { if (!tibPtr) { return false; } - currentThreadTeb = tibPtr; -#ifdef __x86_64__ - tibEntryNumber = tebThreadSetup(tibEntryNumber, tibPtr); - if (tibEntryNumber < 0 || tibPtr->CurrentFsSelector == 0) { - perror("x86_64_thread_setup failed"); - return false; - } -#else - struct user_desc desc; - std::memset(&desc, 0, sizeof(desc)); - desc.entry_number = tibEntryNumber; - desc.base_addr = reinterpret_cast(tibPtr); - desc.limit = static_cast(sizeof(TEB) - 1); - desc.seg_32bit = 1; - desc.contents = 0; - desc.read_exec_only = 0; - desc.limit_in_pages = 0; - desc.seg_not_present = 0; - desc.useable = 1; - if (syscall(SYS_set_thread_area, &desc) != 0) { - perror("set_thread_area failed"); - return false; - } - if (tibEntryNumber != static_cast(desc.entry_number)) { - tibEntryNumber = static_cast(desc.entry_number); - DEBUG_LOG("set_thread_area: allocated entry=%d base=%p\n", tibEntryNumber, tibPtr); - } else { - DEBUG_LOG("set_thread_area: reused entry=%d base=%p\n", tibEntryNumber, tibPtr); - } + return tebThreadSetup(tibPtr); +} - tibPtr->CurrentFsSelector = static_cast((desc.entry_number << 3) | 3); - tibPtr->CurrentGsSelector = 0; -#endif - return true; +void wibo::uninstallTebForCurrentThread() { + TEB* teb = std::exchange(currentThreadTeb, nullptr); + tebThreadTeardown(teb); } static std::string getExeName(const char *argv0) { @@ -375,7 +333,7 @@ int main(int argc, char **argv) { wibo::processPeb = peb; wibo::initializeTibStackInfo(tib); if (!wibo::installTibForCurrentThread(tib)) { - fprintf(stderr, "Failed to install TIB for main thread\n"); + perror("Failed to install TIB for main thread"); return 1; } diff --git a/src/setup.S b/src/setup.S index 3fb4556..7e34f0d 100644 --- a/src/setup.S +++ b/src/setup.S @@ -8,8 +8,8 @@ #if defined(__x86_64__) && defined(__linux__) -# int tebThreadSetup(int entryNumber, TEB *teb) -ASM_GLOBAL(tebThreadSetup, @function) +# int tebThreadSetup64(int entryNumber, TEB *teb) +ASM_GLOBAL(tebThreadSetup64, @function) push rbx # save rbx mov r8, rsp # save host stack rdfsbase r9 # read host FS base @@ -28,17 +28,9 @@ ASM_GLOBAL(tebThreadSetup, @function) mov eax, 0xf3 # SYS_set_thread_area int 0x80 # syscall test eax, eax # check for error - jnz 1f # skip selector setup + jnz 1f # skip if error mov eax, dword ptr [esp] # entry_number - cmp eax, -1 # check for invalid entry_number - jz 2f # skip selector setup - lea ebx, [eax*8+3] # create selector - mov fs, bx # setup fs segment - mov word ptr [esi+TEB_FS_SEL], bx # save selector - jmp 2f # skip error handling 1: - mov eax, -1 # return -1 -2: add esp, 0x10 # cleanup stack LJMP64 esi # far jump into 64-bit code cdqe # sign-extend eax to rax @@ -46,11 +38,26 @@ ASM_GLOBAL(tebThreadSetup, @function) wrfsbase r9 # restore host FS base pop rbx # restore rbx ret -ASM_END(tebThreadSetup) +ASM_END(tebThreadSetup64) -#endif // __x86_64__ +#endif // defined(__x86_64__) && defined(__linux__) -.code32 +#if defined(__x86_64__) && defined(__APPLE__) + +# bool installSelectors(TEB *teb) +ASM_GLOBAL(installSelectors, @function) + mov ax, word ptr [rdi+TEB_DS_SEL] # fetch data segment selector + mov dx, word ptr [rdi+TEB_FS_SEL] # fetch fs selector + LJMP32 rdi # far jump into 32-bit code (sets cs) + mov ds, ax # setup data segment + mov es, ax # setup extra segment + mov fs, dx # setup fs segment + LJMP64 edi # far jump into 64-bit code + mov rax, 1 # return true + ret +ASM_END(installSelectors) + +#endif .macro stubThunkX number #if defined(__x86_64__) @@ -67,6 +74,7 @@ ASM_GLOBAL(STUB_THUNK_SYMBOL, @function) ASM_END(STUB_THUNK_SYMBOL) .endm +.code32 stubThunkX 0 stubThunkX 1 stubThunkX 2 diff --git a/src/setup.h b/src/setup.h index 5d8b48f..bf4c4dc 100644 --- a/src/setup.h +++ b/src/setup.h @@ -2,13 +2,14 @@ #include "types.h" +#define USER_PRIVILEGE 3 + #ifdef __cplusplus extern "C" { #endif -#ifdef __x86_64__ -int tebThreadSetup(int entryNumber, TEB *teb); -#endif +bool tebThreadSetup(TEB *teb); +bool tebThreadTeardown(TEB *teb); #ifdef __cplusplus } diff --git a/src/setup_darwin.cpp b/src/setup_darwin.cpp index 53c8e7c..c0e1d3e 100644 --- a/src/setup_darwin.cpp +++ b/src/setup_darwin.cpp @@ -3,8 +3,11 @@ #include "types.h" +#include +#include #include #include +#include #include #include @@ -12,10 +15,20 @@ // https://github.com/apple/darwin-libpthread/blob/03c4628c8940cca6fd6a82957f683af804f62e7f/private/tsd_private.h#L92-L97 #define _PTHREAD_TSD_SLOT_RESERVED_WIN64 6 -#define USER_PRIVILEGE 3 +// Implemented in setup.S +extern "C" int installSelectors(TEB *teb); namespace { +std::mutex g_tebSetupMutex; +uint16_t g_codeSelector = 0; +uint16_t g_dataSelector = 0; +constexpr int kMaxLdtEntries = 8192; +constexpr int kBitsPerWord = 32; +std::array g_ldtBitmap{}; +bool g_ldtBitmapInitialized = false; +int g_ldtHint = 1; + inline ldt_entry createLdtEntry(uint32_t base, uint32_t size, bool code) { uint32_t limit; uint8_t granular; @@ -49,45 +62,158 @@ inline void writeTsdSlot(uint32_t slot, uint64_t val) { *(volatile uint64_t __seg_gs *)(slot * sizeof(void *)) = val; } +inline bool isLdtEntryValid(int entry) { return entry >= 0 && entry < kMaxLdtEntries; } + +inline void markLdtEntryUsed(int entry) { + if (!isLdtEntryValid(entry)) { + return; + } + g_ldtBitmap[entry / kBitsPerWord] |= (1u << (entry % kBitsPerWord)); +} + +inline void markLdtEntryFree(int entry) { + if (!isLdtEntryValid(entry)) { + return; + } + g_ldtBitmap[entry / kBitsPerWord] &= ~(1u << (entry % kBitsPerWord)); +} + +inline bool isLdtEntryUsed(int entry) { + if (!isLdtEntryValid(entry)) { + return true; + } + return (g_ldtBitmap[entry / kBitsPerWord] & (1u << (entry % kBitsPerWord))) != 0; +} + +bool initializeLdtBitmapLocked() { + if (g_ldtBitmapInitialized) { + return true; + } + ldt_entry unused{}; + int count = i386_get_ldt(0, &unused, 1); + if (count < 0) { + DEBUG_LOG("setup_darwin: i386_get_ldt failed during bitmap init (%d), assuming empty table\n", count); + return false; + } + if (count > kMaxLdtEntries) { + DEBUG_LOG("setup_darwin: i386_get_ldt returned too many entries (%d), truncating to %d\n", count, + kMaxLdtEntries); + errno = ENOSPC; + return false; + } + for (int i = 0; i < count; ++i) { + markLdtEntryUsed(i); + } + g_ldtBitmapInitialized = true; + return true; +} + +int allocateLdtEntryLocked() { + if (!initializeLdtBitmapLocked()) { + errno = ENOSPC; + return -1; + } + auto tryAllocate = [&](int start) -> int { + for (int entry = start; entry < kMaxLdtEntries; ++entry) { + if (!isLdtEntryUsed(entry)) { + markLdtEntryUsed(entry); + g_ldtHint = entry + 1; + if (g_ldtHint >= kMaxLdtEntries) { + g_ldtHint = 1; + } + DEBUG_LOG("setup_darwin: Allocating LDT entry %d\n", entry); + return entry; + } + } + return -1; + }; + int entry = tryAllocate(std::max(g_ldtHint, 1)); + if (entry >= 0) { + return entry; + } + entry = tryAllocate(1); + if (entry >= 0) { + return entry; + } + errno = ENOSPC; + return -1; +} + +void freeLdtEntryLocked(int entryNumber) { + if (!g_ldtBitmapInitialized || !isLdtEntryValid(entryNumber)) { + return; + } + markLdtEntryFree(entryNumber); + if (entryNumber < g_ldtHint) { + g_ldtHint = std::max(entryNumber, 1); + } +} + +bool segmentSetupLocked(TEB *teb) { + // Create code LDT entry + if (g_codeSelector == 0) { + int entryNumber = allocateLdtEntryLocked(); + if (entryNumber < 0) { + return false; + } + ldt_entry codeLdt = createLdtEntry(0, 0xFFFFFFFF, true); + int ret = i386_set_ldt(entryNumber, &codeLdt, 1); + if (ret < 0) { + freeLdtEntryLocked(entryNumber); + return false; + } else if (ret != entryNumber) { + freeLdtEntryLocked(entryNumber); + errno = EALREADY; + return false; + } + g_codeSelector = createSelector(ret); + DEBUG_LOG("setup_darwin: Code LDT selector %x\n", g_codeSelector); + } + // Create data LDT entry + if (g_dataSelector == 0) { + int entryNumber = allocateLdtEntryLocked(); + if (entryNumber < 0) { + return false; + } + ldt_entry dataLdt = createLdtEntry(0, 0xFFFFFFFF, false); + int ret = i386_set_ldt(entryNumber, &dataLdt, 1); + if (ret < 0) { + freeLdtEntryLocked(entryNumber); + return false; + } else if (ret != entryNumber) { + freeLdtEntryLocked(entryNumber); + errno = EALREADY; + return false; + } + g_dataSelector = createSelector(ret); + DEBUG_LOG("setup_darwin: Data LDT selector %x\n", g_dataSelector); + } + teb->CodeSelector = g_codeSelector; + teb->DataSelector = g_dataSelector; + return true; +} + } // namespace -int tebThreadSetup(int entryNumber, TEB *teb) { - bool alloc = entryNumber == -1; - if (alloc) { - ldt_entry unused{}; - entryNumber = i386_get_ldt(0, &unused, 1); - if (entryNumber < 0) { - return entryNumber; - } - DEBUG_LOG("Allocating LDT entry %d\n", entryNumber); - // Create code LDT entry at entry_number + 1 - ldt_entry codeLdt = createLdtEntry(0, 0xFFFFFFFF, true); - int codeLdtEntry = entryNumber++; - int ret = i386_set_ldt(codeLdtEntry, &codeLdt, 1); - if (ret < 0) { - return ret; - } else if (ret != codeLdtEntry) { - errno = EALREADY; - return -EALREADY; - } - DEBUG_LOG("Code selector %x\n", createSelector(ret)); - // Create data LDT entry at entry_number + 2 - ldt_entry dataLdt = createLdtEntry(0, 0xFFFFFFFF, false); - int dataLdtEntry = entryNumber++; - ret = i386_set_ldt(dataLdtEntry, &dataLdt, 1); - if (ret < 0) { - return ret; - } else if (ret != dataLdtEntry) { - errno = EALREADY; - return -EALREADY; - } - DEBUG_LOG("Data selector %x\n", createSelector(dataLdtEntry)); +bool tebThreadSetup(TEB *teb) { + if (!teb) { + return false; + } + std::lock_guard lk(g_tebSetupMutex); + // Perform global segment setup if not already done + if (!segmentSetupLocked(teb)) { + return false; + } + int entryNumber = allocateLdtEntryLocked(); + if (entryNumber < 0) { + return false; } uintptr_t tebBase = reinterpret_cast(teb); if (tebBase > 0xFFFFFFFF) { - DEBUG_LOG("TEB base address exceeds 32-bit limit\n"); + fprintf(stderr, "setup_darwin: TEB base address exceeds 32-bit limit\n"); + freeLdtEntryLocked(entryNumber); errno = EINVAL; - return -EINVAL; + return false; } // Store the TEB base address in the reserved slot for Windows 64-bit (gs:[0x30]) writeTsdSlot(_PTHREAD_TSD_SLOT_RESERVED_WIN64, static_cast(tebBase)); @@ -95,11 +221,35 @@ int tebThreadSetup(int entryNumber, TEB *teb) { ldt_entry fsLdt = createLdtEntry(static_cast(tebBase), 0x1000, false); int ret = i386_set_ldt(entryNumber, &fsLdt, 1); if (ret < 0) { - return ret; + freeLdtEntryLocked(entryNumber); + return false; } else if (ret != entryNumber) { + freeLdtEntryLocked(entryNumber); errno = EALREADY; - return -EALREADY; + return false; } teb->CurrentFsSelector = createSelector(entryNumber); - return entryNumber; + DEBUG_LOG("Installing cs %d, ds %d, fs %d\n", teb->CodeSelector, teb->DataSelector, teb->CurrentFsSelector); + installSelectors(teb); + return true; +} + +bool tebThreadTeardown(TEB *teb) { + if (!teb) { + return true; + } + std::lock_guard lk(g_tebSetupMutex); + writeTsdSlot(_PTHREAD_TSD_SLOT_RESERVED_WIN64, 0); + uint16_t selector = teb->CurrentFsSelector; + if (selector == 0) { + return true; + } + int entryNumber = selector >> 3; + int ret = i386_set_ldt(entryNumber, nullptr, 1); + if (ret < 0) { + return false; + } + freeLdtEntryLocked(entryNumber); + teb->CurrentFsSelector = 0; + return true; } diff --git a/src/setup_linux.cpp b/src/setup_linux.cpp new file mode 100644 index 0000000..1826b32 --- /dev/null +++ b/src/setup_linux.cpp @@ -0,0 +1,82 @@ +#include "setup.h" + +#include "common.h" + +#include +#include +#include + +namespace { + +std::mutex g_tebSetupMutex; +int g_entryNumber = -1; + +} // namespace + +constexpr uint16_t createSelector(int entryNumber) { + return static_cast((entryNumber << 3) | USER_PRIVILEGE); +} + +#if defined(__x86_64__) + +// Implemented in setup.S +extern "C" int tebThreadSetup64(int entryNumber, TEB *teb); + +bool tebThreadSetup(TEB *teb) { + std::lock_guard guard(g_tebSetupMutex); + int ret = tebThreadSetup64(g_entryNumber, teb); + if (ret < 0) { + return false; + } + if (g_entryNumber != ret) { + g_entryNumber = ret; + DEBUG_LOG("set_thread_area: allocated entry=%d base=%p\n", g_entryNumber, teb); + } else { + DEBUG_LOG("set_thread_area: reused entry=%d base=%p\n", g_entryNumber, teb); + } + + teb->CurrentFsSelector = createSelector(ret); + teb->CurrentGsSelector = 0; + return true; +} + +#elif defined(__i386__) + +#include + +bool tebThreadSetup(TEB *teb) { + std::lock_guard guard(g_tebSetupMutex); + + struct user_desc desc; // NOLINT(cppcoreguidelines-pro-type-member-init) + std::memset(&desc, 0, sizeof(desc)); + desc.entry_number = g_entryNumber; + desc.base_addr = reinterpret_cast(teb); + desc.limit = static_cast(sizeof(TEB) - 1); + desc.seg_32bit = 1; + desc.contents = 0; + desc.read_exec_only = 0; + desc.limit_in_pages = 0; + desc.seg_not_present = 0; + desc.useable = 1; + if (syscall(SYS_set_thread_area, &desc) != 0) { + return false; + } + if (g_entryNumber != static_cast(desc.entry_number)) { + g_entryNumber = static_cast(desc.entry_number); + DEBUG_LOG("setup_linux: allocated GDT entry=%d base=%p\n", g_entryNumber, teb); + } else { + DEBUG_LOG("setup_linux: reused GDT entry=%d base=%p\n", g_entryNumber, teb); + } + + teb->CurrentFsSelector = createSelector(desc.entry_number); + teb->CurrentGsSelector = 0; + return true; +} + +#endif + +bool tebThreadTeardown(TEB *teb) { + (void)teb; + // no-op on Linux + return true; +} diff --git a/src/types.h b/src/types.h index c745805..04203d7 100644 --- a/src/types.h +++ b/src/types.h @@ -541,6 +541,10 @@ typedef struct _TEB { // wibo WORD CurrentFsSelector; WORD CurrentGsSelector; +#ifdef __x86_64__ + WORD CodeSelector; + WORD DataSelector; +#endif void *CurrentStackPointer; #ifdef __x86_64__ void *HostFsBase; @@ -558,6 +562,12 @@ static_assert(offsetof(TEB, DeallocationStack) == 0xE0C, "DeallocationStack offs static_assert(offsetof(TEB, TlsSlots) == 0xE10, "TLS slots offset mismatch"); static_assert(offsetof(TEB, CurrentFsSelector) == TEB_FS_SEL); static_assert(offsetof(TEB, CurrentGsSelector) == TEB_GS_SEL); +#ifdef TEB_CS_SEL +static_assert(offsetof(TEB, CodeSelector) == TEB_CS_SEL); +#endif +#ifdef TEB_DS_SEL +static_assert(offsetof(TEB, DataSelector) == TEB_DS_SEL); +#endif static_assert(offsetof(TEB, CurrentStackPointer) == TEB_SP); #ifdef TEB_FSBASE static_assert(offsetof(TEB, HostFsBase) == TEB_FSBASE); diff --git a/tools/gen_trampolines.py b/tools/gen_trampolines.py index 70d646a..b853a86 100644 --- a/tools/gen_trampolines.py +++ b/tools/gen_trampolines.py @@ -761,9 +761,10 @@ def emit_cc_thunk64(f: FuncInfo | TypedefInfo, lines: List[str]): # Jump to 32-bit mode lines.append("\tLJMP32 rbx") - # Setup FS selector - lines.append("\tmov ax, word ptr [ebx+TEB_FS_SEL]") - lines.append("\tmov fs, ax") + if sys.platform != "darwin": + # Setup FS selector + lines.append("\tmov ax, word ptr [ebx+TEB_FS_SEL]") + lines.append("\tmov fs, ax") # Call into target lines.append(f"\tcall {call_target}")