Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Addition of threading in GroupReadsByUmi and some other performance optimizations. #950

Merged
merged 5 commits into from
Dec 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 111 additions & 32 deletions src/main/scala/com/fulcrumgenomics/umi/GroupReadsByUmi.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@ import htsjdk.samtools._
import htsjdk.samtools.util.SequenceUtil

import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{Executors, ForkJoinPool}
import scala.collection.immutable.IndexedSeq
import scala.collection.mutable.ListBuffer
import scala.collection.parallel.ExecutionContextTaskSupport
import scala.collection.{BufferedIterator, Iterator, mutable}
import scala.concurrent.ExecutionContext


object GroupReadsByUmi {
Expand Down Expand Up @@ -210,9 +213,14 @@ object GroupReadsByUmi {
* Class that implements the directed adjacency graph method from umi_tools.
* See: https://github.com/CGATOxford/UMI-tools
*/
private[umi] class AdjacencyUmiAssigner(val maxMismatches: Int) extends UmiAssigner {
private[umi] class AdjacencyUmiAssigner(final val maxMismatches: Int, val threads: Int = 1) extends UmiAssigner {
private val taskSupport = if (threads < 2) None else {
val ctx = ExecutionContext.fromExecutorService(Executors.newFixedThreadPool(threads))
Some(new ExecutionContextTaskSupport(ctx))
}

/** Represents a node in the adjacency graph; equality is just by UMI sequence. */
class Node(val umi: Umi, val count: Long, val children: mutable.Buffer[Node] = mutable.Buffer()) {
final class Node(val umi: Umi, val count: Int, val children: mutable.Buffer[Node] = mutable.Buffer(), var assigned: Boolean = false) {
/** Gets the full set of descendants from this node. */
def descendants: List[Node] = {
val buffer = ListBuffer[Node]()
Expand All @@ -235,12 +243,12 @@ object GroupReadsByUmi {

/** Returns whether or not a pair of UMIs match closely enough to be considered adjacent in the graph. */
protected def matches(lhs: Umi, rhs: Umi): Boolean = {
nh13 marked this conversation as resolved.
Show resolved Hide resolved
require(lhs.length == rhs.length, s"UMIs of different length detected: $lhs vs. $rhs")
nh13 marked this conversation as resolved.
Show resolved Hide resolved
val len = lhs.length
var idx = 0
var mismatches = 0
val len = lhs.length
while (idx < len && mismatches <= this.maxMismatches) {
if (lhs(idx) != rhs(idx)) mismatches += 1
val tooManyMismatches = this.maxMismatches + 1
nh13 marked this conversation as resolved.
Show resolved Hide resolved
while (mismatches < tooManyMismatches && idx < len) {
if (lhs.charAt(idx) != rhs.charAt(idx)) mismatches += 1
idx += 1
}

Expand All @@ -260,39 +268,97 @@ object GroupReadsByUmi {
}

override def assign(rawUmis: Seq[Umi]): Map[Umi, MoleculeId] = {
tfenne marked this conversation as resolved.
Show resolved Hide resolved
// A list of all the root UMIs/Nodes that we find
val roots = IndexedSeq.newBuilder[Node]

// Make a list of counts of all UMIs in order from most to least abundant; we'll consume from this buffer
var remaining = count(rawUmis).map{ case(umi,count) => new Node(umi, count) }.toBuffer.sortBy((n:Node) => -n.count)

// Now build one or more graphs starting with the most abundant remaining umi
while (remaining.nonEmpty) {
val nextRoot = remaining.remove(0)
roots += nextRoot
val working = mutable.Buffer[Node](nextRoot)

while (working.nonEmpty) {
val root = working.remove(0)
val (hits, misses) = remaining.partition(other => root.count >= 2 * other.count - 1 && matches(root.umi, other.umi))
root.children ++= hits
working ++= hits
remaining = misses
val orderedNodes = count(rawUmis).map{ case(umi,count) => new Node(umi, count.toInt) }.toIndexedSeq.sortBy((n:Node) => -n.count)

if (orderedNodes.length == 1) {
orderedNodes.head.assigned = true
assignIdsToNodes(orderedNodes)
}
else {
val umiLength = orderedNodes.head.umi.length
require(orderedNodes.forall(_.umi.length == umiLength), f"Multiple UMI lengths: ${orderedNodes.map(_.umi).mkString(", ")}")
val lookup = countIndexLookup(orderedNodes) // Seq of (count, firstIdx) pairs

// A list of all the root UMIs/Nodes that we find
val roots = Seq.newBuilder[Node]

// Now build one or more graphs starting with the most abundant remaining umi
val working = mutable.Queue[Node]()
forloop (from=0, until=orderedNodes.length) { rootIdx =>
val nextRoot = orderedNodes(rootIdx)

if (!nextRoot.assigned) {
roots += nextRoot
working.enqueue(nextRoot)

while (working.nonEmpty) {
val root = working.dequeue()
root.assigned = true
val maxChildCountPlusOne = (root.count / 2 + 1) + 1
val searchFromIdx = lookup
.find { case (count, _) => count < maxChildCountPlusOne }
.map { case (_, idx) => idx }
.getOrElse(-1)

if (searchFromIdx >= 0) {
val hits = taskSupport match {
case None =>
orderedNodes
.drop(searchFromIdx)
.filter(other => !other.assigned && matches(root.umi, other.umi))
case Some(ts) =>
orderedNodes
.drop(searchFromIdx)
.parWith(ts)
.filter(other => !other.assigned && matches(root.umi, other.umi))
.seq
}

root.children ++= hits
working.enqueueAll(hits)
hits.foreach(_.assigned = true)
}
}
}
}

assignIdsToNodes(roots.result())
}
}

/**
* Generates an indexed seq to enable fast identification of the first index in `nodes` where
nh13 marked this conversation as resolved.
Show resolved Hide resolved
* a given UMI count is observed. Assumes that the input is sorted from most abundant to least
* abundant. The generated Seq will contain one entry for every unique count seen; the second value
* in the tuple is the first index in the list of input nodes with that count.
*
* E.g. given nodes with counts [10, 10, 10, 9, 3, 3, 3, 3, 2, 1], the output would be:
* [(10, 0), (9, 3), (3, 4) (2, 8), (1, 9)]
*/
private def countIndexLookup(nodes: IndexedSeq[Node]): IndexedSeq[(Int, Int)] = {
val builder = IndexedSeq.newBuilder[(Int, Int)]
val iter = nodes.iterator.zipWithIndex.bufferBetter

while (iter.hasNext) {
val (currNode, currIdx) = iter.next
builder += ((currNode.count, currIdx))
iter.dropWhile { case (node, _) => node.count == currNode.count}
}

assignIdsToNodes(roots.result())
builder.result()
}
}



/**
* Version of the adjacency assigner that works for paired UMIs stored as a single tag of
* the form A-B where reads with A-B and B-A are related but not identical.
*
* @param maxMismatches the maximum number of mismatches between UMIs
*/
class PairedUmiAssigner(maxMismatches: Int) extends AdjacencyUmiAssigner(maxMismatches) {
class PairedUmiAssigner(maxMismatches: Int, threads: Int = 1) extends AdjacencyUmiAssigner(maxMismatches, threads) {
/** String that is prefixed onto the UMI from the read with that maps to a lower coordinate in the genome.. */
private[umi] val lowerReadUmiPrefix: String = ("a" * (maxMismatches+1)) + ":"

Expand Down Expand Up @@ -402,27 +468,27 @@ case class TagFamilySizeMetric(family_size: Int,

/** The strategies implemented by [[GroupReadsByUmi]] to identify reads from the same source molecule.*/
sealed trait Strategy extends EnumEntry {
def newStrategy(edits: Int): UmiAssigner
def newStrategy(edits: Int, threads: Int): UmiAssigner
}
object Strategy extends FgBioEnum[Strategy] {
def values: IndexedSeq[Strategy] = findValues
/** Strategy to only reads with identical UMI sequences are grouped together. */
case object Identity extends Strategy {
def newStrategy(edits: Int = 0): UmiAssigner = {
def newStrategy(edits: Int = 0, threads: Int = 0): UmiAssigner = {
require(edits == 0, "Edits should be zero when using the identity UMI assigner.")
new IdentityUmiAssigner
}
}
/** Strategy to cluster reads into groups based on mismatches between reads in clusters. */
case object Edit extends Strategy { def newStrategy(edits: Int): UmiAssigner = new SimpleErrorUmiAssigner(edits) }
case object Edit extends Strategy { def newStrategy(edits: Int, threads: Int = 0): UmiAssigner = new SimpleErrorUmiAssigner(edits) }
/** Strategy based on the directed adjacency method described in [umi_tools](http://dx.doi.org/10.1101/051755)
* that allows for errors between UMIs but only when there is a count gradient.
*/
case object Adjacency extends Strategy { def newStrategy(edits: Int): UmiAssigner = new AdjacencyUmiAssigner(edits) }
case object Adjacency extends Strategy { def newStrategy(edits: Int, threads: Int = 1): UmiAssigner = new AdjacencyUmiAssigner(edits, threads) }
/** Strategy similar to the [[Adjacency]] strategy similar to adjacency but for methods that produce template with a
* pair of UMIs such that a read with A-B is related to but not identical to a read with B-A.
*/
case object Paired extends Strategy { def newStrategy(edits: Int): UmiAssigner = new PairedUmiAssigner(edits)}
case object Paired extends Strategy { def newStrategy(edits: Int, threads: Int = 1): UmiAssigner = new PairedUmiAssigner(edits, threads)}
}

@clp(group=ClpGroups.Umi, description =
Expand Down Expand Up @@ -491,6 +557,11 @@ object Strategy extends FgBioEnum[Strategy] {
| 1. `--min-map-q` defaults to 0 in duplicate marking mode and 1 otherwise
| 2. `--include-secondary` defaults to true in duplicate marking mode and false otherwise
| 3. `--include-supplementary` defaults to true in duplicate marking mode and false otherwise
|
|Multi-threaded operation is supported via the `--threads/-@` option. This only applies to the Adjacency and Paired
|strategies. Additionally the only operation that is multi-threaded is the comparisons of UMIs at the same genomic
|position. Running with e.g. `--threads 8` can provide a _substantial_ reduction in runtime when there are many
|UMIs observed at the same genomic location, such as can occur in amplicon sequencing or ultra-deep coverage data.
"""
)
class GroupReadsByUmi
Expand All @@ -513,13 +584,14 @@ class GroupReadsByUmi
@arg(flag='x', doc= """
|DEPRECATED: this option will be removed in future versions and inter-contig reads will be
|automatically processed.""")
@deprecated val allowInterContig: Boolean = true
@deprecated val allowInterContig: Boolean = true,
@arg(flag='@', doc="Number of threads to use when comparing UMIs. Only recommended for amplicon or similar data.") val threads: Int = 1,
)extends FgBioTool with LazyLogging {
import GroupReadsByUmi._

require(this.minUmiLength.forall(_ => this.strategy != Strategy.Paired), "Paired strategy cannot be used with --min-umi-length")

private val assigner = strategy.newStrategy(this.edits)
private val assigner = strategy.newStrategy(this.edits, this.threads)

// Give values to unset parameters that are different in duplicate marking mode
private val _minMapQ = this.minMapQ.getOrElse(if (this.markDuplicates) 0 else 1)
Expand Down Expand Up @@ -705,13 +777,20 @@ class GroupReadsByUmi
* sub-grouping into UMI groups by original molecule.
*/
def assignUmiGroups(templates: Seq[Template]): Unit = {
val startMs = System.currentTimeMillis
val umis = truncateUmis(templates.map { t => umiForRead(t) })
val rawToId = this.assigner.assign(umis)

templates.iterator.zip(umis.iterator).foreach { case (template, umi) =>
val id = rawToId(umi)
template.primaryReads.foreach(r => r(this.assignTag) = id)
}

val endMs = System.currentTimeMillis()
val durationMs = endMs - startMs
if (durationMs >= 2500) {
logger.debug(f"Grouped ${rawToId.size}%,d UMIs from ${templates.size}%,d templates in ${durationMs}%,d ms." )
}
}

/** When a minimum UMI length is specified, truncates all the UMIs to the length of the shortest UMI. For the paired
Expand Down
26 changes: 13 additions & 13 deletions src/test/scala/com/fulcrumgenomics/umi/GroupReadsByUmiTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -84,31 +84,31 @@ class GroupReadsByUmiTest extends UnitSpec with OptionValues with PrivateMethodT
group(assigner.assign(umis)) should contain theSameElementsAs Set(Set("AAAAAA", "AAAATA", "AAAATT", "GGCGGC", "TGCACC", "TGCACG"))
}

{
"AdjacencyUmiAssigner" should "assign each UMI to separate groups" in {
Seq(1, 4).foreach { threads =>
"AdjacencyUmiAssigner" should s"assign each UMI to separate groups with $threads thread(s)" in {
val umis = Seq("AAAAAA", "CCCCCC", "GGGGGG", "TTTTTT", "AAATTT", "TTTAAA", "AGAGAG")
val groups = group(new AdjacencyUmiAssigner(maxMismatches=2).assign(umis))
val groups = group(new AdjacencyUmiAssigner(maxMismatches=2, threads=threads).assign(umis))
groups shouldBe umis.map(Set(_)).toSet
}

it should "assign everything into one group when all counts=1 and within mismatch threshold" in {
it should f"assign everything into one group when all counts=1 and within mismatch threshold with $threads thread(s)" in {
val umis = Seq("AAAAAA", "AAAAAc", "AAAAAg").map(_.toUpperCase)
val groups = group(new AdjacencyUmiAssigner(maxMismatches=1).assign(umis))
val groups = group(new AdjacencyUmiAssigner(maxMismatches=1, threads=threads).assign(umis))
groups shouldBe Set(umis.toSet)
}

it should "assign everything into one group" in {
it should f"assign everything into one group with $threads thread(s)" in {
val umis = Seq("AAAAAA", "AAAAAA", "AAAAAA", "AAAAAc", "AAAAAc", "AAAAAg", "AAAtAA").map(_.toUpperCase)
val groups = group(new AdjacencyUmiAssigner(maxMismatches=1).assign(umis))
val groups = group(new AdjacencyUmiAssigner(maxMismatches=1, threads=threads).assign(umis))
groups shouldBe Set(umis.toSet)
}

it should "make three groups" in {
it should f"make three groups with $threads thread(s)" in {
val umis: Seq[String] = n("AAAAAA", 4) ++ n("AAAAAT", 2) ++ n("AATAAT", 1) ++ n("AATAAA", 2) ++
n("GACGAC", 9) ++ n("GACGAT", 1) ++ n("GACGCC", 4) ++
n("TACGAC", 7)

val groups = group(new AdjacencyUmiAssigner(maxMismatches=2).assign(umis))
val groups = group(new AdjacencyUmiAssigner(maxMismatches=2, threads=threads).assign(umis))
groups shouldBe Set(
Set("AAAAAA", "AAAAAT", "AATAAT", "AATAAA"),
Set("GACGAC", "GACGAT", "GACGCC"),
Expand All @@ -117,15 +117,15 @@ class GroupReadsByUmiTest extends UnitSpec with OptionValues with PrivateMethodT
}

// Unit test for something that failed when running on real data
it should "correctly assign the following UMIs" in {
it should f"correctly assign the following UMIs with $threads thread(s)" in {
val umis = Seq("CGGGGG", "GTGGGG", "GGGGGG", "CTCACA", "TGCAGT", "CTCACA", "CGGGGG")
val groups = group(new AdjacencyUmiAssigner(maxMismatches=1).assign(umis))
val groups = group(new AdjacencyUmiAssigner(maxMismatches=1, threads=threads).assign(umis))
groups shouldBe Set(Set("CGGGGG", "GGGGGG", "GTGGGG"), Set("CTCACA"), Set("TGCAGT"))
}

it should "handle a deep tree of UMIs" in {
it should f"handle a deep tree of UMIs with $threads thread(s)" in {
val umis = n("AAAAAA", 256) ++ n("TAAAAA", 128) ++ n("TTAAAA", 64) ++ n("TTTAAA", 32) ++ n("TTTTAA", 16)
val groups = group(new AdjacencyUmiAssigner(maxMismatches=1).assign(umis))
val groups = group(new AdjacencyUmiAssigner(maxMismatches=1, threads=threads).assign(umis))
groups shouldBe Set(Set("AAAAAA", "TAAAAA", "TTAAAA", "TTTAAA", "TTTTAA"))
}
}
Expand Down
Loading