Skip to content

Commit

Permalink
Adds an optimized path when creating a consensus from a single input …
Browse files Browse the repository at this point in the history
…read. (#790)

* Adds an optimized path when creating a consensus from a single input read.
* Additional consensus speedups.
* Switched to using the new ParIterator in commons inside of ConsensusCallingIterator
  • Loading branch information
tfenne authored Feb 26, 2022
1 parent adecf24 commit 4b981f2
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 104 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ lazy val root = Project(id="fgbio", base=file("."))
"org.scala-lang" % "scala-compiler" % scalaVersion.value,
"org.scala-lang.modules" %% "scala-xml" % "1.2.0",
"org.scala-lang.modules" %% "scala-collection-compat" % "2.1.1",
"com.fulcrumgenomics" %% "commons" % "1.4.0-014e196-SNAPSHOT",
"com.fulcrumgenomics" %% "commons" % "1.4.0-ffd57b8-SNAPSHOT",
"com.fulcrumgenomics" %% "sopt" % "1.1.0",
"com.github.samtools" % "htsjdk" % "2.23.0" excludeAll(htsjdkExcludes: _*),
"org.apache.commons" % "commons-math3" % "3.6.1",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class CallMolecularConsensusReads
if (tag.length != 2) throw new ValidationException("attribute must be of length 2")
if (errorRatePreUmi < 0) throw new ValidationException("Phred-scaled error rate pre UMI must be >= 0")
if (errorRatePostUmi < 0) throw new ValidationException("Phred-scaled error rate post UMI must be >= 0")
validate(this.maxReads.forall(max => max >= this.minReads), "--max-reads must be >= --min-reads.")

/** Main method that does the work of reading input files, creating the consensus reads, and writing the output file. */
override def execute(): Unit = {
Expand Down Expand Up @@ -164,9 +165,11 @@ class CallMolecularConsensusReads
rejects = rej
)

val iterator = new ConsensusCallingIterator(in.iterator, caller, Some(ProgressLogger(logger, unit=5e5.toInt)), threads=threads)
val progress = ProgressLogger(logger, unit=1e6.toInt)
val iterator = new ConsensusCallingIterator(in.iterator, caller, Some(progress), threads=threads)
out ++= iterator

progress.logLast()
in.safelyClose()
out.close()
rej.foreach(_.close())
Expand Down
17 changes: 12 additions & 5 deletions src/main/scala/com/fulcrumgenomics/umi/ConsensusCaller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,16 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
}

/** Adds a base and un-adjusted base quality to the consensus likelihoods. */
def add(base: Base, qual: PhredScore): Unit = add(base, pError=phredToAdjustedLogProbError(qual), pTruth=phredToAdjustedLogProbCorrect(qual))
def add(base: Base, qual: PhredScore): Unit =
add(base, pErrorPerBase=phredToOneThirdAdjustedLogProbError(qual), pTruth=phredToAdjustedLogProbCorrect(qual))

/** Adds a base with adjusted error and truth probabilities to the consensus likelihoods. */
def add(base: Base, pError: LogProbability, pTruth: LogProbability) = {
/**
* Adds a base with adjusted error and truth probabilities to the consensus likelihoods.
*
*/
private def add(base: Base, pErrorPerBase: LogProbability, pTruth: LogProbability) = {
val b = SequenceUtil.upperCase(base)
if (b != 'N') {
val pErrorNormalized = LogProbability.normalizeByScalar(pError, 3)
var i = 0
while (i < DnaBaseCount) {
val candidateBase = DnaBasesUpperCase(i)
Expand All @@ -105,7 +108,7 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
observations(i) += 1
}
else {
likelihoods(i) += pErrorNormalized
likelihoods(i) += pErrorPerBase
}

i += 1
Expand Down Expand Up @@ -179,6 +182,10 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore,
LogProbability.probabilityOfErrorTwoTrials(e1, e2)
})

/** One third of the adjusted error as a log probability - i.e. what is the probability of being exactly one other base.*/
private val phredToOneThirdAdjustedLogProbError: Array[LogProbability] =
phredToAdjustedLogProbError.map(e => LogProbability.normalizeByScalar(e, 3))

/** Pre-computes the the log-scale probabilities of an not an error for each a phred-scaled base quality from 0-127. */
private val phredToAdjustedLogProbCorrect: Array[Double] = phredToAdjustedLogProbError.map(LogProbability.not)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,13 @@
package com.fulcrumgenomics.umi

import com.fulcrumgenomics.FgBioDef._
import com.fulcrumgenomics.bam.api
import com.fulcrumgenomics.bam.api.SamRecord
import com.fulcrumgenomics.commons.async.AsyncIterator
import com.fulcrumgenomics.commons.collection.ParIterator
import com.fulcrumgenomics.commons.util.LazyLogging
import com.fulcrumgenomics.commons.util.Threads.IterableThreadLocal
import com.fulcrumgenomics.umi.UmiConsensusCaller.SimpleRead
import com.fulcrumgenomics.util.ProgressLogger

import java.util.concurrent.ForkJoinPool

/**
* An iterator that consumes from an incoming iterator of [[SamRecord]]s and generates consensus
* read [[SamRecord]]s using the supplied consensus caller.
Expand All @@ -43,67 +40,52 @@ import java.util.concurrent.ForkJoinPool
* @param caller the consensus caller to use to call consensus reads
* @param progress an optional progress logger to which to log progress in input reads
* @param threads the number of threads to use.
* @param maxRecordsInRam the approximate maximum number of input records to store in RAM across multiple threads.
* @param chunkSize parallel process in chunkSize units; will cause 8 * chunkSize records to be held in memory
*/
class ConsensusCallingIterator[ConsensusRead <: SimpleRead](sourceIterator: Iterator[SamRecord],
caller: UmiConsensusCaller[ConsensusRead],
progress: Option[ProgressLogger] = None,
threads: Int = 1,
maxRecordsInRam: Int = 128000)
chunkSize: Int = ParIterator.DefaultChunkSize)
extends Iterator[SamRecord] with LazyLogging {

private val progressIterator = progress match {
case Some(p) => sourceIterator.tapEach { r => p.record(r) }
case None => sourceIterator
}
private val callers = new IterableThreadLocal[UmiConsensusCaller[ConsensusRead]](() => caller.emptyClone())
private var collectedStats: Boolean = false

protected val iter: Iterator[SamRecord] = {
// Wrap our input iterator in a progress logging iterator if we have a progress logger
val progressIterator = progress match {
case Some(p) => sourceIterator.tapEach { r => p.record(r) }
case None => sourceIterator
}

// Then turn it into a grouping iterator
val groupingIterator = new SamRecordGroupedIterator(progressIterator, caller.sourceMoleculeId)

// Then call consensus either single-threaded or multi-threaded
if (threads <= 1) {
val groupingIterator = new SamRecordGroupedIterator(progressIterator, caller.sourceMoleculeId)
groupingIterator.flatMap(caller.consensusReadsFromSamRecords)
}
else {
val halfMaxRecords = maxRecordsInRam / 2
val pool = new ForkJoinPool(threads - 1, ForkJoinPool.defaultForkJoinWorkerThreadFactory, null, true)
val callers = new IterableThreadLocal[UmiConsensusCaller[ConsensusRead]](() => caller.emptyClone())
val groupingIterator = {
val async = AsyncIterator(progressIterator, Some(halfMaxRecords))
val grouped = new SamRecordGroupedIterator(async, caller.sourceMoleculeId)
grouped.bufferBetter
}

// Create an iterator that will pull in input records (up to the half the maximum number in memory)
// to be consensus called. Each chunk of records will then be called in parallel.
new Iterator[Seq[SamRecord]] {
private var statisticsCollected = false
ParIterator(groupingIterator, threads=threads).flatMap { rs =>
val caller = callers.get()
caller.synchronized { caller.consensusReadsFromSamRecords(rs) }
}.toAsync(chunkSize * 8)
}
}

override def hasNext: Boolean = {
if (groupingIterator.hasNext) true else {
// If we've hit the end then aggregate statistics
if (!statisticsCollected) {
callers.foreach(caller.addStatistics)
statisticsCollected = true
}
false
}
}
// Responsible for adding statistics to the main caller once we hit the end of the iterator.
override def hasNext: Boolean = {
if (this.iter.hasNext) true else {
if (!collectedStats) {
callers.foreach(c => caller.addStatistics(c))
collectedStats = true
}

override def next(): Seq[SamRecord] = {
var total = 0L
groupingIterator
.takeWhile { chunk => if (halfMaxRecords <= total) false else {total += chunk.length; true } }
.toSeq
.parWith(pool)
.flatMap { records =>
val caller = callers.get()
caller.synchronized { caller.consensusReadsFromSamRecords(records) }
}
.seq
}
}.flatten
false
}
}
override def hasNext: Boolean = this.iter.hasNext

override def next(): SamRecord = this.iter.next
}

Expand Down
45 changes: 24 additions & 21 deletions src/main/scala/com/fulcrumgenomics/umi/SimpleConsensusCaller.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,31 +58,34 @@ private[umi] class SimpleConsensusCaller(val errorRatePreLabeling: Byte = 90.toB
/** Calls a simple consensus sequences from a set of sequences all the same length. */
def callConsensus(sequences: Seq[String]): String = {
require(sequences.nonEmpty, "Can't call consensus on an empty set of sequences!")
require(sequences.forall(_.length == sequences.head.length), "Sequences must all have the same length")
val buffer = new StringBuilder
val firstRead = sequences.head
val readLength = firstRead.length
val sequencesLength = sequences.length

forloop (from=0, until=readLength) { i =>
this.consensusBuilder.reset()
var nonDna = 0
sequences.foreach { sequence =>
val char = sequence.charAt(i)
if (!this.DnaBasesBitSet.contains(char.toInt)) {
nonDna += 1
// verify that all non-DNA bases are the same character
require(firstRead.charAt(i) == char,
s"Sequences must have character '${firstRead.charAt(i)}' at position $i, found '$char'")
if (sequences.length == 1) sequences.head else {
require(sequences.forall(_.length == sequences.head.length), "Sequences must all have the same length")
val buffer = new StringBuilder
val firstRead = sequences.head
val readLength = firstRead.length
val sequencesLength = sequences.length

forloop (from=0, until=readLength) { i =>
this.consensusBuilder.reset()
var nonDna = 0
sequences.foreach { sequence =>
val char = sequence.charAt(i)
if (!this.DnaBasesBitSet.contains(char.toInt)) {
nonDna += 1
// verify that all non-DNA bases are the same character
require(firstRead.charAt(i) == char,
s"Sequences must have character '${firstRead.charAt(i)}' at position $i, found '$char'")
}
else this.consensusBuilder.add(char.toByte, qError)
}
else this.consensusBuilder.add(char.toByte, pError=this.pError, pTruth=this.pTruth)

if (nonDna == 0) buffer.append(this.consensusBuilder.call()._1.toChar)
else if (nonDna == sequencesLength) buffer.append(firstRead.charAt(i)) // NB: we have previously verified they are all the same character
else throw new IllegalStateException(s"Sequences contained a mix of DNA and non-DNA characters at offset $i: $sequences")
}

if (nonDna == 0) buffer.append(this.consensusBuilder.call()._1.toChar)
else if (nonDna == sequencesLength) buffer.append(firstRead.charAt(i)) // NB: we have previously verified they are all the same character
else throw new IllegalStateException(s"Sequences contained a mix of DNA and non-DNA characters at offset $i: $sequences")
buffer.toString()
}

buffer.toString()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@

package com.fulcrumgenomics.umi

import com.fulcrumgenomics.FgBioDef.forloop
import com.fulcrumgenomics.bam.api.{SamRecord, SamWriter}
import com.fulcrumgenomics.commons.util.LazyLogging
import com.fulcrumgenomics.umi.ConsensusCaller.Base
import com.fulcrumgenomics.umi.UmiConsensusCaller.ReadType._
import com.fulcrumgenomics.umi.UmiConsensusCaller._
import com.fulcrumgenomics.umi.VanillaUmiConsensusCallerOptions._
import com.fulcrumgenomics.util.NumericTypes._
import scala.collection.mutable.ListBuffer

import scala.util.Random

/**
Expand Down Expand Up @@ -110,6 +111,13 @@ class VanillaUmiConsensusCaller(override val readNamePrefix: String,
private val caller = new ConsensusCaller(errorRatePreLabeling = options.errorRatePreUmi,
errorRatePostLabeling = options.errorRatePostUmi)

/** Map from input qual score to output qual score in the case where there is only one read going into the consensus. */
private val SingleInputConsensusQuals: Array[Byte] = Range.inclusive(0, PhredScore.MaxValue).map { q =>
val lnProbOne = LogProbability.fromPhredScore(q)
val lnProbTwo = LogProbability.fromPhredScore(Math.min(this.options.errorRatePreUmi, this.options.errorRatePostUmi))
PhredScore.fromLogProbability(LogProbability.probabilityOfErrorTwoTrials(lnProbOne, lnProbTwo))
}.toArray

private val random = new Random(42)

/** Returns a clone of this consensus caller in a state where no previous reads were processed. I.e. all counters
Expand Down Expand Up @@ -196,39 +204,57 @@ class VanillaUmiConsensusCaller(override val readNamePrefix: String,
val consensusDepths = new Array[Short](consensusLength)
val consensusErrors = new Array[Short](consensusLength)

var positionInRead = 0
val builder = this.caller.builder()
while (positionInRead < consensusLength) {
// Add the evidence from all reads that are long enough to cover this base
capped.foreach { read =>
if (read.length > positionInRead) {
val base = read.bases(positionInRead)
val qual = read.quals(positionInRead)
if (base != NoCall) builder.add(base=base, qual=qual)
}
}
if (capped.length == 1) {
val inBases = capped.head.bases
val inQuals = capped.head.quals

val depth = builder.contributions // NB: cache this value, as it is re-computed each time
forloop (from=0, until=consensusLength) { i =>
val rawBase = inBases(i)
val rawQual = SingleInputConsensusQuals(inQuals(i))
val (base, qual) = if (rawQual < this.options.minConsensusBaseQuality) (NoCall, TooLowQualityQual) else (rawBase, rawQual)
val isNoCall = base == NoCall

// Call the consensus and do any additional filtering
val (rawBase, rawQual) = builder.call()
val (base, qual) = {
if (depth < this.options.minReads) (NoCall, NotEnoughReadsQual)
else if (rawQual < this.options.minConsensusBaseQuality) (NoCall, TooLowQualityQual)
else (rawBase, rawQual)
consensusBases(i) = base
consensusQuals(i) = qual
consensusDepths(i) = if (isNoCall) 0 else 1
consensusErrors(i) = 0
}
}
else {
var positionInRead = 0
val builder = this.caller.builder()
while (positionInRead < consensusLength) {
// Add the evidence from all reads that are long enough to cover this base
capped.foreach { read =>
if (read.length > positionInRead) {
val base = read.bases(positionInRead)
val qual = read.quals(positionInRead)
if (base != NoCall) builder.add(base=base, qual=qual)
}
}

val depth = builder.contributions // NB: cache this value, as it is re-computed each time

consensusBases(positionInRead) = base
consensusQuals(positionInRead) = qual
// Call the consensus and do any additional filtering
val (rawBase, rawQual) = builder.call()
val (base, qual) = {
if (depth < this.options.minReads) (NoCall, NotEnoughReadsQual)
else if (rawQual < this.options.minConsensusBaseQuality) (NoCall, TooLowQualityQual)
else (rawBase, rawQual)
}

// Generate the values for depth and count of errors
val errors = if (rawBase == NoCall) depth else depth - builder.observations(rawBase)
consensusDepths(positionInRead) = if (depth > Short.MaxValue) Short.MaxValue else depth.toShort
consensusErrors(positionInRead) = if (errors > Short.MaxValue) Short.MaxValue else errors.toShort
consensusBases(positionInRead) = base
consensusQuals(positionInRead) = qual

// Get ready for the next pass
builder.reset()
positionInRead += 1
// Generate the values for depth and count of errors
val errors = if (rawBase == NoCall) depth else depth - builder.observations(rawBase)
consensusDepths(positionInRead) = if (depth > Short.MaxValue) Short.MaxValue else depth.toShort
consensusErrors(positionInRead) = if (errors > Short.MaxValue) Short.MaxValue else errors.toShort

// Get ready for the next pass
builder.reset()
positionInRead += 1
}
}

Some(VanillaConsensusRead(id=capped.head.id, bases=consensusBases, quals=consensusQuals, depths=consensusDepths, errors=consensusErrors))
Expand Down

0 comments on commit 4b981f2

Please sign in to comment.