diff --git a/src/lib.rs b/src/lib.rs index 569695d..fbae67a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,9 @@ use std::fmt; use std::fs::File; use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher}; use std::io::{BufReader, BufWriter, Write}; +use std::sync::mpsc; +use std::thread; + //use std::path::Path; // External crate imports @@ -77,8 +80,8 @@ impl Hasher for HashIntoType { panic!("This hasher only takes u64"); } - fn write_u64(&mut self, _i: u64) { // no-op implementation! - // self.0 = i; + fn write_u64(&mut self, i: u64) { + self.0 = i; } } @@ -733,12 +736,12 @@ impl KmerCountTable { // build KmerCountTables in parallel let tables: Vec = coord_pairs - .par_iter() + .into_par_iter() .map(|(start, end)| { let mut t = KmerCountTable::new(self.ksize, self.store_kmers); - let start = *start as usize; - let end = *end as usize; + let start = start as usize; + let end = end as usize; t._consume(&seq[start..end], skip_bad_kmers) .expect("fail in sub consume"); t @@ -761,6 +764,87 @@ impl KmerCountTable { Ok(total_consumed) } + #[pyo3(signature = (seq, chunk_size, skip_bad_kmers=true))] + pub fn parallel_consume2( + &mut self, + seq: &str, + chunk_size: u64, + skip_bad_kmers: bool, + ) -> PyResult { + let ksize: u64 = self.ksize.into(); + let chunk_size = max(chunk_size, ksize); + + // figure out the number of chunks, given the desired chunk size. + // @CTB: factor out into own function! + let seq_len = seq.len() as u64; + let mut num_chunks: u64 = seq_len / chunk_size; + + // build a vec of (start, end) pairs. + let mut coord_pairs: Vec<(u64, u64)> = vec![]; + + // do entire sequence in one? all good. + if num_chunks <= 1 { + coord_pairs.push((0, seq_len)); + } else { + // more than one chunk: do more complicated stuff :). + let mut final_chunk: bool = false; + if seq_len % chunk_size > 0 { + num_chunks = num_chunks - 1; + final_chunk = true; + } + + for i in 0..num_chunks { + let start = i * chunk_size; + let end = (i + 1) * chunk_size + ksize - 1; + coord_pairs.push((start, end)); + } + if final_chunk { + // @CTB eprintln!("final chunk!"); + // collect up the remainder + coord_pairs.push((num_chunks * chunk_size, seq_len)); + } + } + + // @CTB eprintln!("{:?}", coord_pairs); + + let (sender, receiver) = mpsc::channel(); + + for (start, end) in coord_pairs.into_iter() { + let sender = sender.clone(); + let start = start as usize; + let end = end as usize; + let subseq = String::from(&seq[start..end]); + + thread::spawn(move || { + let hashes = SeqToHashes::new( + subseq.as_str().as_bytes(), + ksize as usize, + skip_bad_kmers, + false, + HashFunctions::Murmur64Dna, + 42, + ); + + for hash_value in hashes { + match hash_value { + Ok(0) => continue, + Ok(x) => { + sender.send(x).expect("send failed?!"); + }, + Err(_) => continue, + } + } + }); + } + drop(sender); + + for received in receiver { + self.count_hash(received); + } + + Ok(1) // total_consumed @CTB + } + // Helper method to get hash set of k-mers fn hash_set(&self) -> HashSet { self.counts.keys().cloned().collect()