From 9d8cd980b8ff798c8be1f1194eb6af7a62554b85 Mon Sep 17 00:00:00 2001 From: Luke Street Date: Mon, 31 Mar 2025 22:53:08 -0600 Subject: [PATCH] DiscStream rework & threading improvements --- Cargo.lock | 40 ---------- nod/Cargo.toml | 1 - nod/src/build/gc.rs | 23 ++++-- nod/src/disc/direct.rs | 7 ++ nod/src/disc/gcn.rs | 35 +++------ nod/src/disc/reader.rs | 26 +------ nod/src/disc/wii.rs | 9 +-- nod/src/disc/writer.rs | 126 +++++++++++++++--------------- nod/src/io/block.rs | 76 +++++++++++------- nod/src/io/ciso.rs | 18 ++--- nod/src/io/gcz.rs | 38 +++++---- nod/src/io/iso.rs | 12 +-- nod/src/io/nfs.rs | 7 +- nod/src/io/nkit.rs | 55 ++++++++++--- nod/src/io/split.rs | 89 ++++++++------------- nod/src/io/tgc.rs | 37 +++++---- nod/src/io/wbfs.rs | 47 +++++------ nod/src/io/wia.rs | 94 +++++++++++++--------- nod/src/lib.rs | 2 +- nod/src/read.rs | 155 ++++++++++++++++++++++++++++++++----- nod/src/util/compress.rs | 27 +------ nod/src/util/lfg.rs | 2 +- nod/src/util/read.rs | 95 +++++++++++++++++++++++ nod/src/write.rs | 2 +- nodtool/src/cmd/gen.rs | 2 +- nodtool/src/util/shared.rs | 4 +- 26 files changed, 602 insertions(+), 427 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3e50658..09dcd0f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -204,25 +204,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "crossbeam-deque" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" -dependencies = [ - "crossbeam-epoch", - "crossbeam-utils", -] - -[[package]] -name = "crossbeam-epoch" -version = "0.9.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" -dependencies = [ - "crossbeam-utils", -] - [[package]] name = "crossbeam-utils" version = "0.8.21" @@ -577,7 +558,6 @@ dependencies = [ "miniz_oxide", "openssl", "polonius-the-crab", - "rayon", "sha1", "simple_moving_average", "thiserror", @@ -778,26 +758,6 @@ dependencies = [ "proc-macro2", ] -[[package]] -name = "rayon" -version = "1.10.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" -dependencies = [ - "either", - "rayon-core", -] - -[[package]] -name = "rayon-core" -version = "1.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" -dependencies = [ - "crossbeam-deque", - "crossbeam-utils", -] - [[package]] name = "regex" version = "1.11.1" diff --git a/nod/Cargo.toml b/nod/Cargo.toml index 2e0d70d..28041b1 100644 --- a/nod/Cargo.toml +++ b/nod/Cargo.toml @@ -44,7 +44,6 @@ md-5 = { workspace = true } miniz_oxide = { version = "0.8", optional = true } openssl = { version = "0.10", optional = true } polonius-the-crab = "0.4" -rayon = "1.10" sha1 = { workspace = true } simple_moving_average = "1.0" thiserror = "2.0" diff --git a/nod/src/build/gc.rs b/nod/src/build/gc.rs index a8ebedc..a9150e6 100644 --- a/nod/src/build/gc.rs +++ b/nod/src/build/gc.rs @@ -15,11 +15,11 @@ use crate::{ WII_MAGIC, fst::{Fst, FstBuilder}, }, - read::DiscStream, + read::{CloneableStream, DiscStream, NonCloneableStream}, util::{Align, array_ref, array_ref_mut, lfg::LaggedFibonacci}, }; -pub trait FileCallback: Clone + Send + Sync { +pub trait FileCallback: Send { fn read_file(&mut self, out: &mut [u8], name: &str, offset: u64) -> io::Result<()>; } @@ -629,15 +629,26 @@ impl GCPartitionWriter { Ok(()) } - pub fn into_stream(self, file_callback: Cb) -> Result> - where Cb: FileCallback + 'static { - Ok(Box::new(GCPartitionStream::new( + pub fn into_cloneable_stream(self, file_callback: Cb) -> Result> + where Cb: FileCallback + Clone + 'static { + Ok(Box::new(CloneableStream::new(GCPartitionStream::new( file_callback, Arc::from(self.write_info), self.disc_size, self.disc_id, self.disc_num, - ))) + )))) + } + + pub fn into_non_cloneable_stream(self, file_callback: Cb) -> Result> + where Cb: FileCallback + 'static { + Ok(Box::new(NonCloneableStream::new(GCPartitionStream::new( + file_callback, + Arc::from(self.write_info), + self.disc_size, + self.disc_id, + self.disc_num, + )))) } } diff --git a/nod/src/disc/direct.rs b/nod/src/disc/direct.rs index 5ae90b7..a0dd91e 100644 --- a/nod/src/disc/direct.rs +++ b/nod/src/disc/direct.rs @@ -11,6 +11,7 @@ use crate::{ common::KeyBytes, disc::{DiscHeader, SECTOR_SIZE, wii::SECTOR_DATA_SIZE}, io::block::{Block, BlockReader}, + read::{PartitionMeta, PartitionReader}, util::impl_read_for_bufread, }; @@ -136,3 +137,9 @@ impl Seek for DirectDiscReader { fn stream_position(&mut self) -> io::Result { Ok(self.pos) } } + +impl PartitionReader for DirectDiscReader { + fn is_wii(&self) -> bool { unimplemented!() } + + fn meta(&mut self) -> Result { unimplemented!() } +} diff --git a/nod/src/disc/gcn.rs b/nod/src/disc/gcn.rs index 7c656b9..6f76d35 100644 --- a/nod/src/disc/gcn.rs +++ b/nod/src/disc/gcn.rs @@ -14,8 +14,7 @@ use crate::{ SECTOR_SIZE, preloader::{Preloader, SectorGroup, SectorGroupRequest, fetch_sector_group}, }, - io::block::BlockReader, - read::{DiscStream, PartitionEncryption, PartitionMeta, PartitionReader}, + read::{PartitionEncryption, PartitionMeta, PartitionReader}, util::{ impl_read_for_bufread, read::{read_arc, read_arc_slice, read_from}, @@ -23,7 +22,6 @@ use crate::{ }; pub struct PartitionReaderGC { - io: Box, preloader: Arc, pos: u64, disc_size: u64, @@ -34,7 +32,6 @@ pub struct PartitionReaderGC { impl Clone for PartitionReaderGC { fn clone(&self) -> Self { Self { - io: self.io.clone(), preloader: self.preloader.clone(), pos: 0, disc_size: self.disc_size, @@ -45,19 +42,8 @@ impl Clone for PartitionReaderGC { } impl PartitionReaderGC { - pub fn new( - inner: Box, - preloader: Arc, - disc_size: u64, - ) -> Result> { - Ok(Box::new(Self { - io: inner, - preloader, - pos: 0, - disc_size, - sector_group: None, - meta: None, - })) + pub fn new(preloader: Arc, disc_size: u64) -> Result> { + Ok(Box::new(Self { preloader, pos: 0, disc_size, sector_group: None, meta: None })) } } @@ -131,9 +117,7 @@ impl PartitionReader for PartitionReaderGC { } pub(crate) fn read_dol( - // TODO: replace with &dyn mut DiscStream when trait_upcasting is stabilized - // https://github.com/rust-lang/rust/issues/65991 - reader: &mut (impl DiscStream + ?Sized), + reader: &mut dyn PartitionReader, boot_header: &BootHeader, is_wii: bool, ) -> Result> { @@ -153,9 +137,7 @@ pub(crate) fn read_dol( } pub(crate) fn read_fst( - // TODO: replace with &dyn mut DiscStream when trait_upcasting is stabilized - // https://github.com/rust-lang/rust/issues/65991 - reader: &mut (impl DiscStream + ?Sized), + reader: &mut dyn PartitionReader, boot_header: &BootHeader, is_wii: bool, ) -> Result> { @@ -173,7 +155,7 @@ pub(crate) fn read_fst( Ok(raw_fst) } -pub(crate) fn read_apploader(reader: &mut dyn DiscStream) -> Result> { +pub(crate) fn read_apploader(reader: &mut dyn PartitionReader) -> Result> { reader .seek(SeekFrom::Start(BOOT_SIZE as u64 + BI2_SIZE as u64)) .context("Seeking to apploader offset")?; @@ -190,7 +172,10 @@ pub(crate) fn read_apploader(reader: &mut dyn DiscStream) -> Result> { Ok(Arc::from(raw_apploader)) } -pub(crate) fn read_part_meta(reader: &mut dyn DiscStream, is_wii: bool) -> Result { +pub(crate) fn read_part_meta( + reader: &mut dyn PartitionReader, + is_wii: bool, +) -> Result { // boot.bin let raw_boot: Arc<[u8; BOOT_SIZE]> = read_arc(reader).context("Reading boot.bin")?; let boot_header = BootHeader::ref_from_bytes(&raw_boot[BB2_OFFSET..]).unwrap(); diff --git a/nod/src/disc/reader.rs b/nod/src/disc/reader.rs index a589af1..adf2b80 100644 --- a/nod/src/disc/reader.rs +++ b/nod/src/disc/reader.rs @@ -252,23 +252,14 @@ impl DiscReader { match &self.disc_data { DiscReaderData::GameCube { .. } => { if index == 0 { - Ok(PartitionReaderGC::new( - self.io.clone(), - self.preloader.clone(), - self.disc_size(), - )?) + Ok(PartitionReaderGC::new(self.preloader.clone(), self.disc_size())?) } else { Err(Error::DiscFormat("GameCube discs only have one partition".to_string())) } } DiscReaderData::Wii { partitions, .. } => { if let Some(part) = partitions.get(index) { - Ok(PartitionReaderWii::new( - self.io.clone(), - self.preloader.clone(), - part, - options, - )?) + Ok(PartitionReaderWii::new(self.preloader.clone(), part, options)?) } else { Err(Error::DiscFormat(format!("Partition {index} not found"))) } @@ -286,23 +277,14 @@ impl DiscReader { match &self.disc_data { DiscReaderData::GameCube { .. } => { if kind == PartitionKind::Data { - Ok(PartitionReaderGC::new( - self.io.clone(), - self.preloader.clone(), - self.disc_size(), - )?) + Ok(PartitionReaderGC::new(self.preloader.clone(), self.disc_size())?) } else { Err(Error::DiscFormat("GameCube discs only have a data partition".to_string())) } } DiscReaderData::Wii { partitions, .. } => { if let Some(part) = partitions.iter().find(|v| v.kind == kind) { - Ok(PartitionReaderWii::new( - self.io.clone(), - self.preloader.clone(), - part, - options, - )?) + Ok(PartitionReaderWii::new(self.preloader.clone(), part, options)?) } else { Err(Error::DiscFormat(format!("Partition type {kind} not found"))) } diff --git a/nod/src/disc/wii.rs b/nod/src/disc/wii.rs index b09e2d0..a6064c5 100644 --- a/nod/src/disc/wii.rs +++ b/nod/src/disc/wii.rs @@ -18,7 +18,6 @@ use crate::{ gcn::{PartitionReaderGC, read_part_meta}, preloader::{Preloader, SectorGroup, SectorGroupRequest, fetch_sector_group}, }, - io::block::BlockReader, read::{PartitionEncryption, PartitionMeta, PartitionOptions, PartitionReader}, util::{ aes::aes_cbc_decrypt, @@ -300,7 +299,6 @@ impl WiiPartitionHeader { } pub(crate) struct PartitionReaderWii { - io: Box, preloader: Arc, partition: PartitionInfo, pos: u64, @@ -312,7 +310,6 @@ pub(crate) struct PartitionReaderWii { impl Clone for PartitionReaderWii { fn clone(&self) -> Self { Self { - io: self.io.clone(), preloader: self.preloader.clone(), partition: self.partition.clone(), pos: 0, @@ -325,13 +322,11 @@ impl Clone for PartitionReaderWii { impl PartitionReaderWii { pub fn new( - io: Box, preloader: Arc, partition: &PartitionInfo, options: &PartitionOptions, ) -> Result> { let mut reader = Self { - io, preloader, partition: partition.clone(), pos: 0, @@ -498,12 +493,12 @@ impl PartitionReader for PartitionReaderWii { if let Some(meta) = &self.meta { return Ok(meta.clone()); } - self.seek(SeekFrom::Start(0)).context("Seeking to partition header")?; + self.rewind().context("Seeking to partition header")?; let mut meta = read_part_meta(self, true)?; meta.raw_ticket = Some(Arc::from(self.partition.header.ticket.as_bytes())); // Read TMD, cert chain, and H3 table - let mut reader = PartitionReaderGC::new(self.io.clone(), self.preloader.clone(), u64::MAX)?; + let mut reader = PartitionReaderGC::new(self.preloader.clone(), u64::MAX)?; let offset = self.partition.start_sector as u64 * SECTOR_SIZE as u64; meta.raw_tmd = if self.partition.header.tmd_size() != 0 { reader diff --git a/nod/src/disc/writer.rs b/nod/src/disc/writer.rs index 843818a..1899f79 100644 --- a/nod/src/disc/writer.rs +++ b/nod/src/disc/writer.rs @@ -1,11 +1,11 @@ use std::{ + collections::VecDeque, io, io::{BufRead, Read}, }; use bytes::{Bytes, BytesMut}; use dyn_clone::DynClone; -use rayon::prelude::*; use crate::{ Error, Result, ResultContext, @@ -25,7 +25,7 @@ use crate::{ /// writing fails. The second and third arguments are the current bytes processed and the total /// bytes to process, respectively. For most formats, this has no relation to the written disc size, /// but can be used to display progress. -pub type DataCallback<'a> = dyn FnMut(Bytes, u64, u64) -> io::Result<()> + Send + 'a; +pub type DataCallback<'a> = dyn FnMut(Bytes, u64, u64) -> io::Result<()> + 'a; /// A trait for writing disc images. pub trait DiscWriter: DynClone { @@ -67,7 +67,7 @@ pub struct BlockResult { pub meta: T, } -pub trait BlockProcessor: Clone + Send + Sync { +pub trait BlockProcessor: Clone + Send { type BlockMeta; fn process_block(&mut self, block_idx: u32) -> io::Result>; @@ -106,10 +106,10 @@ pub fn read_block(reader: &mut DiscReader, block_size: usize) -> io::Result<(Byt /// Process blocks in parallel, ensuring that they are written in order. pub(crate) fn par_process( - create_processor: impl Fn() -> P + Sync, + mut processor: P, block_count: u32, num_threads: usize, - mut callback: impl FnMut(BlockResult) -> Result<()> + Send, + mut callback: impl FnMut(BlockResult) -> Result<()>, ) -> Result<()> where T: Send, @@ -117,7 +117,6 @@ where { if num_threads == 0 { // Fall back to single-threaded processing - let mut processor = create_processor(); for block_idx in 0..block_count { let block = processor .process_block(block_idx) @@ -127,69 +126,70 @@ where return Ok(()); } - let (block_tx, block_rx) = crossbeam_channel::bounded(block_count as usize); - for block_idx in 0..block_count { - block_tx.send(block_idx).unwrap(); - } - drop(block_tx); // Disconnect channel + std::thread::scope(|s| { + let (block_tx, block_rx) = crossbeam_channel::bounded(block_count as usize); + for block_idx in 0..block_count { + block_tx.send(block_idx).unwrap(); + } + drop(block_tx); // Disconnect channel - let (result_tx, result_rx) = crossbeam_channel::bounded(0); - let mut process_error = None; - let mut write_error = None; - rayon::join( - || { - if let Err(e) = (0..num_threads).into_par_iter().try_for_each_init( - || (block_rx.clone(), result_tx.clone(), create_processor()), - |(receiver, block_tx, processor), _| { - while let Ok(block_idx) = receiver.recv() { - let block = processor - .process_block(block_idx) - .with_context(|| format!("Failed to process block {block_idx}"))?; - if block_tx.send(block).is_err() { - break; - } - } - Ok::<_, Error>(()) - }, - ) { - process_error = Some(e); - } - drop(result_tx); // Disconnect channel - }, - || { - let mut current_block = 0; - let mut out_of_order = Vec::>::new(); - 'outer: while let Ok(result) = result_rx.recv() { - if result.block_idx == current_block { - if let Err(e) = callback(result) { - write_error = Some(e); + let (result_tx, result_rx) = crossbeam_channel::bounded(0); + + // Spawn threads to process blocks + for _ in 0..num_threads - 1 { + let block_rx = block_rx.clone(); + let result_tx = result_tx.clone(); + let mut processor = processor.clone(); + s.spawn(move || { + while let Ok(block_idx) = block_rx.recv() { + let result = processor + .process_block(block_idx) + .with_context(|| format!("Failed to process block {block_idx}")); + let failed = result.is_err(); // Stop processing if an error occurs + if result_tx.send(result).is_err() || failed { break; } - current_block += 1; - // Check if any out of order blocks can be written - while out_of_order.first().is_some_and(|r| r.block_idx == current_block) { - let result = out_of_order.remove(0); - if let Err(e) = callback(result) { - write_error = Some(e); - break 'outer; - } - current_block += 1; - } - } else { - out_of_order.push(result); - out_of_order.sort_unstable_by_key(|r| r.block_idx); + } + }); + } + + // Last iteration moves instead of cloning + s.spawn(move || { + while let Ok(block_idx) = block_rx.recv() { + let result = processor + .process_block(block_idx) + .with_context(|| format!("Failed to process block {block_idx}")); + let failed = result.is_err(); // Stop processing if an error occurs + if result_tx.send(result).is_err() || failed { + break; } } - }, - ); - if let Some(e) = process_error { - return Err(e); - } - if let Some(e) = write_error { - return Err(e); - } + }); - Ok(()) + // Main thread processes results + let mut current_block = 0; + let mut out_of_order = VecDeque::>::new(); + while let Ok(result) = result_rx.recv() { + let result = result?; + if result.block_idx == current_block { + callback(result)?; + current_block += 1; + // Check if any out of order blocks can be written + while out_of_order.front().is_some_and(|r| r.block_idx == current_block) { + callback(out_of_order.pop_front().unwrap())?; + current_block += 1; + } + } else { + // Insert sorted + match out_of_order.binary_search_by_key(&result.block_idx, |r| r.block_idx) { + Ok(idx) => Err(Error::Other(format!("Unexpected duplicate block {idx}")))?, + Err(idx) => out_of_order.insert(idx, result), + } + } + } + + Ok(()) + }) } /// The determined block type. diff --git a/nod/src/io/block.rs b/nod/src/io/block.rs index 1e44336..319b47f 100644 --- a/nod/src/io/block.rs +++ b/nod/src/io/block.rs @@ -14,11 +14,16 @@ use crate::{ wia::{WIAException, WIAExceptionList}, }, read::{DiscMeta, DiscStream}, - util::{aes::decrypt_sector, array_ref, array_ref_mut, lfg::LaggedFibonacci, read::read_from}, + util::{ + aes::decrypt_sector, + array_ref, array_ref_mut, + lfg::LaggedFibonacci, + read::{read_at, read_from}, + }, }; /// Block reader trait for reading disc images. -pub trait BlockReader: DynClone + Send + Sync { +pub trait BlockReader: DynClone + Send { /// Reads a block from the disc image containing the specified sector. fn read_block(&mut self, out: &mut [u8], sector: u32) -> io::Result; @@ -33,25 +38,26 @@ dyn_clone::clone_trait_object!(BlockReader); /// Creates a new [`BlockReader`] instance from a stream. pub fn new(mut stream: Box) -> Result> { - let io: Box = match detect(stream.as_mut()).context("Detecting file type")? { - Some(Format::Iso) => crate::io::iso::BlockReaderISO::new(stream)?, - Some(Format::Ciso) => crate::io::ciso::BlockReaderCISO::new(stream)?, - Some(Format::Gcz) => { - #[cfg(feature = "compress-zlib")] - { - crate::io::gcz::BlockReaderGCZ::new(stream)? + let io: Box = + match detect_stream(stream.as_mut()).context("Detecting file type")? { + Some(Format::Iso) => crate::io::iso::BlockReaderISO::new(stream)?, + Some(Format::Ciso) => crate::io::ciso::BlockReaderCISO::new(stream)?, + Some(Format::Gcz) => { + #[cfg(feature = "compress-zlib")] + { + crate::io::gcz::BlockReaderGCZ::new(stream)? + } + #[cfg(not(feature = "compress-zlib"))] + return Err(Error::DiscFormat("GCZ support is disabled".to_string())); } - #[cfg(not(feature = "compress-zlib"))] - return Err(Error::DiscFormat("GCZ support is disabled".to_string())); - } - Some(Format::Nfs) => { - return Err(Error::DiscFormat("NFS requires a filesystem path".to_string())); - } - Some(Format::Wbfs) => crate::io::wbfs::BlockReaderWBFS::new(stream)?, - Some(Format::Wia | Format::Rvz) => crate::io::wia::BlockReaderWIA::new(stream)?, - Some(Format::Tgc) => crate::io::tgc::BlockReaderTGC::new(stream)?, - None => return Err(Error::DiscFormat("Unknown disc format".to_string())), - }; + Some(Format::Nfs) => { + return Err(Error::DiscFormat("NFS requires a filesystem path".to_string())); + } + Some(Format::Wbfs) => crate::io::wbfs::BlockReaderWBFS::new(stream)?, + Some(Format::Wia | Format::Rvz) => crate::io::wia::BlockReaderWIA::new(stream)?, + Some(Format::Tgc) => crate::io::tgc::BlockReaderTGC::new(stream)?, + None => return Err(Error::DiscFormat("Unknown disc format".to_string())), + }; check_block_size(io.as_ref())?; Ok(io) } @@ -71,7 +77,9 @@ pub fn open(filename: &Path) -> Result> { return Err(Error::DiscFormat(format!("Input is not a file: {}", filename.display()))); } let mut stream = Box::new(SplitFileReader::new(filename)?); - let io: Box = match detect(stream.as_mut()).context("Detecting file type")? { + let io: Box = match detect_stream(stream.as_mut()) + .context("Detecting file type")? + { Some(Format::Iso) => crate::io::iso::BlockReaderISO::new(stream)?, Some(Format::Ciso) => crate::io::ciso::BlockReaderCISO::new(stream)?, Some(Format::Gcz) => { @@ -109,12 +117,23 @@ pub const RVZ_MAGIC: MagicBytes = *b"RVZ\x01"; pub fn detect(stream: &mut R) -> io::Result> where R: Read + ?Sized { - let data: [u8; 0x20] = match read_from(stream) { - Ok(magic) => magic, - Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), - Err(e) => return Err(e), - }; - let out = match *array_ref!(data, 0, 4) { + match read_from(stream) { + Ok(ref magic) => Ok(detect_internal(magic)), + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(None), + Err(e) => Err(e), + } +} + +fn detect_stream(stream: &mut dyn DiscStream) -> io::Result> { + match read_at(stream, 0) { + Ok(ref magic) => Ok(detect_internal(magic)), + Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(None), + Err(e) => Err(e), + } +} + +fn detect_internal(data: &[u8; 0x20]) -> Option { + match *array_ref!(data, 0, 4) { CISO_MAGIC => Some(Format::Ciso), GCZ_MAGIC => Some(Format::Gcz), NFS_MAGIC => Some(Format::Nfs), @@ -126,8 +145,7 @@ where R: Read + ?Sized { Some(Format::Iso) } _ => None, - }; - Ok(out) + } } fn check_block_size(io: &dyn BlockReader) -> Result<()> { diff --git a/nod/src/io/ciso.rs b/nod/src/io/ciso.rs index 7eaa05d..4317556 100644 --- a/nod/src/io/ciso.rs +++ b/nod/src/io/ciso.rs @@ -1,6 +1,6 @@ use std::{ io, - io::{Read, Seek, SeekFrom}, + io::{Seek, SeekFrom}, mem::size_of, sync::Arc, }; @@ -28,7 +28,7 @@ use crate::{ array_ref, digest::DigestManager, lfg::LaggedFibonacci, - read::{box_to_bytes, read_arc}, + read::{box_to_bytes, read_arc_at}, static_assert, }, write::{DiscFinalization, DiscWriterWeight, FormatOptions, ProcessOptions}, @@ -58,8 +58,8 @@ pub struct BlockReaderCISO { impl BlockReaderCISO { pub fn new(mut inner: Box) -> Result> { // Read header - inner.seek(SeekFrom::Start(0)).context("Seeking to start")?; - let header: Arc = read_arc(inner.as_mut()).context("Reading CISO header")?; + let header: Arc = + read_arc_at(inner.as_mut(), 0).context("Reading CISO header")?; if header.magic != CISO_MAGIC { return Err(Error::DiscFormat("Invalid CISO magic".to_string())); } @@ -76,7 +76,7 @@ impl BlockReaderCISO { } } let file_size = SECTOR_SIZE as u64 + block as u64 * header.block_size.get() as u64; - let len = inner.seek(SeekFrom::End(0)).context("Determining stream length")?; + let len = inner.stream_len().context("Determining stream length")?; if file_size > len { return Err(Error::DiscFormat(format!( "CISO file size mismatch: expected at least {} bytes, got {}", @@ -86,8 +86,7 @@ impl BlockReaderCISO { // Read NKit header if present (after CISO data) let nkit_header = if len > file_size + 12 { - inner.seek(SeekFrom::Start(file_size)).context("Seeking to NKit header")?; - NKitHeader::try_read_from(inner.as_mut(), header.block_size.get(), true) + NKitHeader::try_read_from(inner.as_mut(), file_size, header.block_size.get(), true) } else { None }; @@ -119,8 +118,7 @@ impl BlockReader for BlockReaderCISO { // Read block let file_offset = size_of::() as u64 + phys_block as u64 * block_size as u64; - self.inner.seek(SeekFrom::Start(file_offset))?; - self.inner.read_exact(out)?; + self.inner.read_exact_at(out, file_offset)?; Ok(Block::new(block_idx, block_size, BlockKind::Raw)) } @@ -259,7 +257,7 @@ impl DiscWriter for DiscWriterCISO { header.magic = CISO_MAGIC; header.block_size = block_size.into(); par_process( - || BlockProcessorCISO { + BlockProcessorCISO { inner: self.inner.clone(), block_size, decrypted_block: <[u8]>::new_box_zeroed_with_elems(block_size as usize).unwrap(), diff --git a/nod/src/io/gcz.rs b/nod/src/io/gcz.rs index d072a97..52fa988 100644 --- a/nod/src/io/gcz.rs +++ b/nod/src/io/gcz.rs @@ -1,6 +1,6 @@ use std::{ io, - io::{Read, Seek, SeekFrom}, + io::{Seek, SeekFrom}, mem::size_of, sync::Arc, }; @@ -22,7 +22,7 @@ use crate::{ util::{ compress::{Compressor, DecompressionKind, Decompressor}, digest::DigestManager, - read::{read_arc_slice, read_from}, + read::{read_arc_slice_at, read_at}, static_assert, }, write::{DiscFinalization, DiscWriterWeight, FormatOptions, ProcessOptions}, @@ -69,18 +69,22 @@ impl Clone for BlockReaderGCZ { impl BlockReaderGCZ { pub fn new(mut inner: Box) -> Result> { // Read header - inner.seek(SeekFrom::Start(0)).context("Seeking to start")?; - let header: GCZHeader = read_from(inner.as_mut()).context("Reading GCZ header")?; + let header: GCZHeader = read_at(inner.as_mut(), 0).context("Reading GCZ header")?; if header.magic != GCZ_MAGIC { return Err(Error::DiscFormat("Invalid GCZ magic".to_string())); } // Read block map and hashes let block_count = header.block_count.get(); - let block_map = read_arc_slice(inner.as_mut(), block_count as usize) - .context("Reading GCZ block map")?; - let block_hashes = read_arc_slice(inner.as_mut(), block_count as usize) - .context("Reading GCZ block hashes")?; + let block_map = + read_arc_slice_at(inner.as_mut(), block_count as usize, size_of::() as u64) + .context("Reading GCZ block map")?; + let block_hashes = read_arc_slice_at( + inner.as_mut(), + block_count as usize, + size_of::() as u64 + block_count as u64 * 8, + ) + .context("Reading GCZ block hashes")?; // header + block_count * (u64 + u32) let data_offset = size_of::() as u64 + block_count as u64 * 12; @@ -121,29 +125,29 @@ impl BlockReader for BlockReaderGCZ { .get() & !(1 << 63)) - file_offset) as usize; - if compressed_size > self.block_buf.len() { + if compressed_size > block_size as usize { return Err(io::Error::new( io::ErrorKind::InvalidData, format!( "Compressed block size exceeds block size: {} > {}", - compressed_size, - self.block_buf.len() + compressed_size, block_size ), )); - } else if !compressed && compressed_size != self.block_buf.len() { + } else if !compressed && compressed_size != block_size as usize { return Err(io::Error::new( io::ErrorKind::InvalidData, format!( "Uncompressed block size does not match block size: {} != {}", - compressed_size, - self.block_buf.len() + compressed_size, block_size ), )); } // Read block - self.inner.seek(SeekFrom::Start(self.data_offset + file_offset))?; - self.inner.read_exact(&mut self.block_buf[..compressed_size])?; + self.inner.read_exact_at( + &mut self.block_buf[..compressed_size], + self.data_offset + file_offset, + )?; // Verify block checksum let checksum = adler32_slice(&self.block_buf[..compressed_size]); @@ -315,7 +319,7 @@ impl DiscWriter for DiscWriterGCZ { let mut input_position = 0; let mut data_position = 0; par_process( - || BlockProcessorGCZ { + BlockProcessorGCZ { inner: self.inner.clone(), header: self.header.clone(), compressor: Compressor::new(self.compression, block_size as usize), diff --git a/nod/src/io/iso.rs b/nod/src/io/iso.rs index 68c7668..13cb704 100644 --- a/nod/src/io/iso.rs +++ b/nod/src/io/iso.rs @@ -1,7 +1,4 @@ -use std::{ - io, - io::{BufRead, Read, Seek, SeekFrom}, -}; +use std::{io, io::BufRead}; use crate::{ Result, ResultContext, @@ -25,7 +22,7 @@ pub struct BlockReaderISO { impl BlockReaderISO { pub fn new(mut inner: Box) -> Result> { - let disc_size = inner.seek(SeekFrom::End(0)).context("Determining stream length")?; + let disc_size = inner.stream_len().context("Determining stream length")?; Ok(Box::new(Self { inner, disc_size })) } } @@ -38,14 +35,13 @@ impl BlockReader for BlockReaderISO { return Ok(Block::sector(sector, BlockKind::None)); } - self.inner.seek(SeekFrom::Start(pos))?; if pos + SECTOR_SIZE as u64 > self.disc_size { // If the last block is not a full sector, fill the rest with zeroes let read = (self.disc_size - pos) as usize; - self.inner.read_exact(&mut out[..read])?; + self.inner.read_exact_at(&mut out[..read], pos)?; out[read..].fill(0); } else { - self.inner.read_exact(out)?; + self.inner.read_exact_at(out, pos)?; } Ok(Block::sector(sector, BlockKind::Raw)) diff --git a/nod/src/io/nfs.rs b/nod/src/io/nfs.rs index 861acea..030d56f 100644 --- a/nod/src/io/nfs.rs +++ b/nod/src/io/nfs.rs @@ -1,7 +1,7 @@ use std::{ fs::File, io, - io::{BufReader, Read, Seek, SeekFrom}, + io::{BufReader, Read}, mem::size_of, path::{Component, Path, PathBuf}, sync::Arc, @@ -17,7 +17,7 @@ use crate::{ block::{Block, BlockKind, BlockReader, NFS_MAGIC}, split::SplitFileReader, }, - read::DiscMeta, + read::{DiscMeta, DiscStream}, util::{aes::aes_cbc_decrypt, array_ref_mut, read::read_arc, static_assert}, }; @@ -116,8 +116,7 @@ impl BlockReader for BlockReaderNFS { // Read sector let offset = size_of::() as u64 + phys_sector as u64 * SECTOR_SIZE as u64; - self.inner.seek(SeekFrom::Start(offset))?; - self.inner.read_exact(out)?; + self.inner.read_exact_at(out, offset)?; // Decrypt let mut iv = [0u8; 0x10]; diff --git a/nod/src/io/nkit.rs b/nod/src/io/nkit.rs index 8b14325..f9137ff 100644 --- a/nod/src/io/nkit.rs +++ b/nod/src/io/nkit.rs @@ -1,15 +1,12 @@ -use std::{ - io, - io::{Read, Seek, SeekFrom, Write}, -}; +use std::io::{self, Read, Seek, Write}; use tracing::warn; use crate::{ common::MagicBytes, disc::DL_DVD_SIZE, - read::DiscMeta, - util::read::{read_from, read_u16_be, read_u32_be, read_u64_be, read_vec}, + read::{DiscMeta, DiscStream}, + util::read::{read_at, read_from, read_u16_be, read_u32_be, read_u64_be, read_vec}, }; #[allow(unused)] @@ -89,12 +86,16 @@ impl Default for NKitHeader { const VERSION_PREFIX: [u8; 7] = *b"NKIT v"; impl NKitHeader { - pub fn try_read_from(reader: &mut R, block_size: u32, has_junk_bits: bool) -> Option - where R: Read + Seek + ?Sized { - let magic: MagicBytes = read_from(reader).ok()?; + pub fn try_read_from( + reader: &mut dyn DiscStream, + pos: u64, + block_size: u32, + has_junk_bits: bool, + ) -> Option { + let magic: MagicBytes = read_at(reader, 0).ok()?; if magic == *b"NKIT" { - reader.seek(SeekFrom::Current(-4)).ok()?; - match NKitHeader::read_from(reader, block_size, has_junk_bits) { + let mut reader = ReadAdapter::new(reader, pos); + match NKitHeader::read_from(&mut reader, block_size, has_junk_bits) { Ok(header) => Some(header), Err(e) => { warn!("Failed to read NKit header: {}", e); @@ -299,3 +300,35 @@ impl JunkBits { DL_DVD_SIZE.div_ceil(block_size as u64).div_ceil(8) as usize } } + +pub struct ReadAdapter<'a> { + reader: &'a mut dyn DiscStream, + pos: u64, +} + +impl<'a> ReadAdapter<'a> { + pub fn new(reader: &'a mut dyn DiscStream, offset: u64) -> Self { Self { reader, pos: offset } } +} + +impl Read for ReadAdapter<'_> { + fn read(&mut self, _buf: &mut [u8]) -> io::Result { + Err(io::Error::from(io::ErrorKind::Unsupported)) + } + + fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { + self.reader.read_exact_at(buf, self.pos)?; + self.pos += buf.len() as u64; + Ok(()) + } +} + +impl Seek for ReadAdapter<'_> { + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { + self.pos = match pos { + io::SeekFrom::Start(pos) => pos, + io::SeekFrom::End(v) => self.reader.stream_len()?.saturating_add_signed(v), + io::SeekFrom::Current(v) => self.pos.saturating_add_signed(v), + }; + Ok(self.pos) + } +} diff --git a/nod/src/io/split.rs b/nod/src/io/split.rs index beb16d2..6d07298 100644 --- a/nod/src/io/split.rs +++ b/nod/src/io/split.rs @@ -1,19 +1,15 @@ use std::{ fs::File, io, - io::{BufReader, Read, Seek, SeekFrom}, path::{Path, PathBuf}, }; -use tracing::instrument; - -use crate::{ErrorContext, Result, ResultContext}; +use crate::{ErrorContext, Result, ResultContext, read::DiscStream}; #[derive(Debug)] pub struct SplitFileReader { files: Vec>, - open_file: Option>>, - pos: u64, + open_file: Option>, } #[derive(Debug, Clone)] @@ -60,7 +56,7 @@ fn split_path_3(input: &Path, index: u32) -> PathBuf { } impl SplitFileReader { - pub fn empty() -> Self { Self { files: Vec::new(), open_file: None, pos: 0 } } + pub fn empty() -> Self { Self { files: Vec::new(), open_file: Default::default() } } pub fn new(path: &Path) -> Result { let mut files = vec![]; @@ -90,7 +86,7 @@ impl SplitFileReader { break; } } - Ok(Self { files, open_file: None, pos: 0 }) + Ok(Self { files, open_file: Default::default() }) } pub fn add(&mut self, path: &Path) -> Result<()> { @@ -102,54 +98,37 @@ impl SplitFileReader { } pub fn len(&self) -> u64 { self.files.last().map_or(0, |f| f.begin + f.size) } - - fn check_open_file(&mut self) -> io::Result>>> { - if self.open_file.is_none() || !self.open_file.as_ref().unwrap().contains(self.pos) { - self.open_file = if let Some(split) = self.files.iter().find(|f| f.contains(self.pos)) { - let mut file = BufReader::new(File::open(&split.inner)?); - // log::info!("Opened file {} at pos {}", split.inner.display(), self.pos); - file.seek(SeekFrom::Start(self.pos - split.begin))?; - Some(Split { inner: file, begin: split.begin, size: split.size }) - } else { - None - }; - } - Ok(self.open_file.as_mut()) - } -} - -impl Read for SplitFileReader { - #[instrument(name = "SplitFileReader::read", skip_all)] - fn read(&mut self, buf: &mut [u8]) -> io::Result { - let pos = self.pos; - let Some(split) = self.check_open_file()? else { - return Ok(0); - }; - let to_read = buf.len().min((split.begin + split.size - pos) as usize); - let read = split.inner.read(&mut buf[..to_read])?; - self.pos += read as u64; - Ok(read) - } -} - -impl Seek for SplitFileReader { - #[instrument(name = "SplitFileReader::seek", skip_all)] - fn seek(&mut self, pos: SeekFrom) -> io::Result { - self.pos = match pos { - SeekFrom::Start(pos) => pos, - SeekFrom::Current(offset) => self.pos.saturating_add_signed(offset), - SeekFrom::End(offset) => self.len().saturating_add_signed(offset), - }; - if let Some(split) = &mut self.open_file { - if split.contains(self.pos) { - // Seek within the open file - split.inner.seek(SeekFrom::Start(self.pos - split.begin))?; - } - } - Ok(self.pos) - } } impl Clone for SplitFileReader { - fn clone(&self) -> Self { Self { files: self.files.clone(), open_file: None, pos: 0 } } + fn clone(&self) -> Self { Self { files: self.files.clone(), open_file: Default::default() } } +} + +impl DiscStream for SplitFileReader { + fn read_exact_at(&mut self, buf: &mut [u8], offset: u64) -> io::Result<()> { + let split = if self.open_file.as_ref().is_none_or(|s| !s.contains(offset)) { + let split = if let Some(split) = self.files.iter().find(|f| f.contains(offset)) { + let file = File::open(&split.inner)?; + Split { inner: file, begin: split.begin, size: split.size } + } else { + return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); + }; + self.open_file.insert(split) + } else { + self.open_file.as_mut().unwrap() + }; + #[cfg(unix)] + { + use std::os::unix::fs::FileExt; + split.inner.read_exact_at(buf, offset) + } + #[cfg(not(unix))] + { + use std::io::{Read, Seek, SeekFrom}; + split.inner.seek(SeekFrom::Start(offset - split.begin))?; + split.inner.read_exact(buf) + } + } + + fn stream_len(&mut self) -> io::Result { Ok(self.len()) } } diff --git a/nod/src/io/tgc.rs b/nod/src/io/tgc.rs index f9353f0..8abc53f 100644 --- a/nod/src/io/tgc.rs +++ b/nod/src/io/tgc.rs @@ -22,7 +22,7 @@ use crate::{ read::{DiscMeta, DiscStream, PartitionOptions, PartitionReader}, util::{ Align, array_ref, - read::{read_arc, read_arc_slice, read_from, read_with_zero_fill}, + read::{read_arc_at, read_arc_slice_at, read_at, read_with_zero_fill}, static_assert, }, write::{DiscFinalization, DiscWriterWeight, FormatOptions, ProcessOptions}, @@ -73,21 +73,19 @@ pub struct BlockReaderTGC { impl BlockReaderTGC { pub fn new(mut inner: Box) -> Result> { - inner.seek(SeekFrom::Start(0)).context("Seeking to start")?; - // Read header - let header: TGCHeader = read_from(inner.as_mut()).context("Reading TGC header")?; + let header: TGCHeader = read_at(inner.as_mut(), 0).context("Reading TGC header")?; if header.magic != TGC_MAGIC { return Err(Error::DiscFormat("Invalid TGC magic".to_string())); } let disc_size = (header.gcm_files_start.get() + header.user_size.get()) as u64; // Read GCM header - inner - .seek(SeekFrom::Start(header.header_offset.get() as u64)) - .context("Seeking to GCM header")?; - let raw_header = - read_arc::<[u8; GCM_HEADER_SIZE], _>(inner.as_mut()).context("Reading GCM header")?; + let raw_header = read_arc_at::<[u8; GCM_HEADER_SIZE], _>( + inner.as_mut(), + header.header_offset.get() as u64, + ) + .context("Reading GCM header")?; let disc_header = DiscHeader::ref_from_bytes(array_ref![raw_header, 0, size_of::()]) @@ -99,14 +97,20 @@ impl BlockReaderTGC { let boot_header = boot_header.clone(); // Read DOL - inner.seek(SeekFrom::Start(header.dol_offset.get() as u64)).context("Seeking to DOL")?; - let raw_dol = read_arc_slice::(inner.as_mut(), header.dol_size.get() as usize) - .context("Reading DOL")?; + let raw_dol = read_arc_slice_at::( + inner.as_mut(), + header.dol_size.get() as usize, + header.dol_offset.get() as u64, + ) + .context("Reading DOL")?; // Read FST - inner.seek(SeekFrom::Start(header.fst_offset.get() as u64)).context("Seeking to FST")?; - let raw_fst = read_arc_slice::(inner.as_mut(), header.fst_size.get() as usize) - .context("Reading FST")?; + let raw_fst = read_arc_slice_at::( + inner.as_mut(), + header.fst_size.get() as usize, + header.fst_offset.get() as u64, + ) + .context("Reading FST")?; let fst = Fst::new(&raw_fst)?; let mut write_info = Vec::with_capacity(5 + fst.num_files()); @@ -198,8 +202,7 @@ impl FileCallback for FileCallbackTGC { // Calculate file offset in TGC let file_start = (node.offset(false) as u32 - self.header.gcm_files_start.get()) + self.header.user_offset.get(); - self.inner.seek(SeekFrom::Start(file_start as u64 + offset))?; - self.inner.read_exact(out)?; + self.inner.read_exact_at(out, file_start as u64 + offset)?; Ok(()) } } diff --git a/nod/src/io/wbfs.rs b/nod/src/io/wbfs.rs index a977cd4..98cc367 100644 --- a/nod/src/io/wbfs.rs +++ b/nod/src/io/wbfs.rs @@ -1,6 +1,6 @@ use std::{ io, - io::{Read, Seek, SeekFrom}, + io::{Seek, SeekFrom}, mem::size_of, sync::Arc, }; @@ -28,7 +28,7 @@ use crate::{ array_ref, digest::DigestManager, lfg::LaggedFibonacci, - read::{read_arc_slice, read_box_slice, read_from}, + read::{read_arc_slice_at, read_at, read_box_slice_at}, }, write::{DiscFinalization, DiscWriterWeight, FormatOptions, ProcessOptions}, }; @@ -69,12 +69,11 @@ pub struct BlockReaderWBFS { impl BlockReaderWBFS { pub fn new(mut inner: Box) -> Result> { - inner.seek(SeekFrom::Start(0)).context("Seeking to start")?; - let header: WBFSHeader = read_from(inner.as_mut()).context("Reading WBFS header")?; + let header: WBFSHeader = read_at(inner.as_mut(), 0).context("Reading WBFS header")?; if header.magic != WBFS_MAGIC { return Err(Error::DiscFormat("Invalid WBFS magic".to_string())); } - let file_len = inner.seek(SeekFrom::End(0)).context("Determining stream length")?; + let file_len = inner.stream_len().context("Determining stream length")?; let expected_file_len = header.num_sectors.get() as u64 * header.sector_size() as u64; if file_len != expected_file_len { return Err(Error::DiscFormat(format!( @@ -83,12 +82,12 @@ impl BlockReaderWBFS { ))); } - inner - .seek(SeekFrom::Start(size_of::() as u64)) - .context("Seeking to WBFS disc table")?; - let disc_table: Box<[u8]> = - read_box_slice(inner.as_mut(), header.sector_size() as usize - size_of::()) - .context("Reading WBFS disc table")?; + let disc_table: Box<[u8]> = read_box_slice_at( + inner.as_mut(), + header.sector_size() as usize - size_of::(), + size_of::() as u64, + ) + .context("Reading WBFS disc table")?; if disc_table[0] != 1 { return Err(Error::DiscFormat("WBFS doesn't contain a disc".to_string())); } @@ -97,15 +96,20 @@ impl BlockReaderWBFS { } // Read WBFS LBA map - inner - .seek(SeekFrom::Start(header.sector_size() as u64 + DISC_HEADER_SIZE as u64)) - .context("Seeking to WBFS LBA table")?; // Skip header - let block_map: Arc<[U16]> = read_arc_slice(inner.as_mut(), header.max_blocks() as usize) - .context("Reading WBFS LBA table")?; + let block_map: Arc<[U16]> = read_arc_slice_at( + inner.as_mut(), + header.max_blocks() as usize, + header.sector_size() as u64 + DISC_HEADER_SIZE as u64, + ) + .context("Reading WBFS LBA table")?; // Read NKit header if present (always at 0x10000) - inner.seek(SeekFrom::Start(NKIT_HEADER_OFFSET)).context("Seeking to NKit header")?; - let nkit_header = NKitHeader::try_read_from(inner.as_mut(), header.block_size(), true); + let nkit_header = NKitHeader::try_read_from( + inner.as_mut(), + NKIT_HEADER_OFFSET, + header.block_size(), + true, + ); Ok(Box::new(Self { inner, header, block_map, nkit_header })) } @@ -134,8 +138,7 @@ impl BlockReader for BlockReaderWBFS { // Read block let block_start = block_size as u64 * phys_block as u64; - self.inner.seek(SeekFrom::Start(block_start))?; - self.inner.read_exact(out)?; + self.inner.read_exact_at(out, block_start)?; Ok(Block::new(block_idx, block_size, BlockKind::Raw)) } @@ -271,7 +274,7 @@ impl DiscWriterWBFS { return Err(Error::Other("WBFS info too large for block".to_string())); } - inner.seek(SeekFrom::Start(0)).context("Seeking to start")?; + inner.rewind().context("Seeking to start")?; Ok(Box::new(Self { inner, header, disc_table, block_count })) } } @@ -307,7 +310,7 @@ impl DiscWriter for DiscWriterWBFS { let mut phys_block = 1; par_process( - || BlockProcessorWBFS { + BlockProcessorWBFS { inner: self.inner.clone(), header: self.header.clone(), decrypted_block: <[u8]>::new_box_zeroed_with_elems(block_size as usize).unwrap(), diff --git a/nod/src/io/wia.rs b/nod/src/io/wia.rs index 84d928b..db90aa4 100644 --- a/nod/src/io/wia.rs +++ b/nod/src/io/wia.rs @@ -2,7 +2,7 @@ use std::{ borrow::Cow, collections::{BTreeSet, HashMap, hash_map::Entry}, io, - io::{Read, Seek, SeekFrom}, + io::{Seek, SeekFrom}, mem::size_of, sync::Arc, time::Instant, @@ -34,7 +34,10 @@ use crate::{ compress::{Compressor, DecompressionKind, Decompressor}, digest::{DigestManager, sha1_hash, xxh64_hash}, lfg::{LaggedFibonacci, SEED_SIZE, SEED_SIZE_BYTES}, - read::{read_arc_slice, read_from, read_vec}, + read::{ + read_arc_slice_at, read_at, read_box_slice_at, read_into_arc_slice, + read_into_box_slice, read_vec_at, + }, static_assert, }, write::{DiscFinalization, DiscWriterWeight, FormatOptions, ProcessOptions}, @@ -588,16 +591,19 @@ fn verify_hash(buf: &[u8], expected: &HashBytes) -> Result<()> { impl BlockReaderWIA { pub fn new(mut inner: Box) -> Result> { // Load & verify file header - inner.seek(SeekFrom::Start(0)).context("Seeking to start")?; let header: WIAFileHeader = - read_from(inner.as_mut()).context("Reading WIA/RVZ file header")?; + read_at(inner.as_mut(), 0).context("Reading WIA/RVZ file header")?; header.validate()?; let is_rvz = header.is_rvz(); debug!("Header: {:?}", header); // Load & verify disc header - let mut disc_buf: Vec = read_vec(inner.as_mut(), header.disc_size.get() as usize) - .context("Reading WIA/RVZ disc header")?; + let mut disc_buf: Vec = read_vec_at( + inner.as_mut(), + header.disc_size.get() as usize, + size_of::() as u64, + ) + .context("Reading WIA/RVZ disc header")?; verify_hash(&disc_buf, &header.disc_hash)?; disc_buf.resize(size_of::(), 0); let disc = WIADisc::read_from_bytes(disc_buf.as_slice()).unwrap(); @@ -605,15 +611,20 @@ impl BlockReaderWIA { debug!("Disc: {:?}", disc); // Read NKit header if present (after disc header) - let nkit_header = NKitHeader::try_read_from(inner.as_mut(), disc.chunk_size.get(), false); + let nkit_header = NKitHeader::try_read_from( + inner.as_mut(), + size_of::() as u64 + header.disc_size.get() as u64, + disc.chunk_size.get(), + false, + ); // Load & verify partition headers - inner - .seek(SeekFrom::Start(disc.partition_offset.get())) - .context("Seeking to WIA/RVZ partition headers")?; - let partitions: Arc<[WIAPartition]> = - read_arc_slice(inner.as_mut(), disc.num_partitions.get() as usize) - .context("Reading WIA/RVZ partition headers")?; + let partitions: Arc<[WIAPartition]> = read_arc_slice_at( + inner.as_mut(), + disc.num_partitions.get() as usize, + disc.partition_offset.get(), + ) + .context("Reading WIA/RVZ partition headers")?; verify_hash(partitions.as_ref().as_bytes(), &disc.partition_hash)?; debug!("Partitions: {:?}", partitions); @@ -622,15 +633,18 @@ impl BlockReaderWIA { // Load raw data headers let raw_data: Arc<[WIARawData]> = { - inner - .seek(SeekFrom::Start(disc.raw_data_offset.get())) - .context("Seeking to WIA/RVZ raw data headers")?; - let mut reader = decompressor - .kind - .wrap(inner.as_mut().take(disc.raw_data_size.get() as u64)) - .context("Creating WIA/RVZ decompressor")?; - read_arc_slice(&mut reader, disc.num_raw_data.get() as usize) - .context("Reading WIA/RVZ raw data headers")? + let compressed_data = read_box_slice_at::( + inner.as_mut(), + disc.raw_data_size.get() as usize, + disc.raw_data_offset.get(), + ) + .context("Reading WIA/RVZ raw data headers")?; + read_into_arc_slice(disc.num_raw_data.get() as usize, |out| { + decompressor + .decompress(&compressed_data, out) + .context("Decompressing WIA/RVZ raw data headers") + .map(|_| ()) + })? }; // Validate raw data alignment for (idx, rd) in raw_data.iter().enumerate() { @@ -652,20 +666,27 @@ impl BlockReaderWIA { // Load group headers let groups = { - inner - .seek(SeekFrom::Start(disc.group_offset.get())) - .context("Seeking to WIA/RVZ group headers")?; - let mut reader = decompressor - .kind - .wrap(inner.as_mut().take(disc.group_size.get() as u64)) - .context("Creating WIA/RVZ decompressor")?; + let compressed_data = read_box_slice_at::( + inner.as_mut(), + disc.group_size.get() as usize, + disc.group_offset.get(), + ) + .context("Reading WIA/RVZ group headers")?; if is_rvz { - read_arc_slice(&mut reader, disc.num_groups.get() as usize) - .context("Reading WIA/RVZ group headers")? + read_into_arc_slice(disc.num_groups.get() as usize, |out| { + decompressor + .decompress(&compressed_data, out) + .context("Decompressing WIA/RVZ group headers") + .map(|_| ()) + })? } else { - let wia_groups: Arc<[WIAGroup]> = - read_arc_slice(&mut reader, disc.num_groups.get() as usize) - .context("Reading WIA/RVZ group headers")?; + let wia_groups = + read_into_box_slice::(disc.num_groups.get() as usize, |out| { + decompressor + .decompress(&compressed_data, out) + .context("Decompressing WIA/RVZ group headers") + .map(|_| ()) + })?; wia_groups.iter().map(RVZGroup::from).collect() } }; @@ -878,8 +899,7 @@ impl BlockReader for BlockReaderWIA { let group_data_start = group.data_offset.get() as u64 * 4; let mut group_data = BytesMut::zeroed(group.data_size() as usize); let io_start = Instant::now(); - self.inner.seek(SeekFrom::Start(group_data_start))?; - self.inner.read_exact(group_data.as_mut())?; + self.inner.read_exact_at(group_data.as_mut(), group_data_start)?; let io_duration = io_start.elapsed(); let mut group_data = group_data.freeze(); @@ -1698,7 +1718,7 @@ impl DiscWriter for DiscWriterWIA { let mut group_hashes = HashMap::::new(); let mut reuse_size = 0; par_process( - || BlockProcessorWIA { + BlockProcessorWIA { inner: self.inner.clone(), header: self.header.clone(), disc: self.disc.clone(), diff --git a/nod/src/lib.rs b/nod/src/lib.rs index ba750c1..29c92a3 100644 --- a/nod/src/lib.rs +++ b/nod/src/lib.rs @@ -134,7 +134,7 @@ //! // Some disc writers calculate data during processing. //! // If the finalization returns header data, seek to the beginning of the file and write it. //! if !finalization.header.is_empty() { -//! output_file.seek(std::io::SeekFrom::Start(0)) +//! output_file.rewind() //! .expect("Failed to seek"); //! output_file.write_all(finalization.header.as_ref()) //! .expect("Failed to write header"); diff --git a/nod/src/read.rs b/nod/src/read.rs index ca21a3d..f0fc25e 100644 --- a/nod/src/read.rs +++ b/nod/src/read.rs @@ -1,8 +1,8 @@ //! [`DiscReader`] and associated types. use std::{ - io::{BufRead, Read, Seek}, + io::{self, BufRead, Read, Seek}, path::Path, - sync::Arc, + sync::{Arc, Mutex}, }; use dyn_clone::DynClone; @@ -62,13 +62,107 @@ pub struct PartitionOptions { pub validate_hashes: bool, } -/// Required trait bounds for reading disc images. -pub trait DiscStream: Read + Seek + DynClone + Send + Sync {} +/// Trait for reading disc images. +/// +/// Disc images are read in blocks, often in the hundred kilobyte to several megabyte range, +/// making the standard [`Read`] and [`Seek`] traits a poor fit for this use case. This trait +/// provides a simplified interface for reading disc images, with a focus on large, random +/// access reads. +/// +/// For multithreading support, an implementation must be [`Send`] and [`Clone`]. +/// [`Sync`] is _not_ required: the stream will be cloned if used in multiple threads. +/// +/// Rather than implement this trait directly, you'll likely use one of the following +/// [`DiscReader`] functions: +/// - [`DiscReader::new`]: to open a disc image from a file path. +/// - [`DiscReader::new_stream`]: when you can provide a [`Box`]. +/// - [`DiscReader::new_from_cloneable_read`]: when you can provide a [`Read`] + [`Seek`] + +/// [`Clone`] stream. +/// - [`DiscReader::new_from_non_cloneable_read`]: when you can provide a [`Read`] + [`Seek`] +/// stream. (Accesses will be synchronized, limiting multithreaded performance.) +pub trait DiscStream: DynClone + Send { + /// Reads the exact number of bytes required to fill `buf` from the given offset. + fn read_exact_at(&mut self, buf: &mut [u8], offset: u64) -> io::Result<()>; -impl DiscStream for T where T: Read + Seek + DynClone + Send + Sync + ?Sized {} + /// Returns the length of the stream in bytes. + fn stream_len(&mut self) -> io::Result; +} dyn_clone::clone_trait_object!(DiscStream); +impl DiscStream for T +where T: AsRef<[u8]> + Send + Clone +{ + fn read_exact_at(&mut self, buf: &mut [u8], offset: u64) -> io::Result<()> { + let data = self.as_ref(); + let len = data.len() as u64; + let end = offset + buf.len() as u64; + if offset >= len || end > len { + return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); + } + buf.copy_from_slice(&data[offset as usize..end as usize]); + Ok(()) + } + + fn stream_len(&mut self) -> io::Result { Ok(self.as_ref().len() as u64) } +} + +#[derive(Debug, Clone)] +pub(crate) struct CloneableStream(pub T) +where T: Read + Seek + Clone + Send; + +impl CloneableStream +where T: Read + Seek + Clone + Send +{ + pub fn new(stream: T) -> Self { Self(stream) } +} + +impl DiscStream for CloneableStream +where T: Read + Seek + Clone + Send +{ + fn read_exact_at(&mut self, buf: &mut [u8], offset: u64) -> io::Result<()> { + self.0.seek(io::SeekFrom::Start(offset))?; + self.0.read_exact(buf) + } + + fn stream_len(&mut self) -> io::Result { self.0.seek(io::SeekFrom::End(0)) } +} + +#[derive(Debug)] +pub(crate) struct NonCloneableStream(pub Arc>) +where T: Read + Seek + Send; + +impl Clone for NonCloneableStream +where T: Read + Seek + Send +{ + fn clone(&self) -> Self { Self(self.0.clone()) } +} + +impl NonCloneableStream +where T: Read + Seek + Send +{ + pub fn new(stream: T) -> Self { Self(Arc::new(Mutex::new(stream))) } + + fn lock(&self) -> io::Result> { + self.0.lock().map_err(|_| io::Error::other("NonCloneableStream mutex poisoned")) + } +} + +impl DiscStream for NonCloneableStream +where T: Read + Seek + Send +{ + fn read_exact_at(&mut self, buf: &mut [u8], offset: u64) -> io::Result<()> { + let mut stream = self.lock()?; + stream.seek(io::SeekFrom::Start(offset))?; + stream.read_exact(buf) + } + + fn stream_len(&mut self) -> io::Result { + let mut stream = self.lock()?; + stream.seek(io::SeekFrom::End(0)) + } +} + /// An open disc image and read stream. /// /// This is the primary entry point for reading disc images. @@ -79,24 +173,44 @@ pub struct DiscReader { impl DiscReader { /// Opens a disc image from a file path. - #[inline] pub fn new>(path: P, options: &DiscOptions) -> Result { let io = block::open(path.as_ref())?; let inner = disc::reader::DiscReader::new(io, options)?; Ok(DiscReader { inner }) } - /// Opens a disc image from a read stream. - #[inline] + /// Opens a disc image from a [`DiscStream`]. This allows low-overhead, multithreaded + /// access to disc images stored in memory, archives, or other non-file sources. + /// + /// See [`DiscStream`] for more information. pub fn new_stream(stream: Box, options: &DiscOptions) -> Result { let io = block::new(stream)?; let reader = disc::reader::DiscReader::new(io, options)?; Ok(DiscReader { inner: reader }) } + /// Opens a disc image from a [`Read`] + [`Seek`] stream that can be cloned. + /// + /// The stream will be cloned for each thread that reads from it, allowing for multithreaded + /// access (e.g. for preloading blocks during reading or parallel block processing during + /// conversion). + pub fn new_from_cloneable_read(stream: R, options: &DiscOptions) -> Result + where R: Read + Seek + Clone + Send + 'static { + Self::new_stream(Box::new(CloneableStream::new(stream)), options) + } + + /// Opens a disc image from a [`Read`] + [`Seek`] stream that cannot be cloned. + /// + /// Multithreaded accesses will be synchronized, which will limit performance (e.g. for + /// preloading blocks during reading or parallel block processing during conversion). + pub fn new_from_non_cloneable_read(stream: R, options: &DiscOptions) -> Result + where R: Read + Seek + Send + 'static { + Self::new_stream(Box::new(NonCloneableStream::new(stream)), options) + } + /// Detects the format of a disc image from a read stream. #[inline] - pub fn detect(stream: &mut R) -> std::io::Result> + pub fn detect(stream: &mut R) -> io::Result> where R: Read + ?Sized { block::detect(stream) } @@ -155,7 +269,7 @@ impl DiscReader { impl BufRead for DiscReader { #[inline] - fn fill_buf(&mut self) -> std::io::Result<&[u8]> { self.inner.fill_buf() } + fn fill_buf(&mut self) -> io::Result<&[u8]> { self.inner.fill_buf() } #[inline] fn consume(&mut self, amt: usize) { self.inner.consume(amt) } @@ -163,12 +277,12 @@ impl BufRead for DiscReader { impl Read for DiscReader { #[inline] - fn read(&mut self, buf: &mut [u8]) -> std::io::Result { self.inner.read(buf) } + fn read(&mut self, buf: &mut [u8]) -> io::Result { self.inner.read(buf) } } impl Seek for DiscReader { #[inline] - fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result { self.inner.seek(pos) } + fn seek(&mut self, pos: io::SeekFrom) -> io::Result { self.inner.seek(pos) } } /// Extra metadata about the underlying disc file format. @@ -199,7 +313,7 @@ pub struct DiscMeta { } /// An open disc partition. -pub trait PartitionReader: BufRead + DiscStream { +pub trait PartitionReader: DynClone + BufRead + Seek + Send { /// Whether this is a Wii partition. (GameCube otherwise) fn is_wii(&self) -> bool; @@ -246,10 +360,10 @@ impl dyn PartitionReader + '_ { /// Ok(()) /// } /// ``` - pub fn open_file(&mut self, node: Node) -> std::io::Result { + pub fn open_file(&mut self, node: Node) -> io::Result { if !node.is_file() { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, + return Err(io::Error::new( + io::ErrorKind::InvalidInput, "Node is not a file".to_string(), )); } @@ -279,21 +393,20 @@ impl dyn PartitionReader { /// let fst = meta.fst()?; /// if let Some((_, node)) = fst.find("/disc.tgc") { /// let file: OwnedFileReader = partition - /// .clone() // Clone the Box /// .into_open_file(node) // Get an OwnedFileStream /// .expect("Failed to open file stream"); /// // Open the inner disc image using the owned stream - /// let inner_disc = DiscReader::new_stream(Box::new(file), &DiscOptions::default()) + /// let inner_disc = DiscReader::new_from_cloneable_read(file, &DiscOptions::default()) /// .expect("Failed to open inner disc"); /// // ... /// } /// Ok(()) /// } /// ``` - pub fn into_open_file(self: Box, node: Node) -> std::io::Result { + pub fn into_open_file(self: Box, node: Node) -> io::Result { if !node.is_file() { - return Err(std::io::Error::new( - std::io::ErrorKind::InvalidInput, + return Err(io::Error::new( + io::ErrorKind::InvalidInput, "Node is not a file".to_string(), )); } diff --git a/nod/src/util/compress.rs b/nod/src/util/compress.rs index 71e061a..31969cd 100644 --- a/nod/src/util/compress.rs +++ b/nod/src/util/compress.rs @@ -1,4 +1,4 @@ -use std::{io, io::Read}; +use std::io; use tracing::instrument; @@ -182,31 +182,6 @@ impl DecompressionKind { comp => Err(Error::DiscFormat(format!("Unsupported WIA/RVZ compression: {:?}", comp))), } } - - pub fn wrap<'a, R>(&mut self, reader: R) -> io::Result> - where R: Read + 'a { - Ok(match self { - DecompressionKind::None => Box::new(reader), - #[cfg(feature = "compress-zlib")] - DecompressionKind::Deflate => unimplemented!("DecompressionKind::Deflate.wrap"), - #[cfg(feature = "compress-bzip2")] - DecompressionKind::Bzip2 => Box::new(bzip2::read::BzDecoder::new(reader)), - #[cfg(feature = "compress-lzma")] - DecompressionKind::Lzma(data) => { - use lzma_util::{lzma_props_decode, new_lzma_decoder}; - let stream = new_lzma_decoder(&lzma_props_decode(data)?)?; - Box::new(liblzma::read::XzDecoder::new_stream(reader, stream)) - } - #[cfg(feature = "compress-lzma")] - DecompressionKind::Lzma2(data) => { - use lzma_util::{lzma2_props_decode, new_lzma2_decoder}; - let stream = new_lzma2_decoder(&lzma2_props_decode(data)?)?; - Box::new(liblzma::read::XzDecoder::new_stream(reader, stream)) - } - #[cfg(feature = "compress-zstd")] - DecompressionKind::Zstandard => Box::new(zstd::stream::Decoder::new(reader)?), - }) - } } pub struct Compressor { diff --git a/nod/src/util/lfg.rs b/nod/src/util/lfg.rs index 8bd6449..8b27a86 100644 --- a/nod/src/util/lfg.rs +++ b/nod/src/util/lfg.rs @@ -77,7 +77,7 @@ impl LaggedFibonacci { out[16] ^= (out[0] >> 9) ^ (out[16] << 23); } - /// Same as [`generate_seed`], but ensures the resulting seed is big-endian. + /// Same as [`Self::generate_seed`], but ensures the resulting seed is big-endian. pub fn generate_seed_be( out: &mut [u32; SEED_SIZE], disc_id: [u8; 4], diff --git a/nod/src/util/read.rs b/nod/src/util/read.rs index 51052e0..e3037c3 100644 --- a/nod/src/util/read.rs +++ b/nod/src/util/read.rs @@ -2,6 +2,8 @@ use std::{io, io::Read, sync::Arc}; use zerocopy::{FromBytes, FromZeros, IntoBytes}; +use crate::read::DiscStream; + #[inline(always)] pub fn read_from(reader: &mut R) -> io::Result where @@ -13,6 +15,17 @@ where Ok(ret) } +#[inline(always)] +pub fn read_at(reader: &mut R, offset: u64) -> io::Result +where + T: FromBytes + IntoBytes, + R: DiscStream + ?Sized, +{ + let mut ret = ::new_zeroed(); + reader.read_exact_at(ret.as_mut_bytes(), offset)?; + Ok(ret) +} + #[inline(always)] pub fn read_vec(reader: &mut R, count: usize) -> io::Result> where @@ -25,6 +38,18 @@ where Ok(ret) } +#[inline(always)] +pub fn read_vec_at(reader: &mut R, count: usize, offset: u64) -> io::Result> +where + T: FromBytes + IntoBytes, + R: DiscStream + ?Sized, +{ + let mut ret = + ::new_vec_zeroed(count).map_err(|_| io::Error::from(io::ErrorKind::OutOfMemory))?; + reader.read_exact_at(ret.as_mut_slice().as_mut_bytes(), offset)?; + Ok(ret) +} + #[inline(always)] pub fn read_box(reader: &mut R) -> io::Result> where @@ -36,6 +61,17 @@ where Ok(ret) } +#[inline(always)] +pub fn read_box_at(reader: &mut R, offset: u64) -> io::Result> +where + T: FromBytes + IntoBytes, + R: DiscStream + ?Sized, +{ + let mut ret = ::new_box_zeroed().map_err(|_| io::Error::from(io::ErrorKind::OutOfMemory))?; + reader.read_exact_at(ret.as_mut().as_mut_bytes(), offset)?; + Ok(ret) +} + #[inline(always)] pub fn read_arc(reader: &mut R) -> io::Result> where @@ -46,6 +82,16 @@ where read_box(reader).map(Arc::from) } +#[inline(always)] +pub fn read_arc_at(reader: &mut R, offset: u64) -> io::Result> +where + T: FromBytes + IntoBytes, + R: DiscStream + ?Sized, +{ + // TODO use Arc::new_zeroed once it's stable + read_box_at(reader, offset).map(Arc::from) +} + #[inline(always)] pub fn read_box_slice(reader: &mut R, count: usize) -> io::Result> where @@ -58,6 +104,18 @@ where Ok(ret) } +#[inline(always)] +pub fn read_box_slice_at(reader: &mut R, count: usize, offset: u64) -> io::Result> +where + T: FromBytes + IntoBytes, + R: DiscStream + ?Sized, +{ + let mut ret = <[T]>::new_box_zeroed_with_elems(count) + .map_err(|_| io::Error::from(io::ErrorKind::OutOfMemory))?; + reader.read_exact_at(ret.as_mut().as_mut_bytes(), offset)?; + Ok(ret) +} + #[inline(always)] pub fn read_arc_slice(reader: &mut R, count: usize) -> io::Result> where @@ -68,6 +126,16 @@ where read_box_slice(reader, count).map(Arc::from) } +#[inline(always)] +pub fn read_arc_slice_at(reader: &mut R, count: usize, offset: u64) -> io::Result> +where + T: FromBytes + IntoBytes, + R: DiscStream + ?Sized, +{ + // TODO use Arc::new_zeroed once it's stable + read_box_slice_at(reader, count, offset).map(Arc::from) +} + #[inline(always)] pub fn read_u16_be(reader: &mut R) -> io::Result where R: Read + ?Sized { @@ -114,3 +182,30 @@ where T: IntoBytes { let sp = unsafe { std::slice::from_raw_parts_mut(p as *mut u8, size_of::()) }; unsafe { Box::from_raw(sp) } } + +pub fn read_into_box_slice( + count: usize, + init: impl FnOnce(&mut [u8]) -> Result<(), E>, +) -> Result, E> +where + T: FromBytes + IntoBytes, +{ + let mut out = <[T]>::new_box_zeroed_with_elems(count).unwrap(); + init(out.as_mut_bytes())?; + Ok(out) +} + +pub fn read_into_arc_slice( + count: usize, + init: impl FnOnce(&mut [u8]) -> Result<(), E>, +) -> Result, E> +where + T: FromBytes + IntoBytes, +{ + let mut arc = Arc::<[T]>::new_uninit_slice(count); + let ptr = Arc::get_mut(&mut arc).unwrap().as_mut_ptr() as *mut u8; + let slice = unsafe { std::slice::from_raw_parts_mut(ptr, count * size_of::()) }; + slice.fill(0); + init(slice)?; + Ok(unsafe { arc.assume_init() }) +} diff --git a/nod/src/write.rs b/nod/src/write.rs index d74bab1..92c80d9 100644 --- a/nod/src/write.rs +++ b/nod/src/write.rs @@ -110,7 +110,7 @@ impl DiscWriter { #[inline] pub fn process( &self, - mut data_callback: impl FnMut(Bytes, u64, u64) -> std::io::Result<()> + Send, + mut data_callback: impl FnMut(Bytes, u64, u64) -> std::io::Result<()>, options: &ProcessOptions, ) -> Result { self.inner.process(&mut data_callback, options) diff --git a/nodtool/src/cmd/gen.rs b/nodtool/src/cmd/gen.rs index 8347fbd..40cd501 100644 --- a/nodtool/src/cmd/gen.rs +++ b/nodtool/src/cmd/gen.rs @@ -570,7 +570,7 @@ fn in_memory_test( Ok(()) } })?; - let disc_stream = writer.into_stream(PartitionFileReader { partition, meta })?; + let disc_stream = writer.into_cloneable_stream(PartitionFileReader { partition, meta })?; let disc_reader = DiscReader::new_stream(disc_stream, &DiscOptions::default())?; let disc_writer = DiscWriter::new(disc_reader, &FormatOptions::default())?; let process_options = ProcessOptions { digest_crc32: true, ..Default::default() }; diff --git a/nodtool/src/util/shared.rs b/nodtool/src/util/shared.rs index b98e285..93d8218 100644 --- a/nodtool/src/util/shared.rs +++ b/nodtool/src/util/shared.rs @@ -1,7 +1,7 @@ use std::{ fmt, fs::File, - io::{Seek, SeekFrom, Write}, + io::{Seek, Write}, path::Path, }; @@ -130,7 +130,7 @@ pub fn convert_and_verify( // Finalize disc writer if !finalization.header.is_empty() { if let Some(file) = &mut file { - file.seek(SeekFrom::Start(0)).context("Seeking to start of output file")?; + file.rewind().context("Seeking to start of output file")?; file.write_all(finalization.header.as_ref()).context("Writing header")?; } else { return Err(nod::Error::Other("No output file, but requires finalization".to_string()));