diff --git a/core/src/main/scala/cats/FlatMapRec.scala b/core/src/main/scala/cats/FlatMapRec.scala new file mode 100644 index 0000000000..5658b3fe00 --- /dev/null +++ b/core/src/main/scala/cats/FlatMapRec.scala @@ -0,0 +1,24 @@ +package cats + +import simulacrum.typeclass + +import cats.data.Xor + +/** + * Version of [[cats.FlatMap]] capable of stack-safe recursive `flatMap`s. + * + * Based on Phil Freeman's + * [[http://functorial.com/stack-safety-for-free/index.pdf Stack Safety for Free]]. + */ +@typeclass trait FlatMapRec[F[_]] extends FlatMap[F] { + + /** + * Keeps calling `f` until a `[[cats.data.Xor.Right Right]][B]` is returned. + * + * Implementations of this method must use constant stack space. + * + * `f` must use constant stack space. (It is OK to use a constant number of + * `map`s and `flatMap`s inside `f`.) + */ + def tailRecM[A, B](a: A)(f: A => F[A Xor B]): F[B] +} diff --git a/core/src/main/scala/cats/MonadRec.scala b/core/src/main/scala/cats/MonadRec.scala new file mode 100644 index 0000000000..e61588c8da --- /dev/null +++ b/core/src/main/scala/cats/MonadRec.scala @@ -0,0 +1,5 @@ +package cats + +import simulacrum.typeclass + +@typeclass trait MonadRec[F[_]] extends Monad[F] with FlatMapRec[F] diff --git a/core/src/main/scala/cats/data/OptionT.scala b/core/src/main/scala/cats/data/OptionT.scala index 91575a4510..2754604656 100644 --- a/core/src/main/scala/cats/data/OptionT.scala +++ b/core/src/main/scala/cats/data/OptionT.scala @@ -132,7 +132,7 @@ object OptionT extends OptionTInstances { def liftF[F[_], A](fa: F[A])(implicit F: Functor[F]): OptionT[F, A] = OptionT(F.map(fa)(Some(_))) } -private[data] sealed trait OptionTInstances1 { +private[data] sealed trait OptionTInstances2 { implicit def catsDataFunctorForOptionT[F[_]:Functor]: Functor[OptionT[F, ?]] = new Functor[OptionT[F, ?]] { override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] = @@ -148,18 +148,15 @@ private[data] sealed trait OptionTInstances1 { } } -private[data] sealed trait OptionTInstances extends OptionTInstances1 { - - implicit def catsDataMonadForOptionT[F[_]](implicit F: Monad[F]): Monad[OptionT[F, ?]] = - new Monad[OptionT[F, ?]] { - def pure[A](a: A): OptionT[F, A] = OptionT.pure(a) +private[data] sealed trait OptionTInstances1 extends OptionTInstances2 { - def flatMap[A, B](fa: OptionT[F, A])(f: A => OptionT[F, B]): OptionT[F, B] = - fa.flatMap(f) + implicit def catsDataMonadForOptionT[F[_]](implicit F0: Monad[F]): Monad[OptionT[F, ?]] = + new OptionTMonad[F] { implicit val F = F0 } +} - override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] = - fa.map(f) - } +private[data] sealed trait OptionTInstances extends OptionTInstances1 { + implicit def catsDataMonadRecForOptionT[F[_]](implicit F0: MonadRec[F]): MonadRec[OptionT[F, ?]] = + new OptionTMonadRec[F] { implicit val F = F0 } implicit def catsDataEqForOptionT[F[_], A](implicit FA: Eq[F[Option[A]]]): Eq[OptionT[F, A]] = FA.on(_.value) @@ -167,3 +164,26 @@ private[data] sealed trait OptionTInstances extends OptionTInstances1 { implicit def catsDataShowForOptionT[F[_], A](implicit F: Show[F[Option[A]]]): Show[OptionT[F, A]] = functor.Contravariant[Show].contramap(F)(_.value) } + +private[data] trait OptionTMonad[F[_]] extends Monad[OptionT[F, ?]] { + implicit val F: Monad[F] + + def pure[A](a: A): OptionT[F, A] = OptionT.pure(a) + + def flatMap[A, B](fa: OptionT[F, A])(f: A => OptionT[F, B]): OptionT[F, B] = + fa.flatMap(f) + + override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] = + fa.map(f) +} + +private[data] trait OptionTMonadRec[F[_]] extends MonadRec[OptionT[F, ?]] with OptionTMonad[F] { + implicit val F: MonadRec[F] + + def tailRecM[A, B](a: A)(f: A => OptionT[F, A Xor B]): OptionT[F, B] = + OptionT(F.tailRecM(a)(a0 => F.map(f(a0).value){ + case None => Xor.Right(None) + case Some(Xor.Left(a1)) => Xor.Left(a1) + case Some(Xor.Right(b)) => Xor.Right(Some(b)) + })) +} diff --git a/core/src/main/scala/cats/data/Xor.scala b/core/src/main/scala/cats/data/Xor.scala index d8a6a4a2b3..08136b4eb0 100644 --- a/core/src/main/scala/cats/data/Xor.scala +++ b/core/src/main/scala/cats/data/Xor.scala @@ -1,6 +1,7 @@ package cats package data +import scala.annotation.tailrec import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} @@ -233,13 +234,19 @@ private[data] sealed abstract class XorInstances extends XorInstances1 { } } - implicit def catsDataInstancesForXor[A]: Traverse[A Xor ?] with MonadError[Xor[A, ?], A] = - new Traverse[A Xor ?] with MonadError[Xor[A, ?], A] { + implicit def catsDataInstancesForXor[A]: Traverse[A Xor ?] with MonadRec[A Xor ?] with MonadError[Xor[A, ?], A] = + new Traverse[A Xor ?] with MonadRec[A Xor ?] with MonadError[Xor[A, ?], A] { def traverse[F[_]: Applicative, B, C](fa: A Xor B)(f: B => F[C]): F[A Xor C] = fa.traverse(f) def foldLeft[B, C](fa: A Xor B, c: C)(f: (C, B) => C): C = fa.foldLeft(c)(f) def foldRight[B, C](fa: A Xor B, lc: Eval[C])(f: (B, Eval[C]) => Eval[C]): Eval[C] = fa.foldRight(lc)(f) def flatMap[B, C](fa: A Xor B)(f: B => A Xor C): A Xor C = fa.flatMap(f) def pure[B](b: B): A Xor B = Xor.right(b) + @tailrec def tailRecM[B, C](b: B)(f: B => A Xor (B Xor C)): A Xor C = + f(b) match { + case Xor.Left(a) => Xor.Left(a) + case Xor.Right(Xor.Left(b1)) => tailRecM(b1)(f) + case Xor.Right(Xor.Right(c)) => Xor.Right(c) + } def handleErrorWith[B](fea: Xor[A, B])(f: A => Xor[A, B]): Xor[A, B] = fea match { case Xor.Left(e) => f(e) diff --git a/core/src/main/scala/cats/data/XorT.scala b/core/src/main/scala/cats/data/XorT.scala index 555f9c1b39..91feb688c4 100644 --- a/core/src/main/scala/cats/data/XorT.scala +++ b/core/src/main/scala/cats/data/XorT.scala @@ -279,6 +279,11 @@ private[data] abstract class XorTInstances1 extends XorTInstances2 { } private[data] abstract class XorTInstances2 extends XorTInstances3 { + implicit def catsDataMonadRecForXorT[F[_], L](implicit F0: MonadRec[F]): MonadRec[XorT[F, L, ?]] = + new XorTMonadRec[F, L] { implicit val F = F0 } +} + +private[data] abstract class XorTInstances3 extends XorTInstances4 { implicit def catsDataMonadErrorForXorT[F[_], L](implicit F: Monad[F]): MonadError[XorT[F, L, ?], L] = { implicit val F0 = F new XorTMonadError[F, L] { implicit val F = F0 } @@ -299,7 +304,7 @@ private[data] abstract class XorTInstances2 extends XorTInstances3 { } } -private[data] abstract class XorTInstances3 { +private[data] abstract class XorTInstances4 { implicit def catsDataFunctorForXorT[F[_], L](implicit F: Functor[F]): Functor[XorT[F, L, ?]] = { implicit val F0 = F new XorTFunctor[F, L] { implicit val F = F0 } @@ -311,10 +316,13 @@ private[data] trait XorTFunctor[F[_], L] extends Functor[XorT[F, L, ?]] { override def map[A, B](fa: XorT[F, L, A])(f: A => B): XorT[F, L, B] = fa map f } -private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L] with XorTFunctor[F, L] { +private[data] trait XorTMonad[F[_], L] extends Monad[XorT[F, L, ?]] with XorTFunctor[F, L] { implicit val F: Monad[F] def pure[A](a: A): XorT[F, L, A] = XorT.pure[F, L, A](a) def flatMap[A, B](fa: XorT[F, L, A])(f: A => XorT[F, L, B]): XorT[F, L, B] = fa flatMap f +} + +private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L] with XorTMonad[F, L] { def handleErrorWith[A](fea: XorT[F, L, A])(f: L => XorT[F, L, A]): XorT[F, L, A] = XorT(F.flatMap(fea.value) { case Xor.Left(e) => f(e).value @@ -333,6 +341,16 @@ private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L] fla.recoverWith(pf) } +private[data] trait XorTMonadRec[F[_], L] extends MonadRec[XorT[F, L, ?]] with XorTMonad[F, L] { + implicit val F: MonadRec[F] + def tailRecM[A, B](a: A)(f: A => XorT[F, L, A Xor B]): XorT[F, L, B] = + XorT(F.tailRecM(a)(a0 => F.map(f(a0).value){ + case Xor.Left(l) => Xor.Right(Xor.Left(l)) + case Xor.Right(Xor.Left(a1)) => Xor.Left(a1) + case Xor.Right(Xor.Right(b)) => Xor.Right(Xor.Right(b)) + })) +} + private[data] trait XorTMonadFilter[F[_], L] extends MonadFilter[XorT[F, L, ?]] with XorTMonadError[F, L] { implicit val F: Monad[F] implicit val L: Monoid[L] diff --git a/core/src/main/scala/cats/package.scala b/core/src/main/scala/cats/package.scala index d1e7ad5730..cc90780828 100644 --- a/core/src/main/scala/cats/package.scala +++ b/core/src/main/scala/cats/package.scala @@ -1,3 +1,6 @@ +import scala.annotation.tailrec +import cats.data.Xor + /** * Symbolic aliases for various types are defined here. */ @@ -26,12 +29,16 @@ package object cats { * encodes pure unary function application. */ type Id[A] = A - implicit val idInstances: Bimonad[Id] with Traverse[Id] = - new Bimonad[Id] with Traverse[Id] { + implicit val idInstances: Bimonad[Id] with MonadRec[Id] with Traverse[Id] = + new Bimonad[Id] with MonadRec[Id] with Traverse[Id] { def pure[A](a: A): A = a def extract[A](a: A): A = a def flatMap[A, B](a: A)(f: A => B): B = f(a) def coflatMap[A, B](a: A)(f: A => B): B = f(a) + @tailrec def tailRecM[A, B](a: A)(f: A => A Xor B): B = f(a) match { + case Xor.Left(a1) => tailRecM(a1)(f) + case Xor.Right(b) => b + } override def map[A, B](fa: A)(f: A => B): B = f(fa) override def ap[A, B](ff: A => B)(fa: A): B = ff(fa) override def flatten[A](ffa: A): A = ffa diff --git a/core/src/main/scala/cats/std/either.scala b/core/src/main/scala/cats/std/either.scala index 43962597bf..019a4dc29f 100644 --- a/core/src/main/scala/cats/std/either.scala +++ b/core/src/main/scala/cats/std/either.scala @@ -1,6 +1,9 @@ package cats package std +import scala.annotation.tailrec +import cats.data.Xor + trait EitherInstances extends EitherInstances1 { implicit val catsStdBitraverseForEither: Bitraverse[Either] = new Bitraverse[Either] { @@ -23,8 +26,8 @@ trait EitherInstances extends EitherInstances1 { } } - implicit def catsStdInstancesForEither[A]: Monad[Either[A, ?]] with Traverse[Either[A, ?]] = - new Monad[Either[A, ?]] with Traverse[Either[A, ?]] { + implicit def catsStdInstancesForEither[A]: MonadRec[Either[A, ?]] with Traverse[Either[A, ?]] = + new MonadRec[Either[A, ?]] with Traverse[Either[A, ?]] { def pure[B](b: B): Either[A, B] = Right(b) def flatMap[B, C](fa: Either[A, B])(f: B => Either[A, C]): Either[A, C] = @@ -33,6 +36,14 @@ trait EitherInstances extends EitherInstances1 { override def map[B, C](fa: Either[A, B])(f: B => C): Either[A, C] = fa.right.map(f) + @tailrec + def tailRecM[B, C](b: B)(f: B => Either[A, B Xor C]): Either[A, C] = + f(b) match { + case Left(a) => Left(a) + case Right(Xor.Left(b1)) => tailRecM(b1)(f) + case Right(Xor.Right(c)) => Right(c) + } + override def map2Eval[B, C, Z](fb: Either[A, B], fc: Eval[Either[A, C]])(f: (B, C) => Z): Eval[Either[A, Z]] = fb match { // This should be safe, but we are forced to use `asInstanceOf`, diff --git a/core/src/main/scala/cats/std/list.scala b/core/src/main/scala/cats/std/list.scala index 102ad05e71..6154c7d494 100644 --- a/core/src/main/scala/cats/std/list.scala +++ b/core/src/main/scala/cats/std/list.scala @@ -6,10 +6,12 @@ import cats.syntax.show._ import scala.annotation.tailrec import scala.collection.mutable.ListBuffer +import cats.data.Xor + trait ListInstances extends cats.kernel.std.ListInstances { - implicit val catsStdInstancesForList: Traverse[List] with MonadCombine[List] with CoflatMap[List] = - new Traverse[List] with MonadCombine[List] with CoflatMap[List] { + implicit val catsStdInstancesForList: Traverse[List] with MonadCombine[List] with MonadRec[List] with CoflatMap[List] = + new Traverse[List] with MonadCombine[List] with MonadRec[List] with CoflatMap[List] { def empty[A]: List[A] = Nil @@ -26,6 +28,20 @@ trait ListInstances extends cats.kernel.std.ListInstances { override def map2[A, B, Z](fa: List[A], fb: List[B])(f: (A, B) => Z): List[Z] = fa.flatMap(a => fb.map(b => f(a, b))) + def tailRecM[A, B](a: A)(f: A => List[A Xor B]): List[B] = { + val buf = List.newBuilder[B] + @tailrec def go(lists: List[List[A Xor B]]): Unit = lists match { + case (ab :: abs) :: tail => ab match { + case Xor.Right(b) => buf += b; go(abs :: tail) + case Xor.Left(a) => go(f(a) :: abs :: tail) + } + case Nil :: tail => go(tail) + case Nil => () + } + go(f(a) :: Nil) + buf.result + } + def coflatMap[A, B](fa: List[A])(f: List[A] => B): List[B] = { @tailrec def loop(buf: ListBuffer[B], as: List[A]): List[B] = as match { diff --git a/core/src/main/scala/cats/std/option.scala b/core/src/main/scala/cats/std/option.scala index af943b923b..41204c5326 100644 --- a/core/src/main/scala/cats/std/option.scala +++ b/core/src/main/scala/cats/std/option.scala @@ -1,10 +1,13 @@ package cats package std +import scala.annotation.tailrec +import cats.data.Xor + trait OptionInstances extends cats.kernel.std.OptionInstances { - implicit val catsStdInstancesForOption: Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with CoflatMap[Option] with Alternative[Option] = - new Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with CoflatMap[Option] with Alternative[Option] { + implicit val catsStdInstancesForOption: Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with MonadRec[Option] with CoflatMap[Option] with Alternative[Option] = + new Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with MonadRec[Option] with CoflatMap[Option] with Alternative[Option] { def empty[A]: Option[A] = None @@ -18,6 +21,14 @@ trait OptionInstances extends cats.kernel.std.OptionInstances { def flatMap[A, B](fa: Option[A])(f: A => Option[B]): Option[B] = fa.flatMap(f) + @tailrec + def tailRecM[A, B](a: A)(f: A => Option[A Xor B]): Option[B] = + f(a) match { + case None => None + case Some(Xor.Left(a1)) => tailRecM(a1)(f) + case Some(Xor.Right(b)) => Some(b) + } + override def map2[A, B, Z](fa: Option[A], fb: Option[B])(f: (A, B) => Z): Option[Z] = fa.flatMap(a => fb.map(b => f(a, b))) diff --git a/free/src/main/scala/cats/free/Free.scala b/free/src/main/scala/cats/free/Free.scala index d6a14d3a70..c4e2233b7b 100644 --- a/free/src/main/scala/cats/free/Free.scala +++ b/free/src/main/scala/cats/free/Free.scala @@ -40,11 +40,16 @@ object Free { /** * `Free[S, ?]` has a monad for any type constructor `S[_]`. */ - implicit def freeMonad[S[_]]: Monad[Free[S, ?]] = - new Monad[Free[S, ?]] { + implicit def freeMonad[S[_]]: MonadRec[Free[S, ?]] = + new MonadRec[Free[S, ?]] { def pure[A](a: A): Free[S, A] = Free.pure(a) override def map[A, B](fa: Free[S, A])(f: A => B): Free[S, B] = fa.map(f) def flatMap[A, B](a: Free[S, A])(f: A => Free[S, B]): Free[S, B] = a.flatMap(f) + def tailRecM[A, B](a: A)(f: A => Free[S, A Xor B]): Free[S, B] = + f(a).flatMap(_ match { + case Xor.Left(a1) => tailRecM(a1)(f) // recursion OK here, since Free is lazy + case Xor.Right(b) => pure(b) + }) } } diff --git a/free/src/test/scala/cats/free/FreeTests.scala b/free/src/test/scala/cats/free/FreeTests.scala index 1b9c7d63be..5291c60d5c 100644 --- a/free/src/test/scala/cats/free/FreeTests.scala +++ b/free/src/test/scala/cats/free/FreeTests.scala @@ -3,7 +3,8 @@ package free import cats.tests.CatsSuite import cats.arrow.NaturalTransformation -import cats.laws.discipline.{CartesianTests, MonadTests, SerializableTests} +import cats.data.Xor +import cats.laws.discipline.{CartesianTests, MonadRecTests, SerializableTests} import cats.laws.discipline.arbitrary.function0Arbitrary import org.scalacheck.{Arbitrary, Gen} @@ -14,8 +15,8 @@ class FreeTests extends CatsSuite { implicit val iso = CartesianTests.Isomorphisms.invariant[Free[Option, ?]] - checkAll("Free[Option, ?]", MonadTests[Free[Option, ?]].monad[Int, Int, Int]) - checkAll("Monad[Free[Option, ?]]", SerializableTests.serializable(Monad[Free[Option, ?]])) + checkAll("Free[Option, ?]", MonadRecTests[Free[Option, ?]].monadRec[Int, Int, Int]) + checkAll("MonadRec[Free[Option, ?]]", SerializableTests.serializable(MonadRec[Free[Option, ?]])) test("mapSuspension id"){ forAll { x: Free[List, Int] => @@ -43,6 +44,13 @@ class FreeTests extends CatsSuite { } } + test("tailRecM is stack safe") { + val n = 50000 + val fa = MonadRec[Free[Option, ?]].tailRecM(0)(i => + Free.pure[Option, Int Xor Int](if(i < n) Xor.Left(i+1) else Xor.Right(i))) + fa should === (Free.pure[Option, Int](n)) + } + ignore("foldMap is stack safe") { trait FTestApi[A] case class TB(i: Int) extends FTestApi[Int] diff --git a/laws/src/main/scala/cats/laws/FlatMapRecLaws.scala b/laws/src/main/scala/cats/laws/FlatMapRecLaws.scala new file mode 100644 index 0000000000..8ff0254262 --- /dev/null +++ b/laws/src/main/scala/cats/laws/FlatMapRecLaws.scala @@ -0,0 +1,26 @@ +package cats +package laws + +import cats.data.Xor +import cats.syntax.flatMap._ +import cats.syntax.functor._ + +/** + * Laws that must be obeyed by any `FlatMapRec`. + */ +trait FlatMapRecLaws[F[_]] extends FlatMapLaws[F] { + implicit override def F: FlatMapRec[F] + + def tailRecMConsistentFlatMap[A](a: A, f: A => F[A]): IsEq[F[A]] = { + val bounce = F.tailRecM[(A, Int), A]((a, 1)) { case (a0, i) => + if(i > 0) f(a0).map(a1 => Xor.left((a1, i-1))) + else f(a0).map(Xor.right) + } + bounce <-> f(a).flatMap(f) + } +} + +object FlatMapRecLaws { + def apply[F[_]](implicit ev: FlatMapRec[F]): FlatMapRecLaws[F] = + new FlatMapRecLaws[F] { def F: FlatMapRec[F] = ev } +} diff --git a/laws/src/main/scala/cats/laws/MonadRecLaws.scala b/laws/src/main/scala/cats/laws/MonadRecLaws.scala new file mode 100644 index 0000000000..2b2e2e90a1 --- /dev/null +++ b/laws/src/main/scala/cats/laws/MonadRecLaws.scala @@ -0,0 +1,14 @@ +package cats +package laws + +/** + * Laws that must be obeyed by any `MonadRec`. + */ +trait MonadRecLaws[F[_]] extends MonadLaws[F] with FlatMapRecLaws[F] { + implicit override def F: MonadRec[F] +} + +object MonadRecLaws { + def apply[F[_]](implicit ev: MonadRec[F]): MonadRecLaws[F] = + new MonadRecLaws[F] { def F: MonadRec[F] = ev } +} diff --git a/laws/src/main/scala/cats/laws/discipline/FlatMapRecTests.scala b/laws/src/main/scala/cats/laws/discipline/FlatMapRecTests.scala new file mode 100644 index 0000000000..c4d2c3a2b5 --- /dev/null +++ b/laws/src/main/scala/cats/laws/discipline/FlatMapRecTests.scala @@ -0,0 +1,35 @@ +package cats +package laws +package discipline + +import cats.laws.discipline.CartesianTests.Isomorphisms +import org.scalacheck.Arbitrary +import org.scalacheck.Prop +import Prop._ + +trait FlatMapRecTests[F[_]] extends FlatMapTests[F] { + def laws: FlatMapRecLaws[F] + + def flatMapRec[A: Arbitrary, B: Arbitrary, C: Arbitrary](implicit + ArbFA: Arbitrary[F[A]], + ArbFB: Arbitrary[F[B]], + ArbFC: Arbitrary[F[C]], + ArbFAtoB: Arbitrary[F[A => B]], + ArbFBtoC: Arbitrary[F[B => C]], + EqFA: Eq[F[A]], + EqFB: Eq[F[B]], + EqFC: Eq[F[C]], + EqFABC: Eq[F[(A, B, C)]], + iso: Isomorphisms[F] + ): RuleSet = { + new DefaultRuleSet( + name = "flatMapRec", + parent = Some(flatMap[A, B, C]), + "tailRecM consistent flatMap" -> forAll(laws.tailRecMConsistentFlatMap[A] _)) + } +} + +object FlatMapRecTests { + def apply[F[_]: FlatMapRec]: FlatMapRecTests[F] = + new FlatMapRecTests[F] { def laws: FlatMapRecLaws[F] = FlatMapRecLaws[F] } +} diff --git a/laws/src/main/scala/cats/laws/discipline/MonadRecTests.scala b/laws/src/main/scala/cats/laws/discipline/MonadRecTests.scala new file mode 100644 index 0000000000..f6904ca6e4 --- /dev/null +++ b/laws/src/main/scala/cats/laws/discipline/MonadRecTests.scala @@ -0,0 +1,38 @@ +package cats +package laws +package discipline + +import cats.laws.discipline.CartesianTests.Isomorphisms +import org.scalacheck.Arbitrary +import org.scalacheck.Prop + +trait MonadRecTests[F[_]] extends MonadTests[F] with FlatMapRecTests[F] { + def laws: MonadRecLaws[F] + + def monadRec[A: Arbitrary: Eq, B: Arbitrary: Eq, C: Arbitrary: Eq](implicit + ArbFA: Arbitrary[F[A]], + ArbFB: Arbitrary[F[B]], + ArbFC: Arbitrary[F[C]], + ArbFAtoB: Arbitrary[F[A => B]], + ArbFBtoC: Arbitrary[F[B => C]], + EqFA: Eq[F[A]], + EqFB: Eq[F[B]], + EqFC: Eq[F[C]], + EqFABC: Eq[F[(A, B, C)]], + iso: Isomorphisms[F] + ): RuleSet = { + new RuleSet { + def name: String = "monadRec" + def bases: Seq[(String, RuleSet)] = Nil + def parents: Seq[RuleSet] = Seq(monad[A, B, C], flatMapRec[A, B, C]) + def props: Seq[(String, Prop)] = Nil + } + } +} + +object MonadRecTests { + def apply[F[_]: MonadRec]: MonadRecTests[F] = + new MonadRecTests[F] { + def laws: MonadRecLaws[F] = MonadRecLaws[F] + } +} diff --git a/tests/src/test/scala/cats/tests/EitherTests.scala b/tests/src/test/scala/cats/tests/EitherTests.scala index 97d7d74e08..b39fc91040 100644 --- a/tests/src/test/scala/cats/tests/EitherTests.scala +++ b/tests/src/test/scala/cats/tests/EitherTests.scala @@ -1,7 +1,7 @@ package cats package tests -import cats.laws.discipline.{BitraverseTests, TraverseTests, MonadTests, SerializableTests, CartesianTests} +import cats.laws.discipline.{BitraverseTests, TraverseTests, MonadRecTests, SerializableTests, CartesianTests} import cats.kernel.laws.OrderLaws class EitherTests extends CatsSuite { @@ -11,8 +11,8 @@ class EitherTests extends CatsSuite { checkAll("Either[Int, Int]", CartesianTests[Either[Int, ?]].cartesian[Int, Int, Int]) checkAll("Cartesian[Either[Int, ?]]", SerializableTests.serializable(Cartesian[Either[Int, ?]])) - checkAll("Either[Int, Int]", MonadTests[Either[Int, ?]].monad[Int, Int, Int]) - checkAll("Monad[Either[Int, ?]]", SerializableTests.serializable(Monad[Either[Int, ?]])) + checkAll("Either[Int, Int]", MonadRecTests[Either[Int, ?]].monadRec[Int, Int, Int]) + checkAll("MonadRec[Either[Int, ?]]", SerializableTests.serializable(MonadRec[Either[Int, ?]])) checkAll("Either[Int, Int] with Option", TraverseTests[Either[Int, ?]].traverse[Int, Int, Int, Int, Option, Option]) checkAll("Traverse[Either[Int, ?]", SerializableTests.serializable(Traverse[Either[Int, ?]])) diff --git a/tests/src/test/scala/cats/tests/IdTests.scala b/tests/src/test/scala/cats/tests/IdTests.scala index 115ccc875e..c7421e1369 100644 --- a/tests/src/test/scala/cats/tests/IdTests.scala +++ b/tests/src/test/scala/cats/tests/IdTests.scala @@ -9,6 +9,9 @@ class IdTests extends CatsSuite { checkAll("Id[Int]", BimonadTests[Id].bimonad[Int, Int, Int]) checkAll("Bimonad[Id]", SerializableTests.serializable(Bimonad[Id])) + checkAll("Id[Int]", MonadRecTests[Id].monadRec[Int, Int, Int]) + checkAll("MonadRec[Id]", SerializableTests.serializable(MonadRec[Id])) + checkAll("Id[Int]", TraverseTests[Id].traverse[Int, Int, Int, Int, Option, Option]) checkAll("Traverse[Id]", SerializableTests.serializable(Traverse[Id])) } diff --git a/tests/src/test/scala/cats/tests/MonadRecInstancesTests.scala b/tests/src/test/scala/cats/tests/MonadRecInstancesTests.scala new file mode 100644 index 0000000000..649cd60c01 --- /dev/null +++ b/tests/src/test/scala/cats/tests/MonadRecInstancesTests.scala @@ -0,0 +1,41 @@ +package cats +package tests + +import cats.data.{OptionT, Xor, XorT} + +class MonadRecInstancesTests extends CatsSuite { + def tailRecMStackSafety[M[_]](implicit M: MonadRec[M], Eq: Eq[M[Int]]): Unit = { + val n = 50000 + val res = M.tailRecM(0)(i => M.pure(if(i < n) Xor.Left(i + 1) else Xor.Right(i))) + res should === (M.pure(n)) + } + + test("tailRecM stack-safety for Id") { + tailRecMStackSafety[Id] + } + + test("tailRecM stack-safety for Option") { + tailRecMStackSafety[Option] + } + + test("tailRecM stack-safety for OptionT") { + tailRecMStackSafety[OptionT[Option, ?]] + } + + test("tailRecM stack-safety for Either") { + tailRecMStackSafety[Either[String, ?]] + } + + test("tailRecM stack-safety for Xor") { + tailRecMStackSafety[String Xor ?] + } + + test("tailRecM stack-safety for XorT") { + tailRecMStackSafety[XorT[Option, String, ?]] + } + + test("tailRecM stack-safety for List") { + tailRecMStackSafety[List] + } + +} diff --git a/tests/src/test/scala/cats/tests/OptionTTests.scala b/tests/src/test/scala/cats/tests/OptionTTests.scala index 95cb4fac7d..9d12864e88 100644 --- a/tests/src/test/scala/cats/tests/OptionTTests.scala +++ b/tests/src/test/scala/cats/tests/OptionTTests.scala @@ -1,8 +1,8 @@ package cats.tests -import cats.{Id, Monad, Cartesian, Show} +import cats.{Id, MonadRec, Cartesian, Show} import cats.data.{OptionT, Xor} -import cats.laws.discipline.{FunctorTests, SerializableTests, CartesianTests, MonadTests} +import cats.laws.discipline.{FunctorTests, SerializableTests, CartesianTests, MonadRecTests} import cats.laws.discipline.arbitrary._ class OptionTTests extends CatsSuite { @@ -160,8 +160,8 @@ class OptionTTests extends CatsSuite { } } - checkAll("Monad[OptionT[List, Int]]", MonadTests[OptionT[List, ?]].monad[Int, Int, Int]) - checkAll("Monad[OptionT[List, ?]]", SerializableTests.serializable(Monad[OptionT[List, ?]])) + checkAll("OptionT[List, Int]", MonadRecTests[OptionT[List, ?]].monadRec[Int, Int, Int]) + checkAll("MonadRec[OptionT[List, ?]]", SerializableTests.serializable(MonadRec[OptionT[List, ?]])) { implicit val F = ListWrapper.functor diff --git a/tests/src/test/scala/cats/tests/OptionTests.scala b/tests/src/test/scala/cats/tests/OptionTests.scala index 8c6ef464b9..a7e0d69d94 100644 --- a/tests/src/test/scala/cats/tests/OptionTests.scala +++ b/tests/src/test/scala/cats/tests/OptionTests.scala @@ -14,6 +14,9 @@ class OptionTests extends CatsSuite { checkAll("Option[Int]", MonadCombineTests[Option].monadCombine[Int, Int, Int]) checkAll("MonadCombine[Option]", SerializableTests.serializable(MonadCombine[Option])) + checkAll("Option[Int]", MonadRecTests[Option].monadRec[Int, Int, Int]) + checkAll("MonadRec[Option]", SerializableTests.serializable(MonadRec[Option])) + checkAll("Option[Int] with Option", TraverseTests[Option].traverse[Int, Int, Int, Int, Option, Option]) checkAll("Traverse[Option]", SerializableTests.serializable(Traverse[Option])) diff --git a/tests/src/test/scala/cats/tests/XorTTests.scala b/tests/src/test/scala/cats/tests/XorTTests.scala index b88e4c2958..80c96a6601 100644 --- a/tests/src/test/scala/cats/tests/XorTTests.scala +++ b/tests/src/test/scala/cats/tests/XorTTests.scala @@ -13,6 +13,8 @@ class XorTTests extends CatsSuite { implicit val iso = CartesianTests.Isomorphisms.invariant[XorT[List, String, ?]] checkAll("XorT[List, String, Int]", MonadErrorTests[XorT[List, String, ?], String].monadError[Int, Int, Int]) checkAll("MonadError[XorT[List, ?, ?]]", SerializableTests.serializable(MonadError[XorT[List, String, ?], String])) + checkAll("XorT[List, String, Int]", MonadRecTests[XorT[List, String, ?]].monadRec[Int, Int, Int]) + checkAll("MonadRec[XorT[List, String, ?]]", SerializableTests.serializable(MonadRec[XorT[List, String, ?]])) checkAll("XorT[List, ?, ?]", BifunctorTests[XorT[List, ?, ?]].bifunctor[Int, Int, Int, String, String, String]) checkAll("Bifunctor[XorT[List, ?, ?]]", SerializableTests.serializable(Bifunctor[XorT[List, ?, ?]])) checkAll("XorT[List, ?, ?]", BitraverseTests[XorT[List, ?, ?]].bitraverse[Option, Int, Int, Int, String, String, String]) diff --git a/tests/src/test/scala/cats/tests/XorTests.scala b/tests/src/test/scala/cats/tests/XorTests.scala index 6afa7308e0..ac984230f6 100644 --- a/tests/src/test/scala/cats/tests/XorTests.scala +++ b/tests/src/test/scala/cats/tests/XorTests.scala @@ -5,7 +5,7 @@ import cats.data.{NonEmptyList, Xor, XorT} import cats.data.Xor._ import cats.laws.discipline.{SemigroupKTests} import cats.laws.discipline.arbitrary._ -import cats.laws.discipline.{BitraverseTests, TraverseTests, MonadErrorTests, SerializableTests, CartesianTests} +import cats.laws.discipline.{BitraverseTests, TraverseTests, MonadErrorTests, MonadRecTests, SerializableTests, CartesianTests} import cats.kernel.laws.{GroupLaws, OrderLaws} import org.scalacheck.Arbitrary import org.scalacheck.Arbitrary._ @@ -27,6 +27,9 @@ class XorTests extends CatsSuite { checkAll("Xor[String, Int]", MonadErrorTests[Xor[String, ?], String].monadError[Int, Int, Int]) checkAll("MonadError[Xor, String]", SerializableTests.serializable(MonadError[Xor[String, ?], String])) + checkAll("Xor[String, Int]", MonadRecTests[Xor[String, ?]].monadRec[Int, Int, Int]) + checkAll("MonadRec[Xor[String, ?]]", SerializableTests.serializable(MonadRec[Xor[String, ?]])) + checkAll("Xor[String, Int] with Option", TraverseTests[Xor[String, ?]].traverse[Int, Int, Int, Int, Option, Option]) checkAll("Traverse[Xor[String,?]]", SerializableTests.serializable(Traverse[Xor[String, ?]]))