Skip to content

Commit

Permalink
Clean up ReaderWriterStateT (#1706)
Browse files Browse the repository at this point in the history
- Check if type class instances are serializable
- Move Arbitrary instance to cats-laws
- Move unamibiguous type class instances
  • Loading branch information
peterneyens authored and kailuowang committed Jun 1, 2017
1 parent 4e55158 commit 4845e68
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 61 deletions.
61 changes: 18 additions & 43 deletions core/src/main/scala/cats/data/ReaderWriterStateT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,6 @@ final class ReaderWriterStateT[F[_], E, S, L, A](val runF: F[(E, S) => F[(L, S,
def mapWritten[LL](f: L => LL)(implicit F: Functor[F]): ReaderWriterStateT[F, E, S, LL, A] =
transform { (l, s, a) => (f(l), s, a) }

/**
* Combine this computation with `rwsb` using `fn`. The state will be be threaded
* through the computations and the log values will be combined.
*/
def map2[B, Z](rwsb: ReaderWriterStateT[F, E, S, L, B])(fn: (A, B) => Z)(
implicit F: FlatMap[F], L: Semigroup[L]): ReaderWriterStateT[F, E, S, L, Z] =
flatMap { a =>
rwsb.map { b =>
fn(a, b)
}
}

/**
* Modify the result of the computation by feeding it into `f`, threading the state
* through the resulting computation and combining the log values.
Expand Down Expand Up @@ -343,6 +331,21 @@ private[data] sealed trait RWSTInstances extends RWSTInstances1 {
new RWSTMonadTrans[E, S, L] {
implicit def L: Monoid[L] = L0
}

implicit def catsDataProfunctorForRWST[F[_], S, L](implicit F0: Functor[F]): Profunctor[ReaderWriterStateT[F, ?, S, L, ?]] =
new RWSTProfunctor[F, S, L] {
implicit def F: Functor[F] = F0
}

implicit def catsDataBifunctorForRWST[F[_], E, S](implicit F0: Functor[F]): Bifunctor[ReaderWriterStateT[F, E, S, ?, ?]] =
new RWSTBifunctor[F, E, S] {
implicit def F: Functor[F] = F0
}

implicit def catsDataContravariantForRWST[F[_], S, L, A](implicit F0: Functor[F]): Contravariant[ReaderWriterStateT[F, ?, S, L, A]] =
new RWSTContravariant[F, S, L, A] {
implicit def F: Functor[F] = F0
}
}

private[data] sealed trait RWSTInstances1 extends RWSTInstances2 {
Expand Down Expand Up @@ -396,38 +399,17 @@ private[data] sealed trait RWSTInstances5 extends RWSTInstances6 {
}
}

private[data] sealed trait RWSTInstances6 extends RWSTInstances7 {
private[data] sealed trait RWSTInstances6 {
implicit def catsDataFunctorForRWST[F[_], E, S, L](implicit F0: Functor[F]): Functor[ReaderWriterStateT[F, E, S, L, ?]] =
new RWSTFunctor[F, E, S, L] {
implicit def F: Functor[F] = F0
}
}

private[data] sealed trait RWSTInstances7 extends RWSTInstances8 {
implicit def catsDataContravariantForRWST[F[_], S, L, A](implicit F0: Functor[F]): Contravariant[ReaderWriterStateT[F, ?, S, L, A]] =
new RWSTContravariant[F, S, L, A] {
implicit def F: Functor[F] = F0
}
}

private[data] sealed trait RWSTInstances8 extends RWSTInstances9 {
implicit def catsDataBifunctorForRWST[F[_], E, S](implicit F0: Functor[F]): Bifunctor[ReaderWriterStateT[F, E, S, ?, ?]] =
new RWSTBifunctor[F, E, S] {
implicit def F: Functor[F] = F0
}
}

private[data] sealed trait RWSTInstances9 {
implicit def catsDataProfunctorForRWST[F[_], S, L](implicit F0: Functor[F]): Profunctor[ReaderWriterStateT[F, ?, S, L, ?]] =
new RWSTProfunctor[F, S, L] {
implicit def F: Functor[F] = F0
}
}

private[data] sealed trait RWSTFunctor[F[_], E, S, L] extends Functor[ReaderWriterStateT[F, E, S, L, ?]] {
implicit def F: Functor[F]

def map[A, B](fa: ReaderWriterStateT[F, E, S, L, A])(f: A => B): ReaderWriterStateT[F, E, S, L, B] =
override def map[A, B](fa: ReaderWriterStateT[F, E, S, L, A])(f: A => B): ReaderWriterStateT[F, E, S, L, B] =
fa.map(f)
}

Expand All @@ -452,7 +434,7 @@ private[data] sealed trait RWSTProfunctor[F[_], S, L] extends Profunctor[ReaderW
fab.contramap(f).map(g)
}

private[data] sealed trait RWSTMonad[F[_], E, S, L] extends Monad[ReaderWriterStateT[F, E, S, L, ?]] {
private[data] sealed trait RWSTMonad[F[_], E, S, L] extends Monad[ReaderWriterStateT[F, E, S, L, ?]] with RWSTFunctor[F, E, S, L] {
implicit def F: Monad[F]
implicit def L: Monoid[L]

Expand All @@ -470,13 +452,6 @@ private[data] sealed trait RWSTMonad[F[_], E, S, L] extends Monad[ReaderWriterSt
}
}
}

override def map[A, B](fa: ReaderWriterStateT[F, E, S, L, A])(f: A => B): ReaderWriterStateT[F, E, S, L, B] =
fa.map(f)

override def map2[A, B, Z](fa: ReaderWriterStateT[F, E, S, L, A],
fb: ReaderWriterStateT[F, E, S, L, B])(f: (A, B) => Z): ReaderWriterStateT[F, E, S, L, Z] =
fa.map2(fb)(f)
}

private[data] sealed trait RWSTMonadState[F[_], E, S, L]
Expand Down
3 changes: 3 additions & 0 deletions laws/src/main/scala/cats/laws/discipline/Arbitrary.scala
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ object arbitrary extends ArbitraryInstances0 {

implicit def catsLawArbitraryForReader[A: Arbitrary: Cogen, B: Arbitrary]: Arbitrary[Reader[A, B]] =
catsLawsArbitraryForKleisli[Id, A, B]

implicit def catsLawsAribtraryForReaderWriterStateT[F[_]: Applicative, E, S, L, A](implicit F: Arbitrary[(E, S) => F[(L, S, A)]]): Arbitrary[ReaderWriterStateT[F, E, S, L, A]] =
Arbitrary(F.arbitrary.map(ReaderWriterStateT(_)))
}

private[discipline] sealed trait ArbitraryInstances0 {
Expand Down
5 changes: 4 additions & 1 deletion tests/src/test/scala/cats/tests/KleisliTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ class KleisliTests extends CatsSuite {

checkAll("Kleisli[Option, ?, Int]", ContravariantTests[Kleisli[Option, ?, Int]].contravariant[Int, Int, Int])
checkAll("Contravariant[Kleisli[Option, ?, Int]]", SerializableTests.serializable(Contravariant[Kleisli[Option, ?, Int]]))
checkAll("MonadTrans[Kleisli[?[_], Int, ?]]", MonadTransTests[Kleisli[?[_], Int, ?]].monadTrans[Option, Int, Int])

checkAll("Kleisli[Option, Int, ?]]", MonadTransTests[Kleisli[?[_], Int, ?]].monadTrans[Option, Int, Int])
checkAll("MonadTrans[Kleisli[?[_], Int, ?]]", SerializableTests.serializable(MonadTrans[Kleisli[?[_], Int, ?]]))


test("local composes functions") {
forAll { (f: Int => Option[String], g: Int => Int, i: Int) =>
Expand Down
2 changes: 2 additions & 0 deletions tests/src/test/scala/cats/tests/ListWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ object ListWrapper {

val monad: Monad[ListWrapper] = monadCombine

val flatMap: FlatMap[ListWrapper] = monadCombine

val applicative: Applicative[ListWrapper] = monadCombine

/** apply is taken due to ListWrapper being a case class */
Expand Down
46 changes: 36 additions & 10 deletions tests/src/test/scala/cats/tests/ReaderWriterStateTTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@ package cats
package tests

import cats.data.{ ReaderWriterStateT, ReaderWriterState, EitherT }
import cats.functor.{ Bifunctor, Contravariant, Profunctor }
import cats.laws.discipline._
import cats.laws.discipline.eq._
import cats.laws.discipline.arbitrary._
import org.scalacheck.{ Arbitrary }
import org.scalacheck.Arbitrary

class ReaderWriterStateTTests extends CatsSuite {
import ReaderWriterStateTTests._
Expand Down Expand Up @@ -279,53 +280,77 @@ class ReaderWriterStateTTests extends CatsSuite {
}

implicit val iso = CartesianTests.Isomorphisms
.invariant[ReaderWriterStateT[ListWrapper, String, Int, String, ?]](ReaderWriterStateT.catsDataFunctorForRWST(ListWrapper.monad))
.invariant[ReaderWriterStateT[ListWrapper, String, Int, String, ?]](ReaderWriterStateT.catsDataFunctorForRWST(ListWrapper.functor))

{
implicit val F: Monad[ListWrapper] = ListWrapper.monad

checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]",
FunctorTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?]].functor[Int, Int, Int])
checkAll("Functor[ReaderWriterStateT[ListWrapper, String, Int, String, ?]]",
SerializableTests.serializable(Functor[ReaderWriterStateT[ListWrapper, String, Int, String, ?]]))

checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]",
ContravariantTests[ReaderWriterStateT[ListWrapper, ?, Int, String, Int]].contravariant[String, String, String])
checkAll("Contravariant[ReaderWriterStateT[ListWrapper, ?, Int, String, Int]]",
SerializableTests.serializable(Contravariant[ReaderWriterStateT[ListWrapper, ?, Int, String, Int]]))

checkAll("ReaderWriterStateT[ListWrapper, Int, Int, String, Int]",
ProfunctorTests[ReaderWriterStateT[ListWrapper, ?, Int, String, ?]].profunctor[Int, Int, Int, Int, Int, Int])
checkAll("Profunctor[ReaderWriterStateT[ListWrapper, ?, Int, String, ?]]",
SerializableTests.serializable(Profunctor[ReaderWriterStateT[ListWrapper, ?, Int, String, ?]]))

checkAll("ReaderWriterStateT[ListWrapper, Int, Int, Int, Int]",
BifunctorTests[ReaderWriterStateT[ListWrapper, String, Int, ?, ?]].bifunctor[Int, Int, Int, Int, Int, Int])
checkAll("Bifunctor[ReaderWriterStateT[ListWrapper, String, Int, ?, ?]]",
SerializableTests.serializable(Bifunctor[ReaderWriterStateT[ListWrapper, String, Int, ?, ?]]))
}

{
implicit val LWM: Monad[ListWrapper] = ListWrapper.monad

checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]",
MonadTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?]].monad[Int, Int, Int])
checkAll("Monad[ReaderWriterStateT[ListWrapper, String, Int, String, ?]]",
SerializableTests.serializable(Monad[ReaderWriterStateT[ListWrapper, String, Int, String, ?]]))
}

{
implicit val LWM: Monad[ListWrapper] = ListWrapper.monad

checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]",
MonadStateTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?], Int].monadState[Int, Int, Int])
checkAll("MonadState[ReaderWriterStateT[ListWrapper, String, Int, String, ?]. Int]",
SerializableTests.serializable(MonadState[ReaderWriterStateT[ListWrapper, String, Int, String, ?], Int]))
}

{
implicit val LWM: MonadCombine[ListWrapper] = ListWrapper.monadCombine

checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]",
MonadCombineTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?]].monadCombine[Int, Int, Int])
checkAll("MonadCombine[ReaderWriterStateT[ListWrapper, String, Int, String, ?]]",
SerializableTests.serializable(MonadCombine[ReaderWriterStateT[ListWrapper, String, Int, String, ?]]))
}

{
implicit val LWM: Monad[ListWrapper] = ListWrapper.monad

checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]",
MonadReaderTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?], String].monadReader[String, String, String])

// check serializable using Option
checkAll("MonadReader[ReaderWriterStateT[Option, String, Int, String, ?], String]",
SerializableTests.serializable(MonadReader[ReaderWriterStateT[Option, String, Int, String, ?], String]))
}

{
implicit val LWM: Monad[ListWrapper] = ListWrapper.monad

checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]",
MonadWriterTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?], String].monadWriter[String, String, String])
checkAll("MonadWriter[ReaderWriterStateT[ListWrapper, String, Int, String, ?], String]",
SerializableTests.serializable(MonadWriter[ReaderWriterStateT[ListWrapper, String, Int, String, ?], String]))
}

{
Expand All @@ -335,24 +360,29 @@ class ReaderWriterStateTTests extends CatsSuite {

checkAll("ReaderWriterStateT[Option, String, Int, String, Int]",
MonadErrorTests[ReaderWriterStateT[Option, String, Int, String, ?], Unit].monadError[Int, Int, Int])
checkAll("MonadError[ReaderWriterStateT[Option, String, Int, String, ?], Unit]",
SerializableTests.serializable(MonadError[ReaderWriterStateT[Option, String, Int, String, ?], Unit]))
}

{
implicit val F = ListWrapper.monad
implicit val S = ListWrapper.semigroupK
implicit val F: Monad[ListWrapper] = ListWrapper.monad
implicit val S: SemigroupK[ListWrapper] = ListWrapper.semigroupK

checkAll("ReaderWriterStateT[ListWrapper, String, Int, String, Int]",
SemigroupKTests[ReaderWriterStateT[ListWrapper, String, Int, String, ?]].semigroupK[Int])
checkAll("SemigroupK[ReaderWriterStateT[ListWrapper, String, Int, String, ?]]",
SerializableTests.serializable(SemigroupK[ReaderWriterStateT[ListWrapper, String, Int, String, ?]]))
}

{
implicit val F = ListWrapper.monad
implicit val F: Monad[ListWrapper] = ListWrapper.monad

checkAll("MonadTrans[ReaderWriterStateT[?[_], String, Int, String, ?]]",
checkAll("ReaderWriterStateT[?[_], String, Int, String, ?]",
MonadTransTests[ReaderWriterStateT[?[_], String, Int, String, ?]].monadTrans[ListWrapper, Int, Int])
checkAll("MonadTrans[ReaderWriterStateT[?[_], String, Int, String, ?]]",
SerializableTests.serializable(MonadTrans[ReaderWriterStateT[?[_], String, Int, String, ?]]))
}

}

object ReaderWriterStateTTests {
Expand All @@ -364,10 +394,6 @@ object ReaderWriterStateTTests {
}
}

implicit def RWSTArbitrary[F[_]: Applicative, E, S, L, A](
implicit F: Arbitrary[(E, S) => F[(L, S, A)]]): Arbitrary[ReaderWriterStateT[F, E, S, L, A]] =
Arbitrary(F.arbitrary.map(ReaderWriterStateT(_)))

implicit def RWSTEq[F[_], E, S, L, A](implicit S: Arbitrary[S], E: Arbitrary[E], FLSA: Eq[F[(L, S, A)]],
F: Monad[F]): Eq[ReaderWriterStateT[F, E, S, L, A]] =
Eq.by[ReaderWriterStateT[F, E, S, L, A], (E, S) => F[(L, S, A)]] { state =>
Expand Down
4 changes: 2 additions & 2 deletions tests/src/test/scala/cats/tests/StateTTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ class StateTTests extends CatsSuite {

{
// F has a Functor
implicit val F: Functor[ListWrapper] = ListWrapper.monad
implicit val F: Functor[ListWrapper] = ListWrapper.functor
// We only need a Functor on F to find a Functor on StateT
Functor[StateT[ListWrapper, Int, ?]]
}
Expand Down Expand Up @@ -235,7 +235,7 @@ class StateTTests extends CatsSuite {

checkAll("StateT[ListWrapper, Int, Int]", MonadTests[StateT[ListWrapper, Int, ?]].monad[Int, Int, Int])
checkAll("Monad[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(Monad[StateT[ListWrapper, Int, ?]]))
checkAll("MonadTrans[StateT[?[_], Int, ?]]", MonadTransTests[StateT[?[_], String, ?]].monadTrans[ListWrapper, Int, Int])
checkAll("StateT[ListWrapper, Int, Int]", MonadTransTests[StateT[?[_], String, ?]].monadTrans[ListWrapper, Int, Int])
checkAll("MonadTrans[StateT[?[_], Int, ?]]", SerializableTests.serializable(MonadTrans[StateT[?[_], Int, ?]]))

Monad[StateT[ListWrapper, Int, ?]]
Expand Down
10 changes: 5 additions & 5 deletions tests/src/test/scala/cats/tests/WriterTTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class WriterTTests extends CatsSuite {
// resolution and the laws of these various instances.
{
// F has an Apply and L has a Semigroup
implicit val F: Apply[ListWrapper] = ListWrapper.monadCombine
implicit val F: Apply[ListWrapper] = ListWrapper.applyInstance
implicit val L: Semigroup[ListWrapper[Int]] = ListWrapper.semigroup[Int]

Functor[WriterT[ListWrapper, ListWrapper[Int], ?]]
Expand All @@ -166,7 +166,7 @@ class WriterTTests extends CatsSuite {

{
// F has a Monad and L has a Semigroup
implicit val F: Monad[ListWrapper] = ListWrapper.monadCombine
implicit val F: Monad[ListWrapper] = ListWrapper.monad
implicit val L: Semigroup[ListWrapper[Int]] = ListWrapper.semigroup[Int]

Functor[WriterT[ListWrapper, ListWrapper[Int], ?]]
Expand All @@ -192,7 +192,7 @@ class WriterTTests extends CatsSuite {
}
{
// F has a FlatMap and L has a Monoid
implicit val F: FlatMap[ListWrapper] = ListWrapper.monadCombine
implicit val F: FlatMap[ListWrapper] = ListWrapper.flatMap
implicit val L: Monoid[ListWrapper[Int]] = ListWrapper.monoid[Int]

Functor[WriterT[ListWrapper, ListWrapper[Int], ?]]
Expand All @@ -219,7 +219,7 @@ class WriterTTests extends CatsSuite {

{
// F has an Applicative and L has a Monoid
implicit val F: Applicative[ListWrapper] = ListWrapper.monadCombine
implicit val F: Applicative[ListWrapper] = ListWrapper.applicative
implicit val L: Monoid[ListWrapper[Int]] = ListWrapper.monoid[Int]

Functor[WriterT[ListWrapper, ListWrapper[Int], ?]]
Expand All @@ -246,7 +246,7 @@ class WriterTTests extends CatsSuite {

{
// F has a Monad and L has a Monoid
implicit val F: Monad[ListWrapper] = ListWrapper.monadCombine
implicit val F: Monad[ListWrapper] = ListWrapper.monad
implicit val L: Monoid[ListWrapper[Int]] = ListWrapper.monoid[Int]

Functor[WriterT[ListWrapper, ListWrapper[Int], ?]]
Expand Down

0 comments on commit 4845e68

Please sign in to comment.