diff --git a/core/src/main/scala/cats/data/AndThen.scala b/core/src/main/scala/cats/data/AndThen.scala index 70a91558a0..a9296abb16 100644 --- a/core/src/main/scala/cats/data/AndThen.scala +++ b/core/src/main/scala/cats/data/AndThen.scala @@ -89,7 +89,7 @@ sealed abstract class AndThen[-T, +R] extends (T => R) with Product with Seriali private def runLoop(start: T): R = { @tailrec - def loop[A, B](self: AndThen[A, B], current: A): B = + def loop[A](self: AndThen[A, R], current: A): R = self match { case Single(f, _) => f(current) @@ -111,7 +111,7 @@ sealed abstract class AndThen[-T, +R] extends (T => R) with Product with Seriali // converts left-leaning to right-leaning final protected def rotateAccum[E](_right: AndThen[R, E]): AndThen[T, E] = { @tailrec - def loop[A, B, C](left: AndThen[A, B], right: AndThen[B, C]): AndThen[A, C] = + def loop[A](left: AndThen[T, A], right: AndThen[A, E]): AndThen[T, E] = left match { case Concat(left1, right1) => loop(left1, Concat(right1, right)) @@ -152,6 +152,41 @@ object AndThen extends AndThenInstances0 { * to be in danger of triggering a stack-overflow error. */ final private val fusionMaxStackDepth = 127 + + /** + * If you are going to call this function many times, right associating it + * once can be a significant performance improvement for VERY long chains. + */ + def toRightAssociated[A, B](fn: AndThen[A, B]): AndThen[A, B] = { + @tailrec + def loop[X, Y](beg: AndThen[A, X], middle: AndThen[X, Y], end: AndThen[Y, B], endDone: Boolean): AndThen[A, B] = + if (endDone) { + // end is right associated + middle match { + case sm @ Single(_, _) => + val newEnd = Concat(sm, end) + beg match { + case sb @ Single(_, _) => Concat(sb, newEnd) + case Concat(begA, begB) => loop(begA, begB, newEnd, true) + } + case Concat(mA, mB) => + // rotate mA onto beg: + loop(Concat(beg, mA), mB, end, true) + } + } else { + // we are still right-associating the end + end match { + case se @ Single(_, _) => loop(beg, middle, se, true) + case Concat(endA, endB) => loop(beg, Concat(middle, endA), endB, false) + } + } + + fn match { + case Concat(Concat(a, b), c) => loop(a, b, c, false) + case Concat(a, Concat(b, c)) => loop(a, b, c, false) + case Concat(Single(_, _), Single(_, _)) | Single(_, _) => fn + } + } } abstract private[data] class AndThenInstances0 extends AndThenInstances1 { diff --git a/tests/src/test/scala/cats/tests/AndThenSuite.scala b/tests/src/test/scala/cats/tests/AndThenSuite.scala index 0b831d21f8..4790506b7f 100644 --- a/tests/src/test/scala/cats/tests/AndThenSuite.scala +++ b/tests/src/test/scala/cats/tests/AndThenSuite.scala @@ -8,9 +8,11 @@ import cats.laws.discipline._ import cats.laws.discipline.arbitrary._ import cats.laws.discipline.eq._ import cats.platform.Platform +import munit.ScalaCheckSuite +import org.scalacheck.{Arbitrary, Cogen, Gen} import org.scalacheck.Prop._ -class AndThenSuite extends CatsSuite { +class AndThenSuite extends CatsSuite with ScalaCheckSuite { checkAll("AndThen[MiniInt, Int]", SemigroupalTests[AndThen[MiniInt, *]].semigroupal[Int, Int, Int]) checkAll("Semigroupal[AndThen[Int, *]]", SerializableTests.serializable(Semigroupal[AndThen[Int, *]])) @@ -88,4 +90,84 @@ class AndThenSuite extends CatsSuite { test("toString") { assert(AndThen((x: Int) => x).toString.startsWith("AndThen$")) } + + // generate a general AndThen which may not be right associated + def genAndThen[A: Cogen: Arbitrary]: Gen[AndThen[A, A]] = { + val gfn = Gen.function1[A, A](Arbitrary.arbitrary[A]) + // if we don't have a long list we don't see any Concat + Gen + .choose(128, 1 << 13) + .flatMap { size => + Gen.listOfN(size, gfn).flatMap { fns => + val ary = fns.toArray + + def loop(start: Int, end: Int): Gen[AndThen[A, A]] = + if (start == (end - 1)) Gen.const(AndThen(ary(start))) + else if (start >= end) Gen.const(AndThen(identity[A])) + else { + Gen.choose(start, end - 1).flatMap { middle => + for { + left <- loop(start, middle) + right <- loop(middle, end) + } yield left.andThen(right) + } + } + + loop(0, ary.length) + } + } + } + + // generate a right associated function by construction + def genRight[A: Cogen: Arbitrary]: Gen[AndThen[A, A]] = { + val gfn = Gen.function1[A, A](Arbitrary.arbitrary[A]) + // if we don't have a long list we don't see any Concat + Gen + .choose(128, 1 << 13) + .flatMap { size => + Gen.listOfN(size, gfn).map { + case Nil => AndThen(identity[A]) + case h :: tail => + tail.foldRight(AndThen(h)) { (fn, at) => AndThen(fn).andThen(at) } + } + } + } + + // generate a left associated function by construction + def genLeft[A: Cogen: Arbitrary]: Gen[AndThen[A, A]] = { + val gfn = Gen.function1[A, A](Arbitrary.arbitrary[A]) + // if we don't have a long list we don't see any Concat + Gen + .choose(1024, 1 << 13) + .flatMap { size => + Gen.listOfN(size, gfn).map { + case Nil => AndThen(identity[A]) + case h :: tail => + tail.foldLeft(AndThen(h)) { (at, fn) => at.andThen(fn) } + } + } + } + + property("toRightAssociated works") { + // we pass explicit Gens here rather than use the Arbitrary + // instance which just wraps a function + + // Right associated should be identity + forAll(genRight[Int]) { at => + AndThen.toRightAssociated(at) == at + } && + // Left associated is never right associated + forAll(genLeft[Int]) { at => + AndThen.toRightAssociated(at) != at + } && + // check that right associating doesn't change the function value + forAll(genAndThen[Int], Gen.choose(Int.MinValue, Int.MaxValue)) { (at, i) => + AndThen.toRightAssociated(at)(i) == at(i) + } && + // in the worst case of a left associated AndThen, values should still match + forAll(genLeft[Int], Gen.choose(Int.MinValue, Int.MaxValue)) { (at, i) => + AndThen.toRightAssociated(at)(i) == at(i) + } + } + }