Skip to content

Commit

Permalink
Speeding up CallDuplexConsensusReads.
Browse files Browse the repository at this point in the history
Changes to CallDuplexConsensusReads:
- added the --threads option to support multi-threading; 4-8 threads
  seems like a decent trade-off.
- added the --max-reads-per-strand option, for when the per-molecule
  coverage is very high, thus causing the tool to run slowly.

Consensus calling API
Implemented many performance optmizations found during profiling for
consensus calling:
- faster grouping of raw reads based on simplified cigars

Both @nh13 and @tfenne contributed to this commit.
  • Loading branch information
nh13 committed Jul 23, 2019
1 parent ea40d22 commit 4f3d545
Show file tree
Hide file tree
Showing 12 changed files with 317 additions and 152 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
sudo: required
language: scala
dist: trusty
scala:
- 2.12.2
jdk:
Expand Down
4 changes: 4 additions & 0 deletions src/main/scala/com/fulcrumgenomics/bam/api/SamRecord.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class TransientAttrs(private val rec: SamRecord) {
if (value == null) rec.asSam.removeTransientAttribute(key) else rec.asSam.setTransientAttribute(key, value)
}
def get[A](key: Any): Option[A] = Option(apply(key))
def getOrElse[A](key: Any, default: => A): A = rec.asSam.getTransientAttribute(key) match {
case null => default
case value => value.asInstanceOf[A]
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,13 @@ class CallDuplexConsensusReads
@arg(flag='m', doc="Ignore bases in raw reads that have Q below this value.") val minInputBaseQuality: PhredScore = DefaultMinInputBaseQuality,
@arg(flag='t', doc="If true, quality trim input reads in addition to masking low Q bases.") val trim: Boolean = false,
@arg(flag='S', doc="The sort order of the output, if `:none:` then the same as the input.") val sortOrder: Option[SamOrder] = Some(SamOrder.Queryname),
@arg(flag='M', minElements=1, maxElements=3, doc="The minimum number of input reads to a consensus read.") val minReads: Seq[Int] = Seq(1)
@arg(flag='M', minElements=1, maxElements=3, doc="The minimum number of input reads to a consensus read.") val minReads: Seq[Int] = Seq(1),
@arg(doc="""
|The maximum number of reads to use when building a single-strand consensus. If more than this many reads are
|present in a tag family, the family is randomly downsampled to exactly max-reads reads.
""")
val maxReadsPerStrand: Option[Int] = None,
@arg(doc="The number of threads to use while consensus calling.") val threads: Int = 1
) extends FgBioTool with LazyLogging {

Io.assertReadable(input)
Expand All @@ -122,11 +128,13 @@ class CallDuplexConsensusReads
trim = trim,
errorRatePreUmi = errorRatePreUmi,
errorRatePostUmi = errorRatePostUmi,
minReads = minReads
minReads = minReads,
maxReadsPerStrand = maxReadsPerStrand.getOrElse(VanillaUmiConsensusCallerOptions.DefaultMaxReads)
)

val iterator = new ConsensusCallingIterator(in.toIterator, caller, Some(ProgressLogger(logger)))
val progress = ProgressLogger(logger, unit=1000000)
val iterator = new ConsensusCallingIterator(in.toIterator, caller, Some(progress), threads)
out ++= iterator
progress.logLast()

in.safelyClose()
out.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class CallMolecularConsensusReads
minInputBaseQuality = minInputBaseQuality,
minConsensusBaseQuality = minConsensusBaseQuality,
minReads = minReads,
maxReads = maxReads.getOrElse(Int.MaxValue),
maxReads = maxReads.getOrElse(VanillaUmiConsensusCallerOptions.DefaultMaxReads),
producePerBaseTags = outputPerBaseTags
)

Expand Down
9 changes: 5 additions & 4 deletions src/main/scala/com/fulcrumgenomics/umi/ConsensusCaller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
def add(base: Base, pError: LogProbability, pTruth: LogProbability) = {
val b = SequenceUtil.upperCase(base)
if (b != 'N') {
val pErrorNormalized = LogProbability.normalizeByScalar(pError, 3)
var i = 0
while (i < DnaBaseCount) {
val candidateBase = DnaBasesUpperCase(i)

if (base == candidateBase) {
likelihoods(i) += pTruth
observations(i) += 1
}
else {
likelihoods(i) += LogProbability.normalizeByScalar(pError, 3)
likelihoods(i) += pErrorNormalized
}

i += 1
Expand All @@ -117,7 +117,9 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
* Returns the number of reads that contributed evidence to the consensus. The value is equal
* to the number of times add() was called with non-ambiguous bases.
*/
def contributions: Int = this.observations.sum
def contributions: Int = {
this.observations(0) + this.observations(1) + this.observations(2) + this.observations(3)
}

/** Gets the number of observations of the base in question. */
def observations(base: Base): Int = base match {
Expand Down Expand Up @@ -170,7 +172,6 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
}
}


/** Pre-computes the the log-scale probabilities of an error for each a phred-scaled base quality from 0-127. */
private val phredToAdjustedLogProbError: Array[LogProbability] = Range(0, Byte.MaxValue).toArray.map(q => {
val e1 = LogProbability.fromPhredScore(this.errorRatePostLabeling)
Expand Down
132 changes: 91 additions & 41 deletions src/main/scala/com/fulcrumgenomics/umi/ConsensusCallingIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,62 +26,112 @@ package com.fulcrumgenomics.umi

import com.fulcrumgenomics.FgBioDef._
import com.fulcrumgenomics.bam.api.SamRecord
import com.fulcrumgenomics.commons.async.AsyncIterator
import com.fulcrumgenomics.commons.util.LazyLogging
import com.fulcrumgenomics.umi.UmiConsensusCaller.SimpleRead
import com.fulcrumgenomics.util.ProgressLogger

import scala.collection.mutable
import scala.concurrent.forkjoin.ForkJoinPool

/**
* An iterator that consumes from an incoming iterator of SAMRecords and generates consensus
* read SAMRecords using the supplied caller.
* An iterator that consumes from an incoming iterator of [[SamRecord]]s and generates consensus
* read [[SamRecord]]s using the supplied consensus caller.
*
* @param sourceIterator an iterator of SAMRecords
* @param caller the consensus caller to use to call consensus reads
* @param sourceIterator the iterator over input [[SamRecord]]s.
* @param caller the consensus caller to use to call consensus reads
* @param progress an optional progress logger to which to log progress in input reads
* @param threads the number of threads to use.
* @param maxRecordsInRam the approximate maximum number of input records to store in RAM across multiple threads.
*/
class ConsensusCallingIterator(sourceIterator: Iterator[SamRecord],
val caller: UmiConsensusCaller[_],
val progress: Option[ProgressLogger] = None
) extends Iterator[SamRecord] {
class ConsensusCallingIterator[ConsensusRead <: SimpleRead](sourceIterator: Iterator[SamRecord],
caller: UmiConsensusCaller[ConsensusRead],
progress: Option[ProgressLogger] = None,
threads: Int = 1,
maxRecordsInRam: Int = 128000)
extends Iterator[SamRecord] with LazyLogging {

private val input = sourceIterator.bufferBetter
private val outputQueue: mutable.Queue[SamRecord] = mutable.Queue[SamRecord]()

/** True if there are more consensus reads, false otherwise. */
def hasNext(): Boolean = this.outputQueue.nonEmpty || (this.input.nonEmpty && advance())

/** Returns the next consensus read. */
def next(): SamRecord = {
if (!this.hasNext()) throw new NoSuchElementException("Calling next() when hasNext() is false.")
this.outputQueue.dequeue()
private val progressIterator = progress match {
case Some(p) => sourceIterator.map { r => p.record(r); r }
case None => sourceIterator
}

/**
* Consumes the next group of records from the input iterator, based on molecule id
* and returns them as a Seq.
*/
private def nextGroupOfRecords(): Seq[SamRecord] = {
if (this.input.isEmpty) {
Nil
protected val iterator: Iterator[SamRecord] = {
if (threads <= 1) {
val groupingIterator = new SamRecordGroupedIterator(progressIterator, caller.sourceMoleculeId)
groupingIterator.flatMap(caller.consensusReadsFromSamRecords)
}
else {
val idToMatch = this.caller.sourceMoleculeId(this.input.head)
this.input.takeWhile(this.caller.sourceMoleculeId(_) == idToMatch).toSeq
val halfMaxRecords = maxRecordsInRam / 2
val groupingIterator = new SamRecordGroupedIterator(new AsyncIterator(progressIterator, Some(halfMaxRecords)).start(), caller.sourceMoleculeId)
val pool = new ForkJoinPool(threads - 1, ForkJoinPool.defaultForkJoinWorkerThreadFactory, null, true)
val bufferedIter = groupingIterator.bufferBetter
val callers = new IterableThreadLocal[UmiConsensusCaller[ConsensusRead]](() => caller.emptyClone())

// Read in groups of records (each from the same source molecule) until we have the maximum number of
// individual records in RAM. We have a few more records in RAM, if the group that pushes us over the limit is
// large. Then process the collected groups in parallel.
Iterator.continually {
var total = 0L
bufferedIter
.takeWhile { chunk => if (halfMaxRecords <= total) false else {total += chunk.length; true } }
.toSeq
.parWith(pool)
.flatMap { records =>
val caller = callers.get()
caller.synchronized { caller.consensusReadsFromSamRecords(records) }
}
.seq
}.takeWhile { records =>
if (records.nonEmpty) true else {
// add the statistics to the original caller since there are no more reads
require(bufferedIter.isEmpty, "Bug: input is not empty")
callers.foreach(caller.addStatistics)
false
}
}.flatten

}
}
override def hasNext: Boolean = this.iterator.hasNext
override def next(): SamRecord = this.iterator.next
}

// TODO: migrate to the commons version of this class after the next commons release
private class IterableThreadLocal[A](factory: () => A) extends ThreadLocal[A] with Iterable[A] {
private val all = new java.util.concurrent.ConcurrentLinkedQueue[A]()

override def initialValue(): A = {
val a = factory()
all.add(a)
a
}

/** Consumes input records until one or more consensus reads can be created, or no more input records are available.
* Returns true if a consensus read was created and enqueued, false otherwise. */
@annotation.tailrec
/** Care should be taken accessing the iterator since objects may be in use by other threads. */
def iterator: Iterator[A] = all.toIterator
}


/** Groups consecutive records based on a method to group records. */
private class SamRecordGroupedIterator[Key](sourceIterator: Iterator[SamRecord],
toKey: SamRecord => Key) extends Iterator[Seq[SamRecord]] {
private val input = sourceIterator.bufferBetter
private var nextChunk = IndexedSeq.empty[SamRecord]

/** True if there are more consensus reads, false otherwise. */
def hasNext(): Boolean = this.nextChunk.nonEmpty || (this.input.nonEmpty && advance())

/** Returns the next consensus read. */
def next(): Seq[SamRecord] = {
if (!this.hasNext()) throw new NoSuchElementException("Calling next() when hasNext() is false.")
yieldAndThen { nextChunk } { nextChunk = IndexedSeq.empty }
}

/** Consumes the next group of records from the input iterator, based on the vkey, and returns them as a [[IndexedSeq]]. */
private def advance(): Boolean = {
// get the records to create the consensus read
val inputs = nextGroupOfRecords()
val outputs = this.caller.consensusReadsFromSamRecords(inputs)
this.outputQueue ++= outputs

// Log progress on the _input_ reads and then return/recurse
for (p <- progress; r <- inputs) p.record(r)
if (outputs.nonEmpty) true
else if (this.input.hasNext) advance()
else false
this.input.headOption.exists { head =>
val idToMatch = this.toKey(head)
this.nextChunk = this.input.takeWhile(this.toKey(_) == idToMatch).toIndexedSeq
true
}
}
}
Loading

0 comments on commit 4f3d545

Please sign in to comment.