Skip to content

Commit

Permalink
Merge pull request #19 from travisbrown/topic/tree-ops
Browse files Browse the repository at this point in the history
Add label collection methods, clean up API a bit
  • Loading branch information
travisbrown-stripe authored Jun 6, 2017
2 parents ae3b525 + fabe077 commit 02c18a3
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 29 deletions.
30 changes: 15 additions & 15 deletions bonsai-core/src/main/scala/com/stripe/bonsai/FullBinaryTree.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class FullBinaryTree[A, B](
}

final def reduce[X](f: (A, X, X) => X)(g: B => X): Option[X] =
if (nonEmpty) Some(FullBinaryTree.reduceNode(this, 0)(f)(g)) else None
if (nonEmpty) Some(FullBinaryTree.reduceNode(this, 0)(f, g)) else None

final def root: Option[NodeRef] =
if (nonEmpty) Some(mkNodeRef(0)) else None
Expand All @@ -113,8 +113,8 @@ class FullBinaryTree[A, B](
}

sealed abstract class NodeRef {
def fold[R](f: (NodeRef, NodeRef, A) => R, g: B => R): R
def reduce[X](f: (A, X, X) => X)(g: B => X): X
def fold[R](f: (A, NodeRef, NodeRef) => R, g: B => R): R
def reduce[X](f: (A, X, X) => X, g: B => X): X
}

final def leafLabel(index: Int): B =
Expand All @@ -126,9 +126,9 @@ class FullBinaryTree[A, B](
final class LeafRef private[FullBinaryTree] (index: Int) extends NodeRef {
def label: B =
tree.leafLabel(index)
def fold[R](f: (NodeRef, NodeRef, A) => R, g: B => R): R =
def fold[R](f: (A, NodeRef, NodeRef) => R, g: B => R): R =
g(label)
def reduce[X](f: (A, X, X) => X)(g: B => X): X =
def reduce[X](f: (A, X, X) => X, g: B => X): X =
g(label)
}

Expand All @@ -139,10 +139,10 @@ class FullBinaryTree[A, B](
tree.mkNodeRef(2 * index + 1)
def rightChild: NodeRef =
tree.mkNodeRef(2 * index + 2)
def fold[R](f: (NodeRef, NodeRef, A) => R, g: B => R): R =
f(leftChild, rightChild, label)
def reduce[X](f: (A, X, X) => X)(g: B => X): X =
FullBinaryTree.reduceNode(tree, index)(f)(g)
def fold[R](f: (A, NodeRef, NodeRef) => R, g: B => R): R =
f(label, leftChild, rightChild)
def reduce[X](f: (A, X, X) => X, g: B => X): X =
FullBinaryTree.reduceNode(tree, index)(f, g)
}
}

Expand All @@ -154,11 +154,11 @@ object FullBinaryTree {
def root(t: FullBinaryTree[A, B]): Option[Node] =
t.root

def foldNode[X](node: Node)(f: (Node, Node, A) => X, g: B => X): X =
def foldNode[X](node: Node)(f: (A, Node, Node) => X, g: B => X): X =
node.fold(f, g)

override def reduce[X](node: Node)(f: (Either[A, B], Iterable[X]) => X): X =
node.reduce[X]((a, x1, x2) => f(Left(a), x1 :: x2 :: Nil))(b => f(Right(b), Nil))
node.reduce[X]((a, x1, x2) => f(Left(a), x1 :: x2 :: Nil), b => f(Right(b), Nil))
}

/**
Expand Down Expand Up @@ -190,7 +190,7 @@ object FullBinaryTree {
build(rest)
case (Some(node), rest) =>
bitsBldr += true
val (ol, or) = foldNode(node)({ (lc, rc, bl) =>
val (ol, or) = foldNode(node)({ (bl, lc, rc) =>
leafBldr += false
branchLabelBldr += bl
(Some(lc), Some(rc))
Expand All @@ -214,13 +214,13 @@ object FullBinaryTree {
}
}

final def reduceNode[A, B, X](tree: FullBinaryTree[A, B], index: Int)(f: (A, X, X) => X)(g: B => X): X =
final def reduceNode[A, B, X](tree: FullBinaryTree[A, B], index: Int)(f: (A, X, X) => X, g: B => X): X =
if (tree.isLeaf(index)) {
g(tree.leafLabel(index))
} else {
val label = tree.branchLabel(index)
val x1 = reduceNode(tree, tree.bitset.rank(2 * index + 1) - 1)(f)(g)
val x2 = reduceNode(tree, tree.bitset.rank(2 * index + 2) - 1)(f)(g)
val x1 = reduceNode(tree, tree.bitset.rank(2 * index + 1) - 1)(f, g)
val x2 = reduceNode(tree, tree.bitset.rank(2 * index + 2) - 1)(f, g)
f(label, x1, x2)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@ package com.stripe.bonsai
trait FullBinaryTreeOps[T, BL, LL] extends TreeOps[T, Either[BL, LL]] {

override def reduce[A](node: Node)(f: (Either[BL, LL], Iterable[A]) => A): A =
foldNode(node)({ (lc, rc, lbl) =>
foldNode(node)({ (lbl, lc, rc) =>
f(Left(lbl), reduce(lc)(f) :: reduce(rc)(f) :: Nil)
}, lbl => f(Right(lbl), Nil))

def foldNode[A](node: Node)(f: (Node, Node, BL) => A, g: LL => A): A
def foldNode[A](node: Node)(f: (BL, Node, Node) => A, g: LL => A): A

def reduceNode[A](node: Node)(f: (BL, Iterable[A]) => A)(g: LL => A): A =
foldNode(node)({ (lc, rc, lbl) =>
f(lbl, reduceNode(lc)(f)(g) :: reduceNode(rc)(f)(g) :: Nil)
}, g)
def reduceNode[A](node: Node)(f: (BL, A, A) => A, g: LL => A): A =
foldNode(node)((lbl, rc, lc) => f(lbl, reduceNode(lc)(f, g), reduceNode(rc)(f, g)), g)

def label(node: Node): Either[BL, LL] =
foldNode(node)({ case (_, _, bl) => Left(bl) }, ll => Right(ll))
foldNode(node)((bl, _, _) => Left(bl), ll => Right(ll))

def children(node: Node): Iterable[Node] =
foldNode(node)({ case (lc, rc, _) => lc :: rc :: Nil }, _ => Nil)
foldNode(node)((_, lc, rc) => lc :: rc :: Nil, _ => Nil)

def collectLeafLabelsF[A](node: Node)(f: LL => A): Set[A] =
reduceNode[Set[A]](node)((_, lc, rc) => lc ++ rc, ll => Set(f(ll)))

def collectLeafLabels(node: Node): Set[LL] = collectLeafLabelsF(node)(identity)
}

object FullBinaryTreeOps {
Expand Down
5 changes: 5 additions & 0 deletions bonsai-core/src/main/scala/com/stripe/bonsai/TreeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ trait TreeOps[Tree, Label] { ops =>
def reduce[A](tree: Tree)(f: (Label, Iterable[A]) => A): Option[A] =
root(tree).map(n => reduce(n)(f))

def collectLabelsF[A](node: Node)(f: Label => A): Set[A] =
reduce[Set[A]](node)((lbl, cs) => cs.foldLeft(Set(f(lbl)))(_ ++ _))

def collectLabels(node: Node): Set[Label] = collectLabelsF(node)(identity)

implicit class OpsForTree(tree: Tree) {
def root: Option[Node] = ops.root(tree)
def reduce[A](f: (Label, Iterable[A]) => A): Option[A] = root.map(ops.reduce(_)(f))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ class FullBinaryTreeSpec extends WordSpec with Matchers with Checkers with Prope
}

def sumNode(n: ops.Node): Long =
ops.foldNode(n)((lc, rc, n) => n + sumNode(lc) + sumNode(rc), n => n)
ops.foldNode(n)((n, lc, rc) => n + sumNode(lc) + sumNode(rc), n => n)

def nodeElems(n: ops.Node): Set[Int] =
ops.foldNode(n)((lc, rc, n) => Set(n) | nodeElems(lc) | nodeElems(rc), Set(_))
ops.foldNode(n)((n, lc, rc) => Set(n) | nodeElems(lc) | nodeElems(rc), Set(_))

def minNode(n: ops.Node): Int =
ops.foldNode(n)((lc, rc, n) => (n min (minNode(lc) min minNode(rc))), n => n)
ops.foldNode(n)((n, lc, rc) => (n min (minNode(lc) min minNode(rc))), n => n)

"write" should {
"round-trip through read" in {
Expand Down Expand Up @@ -93,4 +93,20 @@ class FullBinaryTreeSpec extends WordSpec with Matchers with Checkers with Prope
}
}
}

"FullBinaryTreeOps" should {
"collect leaf labels" in {
val genTree = GenericBinTree.branch(2, GenericBinTree.leaf(1), GenericBinTree.leaf(3))
val bt1 = FullBinaryTree(genTree)
ops.collectLeafLabels(ops.root(bt1).get) shouldBe Set(1, 3)
}

"stream all labels" in {
import GenericBinTree._
val genTree = branch(0, branch(2, leaf(1), leaf(3)), branch(6, leaf(5), leaf(0)))
val bt1 = FullBinaryTree(genTree)
val expected = List(1, 2, 3, 0, 5, 6, 0)
ops.collectLabelsF(ops.root(bt1).get)(_.merge) shouldBe expected.toSet
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ object GenericBinTree {
def root(t: GenericBinTree[A]): Option[Node] =
Some(t)

def foldNode[X](node: Node)(f: (Node, Node, A) => X, g: A => X): X =
def foldNode[X](node: Node)(f: (A, Node, Node) => X, g: A => X): X =
node.children match {
case Some((lc, rc)) => f(lc, rc, node.label)
case Some((lc, rc)) => f(node.label, lc, rc)
case None => g(node.label)
}
}

def fromTree[A](tree: FullBinaryTree[A, A]): Option[GenericBinTree[A]] = {
def construct(n: tree.NodeRef): GenericBinTree[A] =
n.fold({ (lc, rc, a) =>
n.fold({ (a, lc, rc) =>
GenericBinTree.branch(a, construct(lc), construct(rc))
}, GenericBinTree.leaf)
tree.root.map(construct)
Expand Down

0 comments on commit 02c18a3

Please sign in to comment.