From 48b8ce87e2869cc2a65956f48fb644ab26cb1a39 Mon Sep 17 00:00:00 2001 From: Martin Hansen Date: Thu, 13 Jun 2024 21:20:22 +0200 Subject: [PATCH] feat: Arbitrary for union types --- src/main/derived/ArbitraryDeriving.scala | 9 ++----- src/main/derived/derived.scala | 6 +++++ .../derived/extras/UnionArbitraries.scala | 24 +++++++++++++++++++ src/main/derived/extras/api.scala | 3 +++ src/main/derived/extras/macros.scala | 23 ++++++++++++++++++ src/test/ArbitraryDerivingSuite.scala | 17 ++----------- src/test/ArbitrarySuite.scala | 19 +++++++++++++++ src/test/UnionExtrasSuite.scala | 20 ++++++++++++++++ 8 files changed, 99 insertions(+), 22 deletions(-) create mode 100644 src/main/derived/extras/UnionArbitraries.scala create mode 100644 src/main/derived/extras/api.scala create mode 100644 src/main/derived/extras/macros.scala create mode 100644 src/test/ArbitrarySuite.scala create mode 100644 src/test/UnionExtrasSuite.scala diff --git a/src/main/derived/ArbitraryDeriving.scala b/src/main/derived/ArbitraryDeriving.scala index 86b9f6a..fe29013 100644 --- a/src/main/derived/ArbitraryDeriving.scala +++ b/src/main/derived/ArbitraryDeriving.scala @@ -10,12 +10,7 @@ import scala.deriving.* private case class Gens[+T](gens: List[Gen[T]]): - def combine[U >: T](that: Gens[U]): Gens[U] = - Gens[U](this.gens ++ that.gens) - - def gen: Gen[T] = gens match - case List(gen) => gen - case gens => Gen.choose(0, gens.size - 1).flatMap(i => gens(i)) + def gen: Gen[T] = genOneOf(gens) private object Gens: @@ -63,7 +58,7 @@ private object Gens: inline def sumInstance[T](s: Mirror.SumOf[T]): Gens[T] = // must be lazy for support of recursive structures lazy val elems = summonSumInstances[T, s.MirroredElemTypes] - elems.reduce(_.combine(_)) + elems.reduce((a, b) => Gens(a.gens ++ b.gens)) inline def derive[T](m: Mirror.Of[T]): Gens[T] = inline m match diff --git a/src/main/derived/derived.scala b/src/main/derived/derived.scala index 66dfc3f..47f6ec3 100644 --- a/src/main/derived/derived.scala +++ b/src/main/derived/derived.scala @@ -1,5 +1,7 @@ package io.github.martinhh.derived +import org.scalacheck.Gen + import scala.compiletime.error import scala.deriving.Mirror @@ -10,3 +12,7 @@ private inline def productToMirroredElemTypes[T](p: Mirror.ProductOf[T])( // should be impossible to reach (called in case Mirror.SumOf[T].MirroredElemTypes contains T) private inline def endlessRecursionError: Nothing = error("infinite recursive derivation") + +private def genOneOf[A](gens: List[Gen[A]]): Gen[A] = gens match + case List(gen) => gen + case gens => Gen.choose(0, gens.size - 1).flatMap(i => gens(i)) diff --git a/src/main/derived/extras/UnionArbitraries.scala b/src/main/derived/extras/UnionArbitraries.scala new file mode 100644 index 0000000..df1d861 --- /dev/null +++ b/src/main/derived/extras/UnionArbitraries.scala @@ -0,0 +1,24 @@ +package io.github.martinhh.derived.extras + +import io.github.martinhh.derived.genOneOf + +import org.scalacheck.Arbitrary +import org.scalacheck.Gen + +import scala.compiletime.summonInline + +// Serves more or less the same purpose as io.github.martinhh.derived.Gens (just in the context of unions). +// Using a separate type here (instead of reusing Gens) is intended to reduce the compile-time for implicit resolution. +private case class UnionGens[+A](gens: List[Gen[A]]) + +private object UnionGens: + inline given derived[A]: UnionGens[A] = + UnionGens(List(summonInline[Arbitrary[A]].arbitrary)) + +private trait UnionArbitraries: + + transparent inline given unionGensMacro[X]: UnionGens[X] = + io.github.martinhh.derived.extras.unionGensMacro + + transparent inline given arbUnion[X](using inline bg: UnionGens[X]): Arbitrary[X] = + Arbitrary(genOneOf(bg.gens)) diff --git a/src/main/derived/extras/api.scala b/src/main/derived/extras/api.scala new file mode 100644 index 0000000..25271ea --- /dev/null +++ b/src/main/derived/extras/api.scala @@ -0,0 +1,3 @@ +package io.github.martinhh.derived.extras + +object union extends UnionArbitraries diff --git a/src/main/derived/extras/macros.scala b/src/main/derived/extras/macros.scala new file mode 100644 index 0000000..ea85bc0 --- /dev/null +++ b/src/main/derived/extras/macros.scala @@ -0,0 +1,23 @@ +package io.github.martinhh.derived.extras + +import scala.quoted.* + +// macro based on this StackOverflow answer by Dmytro Mitin: https://stackoverflow.com/a/78567397/6152669 +private def unionGens[X: Type](using Quotes): Expr[UnionGens[X]] = + import quotes.reflect.* + TypeRepr.of[X] match + case OrType(l, r) => + (l.asType, r.asType) match + case ('[a], '[b]) => + (Expr.summon[UnionGens[a]], Expr.summon[UnionGens[b]]) match + case (Some(aInst), Some(bInst)) => + '{ + val x = $aInst + val y = $bInst + UnionGens(x.gens ++ y.gens) + }.asExprOf[UnionGens[X]] + case (_, _) => + report.errorAndAbort(s"Could not summon UnionGens") + +private transparent inline given unionGensMacro[X]: UnionGens[X] = + ${ unionGens[X] } diff --git a/src/test/ArbitraryDerivingSuite.scala b/src/test/ArbitraryDerivingSuite.scala index 31810b1..ca3bd05 100644 --- a/src/test/ArbitraryDerivingSuite.scala +++ b/src/test/ArbitraryDerivingSuite.scala @@ -3,21 +3,8 @@ package io.github.martinhh import org.scalacheck import org.scalacheck.Arbitrary import org.scalacheck.Gen -import org.scalacheck.Gen.Parameters -import org.scalacheck.rng.Seed - -class ArbitraryDerivingSuite extends munit.FunSuite: - - private def equalValues[T]( - expectedGen: Gen[T], - nTests: Int = 100 - )(using derivedArb: Arbitrary[T]): Unit = - (0 until nTests).foldLeft(Seed.random()) { case (seed, _) => - val expected = expectedGen(Parameters.default, seed) - val derived = derivedArb.arbitrary(Parameters.default, seed) - assertEquals(derived, expected, s"Differing values for seed $seed") - seed.next - } + +class ArbitraryDerivingSuite extends ArbitrarySuite: test("deriveArbitrary allows to derive a given without loop of given definition") { given arb: Arbitrary[SimpleCaseClass] = derived.arbitrary.deriveArbitrary diff --git a/src/test/ArbitrarySuite.scala b/src/test/ArbitrarySuite.scala new file mode 100644 index 0000000..3c806b9 --- /dev/null +++ b/src/test/ArbitrarySuite.scala @@ -0,0 +1,19 @@ +package io.github.martinhh + +import org.scalacheck.Arbitrary +import org.scalacheck.Gen +import org.scalacheck.Gen.Parameters +import org.scalacheck.rng.Seed + +class ArbitrarySuite extends munit.FunSuite: + + protected def equalValues[T]( + expectedGen: Gen[T], + nTests: Int = 100 + )(using arbUnderTest: Arbitrary[T]): Unit = + (0 until nTests).foldLeft(Seed.random()) { case (seed, _) => + val expected = expectedGen(Parameters.default, seed) + val derived = arbUnderTest.arbitrary(Parameters.default, seed) + assertEquals(derived, expected, s"Differing values for seed $seed") + seed.next + } diff --git a/src/test/UnionExtrasSuite.scala b/src/test/UnionExtrasSuite.scala new file mode 100644 index 0000000..3a3b98b --- /dev/null +++ b/src/test/UnionExtrasSuite.scala @@ -0,0 +1,20 @@ +package io.github.martinhh + +import org.scalacheck.Arbitrary.arbitrary +import org.scalacheck.Gen + +class UnionExtrasSuite extends ArbitrarySuite: + + test("Arbitrary for union of two types") { + import io.github.martinhh.derived.extras.union.given + type TheUnion = String | Int + val expectedGen = Gen.oneOf[TheUnion](arbitrary[String], arbitrary[Int]) + equalValues(expectedGen) + } + + test("Arbitrary for union of three types") { + import io.github.martinhh.derived.extras.union.given + type TheUnion = Boolean | String | Int + val expectedGen = Gen.oneOf[TheUnion](arbitrary[Boolean], arbitrary[String], arbitrary[Int]) + equalValues(expectedGen) + }