diff --git a/core/src/main/scala/cats/data/Chain.scala b/core/src/main/scala/cats/data/Chain.scala index 6dc6301fd7..99c95eaa89 100644 --- a/core/src/main/scala/cats/data/Chain.scala +++ b/core/src/main/scala/cats/data/Chain.scala @@ -1,11 +1,11 @@ package cats package data -import cats.implicits._ import Chain._ import scala.annotation.tailrec import scala.collection.immutable.SortedMap +import scala.collection.mutable.ListBuffer /** * Trivial catenable sequence. Supports O(1) append, and (amortized) @@ -345,6 +345,27 @@ sealed abstract class Chain[+A] { final def toVector: Vector[A] = iterator.toVector + /** + * Typesafe equality operator. + * + * This method is similar to == except that it only allows two + * Chain[A] values to be compared to each other, and uses + * equality provided by Eq[_] instances, rather than using the + * universal equality provided by .equals. + */ + def ===[AA >: A](that: Chain[AA])(implicit A: Eq[AA]): Boolean = + (this eq that) || { + val iterX = iterator + val iterY = that.iterator + while (iterX.hasNext && iterY.hasNext) { + // scalastyle:off return + if (!A.eqv(iterX.next, iterY.next)) return false + // scalastyle:on return + } + + iterX.hasNext == iterY.hasNext + } + def show[AA >: A](implicit AA: Show[AA]): String = { val builder = new StringBuilder("Chain(") var first = true @@ -508,14 +529,15 @@ object Chain extends ChainInstances { // scalastyle:on null } -private[data] sealed abstract class ChainInstances { +private[data] sealed abstract class ChainInstances extends ChainInstances1 { implicit def catsDataMonoidForChain[A]: Monoid[Chain[A]] = new Monoid[Chain[A]] { def empty: Chain[A] = Chain.nil def combine(c: Chain[A], c2: Chain[A]): Chain[A] = Chain.concat(c, c2) } - implicit val catsDataInstancesForChain: Traverse[Chain] with Alternative[Chain] with Monad[Chain] = - new Traverse[Chain] with Alternative[Chain] with Monad[Chain] { + implicit val catsDataInstancesForChain: Traverse[Chain] with Alternative[Chain] + with Monad[Chain] with CoflatMap[Chain] = + new Traverse[Chain] with Alternative[Chain] with Monad[Chain] with CoflatMap[Chain] { def foldLeft[A, B](fa: Chain[A], b: B)(f: (B, A) => B): B = fa.foldLeft(b)(f) def foldRight[A, B](fa: Chain[A], lb: Eval[B])(f: (A, Eval[B]) => Eval[B]): Eval[B] = @@ -530,6 +552,16 @@ private[data] sealed abstract class ChainInstances { override def forall[A](fa: Chain[A])(p: A => Boolean): Boolean = fa.forall(p) override def find[A](fa: Chain[A])(f: A => Boolean): Option[A] = fa.find(f) + def coflatMap[A, B](fa: Chain[A])(f: Chain[A] => B): Chain[B] = { + @tailrec def go(as: Chain[A], res: ListBuffer[B]): Chain[B] = + as.uncons match { + case Some((h, t)) => go(t, res += f(t)) + case None => Chain.fromSeq(res.result()) + } + + go(fa, ListBuffer.empty) + } + def traverse[G[_], A, B](fa: Chain[A])(f: A => G[B])(implicit G: Applicative[G]): G[Chain[B]] = fa.foldRight[G[Chain[B]]](G.pure(nil)) { (a, gcatb) => G.map2(f(a), gcatb)(_ +: _) @@ -566,9 +598,55 @@ private[data] sealed abstract class ChainInstances { implicit def catsDataShowForChain[A](implicit A: Show[A]): Show[Chain[A]] = Show.show[Chain[A]](_.show) + implicit def catsDataOrderForChain[A](implicit A0: Order[A]): Order[Chain[A]] = + new Order[Chain[A]] with ChainPartialOrder[A] { + implicit def A: PartialOrder[A] = A0 + def compare(x: Chain[A], y: Chain[A]): Int = if (x eq y) 0 else { + val iterX = x.iterator + val iterY = y.iterator + while (iterX.hasNext && iterY.hasNext) { + val n = A0.compare(iterX.next, iterY.next) + // scalastyle:off return + if (n != 0) return n + // scalastyle:on return + } + + if (iterX.hasNext) 1 + else if (iterY.hasNext) -1 + else 0 + } + } + +} + +private[data] sealed abstract class ChainInstances1 extends ChainInstances2 { + implicit def catsDataPartialOrderForChain[A](implicit A0: PartialOrder[A]): PartialOrder[Chain[A]] = + new ChainPartialOrder[A] { implicit def A: PartialOrder[A] = A0 } +} + +private[data] sealed abstract class ChainInstances2 { implicit def catsDataEqForChain[A](implicit A: Eq[A]): Eq[Chain[A]] = new Eq[Chain[A]] { - def eqv(x: Chain[A], y: Chain[A]): Boolean = - (x eq y) || x.toList === y.toList + def eqv(x: Chain[A], y: Chain[A]): Boolean = x === y + } +} + +private[data] trait ChainPartialOrder[A] extends PartialOrder[Chain[A]] { + implicit def A: PartialOrder[A] + + override def partialCompare(x: Chain[A], y: Chain[A]): Double = if (x eq y) 0.0 else { + val iterX = x.iterator + val iterY = y.iterator + while (iterX.hasNext && iterY.hasNext) { + val n = A.partialCompare(iterX.next, iterY.next) + // scalastyle:off return + if (n != 0.0) return n + // scalastyle:on return + } + + if (iterX.hasNext) 1.0 + else if (iterY.hasNext) -1.0 + else 0.0 } + override def eqv(x: Chain[A], y: Chain[A]): Boolean = x === y } diff --git a/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala b/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala index 82630e63a4..fcb6d16ac7 100644 --- a/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala +++ b/laws/src/main/scala/cats/laws/discipline/Arbitrary.scala @@ -287,6 +287,9 @@ object arbitrary extends ArbitraryInstances0 { } }) + implicit def catsLawsCogenForChain[A](implicit A: Cogen[A]): Cogen[Chain[A]] = + Cogen[List[A]].contramap(_.toList) + } private[discipline] sealed trait ArbitraryInstances0 { diff --git a/tests/src/test/scala/cats/tests/ChainSuite.scala b/tests/src/test/scala/cats/tests/ChainSuite.scala index ceef5b080e..a83a3d77d1 100644 --- a/tests/src/test/scala/cats/tests/ChainSuite.scala +++ b/tests/src/test/scala/cats/tests/ChainSuite.scala @@ -2,8 +2,8 @@ package cats package tests import cats.data.Chain -import cats.kernel.laws.discipline.MonoidTests -import cats.laws.discipline.{AlternativeTests, MonadTests, SerializableTests, TraverseTests} +import cats.kernel.laws.discipline.{MonoidTests, OrderTests} +import cats.laws.discipline.{AlternativeTests, CoflatMapTests, MonadTests, SerializableTests, TraverseTests} import cats.laws.discipline.arbitrary._ class ChainSuite extends CatsSuite { @@ -16,9 +16,15 @@ class ChainSuite extends CatsSuite { checkAll("Chain[Int]", MonadTests[Chain].monad[Int, Int, Int]) checkAll("Monad[Chain]", SerializableTests.serializable(Monad[Chain])) + checkAll("Chain[Int]", CoflatMapTests[Chain].coflatMap[Int, Int, Int]) + checkAll("Coflatmap[Chain]", SerializableTests.serializable(CoflatMap[Chain])) + checkAll("Chain[Int]", MonoidTests[Chain[Int]].monoid) checkAll("Monoid[Chain]", SerializableTests.serializable(Monoid[Chain[Int]])) + checkAll("Chain[Int]", OrderTests[Chain[Int]].order) + checkAll("Order[Chain]", SerializableTests.serializable(Order[Chain[Int]])) + test("show"){ Show[Chain[Int]].show(Chain(1, 2, 3)) should === ("Chain(1, 2, 3)") Chain.empty[Int].show should === ("Chain()")