From 4f3d5454add7b9584137a584b117a94a0a62e3a8 Mon Sep 17 00:00:00 2001 From: Nils Homer Date: Fri, 14 Jun 2019 15:23:37 -0700 Subject: [PATCH] Speeding up CallDuplexConsensusReads. Changes to CallDuplexConsensusReads: - added the --threads option to support multi-threading; 4-8 threads seems like a decent trade-off. - added the --max-reads-per-strand option, for when the per-molecule coverage is very high, thus causing the tool to run slowly. Consensus calling API Implemented many performance optmizations found during profiling for consensus calling: - faster grouping of raw reads based on simplified cigars Both @nh13 and @tfenne contributed to this commit. --- .travis.yml | 1 + .../fulcrumgenomics/bam/api/SamRecord.scala | 4 + .../umi/CallDuplexConsensusReads.scala | 16 ++- .../umi/CallMolecularConsensusReads.scala | 2 +- .../fulcrumgenomics/umi/ConsensusCaller.scala | 9 +- .../umi/ConsensusCallingIterator.scala | 132 ++++++++++++------ .../umi/DuplexConsensusCaller.scala | 97 ++++++++++--- .../umi/SimpleConsensusCaller.scala | 16 ++- .../umi/UmiConsensusCaller.scala | 98 +++++++------ .../umi/VanillaUmiConsensusCaller.scala | 23 ++- .../umi/CallDuplexConsensusReadsTest.scala | 64 +++++---- .../umi/VanillaUmiConsensusCallerTest.scala | 7 +- 12 files changed, 317 insertions(+), 152 deletions(-) diff --git a/.travis.yml b/.travis.yml index 10fe7b545..8ac313456 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,6 @@ sudo: required language: scala +dist: trusty scala: - 2.12.2 jdk: diff --git a/src/main/scala/com/fulcrumgenomics/bam/api/SamRecord.scala b/src/main/scala/com/fulcrumgenomics/bam/api/SamRecord.scala index c0b9d1fd9..ed601c53c 100644 --- a/src/main/scala/com/fulcrumgenomics/bam/api/SamRecord.scala +++ b/src/main/scala/com/fulcrumgenomics/bam/api/SamRecord.scala @@ -60,6 +60,10 @@ class TransientAttrs(private val rec: SamRecord) { if (value == null) rec.asSam.removeTransientAttribute(key) else rec.asSam.setTransientAttribute(key, value) } def get[A](key: Any): Option[A] = Option(apply(key)) + def getOrElse[A](key: Any, default: => A): A = rec.asSam.getTransientAttribute(key) match { + case null => default + case value => value.asInstanceOf[A] + } } /** diff --git a/src/main/scala/com/fulcrumgenomics/umi/CallDuplexConsensusReads.scala b/src/main/scala/com/fulcrumgenomics/umi/CallDuplexConsensusReads.scala index 70173fb8b..ad8a966e5 100644 --- a/src/main/scala/com/fulcrumgenomics/umi/CallDuplexConsensusReads.scala +++ b/src/main/scala/com/fulcrumgenomics/umi/CallDuplexConsensusReads.scala @@ -99,7 +99,13 @@ class CallDuplexConsensusReads @arg(flag='m', doc="Ignore bases in raw reads that have Q below this value.") val minInputBaseQuality: PhredScore = DefaultMinInputBaseQuality, @arg(flag='t', doc="If true, quality trim input reads in addition to masking low Q bases.") val trim: Boolean = false, @arg(flag='S', doc="The sort order of the output, if `:none:` then the same as the input.") val sortOrder: Option[SamOrder] = Some(SamOrder.Queryname), - @arg(flag='M', minElements=1, maxElements=3, doc="The minimum number of input reads to a consensus read.") val minReads: Seq[Int] = Seq(1) + @arg(flag='M', minElements=1, maxElements=3, doc="The minimum number of input reads to a consensus read.") val minReads: Seq[Int] = Seq(1), + @arg(doc=""" + |The maximum number of reads to use when building a single-strand consensus. If more than this many reads are + |present in a tag family, the family is randomly downsampled to exactly max-reads reads. + """) + val maxReadsPerStrand: Option[Int] = None, + @arg(doc="The number of threads to use while consensus calling.") val threads: Int = 1 ) extends FgBioTool with LazyLogging { Io.assertReadable(input) @@ -122,11 +128,13 @@ class CallDuplexConsensusReads trim = trim, errorRatePreUmi = errorRatePreUmi, errorRatePostUmi = errorRatePostUmi, - minReads = minReads + minReads = minReads, + maxReadsPerStrand = maxReadsPerStrand.getOrElse(VanillaUmiConsensusCallerOptions.DefaultMaxReads) ) - - val iterator = new ConsensusCallingIterator(in.toIterator, caller, Some(ProgressLogger(logger))) + val progress = ProgressLogger(logger, unit=1000000) + val iterator = new ConsensusCallingIterator(in.toIterator, caller, Some(progress), threads) out ++= iterator + progress.logLast() in.safelyClose() out.close() diff --git a/src/main/scala/com/fulcrumgenomics/umi/CallMolecularConsensusReads.scala b/src/main/scala/com/fulcrumgenomics/umi/CallMolecularConsensusReads.scala index d919af0a4..777398b9e 100644 --- a/src/main/scala/com/fulcrumgenomics/umi/CallMolecularConsensusReads.scala +++ b/src/main/scala/com/fulcrumgenomics/umi/CallMolecularConsensusReads.scala @@ -152,7 +152,7 @@ class CallMolecularConsensusReads minInputBaseQuality = minInputBaseQuality, minConsensusBaseQuality = minConsensusBaseQuality, minReads = minReads, - maxReads = maxReads.getOrElse(Int.MaxValue), + maxReads = maxReads.getOrElse(VanillaUmiConsensusCallerOptions.DefaultMaxReads), producePerBaseTags = outputPerBaseTags ) diff --git a/src/main/scala/com/fulcrumgenomics/umi/ConsensusCaller.scala b/src/main/scala/com/fulcrumgenomics/umi/ConsensusCaller.scala index 356ffba14..665338f25 100644 --- a/src/main/scala/com/fulcrumgenomics/umi/ConsensusCaller.scala +++ b/src/main/scala/com/fulcrumgenomics/umi/ConsensusCaller.scala @@ -96,16 +96,16 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore, def add(base: Base, pError: LogProbability, pTruth: LogProbability) = { val b = SequenceUtil.upperCase(base) if (b != 'N') { + val pErrorNormalized = LogProbability.normalizeByScalar(pError, 3) var i = 0 while (i < DnaBaseCount) { val candidateBase = DnaBasesUpperCase(i) - if (base == candidateBase) { likelihoods(i) += pTruth observations(i) += 1 } else { - likelihoods(i) += LogProbability.normalizeByScalar(pError, 3) + likelihoods(i) += pErrorNormalized } i += 1 @@ -117,7 +117,9 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore, * Returns the number of reads that contributed evidence to the consensus. The value is equal * to the number of times add() was called with non-ambiguous bases. */ - def contributions: Int = this.observations.sum + def contributions: Int = { + this.observations(0) + this.observations(1) + this.observations(2) + this.observations(3) + } /** Gets the number of observations of the base in question. */ def observations(base: Base): Int = base match { @@ -170,7 +172,6 @@ class ConsensusCaller(errorRatePreLabeling: PhredScore, } } - /** Pre-computes the the log-scale probabilities of an error for each a phred-scaled base quality from 0-127. */ private val phredToAdjustedLogProbError: Array[LogProbability] = Range(0, Byte.MaxValue).toArray.map(q => { val e1 = LogProbability.fromPhredScore(this.errorRatePostLabeling) diff --git a/src/main/scala/com/fulcrumgenomics/umi/ConsensusCallingIterator.scala b/src/main/scala/com/fulcrumgenomics/umi/ConsensusCallingIterator.scala index f89bad9f9..a354dab05 100644 --- a/src/main/scala/com/fulcrumgenomics/umi/ConsensusCallingIterator.scala +++ b/src/main/scala/com/fulcrumgenomics/umi/ConsensusCallingIterator.scala @@ -26,62 +26,112 @@ package com.fulcrumgenomics.umi import com.fulcrumgenomics.FgBioDef._ import com.fulcrumgenomics.bam.api.SamRecord +import com.fulcrumgenomics.commons.async.AsyncIterator +import com.fulcrumgenomics.commons.util.LazyLogging +import com.fulcrumgenomics.umi.UmiConsensusCaller.SimpleRead import com.fulcrumgenomics.util.ProgressLogger -import scala.collection.mutable +import scala.concurrent.forkjoin.ForkJoinPool /** - * An iterator that consumes from an incoming iterator of SAMRecords and generates consensus - * read SAMRecords using the supplied caller. + * An iterator that consumes from an incoming iterator of [[SamRecord]]s and generates consensus + * read [[SamRecord]]s using the supplied consensus caller. * - * @param sourceIterator an iterator of SAMRecords - * @param caller the consensus caller to use to call consensus reads + * @param sourceIterator the iterator over input [[SamRecord]]s. + * @param caller the consensus caller to use to call consensus reads * @param progress an optional progress logger to which to log progress in input reads + * @param threads the number of threads to use. + * @param maxRecordsInRam the approximate maximum number of input records to store in RAM across multiple threads. */ -class ConsensusCallingIterator(sourceIterator: Iterator[SamRecord], - val caller: UmiConsensusCaller[_], - val progress: Option[ProgressLogger] = None - ) extends Iterator[SamRecord] { +class ConsensusCallingIterator[ConsensusRead <: SimpleRead](sourceIterator: Iterator[SamRecord], + caller: UmiConsensusCaller[ConsensusRead], + progress: Option[ProgressLogger] = None, + threads: Int = 1, + maxRecordsInRam: Int = 128000) + extends Iterator[SamRecord] with LazyLogging { - private val input = sourceIterator.bufferBetter - private val outputQueue: mutable.Queue[SamRecord] = mutable.Queue[SamRecord]() - - /** True if there are more consensus reads, false otherwise. */ - def hasNext(): Boolean = this.outputQueue.nonEmpty || (this.input.nonEmpty && advance()) - - /** Returns the next consensus read. */ - def next(): SamRecord = { - if (!this.hasNext()) throw new NoSuchElementException("Calling next() when hasNext() is false.") - this.outputQueue.dequeue() + private val progressIterator = progress match { + case Some(p) => sourceIterator.map { r => p.record(r); r } + case None => sourceIterator } - /** - * Consumes the next group of records from the input iterator, based on molecule id - * and returns them as a Seq. - */ - private def nextGroupOfRecords(): Seq[SamRecord] = { - if (this.input.isEmpty) { - Nil + protected val iterator: Iterator[SamRecord] = { + if (threads <= 1) { + val groupingIterator = new SamRecordGroupedIterator(progressIterator, caller.sourceMoleculeId) + groupingIterator.flatMap(caller.consensusReadsFromSamRecords) } else { - val idToMatch = this.caller.sourceMoleculeId(this.input.head) - this.input.takeWhile(this.caller.sourceMoleculeId(_) == idToMatch).toSeq + val halfMaxRecords = maxRecordsInRam / 2 + val groupingIterator = new SamRecordGroupedIterator(new AsyncIterator(progressIterator, Some(halfMaxRecords)).start(), caller.sourceMoleculeId) + val pool = new ForkJoinPool(threads - 1, ForkJoinPool.defaultForkJoinWorkerThreadFactory, null, true) + val bufferedIter = groupingIterator.bufferBetter + val callers = new IterableThreadLocal[UmiConsensusCaller[ConsensusRead]](() => caller.emptyClone()) + + // Read in groups of records (each from the same source molecule) until we have the maximum number of + // individual records in RAM. We have a few more records in RAM, if the group that pushes us over the limit is + // large. Then process the collected groups in parallel. + Iterator.continually { + var total = 0L + bufferedIter + .takeWhile { chunk => if (halfMaxRecords <= total) false else {total += chunk.length; true } } + .toSeq + .parWith(pool) + .flatMap { records => + val caller = callers.get() + caller.synchronized { caller.consensusReadsFromSamRecords(records) } + } + .seq + }.takeWhile { records => + if (records.nonEmpty) true else { + // add the statistics to the original caller since there are no more reads + require(bufferedIter.isEmpty, "Bug: input is not empty") + callers.foreach(caller.addStatistics) + false + } + }.flatten + } } + override def hasNext: Boolean = this.iterator.hasNext + override def next(): SamRecord = this.iterator.next +} + +// TODO: migrate to the commons version of this class after the next commons release +private class IterableThreadLocal[A](factory: () => A) extends ThreadLocal[A] with Iterable[A] { + private val all = new java.util.concurrent.ConcurrentLinkedQueue[A]() + + override def initialValue(): A = { + val a = factory() + all.add(a) + a + } - /** Consumes input records until one or more consensus reads can be created, or no more input records are available. - * Returns true if a consensus read was created and enqueued, false otherwise. */ - @annotation.tailrec + /** Care should be taken accessing the iterator since objects may be in use by other threads. */ + def iterator: Iterator[A] = all.toIterator +} + + +/** Groups consecutive records based on a method to group records. */ +private class SamRecordGroupedIterator[Key](sourceIterator: Iterator[SamRecord], + toKey: SamRecord => Key) extends Iterator[Seq[SamRecord]] { + private val input = sourceIterator.bufferBetter + private var nextChunk = IndexedSeq.empty[SamRecord] + + /** True if there are more consensus reads, false otherwise. */ + def hasNext(): Boolean = this.nextChunk.nonEmpty || (this.input.nonEmpty && advance()) + + /** Returns the next consensus read. */ + def next(): Seq[SamRecord] = { + if (!this.hasNext()) throw new NoSuchElementException("Calling next() when hasNext() is false.") + yieldAndThen { nextChunk } { nextChunk = IndexedSeq.empty } + } + + /** Consumes the next group of records from the input iterator, based on the vkey, and returns them as a [[IndexedSeq]]. */ private def advance(): Boolean = { - // get the records to create the consensus read - val inputs = nextGroupOfRecords() - val outputs = this.caller.consensusReadsFromSamRecords(inputs) - this.outputQueue ++= outputs - - // Log progress on the _input_ reads and then return/recurse - for (p <- progress; r <- inputs) p.record(r) - if (outputs.nonEmpty) true - else if (this.input.hasNext) advance() - else false + this.input.headOption.exists { head => + val idToMatch = this.toKey(head) + this.nextChunk = this.input.takeWhile(this.toKey(_) == idToMatch).toIndexedSeq + true + } } } diff --git a/src/main/scala/com/fulcrumgenomics/umi/DuplexConsensusCaller.scala b/src/main/scala/com/fulcrumgenomics/umi/DuplexConsensusCaller.scala index 3b66169ac..377bae550 100644 --- a/src/main/scala/com/fulcrumgenomics/umi/DuplexConsensusCaller.scala +++ b/src/main/scala/com/fulcrumgenomics/umi/DuplexConsensusCaller.scala @@ -43,7 +43,7 @@ object DuplexConsensusCaller { val ErrorRatePostUmi: PhredScore = 40.toByte val MinInputBaseQuality: PhredScore = 15.toByte val NoCall: Byte = 'N'.toByte - val NoCallQual = PhredScore.MinValue + val NoCallQual: PhredScore = PhredScore.MinValue /** Additional filter strings used when rejecting reads. */ val FilterMinReads = "Not Enough Reads (Either Total, AB, or BA)" @@ -103,12 +103,12 @@ class DuplexConsensusCaller(override val readNamePrefix: String, val trim: Boolean = false, val errorRatePreUmi: PhredScore = DuplexConsensusCaller.ErrorRatePreUmi, val errorRatePostUmi: PhredScore = DuplexConsensusCaller.ErrorRatePostUmi, - val minReads: Seq[Int] = Seq(1) + val minReads: Seq[Int] = Seq(1), + val maxReadsPerStrand: Int = VanillaUmiConsensusCallerOptions.DefaultMaxReads ) extends UmiConsensusCaller[DuplexConsensusRead] with LazyLogging { private val Seq(minTotalReads, minXyReads, minYxReads) = this.minReads.padTo(3, this.minReads.last) - // For depth thresholds it's required that ba <= ab <= cc require(minXyReads <= minTotalReads, "min-reads values must be specified high to low.") require(minYxReads <= minXyReads, "min-reads values must be specified high to low.") @@ -117,25 +117,60 @@ class DuplexConsensusCaller(override val readNamePrefix: String, errorRatePreUmi = this.errorRatePreUmi, errorRatePostUmi = this.errorRatePostUmi, minReads = 1, + maxReads = maxReadsPerStrand, minInputBaseQuality = this.minInputBaseQuality, minConsensusBaseQuality = PhredScore.MinValue, producePerBaseTags = true )) /** - * Returns the MI tag minus the trailing suffix that identifies /A vs /B + * Returns the MI tag **with** the trailing suffix that identifies /A vs /B */ - override protected[umi] def sourceMoleculeId(rec: SamRecord): String = { - rec.get[String](ConsensusTags.MolecularId) match { + private def sourceMoleculeAndStrandId(rec: SamRecord): String = { + // Optimization: speed up retrieving this tag by storing it in the transient attributes + rec.transientAttrs.get[String](ConsensusTags) match { + case Some(mi) => mi case None => - throw new IllegalStateException(s"Read ${rec.name} is missing it's ${ConsensusTags.MolecularId} tag.") - case Some(mi) => - val index = mi.lastIndexOf('/') - require(index > 0, s"Read ${rec.name}'s $ConsensusTags tag doesn't look like a duplex id: $mi") - mi.substring(0, index) + rec.get[String](ConsensusTags.MolecularId) match { + case Some(mi) => mi + case None => throw new IllegalStateException(s"Read ${rec.name} is missing it's ${ConsensusTags.MolecularId} tag.") + } } } + /** Returns a clone of this consensus caller in a state where no previous reads were processed. I.e. all counters + * are set to zero.*/ + def emptyClone(): DuplexConsensusCaller = { + new DuplexConsensusCaller( + readNamePrefix = readNamePrefix, + readGroupId = readGroupId, + minInputBaseQuality = minInputBaseQuality, + trim = trim, + errorRatePreUmi = errorRatePreUmi, + errorRatePostUmi = errorRatePostUmi, + minReads = minReads, + maxReadsPerStrand = maxReadsPerStrand + ) + } + + // The key in a [[SamRecord]]'s transient attributes that caches the molecular identifier. The molecular identifier is + // cached by the `sourceMoleculeId` method. + private val MolecularIdNoTrailingSuffix: String = "__" + ConsensusTags.MolecularId + "__" + + /** + * Returns the MI tag **minus** the trailing suffix that identifies /A vs /B + */ + override protected[umi] def sourceMoleculeId(rec: SamRecord): String = { + // Optimization: speed up retrieving this tag by storing it in the transient attributes + rec.transientAttrs.getOrElse[String](MolecularIdNoTrailingSuffix, { + val mi = sourceMoleculeAndStrandId(rec) + val index = mi.lastIndexOf('/') + val miRoot = mi.substring(0, index) + rec.transientAttrs(MolecularIdNoTrailingSuffix) = miRoot + miRoot + }) + } + /** * Takes in all the reads for a source molecule and, if possible, generates one or more * output consensus reads as SAM records. @@ -152,7 +187,7 @@ class DuplexConsensusCaller(override val readNamePrefix: String, } else { // Group the reads by /A vs. /B and ensure that /A is the first group and /B the second - val groups = pairs.groupBy(r => r[String](ConsensusTags.MolecularId)).toSeq.sortBy { case (mi, _) => mi }.map(_._2) + val groups = pairs.groupBy(r => sourceMoleculeAndStrandId(r)).toSeq.sortBy { case (mi, _) => mi }.map(_._2) require(groups.length <= 2, "SamRecords supplied with more than two distinct MI values.") @@ -191,6 +226,22 @@ class DuplexConsensusCaller(override val readNamePrefix: String, } } + // An empty sequence of [[SamRecord]]s, used in subGroupRecords to improve performance + private val NoSamRecords: Seq[SamRecord] = Seq.empty[SamRecord] + + /** Split records into those that should make a single-end consensus read, first of pair consensus read, + * and second of pair consensus read, respectively. This method is overridden in [[DuplexConsensusCaller]] to + * improve performance since no fragment reads should be given to this method. + */ + override protected def subGroupRecords(records: Seq[SamRecord]): (Seq[SamRecord], Seq[SamRecord], Seq[SamRecord]) = { + // NB: the input records should not have fragments + val (firstOfPair, secondOfPair) = records.partition { r => + require(r.paired, "Fragment reads should not be given to subGroupRecords in DuplexConsensusCaller.") + r.firstOfPair + } + (NoSamRecords, firstOfPair, secondOfPair) + } + /** Attempts to call a duplex consensus reads from the two sets of reads, one for each strand. */ private def callDuplexConsensusRead(ab: Seq[SamRecord], ba: Seq[SamRecord]): Seq[SamRecord] = { // Fragments have no place in duplex land (and are filtered out previously anyway)! @@ -209,12 +260,12 @@ class DuplexConsensusCaller(override val readNamePrefix: String, // Check for this explicitly here. (areAllSameStrand(singleStrand1), areAllSameStrand(singleStrand2)) match { case (false, _) => - val ss1Mi = singleStrand1.head.apply[String](ConsensusTags.MolecularId) + val ss1Mi = sourceMoleculeId(singleStrand1.head) rejectRecords(ab ++ ba, FilterCollision) logger.debug(s"Not all AB-R1s and BA-R2s were on the same strand for molecule with id: $ss1Mi") Nil case (_, false) => - val ss2Mi = singleStrand2.head.apply[String](ConsensusTags.MolecularId) + val ss2Mi = sourceMoleculeId(singleStrand2.head) rejectRecords(ab ++ ba, FilterCollision) logger.debug(s"Not all AB-R2s and BA-R1s were on the same strand for molecule with id: $ss2Mi") Nil @@ -305,10 +356,20 @@ class DuplexConsensusCaller(override val readNamePrefix: String, // Then mask it if appropriate val (base, qual) = if (aBase == NoCall || bBase == NoCall || rawQual == PhredScore.MinValue) (NoCall, NoCallQual) else (rawBase, rawQual) - bases(i) = base - quals(i) = qual - - errors(i) = min(sourceReads.count(s => s.length > i && isError(s.bases(i), rawBase)), Short.MaxValue).toShort + bases(i) = base + quals(i) = qual + + // NB: optimized based on profiling; was previously: + // sourceReads.count(s => s.length > i && isError(s.bases(i), rawBase)) + var numErrors = 0 + val sourceReadsArray = sourceReads.toArray + forloop(from=0, until=sourceReadsArray.length) { j => + val sourceRead = sourceReadsArray(j) + if (sourceRead.length > i && isError(sourceRead.bases(i), rawBase)) { + numErrors += 1 + } + } + errors(i) = min(numErrors, Short.MaxValue).toShort } Some(DuplexConsensusRead(id=id, bases, quals, errors, a.truncate(bases.length), Some(b.truncate(bases.length)))) diff --git a/src/main/scala/com/fulcrumgenomics/umi/SimpleConsensusCaller.scala b/src/main/scala/com/fulcrumgenomics/umi/SimpleConsensusCaller.scala index fbb97fbaa..7776ea05a 100644 --- a/src/main/scala/com/fulcrumgenomics/umi/SimpleConsensusCaller.scala +++ b/src/main/scala/com/fulcrumgenomics/umi/SimpleConsensusCaller.scala @@ -27,6 +27,8 @@ package com.fulcrumgenomics.umi import com.fulcrumgenomics.FgBioDef.forloop import com.fulcrumgenomics.util.NumericTypes.{LogProbability, PhredScore} +import scala.collection.immutable.BitSet + /** * A class that can be mixed in to call sequences represented as strings of the same length. * @@ -51,29 +53,33 @@ private[umi] class SimpleConsensusCaller(val errorRatePreLabeling: Byte = 90.toB ).builder() private val DnaBases = Set('A', 'C', 'G', 'T', 'N', 'a', 'c', 'g', 't', 'n') + private val DnaBasesBitSet = BitSet(DnaBases.map(_.toInt).toSeq:_*) /** Calls a simple consensus sequences from a set of sequences all the same length. */ def callConsensus(sequences: Seq[String]): String = { require(sequences.nonEmpty, "Can't call consensus on an empty set of sequences!") require(sequences.forall(_.length == sequences.head.length), "Sequences must all have the same length") val buffer = new StringBuilder + val firstRead = sequences.head + val readLength = firstRead.length + val sequencesLength = sequences.length - forloop (from=0, until=sequences.head.length) { i => + forloop (from=0, until=readLength) { i => this.consensusBuilder.reset() var nonDna = 0 sequences.foreach { sequence => val char = sequence.charAt(i) - if (!this.DnaBases.contains(char)) { + if (!this.DnaBasesBitSet.contains(char.toInt)) { nonDna += 1 // verify that all non-DNA bases are the same character - require(sequences.head.charAt(i) == char, - s"Sequences must have character '${sequences.head.charAt(i)}' at position $i, found '$char'") + require(firstRead.charAt(i) == char, + s"Sequences must have character '${firstRead.charAt(i)}' at position $i, found '$char'") } else this.consensusBuilder.add(char.toByte, pError=this.pError, pTruth=this.pTruth) } if (nonDna == 0) buffer.append(this.consensusBuilder.call()._1.toChar) - else if (nonDna == sequences.length) buffer.append(sequences.head.charAt(i)) // NB: we have previously verified they are all the same character + else if (nonDna == sequencesLength) buffer.append(firstRead.charAt(i)) // NB: we have previously verified they are all the same character else throw new IllegalStateException(s"Sequences contained a mix of DNA and non-DNA characters at offset $i: $sequences") } diff --git a/src/main/scala/com/fulcrumgenomics/umi/UmiConsensusCaller.scala b/src/main/scala/com/fulcrumgenomics/umi/UmiConsensusCaller.scala index 3ce82e25d..9603c63f2 100644 --- a/src/main/scala/com/fulcrumgenomics/umi/UmiConsensusCaller.scala +++ b/src/main/scala/com/fulcrumgenomics/umi/UmiConsensusCaller.scala @@ -31,13 +31,12 @@ import com.fulcrumgenomics.commons.util.{Logger, SimpleCounter} import com.fulcrumgenomics.umi.UmiConsensusCaller._ import com.fulcrumgenomics.util.NumericTypes.PhredScore import htsjdk.samtools.SAMFileHeader.{GroupOrder, SortOrder} -import htsjdk.samtools.util.{Murmur3, SequenceUtil, TrimmingUtil} import htsjdk.samtools._ +import htsjdk.samtools.util.{Murmur3, SequenceUtil, TrimmingUtil} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.math.{abs, min} - /** * Contains shared types and functions used when writing UMI-driven consensus * callers that take in SamRecords and emit SamRecords. @@ -47,6 +46,7 @@ object UmiConsensusCaller { object ReadType extends Enumeration { type ReadType = Value val Fragment, FirstOfPair, SecondOfPair = Value + val ReadTypeKey: String = "__read_type__" } /** Filter reason for when there are too few reads to form a consensus. */ @@ -72,9 +72,9 @@ object UmiConsensusCaller { /** Gets the length of the consensus read. */ def length: Int = bases.length /** Returns the consensus read a String - mostly useful for testing. */ - def baseString = new String(bases) + def baseString: String = new String(bases) /** Retrieves the quals as a phred+33/fastq ascii String. */ - def qualString = SAMUtils.phredToFastq(this.quals) + def qualString: String = SAMUtils.phredToFastq(this.quals) } /** Stores information about a read to be fed into a consensus. */ @@ -151,15 +151,15 @@ object UmiConsensusCaller { * A trait that can be mixed in by any consensus caller that works at the read level, * mapping incoming SamRecords into consensus SamRecords. * - * @tparam C Internally, the type of lightweight consensus read that is used prior to + * @tparam ConsensusRead Internally, the type of lightweight consensus read that is used prior to * rebuilding [[com.fulcrumgenomics.bam.api.SamRecord]]s. */ -trait UmiConsensusCaller[C <: SimpleRead] { +trait UmiConsensusCaller[ConsensusRead <: SimpleRead] { import com.fulcrumgenomics.umi.UmiConsensusCaller.ReadType._ // vars to track how many reads meet various fates private var _totalReads: Long = 0 - private val filteredReads = new SimpleCounter[String]() + private val _filteredReads = new SimpleCounter[String]() private var _consensusReadsConstructed: Long = 0 protected val NoCall: Byte = 'N'.toByte @@ -168,30 +168,35 @@ trait UmiConsensusCaller[C <: SimpleRead] { /** A consensus caller used to generate consensus UMI sequences */ private val consensusBuilder = new SimpleConsensusCaller() + /** Returns a clone of this consensus caller in a state where no previous reads were processed. I.e. all counters + * are set to zero.*/ + def emptyClone(): UmiConsensusCaller[ConsensusRead] + /** Returns the total number of input reads examined by the consensus caller so far. */ def totalReads: Long = _totalReads /** Returns the total number of reads filtered for any reason. */ - def totalFiltered: Long = filteredReads.total + def totalFiltered: Long = _filteredReads.total /** * Returns the number of raw reads filtered out due to there being insufficient reads present * to build the necessary set of consensus reads. */ - def readsFilteredInsufficientSupport: Long = this.filteredReads.countOf(FilterInsufficientSupport) + def readsFilteredInsufficientSupport: Long = this._filteredReads.countOf(FilterInsufficientSupport) /** Returns the number of raw reads filtered out because their alignment disagreed with the majority alignment of * all raw reads for the same source molecule. */ - def readsFilteredMinorityAlignment: Long = this.filteredReads.countOf(FilterMinorityAlignment) + def readsFilteredMinorityAlignment: Long = this._filteredReads.countOf(FilterMinorityAlignment) /** Returns the number of consensus reads constructed by this caller. */ def consensusReadsConstructed: Long = _consensusReadsConstructed /** Records that the supplied records were rejected, and not used to build a consensus read. */ - protected def rejectRecords(recs: Traversable[SamRecord], reason: String) : Unit = { - this.filteredReads.count(reason, recs.size) - } + protected def rejectRecords(recs: Traversable[SamRecord], reason: String) : Unit = this._filteredReads.count(reason, recs.size) + + /** Records that the supplied records were rejected, and not used to build a consensus read. */ + protected def rejectRecords(reason: String, rec: SamRecord*) : Unit = rejectRecords(rec, reason) /** A RG.ID to apply to all generated reads. */ protected def readGroupId: String @@ -298,6 +303,17 @@ trait UmiConsensusCaller[C <: SimpleRead] { (fragments, firstOfPair, secondOfPair) } + /** Used it [[filterToMostCommonAlignment()]] to store a cigar string and a set of flags for which reads match. */ + private final case class AlignmentGroup(cigar: Cigar, flags: mutable.BitSet, var size: Int = 0) { + /** Adds the read at `idx` to the set included. */ + @inline def add(idx: Int): Unit = { + flags(idx) = true + size += 1 + } + + @inline def contains(idx: Int): Boolean = flags(idx) + } + /** * Takes in a non-empty seq of SamRecords and filters them such that the returned seq only contains * those reads that share the most common alignment of the read sequence to the reference. @@ -313,23 +329,18 @@ trait UmiConsensusCaller[C <: SimpleRead] { * NOTE: filtered out reads are sent to the [[rejectRecords]] method and do not need further handling */ protected[umi] def filterToMostCommonAlignment(recs: Seq[SourceRead]): Seq[SourceRead] = { - val groups = new ArrayBuffer[mutable.Buffer[SourceRead]] - val cigars = new ArrayBuffer[Cigar] - - recs.sortBy(r => -r.length).foreach { rec => - var compatible = 0 - val simpleCigar = simplifyCigar(rec.cigar) - - groups.iterator.zip(cigars.iterator).foreach { case(group, cigar) => - if (simpleCigar.isPrefixOf(cigar)) { - group += rec - compatible += 1 - } - } - - if (compatible == 0) { - groups += ArrayBuffer(rec) - cigars += simpleCigar + val groups = new ArrayBuffer[AlignmentGroup] + val sorted = recs.sortBy(r => -r.length).toIndexedSeq + + forloop (from=0, until=sorted.length) { i => + val simpleCigar = simplifyCigar(sorted(i).cigar) + var found = false + groups.foreach { g => if (simpleCigar.isPrefixOf(g.cigar)) { g.add(i); found = true } } + + if (!found) { + val newGroup = AlignmentGroup(simpleCigar, new mutable.BitSet(sorted.size)) + newGroup.add(i) + groups += newGroup } } @@ -337,11 +348,12 @@ trait UmiConsensusCaller[C <: SimpleRead] { Seq.empty } else { - val sorted = groups.sortBy(g => - g.size) - val keepers = sorted.head - val rejects = recs.filter(r => !keepers.contains(r)) - rejectRecords(rejects.flatMap(_.sam), FilterMinorityAlignment) - + val bestGroup = groups.maxBy(_.size) + val keepers = new ArrayBuffer[SourceRead](bestGroup.size) + forloop (from=0, until=sorted.length) { i => + if (bestGroup.contains(i)) keepers += sorted(i) + else sorted(i).sam.foreach(rejectRecords(FilterMinorityAlignment, _)) + } keepers } } @@ -366,7 +378,7 @@ trait UmiConsensusCaller[C <: SimpleRead] { } /** Creates a `SamRecord` from the called consensus base and qualities. */ - protected def createSamRecord(read: C, readType: ReadType, umis: Seq[String] = Seq.empty): SamRecord = { + protected def createSamRecord(read: ConsensusRead, readType: ReadType, umis: Seq[String] = Seq.empty): SamRecord = { val rec = SamRecord(null) rec.name = this.readNamePrefix + ":" + read.id rec.unmapped = true @@ -397,16 +409,22 @@ trait UmiConsensusCaller[C <: SimpleRead] { total } + /** Adds the given caller's statistics (counts) to this caller. */ + def addStatistics(caller: UmiConsensusCaller[ConsensusRead]): Unit = { + this._totalReads += caller.totalReads + this._consensusReadsConstructed += caller.consensusReadsConstructed + this._filteredReads += caller._filteredReads + } /** * Logs statistics about how many reads were seen, and how many were filtered/discarded due * to various filters. */ def logStatistics(logger: Logger): Unit = { - logger.info(f"Total Raw Reads Considered: ${totalReads}%,d.") - this.filteredReads.foreach { case (filter, count) => - logger.info(f"Raw Reads Filtered Due to $filter: ${count}%,d (${count/totalReads.toDouble}%.4f).") + logger.info(f"Total Raw Reads Considered: $totalReads%,d.") + this._filteredReads.foreach { case (filter, count) => + logger.info(f"Raw Reads Filtered Due to $filter: $count%,d (${count/totalReads.toDouble}%.4f).") } - logger.info(f"Consensus reads emitted: ${consensusReadsConstructed}%,d.") + logger.info(f"Consensus reads emitted: $consensusReadsConstructed%,d.") } } diff --git a/src/main/scala/com/fulcrumgenomics/umi/VanillaUmiConsensusCaller.scala b/src/main/scala/com/fulcrumgenomics/umi/VanillaUmiConsensusCaller.scala index 5d256a062..9b0e173dc 100644 --- a/src/main/scala/com/fulcrumgenomics/umi/VanillaUmiConsensusCaller.scala +++ b/src/main/scala/com/fulcrumgenomics/umi/VanillaUmiConsensusCaller.scala @@ -32,7 +32,6 @@ import com.fulcrumgenomics.umi.UmiConsensusCaller.ReadType._ import com.fulcrumgenomics.umi.UmiConsensusCaller._ import com.fulcrumgenomics.umi.VanillaUmiConsensusCallerOptions._ import com.fulcrumgenomics.util.NumericTypes._ - import scala.collection.mutable.ListBuffer import scala.util.Random @@ -107,14 +106,23 @@ class VanillaUmiConsensusCaller(override val readNamePrefix: String, private val NotEnoughReadsQual: PhredScore = 0.toByte // Score output when masking to N due to insufficient input reads private val TooLowQualityQual: PhredScore = 2.toByte // Score output when masking to N due to too low consensus quality - private val DnaBasesUpperCase: Array[Byte] = Array('A', 'C', 'G', 'T').map(_.toByte) - private val LogThree = LogProbability.toLogProbability(3.0) private val caller = new ConsensusCaller(errorRatePreLabeling = options.errorRatePreUmi, errorRatePostLabeling = options.errorRatePostUmi) private val random = new Random(42) + /** Returns a clone of this consensus caller in a state where no previous reads were processed. I.e. all counters + * are set to zero.*/ + def emptyClone(): VanillaUmiConsensusCaller = { + new VanillaUmiConsensusCaller( + readNamePrefix = readNamePrefix, + readGroupId = readGroupId, + options = options, + rejects = rejects + ) + } + /** Returns the value of the SAM tag directly. */ override def sourceMoleculeId(rec: SamRecord): String = rec(this.options.tag) @@ -131,8 +139,8 @@ class VanillaUmiConsensusCaller(override val readNamePrefix: String, // pairs (consensusFromSamRecords(firstOfPair), consensusFromSamRecords(secondOfPair)) match { - case (None, Some(r2)) => rejectRecords(secondOfPair, UmiConsensusCaller.FilterOrphan) - case (Some(r1), None) => rejectRecords(firstOfPair, UmiConsensusCaller.FilterOrphan) + case (None, Some(_)) => rejectRecords(secondOfPair, UmiConsensusCaller.FilterOrphan) + case (Some(_), None) => rejectRecords(firstOfPair, UmiConsensusCaller.FilterOrphan) case (None, None) => rejectRecords(firstOfPair ++ secondOfPair, UmiConsensusCaller.FilterOrphan) case (Some(r1), Some(r2)) => buffer += createSamRecord(r1, FirstOfPair, firstOfPair.flatMap(_.get[String](ConsensusTags.UmiBases))) @@ -200,10 +208,12 @@ class VanillaUmiConsensusCaller(override val readNamePrefix: String, } } + val depth = builder.contributions // NB: cache this value, as it is re-computed each time + // Call the consensus and do any additional filtering val (rawBase, rawQual) = builder.call() val (base, qual) = { - if (builder.contributions < this.options.minReads) (NoCall, NotEnoughReadsQual) + if (depth < this.options.minReads) (NoCall, NotEnoughReadsQual) else if (rawQual < this.options.minConsensusBaseQuality) (NoCall, TooLowQualityQual) else (rawBase, rawQual) } @@ -212,7 +222,6 @@ class VanillaUmiConsensusCaller(override val readNamePrefix: String, consensusQuals(positionInRead) = qual // Generate the values for depth and count of errors - val depth = builder.contributions val errors = if (rawBase == NoCall) depth else depth - builder.observations(rawBase) consensusDepths(positionInRead) = if (depth > Short.MaxValue) Short.MaxValue else depth.toShort consensusErrors(positionInRead) = if (errors > Short.MaxValue) Short.MaxValue else errors.toShort diff --git a/src/test/scala/com/fulcrumgenomics/umi/CallDuplexConsensusReadsTest.scala b/src/test/scala/com/fulcrumgenomics/umi/CallDuplexConsensusReadsTest.scala index 3ea009a81..9a234129c 100644 --- a/src/test/scala/com/fulcrumgenomics/umi/CallDuplexConsensusReadsTest.scala +++ b/src/test/scala/com/fulcrumgenomics/umi/CallDuplexConsensusReadsTest.scala @@ -91,39 +91,41 @@ class CallDuplexConsensusReadsTest extends UnitSpec { recs should have size 0 } - it should "run successfully and create consensus reads" in { - val builder = new SamBuilder(readLength=10, sort=Some(SamOrder.TemplateCoordinate)) - builder.addPair(name="ab1", start1=100, start2=100, attrs=Map(MI -> "1/A"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") - builder.addPair(name="ab2", start1=100, start2=100, attrs=Map(MI -> "1/A"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") - builder.addPair(name="ab3", start1=100, start2=100, attrs=Map(MI -> "1/A"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") - builder.addPair(name="ba1", start1=100, start2=100, strand1=Minus, strand2=Plus, attrs=Map(MI -> "1/B"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") - builder.addPair(name="ba2", start1=100, start2=100, strand1=Minus, strand2=Plus, attrs=Map(MI -> "1/B"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") - builder.addPair(name="ba3", start1=100, start2=100, strand1=Minus, strand2=Plus, attrs=Map(MI -> "1/B"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") - - // Add the original UMI bases to each read - builder.foreach { rec => - val mi = rec[String](MI) - // first of pair ABs and second of pair BAs - if ((rec.firstOfPair && mi.endsWith("/A")) || (rec.secondOfPair && mi.endsWith("/B"))) { - rec(RX) = "AAT-CCG" - } - else { - rec(RX) = "CCG-AAT" + Seq(1, 2, 4).foreach { threads => + it should s"run successfully and create consensus reads with $threads threads" in { + val builder = new SamBuilder(readLength=10, sort=Some(SamOrder.TemplateCoordinate)) + builder.addPair(name="ab1", start1=100, start2=100, attrs=Map(MI -> "1/A"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") + builder.addPair(name="ab2", start1=100, start2=100, attrs=Map(MI -> "1/A"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") + builder.addPair(name="ab3", start1=100, start2=100, attrs=Map(MI -> "1/A"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") + builder.addPair(name="ba1", start1=100, start2=100, strand1=Minus, strand2=Plus, attrs=Map(MI -> "1/B"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") + builder.addPair(name="ba2", start1=100, start2=100, strand1=Minus, strand2=Plus, attrs=Map(MI -> "1/B"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") + builder.addPair(name="ba3", start1=100, start2=100, strand1=Minus, strand2=Plus, attrs=Map(MI -> "1/B"), bases1="AAAAAAAAAA", bases2="AAAAAAAAAA") + + // Add the original UMI bases to each read + builder.foreach { rec => + val mi = rec[String](MI) + // first of pair ABs and second of pair BAs + if ((rec.firstOfPair && mi.endsWith("/A")) || (rec.secondOfPair && mi.endsWith("/B"))) { + rec(RX) = "AAT-CCG" + } + else { + rec(RX) = "CCG-AAT" + } } - } - - val in = builder.toTempFile() - val out = makeTempFile("duplex.", ".bam") - new CallDuplexConsensusReads(input=in, output=out, readGroupId="ZZ").execute() - val reader = SamSource(out) - val recs = reader.toSeq - reader.header.getReadGroups should have size 1 - reader.header.getReadGroups.iterator().next().getId shouldBe "ZZ" - recs should have size 2 - recs.foreach { rec => - rec[String](MI) shouldBe "1" - rec[String](RX) shouldBe (if (rec.firstOfPair) "AAT-CCG" else "CCG-AAT") + val in = builder.toTempFile() + val out = makeTempFile("duplex.", ".bam") + new CallDuplexConsensusReads(input=in, output=out, readGroupId="ZZ", threads=threads).execute() + val reader = SamSource(out) + val recs = reader.toSeq + + reader.header.getReadGroups should have size 1 + reader.header.getReadGroups.iterator().next().getId shouldBe "ZZ" + recs should have size 2 + recs.foreach { rec => + rec[String](MI) shouldBe "1" + rec[String](RX) shouldBe (if (rec.firstOfPair) "AAT-CCG" else "CCG-AAT") + } } } } diff --git a/src/test/scala/com/fulcrumgenomics/umi/VanillaUmiConsensusCallerTest.scala b/src/test/scala/com/fulcrumgenomics/umi/VanillaUmiConsensusCallerTest.scala index 3a4edd576..5a78f83a7 100644 --- a/src/test/scala/com/fulcrumgenomics/umi/VanillaUmiConsensusCallerTest.scala +++ b/src/test/scala/com/fulcrumgenomics/umi/VanillaUmiConsensusCallerTest.scala @@ -33,8 +33,8 @@ import com.fulcrumgenomics.umi.VanillaUmiConsensusCallerOptions._ import com.fulcrumgenomics.util.NumericTypes._ import htsjdk.samtools.SAMUtils import htsjdk.samtools.util.CloserUtil -import net.jafama.FastMath._ import org.scalatest.OptionValues +import org.apache.commons.math3.util.FastMath._ import scala.collection.mutable.ArrayBuffer @@ -453,6 +453,11 @@ class VanillaUmiConsensusCallerTest extends UnitSpec with OptionValues { output.map(_.cigar.toString()).sorted shouldBe expected } + it should "return a single read if a single read was given" in { + val srcs = Seq(src(cigar="50M")) + val recs = cc().filterToMostCommonAlignment(srcs) + recs should have size 1 + } "VanillaConsensusCaller.toSourceRead" should "mask bases that are below the quality threshold" in { val builder = new SamBuilder(readLength=10)