Skip to content

Commit

Permalink
Speed up QTree
Browse files Browse the repository at this point in the history
Heavily driven by benchmarks. Avoid allocations, dealing with options or similar. Attempt to keep as much external  source compatibility as possible
  • Loading branch information
ianoc committed Jul 31, 2015
1 parent 0e52dbc commit e8f869b
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 91 deletions.
262 changes: 172 additions & 90 deletions algebird-core/src/main/scala/com/twitter/algebird/QTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,39 @@ object QTree {
/**
* level gives a bin size of 2^level. By default the bin size is 1/65536 (level = -16)
*/
def apply[A](kv: (Double, A), level: Int = DefaultLevel): QTree[A] =
QTree(math.floor(kv._1 / math.pow(2.0, level)).toLong,
def apply[A](kv: (Double, A), level: Int = DefaultLevel): QTree[A] = {
val offset = math.floor(kv._1 / math.pow(2.0, level)).toLong
require(offset >= 0, "QTree can not accept negative values")

new QTree(offset,
level,
1,
kv._2,
None,
None)
null,
null)
}

def apply[A](kv: (Long, A)): QTree[A] = {
require(kv._1 >= 0, "QTree can not accept negative values")

def apply[A](kv: (Long, A)): QTree[A] =
QTree(kv._1,
new QTree(kv._1,
0,
1,
kv._2,
None,
None)
null,
null)
}

def apply[A](offset: Long,
level: Int,
count: Long,
sum: A, //the sum at just this node (*not* including its children)
lowerChild: Option[QTree[A]],
upperChild: Option[QTree[A]]): QTree[A] = {
require(offset >= 0, "QTree can not accept negative values")

new QTree(offset, level, count, sum, lowerChild.orNull, upperChild.orNull)
}

/**
* The common case of wanting an offset and sum for the same value
Expand All @@ -73,6 +91,12 @@ object QTree {
*/
def apply(k: Double): QTree[Double] = apply(k -> k)

/**
* End user consumable unapply for QTree
*/
def unapply[A](qtree: QTree[A]): Option[(Long, Int, Long, A, Option[QTree[A]], Option[QTree[A]])] =
Some((qtree.offset, qtree.level, qtree.count, qtree.sum, qtree.lowerChild, qtree.upperChild))

/**
* If you are sure you only care about the approximate histogram
* features of QTree, you can save some space by using QTree[Unit]
Expand All @@ -84,11 +108,35 @@ object QTree {
* level gives a bin size of 2^level. By default this is 1/65536 (level = -16)
*/
def value(v: Double, level: Int = DefaultLevel): QTree[Unit] = apply(v -> (), level)

private[algebird] def mergePeers[@specialized(Int, Long, Float, Double) A](left: QTree[A], right: QTree[A])(implicit monoid: Monoid[A]): QTree[A] = {
assert(right.lowerBound == left.lowerBound, "lowerBound " + right.lowerBound + " != " + left.lowerBound)
assert(right.level == left.level, "level " + right.level + " != " + left.level)

new QTree[A](left.offset,
left.level, left.count + right.count,
monoid.plus(left.sum, right.sum),
mergeOptions(left.lowerChildNullable, right.lowerChildNullable),
mergeOptions(left.upperChildNullable, right.upperChildNullable))
}

private def mergeOptions[A](aNullable: QTree[A], bNullable: QTree[A])(implicit monoid: Monoid[A]): QTree[A] =
if (aNullable != null) {
if (bNullable != null) {
mergePeers(aNullable, bNullable)
} else aNullable
} else bNullable

private[algebird] val cachedRangeCacheSize: Int = 20
private[algebird] val cachedRangeLowerBound: Int = cachedRangeCacheSize * -1
private[algebird] val rangeLut: Array[Double] = (cachedRangeLowerBound until cachedRangeCacheSize).map { level =>
math.pow(2.0, level)
}.toArray[Double]
}

class QTreeSemigroup[A](k: Int)(implicit val underlyingMonoid: Monoid[A]) extends Semigroup[QTree[A]] {
/** Override this if you want to change how frequently sumOption calls compress */
def compressBatchSize: Int = 25
def compressBatchSize: Int = 50
def plus(left: QTree[A], right: QTree[A]) = left.merge(right).compress(k)
override def sumOption(items: TraversableOnce[QTree[A]]): Option[QTree[A]] = if (items.isEmpty) None
else {
Expand All @@ -108,32 +156,69 @@ class QTreeSemigroup[A](k: Int)(implicit val underlyingMonoid: Monoid[A]) extend
}
}

case class QTree[A](
offset: Long, //the range this tree covers is offset*(2^level) ... (offset+1)*(2^level)
level: Int,
count: Long, //the total count for this node and all of its children
sum: A, //the sum at just this node (*not* including its children)
lowerChild: Option[QTree[A]],
upperChild: Option[QTree[A]]) {
class QTree[@specialized(Int, Long, Float, Double) A] private[algebird] (
_offset: Long, //the range this tree covers is offset*(2^level) ... (offset+1)*(2^level)
_level: Int,
_count: Long, //the total count for this node and all of its children
_sum: A, //the sum at just this node (*not* including its children)
_lowerChildNullable: QTree[A],
_upperChildNullable: QTree[A])
extends scala.Product6[Long, Int, Long, A, Option[QTree[A]], Option[QTree[A]]] with java.io.Serializable {
import QTree._

val range: Double =
if (_level < cachedRangeCacheSize && level > cachedRangeLowerBound)
rangeLut(_level + cachedRangeCacheSize)
else
math.pow(2.0, level)

def lowerBound: Double = range * _offset
def upperBound: Double = range * (_offset + 1)

def lowerChild: Option[QTree[A]] = Option(_lowerChildNullable)
def upperChild: Option[QTree[A]] = Option(_upperChildNullable)

// Helpers to access the nullable ones from inside the QTree work
@inline private[algebird] def lowerChildNullable: QTree[A] = _lowerChildNullable
@inline private[algebird] def upperChildNullable: QTree[A] = _upperChildNullable

require(offset >= 0, "QTree can not accept negative values")
@inline def offset: Long = _offset
@inline def level: Int = _level
@inline def count: Long = _count
@inline def sum: A = _sum

def range: Double = math.pow(2.0, level)
def lowerBound: Double = range * offset
def upperBound: Double = range * (offset + 1)
@inline def _1: Long = _offset
@inline def _2: Int = _level
@inline def _3: Long = _count
@inline def _4: A = _sum
@inline def _5: Option[QTree[A]] = lowerChild
@inline def _6: Option[QTree[A]] = upperChild

private def extendToLevel(n: Int)(implicit monoid: Monoid[A]): QTree[A] = {
override lazy val hashCode: Int = _root_.scala.runtime.ScalaRunTime._hashCode(this)

override def toString: String = _root_.scala.runtime.ScalaRunTime._toString(this)

override def equals(other: Any): Boolean = _root_.scala.runtime.ScalaRunTime._equals(this, other)

override def canEqual(other: Any): Boolean = other.isInstanceOf[QTree[A]]

override def productArity: Int = 6

@annotation.tailrec
private[algebird] final def extendToLevel(n: Int)(implicit monoid: Monoid[A]): QTree[A] = {
if (n <= level)
this
else {
val nextLevel = level + 1
val nextOffset = offset / 2
val nextLevel = _level + 1
val nextOffset = _offset / 2

// See benchmark in QTreeMicroBenchmark for why do this rather than the single if
// with 2 calls to QTree[A] in it.
val l = if (offset % 2 == 0) this else null
val r = if (offset % 2 == 0) null else this

val parent =
if (offset % 2 == 0)
QTree[A](nextOffset, nextLevel, count, monoid.zero, Some(this), None)
else
QTree[A](nextOffset, nextLevel, count, monoid.zero, None, Some(this))
new QTree[A](nextOffset, nextLevel, _count, monoid.zero, l, r)

parent.extendToLevel(n)
}
Expand All @@ -146,16 +231,16 @@ case class QTree[A](
* level (that is, the power of 2 for the interval).
*/
private def commonAncestorLevel(other: QTree[A]) = {
val minLevel = level.min(other.level)
val leftOffset = offset << (level - minLevel)
val minLevel = _level.min(other.level)
val leftOffset = offset << (_level - minLevel)
val rightOffset = other.offset << (other.level - minLevel)
var offsetDiff = leftOffset ^ rightOffset
var ancestorLevel = minLevel
while (offsetDiff > 0) {
ancestorLevel += 1
offsetDiff >>= 1
}
ancestorLevel.max(level).max(other.level)
ancestorLevel.max(_level).max(other.level)
}

/**
Expand All @@ -169,26 +254,9 @@ case class QTree[A](
val commonAncestor = commonAncestorLevel(other)
val left = extendToLevel(commonAncestor)
val right = other.extendToLevel(commonAncestor)
left.mergeWithPeer(right)
mergePeers(left, right)
}

private def mergeWithPeer(other: QTree[A])(implicit monoid: Monoid[A]): QTree[A] = {
assert(other.lowerBound == lowerBound, "lowerBound " + other.lowerBound + " != " + lowerBound)
assert(other.level == level, "level " + other.level + " != " + level)

copy(count = count + other.count,
sum = monoid.plus(sum, other.sum),
lowerChild = mergeOptions(lowerChild, other.lowerChild),
upperChild = mergeOptions(upperChild, other.upperChild))
}

private def mergeOptions(a: Option[QTree[A]], b: Option[QTree[A]])(implicit monoid: Monoid[A]): Option[QTree[A]] =
(a, b) match {
case (Some(qa), Some(qb)) => Some(qa.mergeWithPeer(qb))
case (None, right) => right
case (left, None) => left
}

/**
* give lower and upper bounds respectively of the percentile
* value given. For instance, quantileBounds(0.5) would give
Expand All @@ -197,36 +265,49 @@ case class QTree[A](
def quantileBounds(p: Double): (Double, Double) = {
require(p >= 0.0 && p <= 1.0, "The given percentile must be of the form 0 <= p <= 1.0")

val rank = math.floor(count * p).toLong
val rank = math.floor(_count * p).toLong
// get is safe below, because findRankLowerBound only returns
// None if rank > count, but due to construction rank <= count
(findRankLowerBound(rank).get, findRankUpperBound(rank).get)
(findRankLowerBound(rank), findRankUpperBound(rank))
}

private def findRankLowerBound(rank: Long): Option[Double] =
if (rank > count)
None
private def findRankLowerBound(rank: Long): java.lang.Double =
if (rank > _count)
null
else {
val childCounts = mapChildrenWithDefault(0L)(_.count)
val parentCount = count - childCounts._1 - childCounts._2
lowerChild.flatMap { _.findRankLowerBound(rank - parentCount) }
.orElse {
val newRank = rank - childCounts._1 - parentCount
if (newRank <= 0)
Some(lowerBound)
else
upperChild.flatMap{ _.findRankLowerBound(newRank) }
}
val parentCount = _count - childCounts._1 - childCounts._2
val r2 = if (lowerChildNullable != null) lowerChildNullable.findRankLowerBound(rank - parentCount) else null

if (r2 == null) {
val newRank = rank - childCounts._1 - parentCount
if (newRank <= 0)
lowerBound
else if (upperChildNullable != null)
upperChildNullable.findRankLowerBound(newRank)
else
null
} else r2
}

private def findRankUpperBound(rank: Long): Option[Double] = {
if (rank > count)
None
private def findRankUpperBound(rank: Long): java.lang.Double = {
if (rank > _count)
null
else {
lowerChild.flatMap{ _.findRankUpperBound(rank) }.orElse {
val lowerCount = lowerChild.map{ _.count }.getOrElse(0L)
upperChild.flatMap{ _.findRankUpperBound(rank - lowerCount) }.orElse(Some(upperBound))
}
val r = if (lowerChildNullable != null) {
lowerChildNullable.findRankUpperBound(rank)
} else null
if (r == null) {
val lowerCount = if (lowerChildNullable == null) 0L else lowerChildNullable.count

val r2: java.lang.Double = if (upperChildNullable != null) {
upperChildNullable.findRankUpperBound(rank - lowerCount)
} else null

if (r2 == null) {
upperBound
} else r2
} else r
}
}

Expand Down Expand Up @@ -271,8 +352,8 @@ case class QTree[A](
* are at most 2^k nodes, but usually fewer.
*/
def compress(k: Int)(implicit m: Monoid[A]): QTree[A] = {
val minCount = count >> k
if ((minCount > 1L) || (count < 1L)) {
val minCount = _count >> k
if ((minCount > 1L) || (_count < 1L)) {
pruneChildren(minCount)
} else {
// count > 0, so for all nodes, if minCount <= 1, then count >= minCount
Expand All @@ -285,29 +366,30 @@ case class QTree[A](

// If we don't prune we MUST return this
private def pruneChildren(minCount: Long)(implicit m: Monoid[A]): QTree[A] =
if (count < minCount) {
copy(sum = totalSum, lowerChild = None, upperChild = None)
if (_count < minCount) {
new QTree[A](_offset, _level, _count, totalSum, null, null)
} else {
val newLower = pruneChild(minCount, lowerChild)
val lowerNotPruned = newLower eq lowerChild
val newUpper = pruneChild(minCount, upperChild)
val upperNotPruned = newUpper eq upperChild
val newLower = pruneChild(minCount, lowerChildNullable)
val lowerNotPruned = newLower eq lowerChildNullable
val newUpper = pruneChild(minCount, upperChildNullable)
val upperNotPruned = newUpper eq upperChildNullable
if (lowerNotPruned && upperNotPruned)
this
else
copy(lowerChild = newLower, upperChild = newUpper)
new QTree[A](_offset, _level, _count, _sum, newLower, newUpper)
}

// If we don't prune we MUST return child
@inline
private def pruneChild(minCount: Long,
child: Option[QTree[A]])(implicit m: Monoid[A]): Option[QTree[A]] = child match {
case exists @ Some(oldChild) =>
val newChild = oldChild.pruneChildren(minCount)
if (newChild eq oldChild) exists // need to pass the same reference if we don't change
else Some(newChild)
case n @ None => n // make sure we pass the same ref out
}
childNullable: QTree[A])(implicit m: Monoid[A]): QTree[A] =
if (childNullable == null)
null
else {
val newChild = childNullable.pruneChildren(minCount)
if (newChild eq childNullable) childNullable // need to pass the same reference if we don't change
else newChild
}

/**
* How many total nodes are there in the QTree.
Expand All @@ -334,20 +416,20 @@ case class QTree[A](

private def parentCount = {
val childCounts = mapChildrenWithDefault(0L){ _.count }
count - childCounts._1 - childCounts._2
_count - childCounts._1 - childCounts._2
}

/**
* A debug method that prints the QTree to standard out using print/println
*/
def dump {
for (i <- (20 to level by -1))
def dump() {
for (i <- (20 to _level by -1))
print(" ")
print(lowerBound + " - " + upperBound + ": " + count)
print(lowerBound + " - " + upperBound + ": " + _count)
if (lowerChild.isDefined || upperChild.isDefined) {
print(" (" + parentCount + ")")
}
println(" {" + sum + "}")
println(" {" + _sum + "}")
lowerChild.foreach{ _.dump }
upperChild.foreach{ _.dump }
}
Expand Down
2 changes: 1 addition & 1 deletion project/Build.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ object AlgebirdBuild extends Build {

javacOptions ++= Seq("-target", "1.6", "-source", "1.6"),

scalacOptions ++= Seq("-unchecked", "-deprecation", "-language:implicitConversions", "-language:higherKinds", "-language:existentials"),
scalacOptions ++= Seq("-unchecked", "-deprecation", "-optimize", "-Xlint", "-language:implicitConversions", "-language:higherKinds", "-language:existentials"),

scalacOptions <++= (scalaVersion) map { sv =>
if (sv startsWith "2.10")
Expand Down

0 comments on commit e8f869b

Please sign in to comment.