diff --git a/core/src/main/scala/cats/NonEmptyTraverse.scala b/core/src/main/scala/cats/NonEmptyTraverse.scala index cc7516269d..64527ef97c 100644 --- a/core/src/main/scala/cats/NonEmptyTraverse.scala +++ b/core/src/main/scala/cats/NonEmptyTraverse.scala @@ -53,9 +53,8 @@ import simulacrum.typeclass * {{{ * scala> import cats.implicits._ * scala> import cats.data.NonEmptyList - * scala> def countWords(words: List[String]): Map[String, Int] = words.groupBy(identity).mapValues(_.length) * scala> val x = NonEmptyList.of(List("How", "do", "you", "fly"), List("What", "do", "you", "do")) - * scala> x.nonEmptyFlatTraverse(_.groupByNel(identity)) + * scala> x.nonEmptyFlatTraverse(_.groupByNel(identity) : Map[String, NonEmptyList[String]]) * res0: Map[String,cats.data.NonEmptyList[String]] = Map(do -> NonEmptyList(do, do, do), you -> NonEmptyList(you, you)) * }}} */ diff --git a/core/src/main/scala/cats/data/NonEmptyList.scala b/core/src/main/scala/cats/data/NonEmptyList.scala index 875609c236..4dfc37dc61 100644 --- a/core/src/main/scala/cats/data/NonEmptyList.scala +++ b/core/src/main/scala/cats/data/NonEmptyList.scala @@ -3,11 +3,10 @@ package data import cats.instances.list._ import cats.syntax.order._ - import scala.annotation.tailrec -import scala.collection.immutable.TreeSet +import scala.collection.immutable.{ SortedMap, TreeMap, TreeSet } +import scala.collection.mutable import scala.collection.mutable.ListBuffer -import scala.collection.{immutable, mutable} /** * A data type which represents a non empty list of A, with @@ -321,28 +320,35 @@ final case class NonEmptyList[+A](head: A, tail: List[A]) { } /** - * Groups elements inside of this `NonEmptyList` using a mapping function + * Groups elements inside this `NonEmptyList` according to the `Order` + * of the keys produced by the given mapping function. * * {{{ + * scala> import scala.collection.immutable.SortedMap * scala> import cats.data.NonEmptyList + * scala> import cats.instances.boolean._ * scala> val nel = NonEmptyList.of(12, -2, 3, -5) * scala> nel.groupBy(_ >= 0) - * res0: Map[Boolean, cats.data.NonEmptyList[Int]] = Map(false -> NonEmptyList(-2, -5), true -> NonEmptyList(12, 3)) + * res0: SortedMap[Boolean, cats.data.NonEmptyList[Int]] = Map(false -> NonEmptyList(-2, -5), true -> NonEmptyList(12, 3)) * }}} */ - def groupBy[B](f: A => B): Map[B, NonEmptyList[A]] = { - val m = mutable.Map.empty[B, mutable.Builder[A, List[A]]] + def groupBy[B](f: A => B)(implicit B: Order[B]): SortedMap[B, NonEmptyList[A]] = { + implicit val ordering: Ordering[B] = B.toOrdering + var m = TreeMap.empty[B, mutable.Builder[A, List[A]]] + for { elem <- toList } { - m.getOrElseUpdate(f(elem), List.newBuilder[A]) += elem - } - val b = immutable.Map.newBuilder[B, NonEmptyList[A]] - for { (k, v) <- m } { - val head :: tail = v.result // we only create non empty list inside of the map `m` - b += ((k, NonEmptyList(head, tail))) + val k = f(elem) + + m.get(k) match { + case None => m += ((k, List.newBuilder[A] += elem)) + case Some(builder) => builder += elem + } } - b.result - } + m.map { + case (k, v) => (k, NonEmptyList.fromListUnsafe(v.result)) + } : TreeMap[B, NonEmptyList[A]] + } } object NonEmptyList extends NonEmptyListInstances { diff --git a/core/src/main/scala/cats/syntax/list.scala b/core/src/main/scala/cats/syntax/list.scala index d354eb7c7c..b8e31dc83e 100644 --- a/core/src/main/scala/cats/syntax/list.scala +++ b/core/src/main/scala/cats/syntax/list.scala @@ -1,6 +1,7 @@ package cats package syntax +import scala.collection.immutable.SortedMap import cats.data.NonEmptyList trait ListSyntax { @@ -27,6 +28,8 @@ final class ListOps[A](val la: List[A]) extends AnyVal { * }}} */ def toNel: Option[NonEmptyList[A]] = NonEmptyList.fromList(la) - def groupByNel[B](f: A => B): Map[B, NonEmptyList[A]] = - toNel.fold(Map.empty[B, NonEmptyList[A]])(_.groupBy(f)) + def groupByNel[B](f: A => B)(implicit B: Order[B]): SortedMap[B, NonEmptyList[A]] = { + implicit val ordering = B.toOrdering + toNel.fold(SortedMap.empty[B, NonEmptyList[A]])(_.groupBy(f)) + } }