use std::{ collections::HashMap, fmt::{Display, Formatter}, io, num::NonZeroUsize, sync::{Arc, Mutex}, thread::JoinHandle, time::{Duration, Instant}, }; use bytes::{Bytes, BytesMut}; use crossbeam_channel::{Receiver, Sender}; use crossbeam_utils::sync::WaitGroup; use lru::LruCache; use polonius_the_crab::{polonius, polonius_return}; use simple_moving_average::{SMA, SingleSumSMA}; use tracing::{Level, debug, error, instrument, span}; use zerocopy::FromZeros; use crate::{ IoResultContext, common::PartitionInfo, disc::{ DiscHeader, SECTOR_GROUP_SIZE, SECTOR_SIZE, hashes::{GroupHashes, hash_sector_group}, wii::HASHES_SIZE, }, io::{ block::{Block, BlockKind, BlockReader}, wia::WIAException, }, read::PartitionEncryption, util::{ aes::{decrypt_sector, encrypt_sector}, array_ref_mut, }, }; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct SectorGroupRequest { pub group_idx: u32, pub partition_idx: Option, pub mode: PartitionEncryption, pub force_rehash: bool, } impl Display for SectorGroupRequest { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self.partition_idx { Some(idx) => write!(f, "Partition {} group {}", idx, self.group_idx), None => write!(f, "Group {}", self.group_idx), } } } #[derive(Clone)] pub struct SectorGroup { pub request: SectorGroupRequest, pub start_sector: u32, pub data: Bytes, pub sector_bitmap: u64, pub io_duration: Option, #[allow(unused)] // TODO WIA hash exceptions pub group_hashes: Option>, } impl SectorGroup { /// Calculate the number of consecutive sectors starting from `start`. #[inline] pub fn consecutive_sectors(&self, start: u32) -> u32 { (self.sector_bitmap >> start).trailing_ones() } } pub type SectorGroupResult = io::Result; #[allow(unused)] pub struct Preloader { request_tx: Sender, request_rx: Receiver, stat_tx: Sender, stat_rx: Receiver, threads: Mutex, cache: Arc>, // Fallback single-threaded loader loader: Mutex, } #[allow(unused)] struct PreloaderThreads { join_handles: Vec>, last_adjust: Instant, num_samples: usize, wait_time_avg: SingleSumSMA, req_time_avg: SingleSumSMA, io_time_avg: SingleSumSMA, } impl PreloaderThreads { fn new(join_handles: Vec>) -> Self { Self { join_handles, last_adjust: Instant::now(), num_samples: 0, wait_time_avg: SingleSumSMA::<_, _, 100>::from_zero(Duration::default()), req_time_avg: SingleSumSMA::<_, _, 100>::from_zero(Duration::default()), io_time_avg: SingleSumSMA::<_, _, 100>::from_zero(Duration::default()), } } fn push_stats(&mut self, stat: PreloaderThreadStats, _outer: &Preloader) { self.wait_time_avg.add_sample(stat.wait_time); self.req_time_avg.add_sample(stat.req_time); self.io_time_avg.add_sample(stat.io_time); self.num_samples += 1; if self.num_samples % 100 == 0 { let avg_wait = self.wait_time_avg.get_average(); let avg_req = self.req_time_avg.get_average(); let avg_io = self.io_time_avg.get_average(); let utilization = avg_req.as_secs_f64() / (avg_req.as_secs_f64() + avg_wait.as_secs_f64()); let io_time = avg_io.as_secs_f64() / avg_req.as_secs_f64(); debug!( "Preloader stats: count {}, wait: {:?}, req: {:?}, util: {:.2}%, io: {:.2}%", self.num_samples, avg_wait, avg_req, utilization * 100.0, io_time * 100.0 ); // if self.last_adjust.elapsed() > Duration::from_secs(2) { // if utilization > 0.9 && io_time < 0.1 { // println!("Preloader is CPU-bound, increasing thread count"); // let id = self.join_handles.len(); // self.join_handles.push(preloader_thread( // id, // outer.request_rx.clone(), // outer.cache.clone(), // outer.loader.lock().unwrap().clone(), // outer.stat_tx.clone(), // )); // self.last_adjust = Instant::now(); // } /*else if io_time > 0.9 { // println!("Preloader is I/O-bound, decreasing thread count"); // if self.join_handles.len() > 1 { // let handle = self.join_handles.pop().unwrap(); // // } // }*/ // } } } } struct PreloaderCache { inflight: HashMap, lru_cache: LruCache, } impl Default for PreloaderCache { fn default() -> Self { Self { inflight: Default::default(), lru_cache: LruCache::new(NonZeroUsize::new(64).unwrap()), } } } impl PreloaderCache { fn push(&mut self, request: SectorGroupRequest, group: SectorGroup) { self.lru_cache.push(request, group); self.inflight.remove(&request); } fn remove(&mut self, request: &SectorGroupRequest) { self.inflight.remove(request); } fn contains(&self, request: &SectorGroupRequest) -> bool { self.lru_cache.contains(request) || self.inflight.contains_key(request) } } #[allow(unused)] struct PreloaderThreadStats { thread_id: usize, wait_time: Duration, req_time: Duration, io_time: Duration, } fn preloader_thread( thread_id: usize, request_rx: Receiver, cache: Arc>, mut loader: SectorGroupLoader, stat_tx: Sender, ) -> JoinHandle<()> { std::thread::Builder::new() .name(format!("Preloader {thread_id}")) .spawn(move || { let mut last_request_end: Option = None; while let Ok(request) = request_rx.recv() { let wait_time = if let Some(last_request) = last_request_end { last_request.elapsed() } else { Duration::default() }; let start = Instant::now(); let mut io_time = Duration::default(); match loader.load(request) { Ok(group) => { let Ok(mut cache_guard) = cache.lock() else { break; }; io_time = group.io_duration.unwrap_or_default(); cache_guard.push(request, group); } Err(_) => { let Ok(mut cache_guard) = cache.lock() else { break; }; // Just drop the request if it failed cache_guard.remove(&request); } } let end = Instant::now(); last_request_end = Some(end); let req_time = end - start; if stat_tx .send(PreloaderThreadStats { thread_id, wait_time, req_time, io_time }) .is_err() { break; } } }) .expect("Failed to spawn preloader thread") } impl Preloader { pub fn new(loader: SectorGroupLoader, num_threads: usize) -> Arc { debug!("Creating preloader with {} threads", num_threads); let (request_tx, request_rx) = crossbeam_channel::unbounded(); let (stat_tx, stat_rx) = crossbeam_channel::unbounded(); let cache = Arc::new(Mutex::new(PreloaderCache::default())); let mut join_handles = Vec::with_capacity(num_threads); for i in 0..num_threads { join_handles.push(preloader_thread( i, request_rx.clone(), cache.clone(), loader.clone(), stat_tx.clone(), )); } let threads = Mutex::new(PreloaderThreads::new(join_handles)); let loader = Mutex::new(loader); Arc::new(Self { request_tx, request_rx, stat_tx, stat_rx, threads, cache, loader }) } #[allow(unused)] pub fn shutdown(self) { let guard = self.threads.into_inner().unwrap(); for handle in guard.join_handles { handle.join().unwrap(); } } #[instrument(name = "Preloader::fetch", skip_all)] pub fn fetch(&self, request: SectorGroupRequest, max_groups: u32) -> SectorGroupResult { let num_threads = { let mut threads_guard = self.threads.lock().map_err(map_poisoned)?; while let Ok(stat) = self.stat_rx.try_recv() { threads_guard.push_stats(stat, self); } threads_guard.join_handles.len() }; let mut cache_guard = self.cache.lock().map_err(map_poisoned)?; // Preload n groups ahead for i in 0..num_threads as u32 { let group_idx = request.group_idx + i; if group_idx >= max_groups { break; } let request = SectorGroupRequest { group_idx, ..request }; if cache_guard.contains(&request) { continue; } if self.request_tx.send(request).is_ok() { cache_guard.inflight.insert(request, WaitGroup::new()); } } if let Some(cached) = cache_guard.lru_cache.get(&request) { return Ok(cached.clone()); } if let Some(wg) = cache_guard.inflight.get(&request) { // Wait for inflight request to finish let wg = wg.clone(); drop(cache_guard); { let _span = span!(Level::TRACE, "wg.wait").entered(); wg.wait(); } let mut cache_guard = self.cache.lock().map_err(map_poisoned)?; if let Some(cached) = cache_guard.lru_cache.get(&request) { return Ok(cached.clone()); } } else { drop(cache_guard); } // No threads are running, fallback to single-threaded loader let result = { let mut loader = self.loader.lock().map_err(map_poisoned)?; loader.load(request) }; match result { Ok(group) => { let mut cache_guard = self.cache.lock().map_err(map_poisoned)?; cache_guard.push(request, group.clone()); Ok(group) } Err(e) => Err(e), } } } #[inline] fn map_poisoned(_: std::sync::PoisonError) -> io::Error { io::Error::other("Mutex poisoned") } pub struct SectorGroupLoader { io: Box, disc_header: Arc, partitions: Arc<[PartitionInfo]>, block: Block, block_buf: Box<[u8]>, } impl Clone for SectorGroupLoader { fn clone(&self) -> Self { let block_size = self.io.block_size() as usize; Self { io: self.io.clone(), disc_header: self.disc_header.clone(), partitions: self.partitions.clone(), block: Block::default(), block_buf: <[u8]>::new_box_zeroed_with_elems(block_size).unwrap(), } } } #[derive(Default)] struct LoadedSectorGroup { /// Start sector of the group start_sector: u32, /// Bitmap of sectors that were read sector_bitmap: u64, /// Total duration of I/O operations io_duration: Option, /// Calculated sector group hashes group_hashes: Option>, } impl SectorGroupLoader { pub fn new( io: Box, disc_header: Arc, partitions: Arc<[PartitionInfo]>, ) -> Self { let block_buf = <[u8]>::new_box_zeroed_with_elems(io.block_size() as usize).unwrap(); Self { io, disc_header, partitions, block: Block::default(), block_buf } } #[instrument(name = "SectorGroupLoader::load", skip_all)] pub fn load(&mut self, request: SectorGroupRequest) -> SectorGroupResult { let mut sector_group_buf = BytesMut::zeroed(SECTOR_GROUP_SIZE); let out = array_ref_mut![sector_group_buf, 0, SECTOR_GROUP_SIZE]; let LoadedSectorGroup { start_sector, sector_bitmap, io_duration, group_hashes } = if request.partition_idx.is_some() { self.load_partition_group(request, out)? } else { self.load_raw_group(request, out)? }; Ok(SectorGroup { request, start_sector, data: sector_group_buf.freeze(), sector_bitmap, io_duration, group_hashes, }) } /// Load a sector group from a partition. /// /// This will handle encryption, decryption, and hash recovery as needed. fn load_partition_group( &mut self, request: SectorGroupRequest, sector_group_buf: &mut [u8; SECTOR_GROUP_SIZE], ) -> io::Result { let Some(partition) = request.partition_idx.and_then(|idx| self.partitions.get(idx as usize)) else { return Ok(LoadedSectorGroup::default()); }; let abs_group_sector = partition.data_start_sector + request.group_idx * 64; if abs_group_sector >= partition.data_end_sector { return Ok(LoadedSectorGroup::default()); } // Bitmap of sectors that were read let mut sector_bitmap = 0u64; // Bitmap of sectors that are decrypted let mut decrypted_sectors = 0u64; // Bitmap of sectors that need hash recovery let mut hash_recovery_sectors = 0u64; // Hash exceptions let mut hash_exceptions = Vec::::new(); // Total duration of I/O operations let mut io_duration = None; // Calculated sector group hashes let mut group_hashes = None; // Read sector group for sector in 0..64 { let sector_data = array_ref_mut![sector_group_buf, sector as usize * SECTOR_SIZE, SECTOR_SIZE]; let abs_sector = abs_group_sector + sector; if abs_sector >= partition.data_end_sector { // Already zeroed decrypted_sectors |= 1 << sector; hash_recovery_sectors |= 1 << sector; continue; } // Read new block if !self.block.contains(abs_sector) { self.block = self .io .read_block(self.block_buf.as_mut(), abs_sector) .io_with_context(|| format!("Reading block for sector {abs_sector}"))?; if let Some(duration) = self.block.io_duration { *io_duration.get_or_insert_with(Duration::default) += duration; } if self.block.kind == BlockKind::None { error!("Failed to read block for sector {}", abs_sector); break; } } // Add hash exceptions self.block .append_hash_exceptions(abs_sector, sector, &mut hash_exceptions) .io_with_context(|| format!("Appending hash exceptions for sector {abs_sector}"))?; // Read new sector into buffer let (encrypted, has_hashes) = self .block .copy_sector( sector_data, self.block_buf.as_mut(), abs_sector, partition.disc_header(), Some(partition), ) .io_with_context(|| format!("Copying sector {abs_sector} from block"))?; if !encrypted { decrypted_sectors |= 1 << sector; } if !has_hashes && partition.has_hashes { hash_recovery_sectors |= 1 << sector; } sector_bitmap |= 1 << sector; } // Recover hashes if request.force_rehash || (request.mode != PartitionEncryption::ForceDecryptedNoHashes && hash_recovery_sectors != 0) { // Decrypt any encrypted sectors if decrypted_sectors != u64::MAX { for sector in 0..64 { let sector_data = array_ref_mut![sector_group_buf, sector * SECTOR_SIZE, SECTOR_SIZE]; if (decrypted_sectors >> sector) & 1 == 0 { decrypt_sector(sector_data, &partition.key); } } decrypted_sectors = u64::MAX; } // Recover hashes let hashes = hash_sector_group(sector_group_buf, request.force_rehash); // Apply hashes for sector in 0..64 { let sector_data = array_ref_mut![sector_group_buf, sector * SECTOR_SIZE, SECTOR_SIZE]; if (hash_recovery_sectors >> sector) & 1 == 1 { hashes.apply(sector_data, sector); } } // Persist hashes group_hashes = Some(Arc::from(hashes)); } // Apply hash exceptions if request.mode != PartitionEncryption::ForceDecryptedNoHashes && !hash_exceptions.is_empty() { for exception in hash_exceptions { let offset = exception.offset.get(); let sector = offset / HASHES_SIZE as u16; // Decrypt sector if needed let sector_data = array_ref_mut![sector_group_buf, sector as usize * SECTOR_SIZE, SECTOR_SIZE]; if (decrypted_sectors >> sector) & 1 == 0 { decrypt_sector(sector_data, &partition.key); decrypted_sectors |= 1 << sector; } let sector_offset = (offset - (sector * HASHES_SIZE as u16)) as usize; *array_ref_mut![sector_data, sector_offset, 20] = exception.hash; } } // Encrypt/decrypt sectors if match request.mode { PartitionEncryption::Original => partition.has_encryption, PartitionEncryption::ForceEncrypted => true, PartitionEncryption::ForceDecrypted | PartitionEncryption::ForceDecryptedNoHashes => { false } } { // Encrypt any decrypted sectors if decrypted_sectors != 0 { for sector in 0..64 { let sector_data = array_ref_mut![ sector_group_buf, sector as usize * SECTOR_SIZE, SECTOR_SIZE ]; if (decrypted_sectors >> sector) & 1 == 1 { encrypt_sector(sector_data, &partition.key); } } } } else if decrypted_sectors != u64::MAX { // Decrypt any encrypted sectors for sector in 0..64 { let sector_data = array_ref_mut![sector_group_buf, sector as usize * SECTOR_SIZE, SECTOR_SIZE]; if (decrypted_sectors >> sector) & 1 == 0 { decrypt_sector(sector_data, &partition.key); } } } Ok(LoadedSectorGroup { start_sector: abs_group_sector, sector_bitmap, io_duration, group_hashes, }) } /// Loads a non-partition sector group. fn load_raw_group( &mut self, request: SectorGroupRequest, sector_group_buf: &mut [u8; SECTOR_GROUP_SIZE], ) -> io::Result { let abs_group_sector = request.group_idx * 64; // Bitmap of sectors that were read let mut sector_bitmap = 0u64; // Total duration of I/O operations let mut io_duration = None; for sector in 0..64 { let sector_data = array_ref_mut![sector_group_buf, sector as usize * SECTOR_SIZE, SECTOR_SIZE]; let abs_sector = abs_group_sector + sector; if self.partitions.iter().any(|p| p.data_contains_sector(abs_sector)) { continue; } // Read new block if !self.block.contains(abs_sector) { self.block = self .io .read_block(self.block_buf.as_mut(), abs_sector) .io_with_context(|| format!("Reading block for sector {abs_sector}"))?; if let Some(duration) = self.block.io_duration { *io_duration.get_or_insert_with(Duration::default) += duration; } if self.block.kind == BlockKind::None { break; } } // Read new sector into buffer self.block .copy_sector( sector_data, self.block_buf.as_mut(), abs_sector, self.disc_header.as_ref(), None, ) .io_with_context(|| format!("Copying sector {abs_sector} from block"))?; sector_bitmap |= 1 << sector; } Ok(LoadedSectorGroup { start_sector: abs_group_sector, sector_bitmap, io_duration, group_hashes: None, }) } } /// Fetch a sector group from the cache or from the preloader. /// Returns a boolean indicating if the group was updated. pub fn fetch_sector_group<'a>( request: SectorGroupRequest, max_groups: u32, mut cached: &'a mut Option, preloader: &Preloader, ) -> io::Result<(&'a SectorGroup, bool)> { polonius!(|cached| -> io::Result<(&'polonius SectorGroup, bool)> { if let Some(sector_group) = cached { if sector_group.request == request { polonius_return!(Ok((sector_group, false))); } } }); let sector_group = preloader.fetch(request, max_groups)?; Ok((cached.insert(sector_group), true)) }