From dab28d71a7a106767fe557f3647ab03d4721466f Mon Sep 17 00:00:00 2001 From: Erik Osheim Date: Wed, 27 Sep 2017 20:39:24 -0400 Subject: [PATCH] Add a stack safety stress test for Eval. (#1888) * Add a stack safety stress test for Eval. The basic idea is to create very deep chains of operations to try to expose any stack problems. We know that for things like map/flatMap Eval seems to work, so let's just randomly construct deeply-nested Eval expressions to see what happens. This test exposed a weakness with .memoize which is fixed. As part of this commit, I noticed that our Arbitrary[Eval[A]] instances were somewhat weak, so I upgraded them. * Reduce depth to get Travis passing. * Better memoization support for flatMap. This fix was proposed by @johnynek and fixes a bug I introduced with how memoization was handled during flatMap evaluation. Our old implementation did memoize intermediate values correctly (which meant that three different evals mapped from the same source shared a memoized value). However, it was not stack-safe. My fix introduced stack-safety but broke this intermediate memoization. The new approach uses a new node (Eval.Memoize) which can be mutably-updated (much like Later). It's confirmed to be stack-safe and to handle intermediate values correctly. While it was unlikely that anyone was doing enough intermediate memoization to cause actual stack overflows, it's nice to know that this is now impossible. * Rename Compute to FlatMap and Call to Defer. --- core/src/main/scala/cats/Eval.scala | 157 ++++++++++++------ .../cats/laws/discipline/Arbitrary.scala | 6 +- .../src/test/scala/cats/tests/EvalTests.scala | 109 +++++++++++- 3 files changed, 213 insertions(+), 59 deletions(-) diff --git a/core/src/main/scala/cats/Eval.scala b/core/src/main/scala/cats/Eval.scala index 91c142bf0e..bf6529d3c3 100644 --- a/core/src/main/scala/cats/Eval.scala +++ b/core/src/main/scala/cats/Eval.scala @@ -72,28 +72,28 @@ sealed abstract class Eval[+A] extends Serializable { self => */ def flatMap[B](f: A => Eval[B]): Eval[B] = this match { - case c: Eval.Compute[A] => - new Eval.Compute[B] { + case c: Eval.FlatMap[A] => + new Eval.FlatMap[B] { type Start = c.Start // See https://issues.scala-lang.org/browse/SI-9931 for an explanation // of why the type annotations are necessary in these two lines on // Scala 2.12.0. val start: () => Eval[Start] = c.start val run: Start => Eval[B] = (s: c.Start) => - new Eval.Compute[B] { + new Eval.FlatMap[B] { type Start = A val start = () => c.run(s) val run = f } } - case c: Eval.Call[A] => - new Eval.Compute[B] { + case c: Eval.Defer[A] => + new Eval.FlatMap[B] { type Start = A val start = c.thunk val run = f } case _ => - new Eval.Compute[B] { + new Eval.FlatMap[B] { type Start = A val start = () => self val run = f @@ -203,7 +203,7 @@ object Eval extends EvalInstances { * which produces an Eval[A] value. Like .flatMap, it is stack-safe. */ def defer[A](a: => Eval[A]): Eval[A] = - new Eval.Call[A](a _) {} + new Eval.Defer[A](a _) {} /** * Static Eval instance for common value `Unit`. @@ -246,47 +246,52 @@ object Eval extends EvalInstances { val One: Eval[Int] = Now(1) /** - * Call is a type of Eval[A] that is used to defer computations + * Defer is a type of Eval[A] that is used to defer computations * which produce Eval[A]. * - * Users should not instantiate Call instances themselves. Instead, + * Users should not instantiate Defer instances themselves. Instead, * they will be automatically created when needed. */ - sealed abstract class Call[A](val thunk: () => Eval[A]) extends Eval[A] { - def memoize: Eval[A] = new Later(() => value) - def value: A = Call.loop(this).value - } + sealed abstract class Defer[A](val thunk: () => Eval[A]) extends Eval[A] { - object Call { + def memoize: Eval[A] = Memoize(this) + def value: A = evaluate(this) + } - /** - * Collapse the call stack for eager evaluations. - */ - @tailrec private def loop[A](fa: Eval[A]): Eval[A] = fa match { - case call: Eval.Call[A] => - loop(call.thunk()) - case compute: Eval.Compute[A] => - new Eval.Compute[A] { + /** + * Advance until we find a non-deferred Eval node. + * + * Often we may have deep chains of Defer nodes; the goal here is to + * advance through those to find the underlying "work" (in the case + * of FlatMap nodes) or "value" (in the case of Now, Later, or + * Always nodes). + */ + @tailrec private def advance[A](fa: Eval[A]): Eval[A] = + fa match { + case call: Eval.Defer[A] => + advance(call.thunk()) + case compute: Eval.FlatMap[A] => + new Eval.FlatMap[A] { type Start = compute.Start val start: () => Eval[Start] = () => compute.start() - val run: Start => Eval[A] = s => loop1(compute.run(s)) + val run: Start => Eval[A] = s => advance1(compute.run(s)) } case other => other } - /** - * Alias for loop that can be called in a non-tail position - * from an otherwise tailrec-optimized loop. - */ - private def loop1[A](fa: Eval[A]): Eval[A] = loop(fa) - } + /** + * Alias for advance that can be called in a non-tail position + * from an otherwise tailrec-optimized advance. + */ + private def advance1[A](fa: Eval[A]): Eval[A] = + advance(fa) /** - * Compute is a type of Eval[A] that is used to chain computations + * FlatMap is a type of Eval[A] that is used to chain computations * involving .map and .flatMap. Along with Eval#flatMap it * implements the trampoline that guarantees stack-safety. * - * Users should not instantiate Compute instances + * Users should not instantiate FlatMap instances * themselves. Instead, they will be automatically created when * needed. * @@ -294,35 +299,77 @@ object Eval extends EvalInstances { * trampoline are not exposed. This allows a slightly more efficient * implementation of the .value method. */ - sealed abstract class Compute[A] extends Eval[A] { + sealed abstract class FlatMap[A] extends Eval[A] { self => type Start val start: () => Eval[Start] val run: Start => Eval[A] - def memoize: Eval[A] = Later(value) - - def value: A = { - type L = Eval[Any] - type C = Any => Eval[Any] - @tailrec def loop(curr: L, fs: List[C]): Any = - curr match { - case c: Compute[_] => - c.start() match { - case cc: Compute[_] => - loop( - cc.start().asInstanceOf[L], - cc.run.asInstanceOf[C] :: c.run.asInstanceOf[C] :: fs) - case xx => - loop(c.run(xx.value), fs) - } - case x => - fs match { - case f :: fs => loop(f(x.value), fs) - case Nil => x.value - } - } - loop(this.asInstanceOf[L], Nil).asInstanceOf[A] + def memoize: Eval[A] = Memoize(this) + def value: A = evaluate(this) + } + + private case class Memoize[A](eval: Eval[A]) extends Eval[A] { + var result: Option[A] = None + def memoize: Eval[A] = this + def value: A = + result match { + case Some(a) => a + case None => + val a = evaluate(this) + result = Some(a) + a + } + } + + + private def evaluate[A](e: Eval[A]): A = { + type L = Eval[Any] + type M = Memoize[Any] + type C = Any => Eval[Any] + + def addToMemo(m: M): C = { a: Any => + m.result = Some(a) + Now(a) } + + @tailrec def loop(curr: L, fs: List[C]): Any = + curr match { + case c: FlatMap[_] => + c.start() match { + case cc: FlatMap[_] => + loop( + cc.start().asInstanceOf[L], + cc.run.asInstanceOf[C] :: c.run.asInstanceOf[C] :: fs) + case mm@Memoize(eval) => + mm.result match { + case Some(a) => + loop(Now(a), c.run.asInstanceOf[C] :: fs) + case None => + loop(eval, addToMemo(mm.asInstanceOf[M]) :: c.run.asInstanceOf[C] :: fs) + } + case xx => + loop(c.run(xx.value), fs) + } + case call: Defer[_] => + loop(advance(call), fs) + case m@Memoize(eval) => + m.result match { + case Some(a) => + fs match { + case f :: fs => loop(f(a), fs) + case Nil => a + } + case None => + loop(eval, addToMemo(m) :: fs) + } + case x => + fs match { + case f :: fs => loop(f(x.value), fs) + case Nil => x.value + } + } + + loop(e.asInstanceOf[L], Nil).asInstanceOf[A] } } diff --git a/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala b/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala index 679d2028a5..ba14df7cd5 100644 --- a/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala +++ b/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala @@ -90,9 +90,9 @@ object arbitrary extends ArbitraryInstances0 { implicit def catsLawsArbitraryForEval[A: Arbitrary]: Arbitrary[Eval[A]] = Arbitrary(Gen.oneOf( - getArbitrary[A].map(Eval.now(_)), - getArbitrary[A].map(Eval.later(_)), - getArbitrary[A].map(Eval.always(_)))) + getArbitrary[A].map(a => Eval.now(a)), + getArbitrary[() => A].map(f => Eval.later(f())), + getArbitrary[() => A].map(f => Eval.always(f())))) implicit def catsLawsCogenForEval[A: Cogen]: Cogen[Eval[A]] = Cogen[A].contramap(_.value) diff --git a/tests/src/test/scala/cats/tests/EvalTests.scala b/tests/src/test/scala/cats/tests/EvalTests.scala index d67d6500a6..a49ba4a3af 100644 --- a/tests/src/test/scala/cats/tests/EvalTests.scala +++ b/tests/src/test/scala/cats/tests/EvalTests.scala @@ -1,11 +1,14 @@ package cats package tests -import scala.math.min import cats.laws.ComonadLaws import cats.laws.discipline.{BimonadTests, CartesianTests, ReducibleTests, SerializableTests} import cats.laws.discipline.arbitrary._ import cats.kernel.laws.{GroupLaws, OrderLaws} +import org.scalacheck.{Arbitrary, Cogen, Gen} +import org.scalacheck.Arbitrary.arbitrary +import scala.annotation.tailrec +import scala.math.min class EvalTests extends CatsSuite { implicit val eqThrow: Eq[Throwable] = Eq.allEqual @@ -140,4 +143,108 @@ class EvalTests extends CatsSuite { isEq.lhs should === (isEq.rhs) } } + + // the following machinery is all to faciliate testing deeply-nested + // eval values for stack safety. the idea is that we want to + // randomly generate deep chains of eval operations. + // + // there are three ways to construct Eval[A] values from expressions + // returning A (and which are generated by Arbitrary[Eval[A]]): + // + // - Eval.now(...) + // - Eval.later(...) + // - Eval.always(...) + // + // there are four operations that transform expressions returning + // Eval[A] into a new Eval[A] value: + // + // - (...).map(f) + // - (...).flatMap(g) + // - (...).memoize + // - Eval.defer(...) + // + // the O[A] ast represents these four operations. we generate a very + // long Vector[O[A]] and a starting () => Eval[A] expression (which + // we call a "leaf") and then compose these to produce one + // (deeply-nested) Eval[A] value, which we wrap in DeepEval(_). + + case class DeepEval[A](eval: Eval[A]) + + object DeepEval { + + sealed abstract class O[A] + + case class OMap[A](f: A => A) extends O[A] + case class OFlatMap[A](f: A => Eval[A]) extends O[A] + case class OMemoize[A]() extends O[A] + case class ODefer[A]() extends O[A] + + implicit def arbitraryO[A: Arbitrary: Cogen]: Arbitrary[O[A]] = + Arbitrary(Gen.oneOf( + arbitrary[A => A].map(OMap(_)), + arbitrary[A => Eval[A]].map(OFlatMap(_)), + Gen.const(OMemoize[A]), + Gen.const(ODefer[A]))) + + def build[A](leaf: () => Eval[A], os: Vector[O[A]]): DeepEval[A] = { + + def restart(i: Int, leaf: () => Eval[A], cbs: List[Eval[A] => Eval[A]]): Eval[A] = + step(i, leaf, cbs) + + @tailrec def step(i: Int, leaf: () => Eval[A], cbs: List[Eval[A] => Eval[A]]): Eval[A] = + if (i >= os.length) cbs.foldLeft(leaf())((e, f) => f(e)) + else os(i) match { + case ODefer() => Eval.defer(restart(i + 1, leaf, cbs)) + case OMemoize() => step(i + 1, leaf, ((e: Eval[A]) => e.memoize) :: cbs) + case OMap(f) => step(i + 1, leaf, ((e: Eval[A]) => e.map(f)) :: cbs) + case OFlatMap(f) => step(i + 1, leaf, ((e: Eval[A]) => e.flatMap(f)) :: cbs) + } + + DeepEval(step(0, leaf, Nil)) + } + + // we keep this low in master to keep travis happy. + // for an actual stress test increase to 200K or so. + val MaxDepth = 100 + + implicit def arbitraryDeepEval[A: Arbitrary: Cogen]: Arbitrary[DeepEval[A]] = { + val gen: Gen[O[A]] = arbitrary[O[A]] + Arbitrary(for { + leaf <- arbitrary[() => Eval[A]] + xs <- Gen.containerOfN[Vector, O[A]](MaxDepth, gen) + } yield DeepEval.build(leaf, xs)) + } + } + + // all that work for this one little test. + + test("stack safety stress test") { + forAll { (d: DeepEval[Int]) => + try { + d.eval.value + succeed + } catch { case (e: StackOverflowError) => + fail(s"stack overflowed with eval-depth ${DeepEval.MaxDepth}") + } + } + } + + test("memoize handles branched evaluation correctly") { + forAll { (e: Eval[Int], fn: Int => Eval[Int]) => + var n0 = 0 + val a0 = e.flatMap { i => n0 += 1; fn(i); }.memoize + assert(a0.flatMap(i1 => a0.map(i1 == _)).value == true) + assert(n0 == 1) + + var n1 = 0 + val a1 = Eval.defer { n1 += 1; fn(0) }.memoize + assert(a1.flatMap(i1 => a1.map(i1 == _)).value == true) + assert(n1 == 1) + + var n2 = 0 + val a2 = Eval.defer { n2 += 1; fn(0) }.memoize + assert(Eval.defer(a2).value == Eval.defer(a2).value) + assert(n2 == 1) + } + } }