diff --git a/src/lib.rs b/src/lib.rs index c3123cb..80daade 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,8 +2,10 @@ use std::collections::BTreeMap; use std::fs::{self, File, Metadata}; use std::io::{self, BufReader, Read, Write}; use std::path::{Path, PathBuf}; +use std::sync::mpsc; +use std::thread; -use ignore::WalkBuilder; +use ignore::{WalkBuilder, WalkState}; use rayon::prelude::*; use serde::{Serialize, Serializer}; @@ -19,6 +21,7 @@ pub struct ScanConfig { pub hash_bytes: u64, pub follow_links: bool, pub verify_full: bool, + pub threads: Option, } #[derive(Debug, Clone, Serialize)] @@ -26,6 +29,7 @@ pub struct ScanReport { #[serde(serialize_with = "serialize_paths")] pub scanned_paths: Vec, pub hash_bytes: u64, + pub worker_threads: usize, pub followed_symlinks: bool, pub full_verification: bool, pub summary: ScanSummary, @@ -135,6 +139,15 @@ enum HashOutcome { Issue(ScanIssue), } +#[derive(Debug, Clone)] +enum ScannedEntry { + File(FileEntry), + Directory, + Symlink(SymlinkInfo), + Special(SpecialEntry), + Issue(ScanIssue), +} + pub fn parse_byte_count(input: &str) -> Result { let trimmed = input.trim(); if trimmed.is_empty() { @@ -184,6 +197,7 @@ pub fn parse_byte_count(input: &str) -> Result { pub fn scan_paths(config: ScanConfig) -> ScanReport { let hash_bytes = config.hash_bytes.max(1); + let worker_threads = worker_threads(config.threads); let mut files = Vec::new(); let mut symlinks = Vec::new(); let mut special_entries = Vec::new(); @@ -194,6 +208,7 @@ pub fn scan_paths(config: ScanConfig) -> ScanReport { for root in &config.paths { let mut builder = WalkBuilder::new(root); builder + .threads(worker_threads) .follow_links(config.follow_links) .hidden(false) .ignore(false) @@ -202,57 +217,31 @@ pub fn scan_paths(config: ScanConfig) -> ScanReport { .git_exclude(false) .parents(false); - for entry in builder.build() { - match entry { - Ok(entry) => { - let path = entry.path().to_path_buf(); - let metadata = match fs::symlink_metadata(&path) { - Ok(metadata) => metadata, - Err(error) => { - errors.push(issue(path, format!("could not read metadata: {error}"))); - continue; - } - }; - - if metadata.file_type().is_symlink() { - symlinks.push(describe_symlink(&path)); - if !config.follow_links { - continue; - } - - match fs::metadata(&path) { - Ok(target_metadata) => { - process_non_symlink_entry( - path, - &target_metadata, - &mut files, - &mut special_entries, - &mut directories, - &mut total_file_bytes, - ); - } - Err(error) => { - errors.push(issue( - path, - format!("could not follow symlink target: {error}"), - )); - } - } - } else { - process_non_symlink_entry( - path, - &metadata, - &mut files, - &mut special_entries, - &mut directories, - &mut total_file_bytes, - ); + let (sender, receiver) = mpsc::channel(); + builder.build_parallel().run(|| { + let sender = sender.clone(); + let follow_links = config.follow_links; + Box::new(move |entry| { + for scanned_entry in classify_walk_entry(entry, follow_links) { + if sender.send(scanned_entry).is_err() { + return WalkState::Quit; } } - Err(error) => { - errors.push(issue(PathBuf::from(""), error.to_string())); - } - } + WalkState::Continue + }) + }); + drop(sender); + + for scanned_entry in receiver { + collect_scanned_entry( + scanned_entry, + &mut files, + &mut symlinks, + &mut special_entries, + &mut errors, + &mut directories, + &mut total_file_bytes, + ); } } @@ -301,6 +290,7 @@ pub fn scan_paths(config: ScanConfig) -> ScanReport { ScanReport { scanned_paths: config.paths, hash_bytes, + worker_threads, followed_symlinks: config.follow_links, full_verification: config.verify_full, summary: ScanSummary { @@ -327,30 +317,93 @@ pub fn scan_paths(config: ScanConfig) -> ScanReport { } } -fn process_non_symlink_entry( - path: PathBuf, - metadata: &Metadata, - files: &mut Vec, - special_entries: &mut Vec, - directories: &mut usize, - total_file_bytes: &mut u64, -) { +fn worker_threads(configured_threads: Option) -> usize { + configured_threads.unwrap_or_else(|| { + thread::available_parallelism() + .map(usize::from) + .unwrap_or(1) + }) +} + +fn classify_walk_entry( + entry: Result, + follow_links: bool, +) -> Vec { + match entry { + Ok(entry) => classify_path(entry.path().to_path_buf(), follow_links), + Err(error) => vec![ScannedEntry::Issue(issue( + PathBuf::from(""), + error.to_string(), + ))], + } +} + +fn classify_path(path: PathBuf, follow_links: bool) -> Vec { + let metadata = match fs::symlink_metadata(&path) { + Ok(metadata) => metadata, + Err(error) => { + return vec![ScannedEntry::Issue(issue( + path, + format!("could not read metadata: {error}"), + ))]; + } + }; + + if !metadata.file_type().is_symlink() { + return vec![non_symlink_entry(path, &metadata)]; + } + + let mut entries = vec![ScannedEntry::Symlink(describe_symlink(&path))]; + if follow_links { + match fs::metadata(&path) { + Ok(target_metadata) => entries.push(non_symlink_entry(path, &target_metadata)), + Err(error) => entries.push(ScannedEntry::Issue(issue( + path, + format!("could not follow symlink target: {error}"), + ))), + } + } + + entries +} + +fn non_symlink_entry(path: PathBuf, metadata: &Metadata) -> ScannedEntry { let file_type = metadata.file_type(); if file_type.is_file() { - *total_file_bytes = total_file_bytes.saturating_add(metadata.len()); - files.push(FileEntry { + ScannedEntry::File(FileEntry { path, size: metadata.len(), device: metadata.dev(), inode: metadata.ino(), - }); + }) } else if file_type.is_dir() { - *directories += 1; + ScannedEntry::Directory } else { - special_entries.push(SpecialEntry { + ScannedEntry::Special(SpecialEntry { path, kind: special_entry_kind(&file_type), - }); + }) + } +} + +fn collect_scanned_entry( + entry: ScannedEntry, + files: &mut Vec, + symlinks: &mut Vec, + special_entries: &mut Vec, + errors: &mut Vec, + directories: &mut usize, + total_file_bytes: &mut u64, +) { + match entry { + ScannedEntry::File(file) => { + *total_file_bytes = total_file_bytes.saturating_add(file.size); + files.push(file); + } + ScannedEntry::Directory => *directories += 1, + ScannedEntry::Symlink(symlink) => symlinks.push(symlink), + ScannedEntry::Special(special_entry) => special_entries.push(special_entry), + ScannedEntry::Issue(error) => errors.push(error), } } @@ -555,6 +608,7 @@ pub fn write_human_report(mut writer: impl Write, report: &ScanReport) -> io::Re join_paths(&report.scanned_paths) )?; writeln!(writer, "Hash window: {}", format_bytes(report.hash_bytes))?; + writeln!(writer, "Worker threads: {}", report.worker_threads)?; writeln!( writer, "Symlink traversal: {}", @@ -874,6 +928,7 @@ mod tests { hash_bytes: 3, follow_links: false, verify_full: false, + threads: None, }); assert_eq!(report.summary.files, 3); @@ -897,6 +952,7 @@ mod tests { hash_bytes: 3, follow_links: false, verify_full: true, + threads: None, }); assert_eq!(report.possible_duplicates.len(), 1); @@ -919,6 +975,7 @@ mod tests { hash_bytes: DEFAULT_HASH_BYTES, follow_links: false, verify_full: false, + threads: None, }); assert_eq!(report.summary.files, 1); @@ -941,6 +998,7 @@ mod tests { hash_bytes: DEFAULT_HASH_BYTES, follow_links: false, verify_full: false, + threads: None, }); assert_eq!(report.summary.files, 2); @@ -965,6 +1023,7 @@ mod tests { hash_bytes: DEFAULT_HASH_BYTES, follow_links: false, verify_full: false, + threads: None, }); let json = serde_json::to_string(&report).expect("serialize report with lossy path"); @@ -977,6 +1036,7 @@ mod tests { let report = ScanReport { scanned_paths: vec![PathBuf::from(".")], hash_bytes: DEFAULT_HASH_BYTES, + worker_threads: 1, followed_symlinks: false, full_verification: false, summary: ScanSummary { diff --git a/src/main.rs b/src/main.rs index 2a59242..68857f5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -31,7 +31,7 @@ struct Cli { #[arg(long)] verify_full: bool, - /// Number of worker threads used for hashing. Defaults to Rayon automatic sizing. + /// Number of worker threads used for scanning and hashing. Defaults to CPU parallelism. #[arg(long, value_parser = parse_thread_count)] threads: Option, @@ -72,6 +72,7 @@ fn main() -> anyhow::Result { hash_bytes: cli.hash_bytes, follow_links: cli.follow_links, verify_full: cli.verify_full, + threads: cli.threads, }); let stdout = io::stdout();