diff --git a/CMakeLists.txt b/CMakeLists.txt index f70c4c8..b1d6d10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -164,7 +164,7 @@ add_executable(wibo dll/version.cpp src/access.cpp src/async_io.cpp - src/async_io_threadpool.cpp + src/async_io_epoll.cpp src/context.cpp src/errors.cpp src/files.cpp diff --git a/src/async_io.cpp b/src/async_io.cpp index 02b4e3b..6730175 100644 --- a/src/async_io.cpp +++ b/src/async_io.cpp @@ -49,7 +49,7 @@ static constexpr BackendEntry kBackends[] = { #if WIBO_ENABLE_LIBURING {"io_uring", detail::createIoUringBackend}, #endif - {"thread pool", detail::createThreadPoolBackend}, + {"epoll", detail::createEpollBackend}, }; AsyncIOBackend &asyncIO() { diff --git a/src/async_io.h b/src/async_io.h index 4c52fd1..715c8f2 100644 --- a/src/async_io.h +++ b/src/async_io.h @@ -24,7 +24,7 @@ namespace detail { #if WIBO_ENABLE_LIBURING std::unique_ptr createIoUringBackend(); #endif -std::unique_ptr createThreadPoolBackend(); +std::unique_ptr createEpollBackend(); } // namespace detail diff --git a/src/async_io_epoll.cpp b/src/async_io_epoll.cpp new file mode 100644 index 0000000..7ed9493 --- /dev/null +++ b/src/async_io_epoll.cpp @@ -0,0 +1,763 @@ +#include "async_io.h" + +#include "errors.h" +#include "files.h" +#include "kernel32/overlapped_util.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace { + +constexpr NTSTATUS kStatusCancelled = static_cast(0xC0000120); + +struct AsyncRequest { + enum class Kind { Read, Write }; + + Kind kind; + Pin file; + OVERLAPPED *overlapped = nullptr; + std::optional offset; + bool isPipe = false; + bool updateFilePointer = false; + void *readBuffer = nullptr; + const uint8_t *writeBuffer = nullptr; + size_t length = 0; + size_t progress = 0; +}; + +struct FileState { + explicit FileState(int fdIn) : fd(fdIn) {} + + int fd; + bool registered = false; + uint32_t events = 0; + int originalFlags = -1; + std::deque> readQueue; + std::deque> writeQueue; +}; + +struct ProcessResult { + bool completed = false; + bool requeue = false; + NTSTATUS status = STATUS_SUCCESS; + size_t bytesTransferred = 0; +}; + +struct Completion { + std::unique_ptr req; + NTSTATUS status = STATUS_SUCCESS; + size_t bytesTransferred = 0; +}; + +class EpollBackend : public wibo::AsyncIOBackend { + public: + ~EpollBackend() override { shutdown(); } + + bool init() override; + void shutdown() override; + [[nodiscard]] bool running() const noexcept override { return mRunning.load(std::memory_order_acquire); } + + bool queueRead(Pin file, OVERLAPPED *ov, void *buffer, DWORD length, + const std::optional &offset, bool isPipe) override; + bool queueWrite(Pin file, OVERLAPPED *ov, const void *buffer, DWORD length, + const std::optional &offset, bool isPipe) override; + + private: + bool enqueueRequest(std::unique_ptr req); + bool enqueueFileRequest(std::unique_ptr req); + void workerLoop(); + void fileWorkerLoop(); + void handleFileEvents(FileState &state, uint32_t events); + void notifyWorker() const; + void drainEventFd() const; + void updateRegistrationLocked(FileState &state) const; + static void ensureNonBlocking(FileState &state); + static void restoreOriginalFlags(FileState &state); + void processCompletions(); + void failAllPending(); + void completeRequest(const AsyncRequest &req, NTSTATUS status, size_t bytesTransferred); + static Completion processBlockingRequest(AsyncRequest &req); + + static ProcessResult tryProcessPipeRead(AsyncRequest &req); + static ProcessResult tryProcessPipeWrite(AsyncRequest &req); + + std::atomic mRunning{false}; + std::atomic mPending{0}; + int mEpollFd = -1; + int mEventFd = -1; + std::thread mThread; + + std::mutex mMutex; + std::unordered_map> mFileStates; + + std::mutex mFileQueueMutex; + std::condition_variable mFileQueueCv; + std::deque> mFileQueue; + bool mFileStopping = false; + std::vector mFileWorkers; + + std::mutex mCompletionMutex; + std::deque mCompletions; +}; + +bool EpollBackend::init() { + if (mRunning.load(std::memory_order_acquire)) { + return true; + } + + mEpollFd = epoll_create1(EPOLL_CLOEXEC); + if (mEpollFd < 0) { + DEBUG_LOG("AsyncIO(epoll): epoll_create1 failed: %d\n", errno); + return false; + } + + mEventFd = eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK); + if (mEventFd < 0) { + DEBUG_LOG("AsyncIO(epoll): eventfd failed: %d\n", errno); + close(mEpollFd); + mEpollFd = -1; + return false; + } + + struct epoll_event event{}; + event.events = EPOLLIN; + event.data.fd = mEventFd; + if (epoll_ctl(mEpollFd, EPOLL_CTL_ADD, mEventFd, &event) != 0) { + DEBUG_LOG("AsyncIO(epoll): epoll_ctl add eventfd failed: %d\n", errno); + close(mEventFd); + close(mEpollFd); + mEventFd = -1; + mEpollFd = -1; + return false; + } + + unsigned int workerCount = std::thread::hardware_concurrency(); + if (workerCount == 0) { + workerCount = 1; + } + workerCount = std::min(workerCount, 4u); + + { + std::lock_guard lk(mFileQueueMutex); + mFileStopping = false; + } + mFileWorkers.reserve(workerCount); + for (unsigned int i = 0; i < workerCount; ++i) { + mFileWorkers.emplace_back(&EpollBackend::fileWorkerLoop, this); + } + + mRunning.store(true, std::memory_order_release); + mThread = std::thread(&EpollBackend::workerLoop, this); + DEBUG_LOG("AsyncIO: epoll backend initialized\n"); + return true; +} + +void EpollBackend::shutdown() { + if (!mRunning.exchange(false, std::memory_order_acq_rel)) { + return; + } + + { + std::lock_guard lk(mFileQueueMutex); + mFileStopping = true; + } + mFileQueueCv.notify_all(); + notifyWorker(); + + if (mThread.joinable()) { + mThread.join(); + } + + for (auto &worker : mFileWorkers) { + if (worker.joinable()) { + worker.join(); + } + } + mFileWorkers.clear(); + + if (mEventFd >= 0) { + close(mEventFd); + mEventFd = -1; + } + if (mEpollFd >= 0) { + close(mEpollFd); + mEpollFd = -1; + } + + { + std::lock_guard lk(mMutex); + for (auto &entry : mFileStates) { + restoreOriginalFlags(*entry.second); + } + mFileStates.clear(); + } + { + std::lock_guard lk(mFileQueueMutex); + mFileQueue.clear(); + } + { + std::lock_guard lk(mCompletionMutex); + mCompletions.clear(); + } + mPending.store(0, std::memory_order_release); +} + +bool EpollBackend::queueRead(Pin file, OVERLAPPED *ov, void *buffer, DWORD length, + const std::optional &offset, bool isPipe) { + auto req = std::make_unique(AsyncRequest::Kind::Read); + req->file = std::move(file); + req->overlapped = ov; + req->offset = offset; + req->isPipe = isPipe; + req->updateFilePointer = req->file ? !req->file->overlapped : true; + req->readBuffer = buffer; + req->length = length; + return enqueueRequest(std::move(req)); +} + +bool EpollBackend::queueWrite(Pin file, OVERLAPPED *ov, const void *buffer, DWORD length, + const std::optional &offset, bool isPipe) { + auto req = std::make_unique(AsyncRequest::Kind::Write); + req->file = std::move(file); + req->overlapped = ov; + req->offset = offset; + req->isPipe = isPipe; + req->updateFilePointer = req->file ? !req->file->overlapped : true; + req->writeBuffer = static_cast(buffer); + req->length = length; + return enqueueRequest(std::move(req)); +} + +bool EpollBackend::enqueueRequest(std::unique_ptr req) { + if (!req || !req->file || !req->file->valid()) { + return false; + } + if (!mRunning.load(std::memory_order_acquire)) { + return false; + } + + if (req->isPipe) { + std::lock_guard lk(mMutex); + if (!mRunning.load(std::memory_order_acquire)) { + return false; + } + mPending.fetch_add(1, std::memory_order_acq_rel); + const int fd = req->file->fd; + auto &statePtr = mFileStates[fd]; + if (!statePtr) { + statePtr = std::make_unique(fd); + } + FileState &state = *statePtr; + ensureNonBlocking(state); + if (req->kind == AsyncRequest::Kind::Read) { + state.readQueue.emplace_back(std::move(req)); + } else { + state.writeQueue.emplace_back(std::move(req)); + } + updateRegistrationLocked(state); + notifyWorker(); + return true; + } + + mPending.fetch_add(1, std::memory_order_acq_rel); + if (enqueueFileRequest(std::move(req))) { + return true; + } + mPending.fetch_sub(1, std::memory_order_acq_rel); + return false; +} + +bool EpollBackend::enqueueFileRequest(std::unique_ptr req) { + std::lock_guard lk(mFileQueueMutex); + if (mFileStopping) { + return false; + } + mFileQueue.emplace_back(std::move(req)); + mFileQueueCv.notify_one(); + return true; +} + +void EpollBackend::workerLoop() { + std::array events{}; + + while (true) { + processCompletions(); + + if (!mRunning.load(std::memory_order_acquire) && mPending.load(std::memory_order_acquire) == 0) { + break; + } + + int timeout = mRunning.load(std::memory_order_acquire) ? -1 : 10; + int count = epoll_wait(mEpollFd, events.data(), static_cast(events.size()), timeout); + if (count < 0) { + if (errno == EINTR) { + continue; + } + DEBUG_LOG("AsyncIO(epoll): epoll_wait failed: %d\n", errno); + continue; + } + if (count == 0) { + continue; + } + + for (int i = 0; i < count; ++i) { + auto &ev = events[static_cast(i)]; + if (ev.data.fd == mEventFd) { + drainEventFd(); + processCompletions(); + continue; + } + if (auto *state = static_cast(ev.data.ptr)) { + handleFileEvents(*state, ev.events); + } + } + } + + processCompletions(); + failAllPending(); +} + +void EpollBackend::fileWorkerLoop() { + while (true) { + std::unique_ptr req; + { + std::unique_lock lk(mFileQueueMutex); + mFileQueueCv.wait(lk, [&] { return mFileStopping || !mFileQueue.empty(); }); + if (mFileStopping && mFileQueue.empty()) { + break; + } + req = std::move(mFileQueue.front()); + mFileQueue.pop_front(); + } + + if (!req) { + continue; + } + + Completion completion = processBlockingRequest(*req); + completion.req = std::move(req); + { + std::lock_guard lk(mCompletionMutex); + mCompletions.emplace_back(std::move(completion)); + } + notifyWorker(); + } +} + +void EpollBackend::handleFileEvents(FileState &state, uint32_t events) { + const bool canRead = (events & (EPOLLIN | EPOLLERR | EPOLLHUP)) != 0; + const bool canWrite = (events & (EPOLLOUT | EPOLLERR | EPOLLHUP)) != 0; + + if (canRead) { + while (true) { + std::unique_ptr req; + { + std::lock_guard lk(mMutex); + if (state.readQueue.empty()) { + break; + } + req = std::move(state.readQueue.front()); + state.readQueue.pop_front(); + } + + auto result = tryProcessPipeRead(*req); + if (result.requeue) { + std::lock_guard lk(mMutex); + state.readQueue.emplace_front(std::move(req)); + updateRegistrationLocked(state); + break; + } + if (result.completed) { + completeRequest(*req, result.status, result.bytesTransferred); + } + { + std::lock_guard lk(mMutex); + updateRegistrationLocked(state); + } + } + } + + if (canWrite) { + while (true) { + std::unique_ptr req; + { + std::lock_guard lk(mMutex); + if (state.writeQueue.empty()) { + break; + } + req = std::move(state.writeQueue.front()); + state.writeQueue.pop_front(); + } + + auto result = tryProcessPipeWrite(*req); + if (result.requeue) { + std::lock_guard lk(mMutex); + state.writeQueue.emplace_front(std::move(req)); + updateRegistrationLocked(state); + break; + } + if (result.completed) { + completeRequest(*req, result.status, result.bytesTransferred); + } + { + std::lock_guard lk(mMutex); + updateRegistrationLocked(state); + } + } + } + + const int fd = state.fd; + { + std::lock_guard lk(mMutex); + auto it = mFileStates.find(fd); + if (it != mFileStates.end() && it->second.get() == &state) { + FileState *ptr = it->second.get(); + if (!ptr->registered && ptr->readQueue.empty() && ptr->writeQueue.empty()) { + restoreOriginalFlags(*ptr); + mFileStates.erase(it); + } + } + } +} + +void EpollBackend::notifyWorker() const { + if (mEventFd < 0) { + return; + } + uint64_t value = 1; + ssize_t rc; + do { + rc = write(mEventFd, &value, sizeof(value)); + } while (rc == -1 && errno == EINTR); +} + +void EpollBackend::drainEventFd() const { + uint64_t value; + while (true) { + ssize_t rc = read(mEventFd, &value, sizeof(value)); + if (rc == -1) { + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN) { + break; + } + } + if (rc == sizeof(value)) { + continue; + } + break; + } +} + +void EpollBackend::updateRegistrationLocked(FileState &state) const { + uint32_t desired = 0; + if (!state.readQueue.empty()) { + desired |= EPOLLIN; + } + if (!state.writeQueue.empty()) { + desired |= EPOLLOUT; + } + + if (desired == state.events && state.registered) { + return; + } + + if (desired == 0) { + if (state.registered) { + if (epoll_ctl(mEpollFd, EPOLL_CTL_DEL, state.fd, nullptr) != 0) { + DEBUG_LOG("AsyncIO(epoll): epoll_ctl del fd %d failed: %d\n", state.fd, errno); + } + state.registered = false; + state.events = 0; + } + restoreOriginalFlags(state); + return; + } + + struct epoll_event ev{}; + ev.events = desired; + ev.data.ptr = &state; + int op = state.registered ? EPOLL_CTL_MOD : EPOLL_CTL_ADD; + if (epoll_ctl(mEpollFd, op, state.fd, &ev) != 0) { + DEBUG_LOG("AsyncIO(epoll): epoll_ctl op=%d fd=%d failed: %d\n", op, state.fd, errno); + return; + } + state.registered = true; + state.events = desired; +} + +void EpollBackend::ensureNonBlocking(FileState &state) { + if (state.originalFlags >= 0) { + return; + } + + int flags = fcntl(state.fd, F_GETFL, 0); + if (flags < 0) { + DEBUG_LOG("AsyncIO(epoll): fcntl(F_GETFL) failed for fd %d: %d\n", state.fd, errno); + return; + } + + if ((flags & O_NONBLOCK) != 0) { + return; + } + + state.originalFlags = flags; + if (fcntl(state.fd, flags | O_NONBLOCK) != 0) { + DEBUG_LOG("AsyncIO(epoll): fcntl(F_SETFL) failed for fd %d: %d\n", state.fd, errno); + state.originalFlags = -1; + } +} + +void EpollBackend::restoreOriginalFlags(FileState &state) { + if (state.originalFlags < 0) { + return; + } + + if (fcntl(state.fd, F_SETFL, state.originalFlags) != 0) { + DEBUG_LOG("AsyncIO(epoll): restoring flags for fd %d failed: %d\n", state.fd, errno); + } + + state.originalFlags = -1; +} + +void EpollBackend::processCompletions() { + std::deque pending; + { + std::lock_guard lk(mCompletionMutex); + if (mCompletions.empty()) { + return; + } + pending.swap(mCompletions); + } + + for (auto &entry : pending) { + if (entry.req) { + completeRequest(*entry.req, entry.status, entry.bytesTransferred); + } + } +} + +void EpollBackend::failAllPending() { + processCompletions(); + + std::vector> pending; + { + std::lock_guard lk(mMutex); + for (auto &entry : mFileStates) { + auto &state = *entry.second; + while (!state.readQueue.empty()) { + pending.emplace_back(std::move(state.readQueue.front())); + state.readQueue.pop_front(); + } + while (!state.writeQueue.empty()) { + pending.emplace_back(std::move(state.writeQueue.front())); + state.writeQueue.pop_front(); + } + state.registered = false; + state.events = 0; + restoreOriginalFlags(state); + } + } + + { + std::lock_guard lk(mFileQueueMutex); + while (!mFileQueue.empty()) { + pending.emplace_back(std::move(mFileQueue.front())); + mFileQueue.pop_front(); + } + } + + std::deque completions; + { + std::lock_guard lk(mCompletionMutex); + completions.swap(mCompletions); + } + for (auto &entry : completions) { + if (entry.req) { + completeRequest(*entry.req, entry.status, entry.bytesTransferred); + } + } + + for (auto &req : pending) { + if (req) { + completeRequest(*req, kStatusCancelled, 0); + } + } +} + +void EpollBackend::completeRequest(const AsyncRequest &req, NTSTATUS status, size_t bytesTransferred) { + kernel32::detail::signalOverlappedEvent(req.file.get(), req.overlapped, status, bytesTransferred); + mPending.fetch_sub(1, std::memory_order_acq_rel); +} + +Completion EpollBackend::processBlockingRequest(AsyncRequest &req) { + Completion result{}; + if (!req.file || !req.file->valid()) { + result.status = STATUS_INVALID_HANDLE; + return result; + } + + files::IOResult io{}; + if (req.kind == AsyncRequest::Kind::Read) { + io = files::read(req.file.get(), req.readBuffer, req.length, req.offset, req.updateFilePointer); + } else { + io = files::write(req.file.get(), req.writeBuffer, req.length, req.offset, req.updateFilePointer); + } + + result.bytesTransferred = io.bytesTransferred; + + if (io.unixError != 0) { + result.status = wibo::statusFromErrno(io.unixError); + if (result.status == STATUS_SUCCESS) { + result.status = STATUS_UNEXPECTED_IO_ERROR; + } + } else if (req.kind == AsyncRequest::Kind::Read && io.bytesTransferred == 0 && io.reachedEnd) { + result.status = req.isPipe ? STATUS_PIPE_BROKEN : STATUS_END_OF_FILE; + } else if (req.kind == AsyncRequest::Kind::Write && io.bytesTransferred == 0 && io.reachedEnd) { + result.status = STATUS_END_OF_FILE; + } else { + result.status = STATUS_SUCCESS; + } + + return result; +} + +ProcessResult EpollBackend::tryProcessPipeRead(AsyncRequest &req) { + ProcessResult result{}; + if (!req.file || !req.file->valid()) { + result.completed = true; + result.status = STATUS_INVALID_HANDLE; + return result; + } + const int fd = req.file->fd; + if (req.length == 0) { + result.completed = true; + result.status = STATUS_SUCCESS; + return result; + } + + uint8_t *buffer = static_cast(req.readBuffer); + size_t toRead = req.length; + while (true) { + size_t chunk = std::min(toRead, static_cast(SSIZE_MAX)); + ssize_t rc = ::read(fd, buffer, chunk); + if (rc > 0) { + result.completed = true; + result.status = STATUS_SUCCESS; + result.bytesTransferred = static_cast(rc); + return result; + } + if (rc == 0) { + result.completed = true; + result.status = req.isPipe ? STATUS_PIPE_BROKEN : STATUS_END_OF_FILE; + result.bytesTransferred = 0; + return result; + } + if (rc == -1) { + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + result.requeue = true; + return result; + } + int err = errno ? errno : EIO; + result.completed = true; + if (err == EPIPE || err == ECONNRESET) { + result.status = STATUS_PIPE_BROKEN; + } else { + result.status = wibo::statusFromErrno(err); + if (result.status == STATUS_SUCCESS) { + result.status = STATUS_UNEXPECTED_IO_ERROR; + } + } + result.bytesTransferred = 0; + return result; + } + } +} + +ProcessResult EpollBackend::tryProcessPipeWrite(AsyncRequest &req) { + ProcessResult result{}; + if (!req.file || !req.file->valid()) { + result.completed = true; + result.status = STATUS_INVALID_HANDLE; + return result; + } + + const int fd = req.file->fd; + size_t remaining = req.length - req.progress; + const uint8_t *buffer = req.writeBuffer ? req.writeBuffer + req.progress : nullptr; + + while (remaining > 0) { + size_t chunk = std::min(remaining, static_cast(SSIZE_MAX)); + ssize_t rc = ::write(fd, buffer, chunk); + if (rc > 0) { + size_t written = static_cast(rc); + req.progress += written; + remaining -= written; + buffer += written; + if (req.offset.has_value()) { + *req.offset += static_cast(written); + } + continue; + } + if (rc == 0) { + break; + } + if (errno == EINTR) { + continue; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + result.requeue = true; + return result; + } + int err = errno ? errno : EIO; + result.completed = true; + if (err == EPIPE || err == ECONNRESET) { + result.status = STATUS_PIPE_BROKEN; + } else { + result.status = wibo::statusFromErrno(err); + if (result.status == STATUS_SUCCESS) { + result.status = STATUS_UNEXPECTED_IO_ERROR; + } + } + result.bytesTransferred = req.progress; + return result; + } + + if (remaining == 0) { + result.completed = true; + result.status = STATUS_SUCCESS; + result.bytesTransferred = req.progress; + } else { + result.requeue = true; + } + return result; +} + +} // namespace + +namespace wibo::detail { + +std::unique_ptr createEpollBackend() { return std::make_unique(); } + +} // namespace wibo::detail diff --git a/src/async_io_threadpool.cpp b/src/async_io_threadpool.cpp deleted file mode 100644 index 44382ca..0000000 --- a/src/async_io_threadpool.cpp +++ /dev/null @@ -1,214 +0,0 @@ -#include "async_io.h" - -#include "errors.h" -#include "files.h" -#include "kernel32/overlapped_util.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace { - -struct AsyncRequest { - enum class Kind { Read, Write }; - - Kind kind; - Pin file; - OVERLAPPED *overlapped = nullptr; - void *buffer = nullptr; - DWORD length = 0; - std::optional offset; - bool isPipe = false; - bool updateFilePointer = false; - - explicit AsyncRequest(Kind k) : kind(k) {} -}; - -class ThreadPoolBackend : public wibo::AsyncIOBackend { - public: - ~ThreadPoolBackend() override { shutdown(); } - - bool init() override; - void shutdown() override; - [[nodiscard]] bool running() const noexcept override { return mActive.load(std::memory_order_acquire); } - - bool queueRead(Pin file, OVERLAPPED *ov, void *buffer, DWORD length, - const std::optional &offset, bool isPipe) override; - bool queueWrite(Pin file, OVERLAPPED *ov, const void *buffer, DWORD length, - const std::optional &offset, bool isPipe) override; - - private: - bool enqueueRequest(std::unique_ptr req); - void workerLoop(); - static void processRequest(const AsyncRequest &req); - - std::atomic mActive{false}; - std::mutex mQueueMutex; - std::condition_variable mQueueCv; - std::deque> mQueue; - std::vector mWorkers; - std::atomic mPending{0}; - bool mStopping = false; // guarded by mQueueMutex -}; - -bool ThreadPoolBackend::init() { - bool expected = false; - if (!mActive.compare_exchange_strong(expected, true, std::memory_order_acq_rel)) { - return true; - } - - unsigned int threadCount = std::thread::hardware_concurrency(); - if (threadCount == 0) { - threadCount = 1; - } - threadCount = std::min(threadCount, 4u); // cap to avoid oversubscription - - { - std::lock_guard lk(mQueueMutex); - mStopping = false; - } - mWorkers.reserve(threadCount); - for (unsigned int i = 0; i < threadCount; ++i) { - mWorkers.emplace_back(&ThreadPoolBackend::workerLoop, this); - } - DEBUG_LOG("thread pool backend initialized (workers=%u)\n", threadCount); - return true; -} - -void ThreadPoolBackend::shutdown() { - if (!mActive.exchange(false, std::memory_order_acq_rel)) { - return; - } - - { - std::lock_guard lk(mQueueMutex); - mStopping = true; - } - mQueueCv.notify_all(); - - for (auto &worker : mWorkers) { - if (worker.joinable()) { - worker.join(); - } - } - mWorkers.clear(); - - { - std::lock_guard lk(mQueueMutex); - mQueue.clear(); - mStopping = false; - } - mPending.store(0, std::memory_order_release); - DEBUG_LOG("thread-pool async backend shut down\n"); -} - -bool ThreadPoolBackend::queueRead(Pin file, OVERLAPPED *ov, void *buffer, DWORD length, - const std::optional &offset, bool isPipe) { - auto req = std::make_unique(AsyncRequest::Kind::Read); - req->file = std::move(file); - req->overlapped = ov; - req->buffer = buffer; - req->length = length; - req->offset = offset; - req->isPipe = isPipe; - req->updateFilePointer = req->file ? !req->file->overlapped : true; - return enqueueRequest(std::move(req)); -} - -bool ThreadPoolBackend::queueWrite(Pin file, OVERLAPPED *ov, const void *buffer, DWORD length, - const std::optional &offset, bool isPipe) { - auto req = std::make_unique(AsyncRequest::Kind::Write); - req->file = std::move(file); - req->overlapped = ov; - req->buffer = const_cast(buffer); - req->length = length; - req->offset = offset; - req->isPipe = isPipe; - req->updateFilePointer = req->file ? !req->file->overlapped : true; - return enqueueRequest(std::move(req)); -} - -bool ThreadPoolBackend::enqueueRequest(std::unique_ptr req) { - if (!running()) { - return false; - } - if (!req || !req->file) { - return false; - } - - { - std::lock_guard lk(mQueueMutex); - if (mStopping) { - return false; - } - mQueue.emplace_back(std::move(req)); - mPending.fetch_add(1, std::memory_order_acq_rel); - } - mQueueCv.notify_one(); - return true; -} - -void ThreadPoolBackend::workerLoop() { - while (true) { - std::unique_ptr req; - { - std::unique_lock lk(mQueueMutex); - mQueueCv.wait(lk, [&] { return mStopping || !mQueue.empty(); }); - if (mStopping && mQueue.empty()) { - break; - } - req = std::move(mQueue.front()); - mQueue.pop_front(); - } - - if (req) { - processRequest(*req); - } - mPending.fetch_sub(1, std::memory_order_acq_rel); - } -} - -void ThreadPoolBackend::processRequest(const AsyncRequest &req) { - if (!req.file || !req.file->valid()) { - kernel32::detail::signalOverlappedEvent(req.file.get(), req.overlapped, STATUS_INVALID_HANDLE, 0); - return; - } - - files::IOResult io{}; - if (req.kind == AsyncRequest::Kind::Read) { - io = files::read(req.file.get(), req.buffer, req.length, req.offset, req.updateFilePointer); - } else { - const void *ptr = req.buffer; - io = files::write(req.file.get(), ptr, req.length, req.offset, req.updateFilePointer); - } - - NTSTATUS completionStatus = STATUS_SUCCESS; - size_t bytesTransferred = io.bytesTransferred; - - if (io.unixError != 0) { - completionStatus = wibo::statusFromErrno(io.unixError); - if (completionStatus == STATUS_SUCCESS) { - completionStatus = STATUS_UNEXPECTED_IO_ERROR; - } - } else if (req.kind == AsyncRequest::Kind::Read && bytesTransferred == 0 && io.reachedEnd) { - completionStatus = req.isPipe ? STATUS_PIPE_BROKEN : STATUS_END_OF_FILE; - } else if (req.kind == AsyncRequest::Kind::Write && bytesTransferred == 0 && io.reachedEnd) { - completionStatus = STATUS_END_OF_FILE; - } - - kernel32::detail::signalOverlappedEvent(req.file.get(), req.overlapped, completionStatus, bytesTransferred); -} - -} // namespace - -namespace wibo::detail { - -std::unique_ptr createThreadPoolBackend() { return std::make_unique(); } - -} // namespace wibo::detail