Skip to content

Commit

Permalink
Add toRightAssociated to AndThen (#3527)
Browse files Browse the repository at this point in the history
* Add toRightAssociated to AndThen

* fix a comment
  • Loading branch information
johnynek authored Aug 11, 2020
1 parent 9db7a3e commit e1046dc
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 3 deletions.
39 changes: 37 additions & 2 deletions core/src/main/scala/cats/data/AndThen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ sealed abstract class AndThen[-T, +R] extends (T => R) with Product with Seriali

private def runLoop(start: T): R = {
@tailrec
def loop[A, B](self: AndThen[A, B], current: A): B =
def loop[A](self: AndThen[A, R], current: A): R =
self match {
case Single(f, _) => f(current)

Expand All @@ -111,7 +111,7 @@ sealed abstract class AndThen[-T, +R] extends (T => R) with Product with Seriali
// converts left-leaning to right-leaning
final protected def rotateAccum[E](_right: AndThen[R, E]): AndThen[T, E] = {
@tailrec
def loop[A, B, C](left: AndThen[A, B], right: AndThen[B, C]): AndThen[A, C] =
def loop[A](left: AndThen[T, A], right: AndThen[A, E]): AndThen[T, E] =
left match {
case Concat(left1, right1) =>
loop(left1, Concat(right1, right))
Expand Down Expand Up @@ -152,6 +152,41 @@ object AndThen extends AndThenInstances0 {
* to be in danger of triggering a stack-overflow error.
*/
final private val fusionMaxStackDepth = 127

/**
* If you are going to call this function many times, right associating it
* once can be a significant performance improvement for VERY long chains.
*/
def toRightAssociated[A, B](fn: AndThen[A, B]): AndThen[A, B] = {
@tailrec
def loop[X, Y](beg: AndThen[A, X], middle: AndThen[X, Y], end: AndThen[Y, B], endDone: Boolean): AndThen[A, B] =
if (endDone) {
// end is right associated
middle match {
case sm @ Single(_, _) =>
val newEnd = Concat(sm, end)
beg match {
case sb @ Single(_, _) => Concat(sb, newEnd)
case Concat(begA, begB) => loop(begA, begB, newEnd, true)
}
case Concat(mA, mB) =>
// rotate mA onto beg:
loop(Concat(beg, mA), mB, end, true)
}
} else {
// we are still right-associating the end
end match {
case se @ Single(_, _) => loop(beg, middle, se, true)
case Concat(endA, endB) => loop(beg, Concat(middle, endA), endB, false)
}
}

fn match {
case Concat(Concat(a, b), c) => loop(a, b, c, false)
case Concat(a, Concat(b, c)) => loop(a, b, c, false)
case Concat(Single(_, _), Single(_, _)) | Single(_, _) => fn
}
}
}

abstract private[data] class AndThenInstances0 extends AndThenInstances1 {
Expand Down
84 changes: 83 additions & 1 deletion tests/src/test/scala/cats/tests/AndThenSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import cats.laws.discipline._
import cats.laws.discipline.arbitrary._
import cats.laws.discipline.eq._
import cats.platform.Platform
import munit.ScalaCheckSuite
import org.scalacheck.{Arbitrary, Cogen, Gen}
import org.scalacheck.Prop._

class AndThenSuite extends CatsSuite {
class AndThenSuite extends CatsSuite with ScalaCheckSuite {
checkAll("AndThen[MiniInt, Int]", SemigroupalTests[AndThen[MiniInt, *]].semigroupal[Int, Int, Int])
checkAll("Semigroupal[AndThen[Int, *]]", SerializableTests.serializable(Semigroupal[AndThen[Int, *]]))

Expand Down Expand Up @@ -88,4 +90,84 @@ class AndThenSuite extends CatsSuite {
test("toString") {
assert(AndThen((x: Int) => x).toString.startsWith("AndThen$"))
}

// generate a general AndThen which may not be right associated
def genAndThen[A: Cogen: Arbitrary]: Gen[AndThen[A, A]] = {
val gfn = Gen.function1[A, A](Arbitrary.arbitrary[A])
// if we don't have a long list we don't see any Concat
Gen
.choose(128, 1 << 13)
.flatMap { size =>
Gen.listOfN(size, gfn).flatMap { fns =>
val ary = fns.toArray

def loop(start: Int, end: Int): Gen[AndThen[A, A]] =
if (start == (end - 1)) Gen.const(AndThen(ary(start)))
else if (start >= end) Gen.const(AndThen(identity[A]))
else {
Gen.choose(start, end - 1).flatMap { middle =>
for {
left <- loop(start, middle)
right <- loop(middle, end)
} yield left.andThen(right)
}
}

loop(0, ary.length)
}
}
}

// generate a right associated function by construction
def genRight[A: Cogen: Arbitrary]: Gen[AndThen[A, A]] = {
val gfn = Gen.function1[A, A](Arbitrary.arbitrary[A])
// if we don't have a long list we don't see any Concat
Gen
.choose(128, 1 << 13)
.flatMap { size =>
Gen.listOfN(size, gfn).map {
case Nil => AndThen(identity[A])
case h :: tail =>
tail.foldRight(AndThen(h)) { (fn, at) => AndThen(fn).andThen(at) }
}
}
}

// generate a left associated function by construction
def genLeft[A: Cogen: Arbitrary]: Gen[AndThen[A, A]] = {
val gfn = Gen.function1[A, A](Arbitrary.arbitrary[A])
// if we don't have a long list we don't see any Concat
Gen
.choose(1024, 1 << 13)
.flatMap { size =>
Gen.listOfN(size, gfn).map {
case Nil => AndThen(identity[A])
case h :: tail =>
tail.foldLeft(AndThen(h)) { (at, fn) => at.andThen(fn) }
}
}
}

property("toRightAssociated works") {
// we pass explicit Gens here rather than use the Arbitrary
// instance which just wraps a function

// Right associated should be identity
forAll(genRight[Int]) { at =>
AndThen.toRightAssociated(at) == at
} &&
// Left associated is never right associated
forAll(genLeft[Int]) { at =>
AndThen.toRightAssociated(at) != at
} &&
// check that right associating doesn't change the function value
forAll(genAndThen[Int], Gen.choose(Int.MinValue, Int.MaxValue)) { (at, i) =>
AndThen.toRightAssociated(at)(i) == at(i)
} &&
// in the worst case of a left associated AndThen, values should still match
forAll(genLeft[Int], Gen.choose(Int.MinValue, Int.MaxValue)) { (at, i) =>
AndThen.toRightAssociated(at)(i) == at(i)
}
}

}

0 comments on commit e1046dc

Please sign in to comment.