-
Notifications
You must be signed in to change notification settings - Fork 50
Bounded-memory local training #89
Changes from all commits
dec1655
48deb83
732c107
fdd9500
605ca39
9e5e712
ab147a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
package com.stripe.brushfire.local | ||
|
||
import java.io._ | ||
import scala.collection.generic.CanBuildFrom | ||
|
||
abstract class Lines[A](val file: File, val charset: String = "UTF-8") { self => | ||
|
||
def iterator: Iterator[A] | ||
|
||
def toIterable: Iterable[A] = | ||
new IterableLines(this) | ||
|
||
def filter(f: A => Boolean): Lines[A] = | ||
new Lines[A](file) { | ||
def iterator: Iterator[A] = self.iterator.filter(f) | ||
} | ||
|
||
def map[B](f: A => B): Lines[B] = | ||
new Lines[B](file) { | ||
def iterator: Iterator[B] = self.iterator.map(f) | ||
} | ||
|
||
def flatMap[B](f: A => Iterable[B]): Lines[B] = | ||
new Lines[B](file) { | ||
def iterator: Iterator[B] = self.iterator.flatMap(a => f(a).iterator) | ||
} | ||
|
||
def foldLeft[B](b0: B)(f: (B, A) => B): B = | ||
iterator.foldLeft(b0)(f) | ||
|
||
def foreach(f: A => Unit): Unit = | ||
iterator.foreach(f) | ||
|
||
override def toString(): String = | ||
s"Lines(<over $file with $charset>)" | ||
} | ||
|
||
object Lines { | ||
def apply(f: File, cs: String = "UTF-8"): Lines[String] = | ||
new Lines[String](f, cs) { | ||
def iterator: Iterator[String] = | ||
new Iterator[String] { | ||
println("Opening " + f) | ||
val reader = new BufferedReader(new InputStreamReader(new FileInputStream(file), charset)) | ||
var line: String = reader.readLine() | ||
if (line == null) reader.close() | ||
def hasNext(): Boolean = line != null | ||
def next(): String = { | ||
if (line == null) throw new NoSuchElementException("next on empty iterator") | ||
val out = line | ||
line = reader.readLine() | ||
if (line == null) reader.close() | ||
out | ||
} | ||
} | ||
} | ||
|
||
def apply(pathname: String): Lines[String] = apply(new File(pathname)) | ||
} | ||
|
||
class IterableLines[A](lines: Lines[A]) extends Iterable[A] { | ||
def iterator: Iterator[A] = lines.iterator | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,81 +6,145 @@ import com.twitter.algebird._ | |
|
||
import AnnotatedTree.AnnotatedTreeTraversal | ||
|
||
|
||
//map with a reservoir of up to `capacity` randomly chosen keys | ||
case class SampledMap[A,B](capacity: Int) { | ||
var mapValues = Map[A,B]() | ||
var randValues = Map[A,Double]() | ||
var threshold = 0.0 | ||
val rand = new util.Random | ||
|
||
private def randValue(key: A): Double = { | ||
randValues.get(key) match { | ||
case Some(r) => r | ||
case None => { | ||
val r = rand.nextDouble | ||
randValues += key->r | ||
|
||
if(randValues.size <= capacity && r >= threshold) | ||
threshold = r | ||
else if(randValues.size > capacity && r < threshold) { | ||
println("evicting") | ||
val bottomK = randValues.toList.sortBy{_._2}.take(capacity) | ||
val keep = bottomK.map{_._1}.toSet | ||
threshold = bottomK.last._2 | ||
mapValues = mapValues.filterKeys(keep) | ||
} | ||
|
||
r | ||
} | ||
} | ||
} | ||
|
||
def containsKey(key: A): Boolean = randValue(key) <= threshold | ||
def update(key: A, value: B) { | ||
if(containsKey(key)) | ||
mapValues += key -> value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we're just mutating these structures in place why don't we use |
||
} | ||
|
||
def get(key: A): Option[B] = mapValues.get(key) | ||
} | ||
|
||
case class Trainer[K: Ordering, V: Ordering, T: Monoid]( | ||
trainingData: Iterable[Instance[K, V, T]], | ||
sampler: Sampler[K], | ||
trees: List[Tree[K, V, T]])(implicit traversal: AnnotatedTreeTraversal[K, V, T, Unit]) { | ||
|
||
private def updateTrees(fn: (Tree[K, V, T], Int, Map[(Int, T, Unit), Iterable[Instance[K, V, T]]]) => Tree[K, V, T]): Trainer[K, V, T] = { | ||
val newTrees = trees.zipWithIndex.par.map { | ||
case (tree, index) => | ||
val byLeaf = | ||
trainingData.flatMap { instance => | ||
val repeats = sampler.timesInTrainingSet(instance.id, instance.timestamp, index) | ||
if (repeats > 0) { | ||
tree.leafFor(instance.features).map { leaf => | ||
(1 to repeats).toList.map { i => (instance, leaf) } | ||
}.getOrElse(Nil) | ||
} else { | ||
Nil | ||
} | ||
}.groupBy { _._2 } | ||
.mapValues { _.map { _._1 } } | ||
fn(tree, index, byLeaf) | ||
val treeMap = trees.zipWithIndex.map{case (t,i) => i->t}.toMap | ||
|
||
def expand(maxLeavesPerTree: Int)(implicit splitter: Splitter[V, T], evaluator: Evaluator[V, T], stopper: Stopper[T]): Trainer[K, V, T] = { | ||
val allStats = treeMap.map{case (treeIndex, tree) => | ||
treeIndex -> SampledMap[Int,Map[K,splitter.S]](maxLeavesPerTree) | ||
} | ||
copy(trees = newTrees.toList) | ||
} | ||
|
||
private def updateLeaves(fn: (Int, (Int, T, Unit), Iterable[Instance[K, V, T]]) => Node[K, V, T, Unit]): Trainer[K, V, T] = { | ||
updateTrees { | ||
case (tree, treeIndex, byLeaf) => | ||
val newNodes = byLeaf.map { | ||
case (leaf, instances) => | ||
val (index, _, _) = leaf | ||
index -> fn(treeIndex, leaf, instances) | ||
trainingData.foreach{instance => | ||
lazy val features = instance.features.mapValues { value => splitter.create(value, instance.target) } | ||
|
||
for ( | ||
(treeIndex, tree) <- treeMap.toList; | ||
treeStats <- allStats.get(treeIndex).toList; | ||
i <- 1.to(sampler.timesInTrainingSet(instance.id, instance.timestamp, treeIndex)).toList; | ||
(leafIndex, target, annotation) <- tree.leafFor(instance.features).toList | ||
if stopper.shouldSplit(target) && treeStats.containsKey(leafIndex); | ||
(feature, stats) <- features | ||
if sampler.includeFeature(feature, treeIndex, leafIndex) | ||
) { | ||
var leafStats = treeStats.get(leafIndex).getOrElse(Map[K,splitter.S]()) | ||
val combined = leafStats.get(feature) match { | ||
case Some(old) => splitter.semigroup.plus(old, stats) | ||
case None => stats | ||
} | ||
leafStats += feature -> combined | ||
treeStats.update(leafIndex, leafStats) | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you are using this imperatively, I think it would be more efficient to do this explicitly (you wouldn't need to call treeMap.foreach { case (treeIndex, tree) =>
allStats.get(treeIndex).foreach { treeStats =>
val times = sampler.timesInTrainingSet(instance.id, instance.timestamp, treeIndex)
(1 to times).foreach { i =>
tree.leafFor(instance.features).foreach { case (leafIndex, target, annotation) =>
if (stopper.shouldSplit(target) && treeStats.containsKey(leafIndex)) {
features.foreach { case (feature, stats) =>
if (sampler.includeFeature(feature, treeIndex, leafIndex)) {
var leafStats = treeStats.get(leafIndex).getOrElse(Map[K,splitter.S]())
val combined = leafStats.get(feature) match {
case Some(old) => splitter.semigroup.plus(old, stats)
case None => stats
}
leafStats += feature -> combined
treeStats.update(leafIndex, leafStats)
}
}
}
}
}
}
} (I realize that syntactically it looks uglier, but it's also a bit clearer what is happening and will have much better performance than the previous code which was essentially doing |
||
|
||
tree.updateByLeafIndex(newNodes.lift) | ||
val newTreeMap = allStats.map{case (treeIndex, treeStats) => | ||
val tree = treeMap(treeIndex) | ||
treeIndex -> tree.growByLeafIndex{leafIndex => | ||
val candidates = for( | ||
leafStats <- treeStats.get(leafIndex).toList; | ||
parent <- tree.leafAt(leafIndex).toList; | ||
(feature, stats) <- leafStats.toList; | ||
split <- splitter.split(parent.target, stats); | ||
(newSplit, score) <- evaluator.evaluate(split).toList | ||
) yield (newSplit.createSplitNode(feature), score) | ||
|
||
if(candidates.isEmpty) | ||
None | ||
else | ||
Some(candidates.maxBy{_._2}._1) | ||
} | ||
} | ||
|
||
val newTrees = 0.until(trees.size).toList.map{i => newTreeMap(i)} | ||
Trainer(trainingData, sampler, newTrees) | ||
} | ||
|
||
def updateTargets: Trainer[K, V, T] = | ||
updateLeaves { | ||
case (treeIndex, (index, _, annotation), instances) => | ||
val target = implicitly[Monoid[T]].sum(instances.map { _.target }) | ||
LeafNode(index, target, annotation) | ||
def updateTargets: Trainer[K, V, T] = { | ||
var targets = treeMap.mapValues{tree => Map[Int, T]()} | ||
trainingData.foreach{instance => | ||
for ( | ||
(treeIndex, tree) <- treeMap.toList; | ||
i <- 1.to(sampler.timesInTrainingSet(instance.id, instance.timestamp, treeIndex)).toList; | ||
leafIndex <- tree.leafIndexFor(instance.features).toList | ||
) { | ||
val treeTargets = targets(treeIndex) | ||
val old = treeTargets.getOrElse(leafIndex, Monoid.zero[T]) | ||
val combined = Monoid.plus(instance.target, old) | ||
targets += treeIndex -> (treeTargets + (leafIndex -> combined)) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above -- I think converting this to |
||
} | ||
|
||
def expand(times: Int)(implicit splitter: Splitter[V, T], evaluator: Evaluator[V, T], stopper: Stopper[T]): Trainer[K, V, T] = | ||
updateLeaves { | ||
case (treeIndex, (index, target, annotation), instances) => | ||
Tree.expand(times, treeIndex, LeafNode[K, V, T, Unit](index, target, annotation), splitter, evaluator, stopper, sampler, instances) | ||
val newTrees = trees.zipWithIndex.map{case (tree, index) => | ||
tree.updateByLeafIndex{leafIndex => | ||
val target = targets(index).getOrElse(leafIndex, Monoid.zero[T]) | ||
Some(LeafNode(leafIndex, target, ())) | ||
} | ||
} | ||
|
||
def prune[P, E](error: Error[T, P, E])(implicit voter: Voter[T, P], ord: Ordering[E]): Trainer[K, V, T] = | ||
updateTrees { | ||
case (tree, treeIndex, byLeaf) => | ||
val byLeafIndex = byLeaf.map { | ||
case ((index, _, _), instances) => | ||
index -> implicitly[Monoid[T]].sum(instances.map { _.target }) | ||
} | ||
tree.prune(byLeafIndex, voter, error) | ||
} | ||
copy(trees = newTrees) | ||
} | ||
|
||
def validate[P, E](error: Error[T, P, E])(implicit voter: Voter[T, P]): Option[E] = { | ||
val errors = trainingData.flatMap { instance => | ||
val useTrees = trees.zipWithIndex.filter { | ||
case (tree, i) => | ||
sampler.includeInValidationSet(instance.id, instance.timestamp, i) | ||
}.map { _._1 } | ||
if(useTrees.isEmpty) | ||
None | ||
else { | ||
val prediction = voter.predict(useTrees, instance.features) | ||
Some(error.create(instance.target, prediction)) | ||
var output: Option[E] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we expect that we won't have a monoid for error? It seems like using |
||
trainingData.foreach{instance => | ||
val predictions = | ||
for ( | ||
(treeIndex, tree) <- treeMap | ||
if sampler.includeInValidationSet(instance.id, instance.timestamp, treeIndex); | ||
target <- tree.targetFor(instance.features).toList | ||
) yield target | ||
|
||
if(!predictions.isEmpty) { | ||
val newError = error.create(instance.target, voter.combine(predictions)) | ||
output = output | ||
.map{old => error.semigroup.plus(old, newError)} | ||
.orElse(Some(newError)) | ||
} | ||
} | ||
error.semigroup.sumOption(errors) | ||
|
||
output | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
#!/bin/sh | ||
java -Xmx2G -cp ../brushfire-scalding/target/target/brushfire-scalding-0.6.0-SNAPSHOT-jar-with-dependencies.jar \ | ||
java -Xmx2G -cp ../brushfire-scalding/target/scala-2.10/brushfire-scalding-0.7.1-SNAPSHOT-jar-with-dependencies.jar \ | ||
com.stripe.brushfire.local.Example \ | ||
petal-width petal-length sepal-width sepal-length \ | ||
< iris.data | ||
iris.data \ | ||
petal-width petal-length sepal-width sepal-length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really small thing here but you could say: