Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
ctb committed Nov 2, 2024
1 parent 10ba8f1 commit 22a9f62
Showing 1 changed file with 89 additions and 5 deletions.
94 changes: 89 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ use std::fmt;
use std::fs::File;
use std::hash::{BuildHasher, BuildHasherDefault, Hash, Hasher};
use std::io::{BufReader, BufWriter, Write};
use std::sync::mpsc;
use std::thread;

//use std::path::Path;

// External crate imports
Expand Down Expand Up @@ -77,8 +80,8 @@ impl Hasher for HashIntoType {
panic!("This hasher only takes u64");
}

fn write_u64(&mut self, _i: u64) { // no-op implementation!
// self.0 = i;
fn write_u64(&mut self, i: u64) {
self.0 = i;
}
}

Expand Down Expand Up @@ -733,12 +736,12 @@ impl KmerCountTable {

// build KmerCountTables in parallel
let tables: Vec<KmerCountTable> = coord_pairs
.par_iter()
.into_par_iter()
.map(|(start, end)| {
let mut t = KmerCountTable::new(self.ksize, self.store_kmers);

let start = *start as usize;
let end = *end as usize;
let start = start as usize;
let end = end as usize;
t._consume(&seq[start..end], skip_bad_kmers)
.expect("fail in sub consume");
t
Expand All @@ -761,6 +764,87 @@ impl KmerCountTable {
Ok(total_consumed)
}

#[pyo3(signature = (seq, chunk_size, skip_bad_kmers=true))]
pub fn parallel_consume2(
&mut self,
seq: &str,
chunk_size: u64,
skip_bad_kmers: bool,
) -> PyResult<u64> {
let ksize: u64 = self.ksize.into();
let chunk_size = max(chunk_size, ksize);

// figure out the number of chunks, given the desired chunk size.
// @CTB: factor out into own function!
let seq_len = seq.len() as u64;
let mut num_chunks: u64 = seq_len / chunk_size;

// build a vec of (start, end) pairs.
let mut coord_pairs: Vec<(u64, u64)> = vec![];

// do entire sequence in one? all good.
if num_chunks <= 1 {
coord_pairs.push((0, seq_len));
} else {
// more than one chunk: do more complicated stuff :).
let mut final_chunk: bool = false;
if seq_len % chunk_size > 0 {
num_chunks = num_chunks - 1;
final_chunk = true;
}

for i in 0..num_chunks {
let start = i * chunk_size;
let end = (i + 1) * chunk_size + ksize - 1;
coord_pairs.push((start, end));
}
if final_chunk {
// @CTB eprintln!("final chunk!");
// collect up the remainder
coord_pairs.push((num_chunks * chunk_size, seq_len));
}
}

// @CTB eprintln!("{:?}", coord_pairs);

let (sender, receiver) = mpsc::channel();

for (start, end) in coord_pairs.into_iter() {
let sender = sender.clone();
let start = start as usize;
let end = end as usize;
let subseq = String::from(&seq[start..end]);

thread::spawn(move || {
let hashes = SeqToHashes::new(
subseq.as_str().as_bytes(),
ksize as usize,
skip_bad_kmers,
false,
HashFunctions::Murmur64Dna,
42,
);

for hash_value in hashes {
match hash_value {
Ok(0) => continue,
Ok(x) => {
sender.send(x).expect("send failed?!");
},
Err(_) => continue,
}
}
});
}
drop(sender);

for received in receiver {
self.count_hash(received);
}

Ok(1) // total_consumed @CTB
}

// Helper method to get hash set of k-mers
fn hash_set(&self) -> HashSet<HashIntoType> {
self.counts.keys().cloned().collect()
Expand Down

0 comments on commit 22a9f62

Please sign in to comment.