Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #2186: make IndexedStateT stack safe #2187

Merged
merged 10 commits into from
Mar 14, 2018
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions bench/src/main/scala/cats/bench/StateTBench.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
package cats.bench

import cats.Eval
import cats.data.StateT
import org.openjdk.jmh.annotations._

/**
* To run:
*
* bench/jmh:run -i 10 -wi 10 -f 2 -t 1 cats.bench.StateTBench
*/
@State(Scope.Thread)
@BenchmarkMode(Array(Mode.Throughput))
class StateTBench {
@Param(Array("10"))
var count: Int = _

@Benchmark
def single(): Long = {
randLong.run(32311).value._2
}

@Benchmark
def repeatedLeftBinds(): Int = {
var state = randInt
var i = 0
while (i < count) {
state = state.flatMap(int => randInt.map(_ + int))
i += 1
}
state.run(32312).value._2
}

@Benchmark
def repeatedRightBinds(): Int = {
var state = randInt
var i = 0
while (i < count) {
val oldS = state
state = randInt.flatMap(int => oldS.map(_ + int))
i += 1
}
state.run(32313).value._2
}

def fn(seed: Long): Eval[(Long, Int)] =
Eval.now {
val newSeed = (seed * 0x5DEECE66DL + 0xBL) & 0xFFFFFFFFFFFFL
val n = (newSeed >>> 16).toInt
(newSeed, n)
}

val randInt: StateT[Eval, Long, Int] =
StateT(fn)

val randLong: StateT[Eval, Long, Long] =
for {
int1 <- randInt
int2 <- randInt
} yield {
(int1.toLong << 32) | int2
}
}
112 changes: 112 additions & 0 deletions core/src/main/scala/cats/data/AndThen.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package cats.data

import java.io.Serializable

/**
* A function type of a single input that can do function composition
* (via `andThen` and `compose`) in constant stack space with amortized
* linear time application (in the number of constituent functions).
*
* Example:
*
* {{{
* val seed = AndThen((x: Int) => x + 1))
* val f = (0 until 10000).foldLeft(seed)((acc, _) => acc.andThen(_ + 1))
*
* // This should not trigger stack overflow ;-)
* f(0)
* }}}
*/
private[cats] sealed abstract class AndThen[-T, +R]
extends (T => R) with Product with Serializable {

import AndThen._

final def apply(a: T): R =
runLoop(a)

override def andThen[A](g: R => A): AndThen[T, A] = {
// Fusing calls up to a certain threshold, using the fusion
// technique implemented for `cats.effect.IO#map`
this match {
case Single(f, index) if index != 127 =>
Single(f.andThen(g), index + 1)
case _ =>
andThenF(AndThen(g))
}
}

override def compose[A](g: A => T): AndThen[A, R] = {
// Fusing calls up to a certain threshold, using the fusion
// technique implemented for `cats.effect.IO#map`
this match {
case Single(f, index) if index != 127 =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick, can we make this 127 (and at line 32) a constant with a sensible name?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion. I've updated the code with a documented constant.

Single(f.compose(g), index + 1)
case _ =>
composeF(AndThen(g))
}
}

private def runLoop(start: T): R = {
var self: AndThen[Any, Any] = this.asInstanceOf[AndThen[Any, Any]]
var current: Any = start.asInstanceOf[Any]
var continue = true

while (continue) {
self match {
case Single(f, _) =>
current = f(current)
continue = false

case Concat(Single(f, _), right) =>
current = f(current)
self = right.asInstanceOf[AndThen[Any, Any]]

case Concat(left @ Concat(_, _), right) =>
self = left.rotateAccum(right)
}
}
current.asInstanceOf[R]
}

private final def andThenF[X](right: AndThen[R, X]): AndThen[T, X] =
Concat(this, right)
private final def composeF[X](right: AndThen[X, T]): AndThen[X, R] =
Concat(right, this)

// converts left-leaning to right-leaning
protected final def rotateAccum[E](_right: AndThen[R, E]): AndThen[T, E] = {
var self: AndThen[Any, Any] = this.asInstanceOf[AndThen[Any, Any]]
var right: AndThen[Any, Any] = _right.asInstanceOf[AndThen[Any, Any]]
var continue = true
while (continue) {
self match {
case Concat(left, inner) =>
self = left.asInstanceOf[AndThen[Any, Any]]
right = inner.andThenF(right)

case _ => // Single
self = self.andThenF(right)
continue = false
}
}
self.asInstanceOf[AndThen[T, E]]
}

override def toString: String =
"AndThen$" + System.identityHashCode(this)
}

private[cats] object AndThen {
/** Builds an [[AndThen]] reference by wrapping a plain function. */
def apply[A, B](f: A => B): AndThen[A, B] =
f match {
case ref: AndThen[A, B] @unchecked => ref
case _ => Single(f, 0)
}

private final case class Single[-A, +B](f: A => B, index: Int)
extends AndThen[A, B]
private final case class Concat[-A, E, +B](left: AndThen[A, E], right: AndThen[E, B])
extends AndThen[A, B]
}
4 changes: 2 additions & 2 deletions core/src/main/scala/cats/data/IndexedStateT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ final class IndexedStateT[F[_], SA, SB, A](val runF: F[SA => F[(SB, A)]]) extend

def flatMap[B, SC](fas: A => IndexedStateT[F, SB, SC, B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SC, B] =
IndexedStateT.applyF(F.map(runF) { safsba =>
safsba.andThen { fsba =>
AndThen(safsba).andThen { fsba =>
F.flatMap(fsba) { case (sb, a) =>
fas(a).run(sb)
}
Expand All @@ -31,7 +31,7 @@ final class IndexedStateT[F[_], SA, SB, A](val runF: F[SA => F[(SB, A)]]) extend

def flatMapF[B](faf: A => F[B])(implicit F: FlatMap[F]): IndexedStateT[F, SA, SB, B] =
IndexedStateT.applyF(F.map(runF) { sfsa =>
sfsa.andThen { fsa =>
AndThen(sfsa).andThen { fsa =>
F.flatMap(fsa) { case (s, a) => F.map(faf(a))((s, _)) }
}
})
Expand Down
53 changes: 53 additions & 0 deletions tests/src/test/scala/cats/tests/AndThenSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package cats.tests

import catalysts.Platform
import cats.data._

class AndThenSuite extends CatsSuite {
test("compose a chain of functions with andThen") {
check { (i: Int, fs: List[Int => Int]) =>
val result = fs.map(AndThen(_)).reduceOption(_.andThen(_)).map(_(i))
val expect = fs.reduceOption(_.andThen(_)).map(_(i))

result == expect
}
}

test("compose a chain of functions with compose") {
check { (i: Int, fs: List[Int => Int]) =>
val result = fs.map(AndThen(_)).reduceOption(_.compose(_)).map(_(i))
val expect = fs.reduceOption(_.compose(_)).map(_(i))

result == expect
}
}

test("andThen is stack safe") {
val count = if (Platform.isJvm) 500000 else 1000
val fs = (0 until count).map(_ => { i: Int => i + 1 })
val result = fs.foldLeft(AndThen((x: Int) => x))(_.andThen(_))(42)

result shouldEqual (count + 42)
}

test("compose is stack safe") {
val count = if (Platform.isJvm) 500000 else 1000
val fs = (0 until count).map(_ => { i: Int => i + 1 })
val result = fs.foldLeft(AndThen((x: Int) => x))(_.compose(_))(42)

result shouldEqual (count + 42)
}

test("Function1 andThen is stack safe") {
val count = if (Platform.isJvm) 50000 else 1000
val start: (Int => Int) = AndThen((x: Int) => x)
val fs = (0 until count).foldLeft(start) { (acc, _) =>
acc.andThen(_ + 1)
}
fs(0) shouldEqual count
}

test("toString") {
AndThen((x: Int) => x).toString should startWith("AndThen$")
}
}
19 changes: 18 additions & 1 deletion tests/src/test/scala/cats/tests/IndexedStateTSuite.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package cats
package tests

import catalysts.Platform
import cats.arrow.{Profunctor, Strong}
import cats.data.{EitherT, IndexedStateT, State, StateT}

import cats.arrow.Profunctor
import cats.kernel.instances.tuple._
import cats.laws.discipline._
Expand Down Expand Up @@ -251,6 +251,23 @@ class IndexedStateTSuite extends CatsSuite {
got should === (expected)
}

test("flatMap is stack safe on repeated left binds when F is") {
val unit = StateT.pure[Eval, Unit, Unit](())
val count = if (Platform.isJvm) 100000 else 100
val result = (0 until count).foldLeft(unit) { (acc, _) =>
acc.flatMap(_ => unit)
}
result.run(()).value should === (((), ()))
}

test("flatMap is stack safe on repeated right binds when F is") {
val unit = StateT.pure[Eval, Unit, Unit](())
val count = if (Platform.isJvm) 100000 else 100
val result = (0 until count).foldLeft(unit) { (acc, _) =>
unit.flatMap(_ => acc)
}
result.run(()).value should === (((), ()))
}

implicit val iso = SemigroupalTests.Isomorphisms.invariant[IndexedStateT[ListWrapper, String, Int, ?]](IndexedStateT.catsDataFunctorForIndexedStateT(ListWrapper.monad))

Expand Down