use std::path::{Path,PathBuf}; use std::fs::File; use std::io::Read; use std::thread; use std::cmp::min; use std::sync::Arc; extern crate failure; use failure::Fallible as Result; extern crate tempfile; extern crate sequoia_openpgp as openpgp; use openpgp::Packet; use openpgp::parse::{PacketParser, PacketParserResult, Parse}; extern crate hagrid_database as database; use database::{Database, KeyDatabase, ImportResult}; use indicatif::{MultiProgress,ProgressBar,ProgressStyle}; use HagridConfig; // parsing TPKs takes time, so we benefit from some parallelism. however, the // database is locked during the entire merge operation, so we get diminishing // returns after the first few threads. const NUM_THREADS_MAX: usize = 3; pub fn do_import(config: &HagridConfig, dry_run: bool, input_files: Vec) -> Result<()> { let num_threads = min(NUM_THREADS_MAX, input_files.len()); let input_file_chunks = setup_chunks(input_files, num_threads); let multi_progress = Arc::new(MultiProgress::new()); let progress_bar = multi_progress.add(ProgressBar::new(0)); let threads: Vec<_> = input_file_chunks .into_iter() .map(|input_file_chunk| { let config = config.clone(); let multi_progress = multi_progress.clone(); thread::spawn(move || { import_from_files( &config, dry_run, input_file_chunk, multi_progress).unwrap(); }) }) .collect(); eprintln!("Importing in {} threads", num_threads); thread::spawn(move || multi_progress.join().unwrap()); threads.into_iter().for_each(|t| t.join().unwrap()); progress_bar.finish(); Ok(()) } fn setup_chunks( mut input_files: Vec, num_threads: usize, ) -> Vec> { let chunk_size = (input_files.len() + (num_threads - 1)) / num_threads; (0..num_threads) .map(|_| { let len = input_files.len(); input_files.drain(0..min(chunk_size,len)).collect() }) .collect() } struct ImportStats<'a> { progress: &'a ProgressBar, filename: String, count_total: u64, count_err: u64, count_new: u64, count_updated: u64, count_unchanged: u64, } impl <'a> ImportStats<'a> { fn new(progress: &'a ProgressBar, filename: String) -> Self { ImportStats { progress, filename, count_total: 0, count_err: 0, count_new: 0, count_updated: 0, count_unchanged: 0, } } fn update(&mut self, result: Result) { // If a new TPK starts, parse and import. self.count_total += 1; match result { Err(_) => self.count_err += 1, Ok(ImportResult::New(_)) => self.count_new += 1, Ok(ImportResult::Updated(_)) => self.count_updated += 1, Ok(ImportResult::Unchanged(_)) => self.count_unchanged += 1, } self.progress_update(); } fn progress_update(&self) { if (self.count_total % 10) != 0 { return; } self.progress.set_message(&format!( "{}, imported {:5} keys, {:5} New {:5} Updated {:5} Unchanged {:5} Errors", &self.filename, self.count_total, self.count_new, self.count_updated, self.count_unchanged, self.count_err)); } } fn import_from_files( config: &HagridConfig, dry_run: bool, input_files: Vec, multi_progress: Arc, ) -> Result<()> { let db = KeyDatabase::new_internal( config.keys_internal_dir.as_ref().unwrap(), config.keys_external_dir.as_ref().unwrap(), config.tmp_dir.as_ref().unwrap(), dry_run, )?; for input_file in input_files { import_from_file(&db, &input_file, &multi_progress)?; } Ok(()) } fn import_from_file(db: &KeyDatabase, input: &Path, multi_progress: &MultiProgress) -> Result<()> { let input_file = File::open(input)?; let bytes_total = input_file.metadata()?.len(); let progress_bar = multi_progress.add(ProgressBar::new(bytes_total)); progress_bar .set_style(ProgressStyle::default_bar() .template("[{elapsed_precise}] {bar:40.cyan/blue} {msg}") .progress_chars("##-")); progress_bar.set_message("Starting…"); let input_reader = &mut progress_bar.wrap_read(input_file); let filename = input.file_name().unwrap().to_string_lossy().to_string(); let mut stats = ImportStats::new(&progress_bar, filename); read_file_to_tpks(input_reader, &mut |acc| { let result = import_key(&db, acc); if let Err(ref e) = result { progress_bar.println(e.to_string()); } stats.update(result); })?; progress_bar.finish_and_clear(); Ok(()) } fn read_file_to_tpks( reader: impl Read, callback: &mut impl FnMut(Vec) -> () ) -> Result<()> { let mut ppr = PacketParser::from_reader(reader)?; let mut acc = Vec::new(); // Iterate over all packets. while let PacketParserResult::Some(pp) = ppr { // Get the packet and advance the parser. let (packet, tmp) = pp.next()?; ppr = tmp; if !acc.is_empty() { if let Packet::PublicKey(_) | Packet::SecretKey(_) = packet { callback(acc); acc = vec!(); } } acc.push(packet); } Ok(()) } fn import_key(db: &KeyDatabase, packets: Vec) -> Result { let packet_pile = openpgp::PacketPile::from(packets); openpgp::TPK::from_packet_pile(packet_pile) .and_then(|tpk| { db.merge(tpk) }) } /* #[cfg(test)] mod import_tests { use std::fs::File; use tempfile::tempdir; use openpgp::serialize::Serialize; use super::*; #[test] fn import() { let root = tempdir().unwrap(); let db = KeyDatabase::new_from_base(root.path().to_path_buf()).unwrap(); // Generate a key and import it. let (tpk, _) = openpgp::tpk::TPKBuilder::autocrypt( None, Some("foo@invalid.example.com".into())) .generate().unwrap(); let import_me = root.path().join("import-me"); tpk.serialize(&mut File::create(&import_me).unwrap()).unwrap(); do_import(root.path().to_path_buf(), vec![import_me]).unwrap(); let check = |query: &str| { let tpk_ = db.lookup(&query.parse().unwrap()).unwrap().unwrap(); assert_eq!(tpk.fingerprint(), tpk_.fingerprint()); assert_eq!(tpk.subkeys().map(|skb| skb.subkey().fingerprint()) .collect::>(), tpk_.subkeys().map(|skb| skb.subkey().fingerprint()) .collect::>()); assert_eq!(tpk_.userids().count(), 0); }; check(&format!("{}", tpk.primary().fingerprint())); check(&format!("{}", tpk.primary().fingerprint().to_keyid())); check(&format!("{}", tpk.subkeys().nth(0).unwrap().subkey() .fingerprint())); check(&format!("{}", tpk.subkeys().nth(0).unwrap().subkey() .fingerprint().to_keyid())); } } */