Skip to content

Commit

Permalink
foldLeftM without Free. (#1117)
Browse files Browse the repository at this point in the history
  • Loading branch information
TomasMikula authored and kailuowang committed Jun 21, 2017
1 parent 4ba557c commit e3a969e
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 46 deletions.
61 changes: 53 additions & 8 deletions core/src/main/scala/cats/Foldable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,31 @@ import simulacrum.typeclass
foldLeft(fa, B.empty)((b, a) => B.combine(b, f(a)))

/**
* Left associative monadic folding on `F`.
* Perform a stack-safe monadic left fold from the source context `F`
* into the target monad `G`.
*
* The default implementation of this is based on `foldLeft`, and thus will
* always fold across the entire structure. Certain structures are able to
* implement this in such a way that folds can be short-circuited (not
* traverse the entirety of the structure), depending on the `G` result
* produced at a given step.
* This method can express short-circuiting semantics. Even when
* `fa` is an infinite structure, this method can potentially
* terminate if the `foldRight` implementation for `F` and the
* `tailRecM` implementation for `G` are sufficiently lazy.
*
* Instances for concrete structures (e.g. `List`) will often
* have a more efficient implementation than the default one
* in terms of `foldRight`.
*/
def foldM[G[_], A, B](fa: F[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] = {
val src = Foldable.Source.fromFoldable(fa)(self)
G.tailRecM((z, src)) { case (b, src) => src.uncons match {
case Some((a, src)) => G.map(f(b, a))(b => Left((b, src)))
case None => G.pure(Right(b))
}}
}

/**
* Alias for [[foldM]].
*/
def foldM[G[_], A, B](fa: F[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
foldLeft(fa, G.pure(z))((gb, a) => G.flatMap(gb)(f(_, a)))
final def foldLeftM[G[_], A, B](fa: F[A], z: B)(f: (B, A) => G[B])(implicit G: Monad[G]): G[B] =
foldM(fa, z)(f)

/**
* Monadic folding on `F` by mapping `A` values to `G[B]`, combining the `B`
Expand Down Expand Up @@ -433,4 +448,34 @@ object Foldable {
}
M.tailRecM(z)(go)
}


/**
* Isomorphic to
*
* type Source[+A] = () => Option[(A, Source[A])]
*
* (except that recursive type aliases are not allowed).
*
* It could be made a value class after
* https://github.com/scala/bug/issues/9600 is resolved.
*/
private sealed abstract class Source[+A] {
def uncons: Option[(A, Source[A])]
}

private object Source {
val Empty: Source[Nothing] = new Source[Nothing] {
def uncons = None
}

def cons[A](a: A, src: Eval[Source[A]]): Source[A] = new Source[A] {
def uncons = Some((a, src.value))
}

def fromFoldable[F[_], A](fa: F[A])(implicit F: Foldable[F]): Source[A] =
F.foldRight[A, Source[A]](fa, Now(Empty))((a, evalSrc) =>
Later(cons(a, evalSrc))
).value
}
}
19 changes: 0 additions & 19 deletions free/src/main/scala/cats/free/Free.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,23 +249,4 @@ object Free {
override def map[A, B](fa: Free[S, A])(f: A => B): Free[S, B] = fa.map(f)
def flatMap[A, B](a: Free[S, A])(f: A => Free[S, B]): Free[S, B] = a.flatMap(f)
}

/**
* Perform a stack-safe monadic fold from the source context `F`
* into the target monad `G`.
*
* This method can express short-circuiting semantics. Even when
* `fa` is an infinite structure, this method can potentially
* terminate if the `foldRight` implementation for `F` and the
* `tailRecM` implementation for `G` are sufficiently lazy.
*/
def foldLeftM[F[_]: Foldable, G[_]: Monad, A, B](fa: F[A], z: B)(f: (B, A) => G[B]): G[B] =
unsafeFoldLeftM[F, Free[G, ?], A, B](fa, z) { (b, a) =>
Free.liftF(f(b, a))
}.runTailRec

private def unsafeFoldLeftM[F[_], G[_], A, B](fa: F[A], z: B)(f: (B, A) => G[B])(implicit F: Foldable[F], G: Monad[G]): G[B] =
F.foldRight(fa, Always((w: B) => G.pure(w))) { (a, lb) =>
Always((w: B) => G.flatMap(f(w, a))(lb.value))
}.value.apply(z)
}
19 changes: 0 additions & 19 deletions free/src/test/scala/cats/free/FreeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,6 @@ class FreeTests extends CatsSuite {
assert(res == List(112358))
}

test(".foldLeftM") {
// you can see .foldLeftM traversing the entire structure by
// changing the constant argument to .take and observing the time
// this test takes.
val ns = Stream.from(1).take(1000)
val res = Free.foldLeftM[Stream, Either[Int, ?], Int, Int](ns, 0) { (sum, n) =>
if (sum >= 2) Either.left(sum) else Either.right(sum + n)
}
assert(res == Either.left(3))
}

test(".foldLeftM short-circuiting") {
val ns = Stream.continually(1)
val res = Free.foldLeftM[Stream, Either[Int, ?], Int, Int](ns, 0) { (sum, n) =>
if (sum >= 100000) Either.left(sum) else Either.right(sum + n)
}
assert(res == Either.left(100000))
}

sealed trait Test1Algebra[A]

case class Test1[A](value : Int, f: Int => A) extends Test1Algebra[A]
Expand Down
22 changes: 22 additions & 0 deletions tests/src/test/scala/cats/tests/FoldableTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,28 @@ class FoldableTestsAdditional extends CatsSuite {
// test laziness of foldM
dangerous.foldM(0)((acc, a) => if (a < 2) Some(acc + a) else None) should === (None)
}

test(".foldLeftM short-circuiting") {
val ns = Stream.continually(1)
val res = Foldable[Stream].foldLeftM[Either[Int, ?], Int, Int](ns, 0) { (sum, n) =>
if (sum >= 100000) Left(sum) else Right(sum + n)
}
assert(res == Left(100000))
}

test(".foldLeftM short-circuiting optimality") {
// test that no more elements are evaluated than absolutely necessary

def concatUntil(ss: Stream[String], stop: String): Either[String, String] =
Foldable[Stream].foldLeftM[Either[String, ?], String, String](ss, "") { (acc, s) =>
if (s == stop) Left(acc) else Right(acc + s)
}

def boom: Stream[String] = sys.error("boom")
assert(concatUntil("STOP" #:: boom, "STOP") == Left(""))
assert(concatUntil("Zero" #:: "STOP" #:: boom, "STOP") == Left("Zero"))
assert(concatUntil("Zero" #:: "One" #:: "STOP" #:: boom, "STOP") == Left("ZeroOne"))
}
}

class FoldableListCheck extends FoldableCheck[List]("list") {
Expand Down

0 comments on commit e3a969e

Please sign in to comment.