diff --git a/core/src/main/scala/cats/Eval.scala b/core/src/main/scala/cats/Eval.scala index 91c142bf0e..bf6529d3c3 100644 --- a/core/src/main/scala/cats/Eval.scala +++ b/core/src/main/scala/cats/Eval.scala @@ -72,28 +72,28 @@ sealed abstract class Eval[+A] extends Serializable { self => */ def flatMap[B](f: A => Eval[B]): Eval[B] = this match { - case c: Eval.Compute[A] => - new Eval.Compute[B] { + case c: Eval.FlatMap[A] => + new Eval.FlatMap[B] { type Start = c.Start // See https://issues.scala-lang.org/browse/SI-9931 for an explanation // of why the type annotations are necessary in these two lines on // Scala 2.12.0. val start: () => Eval[Start] = c.start val run: Start => Eval[B] = (s: c.Start) => - new Eval.Compute[B] { + new Eval.FlatMap[B] { type Start = A val start = () => c.run(s) val run = f } } - case c: Eval.Call[A] => - new Eval.Compute[B] { + case c: Eval.Defer[A] => + new Eval.FlatMap[B] { type Start = A val start = c.thunk val run = f } case _ => - new Eval.Compute[B] { + new Eval.FlatMap[B] { type Start = A val start = () => self val run = f @@ -203,7 +203,7 @@ object Eval extends EvalInstances { * which produces an Eval[A] value. Like .flatMap, it is stack-safe. */ def defer[A](a: => Eval[A]): Eval[A] = - new Eval.Call[A](a _) {} + new Eval.Defer[A](a _) {} /** * Static Eval instance for common value `Unit`. @@ -246,47 +246,52 @@ object Eval extends EvalInstances { val One: Eval[Int] = Now(1) /** - * Call is a type of Eval[A] that is used to defer computations + * Defer is a type of Eval[A] that is used to defer computations * which produce Eval[A]. * - * Users should not instantiate Call instances themselves. Instead, + * Users should not instantiate Defer instances themselves. Instead, * they will be automatically created when needed. */ - sealed abstract class Call[A](val thunk: () => Eval[A]) extends Eval[A] { - def memoize: Eval[A] = new Later(() => value) - def value: A = Call.loop(this).value - } + sealed abstract class Defer[A](val thunk: () => Eval[A]) extends Eval[A] { - object Call { + def memoize: Eval[A] = Memoize(this) + def value: A = evaluate(this) + } - /** - * Collapse the call stack for eager evaluations. - */ - @tailrec private def loop[A](fa: Eval[A]): Eval[A] = fa match { - case call: Eval.Call[A] => - loop(call.thunk()) - case compute: Eval.Compute[A] => - new Eval.Compute[A] { + /** + * Advance until we find a non-deferred Eval node. + * + * Often we may have deep chains of Defer nodes; the goal here is to + * advance through those to find the underlying "work" (in the case + * of FlatMap nodes) or "value" (in the case of Now, Later, or + * Always nodes). + */ + @tailrec private def advance[A](fa: Eval[A]): Eval[A] = + fa match { + case call: Eval.Defer[A] => + advance(call.thunk()) + case compute: Eval.FlatMap[A] => + new Eval.FlatMap[A] { type Start = compute.Start val start: () => Eval[Start] = () => compute.start() - val run: Start => Eval[A] = s => loop1(compute.run(s)) + val run: Start => Eval[A] = s => advance1(compute.run(s)) } case other => other } - /** - * Alias for loop that can be called in a non-tail position - * from an otherwise tailrec-optimized loop. - */ - private def loop1[A](fa: Eval[A]): Eval[A] = loop(fa) - } + /** + * Alias for advance that can be called in a non-tail position + * from an otherwise tailrec-optimized advance. + */ + private def advance1[A](fa: Eval[A]): Eval[A] = + advance(fa) /** - * Compute is a type of Eval[A] that is used to chain computations + * FlatMap is a type of Eval[A] that is used to chain computations * involving .map and .flatMap. Along with Eval#flatMap it * implements the trampoline that guarantees stack-safety. * - * Users should not instantiate Compute instances + * Users should not instantiate FlatMap instances * themselves. Instead, they will be automatically created when * needed. * @@ -294,35 +299,77 @@ object Eval extends EvalInstances { * trampoline are not exposed. This allows a slightly more efficient * implementation of the .value method. */ - sealed abstract class Compute[A] extends Eval[A] { + sealed abstract class FlatMap[A] extends Eval[A] { self => type Start val start: () => Eval[Start] val run: Start => Eval[A] - def memoize: Eval[A] = Later(value) - - def value: A = { - type L = Eval[Any] - type C = Any => Eval[Any] - @tailrec def loop(curr: L, fs: List[C]): Any = - curr match { - case c: Compute[_] => - c.start() match { - case cc: Compute[_] => - loop( - cc.start().asInstanceOf[L], - cc.run.asInstanceOf[C] :: c.run.asInstanceOf[C] :: fs) - case xx => - loop(c.run(xx.value), fs) - } - case x => - fs match { - case f :: fs => loop(f(x.value), fs) - case Nil => x.value - } - } - loop(this.asInstanceOf[L], Nil).asInstanceOf[A] + def memoize: Eval[A] = Memoize(this) + def value: A = evaluate(this) + } + + private case class Memoize[A](eval: Eval[A]) extends Eval[A] { + var result: Option[A] = None + def memoize: Eval[A] = this + def value: A = + result match { + case Some(a) => a + case None => + val a = evaluate(this) + result = Some(a) + a + } + } + + + private def evaluate[A](e: Eval[A]): A = { + type L = Eval[Any] + type M = Memoize[Any] + type C = Any => Eval[Any] + + def addToMemo(m: M): C = { a: Any => + m.result = Some(a) + Now(a) } + + @tailrec def loop(curr: L, fs: List[C]): Any = + curr match { + case c: FlatMap[_] => + c.start() match { + case cc: FlatMap[_] => + loop( + cc.start().asInstanceOf[L], + cc.run.asInstanceOf[C] :: c.run.asInstanceOf[C] :: fs) + case mm@Memoize(eval) => + mm.result match { + case Some(a) => + loop(Now(a), c.run.asInstanceOf[C] :: fs) + case None => + loop(eval, addToMemo(mm.asInstanceOf[M]) :: c.run.asInstanceOf[C] :: fs) + } + case xx => + loop(c.run(xx.value), fs) + } + case call: Defer[_] => + loop(advance(call), fs) + case m@Memoize(eval) => + m.result match { + case Some(a) => + fs match { + case f :: fs => loop(f(a), fs) + case Nil => a + } + case None => + loop(eval, addToMemo(m) :: fs) + } + case x => + fs match { + case f :: fs => loop(f(x.value), fs) + case Nil => x.value + } + } + + loop(e.asInstanceOf[L], Nil).asInstanceOf[A] } } diff --git a/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala b/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala index 679d2028a5..ba14df7cd5 100644 --- a/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala +++ b/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala @@ -90,9 +90,9 @@ object arbitrary extends ArbitraryInstances0 { implicit def catsLawsArbitraryForEval[A: Arbitrary]: Arbitrary[Eval[A]] = Arbitrary(Gen.oneOf( - getArbitrary[A].map(Eval.now(_)), - getArbitrary[A].map(Eval.later(_)), - getArbitrary[A].map(Eval.always(_)))) + getArbitrary[A].map(a => Eval.now(a)), + getArbitrary[() => A].map(f => Eval.later(f())), + getArbitrary[() => A].map(f => Eval.always(f())))) implicit def catsLawsCogenForEval[A: Cogen]: Cogen[Eval[A]] = Cogen[A].contramap(_.value) diff --git a/tests/src/test/scala/cats/tests/EvalTests.scala b/tests/src/test/scala/cats/tests/EvalTests.scala index d67d6500a6..a49ba4a3af 100644 --- a/tests/src/test/scala/cats/tests/EvalTests.scala +++ b/tests/src/test/scala/cats/tests/EvalTests.scala @@ -1,11 +1,14 @@ package cats package tests -import scala.math.min import cats.laws.ComonadLaws import cats.laws.discipline.{BimonadTests, CartesianTests, ReducibleTests, SerializableTests} import cats.laws.discipline.arbitrary._ import cats.kernel.laws.{GroupLaws, OrderLaws} +import org.scalacheck.{Arbitrary, Cogen, Gen} +import org.scalacheck.Arbitrary.arbitrary +import scala.annotation.tailrec +import scala.math.min class EvalTests extends CatsSuite { implicit val eqThrow: Eq[Throwable] = Eq.allEqual @@ -140,4 +143,108 @@ class EvalTests extends CatsSuite { isEq.lhs should === (isEq.rhs) } } + + // the following machinery is all to faciliate testing deeply-nested + // eval values for stack safety. the idea is that we want to + // randomly generate deep chains of eval operations. + // + // there are three ways to construct Eval[A] values from expressions + // returning A (and which are generated by Arbitrary[Eval[A]]): + // + // - Eval.now(...) + // - Eval.later(...) + // - Eval.always(...) + // + // there are four operations that transform expressions returning + // Eval[A] into a new Eval[A] value: + // + // - (...).map(f) + // - (...).flatMap(g) + // - (...).memoize + // - Eval.defer(...) + // + // the O[A] ast represents these four operations. we generate a very + // long Vector[O[A]] and a starting () => Eval[A] expression (which + // we call a "leaf") and then compose these to produce one + // (deeply-nested) Eval[A] value, which we wrap in DeepEval(_). + + case class DeepEval[A](eval: Eval[A]) + + object DeepEval { + + sealed abstract class O[A] + + case class OMap[A](f: A => A) extends O[A] + case class OFlatMap[A](f: A => Eval[A]) extends O[A] + case class OMemoize[A]() extends O[A] + case class ODefer[A]() extends O[A] + + implicit def arbitraryO[A: Arbitrary: Cogen]: Arbitrary[O[A]] = + Arbitrary(Gen.oneOf( + arbitrary[A => A].map(OMap(_)), + arbitrary[A => Eval[A]].map(OFlatMap(_)), + Gen.const(OMemoize[A]), + Gen.const(ODefer[A]))) + + def build[A](leaf: () => Eval[A], os: Vector[O[A]]): DeepEval[A] = { + + def restart(i: Int, leaf: () => Eval[A], cbs: List[Eval[A] => Eval[A]]): Eval[A] = + step(i, leaf, cbs) + + @tailrec def step(i: Int, leaf: () => Eval[A], cbs: List[Eval[A] => Eval[A]]): Eval[A] = + if (i >= os.length) cbs.foldLeft(leaf())((e, f) => f(e)) + else os(i) match { + case ODefer() => Eval.defer(restart(i + 1, leaf, cbs)) + case OMemoize() => step(i + 1, leaf, ((e: Eval[A]) => e.memoize) :: cbs) + case OMap(f) => step(i + 1, leaf, ((e: Eval[A]) => e.map(f)) :: cbs) + case OFlatMap(f) => step(i + 1, leaf, ((e: Eval[A]) => e.flatMap(f)) :: cbs) + } + + DeepEval(step(0, leaf, Nil)) + } + + // we keep this low in master to keep travis happy. + // for an actual stress test increase to 200K or so. + val MaxDepth = 100 + + implicit def arbitraryDeepEval[A: Arbitrary: Cogen]: Arbitrary[DeepEval[A]] = { + val gen: Gen[O[A]] = arbitrary[O[A]] + Arbitrary(for { + leaf <- arbitrary[() => Eval[A]] + xs <- Gen.containerOfN[Vector, O[A]](MaxDepth, gen) + } yield DeepEval.build(leaf, xs)) + } + } + + // all that work for this one little test. + + test("stack safety stress test") { + forAll { (d: DeepEval[Int]) => + try { + d.eval.value + succeed + } catch { case (e: StackOverflowError) => + fail(s"stack overflowed with eval-depth ${DeepEval.MaxDepth}") + } + } + } + + test("memoize handles branched evaluation correctly") { + forAll { (e: Eval[Int], fn: Int => Eval[Int]) => + var n0 = 0 + val a0 = e.flatMap { i => n0 += 1; fn(i); }.memoize + assert(a0.flatMap(i1 => a0.map(i1 == _)).value == true) + assert(n0 == 1) + + var n1 = 0 + val a1 = Eval.defer { n1 += 1; fn(0) }.memoize + assert(a1.flatMap(i1 => a1.map(i1 == _)).value == true) + assert(n1 == 1) + + var n2 = 0 + val a2 = Eval.defer { n2 += 1; fn(0) }.memoize + assert(Eval.defer(a2).value == Eval.defer(a2).value) + assert(n2 == 1) + } + } }