diff --git a/bonsai-core/src/main/scala/com/stripe/bonsai/FullBinaryTree.scala b/bonsai-core/src/main/scala/com/stripe/bonsai/FullBinaryTree.scala index f85323b..9d8cbe1 100644 --- a/bonsai-core/src/main/scala/com/stripe/bonsai/FullBinaryTree.scala +++ b/bonsai-core/src/main/scala/com/stripe/bonsai/FullBinaryTree.scala @@ -88,7 +88,7 @@ class FullBinaryTree[A, B]( } final def reduce[X](f: (A, X, X) => X)(g: B => X): Option[X] = - if (nonEmpty) Some(FullBinaryTree.reduceNode(this, 0)(f)(g)) else None + if (nonEmpty) Some(FullBinaryTree.reduceNode(this, 0)(f, g)) else None final def root: Option[NodeRef] = if (nonEmpty) Some(mkNodeRef(0)) else None @@ -113,8 +113,8 @@ class FullBinaryTree[A, B]( } sealed abstract class NodeRef { - def fold[R](f: (NodeRef, NodeRef, A) => R, g: B => R): R - def reduce[X](f: (A, X, X) => X)(g: B => X): X + def fold[R](f: (A, NodeRef, NodeRef) => R, g: B => R): R + def reduce[X](f: (A, X, X) => X, g: B => X): X } final def leafLabel(index: Int): B = @@ -126,9 +126,9 @@ class FullBinaryTree[A, B]( final class LeafRef private[FullBinaryTree] (index: Int) extends NodeRef { def label: B = tree.leafLabel(index) - def fold[R](f: (NodeRef, NodeRef, A) => R, g: B => R): R = + def fold[R](f: (A, NodeRef, NodeRef) => R, g: B => R): R = g(label) - def reduce[X](f: (A, X, X) => X)(g: B => X): X = + def reduce[X](f: (A, X, X) => X, g: B => X): X = g(label) } @@ -139,10 +139,10 @@ class FullBinaryTree[A, B]( tree.mkNodeRef(2 * index + 1) def rightChild: NodeRef = tree.mkNodeRef(2 * index + 2) - def fold[R](f: (NodeRef, NodeRef, A) => R, g: B => R): R = - f(leftChild, rightChild, label) - def reduce[X](f: (A, X, X) => X)(g: B => X): X = - FullBinaryTree.reduceNode(tree, index)(f)(g) + def fold[R](f: (A, NodeRef, NodeRef) => R, g: B => R): R = + f(label, leftChild, rightChild) + def reduce[X](f: (A, X, X) => X, g: B => X): X = + FullBinaryTree.reduceNode(tree, index)(f, g) } } @@ -154,11 +154,11 @@ object FullBinaryTree { def root(t: FullBinaryTree[A, B]): Option[Node] = t.root - def foldNode[X](node: Node)(f: (Node, Node, A) => X, g: B => X): X = + def foldNode[X](node: Node)(f: (A, Node, Node) => X, g: B => X): X = node.fold(f, g) override def reduce[X](node: Node)(f: (Either[A, B], Iterable[X]) => X): X = - node.reduce[X]((a, x1, x2) => f(Left(a), x1 :: x2 :: Nil))(b => f(Right(b), Nil)) + node.reduce[X]((a, x1, x2) => f(Left(a), x1 :: x2 :: Nil), b => f(Right(b), Nil)) } /** @@ -190,7 +190,7 @@ object FullBinaryTree { build(rest) case (Some(node), rest) => bitsBldr += true - val (ol, or) = foldNode(node)({ (lc, rc, bl) => + val (ol, or) = foldNode(node)({ (bl, lc, rc) => leafBldr += false branchLabelBldr += bl (Some(lc), Some(rc)) @@ -214,13 +214,13 @@ object FullBinaryTree { } } - final def reduceNode[A, B, X](tree: FullBinaryTree[A, B], index: Int)(f: (A, X, X) => X)(g: B => X): X = + final def reduceNode[A, B, X](tree: FullBinaryTree[A, B], index: Int)(f: (A, X, X) => X, g: B => X): X = if (tree.isLeaf(index)) { g(tree.leafLabel(index)) } else { val label = tree.branchLabel(index) - val x1 = reduceNode(tree, tree.bitset.rank(2 * index + 1) - 1)(f)(g) - val x2 = reduceNode(tree, tree.bitset.rank(2 * index + 2) - 1)(f)(g) + val x1 = reduceNode(tree, tree.bitset.rank(2 * index + 1) - 1)(f, g) + val x2 = reduceNode(tree, tree.bitset.rank(2 * index + 2) - 1)(f, g) f(label, x1, x2) } diff --git a/bonsai-core/src/main/scala/com/stripe/bonsai/FullBinaryTreeOps.scala b/bonsai-core/src/main/scala/com/stripe/bonsai/FullBinaryTreeOps.scala index e3c15f8..9b1b1ae 100644 --- a/bonsai-core/src/main/scala/com/stripe/bonsai/FullBinaryTreeOps.scala +++ b/bonsai-core/src/main/scala/com/stripe/bonsai/FullBinaryTreeOps.scala @@ -3,22 +3,25 @@ package com.stripe.bonsai trait FullBinaryTreeOps[T, BL, LL] extends TreeOps[T, Either[BL, LL]] { override def reduce[A](node: Node)(f: (Either[BL, LL], Iterable[A]) => A): A = - foldNode(node)({ (lc, rc, lbl) => + foldNode(node)({ (lbl, lc, rc) => f(Left(lbl), reduce(lc)(f) :: reduce(rc)(f) :: Nil) }, lbl => f(Right(lbl), Nil)) - def foldNode[A](node: Node)(f: (Node, Node, BL) => A, g: LL => A): A + def foldNode[A](node: Node)(f: (BL, Node, Node) => A, g: LL => A): A - def reduceNode[A](node: Node)(f: (BL, Iterable[A]) => A)(g: LL => A): A = - foldNode(node)({ (lc, rc, lbl) => - f(lbl, reduceNode(lc)(f)(g) :: reduceNode(rc)(f)(g) :: Nil) - }, g) + def reduceNode[A](node: Node)(f: (BL, A, A) => A, g: LL => A): A = + foldNode(node)((lbl, rc, lc) => f(lbl, reduceNode(lc)(f, g), reduceNode(rc)(f, g)), g) def label(node: Node): Either[BL, LL] = - foldNode(node)({ case (_, _, bl) => Left(bl) }, ll => Right(ll)) + foldNode(node)((bl, _, _) => Left(bl), ll => Right(ll)) def children(node: Node): Iterable[Node] = - foldNode(node)({ case (lc, rc, _) => lc :: rc :: Nil }, _ => Nil) + foldNode(node)((_, lc, rc) => lc :: rc :: Nil, _ => Nil) + + def collectLeafLabelsF[A](node: Node)(f: LL => A): Set[A] = + reduceNode[Set[A]](node)((_, lc, rc) => lc ++ rc, ll => Set(f(ll))) + + def collectLeafLabels(node: Node): Set[LL] = collectLeafLabelsF(node)(identity) } object FullBinaryTreeOps { diff --git a/bonsai-core/src/main/scala/com/stripe/bonsai/TreeOps.scala b/bonsai-core/src/main/scala/com/stripe/bonsai/TreeOps.scala index c55d688..6995fc1 100644 --- a/bonsai-core/src/main/scala/com/stripe/bonsai/TreeOps.scala +++ b/bonsai-core/src/main/scala/com/stripe/bonsai/TreeOps.scala @@ -66,6 +66,11 @@ trait TreeOps[Tree, Label] { ops => def reduce[A](tree: Tree)(f: (Label, Iterable[A]) => A): Option[A] = root(tree).map(n => reduce(n)(f)) + def collectLabelsF[A](node: Node)(f: Label => A): Set[A] = + reduce[Set[A]](node)((lbl, cs) => cs.foldLeft(Set(f(lbl)))(_ ++ _)) + + def collectLabels(node: Node): Set[Label] = collectLabelsF(node)(identity) + implicit class OpsForTree(tree: Tree) { def root: Option[Node] = ops.root(tree) def reduce[A](f: (Label, Iterable[A]) => A): Option[A] = root.map(ops.reduce(_)(f)) diff --git a/bonsai-core/src/test/scala/com/stripe/bonsai/FullBinaryTreeSpec.scala b/bonsai-core/src/test/scala/com/stripe/bonsai/FullBinaryTreeSpec.scala index 9a785fe..78edda2 100644 --- a/bonsai-core/src/test/scala/com/stripe/bonsai/FullBinaryTreeSpec.scala +++ b/bonsai-core/src/test/scala/com/stripe/bonsai/FullBinaryTreeSpec.scala @@ -17,13 +17,13 @@ class FullBinaryTreeSpec extends WordSpec with Matchers with Checkers with Prope } def sumNode(n: ops.Node): Long = - ops.foldNode(n)((lc, rc, n) => n + sumNode(lc) + sumNode(rc), n => n) + ops.foldNode(n)((n, lc, rc) => n + sumNode(lc) + sumNode(rc), n => n) def nodeElems(n: ops.Node): Set[Int] = - ops.foldNode(n)((lc, rc, n) => Set(n) | nodeElems(lc) | nodeElems(rc), Set(_)) + ops.foldNode(n)((n, lc, rc) => Set(n) | nodeElems(lc) | nodeElems(rc), Set(_)) def minNode(n: ops.Node): Int = - ops.foldNode(n)((lc, rc, n) => (n min (minNode(lc) min minNode(rc))), n => n) + ops.foldNode(n)((n, lc, rc) => (n min (minNode(lc) min minNode(rc))), n => n) "write" should { "round-trip through read" in { @@ -93,4 +93,20 @@ class FullBinaryTreeSpec extends WordSpec with Matchers with Checkers with Prope } } } + + "FullBinaryTreeOps" should { + "collect leaf labels" in { + val genTree = GenericBinTree.branch(2, GenericBinTree.leaf(1), GenericBinTree.leaf(3)) + val bt1 = FullBinaryTree(genTree) + ops.collectLeafLabels(ops.root(bt1).get) shouldBe Set(1, 3) + } + + "stream all labels" in { + import GenericBinTree._ + val genTree = branch(0, branch(2, leaf(1), leaf(3)), branch(6, leaf(5), leaf(0))) + val bt1 = FullBinaryTree(genTree) + val expected = List(1, 2, 3, 0, 5, 6, 0) + ops.collectLabelsF(ops.root(bt1).get)(_.merge) shouldBe expected.toSet + } + } } diff --git a/bonsai-core/src/test/scala/com/stripe/bonsai/GenericTree.scala b/bonsai-core/src/test/scala/com/stripe/bonsai/GenericTree.scala index 4439c55..4bde07c 100644 --- a/bonsai-core/src/test/scala/com/stripe/bonsai/GenericTree.scala +++ b/bonsai-core/src/test/scala/com/stripe/bonsai/GenericTree.scala @@ -57,16 +57,16 @@ object GenericBinTree { def root(t: GenericBinTree[A]): Option[Node] = Some(t) - def foldNode[X](node: Node)(f: (Node, Node, A) => X, g: A => X): X = + def foldNode[X](node: Node)(f: (A, Node, Node) => X, g: A => X): X = node.children match { - case Some((lc, rc)) => f(lc, rc, node.label) + case Some((lc, rc)) => f(node.label, lc, rc) case None => g(node.label) } } def fromTree[A](tree: FullBinaryTree[A, A]): Option[GenericBinTree[A]] = { def construct(n: tree.NodeRef): GenericBinTree[A] = - n.fold({ (lc, rc, a) => + n.fold({ (a, lc, rc) => GenericBinTree.branch(a, construct(lc), construct(rc)) }, GenericBinTree.leaf) tree.root.map(construct)