Skip to content
This repository has been archived by the owner on Apr 8, 2021. It is now read-only.

Bounded-memory local training #89

Merged
merged 7 commits into from
Apr 6, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,25 @@ import AnnotatedTree.{AnnotatedTreeTraversal, fullBinaryTreeOpsForAnnotatedTree}
object Example extends Defaults {

def main(args: Array[String]) {
val cols = args.toList

val trainingData = io.Source.stdin.getLines.map { line =>
val path = args.head
val cols = args.tail.toList
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really small thing here but you could say:

val path :: cols = args.toList


val trainingData = Lines(path).map { line =>
val parts = line.split(",").reverse.toList
val label = parts.head
val values = parts.tail.map { s => s.toDouble }
Instance(line, 0L, Map(cols.zip(values): _*), Map(label -> 1L))
}.toList
}.toIterable

var trainer =
Trainer(trainingData, KFoldSampler(4))
.updateTargets

println(trainer.validate(AccuracyError()))
println(trainer.validate(BrierScoreError()))


1.to(10).foreach { i =>
trainer = trainer.expand(1)
println(trainer.validate(AccuracyError()))
println(trainer.validate(BrierScoreError()))
}

implicit val ord = Ordering.by[AveragedValue, Double] { _.value }
trainer = trainer.prune(BrierScoreError())
println(trainer.validate(AccuracyError()))
println(trainer.validate(BrierScoreError()))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package com.stripe.brushfire.local

import java.io._
import scala.collection.generic.CanBuildFrom

abstract class Lines[A](val file: File, val charset: String = "UTF-8") { self =>

def iterator: Iterator[A]

def toIterable: Iterable[A] =
new IterableLines(this)

def filter(f: A => Boolean): Lines[A] =
new Lines[A](file) {
def iterator: Iterator[A] = self.iterator.filter(f)
}

def map[B](f: A => B): Lines[B] =
new Lines[B](file) {
def iterator: Iterator[B] = self.iterator.map(f)
}

def flatMap[B](f: A => Iterable[B]): Lines[B] =
new Lines[B](file) {
def iterator: Iterator[B] = self.iterator.flatMap(a => f(a).iterator)
}

def foldLeft[B](b0: B)(f: (B, A) => B): B =
iterator.foldLeft(b0)(f)

def foreach(f: A => Unit): Unit =
iterator.foreach(f)

override def toString(): String =
s"Lines(<over $file with $charset>)"
}

object Lines {
def apply(f: File, cs: String = "UTF-8"): Lines[String] =
new Lines[String](f, cs) {
def iterator: Iterator[String] =
new Iterator[String] {
println("Opening " + f)
val reader = new BufferedReader(new InputStreamReader(new FileInputStream(file), charset))
var line: String = reader.readLine()
if (line == null) reader.close()
def hasNext(): Boolean = line != null
def next(): String = {
if (line == null) throw new NoSuchElementException("next on empty iterator")
val out = line
line = reader.readLine()
if (line == null) reader.close()
out
}
}
}

def apply(pathname: String): Lines[String] = apply(new File(pathname))
}

class IterableLines[A](lines: Lines[A]) extends Iterable[A] {
def iterator: Iterator[A] = lines.iterator
}
174 changes: 119 additions & 55 deletions brushfire-core/src/main/scala/com/stripe/brushfire/local/Trainer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,81 +6,145 @@ import com.twitter.algebird._

import AnnotatedTree.AnnotatedTreeTraversal


//map with a reservoir of up to `capacity` randomly chosen keys
case class SampledMap[A,B](capacity: Int) {
var mapValues = Map[A,B]()
var randValues = Map[A,Double]()
var threshold = 0.0
val rand = new util.Random

private def randValue(key: A): Double = {
randValues.get(key) match {
case Some(r) => r
case None => {
val r = rand.nextDouble
randValues += key->r

if(randValues.size <= capacity && r >= threshold)
threshold = r
else if(randValues.size > capacity && r < threshold) {
println("evicting")
val bottomK = randValues.toList.sortBy{_._2}.take(capacity)
val keep = bottomK.map{_._1}.toSet
threshold = bottomK.last._2
mapValues = mapValues.filterKeys(keep)
}

r
}
}
}

def containsKey(key: A): Boolean = randValue(key) <= threshold
def update(key: A, value: B) {
if(containsKey(key))
mapValues += key -> value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're just mutating these structures in place why don't we use mutable.Map for better performance?

}

def get(key: A): Option[B] = mapValues.get(key)
}

case class Trainer[K: Ordering, V: Ordering, T: Monoid](
trainingData: Iterable[Instance[K, V, T]],
sampler: Sampler[K],
trees: List[Tree[K, V, T]])(implicit traversal: AnnotatedTreeTraversal[K, V, T, Unit]) {

private def updateTrees(fn: (Tree[K, V, T], Int, Map[(Int, T, Unit), Iterable[Instance[K, V, T]]]) => Tree[K, V, T]): Trainer[K, V, T] = {
val newTrees = trees.zipWithIndex.par.map {
case (tree, index) =>
val byLeaf =
trainingData.flatMap { instance =>
val repeats = sampler.timesInTrainingSet(instance.id, instance.timestamp, index)
if (repeats > 0) {
tree.leafFor(instance.features).map { leaf =>
(1 to repeats).toList.map { i => (instance, leaf) }
}.getOrElse(Nil)
} else {
Nil
}
}.groupBy { _._2 }
.mapValues { _.map { _._1 } }
fn(tree, index, byLeaf)
val treeMap = trees.zipWithIndex.map{case (t,i) => i->t}.toMap

def expand(maxLeavesPerTree: Int)(implicit splitter: Splitter[V, T], evaluator: Evaluator[V, T], stopper: Stopper[T]): Trainer[K, V, T] = {
val allStats = treeMap.map{case (treeIndex, tree) =>
treeIndex -> SampledMap[Int,Map[K,splitter.S]](maxLeavesPerTree)
}
copy(trees = newTrees.toList)
}

private def updateLeaves(fn: (Int, (Int, T, Unit), Iterable[Instance[K, V, T]]) => Node[K, V, T, Unit]): Trainer[K, V, T] = {
updateTrees {
case (tree, treeIndex, byLeaf) =>
val newNodes = byLeaf.map {
case (leaf, instances) =>
val (index, _, _) = leaf
index -> fn(treeIndex, leaf, instances)
trainingData.foreach{instance =>
lazy val features = instance.features.mapValues { value => splitter.create(value, instance.target) }

for (
(treeIndex, tree) <- treeMap.toList;
treeStats <- allStats.get(treeIndex).toList;
i <- 1.to(sampler.timesInTrainingSet(instance.id, instance.timestamp, treeIndex)).toList;
(leafIndex, target, annotation) <- tree.leafFor(instance.features).toList
if stopper.shouldSplit(target) && treeStats.containsKey(leafIndex);
(feature, stats) <- features
if sampler.includeFeature(feature, treeIndex, leafIndex)
) {
var leafStats = treeStats.get(leafIndex).getOrElse(Map[K,splitter.S]())
val combined = leafStats.get(feature) match {
case Some(old) => splitter.semigroup.plus(old, stats)
case None => stats
}
leafStats += feature -> combined
treeStats.update(leafIndex, leafStats)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are using this imperatively, I think it would be more efficient to do this explicitly (you wouldn't need to call .toList internally everywhere):

treeMap.foreach { case (treeIndex, tree) =>
  allStats.get(treeIndex).foreach { treeStats =>
    val times = sampler.timesInTrainingSet(instance.id, instance.timestamp, treeIndex)
    (1 to times).foreach { i =>
      tree.leafFor(instance.features).foreach { case (leafIndex, target, annotation) =>
        if (stopper.shouldSplit(target) && treeStats.containsKey(leafIndex)) {
          features.foreach { case (feature, stats) =>
            if (sampler.includeFeature(feature, treeIndex, leafIndex)) {
              var leafStats = treeStats.get(leafIndex).getOrElse(Map[K,splitter.S]())
              val combined = leafStats.get(feature) match {
                case Some(old) => splitter.semigroup.plus(old, stats)
                case None => stats
              }
              leafStats += feature -> combined
              treeStats.update(leafIndex, leafStats)
            }
          }
        }
      }
    }
  }
}

(I realize that syntactically it looks uglier, but it's also a bit clearer what is happening and will have much better performance than the previous code which was essentially doing .toList.foreach instead of .foreach, and with extra .filter calls.)


tree.updateByLeafIndex(newNodes.lift)
val newTreeMap = allStats.map{case (treeIndex, treeStats) =>
val tree = treeMap(treeIndex)
treeIndex -> tree.growByLeafIndex{leafIndex =>
val candidates = for(
leafStats <- treeStats.get(leafIndex).toList;
parent <- tree.leafAt(leafIndex).toList;
(feature, stats) <- leafStats.toList;
split <- splitter.split(parent.target, stats);
(newSplit, score) <- evaluator.evaluate(split).toList
) yield (newSplit.createSplitNode(feature), score)

if(candidates.isEmpty)
None
else
Some(candidates.maxBy{_._2}._1)
}
}

val newTrees = 0.until(trees.size).toList.map{i => newTreeMap(i)}
Trainer(trainingData, sampler, newTrees)
}

def updateTargets: Trainer[K, V, T] =
updateLeaves {
case (treeIndex, (index, _, annotation), instances) =>
val target = implicitly[Monoid[T]].sum(instances.map { _.target })
LeafNode(index, target, annotation)
def updateTargets: Trainer[K, V, T] = {
var targets = treeMap.mapValues{tree => Map[Int, T]()}
trainingData.foreach{instance =>
for (
(treeIndex, tree) <- treeMap.toList;
i <- 1.to(sampler.timesInTrainingSet(instance.id, instance.timestamp, treeIndex)).toList;
leafIndex <- tree.leafIndexFor(instance.features).toList
) {
val treeTargets = targets(treeIndex)
val old = treeTargets.getOrElse(leafIndex, Monoid.zero[T])
val combined = Monoid.plus(instance.target, old)
targets += treeIndex -> (treeTargets + (leafIndex -> combined))
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above -- I think converting this to .foreach would be more efficient.

}

def expand(times: Int)(implicit splitter: Splitter[V, T], evaluator: Evaluator[V, T], stopper: Stopper[T]): Trainer[K, V, T] =
updateLeaves {
case (treeIndex, (index, target, annotation), instances) =>
Tree.expand(times, treeIndex, LeafNode[K, V, T, Unit](index, target, annotation), splitter, evaluator, stopper, sampler, instances)
val newTrees = trees.zipWithIndex.map{case (tree, index) =>
tree.updateByLeafIndex{leafIndex =>
val target = targets(index).getOrElse(leafIndex, Monoid.zero[T])
Some(LeafNode(leafIndex, target, ()))
}
}

def prune[P, E](error: Error[T, P, E])(implicit voter: Voter[T, P], ord: Ordering[E]): Trainer[K, V, T] =
updateTrees {
case (tree, treeIndex, byLeaf) =>
val byLeafIndex = byLeaf.map {
case ((index, _, _), instances) =>
index -> implicitly[Monoid[T]].sum(instances.map { _.target })
}
tree.prune(byLeafIndex, voter, error)
}
copy(trees = newTrees)
}

def validate[P, E](error: Error[T, P, E])(implicit voter: Voter[T, P]): Option[E] = {
val errors = trainingData.flatMap { instance =>
val useTrees = trees.zipWithIndex.filter {
case (tree, i) =>
sampler.includeInValidationSet(instance.id, instance.timestamp, i)
}.map { _._1 }
if(useTrees.isEmpty)
None
else {
val prediction = voter.predict(useTrees, instance.features)
Some(error.create(instance.target, prediction))
var output: Option[E] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we expect that we won't have a monoid for error? It seems like using Monoid[E] here would be preferable unless that was a non-starter.

trainingData.foreach{instance =>
val predictions =
for (
(treeIndex, tree) <- treeMap
if sampler.includeInValidationSet(instance.id, instance.timestamp, treeIndex);
target <- tree.targetFor(instance.features).toList
) yield target

if(!predictions.isEmpty) {
val newError = error.create(instance.target, voter.combine(predictions))
output = output
.map{old => error.semigroup.plus(old, newError)}
.orElse(Some(newError))
}
}
error.semigroup.sumOption(errors)

output
}
}

Expand Down
6 changes: 3 additions & 3 deletions example/iris-local
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/sh
java -Xmx2G -cp ../brushfire-scalding/target/target/brushfire-scalding-0.6.0-SNAPSHOT-jar-with-dependencies.jar \
java -Xmx2G -cp ../brushfire-scalding/target/scala-2.10/brushfire-scalding-0.7.1-SNAPSHOT-jar-with-dependencies.jar \
com.stripe.brushfire.local.Example \
petal-width petal-length sepal-width sepal-length \
< iris.data
iris.data \
petal-width petal-length sepal-width sepal-length