Skip to content

Commit

Permalink
Refactor & rewrite group to tag reads
Browse files Browse the repository at this point in the history
  • Loading branch information
olliecheng committed Dec 2, 2024
1 parent 8983c32 commit bd27a9e
Show file tree
Hide file tree
Showing 5 changed files with 366 additions and 238 deletions.
301 changes: 98 additions & 203 deletions src/call.rs
Original file line number Diff line number Diff line change
@@ -1,110 +1,36 @@
use crate::duplicates::DuplicateMap;
use crate::duplicates::RecordIdentifier;
use crate::duplicates::{DuplicateMap, RecordIdentifier};
use crate::io::{iter_duplicates, until_err, write_read, ReadType, Record, UMIGroup};

use bio::io::fastq;
use bio::io::fastq::FastqRead;
use needletail::parser::{FastqReader, FastxReader, SequenceRecord};
use spoa::{AlignmentEngine, AlignmentType};

use std::fs::File;
use std::io::prelude::*;
use std::io::{Seek, SeekFrom};
use std::io::{Cursor, Seek, SeekFrom};
use std::process::{Command, Stdio};

use anyhow::{anyhow, Context, Result};

// required for writeln! on a string
use std::fmt::Write as FmtWrite;

use crate::io;
use pariter::IteratorExt as _;

struct DuplicateRecord {
id: RecordIdentifier,
records: Vec<fastq::Record>,
}

fn run_command(stdin: &str, shell: &str, command: &str) -> Result<Vec<u8>> {
let mut child = Command::new(shell)
.args(["-c", command])
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.spawn()
.expect("Could not execute process");

let mut stdin_pipe = child
.stdin
.take()
.with_context(|| format!("Failed to take stdin pipe on input instance:\n{}", stdin))?;
stdin_pipe
.write_all(stdin.as_bytes())
.with_context(|| format!("Failed to write to stdin pipe on instance:\n{}", stdin))?;

// drop the stdin pipe to close the stream
drop(stdin_pipe);

let output = child
.wait_with_output()
.with_context(|| format!("Failed to wait on output for input instance:\n{}", stdin))?;
if output.status.success() {
Ok(output.stdout)
} else {
let stderr = String::from_utf8(output.stderr)?;
match output.status.code() {
Some(code) => Err(anyhow!(
"Exited with status code: {code}. Stderr:\n{stderr}"
)),
None => Err(anyhow!("Process terminated by signal. Stderr:\n{stderr}")),
}
}
}

pub fn custom_command(
input: &str,
writer: &mut impl Write,
duplicates: DuplicateMap,
threads: usize,
shell: &str,
command: &str,
) -> Result<()> {
let cache_size = threads * 3;

let scope_obj = crossbeam::thread::scope(|scope| -> Result<()> {
// this will store any errors
let mut err = Ok(());

iter_duplicates(input, duplicates, true)?
.parallel_map_scoped_custom(
scope,
|o| o.threads(threads).buffer_size(cache_size),
|rec| {
// propagate errors
let rec = rec?;

assert_ne!(rec.records.len(), 1);
let mut fastq_str = String::new();

for record in rec.records.iter() {
write!(fastq_str, "{}", record).context("Could not format string")?;
}

run_command(input, shell, command)
},
)
.scan(&mut err, until_err)
.for_each(|output| {
writer
.write_all(&output)
.expect("Could not write to output");
});
err
});

scope_obj.unwrap_or_else(|e| {
error!("Caught a panic which is unrecoverable");
std::panic::resume_unwind(e)
})
}

/// Generates consensus sequences from the input in a thread-stable manner.
///
/// # Arguments
///
/// * `input` - A string slice that holds the path to the input file.
/// * `writer` - A mutable reference to an object that implements the `Write` trait,
/// used for writing the output.
/// * `duplicates` - A `DuplicateMap` containing the duplicate reads.
/// * `threads` - The number of threads to use for parallel processing.
/// * `duplicates_only` - A boolean indicating whether to process only duplicate reads.
/// * `output_originals` - A boolean indicating whether to include the original reads in the output.
///
/// # Returns
///
/// * `Result<()>` - Returns `Ok(())` if successful, or an error if an error occurs
/// during processing.
pub fn consensus(
input: &str,
writer: &mut impl Write,
Expand All @@ -113,135 +39,104 @@ pub fn consensus(
duplicates_only: bool,
output_originals: bool,
) -> Result<()> {
// Start with a placeholder error object. This will be mutated if there are errors during
// iteration through the reads.
let mut err = Ok(());

let cache_size = threads * 3;

let result = crossbeam::thread::scope(|s| -> Result<()> {
let duplicate_iterator = iter_duplicates(input, duplicates, duplicates_only)?;
let result = crossbeam::thread::scope(|scope| -> Result<()> {
let duplicate_iterator = iter_duplicates(
input,
duplicates,
duplicates_only,
)?;

// convert this sequential iterator into a parallel one for consensus calling
duplicate_iterator
// convert this sequential iterator into a parallel one for consensus calling
.scan(&mut err, until_err)
.scan(&mut err, until_err) // iterate until an error is found, writing into &err
.parallel_map_scoped_custom(
s,
scope,
|o| o.threads(threads).buffer_size(cache_size),
|rec| {
let single = rec.records.len() == 1;

let mut poa_graph;
if single {
let consensus = std::str::from_utf8(rec.records[0].seq()).unwrap();

format!(">{0}_SIN\n{1}\n", rec.id, consensus)
} else {
let mut output = String::new();

// TODO: find a way to move this outside of the parallel map
let mut alignment_engine =
AlignmentEngine::new(AlignmentType::kOV, 5, -4, -8, -6, -10, -4);
poa_graph = spoa::Graph::new();

let record_count = rec.records.len();

for (index, record) in rec.records.iter().enumerate() {
let seq = record.seq();
let qual = record.qual();

if output_originals {
writeln!(
output,
">{0}_DUP_{1}_of_{2}\n{3}",
rec.id,
index + 1,
record_count,
std::str::from_utf8(seq).unwrap()
)
.expect("string writing should not fail");
}

let align = alignment_engine.align_from_bytes(seq, &poa_graph);
poa_graph.add_alignment_from_bytes(&align, seq, qual);
}

let consensus = poa_graph.consensus();
let consensus = consensus
.to_str()
.expect("spoa module should produce valid utf-8");

writeln!(
output,
">{0}_CON_{1}\n{2}",
rec.id, record_count, consensus
)
.expect("string writing should not fail");

output
}
},
|r| call_record(r, output_originals),
)
// write every read in a global thread in order
.for_each(|output| {
writer.write_all(output.as_bytes()).unwrap();
writer.write_all(&output).unwrap();
});

Ok(())
});

// Threads can't send regular errors well between them, so
// if there is an issue here we panic
result.unwrap_or_else(|e| {
error!("Caught a panic which is unrecoverable");
std::panic::resume_unwind(e)
})?;

err
}

fn iter_duplicates(
input: &str,
duplicates: DuplicateMap,
duplicates_only: bool,
) -> Result<impl Iterator<Item=Result<DuplicateRecord>> + '_> {
let mut file = File::open(input).with_context(|| format!("Unable to open file {input}"))?;

Ok(duplicates
.into_iter()
// first, read from the file (sequential)
.filter_map(move |(id, positions)| -> Option<Result<DuplicateRecord>> {
// if we choose not to only output duplicates, we can skip over this
if (positions.len() == 1) && (duplicates_only) {
return None;
}

let mut rec = DuplicateRecord {
id,
records: Vec::new(),
};

for pos in positions.iter() {
let mut record = fastq::Record::new();
let err = file.seek(SeekFrom::Start(*pos as u64));
if let Err(e) = err {
let context = format!("Unable to seek to file {} at position {}", input, pos);
return Some(Err(anyhow::Error::new(e).context(context)));
}

let mut reader = fastq::Reader::new(&mut file);

let err = reader.read(&mut record);
if let Err(e) = err {
let context = format!("Unable to read from file {} at position {}", input, pos);
return Some(Err(anyhow::Error::new(e).context(context)));
}

rec.records.push(record);
}
Some(Ok(rec))
}))
}
/// Generates a consensus sequence from a group of reads.
///
/// # Arguments
///
/// * `group` - A `UMIGroup` containing the reads to be processed.
/// * `output_originals` - A boolean indicating whether to include the original reads in the
/// output alongside the consensus read.
///
/// # Returns
///
/// A `String` containing the consensus sequence in FASTQ format.
fn call_record(group: UMIGroup, output_originals: bool) -> Vec<u8> {
let length = group.records.len();
let mut output = Cursor::new(Vec::new());

// for singletons, the read is its own consensus
if length == 1 {
let record = &group.records[0];
io::write_read(&mut output, record, &group, ReadType::CONSENSUS, false).unwrap();
return output.into_inner();
}

/// Utility function to extract the error from an iterator
fn until_err<T>(err: &mut &mut Result<()>, item: Result<T>) -> Option<T> {
match item {
Ok(item) => Some(item),
Err(e) => {
**err = Err(e);
None
// initialise `spoa` machinery
let mut alignment_engine =
AlignmentEngine::new(AlignmentType::kOV, 5, -4, -8, -6, -10, -4);
let mut poa_graph = spoa::Graph::new();

// add each read in the duplicate group to the graph
for record in group.records.iter() {
if output_originals {
// Write the original reads as well
io::write_read(&mut output, record, &group, ReadType::ORIGINAL, false).unwrap();
}

// Align to the graph
let align = alignment_engine.align_from_bytes(record.seq.as_ref(), &poa_graph);
poa_graph.add_alignment_from_bytes(&align, record.seq.as_ref(), record.qual.as_ref());
}

// Create a consensus read
let consensus_str = poa_graph.consensus();
let consensus_str = consensus_str
.to_str()
.expect("spoa module did not produce valid utf-8");

let id_string = format!(
"consensus_{} avg_input_quality={:.2}",
group.index,
group.avg_qual
);

let consensus = Record {
id: id_string,
seq: consensus_str.to_string(),
qual: "".to_string(),
};

io::write_read(&mut output, &consensus, &group, ReadType::CONSENSUS, false).unwrap();

output.into_inner()
}

18 changes: 1 addition & 17 deletions src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,33 +106,17 @@ pub enum Commands {
report_original_reads: bool,
},

/// 'Group' duplicate reads, and pass to downstream applications.
/// Tag each read by its UMI group, and write to a .fastq file
#[command(arg_required_else_help = true)]
Group {
/// the index file
#[arg(long)]
index: String,

/// the input .fastq
#[arg(long)]
input: String,

/// the output location, or default to stdout
#[arg(long)]
output: Option<String>,

/// the shell used to run the given command
#[arg(long, default_value = "bash")]
shell: String,

/// the number of threads to use. this will not guard against race conditions in any
/// downstream applications used. this will effectively set the number of individual
/// processes to launch
#[arg(short, long, default_value_t = 1)]
threads: usize,

/// the command to run. any groups will be passed as .fastq standard input.
#[arg(trailing_var_arg = true, default_value = "cat")]
command: Vec<String>,
},
}
Loading

0 comments on commit bd27a9e

Please sign in to comment.