diff --git a/src/main/derived/ArbitraryDeriving.scala b/src/main/derived/ArbitraryDeriving.scala index 569f750..c33528e 100644 --- a/src/main/derived/ArbitraryDeriving.scala +++ b/src/main/derived/ArbitraryDeriving.scala @@ -21,20 +21,24 @@ private object Gens: def apply[T](gen: Gen[T]): Gens[T] = Gens(List(gen)) - inline private def tupleInstance[T <: Tuple]: Gens[T] = + inline private def tupleInstance[T <: Tuple](isHead: Boolean): Gens[T] = inline erasedValue[T] match case _: EmptyTuple => Gens(Gen.const(EmptyTuple.asInstanceOf[T])) case _: (t *: ts) => val gen: Gen[T] = for { - tVal <- scalacheck.anyGivenArbitrary[t].arbitrary - tsVal <- tupleInstance[ts].gen + tVal <- + // only for the very first member of a product, some extra lazyness is needed to + // ensure we don't end up in an endless loop in case of recursive structures + if (isHead) Gen.lzy(scalacheck.anyGivenArbitrary[t].arbitrary) + else scalacheck.anyGivenArbitrary[t].arbitrary + tsVal <- tupleInstance[ts](false).gen } yield (tVal *: tsVal).asInstanceOf[T] Gens(gen) inline def productInstance[T](p: Mirror.ProductOf[T]): Gens[T] = - Gens(tupleInstance[p.MirroredElemTypes].gen.map(p.fromProduct(_))) + Gens(tupleInstance[p.MirroredElemTypes](true).gen.map(p.fromProduct(_))) private inline def summonSumInstances[T, Elems <: Tuple]: List[Gens[T]] = inline erasedValue[Elems] match diff --git a/src/test/ArbitraryDerivingSuite.scala b/src/test/ArbitraryDerivingSuite.scala index 9b9318d..3099f7f 100644 --- a/src/test/ArbitraryDerivingSuite.scala +++ b/src/test/ArbitraryDerivingSuite.scala @@ -82,6 +82,10 @@ class ArbitraryDerivingSuite extends munit.FunSuite: equalValues(MaybeMaybeList.expectedGen[Int]) } + test("supports direct recursion)") { + equalValues(DirectRecursion.expectedGen) + } + // not a hard requirement (just guarding against accidental worsening by refactoring) test("supports case classes with up to 26 fields (if -Xmax-inlines=32)") { summon[Arbitrary[MaxCaseClass]] diff --git a/src/test/CogenDerivingSuite.scala b/src/test/CogenDerivingSuite.scala index 88c61c1..21b1f2c 100644 --- a/src/test/CogenDerivingSuite.scala +++ b/src/test/CogenDerivingSuite.scala @@ -73,6 +73,10 @@ class CogenDerivingSuite extends munit.ScalaCheckSuite: equalValues(MaybeMaybeList.expectedCogen[Int]) } + test("given derivation supports direct recursion") { + equalValues(DirectRecursion.expectedCogen) + } + test("enables derivation of Arbitrary instances for functions") { val arbFunction1: Arbitrary[ComplexADTWithNestedMembers => ABC] = summon diff --git a/src/test/ShrinkDerivingSuite.scala b/src/test/ShrinkDerivingSuite.scala index 9231bcf..e8bde0a 100644 --- a/src/test/ShrinkDerivingSuite.scala +++ b/src/test/ShrinkDerivingSuite.scala @@ -71,6 +71,10 @@ class ShrinkDerivingSuite extends munit.ScalaCheckSuite: equalValues(MaybeMaybeList.expectedShrink[Int]) } + property("supports direct recursion") { + equalValues(DirectRecursion.expectedShrink) + } + // seems there is no feasible way to get up to par with ArbitraryDeriving, so this is just a // guard against making things even worse test("supports case classes with up to 25 fields (if -Xmax-inlines=32)") { diff --git a/src/test/test_classes.scala b/src/test/test_classes.scala index 0c91038..a5ca303 100644 --- a/src/test/test_classes.scala +++ b/src/test/test_classes.scala @@ -471,6 +471,46 @@ object MaybeMaybeList: mml => (mml.head, mml.tail) )(shrinkTuple) +enum DirectRecursion: + case Continue(next: DirectRecursion) + case Stop + +object DirectRecursion: + + def expectedGen: Gen[DirectRecursion] = + Gen.oneOf(Gen.lzy(expectedGen.map(Continue(_))), Gen.const(Stop)) + + def expectedCogen: Cogen[DirectRecursion] = + Cogen { (seed, value) => + value match + case Continue(dr) => + perturb( + perturb[Unit]( + perturb[DirectRecursion]( + seed, + dr + )(expectedCogen), + () + ), + 0 + ) + case Stop => + perturbSingletonInSum(1, seed, Stop) + } + + @annotation.nowarn("msg=Stream .* is deprecated") + def expectedShrink: Shrink[DirectRecursion] = + Shrink { + case c: Continue => + Shrink + .xmap[DirectRecursion, Continue]( + Continue.apply, + _.next + )(expectedShrink) + .shrink(c) + case Stop => Stream.empty + } + // format: off case class MaxCaseClass( a1: Int, b1: Int, c1: Int, d1: Int, e1: Int, f1: Int, g1: Int, h1: Int, i1: Int, j1: Int,