Skip to content

Commit

Permalink
feat: Arbitrary for union types
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinHH committed Jun 14, 2024
1 parent efe725a commit 48b8ce8
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 22 deletions.
9 changes: 2 additions & 7 deletions src/main/derived/ArbitraryDeriving.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,7 @@ import scala.deriving.*

private case class Gens[+T](gens: List[Gen[T]]):

def combine[U >: T](that: Gens[U]): Gens[U] =
Gens[U](this.gens ++ that.gens)

def gen: Gen[T] = gens match
case List(gen) => gen
case gens => Gen.choose(0, gens.size - 1).flatMap(i => gens(i))
def gen: Gen[T] = genOneOf(gens)

private object Gens:

Expand Down Expand Up @@ -63,7 +58,7 @@ private object Gens:
inline def sumInstance[T](s: Mirror.SumOf[T]): Gens[T] =
// must be lazy for support of recursive structures
lazy val elems = summonSumInstances[T, s.MirroredElemTypes]
elems.reduce(_.combine(_))
elems.reduce((a, b) => Gens(a.gens ++ b.gens))

inline def derive[T](m: Mirror.Of[T]): Gens[T] =
inline m match
Expand Down
6 changes: 6 additions & 0 deletions src/main/derived/derived.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package io.github.martinhh.derived

import org.scalacheck.Gen

import scala.compiletime.error
import scala.deriving.Mirror

Expand All @@ -10,3 +12,7 @@ private inline def productToMirroredElemTypes[T](p: Mirror.ProductOf[T])(

// should be impossible to reach (called in case Mirror.SumOf[T].MirroredElemTypes contains T)
private inline def endlessRecursionError: Nothing = error("infinite recursive derivation")

private def genOneOf[A](gens: List[Gen[A]]): Gen[A] = gens match
case List(gen) => gen
case gens => Gen.choose(0, gens.size - 1).flatMap(i => gens(i))
24 changes: 24 additions & 0 deletions src/main/derived/extras/UnionArbitraries.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package io.github.martinhh.derived.extras

import io.github.martinhh.derived.genOneOf

import org.scalacheck.Arbitrary
import org.scalacheck.Gen

import scala.compiletime.summonInline

// Serves more or less the same purpose as io.github.martinhh.derived.Gens (just in the context of unions).
// Using a separate type here (instead of reusing Gens) is intended to reduce the compile-time for implicit resolution.
private case class UnionGens[+A](gens: List[Gen[A]])

private object UnionGens:
inline given derived[A]: UnionGens[A] =
UnionGens(List(summonInline[Arbitrary[A]].arbitrary))

private trait UnionArbitraries:

transparent inline given unionGensMacro[X]: UnionGens[X] =
io.github.martinhh.derived.extras.unionGensMacro

transparent inline given arbUnion[X](using inline bg: UnionGens[X]): Arbitrary[X] =
Arbitrary(genOneOf(bg.gens))
3 changes: 3 additions & 0 deletions src/main/derived/extras/api.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package io.github.martinhh.derived.extras

object union extends UnionArbitraries
23 changes: 23 additions & 0 deletions src/main/derived/extras/macros.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package io.github.martinhh.derived.extras

import scala.quoted.*

// macro based on this StackOverflow answer by Dmytro Mitin: https://stackoverflow.com/a/78567397/6152669
private def unionGens[X: Type](using Quotes): Expr[UnionGens[X]] =
import quotes.reflect.*
TypeRepr.of[X] match
case OrType(l, r) =>
(l.asType, r.asType) match
case ('[a], '[b]) =>
(Expr.summon[UnionGens[a]], Expr.summon[UnionGens[b]]) match
case (Some(aInst), Some(bInst)) =>
'{
val x = $aInst
val y = $bInst
UnionGens(x.gens ++ y.gens)
}.asExprOf[UnionGens[X]]
case (_, _) =>
report.errorAndAbort(s"Could not summon UnionGens")

private transparent inline given unionGensMacro[X]: UnionGens[X] =
${ unionGens[X] }
17 changes: 2 additions & 15 deletions src/test/ArbitraryDerivingSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,8 @@ package io.github.martinhh
import org.scalacheck
import org.scalacheck.Arbitrary
import org.scalacheck.Gen
import org.scalacheck.Gen.Parameters
import org.scalacheck.rng.Seed

class ArbitraryDerivingSuite extends munit.FunSuite:

private def equalValues[T](
expectedGen: Gen[T],
nTests: Int = 100
)(using derivedArb: Arbitrary[T]): Unit =
(0 until nTests).foldLeft(Seed.random()) { case (seed, _) =>
val expected = expectedGen(Parameters.default, seed)
val derived = derivedArb.arbitrary(Parameters.default, seed)
assertEquals(derived, expected, s"Differing values for seed $seed")
seed.next
}

class ArbitraryDerivingSuite extends ArbitrarySuite:

test("deriveArbitrary allows to derive a given without loop of given definition") {
given arb: Arbitrary[SimpleCaseClass] = derived.arbitrary.deriveArbitrary
Expand Down
19 changes: 19 additions & 0 deletions src/test/ArbitrarySuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.github.martinhh

import org.scalacheck.Arbitrary
import org.scalacheck.Gen
import org.scalacheck.Gen.Parameters
import org.scalacheck.rng.Seed

class ArbitrarySuite extends munit.FunSuite:

protected def equalValues[T](
expectedGen: Gen[T],
nTests: Int = 100
)(using arbUnderTest: Arbitrary[T]): Unit =
(0 until nTests).foldLeft(Seed.random()) { case (seed, _) =>
val expected = expectedGen(Parameters.default, seed)
val derived = arbUnderTest.arbitrary(Parameters.default, seed)
assertEquals(derived, expected, s"Differing values for seed $seed")
seed.next
}
20 changes: 20 additions & 0 deletions src/test/UnionExtrasSuite.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package io.github.martinhh

import org.scalacheck.Arbitrary.arbitrary
import org.scalacheck.Gen

class UnionExtrasSuite extends ArbitrarySuite:

test("Arbitrary for union of two types") {
import io.github.martinhh.derived.extras.union.given
type TheUnion = String | Int
val expectedGen = Gen.oneOf[TheUnion](arbitrary[String], arbitrary[Int])
equalValues(expectedGen)
}

test("Arbitrary for union of three types") {
import io.github.martinhh.derived.extras.union.given
type TheUnion = Boolean | String | Int
val expectedGen = Gen.oneOf[TheUnion](arbitrary[Boolean], arbitrary[String], arbitrary[Int])
equalValues(expectedGen)
}

0 comments on commit 48b8ce8

Please sign in to comment.