diff --git a/src/main/scala/com/fulcrumgenomics/umi/GroupReadsByUmi.scala b/src/main/scala/com/fulcrumgenomics/umi/GroupReadsByUmi.scala index 1559b9f7c..2a0520ef1 100644 --- a/src/main/scala/com/fulcrumgenomics/umi/GroupReadsByUmi.scala +++ b/src/main/scala/com/fulcrumgenomics/umi/GroupReadsByUmi.scala @@ -43,9 +43,12 @@ import htsjdk.samtools._ import htsjdk.samtools.util.SequenceUtil import java.util.concurrent.atomic.AtomicLong +import java.util.concurrent.{Executors, ForkJoinPool} import scala.collection.immutable.IndexedSeq import scala.collection.mutable.ListBuffer +import scala.collection.parallel.ExecutionContextTaskSupport import scala.collection.{BufferedIterator, Iterator, mutable} +import scala.concurrent.ExecutionContext object GroupReadsByUmi { @@ -210,9 +213,14 @@ object GroupReadsByUmi { * Class that implements the directed adjacency graph method from umi_tools. * See: https://github.com/CGATOxford/UMI-tools */ - private[umi] class AdjacencyUmiAssigner(val maxMismatches: Int) extends UmiAssigner { + private[umi] class AdjacencyUmiAssigner(final val maxMismatches: Int, val threads: Int = 1) extends UmiAssigner { + private val taskSupport = if (threads < 2) None else { + val ctx = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(threads)) + Some(new ExecutionContextTaskSupport(ctx)) + } + /** Represents a node in the adjacency graph; equality is just by UMI sequence. */ - class Node(val umi: Umi, val count: Long, val children: mutable.Buffer[Node] = mutable.Buffer()) { + final class Node(val umi: Umi, val count: Int, val children: mutable.Buffer[Node] = mutable.Buffer(), var assigned: Boolean = false) { /** Gets the full set of descendants from this node. */ def descendants: List[Node] = { val buffer = ListBuffer[Node]() @@ -235,12 +243,12 @@ object GroupReadsByUmi { /** Returns whether or not a pair of UMIs match closely enough to be considered adjacent in the graph. */ protected def matches(lhs: Umi, rhs: Umi): Boolean = { - require(lhs.length == rhs.length, s"UMIs of different length detected: $lhs vs. $rhs") + val len = lhs.length var idx = 0 var mismatches = 0 - val len = lhs.length - while (idx < len && mismatches <= this.maxMismatches) { - if (lhs(idx) != rhs(idx)) mismatches += 1 + val tooManyMismatches = this.maxMismatches + 1 + while (mismatches < tooManyMismatches && idx < len) { + if (lhs.charAt(idx) != rhs.charAt(idx)) mismatches += 1 idx += 1 } @@ -260,39 +268,97 @@ object GroupReadsByUmi { } override def assign(rawUmis: Seq[Umi]): Map[Umi, MoleculeId] = { - // A list of all the root UMIs/Nodes that we find - val roots = IndexedSeq.newBuilder[Node] - // Make a list of counts of all UMIs in order from most to least abundant; we'll consume from this buffer - var remaining = count(rawUmis).map{ case(umi,count) => new Node(umi, count) }.toBuffer.sortBy((n:Node) => -n.count) - - // Now build one or more graphs starting with the most abundant remaining umi - while (remaining.nonEmpty) { - val nextRoot = remaining.remove(0) - roots += nextRoot - val working = mutable.Buffer[Node](nextRoot) - - while (working.nonEmpty) { - val root = working.remove(0) - val (hits, misses) = remaining.partition(other => root.count >= 2 * other.count - 1 && matches(root.umi, other.umi)) - root.children ++= hits - working ++= hits - remaining = misses + val orderedNodes = count(rawUmis).map{ case(umi,count) => new Node(umi, count.toInt) }.toIndexedSeq.sortBy((n:Node) => -n.count) + + if (orderedNodes.length == 1) { + orderedNodes.head.assigned = true + assignIdsToNodes(orderedNodes) + } + else { + val umiLength = orderedNodes.head.umi.length + require(orderedNodes.forall(_.umi.length == umiLength), f"Multiple UMI lengths: ${orderedNodes.map(_.umi).mkString(", ")}") + val lookup = countIndexLookup(orderedNodes) // Seq of (count, firstIdx) pairs + + // A list of all the root UMIs/Nodes that we find + val roots = Seq.newBuilder[Node] + + // Now build one or more graphs starting with the most abundant remaining umi + val working = mutable.Queue[Node]() + forloop (from=0, until=orderedNodes.length) { rootIdx => + val nextRoot = orderedNodes(rootIdx) + + if (!nextRoot.assigned) { + roots += nextRoot + working.enqueue(nextRoot) + + while (working.nonEmpty) { + val root = working.dequeue() + root.assigned = true + val maxChildCountPlusOne = (root.count / 2 + 1) + 1 + val searchFromIdx = lookup + .find { case (count, _) => count < maxChildCountPlusOne } + .map { case (_, idx) => idx } + .getOrElse(-1) + + if (searchFromIdx >= 0) { + val hits = taskSupport match { + case None => + orderedNodes + .drop(searchFromIdx) + .filter(other => !other.assigned && matches(root.umi, other.umi)) + case Some(ts) => + orderedNodes + .drop(searchFromIdx) + .parWith(ts) + .filter(other => !other.assigned && matches(root.umi, other.umi)) + .seq + } + + root.children ++= hits + working.enqueueAll(hits) + hits.foreach(_.assigned = true) + } + } + } } + + assignIdsToNodes(roots.result()) + } + } + + /** + * Generates an indexed seq to enable fast identification of the first index in `nodes` where + * a given UMI count is observed. Assumes that the input is sorted from most abundant to least + * abundant. The generated Seq will contain one entry for every unique count seen; the second value + * in the tuple is the first index in the list of input nodes with that count. + * + * E.g. given nodes with counts [10, 10, 10, 9, 3, 3, 3, 3, 2, 1], the output would be: + * [(10, 0), (9, 3), (3, 4) (2, 8), (1, 9)] + */ + private def countIndexLookup(nodes: IndexedSeq[Node]): IndexedSeq[(Int, Int)] = { + val builder = IndexedSeq.newBuilder[(Int, Int)] + val iter = nodes.iterator.zipWithIndex.bufferBetter + + while (iter.hasNext) { + val (currNode, currIdx) = iter.next + builder += ((currNode.count, currIdx)) + iter.dropWhile { case (node, _) => node.count == currNode.count} } - assignIdsToNodes(roots.result()) + builder.result() } } + /** * Version of the adjacency assigner that works for paired UMIs stored as a single tag of * the form A-B where reads with A-B and B-A are related but not identical. * * @param maxMismatches the maximum number of mismatches between UMIs */ - class PairedUmiAssigner(maxMismatches: Int) extends AdjacencyUmiAssigner(maxMismatches) { + class PairedUmiAssigner(maxMismatches: Int, threads: Int = 1) extends AdjacencyUmiAssigner(maxMismatches, threads) { /** String that is prefixed onto the UMI from the read with that maps to a lower coordinate in the genome.. */ private[umi] val lowerReadUmiPrefix: String = ("a" * (maxMismatches+1)) + ":" @@ -402,27 +468,27 @@ case class TagFamilySizeMetric(family_size: Int, /** The strategies implemented by [[GroupReadsByUmi]] to identify reads from the same source molecule.*/ sealed trait Strategy extends EnumEntry { - def newStrategy(edits: Int): UmiAssigner + def newStrategy(edits: Int, threads: Int): UmiAssigner } object Strategy extends FgBioEnum[Strategy] { def values: IndexedSeq[Strategy] = findValues /** Strategy to only reads with identical UMI sequences are grouped together. */ case object Identity extends Strategy { - def newStrategy(edits: Int = 0): UmiAssigner = { + def newStrategy(edits: Int = 0, threads: Int = 0): UmiAssigner = { require(edits == 0, "Edits should be zero when using the identity UMI assigner.") new IdentityUmiAssigner } } /** Strategy to cluster reads into groups based on mismatches between reads in clusters. */ - case object Edit extends Strategy { def newStrategy(edits: Int): UmiAssigner = new SimpleErrorUmiAssigner(edits) } + case object Edit extends Strategy { def newStrategy(edits: Int, threads: Int = 0): UmiAssigner = new SimpleErrorUmiAssigner(edits) } /** Strategy based on the directed adjacency method described in [umi_tools](http://dx.doi.org/10.1101/051755) * that allows for errors between UMIs but only when there is a count gradient. */ - case object Adjacency extends Strategy { def newStrategy(edits: Int): UmiAssigner = new AdjacencyUmiAssigner(edits) } + case object Adjacency extends Strategy { def newStrategy(edits: Int, threads: Int = 1): UmiAssigner = new AdjacencyUmiAssigner(edits, threads) } /** Strategy similar to the [[Adjacency]] strategy similar to adjacency but for methods that produce template with a * pair of UMIs such that a read with A-B is related to but not identical to a read with B-A. */ - case object Paired extends Strategy { def newStrategy(edits: Int): UmiAssigner = new PairedUmiAssigner(edits)} + case object Paired extends Strategy { def newStrategy(edits: Int, threads: Int = 1): UmiAssigner = new PairedUmiAssigner(edits, threads)} } @clp(group=ClpGroups.Umi, description = @@ -491,6 +557,11 @@ object Strategy extends FgBioEnum[Strategy] { | 1. `--min-map-q` defaults to 0 in duplicate marking mode and 1 otherwise | 2. `--include-secondary` defaults to true in duplicate marking mode and false otherwise | 3. `--include-supplementary` defaults to true in duplicate marking mode and false otherwise + | + |Multi-threaded operation is supported via the `--threads/-@` option. This only applies to the Adjacency and Paired + |strategies. Additionally the only operation that is multi-threaded is the comparisons of UMIs at the same genomic + |position. Running with e.g. `--threads 8` can provide a _substantial_ reduction in runtime when there are many + |UMIs observed at the same genomic location, such as can occur in amplicon sequencing or ultra-deep coverage data. """ ) class GroupReadsByUmi @@ -513,13 +584,14 @@ class GroupReadsByUmi @arg(flag='x', doc= """ |DEPRECATED: this option will be removed in future versions and inter-contig reads will be |automatically processed.""") - @deprecated val allowInterContig: Boolean = true + @deprecated val allowInterContig: Boolean = true, + @arg(flag='@', doc="Number of threads to use when comparing UMIs. Only recommended for amplicon or similar data.") val threads: Int = 1, )extends FgBioTool with LazyLogging { import GroupReadsByUmi._ require(this.minUmiLength.forall(_ => this.strategy != Strategy.Paired), "Paired strategy cannot be used with --min-umi-length") - private val assigner = strategy.newStrategy(this.edits) + private val assigner = strategy.newStrategy(this.edits, this.threads) // Give values to unset parameters that are different in duplicate marking mode private val _minMapQ = this.minMapQ.getOrElse(if (this.markDuplicates) 0 else 1) @@ -705,6 +777,7 @@ class GroupReadsByUmi * sub-grouping into UMI groups by original molecule. */ def assignUmiGroups(templates: Seq[Template]): Unit = { + val startMs = System.currentTimeMillis val umis = truncateUmis(templates.map { t => umiForRead(t) }) val rawToId = this.assigner.assign(umis) @@ -712,6 +785,12 @@ class GroupReadsByUmi val id = rawToId(umi) template.primaryReads.foreach(r => r(this.assignTag) = id) } + + val endMs = System.currentTimeMillis() + val durationMs = endMs - startMs + if (durationMs >= 2500) { + logger.debug(f"Grouped ${rawToId.size}%,d UMIs from ${templates.size}%,d templates in ${durationMs}%,d ms." ) + } } /** When a minimum UMI length is specified, truncates all the UMIs to the length of the shortest UMI. For the paired diff --git a/src/test/scala/com/fulcrumgenomics/umi/GroupReadsByUmiTest.scala b/src/test/scala/com/fulcrumgenomics/umi/GroupReadsByUmiTest.scala index baec12e8e..77a5dc8f6 100644 --- a/src/test/scala/com/fulcrumgenomics/umi/GroupReadsByUmiTest.scala +++ b/src/test/scala/com/fulcrumgenomics/umi/GroupReadsByUmiTest.scala @@ -84,31 +84,31 @@ class GroupReadsByUmiTest extends UnitSpec with OptionValues with PrivateMethodT group(assigner.assign(umis)) should contain theSameElementsAs Set(Set("AAAAAA", "AAAATA", "AAAATT", "GGCGGC", "TGCACC", "TGCACG")) } - { - "AdjacencyUmiAssigner" should "assign each UMI to separate groups" in { + Seq(1, 4).foreach { threads => + "AdjacencyUmiAssigner" should s"assign each UMI to separate groups with $threads thread(s)" in { val umis = Seq("AAAAAA", "CCCCCC", "GGGGGG", "TTTTTT", "AAATTT", "TTTAAA", "AGAGAG") - val groups = group(new AdjacencyUmiAssigner(maxMismatches=2).assign(umis)) + val groups = group(new AdjacencyUmiAssigner(maxMismatches=2, threads=threads).assign(umis)) groups shouldBe umis.map(Set(_)).toSet } - it should "assign everything into one group when all counts=1 and within mismatch threshold" in { + it should f"assign everything into one group when all counts=1 and within mismatch threshold with $threads thread(s)" in { val umis = Seq("AAAAAA", "AAAAAc", "AAAAAg").map(_.toUpperCase) - val groups = group(new AdjacencyUmiAssigner(maxMismatches=1).assign(umis)) + val groups = group(new AdjacencyUmiAssigner(maxMismatches=1, threads=threads).assign(umis)) groups shouldBe Set(umis.toSet) } - it should "assign everything into one group" in { + it should f"assign everything into one group with $threads thread(s)" in { val umis = Seq("AAAAAA", "AAAAAA", "AAAAAA", "AAAAAc", "AAAAAc", "AAAAAg", "AAAtAA").map(_.toUpperCase) - val groups = group(new AdjacencyUmiAssigner(maxMismatches=1).assign(umis)) + val groups = group(new AdjacencyUmiAssigner(maxMismatches=1, threads=threads).assign(umis)) groups shouldBe Set(umis.toSet) } - it should "make three groups" in { + it should f"make three groups with $threads thread(s)" in { val umis: Seq[String] = n("AAAAAA", 4) ++ n("AAAAAT", 2) ++ n("AATAAT", 1) ++ n("AATAAA", 2) ++ n("GACGAC", 9) ++ n("GACGAT", 1) ++ n("GACGCC", 4) ++ n("TACGAC", 7) - val groups = group(new AdjacencyUmiAssigner(maxMismatches=2).assign(umis)) + val groups = group(new AdjacencyUmiAssigner(maxMismatches=2, threads=threads).assign(umis)) groups shouldBe Set( Set("AAAAAA", "AAAAAT", "AATAAT", "AATAAA"), Set("GACGAC", "GACGAT", "GACGCC"), @@ -117,15 +117,15 @@ class GroupReadsByUmiTest extends UnitSpec with OptionValues with PrivateMethodT } // Unit test for something that failed when running on real data - it should "correctly assign the following UMIs" in { + it should f"correctly assign the following UMIs with $threads thread(s)" in { val umis = Seq("CGGGGG", "GTGGGG", "GGGGGG", "CTCACA", "TGCAGT", "CTCACA", "CGGGGG") - val groups = group(new AdjacencyUmiAssigner(maxMismatches=1).assign(umis)) + val groups = group(new AdjacencyUmiAssigner(maxMismatches=1, threads=threads).assign(umis)) groups shouldBe Set(Set("CGGGGG", "GGGGGG", "GTGGGG"), Set("CTCACA"), Set("TGCAGT")) } - it should "handle a deep tree of UMIs" in { + it should f"handle a deep tree of UMIs with $threads thread(s)" in { val umis = n("AAAAAA", 256) ++ n("TAAAAA", 128) ++ n("TTAAAA", 64) ++ n("TTTAAA", 32) ++ n("TTTTAA", 16) - val groups = group(new AdjacencyUmiAssigner(maxMismatches=1).assign(umis)) + val groups = group(new AdjacencyUmiAssigner(maxMismatches=1, threads=threads).assign(umis)) groups shouldBe Set(Set("AAAAAA", "TAAAAA", "TTAAAA", "TTTAAA", "TTTTAA")) } }